CS336 第七讲 · 分布式训练:数据、张量、流水线、序列并行全景

June 15, 2026

CS336 第七讲 · 分布式训练:数据、张量、流水线、序列并行全景

第五讲第六讲我们一直在一张卡上较劲:GPU 算得太快、喂不饱数据,于是所有优化都收敛到「减少对慢速 HBM 的访问」。但有一个前提被悄悄假设了——模型能装进一张卡

这个前提正在崩塌。今天的前沿模型动辄几千亿参数(DeepSeek-V3 高达 671B),光是参数加上训练所需的梯度与优化器状态,就要 TB 级显存,而单卡 HBM 不过 80–192 GB。一张卡装不下,就只能把模型摊到很多张卡、很多台机器上——这就是分布式训练。

本文对应斯坦福 CS336 第七讲(Parallelism),也对应 Datawhale diy-llm 第八章。原文把硬件、通信、四种并行、PyTorch 实现讲得非常完整,但也铺得很长。这里把它压成一条主线 + 一个框架:先看清多卡时代的新敌人,再用「一份负载、四种切法」把所有并行策略串成一棵决策树。

一条新主线:敌人从「访存」换成了「通信」

整个 CS336 的 GPU 三讲,其实在反复讲同一件事:算力一直在疯涨,真正的瓶颈永远是「喂数据」。只是到了多卡场景,"喂"的对象变了:

  • 单卡(第五、六讲):计算单元喂不饱,瓶颈是 GPU 内部的 HBM 访存带宽。
  • 多卡(本讲):每张卡都要等别人的中间结果,瓶颈变成了卡与卡之间的通信带宽

而卡间通信,比 HBM 还要慢得多、还要分层。要理解所有并行策略的取舍,先得看清这座新的「延迟金字塔」。

通信带宽金字塔:从片内 HBM 到节点内 NVLink 再到跨节点网络,带宽逐级暴跌

关键不在某一个具体数字,而在每跨一层边界,带宽就掉一个量级:片内 HBM 有约 3.9 TB/s,同一台机器里 GPU 之间走 NVLink 还有约 900 GB/s,可一旦跨出机器、走以太网或 InfiniBand,就只剩几十 GB/s 甚至更低。

这条金字塔直接决定了后面所有策略的「摆放位置」:通信越频繁的策略,越要关进越快的硬件层级。需要每层都同步的张量并行,只能挤在一台机器的 8 张卡里吃 NVLink;而只在每步同步一次的数据并行,才敢放到跨节点的慢速网络上。记住这一点,本文后半段的所有结论都会顺理成章。

通信的语言:一条恒等式讲透集合通信

多卡之间怎么"对话"?靠的是一组标准化的集合通信原语(Collective Communication)——Broadcast(广播)、Scatter(分发)、Gather(汇聚)、Reduce(归约)、All-Gather(全收集)、Reduce-Scatter(归约分散)、All-Reduce(全归约)。名字一大串,但真正撑起分布式训练的,是其中一条恒等式:

集合通信原语:All-Reduce 等价于 Reduce-Scatter 加 All-Gather

All-Reduce=Reduce-Scatter+All-Gather\text{All-Reduce} = \text{Reduce-Scatter} + \text{All-Gather}

把它拆开看就一目了然:每张卡都有一份完整向量(比如各自算出的梯度),第一步 Reduce-Scatter 让每张卡只负责把"自己那一段"在所有卡上求和、其余丢弃;第二步 All-Gather 再把这些求好和的分片广播回所有卡。两步下来,人人都拿到了完整的全局求和结果。

这条恒等式之所以重要,是因为它点破了通信成本:两步各传输约 1× 的数据量,一次 All-Reduce 的通信量约等于 2× 数据量,与卡的数量基本无关。这正是数据并行"每步同步一次梯度"的成本来源;而把这两步拆开单独使用,又恰好是 FSDP 的核心手法——后面马上会用到。

工程上,这些原语由 NVIDIA 的 NCCL 库实现:它自动探测硬件拓扑、规划最优传输路径,PyTorch 的 torch.distributed 再在其上包一层更易用的接口(GPU 用 NCCL 后端,CPU 用 Gloo)。

