HelloAI
L7 第 6 篇 🐥 难度 🕒 14 分钟

训练优化进阶:让大模型训得动

梯度检查点 / 混合精度 / Activation Recomputation / ZeRO Offload——这些工程技巧让 70B 模型在单卡上能微调。

阿莱
2026/8/17

L7-01/02/03/04/05 我们覆盖了 GPU、分布式、推理、量化、部署。 这一篇讲训练优化的硬核技巧—— 让你的训练显存更省、速度更快、质量更高

训练时的”四大消耗”

训神经网络时,显存被四个东西占用:

占比(典型)
模型参数(FP16)30%
梯度(FP16)30%
优化器状态(Adam 用 FP32 m, v)30%
激活值(forward 中间结果)10-50%

训 7B 模型 = 14GB 参数 + 14GB 梯度 + 28GB Adam 状态 = 56GB 已经满了 A100。 还没算激活值!

所有训练优化技巧 = 围绕这四个的省 / 重算 / 卸载

技巧 1:Mixed Precision(混合精度)

最基础 + 最重要的优化。

朴素 FP32

model = MyModel().cuda()  # 默认 FP32
optimizer = torch.optim.Adam(model.parameters())

每个参数 4 字节 + 梯度 4 字节 + Adam m, v 各 4 字节 = 16 字节/参数。 7B 模型 = 112 GB——单卡跑不动。

FP16 / BF16 Mixed Precision

让 forward / backward 用 FP16(2 字节),优化器状态保留 FP32:

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for batch in data:
    optimizer.zero_grad()
    with autocast(dtype=torch.bfloat16):
        loss = model(batch)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

显存约降一半——速度也快(FP16 算力 2× FP32)。

选 FP16 还是 BF16

类型精度范围适用
FP16小(±10⁴)旧硬件
BF16大(±10³⁸)H100, A100, TPU

BF16 在新硬件上更稳(范围大 = 不溢出)—— 现代训练几乎全用 BF16。

FP8 训练(最新)

H100 + B100 原生支持 FP8:

  • 进一步减半显存
  • 训练再快 2×

DeepSeek V3 大规模用了 FP8—— 省了一半训练成本。

技巧 2:Gradient Checkpointing(梯度检查点)

也叫 Activation Recomputation —— 经典 “time-memory tradeoff”。

朴素 backward

前向时存储所有激活值—— 反向时用它们算梯度。

问题:激活值占巨大显存—— 100 层网络的激活可能比参数还大

检查点

前向时只存部分激活值(“检查点”)—— 反向时重新计算中间激活。

import torch.utils.checkpoint as checkpoint

class MyBlock(nn.Module):
    def forward(self, x):
        # 不存中间激活,反向时重算
        return checkpoint.checkpoint(self._forward, x)
    
    def _forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        return x

Trade-off

  • 显存:减少 60-80%
  • 速度:慢 20-30%(重算开销)

几乎所有大模型训练都开—— 显存 vs 速度,显存赢

技巧 3:ZeRO Optimizer(DeepSpeed)

L7-02 详讲过—— 把优化器状态、梯度、参数分片到多卡

ZeRO Stage分片
1优化器状态
2+ 梯度
3+ 参数(= FSDP)

ZeRO-Offload(关键升级)

不只分片到多卡——还能 offload 到 CPU 内存

GPU 显存:参数(forward/backward 用)
CPU 内存:优化器状态(更新参数时再换上来)

好处

  • GPU 显存进一步省 4×
  • 训 13B 模型用 24GB 显存 = 消费级 GPU

代价:训练慢 30-50%(数据搬运)。

ZeRO-Infinity

极端版本——还能 offload 到 NVMe SSD

GPU:当前 batch
CPU 内存:少数层
NVMe SSD:大多数层

训 100B+ 模型在单机—— 代价:极慢。

个人开发者训大模型的极限——ZeRO-Infinity + 量化 + LoRA

技巧 4:Activation Compression

激活值 占巨大显存——能压缩吗?

Activation Quantization

把激活从 FP16 量化到 INT8 / FP8:

# 中间激活值用低位数存
x = layer1(x)  # FP16
x_quantized = quantize_8bit(x)  # INT8,省一半显存
# 反向时再 dequantize

显存减半,质量损失小

Activation Sparsity

观察:很多激活值是 0(ReLU 之后)—— 只存非零部分 + 索引。

省 30-70% 取决于稀疏度。

技巧 5:Gradient Accumulation

GPU 太小,装不下大 batch—— 怎么模拟大 batch 训练?

朴素

for batch in data:
    loss = model(batch)  # batch_size=4 (GPU 极限)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Gradient Accumulation

accumulation_steps = 8

for i, batch in enumerate(data):
    loss = model(batch) / accumulation_steps  # 缩放损失
    loss.backward()  # 梯度累积

    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()  # 每 8 步才更新

等效于 batch_size=32—— 显存只用 batch_size=4 的量。

