梯度检查点(Gradient Checkpointing)是一种在深度学习训练中优化显存使用的技术,尤其适用于处理大型模型(如Transformer架构)时显存不足的情况。下面用简单的例子解释其工作原理和优缺点:
核心原理
深度学习训练中的显存占用主要来自三个方面:
- 模型参数(如权重、偏置)
- 优化器状态(如Adam的动量项)
- 中间激活值(forward过程中产生的张量,如注意力图、隐藏层输出等)
其中,中间激活值通常占用最大的显存空间,尤其是在深层网络中。梯度检查点的核心思想是:
- 在正向传播时:只保存少量关键的中间结果(称为“检查点”),其余中间值在计算后立即丢弃。
- 在反向传播时:利用保存的检查点重新计算被丢弃的中间值,从而获得计算梯度所需的全部信息。
如下图所示:
这种方法通过牺牲计算时间(重新计算)来节省显存空间(无需保存所有中间值)。
为什么需要梯度检查点?
假设你有一个包含100层的Transformer模型,每层在forward过程中产生1GB的中间激活值:
- 传统训练:需要保存所有100层的中间值,总显存需求为100GB。
- 梯度检查点:只保存10个检查点(每层1GB),反向传播时通过检查点重新计算其余90层,总显存需求降至10GB。
代码中的应用
在你的代码中,gradient_checkpointing=True
的配置会使模型在训练时启用梯度检查点:
trainable_model = Model(# ...其他参数gradient_checkpointing=training_config.get("gradient_checkpointing", False),# ...
)
这意味着:
- 正向传播时,模型不会保存所有注意力图和隐藏层输出。
- 反向传播时,PyTorch会利用检查点重新计算这些值,从而减少显存占用。
优缺点
- 优点:显著减少显存使用(通常能节省30%-50%的显存),允许训练更大的模型或使用更大的批次大小。
- 缺点:增加训练时间(通常慢20%-30%),因为需要重新计算中间值。
何时使用?
- 显存不足:当模型因显存限制无法训练时,梯度检查点是一种有效的解决方案。
- 计算资源充足:如果你的GPU算力充足但显存有限,可以通过延长训练时间换取更小的显存占用。
技术细节
在PyTorch中,梯度检查点通过torch.utils.checkpoint
模块实现。例如:
from torch.utils.checkpoint import checkpointdef forward(self, x):# 普通forward:保存所有中间值x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)return x# 使用梯度检查点:只保存关键检查点
def forward(self, x):x = checkpoint(self.layer1, x) # 只保存layer1的输出x = checkpoint(self.layer2, x) # 只保存layer2的输出x = self.layer3(x)return x
PyTorch Lightning的gradient_checkpointing
参数会自动为模型的所有层应用这种优化。