分布式训练:DP / DDP / FSDP / Tensor Parallel 怎么选
一张 H100 装不下 70B 模型。怎么把训练任务分给几千张卡?这一篇梳理 4 种主流并行策略。
L7-01 我们看到训 70B 模型一张 H100 装不下—— 70B 模型仅参数就占 140GB(FP16),加上梯度、优化器状态、激活,总共 ~700GB 显存。
那 GPT-4 是 1.7T 参数怎么训的?靠”分布式训练”——把一个模型的训练任务拆给多卡甚至多机器。
4 大并行策略
| 策略 | 切什么 | 通信开销 | 适用 |
|---|---|---|---|
| DP/DDP(数据并行) | 切数据 | 中 | 模型能塞单卡 |
| FSDP/ZeRO(分片数据并行) | 切参数 | 高 | 模型大但能放 N 卡 |
| TP(张量并行) | 切单层 | 极高 | 单层都装不下 |
| PP(流水线并行) | 切层 | 低 | 跨多机器 |
实战中通常混用——这叫 “3D Parallelism”。
一、数据并行(DP / DDP)
最简单的并行:每个 GPU 一份模型副本,切分 batch。
GPU 0: 模型副本 → 处理 batch 的 0-31 样本 → 梯度 g₀
GPU 1: 模型副本 → 处理 batch 的 32-63 样本 → 梯度 g₁
GPU 2: 模型副本 → 处理 batch 的 64-95 样本 → 梯度 g₂
GPU 3: 模型副本 → 处理 batch 的 96-127 样本 → 梯度 g₃
↓ All-Reduce(梯度求平均)
每个 GPU 用平均梯度更新参数(保持模型副本一致)
DP vs DDP
DP(DataParallel):单进程多线程——Python GIL 卡住,慢。 DDP(DistributedDataParallel):多进程——快得多。
DDP 是 PyTorch 数据并行的标准做法——
torch.distributed。
PyTorch 代码
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
def train(rank, world_size):
# 初始化进程组
dist.init_process_group("nccl", rank=rank, world_size=world_size)
# 创建模型并搬到对应 GPU
model = MyModel().to(rank)
model = DDP(model, device_ids=[rank])
# 训练循环(和单卡一样)
for batch in dataloader:
loss = model(batch)
loss.backward()
optimizer.step()
dist.destroy_process_group()
if __name__ == "__main__":
world_size = 4
mp.spawn(train, args=(world_size,), nprocs=world_size)
DP 的局限
每张卡都要有完整模型副本 —— 模型必须能塞进单卡显存。
70B 模型 + DP 不行——一张 H100 装不下。
二、FSDP(完全分片数据并行)/ ZeRO
解决 DP 的局限——把模型参数也切到不同卡。
ZeRO(Zero Redundancy Optimizer)
DeepSpeed 2019 年提出,分 3 个阶段:
| ZeRO Stage | 切什么 | 显存节省 |
|---|---|---|
| Stage 1 | 优化器状态(Adam 的 m, v) | 4× |
| Stage 2 | + 梯度 | 8× |
| Stage 3 | + 参数 | N× (正比卡数) |
Stage 3 = FSDP:每张卡只存模型的 1/N 参数。
工作流程(FSDP / ZeRO-3)
正常时刻:每张卡只有 1/N 参数
↓
前向计算第 K 层时:
- 所有卡互相 all-gather,凑出第 K 层完整参数
- 算这一层的输出
- 算完立刻丢掉除自己那份外的参数
↓
反向算第 K 层时:
- 又 all-gather 一次(已经丢了)
- 算梯度
- reduce-scatter 梯度(每张卡只保留自己那份的梯度)
↓
更新:每张卡更新自己那 1/N 的参数
牺牲:通信开销大(每层需要 all-gather)——网络带宽成为瓶颈。
NVLink 集群上 FSDP 高效,普通以太网会被通信拖垮。
PyTorch 代码
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model = MyHugeModel()
model = FSDP(model, sharding_strategy=ShardingStrategy.FULL_SHARD)
# 训练代码和 DDP 几乎一样
PyTorch 2.0+ 的 FSDP 是 ZeRO-3 的原生实现。
训 7B-70B 模型常用方案:FSDP 单机 8 卡,或多机 32 卡。
三、张量并行(Tensor Parallel)
如果单层都装不下一张卡怎么办?
比如 GPT-3 175B 的一个 attention 层,权重就 36GB——单层超过单卡显存了。
TP 把单层的矩阵切到多卡:
怎么切矩阵乘法
,假设 是 4096×4096——可以切成两半:
W = [W_1; W_2] (按列切)
GPU 0: y_0 = W_1 · x
GPU 1: y_1 = W_2 · x
合并: y = concat(y_0, y_1) ← 需要 all-gather
或按行切:
W = [W_1; W_2]^T (按行切)
GPU 0: y_0 = W_1 · x
GPU 1: y_1 = W_2 · x
合并: y = y_0 + y_1 ← 需要 all-reduce
Megatron-LM 的设计
NVIDIA 的 Megatron-LM 把 Transformer 的 attention 和 FFN 都做了 TP:
Attention 切多头:每个头放一张卡
FFN 切中间维度:4d 切给多张卡
优点:能训单层超大的模型。 缺点:通信极频繁——每层都要 all-reduce 或 all-gather。需要 NVLink/NVSwitch 这种高速互联。
GPT-3 / GPT-4 / Llama 3 405B 训练都用 TP——必须有。
四、流水线并行(Pipeline Parallel)
切层而不是切单层——
GPU 0: 处理第 1-8 层
GPU 1: 处理第 9-16 层
GPU 2: 处理第 17-24 层
GPU 3: 处理第 25-32 层
每张卡只存一部分层。
问题:流水线泡(Pipeline Bubble)
朴素流水线:
时间 →
GPU 0: L1-8(batch1) ─────────
GPU 1: L9-16(batch1) ──────
GPU 2: L17-24(batch1) ─────
GPU 3: L25-32(batch1)
← 大部分时间 GPU 在闲置!这就是 "bubble"
1F1B / Interleaved
优化:把 batch 拆成 micro-batches,让每张卡同时处理不同的 micro-batches:
GPU 0: m1 m2 m3 m4 (forward) ... m4 m3 m2 m1 (backward)
GPU 1: m1 m2 m3 m4 (forward) ... m4 m3 m2 m1 (backward)
GPU 2: m1 m2 m3 m4 ...
GPU 3: m1 m2 m3 m4 ...
← bubble 显著减少
优点:通信少(只在层与层间)——适合跨多机器(网络带宽低)。 缺点:实现复杂,调度难。
五、3D Parallelism(混合)
训 GPT-3 / GPT-4 这种巨型模型,单一并行不够:
3D = TP + PP + DP
例子:训 GPT-3 用 1024 张 GPU
- TP: 8 张卡组成"张量并行组"(一层切 8 块)
- PP: 16 个流水线阶段
- DP: 8 个数据并行副本
每个维度都做:单层切 + 多层流水 + 多数据。
真实例子
Megatron-Turing 530B 模型:
- 模型大小:530B 参数
- 集群:560 张 A100
- TP=8, PP=35, DP=2
- 训练时长:~3 个月
- 估算成本:~$10M
PyTorch 工具
- PyTorch FSDP:内置 FSDP
- DeepSpeed:微软出品,ZeRO + 流水 + TP
- Megatron-LM:NVIDIA 出品,最快 TP
- Colossal-AI:清华开源,相对易用
- Lightning AI:高层 API
训大模型大多用 DeepSpeed + Megatron-LM 组合。
第六站:怎么选
决策树
模型能塞单卡?
├── 能 → DDP(数据并行,最简单)
└── 不能
├── N 卡 ZeRO-3 / FSDP 能塞?
│ └── 能 → FSDP
│
└── 单层都装不下?
└── Megatron-LM TP + PP
经验值
| 模型大小 | 推荐 |
|---|---|
| < 1B | DDP(单机多卡) |
| 1-13B | FSDP / DeepSpeed ZeRO-3 |
| 13-70B | FSDP + Activation Checkpointing |
| 70-200B | TP + FSDP(混合) |
| 200B+ | TP + PP + DP(3D) |
第七站:通信瓶颈
分布式训练的性能瓶颈通常不是计算,是通信。
关键指标
- NVLink (单机内): 900 GB/s(H100 NVLink 4.0)
- InfiniBand (跨机器): 400 Gb/s(最快)
- 以太网 (普通服务器): 10-25 Gb/s
普通以太网比 NVLink 慢 100 倍——普通服务器上做大规模分布式训练完全不行。
通信优化
- 梯度压缩:梯度从 FP32 压到 FP8 传,省 4 倍
- 梯度累积:先在卡内累积,再 all-reduce
- 重叠通信和计算:上一层算完梯度,下一层还在算前向时就 all-reduce
这些工程优化让训练效率从 30%(基础实现)提到 60-70%(优化版)。
GPT-3 训练 175B 模型用了约 3.14 × 10²³ FLOPs。 理论上 1024 张 A100 全速跑约 1 个月。 实际跑了 3-4 个月——通信效率和重启容错占了 70% 时间。
分布式训练的工程比算法本身更难。
下一篇:《推理优化:vLLM / 量化 / 投机解码 / KV Cache》 —— 训完之后怎么把它跑快、跑省。
读到这里说明你认真在学 🎯
订阅每周精选 —— 下一篇新文章 / 新可视化第一时间送到邮箱。
讨论区
· 用 GitHub 账号登录评论src/components/Comments.astro 顶部填入
仓库 ID 和分类 ID(见组件注释里的配置步骤)。