PyTorch DDP 随机卡死复盘:最后一个 batch 挂起,NCCL 等待不返回,三步修复 Sampler & drop_last

很多人在接触深度学习的过程往往都是从自己的笔记本开始的,但是从接触工作后,更多的是通过分布式的训练来模型。由于经验的不足常常会遇到分布式训练“玄学卡死”:多卡的训练偶发在 epoch 尾部停住不动,并且GPU 利用率掉到 0%,日志无异常。为了解决首次接触分布式训练的人的疑问,本文从bug现象以及调试逐一分析。

❓ Bug 现象

在我们进行多卡训练的时候,偶尔会出现随机在某些 epoch 尾部卡住,无异常栈;nvidia-smi 显示两卡功耗接近空闲。偶尔能看到 NCCL 打印(并不总出现):

NCCL WARN Reduce failed: ... Async operation timed out

接着通过kill -SIGQUIT 打印 Python 栈后发现停在 反向传播的梯度 allreduce*上(DistributedDataParallel 内部)。

但是这个现象在关掉 DDP(单卡训练)完全正常;把 batch_size 改小/大,卡住概率改变但仍会发生。

📽️ 场景重现

当我们的问题在单卡不会出现,但是多卡会出现问题的时候,问题点集中在数据的问题上。主要原因以下:

1️⃣ shuffle=TrueDistributedSampler 混用(会被忽略但容易误导)。

2️⃣ drop_last=False 时,最后一个小批的样本数在不同 rank 上可能不一致(当 len(dataset) 不是 world_size 的整数倍且某些数据被过滤/增强丢弃时尤其明显)。

3️⃣ 每个 epoch 忘记调用 sampler.set_epoch(epoch),导致各 rank 的随机顺序不一致

以下是笔者在多卡训练遇到的问题代码

import os, random, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset, DistributedSamplerclass DummyDS(Dataset):def __init__(self, N=1003):  # 刻意设成非 world_size 整数倍self.N = Ndef __len__(self): return self.Ndef __getitem__(self, i):x = torch.randn(32, 3, 224, 224)y = torch.randint(0, 10, (32,))   # 模拟有时会丢弃某些样本的增强(省略)return x, ydef setup():dist.init_process_group("nccl")torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))def main():setup()rank = dist.get_rank()device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))ds = DummyDS()sampler = DistributedSampler(ds, shuffle=True, drop_last=False)  # ❌ drop_last=False# ❌ DataLoader 里又写了 shuffle=True(被忽略,但容易误以为生效)loader = DataLoader(ds, batch_size=2, shuffle=True, sampler=sampler, num_workers=4)model = torch.nn.Linear(3*224*224, 10).to(device)model = DDP(model, device_ids=[device.index])opt = torch.optim.SGD(model.parameters(), lr=0.1)for epoch in range(5):# ❌ 忘记 sampler.set_epoch(epoch)for x, y in loader:x = x.view(x.size(0), -1).to(device)y = y.to(device)opt.zero_grad()loss = torch.nn.functional.cross_entropy(model(x), y)loss.backward()      # 🔥 偶发卡在这里(allreduce)opt.step()if rank == 0:print(f"epoch {epoch} done")dist.destroy_process_group()if __name__ == "__main__":main()

🔴触发条件(满足一两个就可能复现):

1️⃣ len(dataset) 不是 world_size 的整数倍。

2️⃣ 动态数据过滤/增强(例如有时返回 None 或丢样),导致各 rank 实际步数不同。

3️⃣ 忘记 sampler.set_epoch(epoch),各 rank 洗牌次序不同。

4️⃣ drop_last=False,导致最后一个 batch 在各 rank 的样本数不同。

5️⃣ 某些自定义 collate_fn 在“空 batch”时直接 continue

✔️ Debug

1️⃣ 先确认“各 rank 步数一致”

在训练 loop 里加统计(不要只在 rank0 打印):

