文章目录
- 什么是 Warmup?
- 实现示例
- 科学设置 Warmup 的黄金法则
- 直观例子
什么是 Warmup?
Warmup 是一种学习率调度策略,在训练初期逐步增加学习率(LR),而不是直接使用目标学习率。它解决了两个关键问题:
- 避免早期震荡:模型参数初始化为随机值,直接高LR会导致不稳定更新。
- 稳定Adam优化器:Adam的动量估计在初始阶段不准确,需要渐进调整。
实现示例
def get_lr(step, warmup_steps, d_model):# 1. 预热阶段:线性增长if step < warmup_steps:return base_lr * (step / warmup_steps)# 2. 衰减阶段:反平方根衰减scale = (warmup_steps ** 0.5) * min(step ** (-0.5), step * (warmup_steps ** (-1.5))return base_lr * scale
科学设置 Warmup 的黄金法则
- NNN = 总样本数(条)
- BBB = 每次 forward 的原始 batch(每卡)
- AAA = 梯度累积步数(
accum_grad
) - EEE = epoch 数
- WWW = 卡数(
world_size
,单卡 = 1)
- 每 epoch 的 forward 次数(向上取整):
I=⌈N/(B×W)⌉I = \lceil N / (B \times W) \rceil I=⌈N/(B×W)⌉
- 每 epoch 的 optimizer 更新次数(每 AAA 次 forward 做一次 update,向上取整):
S=⌈I/A⌉S = \lceil I / A \rceil S=⌈I/A⌉
- 总 optimizer step(也就是 scheduler 用的 total steps):
T=S×ET = S \times E T=S×E
- 推荐的 warmup 步数:
warmup={max{⌈0.10×T⌉,10},T<4000clamp(⌊0.05×T⌉,4000,20000),T≥4000\text{warmup} = \begin{cases} \max\{\lceil 0.10 \times T \rceil, 10\}, & T < 4000\\[4pt] \operatorname{clamp}(\, \lfloor 0.05 \times T \rceil,\; 4000,\; 20000 \,), & T \ge 4000 \end{cases} warmup={max{⌈0.10×T⌉,10},clamp(⌊0.05×T⌉,4000,20000),T<4000T≥4000
并且最终确保 warmup≤T−1\text{warmup} \le T-1warmup≤T−1。
解释:小训练用 10%,大训练用 5%,并在 4k–20k 之间限制
直观例子
假设 B=16,A=8,E=120,W=1B=16, A=8, E=120, W=1B=16,A=8,E=120,W=1:
- 若 N=100,000N = 100{,}000N=100,000:
- I=⌈100000/16⌉=6250I=\lceil100000/16\rceil=6250I=⌈100000/16⌉=6250
- S=⌈6250/8⌉=782S=\lceil6250/8\rceil=782S=⌈6250/8⌉=782
- T=782×120=93,840T=782\times120=93{,}840T=782×120=93,840
- warmup ≈ ⌊0.05×93840⌉=4,692\lfloor0.05\times93840\rceil=4{,}692⌊0.05×93840⌉=4,692(取 4000–20000 区间内 → 4692)
- 若 N=1,000,000N = 1{,}000{,}000N=1,000,000:
- I=⌈1000000/16⌉=62500I=\lceil1000000/16\rceil=62500I=⌈1000000/16⌉=62500
- S=⌈62500/8⌉=7813S=\lceil62500/8\rceil=7813S=⌈62500/8⌉=7813
- T=7813×120=937,560T=7813\times120=937{,}560T=7813×120=937,560
- 0.05×T ≫ 20000 → clamp → 20000
- 若 N=100N = 100N=100(极小样本、仅作示例):
- I=⌈100/16⌉=7I=\lceil100/16\rceil=7I=⌈100/16⌉=7
- S=⌈7/8⌉=1S=\lceil7/8\rceil=1S=⌈7/8⌉=1
- T=1×120=120T=1\times120=120T=1×120=120
- 因为 T<4000T<4000T<4000,warmup = max(ceil(0.1×120),10) = 12
import mathdef suggest_warmup(N,B,A,E,W=1):I = math.ceil(N / (B*W))S = math.ceil(I / A)T = S * Eif T < 4000:w = max(math.ceil(0.10*T), 10)else:w = round(0.05*T)w = max(4000, min(w, 20000))w = min(w, T-1)return {"iters_per_epoch":I, "opt_steps_per_epoch":S, "total_steps":T, "warmup":w}print(suggest_warmup(403733, 16, 8, 120, 4))
# {'iters_per_epoch': 6309, 'opt_steps_per_epoch': 789, 'total_steps': 94680, 'warmup': 4734}print(suggest_warmup(403733, 16, 4, 100, 4))
# {'iters_per_epoch': 6309, 'opt_steps_per_epoch': 1578, 'total_steps': 157800, 'warmup': 7890}