一份负载,四种切法:所有并行策略的总框架

理解了通信成本,现在可以俯瞰全局了。一次训练负载,本质是一批序列,流过一摞层,每层都是一堆大矩阵。要把它摊到多卡上,无非是沿某个维度"切一刀"——而切哪个维度,就对应哪种并行策略:

一份负载,四种切法:数据、宽度、深度、序列四个维度分别对应四种并行

  • 切数据维度(批次) → 数据并行(DP):模型整份复制,每卡喂不同批次。
  • 切宽度维度(矩阵) → 张量并行(TP):把每层的大矩阵沿宽度切给多卡同算。
  • 切深度维度(层) → 流水线并行(PP):把层摞成的栈切成前段、后段。
  • 切序列维度(token) → 序列并行(SP):把一条长序列切成几段。

这四刀彼此正交,可以叠加使用(即所谓 3D / 4D 并行)。下面就按"最常用、最先考虑"的顺序,逐一拆开。

数据并行与 ZeRO:把重复的状态切掉

数据并行(Data Parallelism, DDP)是最朴素、也最常用的一招:每张卡放一份完整模型副本,把一个大批次切成若干小份各算各的,反向传播后用一次 All-Reduce 把梯度取平均同步,然后各卡用相同的梯度独立更新——更新后所有副本依然一致。

数据并行 DDP 流程:复制模型、切分批次、All-Reduce 同步梯度、独立更新

它的优点是逻辑简单、通信少(每步只在反向后同步一次)。但它有个致命的浪费:每张卡都存了一份完整的「参数 + 梯度 + 优化器状态」。以混合精度 Adam 为例,每个参数要占约 16 字节(2 字节 FP16 参数 + 2 字节梯度 + 12 字节 FP32 的主权重和一二阶动量),一个 7.5B 的模型单卡就要 120 GB——而且 N 张卡上存的是 N 份一模一样的东西,纯属冗余。

ZeRO(Zero Redundancy Optimizer)的洞见是:既然这些状态在所有卡上都一样,何必每张卡都存全份?不如切片分摊,每张卡只存 1/N,要用时再临时取回。它分三个阶段,逐级把更多状态切开:

ZeRO 三阶段:从基线 120GB 逐级切到 31.4GB、16.6GB、1.9GB

  • Stage 1(切优化器状态):占大头的 12 字节 FP32 状态最先切片。120 GB → 31.4 GB。
  • Stage 2(再切梯度):梯度也按分片只保留自己那段。31.4 GB → 16.6 GB。
  • Stage 3(再切参数):连参数本身都切片,这就是 FSDP(Fully Sharded Data Parallel)。16.6 GB → 1.9 GB,约 64× 收缩

天下没有免费的午餐:参数被切开后,每层前向/反向都得先 All-Gather 把这一层的完整参数临时拼回来用,用完即弃;梯度则用 Reduce-Scatter 更新回各自的分片。算下来通信量约是 DDP 的 1.5 倍(理论 3× vs 2×)。FSDP 靠一个关键技巧把这笔开销藏了起来——预取(prefetch)与计算重叠:在算当前层时,提前异步把下一层的参数 All-Gather 过来,让通信悄悄发生在计算的"阴影"里。实测下来,实际性能损耗只有 10%–20%,却换回了几十倍的显存。

一句话记住数据并行这一节:DDP 用通信换简单,ZeRO/FSDP 用通信换显存。当模型的「参数+状态」单卡装不下时,FSDP 几乎是第一选择。

不过数据并行有个天花板:它的并行度受批大小约束(再多的卡也不能让每卡分到的样本少于 1),而且批大小过了某个临界点后,梯度更新的次数会变少,收益递减。要继续扩展,就得动模型本身了。

张量并行:把一层矩阵切给多卡同算

如果说数据并行切的是"数据",张量并行(Tensor Parallelism, TP)切的就是"模型本身"——它把每一层内部的大矩阵沿宽度维度劈开,分给多张卡同时算同一层。

