from sklearn.calibration import calibration_curve
= calibration_curve(y_true, y_pred, n_bins=10,strategy) y_means, pred_means
calibration
在我们利用机器学习模型来建模分类预测时,首要关注的指标能力当然是dircrimination,即模型的预测区分能力。常见的指标有sensitivity、specificity、AUROC等。我们在上一篇文章中介绍了如何选择最优分类阈值,这里我们接着介绍在选择了最优阈值后,如何评估模型的校准能力。
所谓校准能力,即模型预测的概率与实际发生的概率一致。
通俗来解释这个事情:比如说,我们模型预测某个病人患病的概率是0.8,那么,按照概率定义理解,模型预测概率为0.8时,100个人中应该有80个人最终患病,这个结果体现了模型的校准能力和稳定性。如果模型预测概率为0.8时,实际只有20个人患病,那么,模型的校准能力就不够好,你也不会信任这个模型在实际应用中的预测结果。这就是校准能力的重要性,即你的模型最终输出的概率值要准确反映出事件实际发生的概率。
如何评价calibration
calibration plot
上图是一个典型的calibration curve,也是我们在文章中常见的图。
我们将模型预测概率cut或者quantile成5或者10个区间(bin),每个区间预测概率的均值作为x轴,每个区间实际发生的概率作为y轴,然后画出来这个曲线。这个图是评价模型校准能力的一个直观指标。python中可以轻松实现这个工作:
理想情况下,所有点都在对角线上,即模型预测的概率与实际发生的概率完全一致。如果点在对角线上方,说明模型低估,反之,高估。
calibration level的定义有:
其他指标
除了calibration plot,我们还可以用其他指标来评价模型的校准能力,比如说Brier score、Hosmer-Lemeshow test、calibration in the large等。这里不做详细介绍。
我们感兴趣的是,当我们通过上述方法评价了模型的校准能力后,如果发现模型的校准能力不够好,我们应该怎么办?
calibrate model
我们已经发现,模型输出值并不能代表概率。python中一般有predict_proba方法,即这个方法其实并不能保证输出的概率是真实的概率。
from sklearn.ensemble import RandomForestClassifier
= RandomForestClassifier().fit(X_train, y_train)
model= model.predict_proba(X_test)[:,1] y_pred
所以,我们需要对模型进行校准。
Platt scaling
Platt scaling是一种常见的校准方法,其原理是对模型输出的概率以及真实标签,用一个logistic regression模型来拟合,从而实现对模型输出的概率进行校准,拿到最终的概率。
from sklearn.calibration import CalibratedClassifierCV
= CalibratedClassifierCV(model, method='sigmoid', cv=5)
calibrated calibrated.fit(X_train, y_train)
Isotonic regression
Isotonic regression是另一种校准方法,其原理是对模型输出的概率以及真实标签,用一个isotonic regression模型来拟合,从而实现对模型输出的概率进行校准。
from sklearn.isotonic import IsotonicRegression
= IsotonicRegression().fit(y_pred, y_test) ir
bayesian binning into quantiles
BBQ是一种基于贝叶斯的校准方法,其原理是将预测概率分成若干个区间,然后在每个区间内对概率进行校准。该方法结合了分箱(binning)和贝叶斯推断的优点,可以在样本量较小时仍然保持较好的校准效果。
还有其他方法可以供尝试。
take home message
在利用机器学习模型进行分类预测时,我们不可忽视模型的校准能力。
代码已经放进了星球里。
References
Citation
@online{lu2024,
author = {Lu, Zhen},
title = {Python中机器学习模型的校准},
date = {2024-05-17},
url = {https://leslie-lu.github.io/blog/2024/05/17/calibration/},
langid = {en}
}