SGDClassifier logo partial_fit: sgdclassifier

1 2 3 4 5 6 7 8

You're able to classify via .partial_fit() as well! You will need to pay attention when you call the method though.

Let's start by first generating data and training a baseline model.

import numpy as np
import pandas as pd
from sklearn.datasets import make_classification
from sklearn.linear_model import SGDClassifier, PassiveAggressiveClassifier, LogisticRegression
from sklearn.model_selection import train_test_split

# First generate data
X, y = make_classification(n_samples=20000, n_features=2, n_redundant=0,
                     random_state=42, n_clusters_per_class=1)

X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                    test_size=0.5,
                                                    random_state=42)

# Next train a baseline model.
mod_lmc = LogisticRegression()
mod_lmc.fit(X_train, y_train)

normal_acc_train = np.mean(mod_lmc.predict(X_train) == y_train)
normal_acc_test = np.mean(mod_lmc.predict(X_test) == y_test)

From here we can generate a classifier. Pay attention to the classes parameter of .partial_fit() though.

mod_sgd = SGDClassifier()
data = []

for i, x in enumerate(X_train):
    # Pay attention to `classes` here, we need it!
    mod_sgd.partial_fit([x], [y_train[i]], classes=[0, 1])
    data.append({
        'c1': mod_sgd.coef_.flatten()[0],
        'c2': mod_sgd.coef_.flatten()[1],
        'mod_sgd': np.mean(mod_sgd.predict(X_test) == y_test),
        'normal_acc_test': normal_acc_test,
        'i': i
    })

df_stats = pd.DataFrame(data)

The plotting code is also a bit different than before.

pltr1 = (pd.melt(df_stats[['i', 'c1', 'c2']], id_vars=["i"]))
pltr2 = (pd.melt(df_stats[['i', 'normal_acc_test', 'mod_sgd']], id_vars=["i"]))

q1 = (alt.Chart(pltr1, title='SGD evolution of weights')
        .mark_line()
        .encode(x='i', y='value', color='variable', tooltip=['i', 'value', 'variable'])
        .properties(width=300, height=150)
        .interactive())

q2 = (alt.Chart(pltr2, title='PA evolution of accuracy')
        .mark_line()
        .encode(x='i', y='value', color='variable', tooltip=['i', 'value', 'variable'])
        .properties(width=350, height=150)
        .interactive())

(q1 | q2)