Calmcode - scikit save: h5

Storing Scikit-Learn Weights Manually

1 2 3 4 5 6

If we consider a simple logistic regression model we might wonder if we really need to store the entire Python object on disk. After all, to reconstruct the model we may suffice with the trained weights.

That means that we could try to manually grab all the parts that we need and store that as data instead. This way, we can safely load in data without having any concerns about Python code running on our behalf. In the case of a logistic regression, we may only need to store numpy arrays, which is handled very well by a library called h5py.

Quick Demo

Suppose that we want to store the following model.

from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_wine

X, y = load_wine(return_X_y=True)

clf = LogisticRegression(max_iter=10_000), y)

Then we might write some helper functions that grab all the required data from the model.

import h5py

def save_coefficients(classifier, filename):
    """Save the coefficients of a linear model into a .h5 file."""
    with h5py.File(filename, 'w') as hf:
        hf.create_dataset("coef",  data=classifier.coef_)
        hf.create_dataset("intercept",  data=classifier.intercept_)
        hf.create_dataset("classes", data=classifier.classes_)

def load_coefficients(classifier, filename):
    """Attach the saved coefficients to a linear model."""
    with h5py.File(filename, 'r') as hf:
        coef = hf['coef'][:]
        intercept = hf['intercept'][:]
        classes = hf['classes'][:]
    classifier.coef_ = coef
    classifier.intercept_ = intercept
    classifier.classes_ = classes

We can use the save_coefficients function to store the coefficients into a file.

save_coefficients(clf, "clf.h5")

These files can be loaded into a new classifier by running:

lr = LogisticRegression()
load_coefficients(lr, "clf.h5")

The new lr model can now be used to make predictions.


Not Perfect

While this solution works pretty well for this simple logistic regression model, it may not suffice for our entire pipeline.

Pipeline(steps=[('countvectorizer', CountVectorizer()),
                ('logisticregression', LogisticRegression())])

A scikit-learn pipeline can have many different components, some of which won't be as-easy to reconstruct with out approach. So perhaps it's also good to consider another alternative.