很多同学问:“模型好不好,怎么量化?”
本篇系统梳理 sklearn.metrics 中常用且“够用”的多分类指标,并给出一段可直接运行的示例代码,覆盖:准确率、宏/微/加权 F1、Kappa、MCC、混淆矩阵(计数/归一化)、Top-K 准确率、ROC-AUC(OvR/OvO)、PR-AUC、对数损失、(多类)Brier 分数、以及 ROC/PR 曲线绘制。
🧭 指标速览与使用场景
-
整体验证
accuracy_score
(OA,总体准确率)balanced_accuracy_score
(类别不均衡时更合理)
-
逐类与加权
precision_recall_fscore_support
/classification_report
- 平均方式:
average="macro" | "micro" | "weighted"
-
一致性/稳健性
cohen_kappa_score
(Kappa)matthews_corrcoef
(MCC,抗不均衡)
-
混淆矩阵
confusion_matrix
(计数 & 归一化)
-
概率质量/排序质量
roc_auc_score
(多类:multi_class="ovr"|"ovo"
;average="macro"|"weighted"
)average_precision_score
(PR-AUC)top_k_accuracy_score
(Top-K)log_loss
(对数损失,校准敏感)- 多类 Brier(自定义:one-hot 与
predict_proba
的 MSE 均值)
-
曲线
- ROC 曲线(micro/macro)
- Precision-Recall 曲线(micro)
经验:类不均衡→看
balanced_accuracy
/macro-F1
/Kappa
/MCC
;
要概率好坏→看log_loss
/ROC-AUC
/PR-AUC
;
Top-K 检索/多候选→看top_k_accuracy_score
。
💻 一键可跑代码(修改 DATA_DIR
后直接运行)
# -*- coding: utf-8 -*-
"""
Sklearn案例⑧:metrics 全解析(多分类 / 概率与曲线)
数据:KSC(将 DATA_DIR 改为你的数据路径)
"""import os, numpy as np, scipy.io as sio, matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (accuracy_score, balanced_accuracy_score,precision_recall_fscore_support, classification_report, confusion_matrix,cohen_kappa_score, matthews_corrcoef, top_k_accuracy_score, roc_auc_score,average_precision_score, log_loss)
from sklearn.preprocessing import label_binarize# ============ 参数 ============
DATA_DIR = "your_path" # ←← 修改为包含 KSC.mat / KSC_gt.mat 的目录
PCA_DIM, TRAIN_RATIO, SEED = 30, 0.3, 42# ============ 1. 载入与预处理 ============
X = sio.loadmat(os.path.join(DATA_DIR, "KSC.mat"))["KSC"].astype(np.float32) # (H,W,B)
Y = sio.loadmat(os.path.join(DATA_DIR, "KSC_gt.mat"))["KSC_gt"].astype(int) # (H,W)
coords = np.argwhere(Y != 0)
Xpix = X[coords[:,0], coords[:,1]] # (N,B)
y = Y[coords[:,0], coords[:,1]] - 1 # 0..C-1
num_classes = int(y.max() + 1)Xtr, Xte, ytr, yte = train_test_split(Xpix, y, train_size=TRAIN_RATIO,stratify=y, random_state=SEED)
scaler = StandardScaler().fit(Xtr)
pca = PCA(n_components=PCA_DIM, random_state=SEED).fit(scaler.transform(Xtr))
Xtr = pca.transform(scaler.transform(Xtr))
Xte = pca.transform(scaler.transform(Xte))# ============ 2. 训练一个可输出概率的模型 ============
# 用 RF 示范(也可以换 SVC(probability=True)、LogReg 等)
clf = RandomForestClassifier(n_estimators=300, random_state=SEED, n_jobs=-1)
clf.fit(Xtr, ytr)
y_pred = clf.predict(Xte)
y_proba = clf.predict_proba(Xte) # (N_test, C)# ============ 3. 基础/稳健指标 ============
oa = accuracy_score(yte, y_pred)
boa = balanced_accuracy_score(yte, y_pred)
kappa = cohen_kappa_score(yte, y_pred)
mcc = matthews_corrcoef(yte, y_pred)prec_m, rec_m, f1_m, _ = precision_recall_fscore_support(yte, y_pred, average="macro", zero_division=0)
prec_w, rec_w, f1_w, _ = precision_recall_fscore_support(yte, y_pred, average="weighted", zero_division=0)print("=== 基础评估 ===")
print(f"OA : {oa*100:.2f}%")
print(f"Balanced Acc : {boa*100:.2f}%")
print(f"Macro-F1 : {f1_m*100:.2f}% (P={prec_m*100:.1f} R={rec_m*100:.1f})")
print(f"Weighted-F1 : {f1_w*100:.2f}% (P={prec_w*100:.1f} R={rec_w*100:.1f})")
print(f"Cohen's Kappa : {kappa:.4f}")
print(f"Matthews Corrcoef : {mcc:.4f}")
print("\n=== 分类报告(逐类) ===")
print(classification_report(yte, y_pred, digits=4, zero_division=0))# ============ 4. 混淆矩阵(计数/归一化) ============
cm = confusion_matrix(yte, y_pred, labels=np.arange(num_classes))
cm_norm = cm / np.maximum(cm.sum(axis=1, keepdims=True), 1)plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.imshow(cm, interpolation='nearest')
plt.title("Confusion Matrix (Counts)")
plt.xlabel("Pred"); plt.ylabel("True")
plt.colorbar(fraction=0.046, pad=0.04)plt.subplot(1,2,2)
plt.imshow(cm_norm, vmin=0, vmax=1, interpolation='nearest')
plt.title("Confusion Matrix (Normalized)")
plt.xlabel("Pred"); plt.ylabel("True")
plt.colorbar(fraction=0.046, pad=0.04)
plt.tight_layout(); plt.show()# ============ 5. 概率/排序质量 ============
# 5.1 多类 ROC-AUC:OvR & OvO(macro/weighted)
y_bin = label_binarize(yte, classes=np.arange(num_classes)) # (N,C)
auc_ovr_macro = roc_auc_score(yte, y_proba, multi_class="ovr", average="macro")
auc_ovr_weight= roc_auc_score(yte, y_proba, multi_class="ovr", average="weighted")
auc_ovo_macro = roc_auc_score(yte, y_proba, multi_class="ovo", average="macro")
print("\n=== 概率/排序质量 ===")
print(f"ROC-AUC OvR (macro) : {auc_ovr_macro:.4f}")
print(f"ROC-AUC OvR (weighted): {auc_ovr_weight:.4f}")
print(f"ROC-AUC OvO (macro) : {auc_ovo_macro:.4f}")# 5.2 PR-AUC(macro)
ap_macro = average_precision_score(y_bin, y_proba, average="macro")
print(f"PR-AUC (macro) : {ap_macro:.4f}")# 5.3 对数损失(log-loss)
ll = log_loss(yte, y_proba, labels=np.arange(num_classes))
print(f"Log Loss : {ll:.4f}")# 5.4 多类 Brier(自定义:one-hot 与 predict_proba 的 MSE 均值)
brier_multi = np.mean((y_bin - y_proba)**2)
print(f"Brier Score (multi) : {brier_multi:.4f}")# 5.5 Top-K 准确率(以 K=3 为例)
top3 = top_k_accuracy_score(yte, y_proba, k=3, labels=np.arange(num_classes))
print(f"Top-3 Accuracy : {top3*100:.2f}%")# ============ 6. 曲线:micro-ROC 与 micro-PR ============
# micro:将多类视为一个“整体二分类”汇总,便于一张图比较
from sklearn.metrics import RocCurveDisplay, PrecisionRecallDisplay
# ROC (micro)
fpr = dict(); tpr = dict()
from sklearn.metrics import roc_curve, precision_recall_curve, auc
y_bin_pred = y_proba
fpr_micro, tpr_micro, _ = roc_curve(y_bin.ravel(), y_bin_pred.ravel())
roc_auc_micro = auc(fpr_micro, tpr_micro)# PR (micro)
prec_micro, rec_micro, _ = precision_recall_curve(y_bin.ravel(), y_bin_pred.ravel())
pr_auc_micro = auc(rec_micro, prec_micro)plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.plot(fpr_micro, tpr_micro, lw=2, label=f"micro-ROC AUC={roc_auc_micro:.3f}")
plt.plot([0,1],[0,1],'--', lw=1)
plt.xlabel("FPR"); plt.ylabel("TPR")
plt.title("ROC (micro-average)")
plt.legend(frameon=False)plt.subplot(1,2,2)
plt.plot(rec_micro, prec_micro, lw=2, label=f"micro-PR AUC={pr_auc_micro:.3f}")
plt.xlabel("Recall"); plt.ylabel("Precision")
plt.title("Precision-Recall (micro-average)")
plt.legend(frameon=False)
plt.tight_layout(); plt.show()
✅ 实战要点(如何选指标)
- 报告一页通读:
OA + macro-F1 + Kappa + MCC + 混淆矩阵(归一化)
这几项能同时反映整体、逐类与稳健性,对不均衡也更有意义。 - 需要概率质量:加上
log_loss + ROC-AUC(ovr, macro) + PR-AUC(macro)
;
若要“多候选命中”,再加Top-K
。 - 展示与沟通:曲线(ROC/PR)更直观,归一化混淆矩阵能指出“易混类”。
- 避免踩坑:类别极不均衡时,单看
accuracy
容易误判;阈值可调的任务(告警/检索),更应看 PR-AUC 与 Precision-Recall 曲线。
欢迎大家关注下方我的公众获取更多内容!