How to set a threshold for a sklearn classifier based on ROC results?

17,843

This is what I have done:

model = SomeSklearnModel()
model.fit(X_train, y_train)
predict = model.predict(X_test)
predict_probabilities = model.predict_proba(X_test)
fpr, tpr, _ = roc_curve(y_test, predict_probabilities)

However, I am annoyed that predict chooses a threshold corresponding to 0.4% of true positives (false positives are zero). The ROC curve shows a threshold I like better for my problem where the true positives are approximately 20% (false positive around 4%). I then scan the predict_probabilities to find what probability value corresponds to my favourite ROC point. In my case this probability is 0.21. Then I create my own predict array:

predict_mine = np.where(rf_predict_probabilities > 0.21, 1, 0)

and there you go:

confusion_matrix(y_test, predict_mine)

returns what I wanted:

array([[6927,  309],
       [ 621,  121]])
Share:
17,843

Related videos on Youtube

Colis
Author by

Colis

Updated on September 16, 2022

Comments

  • Colis
    Colis over 1 year

    I trained an ExtraTreesClassifier (gini index) using scikit-learn and it suits my needs fairly. Not so good accuracy, but using a 10-fold cross validation, AUC is 0.95. I would like to use this classifier on my work. I am quite new to ML, so please forgive me if I'm asking you something conceptually wrong.

    I plotted some ROC curves, and by it, its seems I have a specific threshold where my classifier starts performing well. I'd like to set this value on the fitted classifier, so everytime I'd call predict, the classifiers use that threshold and I could believe in the FP and TP rates.

    I also came to this post (scikit .predict() default threshold), where its stated that a threshold is not a generic concept for classifiers. But since the ExtraTreesClassifier has the method predict_proba, and the ROC curve is also related to thresdholds definition, it seems to me I should be available to specify it.

    I did not find any parameter, nor any class/interface to use to do it. How can I set a threshold for it for a trained ExtraTreesClassifier (or any other one) using scikit-learn?

    Many Thanks, Colis

  • Philipp
    Philipp over 4 years
    Keep in mind that the resulting confusion matrix is not a proper metric for out of sample performance due to the fact that the threshold was chosen upon the test data which leads to data leakage. The proper way to do this is to split the data into train/validate/test. Train the classifier with the train data, choose the threshold with the validation data and evaluate the final model (threshold included) with the test set.
  • famargar
    famargar over 2 years
    yes you are right, I oversimplified the answer