注意力机制详解:从直觉到完整推导
Attention is all you need. 这一篇带你从"它到底在干嘛"到"它的每行公式",一次性吃透 Transformer 的核心。
2017 年那篇标题嚣张的论文 《Attention Is All You Need》 改变了一切。
L0-05 教过你怎么用 attention 写 prompt。L1-02 让你看到了 Q、K、V 这些矩阵在干什么。这一篇我们打通最后一公里——一行公式都不能含糊。
读完,你会真正”看懂” Transformer 的核心。
🎮 强烈建议先去 注意力实时计算可视化 玩 5 分钟。看着热力图变化读这篇,事半功倍。
第一站:注意力解决什么问题
回到一个具体的例子。一句话:
“小猫追着小球跑,因为它很好奇。”
“它”指谁?人脑读到”它”瞬间就知道是”小猫”——但这需要看回前面、找指代对象。
RNN 怎么做(不好)
2014 年前主流的 RNN 是这么做的:
从左往右读,每个词把”自己 + 之前的语义”压缩成一个 hidden state,传给下一个词。
问题:信息在传递过程中会衰减。等读到”它”的时候,“小猫”已经在第 N 个 hidden state 里被稀释了。
Attention 怎么做(好)
每个词在被处理时,可以”直接看”句子里所有其它词,根据相关性决定关注哪几个。
读”它”时——
- 看一眼”小猫”——相关性 0.62 ✓
- 看一眼”追”——相关性 0.08
- 看一眼”球”——相关性 0.14
- ……
最相关的”小猫”被赋予最大权重——“它”的语义就从”小猫”那里”借”过来了。
这就是注意力机制的本质:信息的有向流动。
第二站:Q、K、V 是什么
为了让”注意力”能算,我们给每个词配三个向量:
- Q(Query):当前词在”问”——我该关注谁?
- K(Key):每个词的”名片”——告诉别人”我是谁、我代表什么”
- V(Value):每个词真正”携带的信息”——如果你注意我,就拿走我这份
类比一个图书馆查询:
- 你提问(Q)
- 每本书有个标题(K)
- 找最匹配的几本,把它们的内容(V)抄到你的答案里
关键问题:Q、K、V 从哪儿来?
答:从输入词向量做线性变换。
每个词最初是一个 embedding 向量 (比如 768 维)。我们用三个可训练的矩阵把它们投影成 Q、K、V:
其中:
- 是
- 都是 (投影到一个更小的维度)
- 结果 都是
这三个矩阵 是整个 Transformer 唯一要学的东西之一。
第三站:注意力计算(核心公式)
OK,每个词都有了 Q、K、V。怎么算注意力?
整个公式——记好,这是 AI 历史上最重要的一行:
我们一步步拆。
Step 1:算原始分数
把所有 Q 摞起来(一个矩阵,每行是一个词的 Q),所有 K 摞起来。
如果有 个词,结果是一个 矩阵——每一格 是”词 i 应该多关注词 j”的原始分数。
数学上这就是点积——还记得 L1-02 说的”点积越大方向越像”吗?这里就是用点积测 Q 和 K 的相似度。
Step 2:缩放
为什么要除以 ( 是 Q/K 的维度,比如 64)?
数学原因:点积是 项之和。 越大,分数方差越大——softmax 后会变得极端(一个 1,其它接近 0),梯度消失。
除以 把方差控制在合理范围。
这是工程细节,不影响理解,但很关键。
Step 3:Softmax 归一化
对每一行做 softmax,让每行加起来 = 1:
每行就是一个概率分布——告诉你词 i 把多大比例的”注意力预算”分给了每个词。
Step 4:用权重加权 V
每个词的输出 = 用注意力权重对所有 V 做加权平均。
最相关的词的 V 权重大——它的信息”贡献”最多。
每个词都把所有其它词的 V 做加权平均,权重来自 Q 和 K 的相似度,最后归一化。 就是它。
第四站:用 PyTorch 实现 30 行
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q, K, V: (batch, n_tokens, d_k)
返回: (batch, n_tokens, d_v) 和注意力权重
"""
d_k = Q.size(-1)
# Step 1: 算 Q · K^T
scores = torch.matmul(Q, K.transpose(-2, -1)) # (batch, n, n)
# Step 2: 缩放
scores = scores / (d_k ** 0.5)
# Step 3 (optional): 因果掩码(生成模型用)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Step 4: softmax
weights = F.softmax(scores, dim=-1) # (batch, n, n)
# Step 5: 加权 V
output = torch.matmul(weights, V) # (batch, n, d_v)
return output, weights
# 实战
batch, n, d = 1, 10, 64
Q = K = V = torch.randn(batch, n, d)
output, attn = scaled_dot_product_attention(Q, K, V)
print(output.shape) # torch.Size([1, 10, 64])
print(attn.shape) # torch.Size([1, 10, 10])
print(attn[0, 5].sum()) # ≈ 1.0 (每行加起来等于 1)
这 30 行代码就是 Transformer 的核心。剩下都是工程细节。
第五站:多头注意力(Multi-Head)
但实际中我们不只用 1 个注意力——用很多(GPT-3 用 96 个,PaLM 用 48 个)。
为什么?一个头只能学一种”关注模式”。
可视化已经让你看到了:
- 有的头看”前一个词”
- 有的头看”句首”
- 有的头看”相似词”
- 有的头做”指代消解”
让模型同时跑 N 个独立的注意力,就能并行捕捉 N 种不同的依赖关系。
数学上:
def multi_head_attention(X, n_heads=8, d_model=512):
d_k = d_model // n_heads # 每个头的维度
# 每个头有自己的 W^Q, W^K, W^V
Q = X @ W_Q # (batch, n, d_model)
K = X @ W_K
V = X @ W_V
# 切分成 n_heads 个头
Q = Q.view(batch, n, n_heads, d_k).transpose(1, 2) # (batch, n_heads, n, d_k)
K = K.view(batch, n, n_heads, d_k).transpose(1, 2)
V = V.view(batch, n, n_heads, d_k).transpose(1, 2)
# 每个头独立做注意力
out, _ = scaled_dot_product_attention(Q, K, V)
# 把所有头的结果拼回来
out = out.transpose(1, 2).contiguous().view(batch, n, d_model)
# 最后过一个线性层混合
return out @ W_O
这就是 nn.MultiheadAttention 的核心。
第六站:自注意力 vs 交叉注意力
到目前为止,Q、K、V 都来自同一个输入——这叫 自注意力(self-attention)。
但 Transformer 在不同位置用不同方式:
| 位置 | Q 来自 | K, V 来自 |
|---|---|---|
| Encoder 自注意力 | encoder 输入 | encoder 输入 |
| Decoder 自注意力(带掩码) | decoder 输入 | decoder 输入 |
| Decoder 交叉注意力 | decoder 输入 | encoder 输出 |
最后一种”交叉注意力”,让 decoder 在生成时回头看 encoder 编码的信息——这就是 Seq2Seq、翻译模型的核心。
第七站:因果掩码(Causal Mask)
GPT 这类自回归生成模型有个限制:生成第 i 个词时,不能看第 i+1 之后的词——不然就是”作弊”。
实现方式:在 softmax 之前,把”未来位置”的分数设为 :
# 上三角掩码(不含对角线)
mask = torch.triu(torch.ones(n, n), diagonal=1).bool()
scores.masked_fill_(mask, float('-inf'))
这样 softmax 后未来位置的权重就是 0——模型只能看到已经生成的内容。
这就是 GPT 的工作方式——一次生成一个 token,永远只看过去。
一些常见误解
”注意力就是加权求和”——不完全对
更准确:注意力是让每个位置”采购”其它位置的信息。 权重只是采购过程中的副产品。
“Q 和 K 越像权重越大”——对,但要补一句
它们的相似度是通过点积测的——你需要在合适的”投影空间”里测才有意义。这就是为什么 Q 和 K 有自己的投影矩阵。
“多头是不是有冗余?“——确实有
研究表明,BERT 的 12 头中只有 4-6 个真正工作,剩下的可剪掉。这是个开放问题。
一句话总结
Attention = 信息的有向流动 + softmax 加权融合。
它的革命性在于:任何两个位置之间都能直接通信——无论它们隔了几千个词。
RNN 是接力赛,Attention 是会场广播。
想”看见”它
👀 注意力实时计算可视化 —— 玩 4 种注意力头,看 attention pattern 在不同任务下的差异。
- 能看懂 BERT、GPT 论文里的 attention 部分
- 能在 PyTorch 里手写一个 transformer block
- 能对”你说 attention pattern” 等讨论应得上话
下一步:搭一个最小可运行的 Transformer。
下一篇:《CNN 卷积原理:从滤镜到 ResNet》 —— 视觉领域的”attention 之前的王者”。
读到这里说明你认真在学 🎯
订阅每周精选 —— 下一篇新文章 / 新可视化第一时间送到邮箱。
讨论区
· 用 GitHub 账号登录评论src/components/Comments.astro 顶部填入
仓库 ID 和分类 ID(见组件注释里的配置步骤)。