以 Transformer 里的 MLP 块 Y=GELU(XA)BY = \text{GELU}(X A) B 为例,Megatron 的经典切法巧妙地让通信降到最低:

张量并行:A 按列切分无需通信,B 按行切分后用 All-Reduce 求和

第一个矩阵 AA 按列切分,每张卡拿一部分列,各自算出一段隐藏向量。由于 GELU 是逐元素的、列与列互不影响,这一整段计算全程无需通信。第二个矩阵 BB按行切分,每张卡用自己那段隐藏向量算出一个"部分和",最后一次 All-Reduce 把各卡的部分和加起来,就拼回了完整输出。

代价非常直白:每一层的前向、反向都各要一次 All-Reduce。模型有几十层,一步就是几十上百次同步——这是所有并行里通信最频繁的。回到开头那座金字塔:如此高频的同步只有 NVLink 那个量级的带宽扛得住。所以张量并行有一条铁律:只在单节点内用(通常 ≤ 8 卡)。一旦跨出节点走慢速网络,性能就会断崖式下跌——经验数据是 8 卡损耗约 10%,16 卡骤降 42%,32 卡再降 65%。

流水线并行:气泡,与如何挤掉它

流水线并行(Pipeline Parallelism, PP)换了个维度:它按层的边界把模型切成几段,每张卡负责连续的几层,像工厂流水线一样,前一张卡算完把激活值传给下一张卡。

听起来很美,但朴素实现有个致命缺陷——气泡(bubble)

流水线气泡:朴素流水线利用率约 1/n,微批次填充后气泡只剩首尾小三角

如果只丢一个批次进去,那么 GPU0 算第一段时,后面三张卡全在干等;轮到 GPU3 时,前三张又闲下来了。任意时刻只有一张卡在干活,利用率仅约 1/n,其余全是空转的气泡。

解法是微批次(micro-batching):把一个大批次拆成 m 个小微批次,像流水一样连续送进管道。前一个微批次往下游走的同时,后一个紧跟着进入上游,很快所有卡就都忙起来了,气泡只剩开头"灌注"和结尾"排空"的两个小三角。气泡占比由一个简洁的公式刻画:

气泡占比=n1m+n1\text{气泡占比} = \frac{n-1}{m + n - 1}

其中 nn 是流水线段数、mm 是微批次数。微批次越多,气泡越小。更进一步的零气泡(zero-bubble)调度,则把反向传播拆成"算激活梯度(B)"和"算权重梯度(W)"两部分,用不急的 W 计算去填补流水线里的空隙,把气泡压到接近于零。

流水线并行的通信很轻(只在段与段的边界点对点传一次激活值,约 1× 激活量),所以它可以跨节点。但调度复杂、实现繁琐,经验法则是:只有当模型大到单卡(哪怕用了 TP 和 FSDP)仍然装不下时,才动用 PP

序列并行:顺手再省一笔激活值显存

张量并行已经把矩阵乘法切开了,但 Transformer 里还有 LayerNorm、Dropout 这类逐点运算没被并行——它们会在每张卡上完整地重复一遍,对应的激活值也得完整存一份。序列并行(Sequence Parallelism, SP)补上了这一刀:既然这些运算沿序列维度是独立的,那就把激活值沿序列长度切开,每张卡只处理一段 token。

它常与张量并行配合,在两者交界处用 All-Gather(前向)和 Reduce-Scatter(反向)切换数据布局——注意这两步合起来恰好就是一次 All-Reduce,所以序列并行几乎没有增加额外通信量,却把激活值显存又除以了张量并行度 tt。单层激活内存大致是:

每层激活内存sbh(34+5as/h)t\text{每层激活内存} \approx \frac{s\,b\,h\,(34 + 5\,a\,s/h)}{t}

ss 序列长度,bb 微批大小,hh 隐藏维,aa 注意力头数,tt 张量并行度。)当序列长到 LayerNorm 的激活也成为显存负担时,这一刀很划算。处理超长上下文还有它的近亲——上下文并行 / 环形注意力(Ring Attention):把 Q、K、V 沿序列切块,在卡间"环形"传递 K/V 来增量计算注意力,让单条序列的长度也能横向扩展。

