In this part we're going to train our first SGDRegressor
.
Let's start by simulating a regression dataset that we'll use.
import numpy as np
import pandas as pd
from sklearn.datasets import make_regression
from sklearn.linear_model import SGDRegressor, LinearRegression
from sklearn.model_selection import train_test_split
# Prepare Data
X, y, w = make_regression(n_features=2, n_samples=4000,
random_state=42, coef=True, noise=1.0)
y = y + 1.5
X_train, X_test, y_train, y_test = train_test_split(X, y,
test_size=0.5,
random_state=42)
We'll run a baseline LinearRegression
model first, so we have something to
compare against once we train a SGDRegressor
on a stream of data..
# Run a Baseline Model
mod_lm = LinearRegression()
mod_lm.fit(X_train, y_train)
# Keep the MSE number around for safe-keeps.
normal_mse_test = np.mean((mod_lm.predict(X_test) - y_test)**2)
Given our benchmark we can start learning via ._partial_fit()
.
# Run for Stats
mod_pac = SGDRegressor()
data = []
for i, x in enumerate(X_train):
# This is where we learn on a single datapoint
mod_pac.partial_fit([x], [y_train[i]])
# This is where we measure and save stats
data.append({
'c0': mod_pac.intercept_[0],
'c1': mod_pac.coef_.flatten()[0],
'c2': mod_pac.coef_.flatten()[1],
'mse_test': np.mean((mod_pac.predict(X_test) - y_test)**2),
'normal_mse_test': normal_mse_test,
'i': i
})
df_stats = pd.DataFrame(data)
These stats can be expected via;
import altair as alt
alt.data_transformers.disable_max_rows()
pltr1 = (pd.melt(df_stats[['i', 'c1', 'c2']], id_vars=["i"]))
pltr2 = (pd.melt(df_stats[['i', 'normal_mse_test', 'mse_test']], id_vars=["i"]))
p1 = (alt.Chart(pltr1, title='SGD evolution of weights')
.mark_line()
.encode(x='i', y='value', color='variable', tooltip=['i', 'value', 'variable'])
.properties(width=300, height=150)
.interactive())
p2 = (alt.Chart(pltr2, title='SGD evolution of mse')
.mark_line()
.encode(x='i', y='value', color='variable', tooltip=['i', 'value', 'variable'])
.properties(width=350, height=150)
.interactive())
p1 | p2
If you're unfamiliar with the altair API, you may appreciate our course on it.