FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness
为什么这篇论文重要
注意力(Attention)公式简单:
但朴素实现极其低效——序列稍长就内存爆炸 + 运行龟速。
2022 年这篇论文让所有人意识到:
注意力计算的瓶颈不是算力,是显存读写。
通过重新组织计算来最小化 HBM 访问,FlashAttention 把同一个数学公式跑得快 2-4 倍。 所有现代 LLM 推理框架都默认开启它。
GPU 内存层级(关键背景)
GPU 内有不同速度的内存:
速度
↑
SRAM 几 KB/SM 19 TB/s 超快
HBM 80 GB 1.5 TB/s 慢(相对)
PCIe 系统 RAM 10 GB/s 超慢
↓ 越下面越大越慢
注意 SRAM 比 HBM 快 12 倍—— 但容量小 100 万倍。
这意味着:如果你能把计算保留在 SRAM 里,速度就能飞。
朴素 Attention 慢在哪
# 标准实现
def naive_attention(Q, K, V):
S = Q @ K.T # (n, n) ← 写入 HBM
P = softmax(S) # (n, n) ← 写入 HBM
O = P @ V # (n, d) ← 写入 HBM
return O
序列长度 ,每个矩阵大小 ——
| 步骤 | HBM 读写 |
|---|---|
| 计算 | 读 (~2nd),写 () |
| Softmax | 读 ,写 () |
| 加权 V | 读 (),写 |
总 HBM 访问 ~5n²—— 当 n=8K,这是 2.5 亿次内存操作——爆炸。
GPU 计算很快,但被内存带宽卡死了。
FlashAttention 的核心思路
不要写入中间结果到 HBM—— 在 SRAM 里把完整计算做完。
分块计算(Tiling)
把 切成小块,每次只在 SRAM 里处理一小块:
Q [n × d] 切成 Q_1, Q_2, ..., Q_T (每块小到能放 SRAM)
K [n × d] 切成 K_1, K_2, ..., K_T
V [n × d] 切成 V_1, V_2, ..., V_T
对每个 Q_i:
for j = 1 to T:
从 HBM 加载 K_j, V_j 到 SRAM
在 SRAM 内计算 Q_i K_j^T、softmax、加权 V_j
累积结果
写一次 Q_i 的最终输出到 HBM
关键技巧:在累积时计算”running softmax”—— 不需要算完整 softmax 再加权。
Recomputation
反向传播时,不存储中间的 softmax 矩阵—— 重新计算它(因为快内存访问 < 重新算)。
这是经典的 “trade computation for memory” —— 算多点没事,内存便宜了。
数学正确性
最神奇的事——FlashAttention 和朴素 Attention 给出完全相同的输出。 不是近似,是数学上等价。
它只是改变了计算顺序—— 让内存访问模式更高效。
这是优秀系统优化的标志:算法不变,硬件友好。
性能数据
速度对比
| 序列长度 | 朴素 | FlashAttention |
|---|---|---|
| 1k | 1× baseline | 1.6× faster |
| 4k | 1× | 2.4× faster |
| 16k | 1× | 3.0× faster |
| 64k | OOM | works + 4× faster |
序列越长,加速越明显——因为内存压力越大。
显存对比
| 序列长度 | 朴素显存 | FlashAttention 显存 |
|---|---|---|
| 1k | 4 MB | 0.4 MB |
| 8k | 256 MB | 3 MB |
| 32k | 4 GB | 12 MB |
长上下文场景显存能省 100 倍以上。
FlashAttention 2 (2023)
一年后的升级版:
- 更优的工作分配
- 减少非矩阵乘运算
- 在 H100 上速度再翻 2 倍
“FlashAttention” 现在通常指 v2。
FlashAttention 3 (2024)
针对 H100 的 FP8 + 异步执行优化:
- 1.5-2× 比 FlashAttention 2 还快
- 支持 FP8 精度(不损失质量)
- 利用 H100 的 Hopper 架构特性
论文之后
FlashAttention 改变了:
1. 长上下文成为可能
GPT-4 128k、Claude 2M、Gemini 1M—— 这些长上下文 LLM 都用 FlashAttention。 没有它,长上下文显存费用会让模型不可用。
2. 训练效率翻倍
H100 + FlashAttention 训 70B 模型—— 比 A100 + naive attention 快 6-8 倍。
3. 推理框架的标配
vLLM、TGI、TensorRT-LLM、llama.cpp—— 全部用 FlashAttention 作为默认 attention 实现。
不开它的代价:你的推理服务慢 2-4 倍。
一些有趣的事
Tri Dao 是谁
第一作者 Tri Dao—— Stanford 博士生做出 FlashAttention(2022)。 同一年又做出 Mamba(2023)。 毕业后去 Princeton 当教授。
26 岁做出两项基础工作—— 被誉为系统 ML 领域最有创造力的研究者之一。
“硬件感知 ML”
FlashAttention 开创了一个新方向—— ML 算法设计要考虑硬件特性。
之前 ML 研究主要由”什么数学结构强”驱动—— FlashAttention 教会大家:算法-硬件协同设计同等重要。
后续 SnowflakeKernels、xFormers、Triton 优化等都受启发。
它有时被叫”系统级的 ResNet”
意思是——一个简单但深刻的工程改进,影响整个领域。 ResNet 是算法上的,FlashAttention 是系统上的。
实际使用
在 PyTorch 里默认开启
# PyTorch 2.0+ 默认用 FlashAttention(如果硬件支持)
import torch
import torch.nn.functional as F
q = torch.randn(1, 8, 1000, 64, device='cuda', dtype=torch.float16)
k = torch.randn(1, 8, 1000, 64, device='cuda', dtype=torch.float16)
v = torch.randn(1, 8, 1000, 64, device='cuda', dtype=torch.float16)
# 这一行自动用 FlashAttention(如果可用)
output = F.scaled_dot_product_attention(q, k, v)
HuggingFace Transformers
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"Llama-3-70B",
attn_implementation="flash_attention_2", # ← 显式启用
torch_dtype=torch.float16,
)
推荐配套阅读
- HelloAI: L3-05 注意力机制 + L7-01 GPU 速览 + L7-03 推理优化
- FlashAttention 论文(原版)
- FlashAttention 2 论文
- Tri Dao 的 NeurIPS 演讲
- PyTorch 2.0 SDPA 文档
FlashAttention 不是新算法——是更聪明地用硬件。
这告诉我们:MLOps / 系统优化 / 硬件理解 在 AI 时代极其重要。 不是只懂数学就够——懂 GPU 内存层级、能写 CUDA、能用 Triton——这些都是关键技能。
如果你的目标是 ML 系统工程师——这是必读论文。
想要更多论文精读
订阅每周精选 —— 下一篇论文笔记直接送邮箱。
讨论区
· 用 GitHub 账号登录评论src/components/Comments.astro 顶部填入
仓库 ID 和分类 ID(见组件注释里的配置步骤)。