训练优化进阶:让大模型训得动
梯度检查点 / 混合精度 / Activation Recomputation / ZeRO Offload——这些工程技巧让 70B 模型在单卡上能微调。
L7-01/02/03/04/05 我们覆盖了 GPU、分布式、推理、量化、部署。 这一篇讲训练优化的硬核技巧—— 让你的训练显存更省、速度更快、质量更高。
训练时的”四大消耗”
训神经网络时,显存被四个东西占用:
| 项 | 占比(典型) |
|---|---|
| 模型参数(FP16) | 30% |
| 梯度(FP16) | 30% |
| 优化器状态(Adam 用 FP32 m, v) | 30% |
| 激活值(forward 中间结果) | 10-50% |
训 7B 模型 = 14GB 参数 + 14GB 梯度 + 28GB Adam 状态 = 56GB 已经满了 A100。 还没算激活值!
所有训练优化技巧 = 围绕这四个的省 / 重算 / 卸载。
技巧 1:Mixed Precision(混合精度)
最基础 + 最重要的优化。
朴素 FP32
model = MyModel().cuda() # 默认 FP32
optimizer = torch.optim.Adam(model.parameters())
每个参数 4 字节 + 梯度 4 字节 + Adam m, v 各 4 字节 = 16 字节/参数。 7B 模型 = 112 GB——单卡跑不动。
FP16 / BF16 Mixed Precision
让 forward / backward 用 FP16(2 字节),优化器状态保留 FP32:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for batch in data:
optimizer.zero_grad()
with autocast(dtype=torch.bfloat16):
loss = model(batch)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
显存约降一半——速度也快(FP16 算力 2× FP32)。
选 FP16 还是 BF16
| 类型 | 精度 | 范围 | 适用 |
|---|---|---|---|
| FP16 | 高 | 小(±10⁴) | 旧硬件 |
| BF16 | 低 | 大(±10³⁸) | H100, A100, TPU |
BF16 在新硬件上更稳(范围大 = 不溢出)—— 现代训练几乎全用 BF16。
FP8 训练(最新)
H100 + B100 原生支持 FP8:
- 进一步减半显存
- 训练再快 2×
DeepSeek V3 大规模用了 FP8—— 省了一半训练成本。
技巧 2:Gradient Checkpointing(梯度检查点)
也叫 Activation Recomputation —— 经典 “time-memory tradeoff”。
朴素 backward
前向时存储所有激活值—— 反向时用它们算梯度。
问题:激活值占巨大显存—— 100 层网络的激活可能比参数还大。
检查点
前向时只存部分激活值(“检查点”)—— 反向时重新计算中间激活。
import torch.utils.checkpoint as checkpoint
class MyBlock(nn.Module):
def forward(self, x):
# 不存中间激活,反向时重算
return checkpoint.checkpoint(self._forward, x)
def _forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
return x
Trade-off:
- 显存:减少 60-80%
- 速度:慢 20-30%(重算开销)
几乎所有大模型训练都开—— 显存 vs 速度,显存赢。
技巧 3:ZeRO Optimizer(DeepSpeed)
L7-02 详讲过—— 把优化器状态、梯度、参数分片到多卡:
| ZeRO Stage | 分片 |
|---|---|
| 1 | 优化器状态 |
| 2 | + 梯度 |
| 3 | + 参数(= FSDP) |
ZeRO-Offload(关键升级)
不只分片到多卡——还能 offload 到 CPU 内存:
GPU 显存:参数(forward/backward 用)
CPU 内存:优化器状态(更新参数时再换上来)
好处:
- GPU 显存进一步省 4×
- 训 13B 模型用 24GB 显存 = 消费级 GPU
代价:训练慢 30-50%(数据搬运)。
ZeRO-Infinity
极端版本——还能 offload 到 NVMe SSD:
GPU:当前 batch
CPU 内存:少数层
NVMe SSD:大多数层
训 100B+ 模型在单机—— 代价:极慢。
个人开发者训大模型的极限——ZeRO-Infinity + 量化 + LoRA。
技巧 4:Activation Compression
激活值 占巨大显存——能压缩吗?
Activation Quantization
把激活从 FP16 量化到 INT8 / FP8:
# 中间激活值用低位数存
x = layer1(x) # FP16
x_quantized = quantize_8bit(x) # INT8,省一半显存
# 反向时再 dequantize
显存减半,质量损失小。
Activation Sparsity
观察:很多激活值是 0(ReLU 之后)—— 只存非零部分 + 索引。
省 30-70% 取决于稀疏度。
技巧 5:Gradient Accumulation
GPU 太小,装不下大 batch—— 怎么模拟大 batch 训练?
朴素
for batch in data:
loss = model(batch) # batch_size=4 (GPU 极限)
loss.backward()
optimizer.step()
optimizer.zero_grad()
Gradient Accumulation
accumulation_steps = 8
for i, batch in enumerate(data):
loss = model(batch) / accumulation_steps # 缩放损失
loss.backward() # 梯度累积
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad() # 每 8 步才更新
等效于 batch_size=32—— 显存只用 batch_size=4 的量。
Trade-off:训得慢(计算量不变)但 batch 大。
技巧 6:Efficient Attention
L7-01 提过 Flash Attention —— 让 attention 在 GPU 上极快 + 极省内存。
Flash Attention 2
现在默认开:
# PyTorch 2.0+ 自动用
output = F.scaled_dot_product_attention(q, k, v)
比朴素实现快 2-4 倍,显存少 10×—— 长上下文场景必备。
Memory-Efficient Attention
xFormers 等库提供—— 更省显存的注意力实现。
技巧 7:Learning Rate 调度
不只是显存——收敛速度也优化。
Cosine Schedule + Warmup
def get_lr(step, warmup=1000, total=100000, max_lr=1e-4):
if step < warmup:
return max_lr * step / warmup
else:
progress = (step - warmup) / (total - warmup)
return max_lr * 0.5 * (1 + cos(pi * progress))
- Warmup:开始慢慢加 lr(防止初期发散)
- Cosine:余弦下降到 0
几乎所有现代大模型训练标配。
Curriculum Learning
简单数据 → 难数据。 不是均匀采样——按难度顺序。
收敛快 + 质量好—— 但实施复杂。
技巧 8:LoRA(详见 L4-05)
不微调全部参数——只调一小部分。
from peft import LoraConfig, get_peft_model
config = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"])
model = get_peft_model(base_model, config)
# 现在 base_model 参数冻结,只训 LoRA adapters
# 显存:base 模型可量化到 4-bit + LoRA FP16
QLoRA(4-bit base + LoRA)让单卡训 70B 成为可能。
技巧 9:Distillation(蒸馏)
训不动大模型?让大模型当老师,训小模型。
# 学生模型尝试模仿老师模型的输出分布
teacher = LargeModel() # 不训
student = SmallModel() # 训这个
teacher_logits = teacher(batch).detach()
student_logits = student(batch)
# KL divergence loss
loss = F.kl_div(F.log_softmax(student_logits / T), F.softmax(teacher_logits / T))
T(temperature)控制软化程度。
好处:
- 小模型也能学到大模型的知识
- 训练 + 推理都便宜
- 可以蒸馏多个老师(ensemble distillation)
Phi-3、Gemma 等小模型都用过蒸馏。
技巧 10:监控 / Profiling
光优化不监控 = 黑盒。 必须 profile:
PyTorch Profiler
from torch.profiler import profile, record_function
with profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
train_one_step()
print(prof.key_averages().table(sort_by="cuda_time_total"))
看哪一步耗时最多—— 针对性优化。
Nsight Systems(NVIDIA)
更专业的 GPU profiling—— 看 kernel 执行、内存读写、通信。
优化前先 profile —— 不然你可能在错的地方使劲。
一个真实工作流
某团队训 70B 模型的实战配置:
硬件:8 张 H100 (一台 NVIDIA DGX)
框架:DeepSpeed + Megatron-LM
优化技巧:
- BF16 mixed precision ✓
- ZeRO Stage 3 + Activation Checkpointing ✓
- Flash Attention 2 ✓
- Gradient Accumulation (effective batch 256) ✓
- Cosine LR schedule + warmup ✓
- DeepSpeed Offload to CPU(部分) ✓
效果:
- 单步:3.2 秒
- 显存:72 GB / 80 GB(每卡)
- 速度:约 50 tokens/sec/GPU
- 总训练时间:6 周(1T token)
所有技巧叠加 = 让”训 70B” 从不可能变成可能。
选哪些技巧
按场景:
单卡(消费级 GPU)
- BF16 ✓
- Gradient checkpointing ✓
- 8-bit optimizer(bitsandbytes)
- LoRA + 4-bit base model
多卡(DGX 工作站)
- BF16 ✓
- ZeRO-2 或 Stage 3
- Flash Attention 2
- Gradient accumulation
集群(数千卡)
- 全部上 ✓
- Megatron-LM TP + PP
- 仔细 profile + 调度
- 容错恢复
训大模型的工程是个艺术—— 没有”一刀切”配置。
最佳实践:
- 先把小模型跑通——验证代码正确
- 逐步加优化——每加一个 profile 一下
- 从小数据集开始 —— 几小时迭代
- 写好 checkpoint —— 失败能恢复
训练失败的成本比生产 bug 高 100×—— 谨慎、有计划、有备份。
下一篇推荐:L7-07 监控与可观测性 或 L7-08 模型生命周期管理。
读到这里说明你认真在学 🎯
订阅每周精选 —— 下一篇新文章 / 新可视化第一时间送到邮箱。
讨论区
· 用 GitHub 账号登录评论src/components/Comments.astro 顶部填入
仓库 ID 和分类 ID(见组件注释里的配置步骤)。