scikit meta logo scikit meta: multi output

1 2 3 4 5 6 7
Notes

If you're interested in using the same dataset X to predict two labels, y1 and y2, then you may take a shortcut. Instead of making two pipelines, you can also use a single pipeline that contains a copy of a model. One for each model. Scikit-Learn allows you to do this by using a MultiOutputClassifier.

Let's demonstrate this method by using the titanic dataset.

import numpy as np
import pandas as pd
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.multioutput import MultiOutputClassifier, MultiOutputRegressor


df = pd.read_csv("https://calmcode.io/datasets/titanic.csv")
df.head()

The idea is to predict the survived column and the pclass column.

labels = df[['survived', 'pclass']].values
X = df.assign(sex=lambda d: d['sex'] == 'male')[['sex', 'age', 'fare']]

We can pass the X and labels arrays into the .fit() method of our MultiOutputClassifier.

clf = MultiOutputClassifier(LogisticRegression()).fit(X, labels)
clf.predict(X)

This will train a LogisticRegression for each model. Note that you're also free to train any other scikit-learn compatible classifier here. Here's another example with the KNeighborsClassifier.

clf = MultiOutputClassifier(KNeighborsClassifier()).fit(X, labels)
clf.predict(X)

You can also explore the estimator probabilities. Note that you'll get two arrays as output here. One for each label.

clf.predict_proba(X)

If you'd like you can also inspect both trained models individually. They are stored in the .estimators_ property of the model.

clf.estimators_

If you're looking for a quick way to predict multiple labels from a single dataset then this meta-trick will work for you. It mainly works when you want sensible predictions but there's no need for state of the art. If you're interested in having the most optimal pipeline out there then you may want to make two seperate pipelines instead.

Note that there's also a regressor variant of this trick if that's of interest.