logo

... human learn: classifier



Notes

Let's expand our work from the previous video by making a classifier.

The fare_based function defined below can be used to make predictions.

from hulearn.datasets import load_titanic

df = load_titanic(as_frame=True)
X, y = df.drop(columns=['survived']), df['survived']

def fare_based(dataf, threshold=10):
    return np.array(dataf['fare'] > threshold).astype(int)

from hulearn.classification import FunctionClassifier
# This next line of code changes the function into a proper scikit-learn compatible model.
mod = FunctionClassifier(fare_based, threshold=10)

We can use this mod in a GridSearchCV too.

from sklearn.model_selection import GridSearchCV
from sklearn.metrics import precision_score, recall_score, accuracy_score, make_scorer

grid = GridSearchCV(mod, 
                    cv=2, 
                    param_grid={'threshold': np.linspace(0, 100, 30)},
                    scoring={'accuracy': make_scorer(accuracy_score), 
                            'precision': make_scorer(precision_score),
                            'recall': make_scorer(recall_score)},
                    refit='accuracy'
                )
grid.fit(X, y)

You can use this trained grid object to make a chart that shows the effect of threshold.

score_df = (pd.DataFrame(grid.cv_results_)
  .set_index('param_threshold')
  [['mean_test_accuracy', 'mean_test_precision', 'mean_test_recall']])

score_df.plot(figsize=(12, 5), title="scores vs. fare-threshold");

Feedback? See an issue? Something unclear? Feel free to mention it here.



If you want to be kept up to date, consider signing up for the newsletter.