Using the MultiOutputClassifier
If you're interested in using the same dataset
X to predict two labels,
y2, then you
may take a shortcut. Instead of making two pipelines, you can also use a single pipeline that
contains a copy of a model. One for each model. Scikit-Learn allows you to do this by using a
Let's demonstrate this method by using the
import numpy as np
import pandas as pd
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.multioutput import MultiOutputClassifier, MultiOutputRegressor
df = pd.read_csv("https://calmcode.io/datasets/titanic.csv")
The idea is to predict the
survived column and the
labels = df[['survived', 'pclass']].values
X = df.assign(sex=lambda d: d['sex'] == 'male')[['sex', 'age', 'fare']]
We can pass the
labels arrays into the
.fit() method of our MultiOutputClassifier.
clf = MultiOutputClassifier(LogisticRegression()).fit(X, labels)
This will train a
LogisticRegression for each model. Note that you're
also free to train any other scikit-learn compatible classifier here.
Here's another example with the
clf = MultiOutputClassifier(KNeighborsClassifier()).fit(X, labels)
You can also explore the estimator probabilities. Note that you'll get two arrays as output here. One for each label.
If you'd like you can also inspect both trained models individually. They are stored
.estimators_ property of the model.
If you're looking for a quick way to predict multiple labels from a single dataset then this meta-trick will work for you. It mainly works when you want sensible predictions but there's no need for state of the art. If you're interested in having the most optimal pipeline out there then you may want to make two seperate pipelines instead.
Note that there's also a regressor variant of this trick if that's of interest.
If you're looking for an interesting use-case for this technique: consider confidence intervals! When your dataset has an upper/lower bound then this is something that you can predict as well. This technique is explored in more detail in this blogpost.