from collections import Counter
steps = Counter()
for i, _ in enumerate(loader):steps[rank] += 1
dist.all_reduce(torch.tensor([steps[rank]], device=device), op=dist.ReduceOp.SUM)
# 或每个 rank 各自 print,检查是否相等

我的现象:有的 epoch,rank0 比 rank1 多 1–2 个 step

2️⃣开启 NCCL 调试

在启动前设置:

export NCCL_DEBUG=INFO
export NCCL_ASYNC_ERROR_HANDLING=1
export NCCL_BLOCKING_WAIT=1

再跑一遍,可看到某些 allreduce 一直等不到某 rank 进来。

3️⃣检查 Sampler 与 DataLoader 参数
  • DistributedSampler 必须搭配 sampler.set_epoch(epoch)
  • DataLoader 里不要再写 shuffle=True
  • 若数据不可整除,优先 drop_last=True;否则确保各 rank 最后一个 batch 大小一致(例如补齐/填充)。

🟢 解决方案(修复版)

  • 严格对齐 Sampler 语义 + 丢最后不齐整的 batch
import os, torch, torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, Datasetclass DummyDS(Dataset):def __init__(self, N=1003): self.N=Ndef __len__(self): return self.Ndef __getitem__(self, i):x = torch.randn(32, 3, 224, 224)y = torch.randint(0, 10, (32,))return x, ydef setup():dist.init_process_group("nccl")torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))def main():setup()rank = dist.get_rank()device = torch.device("cuda", int(os.environ["LOCAL_RANK"]))ds = DummyDS()# 关键 1:使用 DistributedSampler,统一交给它洗牌sampler = DistributedSampler(ds, shuffle=True, drop_last=True)  # ✅# 关键 2:DataLoader 里不要再写 shuffleloader = DataLoader(ds, batch_size=2, sampler=sampler, num_workers=4, pin_memory=True)model = torch.nn.Linear(3*224*224, 10).to(device)ddp = DDP(model, device_ids=[device.index], find_unused_parameters=False)  # 如无动态分支,关掉更稳更快opt = torch.optim.SGD(ddp.parameters(), lr=0.1)for epoch in range(5):sampler.set_epoch(epoch)  # ✅ 关键 3:每个 epoch 设置不同随机种子for x, y in loader:x = x.view(x.size(0), -1).to(device, non_blocking=True)y = y.to(device, non_blocking=True)opt.zero_grad(set_to_none=True)loss = torch.nn.functional.cross_entropy(ddp(x), y)loss.backward()opt.step()if rank == 0:print(f"epoch {epoch} ok")dist.barrier()  # ✅ 收尾同步,避免 rank 提前退出dist.destroy_process_group()if __name__ == "__main__":main()
  • 必须保留最后一批

如果确实不能 drop_last=True(例如小数据集),可考虑对齐 batch 大小

  1. Padding/Repeat:在 collate_fn 里把最后一批补齐到一致大小
  2. EvenlyDistributedSampler:自定义 sampler,确保各 rank 拿到完全等长的 index 列表(对总长度做上采样)。

示例(最简单的“循环补齐”):

class EvenSampler(DistributedSampler):def __iter__(self):# 先拿到原始 index,再做均匀补齐indices = list(super().__iter__())# 使得 len(indices) 可整除 num_replicasrem = len(indices) % self.num_replicasif rem != 0:pad = self.num_replicas - remindices += indices[:pad]     # 简单重复前几个样本return iter(indices)

总结

以上是这次 DDP 卡死问题从现象 → 排查 → 解决的完整记录。这个坑非常高频,尤其在课程项目/科研代码里常被忽视。希望这篇复盘能让你在分布式训练时少掉一把汗。最终定位是 DistributedSampler 使用不当 + drop_last=False + 忘记 set_epoch引发各 rank 步数不一致,导致 allreduce 永久等待。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。
如若转载,请注明出处:http://www.pswp.cn/bicheng/95858.shtml
繁体地址,请注明出处:http://hk.pswp.cn/bicheng/95858.shtml
英文地址,请注明出处:http://en.pswp.cn/bicheng/95858.shtml

