知识蒸馏 - 通过引入温度参数T调整 Softmax 的输出
flyfish
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np# 设置中文字体支持
plt.rcParams["font.family"] = ['AR PL UMing CN'] # Linux
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题# 模拟模型输出的logits
logits = torch.tensor([10.0, 4.0, 1.0])# 定义不同的温度值
temperatures = [0.5, 1.0, 5.0, 10.0, 20.0]# 计算不同温度下的softmax输出
results = {}
for T in temperatures:soft_labels = F.softmax(logits / T, dim=0)results[T] = soft_labels.numpy()# 打印结果(保留四位小数)
print("原始logits:", logits.numpy())
for T, soft_labels in results.items():# 使用列表推导式和格式化字符串保留四位小数formatted_probs = [f"{p:.4f}" for p in soft_labels]print(f"温度 T={T} 时的软标签: [{', '.join(formatted_probs)}]")# 可视化不同温度下的概率分布
plt.figure(figsize=(14, 7))
x = np.arange(len(logits))
width = 0.8 / len(temperatures)for i, (T, soft_labels) in enumerate(results.items()):bars = plt.bar(x + i * width - 0.4 + width/2, soft_labels, width, label=f'T={T}')# 在每个柱子上方添加保留四位小数的概率值for bar in bars:height = bar.get_height()plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,f'{height:.4f}', ha='center', va='bottom', rotation=90)plt.xticks(x, ['猫', '狗', '狐狸'])
plt.ylabel('概率')
plt.title('不同温度T下的softmax概率分布')
plt.legend()
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.ylim(0, 1.1) # 调整y轴范围,使标签显示完整
plt.tight_layout()
plt.show()
原始logits: [10. 4. 1.]
温度 T=0.5 时的软标签: [1.0000, 0.0000, 0.0000]
温度 T=1.0 时的软标签: [0.9974, 0.0025, 0.0001]
温度 T=5.0 时的软标签: [0.6819, 0.2054, 0.1127]
温度 T=10.0 时的软标签: [0.5114, 0.2807, 0.2079]
温度 T=20.0 时的软标签: [0.4204, 0.3115, 0.2681]
低温(T=0.5):分布极陡峭,几乎只保留最大值对应的类别(猫)
标准温度(T=1.0):接近传统 softmax,突出最大值但保留少量其他类别概率
高温(T=10.0):分布非常平滑,所有类别概率接近均等
对于给定的logits向量z=[z1,z2,...,zk]\mathbf{z} = [z_1, z_2, ..., z_k]z=[z1,z2,...,zk](其中ziz_izi是模型对第iii类的原始输出分数,比如代码中的logits = [10.0, 4.0, 1.0]
),以及温度参数TTT,第iii类的软标签概率pip_ipi计算公式为:
pi=ezi/T∑j=1kezj/T p_i = \frac{e^{z_i / T}}{\sum_{j=1}^{k} e^{z_j / T}} pi=∑j=1kezj/Tezi/T
解释:
ziz_izi:代码中的logits[i]
(如logits[0] = 10.0
对应“猫”的原始分数);
TTT:代码中的温度参数(如T=0.5,1.0,5.0
等);
ezi/Te^{z_i / T}ezi/T:对“原始分数除以温度”做指数运算(代码中由F.softmax
内部实现);
分母∑j=1kezj/T\sum_{j=1}^{k} e^{z_j / T}∑j=1kezj/T:所有类别的指数结果之和,用于归一化(确保所有概率之和为1);
pip_ipi:最终的软标签概率(代码中soft_labels[i]
,如“猫”在T=5.0
时的概率约为0.6811)。
作用:
通过温度TTT缩放logits的“差异幅度”:
当T→0+T \to 0^+T→0+时,指数部分对大的ziz_izi更敏感,概率分布会极度陡峭(接近硬标签);
当T→+∞T \to +\inftyT→+∞时,所有zi/Tz_i / Tzi/T趋近于0,指数结果趋近于1,概率分布会趋近均匀(所有类别概率接近相等)。
如T=0.5
时“猫”的概率接近1,T=20
时三类概率更均匀。
在知识蒸馏(Knowledge Distillation)中,引入温度参数TTT 调整 Softmax 输出的核心目的是获取更有信息量的“软标签”(Soft Labels),以便让学生模型(Student Model)更好地学习教师模型(Teacher Model)的“知识”。温度TTT 的核心作用是通过“软化”教师模型的输出分布,保留更多关于类别间关系的细粒度知识,让学生模型能更有效地学习教师的经验。
原因
1. 原始 Softmax(T=1T=1T=1)的局限性
原始 Softmax 函数的公式为:
pi=ezi∑jezj
p_i = \frac{e^{z_i}}{\sum_{j} e^{z_j}}
pi=∑jezjezi
其中ziz_izi 是模型输出的 logits(未归一化的分数)。
当模型对正确类别有较高置信度时(比如教师模型很“确信”某个样本是“猫”),原始 Softmax 的输出会极度集中在最大 logits 对应的类别上,其他类别的概率几乎为 0(例如:p猫≈0.999p_{\text{猫}} \approx 0.999p猫≈0.999,p狗≈0.001p_{\text{狗}} \approx 0.001p狗≈0.001,p狐狸≈0p_{\text{狐狸}} \approx 0p狐狸≈0)。
这种“陡峭”的概率分布(接近硬标签)丢失了很多有价值的信息:教师模型可能认为“狗”比“狐狸”更接近“猫”(即p狗>p狐狸p_{\text{狗}} > p_{\text{狐狸}}p狗>p狐狸),但原始 Softmax 会将这种差异压缩到几乎不可见。
2. 温度TTT 的作用:“软化”概率分布,保留更多知识
当引入温度TTT 后,Softmax 公式变为:
pi=ezi/T∑jezj/T
p_i = \frac{e^{z_i / T}}{\sum_{j} e^{z_j / T}}
pi=∑jezj/Tezi/T
当T>1T > 1T>1 时:logits 被“缩放”(除以TTT),导致指数函数的“敏感度”降低,不同类别的概率差异被拉平(分布更平缓)。
例如,教师模型对“猫”“狗”“狐狸”的 logits 为 [10, 4, 1]:
T=1T=1T=1 时,输出可能是 [0.997, 0.002, 0.001](几乎只有“猫”有概率);
T=10T=10T=10 时,输出可能是 [0.607, 0.242, 0.151](保留了“狗比狐狸更接近猫”的信息)。
这种“软化”的软标签包含了教师模型对类别间相似性的判断(哪些类别容易混淆、哪些类别差异大),这些信息比单纯的硬标签(如“猫”)更丰富,能帮助学生模型学习到更鲁棒的特征。
3. 知识蒸馏中的配合使用
在知识蒸馏中,教师模型用高温TTT 生成软标签,学生模型在训练时既学习软标签(用相同的TTT),也学习原始硬标签(可选)。推理时,学生模型再用T=1T=1T=1 输出最终的硬预测。
通过这种方式,学生模型不仅学到了“正确答案”,还学到了教师模型的“推理过程”(如何权衡不同类别的可能性),从而在参数更少的情况下达到接近教师模型的性能。