Calmcode - partial_fit: sgdregression

How to train SGDRegression on micro-batches via partial_fit.

1 2 3 4 5 6 7 8

In this part we're going to train our first SGDRegressor. Let's start by simulating a regression dataset that we'll use.

import numpy as np
import pandas as pd
from sklearn.datasets import make_regression
from sklearn.linear_model import SGDRegressor, LinearRegression
from sklearn.model_selection import train_test_split

# Prepare Data
X, y, w = make_regression(n_features=2, n_samples=4000,
                          random_state=42, coef=True, noise=1.0)
y = y + 1.5

X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                    test_size=0.5,
                                                    random_state=42)

We'll run a baseline LinearRegression model first, so we have something to compare against once we train a SGDRegressor on a stream of data..

# Run a Baseline Model
mod_lm = LinearRegression()
mod_lm.fit(X_train, y_train)

# Keep the MSE number around for safe-keeps.
normal_mse_test = np.mean((mod_lm.predict(X_test) - y_test)**2)

Given our benchmark we can start learning via ._partial_fit().

# Run for Stats
mod_pac = SGDRegressor()
data = []


for i, x in enumerate(X_train):
    # This is where we learn on a single datapoint
    mod_pac.partial_fit([x], [y_train[i]])

    # This is where we measure and save stats
    data.append({
        'c0': mod_pac.intercept_[0],
        'c1': mod_pac.coef_.flatten()[0],
        'c2': mod_pac.coef_.flatten()[1],
        'mse_test': np.mean((mod_pac.predict(X_test) - y_test)**2),
        'normal_mse_test': normal_mse_test,
        'i': i
    })

df_stats = pd.DataFrame(data)

These stats can be expected via;

import altair as alt

alt.data_transformers.disable_max_rows()

pltr1 = (pd.melt(df_stats[['i', 'c1', 'c2']], id_vars=["i"]))
pltr2 = (pd.melt(df_stats[['i', 'normal_mse_test', 'mse_test']], id_vars=["i"]))

p1 = (alt.Chart(pltr1, title='SGD evolution of weights')
        .mark_line()
        .encode(x='i', y='value', color='variable', tooltip=['i', 'value', 'variable'])
        .properties(width=300, height=150)
        .interactive())

p2 = (alt.Chart(pltr2, title='SGD evolution of mse')
        .mark_line()
        .encode(x='i', y='value', color='variable', tooltip=['i', 'value', 'variable'])
        .properties(width=350, height=150)
        .interactive())

p1 | p2

If you're unfamiliar with the altair API, you may appreciate our course on it.