如若内容造成侵权/违法违规/事实不符,请联系英文站点网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算机专业考研备考建议

对于全国硕士研究生招生考试(考研),考试科目主要由两大部分组成:全国统一命题的公共课 和 由招生单位自主命题的专业课。具体的考试科目取决于你报考的专业和学校。下面我为你详细拆解:一、考试科目构成(绝…

关于嵌入式学习——单片机1

基础整体概念以应用为中心:消费电子(手机、蓝牙耳机、智能音响)、医疗电子(心率脉搏、呼吸机)、无人机(大疆D)、机器人(人形四足机器人) 计算机技术:计算机五大组成:运算器(数据运算)、控制器(指令控制)、存储器(内存外存)、输入设备(鼠标、键盘、摄像头)、输出设备(显示器)软件…

LightDock.server liunx 双跑比较

LightDock: a new multi-scale approach to protein–protein docking The LightDock server is free and open to all users and there is no login requirement server 1示例 故去除约束 next step 结果有正有负合理 2.常见警告⚠ Structure contains HETATM entries. P…

SQL面试题及详细答案150道(61-80) --- 多表连接查询篇

《前后端面试题》专栏集合了前后端各个知识模块的面试题,包括html,javascript,css,vue,react,java,Openlayers,leaflet,cesium,mapboxGL,threejs,nodejs,mangoDB,MySQL,Linux… 。 前后端面试题-专栏总目录 文章目录 一、本文面试题目录 61. 什么是内连接(INNE…

【实操】Noej4图数据库安装和mysql表衔接实操

目录 一、图数据库介绍 二、安装Neo4j 2.1 安装java环境 2.2 安装 Neo4j(社区版) 2.3 修改配置 2.4 验证测试 2.5 卸载 2.6 基本用法 2.7 windows连接服务器可视化 三、neo4j和mysql对比 3.1 场景对比 3.2 Mysql和neo4j的映射对比 3.3 mys…

【mysql】SQL查询全解析:从基础分组到高级自连接技巧

SQL查询全解析:从基础分组到高级自连接技巧详解玩家首次登录查询的多种实现方式与优化技巧在数据库查询中,同一个需求往往有多种实现方式。本文将通过"查询每个玩家第一次登录的日期"这一常见需求,深入解析SQL查询的多种实现方法&a…

MySQL常见报错分析及解决方案总结(9)---出现interactive_timeout/wait_timeout

关于超时报错,一共有五种超时参数,详见:MySQL常见报错分析及解决方案总结(7)---超时参数connect_timeout、interactive_timeout/wait_timeout、lock_wait_timeout、net等-CSDN博客 以下是当前报错的排查方法和解决方案: MySQL 中…

第13章 Jenkins性能优化

13.1 性能优化概述 性能问题识别 常见性能瓶颈: Jenkins性能问题分类:1. 系统资源瓶颈- CPU使用率过高- 内存不足或泄漏- 磁盘I/O瓶颈- 网络带宽限制2. 应用层面问题- JVM配置不当- 垃圾回收频繁- 线程池配置问题- 数据库连接池不足3. 架构设计问题- 单点…

Python+DRVT 从外部调用 Revit:批量创建梁

今天让我们继续,看看如何批量创建常用的基础元素:梁。 跳过轴线为直线段形的,先从圆弧形的开始: from typing import List, Tuple import math # drvt_pybind 支持多会话、多文档,先从简单的单会话、单文档开始 # My…

水上乐园票务管理系统设计与开发(代码+数据库+LW)

摘 要 随着旅游业的蓬勃发展,水上乐园作为夏日娱乐的重要组成部分,其票务管理效率和服务质量直接影响游客体验。然而,传统的票务管理模式往往面临信息更新不及时、服务响应慢等问题。因此,本研究旨在通过设计并实现一个基于Spri…

【前端教程】JavaScript DOM 操作实战案例详解

案例1&#xff1a;操作div子节点并修改样式与内容 功能说明 获取div下的所有子节点&#xff0c;设置它们的背景颜色为红色&#xff1b;如果是p标签&#xff0c;将其内容设置为"我爱中国"。 实现代码 <!DOCTYPE html> <html> <head><meta ch…

qiankun+vite+react配置微前端

微前端框架&#xff1a;qiankun。 主应用&#xff1a;react19vite7&#xff0c;子应用1&#xff1a;react19vite7&#xff0c;子应用2 &#xff1a;react19vite7 一、主应用 1. 安装依赖 pnpm i qiankun 2. 注册子应用 (1) 在src目录下创建个文件夹&#xff0c;用来存储关于微…

git: 取消文件跟踪

场景&#xff1a;第一次初始化仓库的时候没有忽略.env或者node_modules&#xff0c;导致后面将.env加入.gitignore也不生效。 取消文件跟踪&#xff1a;如果是因为 node_modules 已被跟踪导致忽略无效&#xff0c; 可以使用命令git rm -r --cached node_modules来删除缓存&…

开讲啦|MBSE公开课:第五集 MBSE中期设想(下)

第五集 在本集课程中&#xff0c;刘玉生教授以MBSE建模工具选型及二次定制开发为核心切入点&#xff0c;系统阐释了"为何需要定制开发"与"如何实施定制开发"的实践逻辑&#xff0c;并提炼出MBSE中期实施的四大核心要素&#xff1a;高效高质建摸、跨域协同…

CSDN个人博客文章全面优化过程

两天前达到博客专家申请条件&#xff0c;兴高采烈去申请博客专家&#xff1a; 结果今天一看&#xff0c;申请被打回了&#xff1a; 我根据“是Yu欸”大神的博客&#xff1a; 【2024-完整版】python爬虫 批量查询自己所有CSDN文章的质量分&#xff1a;附整个实现流程_抓取csdn的…

Websocket的Key多少个字节

在WebSocket协议中&#xff0c;握手过程中的Sec-WebSocket-Key是一个由客户端生成的随机字符串&#xff0c;用于安全地建立WebSocket连接。这个Sec-WebSocket-Key是基于Base64编码的&#xff0c;并且通常由客户端在WebSocket握手请求的头部字段中发送。根据WebSocket协议规范&a…

SVT-AV1编码器中实现WPP依赖管理核心调度

一 assign_enc_dec_segments 函数。这个函数是 SVT-AV1 编码器中实现波前并行处理&#xff08;WPP&#xff09; 和分段依赖管理的核心调度器之一。//函数功能&#xff1a;分配编码解码段任务//返回值Bool//True 成功分配了一个段给当前线程&#xff0c;调用者应该处理这个段//F…

直接让前端请求代理到自己的本地服务器,告别CV报文到自己的API工具,解放双手

直接使用前端直接调用本地服务器&#xff0c;在自己的浏览器搜索插件proxyVerse&#xff0c;类似的插件应该还有一些&#xff0c;可以选择自己喜欢的这类插件可以将浏览器请求&#xff0c;直接转发到本地服务器&#xff0c;这样在本地调试的时候&#xff0c;不需要前端项目&…

Golang Goroutine 与 Channel:构建高效并发程序的基石

在当今这个多核处理器日益普及的时代&#xff0c;利用并发来提升程序的性能和响应能力已经成为软件开发的必然趋势。而Go语言&#xff0c;作为一门为并发而生的语言&#xff0c;其设计哲学中将“并发”置于核心地位。其中&#xff0c;Goroutines 和 Channels 是Go实现并发编程的…

17 C 语言宏进阶必看:从宏替换避坑到宏函数用法,不定参数模拟实现一次搞定

预处理详解1. 预定义符号//C语⾔设置了⼀些预定义符号&#xff0c;可以直接使⽤&#xff0c;预定义符号也是在预处理期间处理的。 __FILE__ //进⾏编译的源⽂件--预处理阶段被替换成指向文件名字符串的指针--char* 类型的变量 __LINE__ //⽂件当前的⾏号 --预处理阶段替换成使用…