怎么组合:先快后慢,由内向外

四种并行不是单选题,真实的大规模训练是把它们叠起来用。怎么叠?答案全写在开头那座通信金字塔里——通信越频繁的策略,越要放进越快的硬件层级

并行策略的组合与放置:TP 在节点内吃 NVLink,DP/PP 跨节点

一个常用的加挂顺序是:① 先在节点内开张量并行(≤ 8 卡,吃满 NVLink);② 需要长序列就叠上序列 / 上下文并行;③ 模型仍然装不下,才按层切流水线并行(可跨节点);④ 最后用数据并行 / FSDP 横向扩展吞吐(跨节点)。落到硬件上,就是把高频的 TP 关进每台机器内部,把低频的 DP/PP 留给机器之间的慢速网络。

把四种策略的取舍并到一张表,一眼看清它们在"批大小、显存、通信"三角里的不同站位:

策略主要切什么省显存通信成本通信频率放在哪
数据并行 DDP批次不省2× 参数每步 1 次跨节点
FSDP(ZeRO-3)批次 + 全部状态线性下降~3× 参数每层跨节点
张量并行 TP矩阵宽度省(参数+激活)2× 激活每层 2 次,最频繁节点内 ≤8 卡
流水线并行 PP层(深度)线性下降1× 激活段边界点对点可跨节点
序列并行 SP序列长度省激活≈ 0(并入 TP)随 TP节点内

别忘了「先测准」:基准测试通信

最后呼应第六讲的那条铁律——优化不能靠猜,要靠测,对分布式同样成立。理论带宽和实测带宽往往差得离谱:在 4 卡上对一亿元素的张量做 All-Reduce,理论有 900 GB/s,实测往往只有约 277 GB/s;而 Reduce-Scatter 由于产生更多同步开销,实测可能只有约 70 GB/s

这提醒我们两件事:其一,通信的真实成本严重依赖张量大小、卡数和硬件拓扑,必须在自己的集群上实测才算数;其二,并行策略的选择从来不是纯理论题——同一套策略在 NVLink 全互联的机器和普通以太网集群上,最优配置可能完全不同。

尾声:硬件在变,第一性原理不变

GPU 之外,Google 的 TPU 走了另一条路:它用环面网格(torus)拓扑把芯片连起来,相邻芯片通信极快,特别适合超大规模;配套的 JAX 则让你用声明式的方式描述分片策略,由编译器自动生成底层通信原语,把 PyTorch 里要手写的集合操作藏进了编译器。生态在演进,但物理限制(散热、带宽、片上内存)始终都在。

回到 CS336 GPU 三讲的总主线:从单卡到万卡,算力始终在涨,瓶颈始终是"喂数据"。第五讲我们对抗 HBM 访存,第六讲学会测量与落地,这一讲对抗的则是卡间通信。所有花哨的并行策略,剥开外壳都在回答同一个问题——怎么切,才能在把模型摊开的同时,让卡与卡之间少说话、说得快

总结:分布式训练速查表

关键点一句话核心
新瓶颈单卡对抗 HBM,多卡对抗卡间通信——每跨一层边界,带宽掉一个量级
集合通信All-Reduce = Reduce-Scatter + All-Gather,通信量 ≈ 2× 数据
数据并行复制模型、切批次、All-Reduce 梯度;简单但状态重复存储
ZeRO / FSDP把参数/梯度/状态切片分摊,单卡显存 120GB → 1.9GB,靠预取重叠藏通信
张量并行切矩阵宽度,每层两次 All-Reduce,最频繁→只在节点内 ≤8 卡
流水线并行切层传激活,气泡占比 (n−1)/(m+n−1),微批次/零气泡来填
序列并行切序列省激活值,并入 TP 几乎不增通信
组合法则先快后慢、由内向外:TP 关节点内,DP/PP 才跨节点

记住这张表背后的那条主线,再看任何一种新的并行方案,你都能一眼判断它在和"通信"的哪一面较劲——是省了显存却多了同步,还是少了同步却限了扩展。权衡,永远在"存储、传输、重算"这个三角里打转。

参考资料