Trade-off:训得慢(计算量不变)但 batch 大。

技巧 6:Efficient Attention

L7-01 提过 Flash Attention —— 让 attention 在 GPU 上极快 + 极省内存。

Flash Attention 2

现在默认开:

# PyTorch 2.0+ 自动用
output = F.scaled_dot_product_attention(q, k, v)

比朴素实现快 2-4 倍,显存少 10×—— 长上下文场景必备。

Memory-Efficient Attention

xFormers 等库提供—— 更省显存的注意力实现。

技巧 7:Learning Rate 调度

不只是显存——收敛速度也优化。

Cosine Schedule + Warmup

def get_lr(step, warmup=1000, total=100000, max_lr=1e-4):
    if step < warmup:
        return max_lr * step / warmup
    else:
        progress = (step - warmup) / (total - warmup)
        return max_lr * 0.5 * (1 + cos(pi * progress))
  • Warmup:开始慢慢加 lr(防止初期发散)
  • Cosine:余弦下降到 0

几乎所有现代大模型训练标配

Curriculum Learning

简单数据 → 难数据。 不是均匀采样——按难度顺序。

收敛快 + 质量好—— 但实施复杂。

技巧 8:LoRA(详见 L4-05)

不微调全部参数——只调一小部分。

from peft import LoraConfig, get_peft_model

config = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj", "v_proj"])
model = get_peft_model(base_model, config)

# 现在 base_model 参数冻结,只训 LoRA adapters
# 显存:base 模型可量化到 4-bit + LoRA FP16

QLoRA(4-bit base + LoRA)让单卡训 70B 成为可能。

技巧 9:Distillation(蒸馏)

训不动大模型?让大模型当老师,训小模型

# 学生模型尝试模仿老师模型的输出分布

teacher = LargeModel()  # 不训
student = SmallModel()  # 训这个

teacher_logits = teacher(batch).detach()
student_logits = student(batch)

# KL divergence loss
loss = F.kl_div(F.log_softmax(student_logits / T), F.softmax(teacher_logits / T))

T(temperature)控制软化程度。

好处

  • 小模型也能学到大模型的知识
  • 训练 + 推理都便宜
  • 可以蒸馏多个老师(ensemble distillation)

Phi-3、Gemma 等小模型都用过蒸馏。

技巧 10:监控 / Profiling

光优化不监控 = 黑盒。 必须 profile

PyTorch Profiler

from torch.profiler import profile, record_function

with profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
    train_one_step()

print(prof.key_averages().table(sort_by="cuda_time_total"))

看哪一步耗时最多—— 针对性优化

Nsight Systems(NVIDIA)

更专业的 GPU profiling—— 看 kernel 执行、内存读写、通信。

优化前先 profile —— 不然你可能在错的地方使劲。

一个真实工作流

某团队训 70B 模型的实战配置:

硬件:8 张 H100 (一台 NVIDIA DGX)
框架:DeepSpeed + Megatron-LM

优化技巧:
  - BF16 mixed precision ✓
  - ZeRO Stage 3 + Activation Checkpointing ✓
  - Flash Attention 2 ✓
  - Gradient Accumulation (effective batch 256) ✓
  - Cosine LR schedule + warmup ✓
  - DeepSpeed Offload to CPU(部分) ✓

效果:
  - 单步:3.2 秒
  - 显存:72 GB / 80 GB(每卡)
  - 速度:约 50 tokens/sec/GPU
  - 总训练时间:6 周(1T token)

所有技巧叠加 = 让”训 70B” 从不可能变成可能

选哪些技巧

按场景:

单卡(消费级 GPU)

  • BF16 ✓
  • Gradient checkpointing ✓
  • 8-bit optimizer(bitsandbytes)
  • LoRA + 4-bit base model

多卡(DGX 工作站)

  • BF16 ✓
  • ZeRO-2 或 Stage 3
  • Flash Attention 2
  • Gradient accumulation

集群(数千卡)

  • 全部上 ✓
  • Megatron-LM TP + PP
  • 仔细 profile + 调度
  • 容错恢复
💡 一个心法

训大模型的工程是个艺术—— 没有”一刀切”配置。

最佳实践

  1. 先把小模型跑通——验证代码正确
  2. 逐步加优化——每加一个 profile 一下
  3. 从小数据集开始 —— 几小时迭代
  4. 写好 checkpoint —— 失败能恢复

训练失败的成本比生产 bug 高 100×—— 谨慎、有计划、有备份。

下一篇推荐:L7-07 监控与可观测性L7-08 模型生命周期管理

📬

读到这里说明你认真在学 🎯

订阅每周精选 —— 下一篇新文章 / 新可视化第一时间送到邮箱。

💬

讨论区

· 用 GitHub 账号登录评论
⚠️ Giscus 评论未配置 —— 在 src/components/Comments.astro 顶部填入 仓库 ID 和分类 ID(见组件注释里的配置步骤)。