HelloAI
L3 第 5 篇 🐣 难度 🕒 18 分钟

注意力机制详解:从直觉到完整推导

Attention is all you need. 这一篇带你从"它到底在干嘛"到"它的每行公式",一次性吃透 Transformer 的核心。

阿莱
2026/6/15

2017 年那篇标题嚣张的论文 《Attention Is All You Need》 改变了一切。

L0-05 教过你怎么用 attention 写 prompt。L1-02 让你看到了 Q、K、V 这些矩阵在干什么。这一篇我们打通最后一公里——一行公式都不能含糊

读完,你会真正”看懂” Transformer 的核心。

🎮 强烈建议先去 注意力实时计算可视化 玩 5 分钟。看着热力图变化读这篇,事半功倍。

第一站:注意力解决什么问题

回到一个具体的例子。一句话:

“小猫追着小球跑,因为很好奇。”

“它”指谁?人脑读到”它”瞬间就知道是”小猫”——但这需要看回前面、找指代对象

RNN 怎么做(不好)

2014 年前主流的 RNN 是这么做的:

从左往右读,每个词把”自己 + 之前的语义”压缩成一个 hidden state,传给下一个词。

问题:信息在传递过程中会衰减。等读到”它”的时候,“小猫”已经在第 N 个 hidden state 里被稀释了。

Attention 怎么做(好)

每个词在被处理时,可以”直接看”句子里所有其它词,根据相关性决定关注哪几个。

读”它”时——

  • 看一眼”小猫”——相关性 0.62 ✓
  • 看一眼”追”——相关性 0.08
  • 看一眼”球”——相关性 0.14
  • ……

最相关的”小猫”被赋予最大权重——“它”的语义就从”小猫”那里”借”过来了

这就是注意力机制的本质:信息的有向流动

第二站:Q、K、V 是什么

为了让”注意力”能算,我们给每个词配三个向量:

  • Q(Query):当前词在”问”——我该关注谁?
  • K(Key):每个词的”名片”——告诉别人”我是谁、我代表什么”
  • V(Value):每个词真正”携带的信息”——如果你注意我,就拿走我这份

类比一个图书馆查询:

  • 你提问(Q)
  • 每本书有个标题(K)
  • 找最匹配的几本,把它们的内容(V)抄到你的答案里

关键问题:Q、K、V 从哪儿来?

答:从输入词向量做线性变换

每个词最初是一个 embedding 向量 xix_i(比如 768 维)。我们用三个可训练的矩阵把它们投影成 Q、K、V:

Qi=xiWQ,Ki=xiWK,Vi=xiWVQ_i = x_i W^Q, \quad K_i = x_i W^K, \quad V_i = x_i W^V

其中:

  • xix_i1×7681 \times 768
  • WQ,WK,WVW^Q, W^K, W^V 都是 768×64768 \times 64(投影到一个更小的维度)
  • 结果 Qi,Ki,ViQ_i, K_i, V_i 都是 1×641 \times 64

这三个矩阵 WQ,WK,WVW^Q, W^K, W^V整个 Transformer 唯一要学的东西之一。

第三站:注意力计算(核心公式)

OK,每个词都有了 Q、K、V。怎么算注意力?

整个公式——记好,这是 AI 历史上最重要的一行:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V

我们一步步拆。

Step 1:算原始分数 QKTQK^T

把所有 Q 摞起来(一个矩阵,每行是一个词的 Q),所有 K 摞起来。

scores=QKT\text{scores} = Q K^T

如果有 nn 个词,结果是一个 n×nn \times n 矩阵——每一格 [i,j][i,j] 是”词 i 应该多关注词 j”的原始分数

数学上这就是点积——还记得 L1-02 说的”点积越大方向越像”吗?这里就是用点积测 Q 和 K 的相似度。

Step 2:缩放 dk\sqrt{d_k}

为什么要除以 dk\sqrt{d_k}dkd_k 是 Q/K 的维度,比如 64)?

数学原因:点积是 dkd_k 项之和。dkd_k 越大,分数方差越大——softmax 后会变得极端(一个 1,其它接近 0),梯度消失。

除以 dk\sqrt{d_k} 把方差控制在合理范围。

这是工程细节,不影响理解,但很关键。

Step 3:Softmax 归一化

对每一行做 softmax,让每行加起来 = 1:

attention_weights[i,j]=exp(scores[i,j])kexp(scores[i,k])\text{attention\_weights}[i,j] = \frac{\exp(\text{scores}[i,j])}{\sum_k \exp(\text{scores}[i,k])}

每行就是一个概率分布——告诉你词 i 把多大比例的”注意力预算”分给了每个词。

Step 4:用权重加权 V

output=attention_weightsV\text{output} = \text{attention\_weights} \cdot V

每个词的输出 = 用注意力权重对所有 V 做加权平均。

最相关的词的 V 权重大——它的信息”贡献”最多。

💡 一句话翻译整个公式

每个词都把所有其它词的 V 做加权平均,权重来自 Q 和 K 的相似度,最后归一化。 就是它。

第四站:用 PyTorch 实现 30 行

import torch
import torch.nn.functional as F

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Q, K, V: (batch, n_tokens, d_k)
    返回: (batch, n_tokens, d_v) 和注意力权重
    """
    d_k = Q.size(-1)

    # Step 1: 算 Q · K^T
    scores = torch.matmul(Q, K.transpose(-2, -1))   # (batch, n, n)

    # Step 2: 缩放
    scores = scores / (d_k ** 0.5)

    # Step 3 (optional): 因果掩码(生成模型用)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    # Step 4: softmax
    weights = F.softmax(scores, dim=-1)   # (batch, n, n)

    # Step 5: 加权 V
    output = torch.matmul(weights, V)     # (batch, n, d_v)

    return output, weights


# 实战
batch, n, d = 1, 10, 64
Q = K = V = torch.randn(batch, n, d)

output, attn = scaled_dot_product_attention(Q, K, V)
print(output.shape)   # torch.Size([1, 10, 64])
print(attn.shape)     # torch.Size([1, 10, 10])
print(attn[0, 5].sum())  # ≈ 1.0 (每行加起来等于 1)

这 30 行代码就是 Transformer 的核心。剩下都是工程细节。

第五站:多头注意力(Multi-Head)

但实际中我们不只用 1 个注意力——用很多(GPT-3 用 96 个,PaLM 用 48 个)。

为什么?一个头只能学一种”关注模式”

可视化已经让你看到了:

  • 有的头看”前一个词”
  • 有的头看”句首”
  • 有的头看”相似词”
  • 有的头做”指代消解”

让模型同时跑 N 个独立的注意力,就能并行捕捉 N 种不同的依赖关系

数学上:

def multi_head_attention(X, n_heads=8, d_model=512):
    d_k = d_model // n_heads  # 每个头的维度

    # 每个头有自己的 W^Q, W^K, W^V
    Q = X @ W_Q   # (batch, n, d_model)
    K = X @ W_K
    V = X @ W_V

    # 切分成 n_heads 个头
    Q = Q.view(batch, n, n_heads, d_k).transpose(1, 2)  # (batch, n_heads, n, d_k)
    K = K.view(batch, n, n_heads, d_k).transpose(1, 2)
    V = V.view(batch, n, n_heads, d_k).transpose(1, 2)

    # 每个头独立做注意力
    out, _ = scaled_dot_product_attention(Q, K, V)

    # 把所有头的结果拼回来
    out = out.transpose(1, 2).contiguous().view(batch, n, d_model)

    # 最后过一个线性层混合
    return out @ W_O

这就是 nn.MultiheadAttention 的核心

第六站:自注意力 vs 交叉注意力

到目前为止,Q、K、V 都来自同一个输入——这叫 自注意力(self-attention)

但 Transformer 在不同位置用不同方式:

位置Q 来自K, V 来自
Encoder 自注意力encoder 输入encoder 输入
Decoder 自注意力(带掩码)decoder 输入decoder 输入
Decoder 交叉注意力decoder 输入encoder 输出

最后一种”交叉注意力”,让 decoder 在生成时回头看 encoder 编码的信息——这就是 Seq2Seq、翻译模型的核心。

第七站:因果掩码(Causal Mask)

GPT 这类自回归生成模型有个限制:生成第 i 个词时,不能看第 i+1 之后的词——不然就是”作弊”。

实现方式:在 softmax 之前,把”未来位置”的分数设为 -\infty

# 上三角掩码(不含对角线)
mask = torch.triu(torch.ones(n, n), diagonal=1).bool()
scores.masked_fill_(mask, float('-inf'))

这样 softmax 后未来位置的权重就是 0——模型只能看到已经生成的内容。

这就是 GPT 的工作方式——一次生成一个 token,永远只看过去。

一些常见误解

”注意力就是加权求和”——不完全对

更准确:注意力是让每个位置”采购”其它位置的信息。 权重只是采购过程中的副产品。

“Q 和 K 越像权重越大”——对,但要补一句

它们的相似度是通过点积测的——你需要在合适的”投影空间”里测才有意义。这就是为什么 Q 和 K 有自己的投影矩阵。

“多头是不是有冗余?“——确实有

研究表明,BERT 的 12 头中只有 4-6 个真正工作,剩下的可剪掉。这是个开放问题。

一句话总结

Attention = 信息的有向流动 + softmax 加权融合。

它的革命性在于:任何两个位置之间都能直接通信——无论它们隔了几千个词。

RNN 是接力赛,Attention 是会场广播。

想”看见”它

👀 注意力实时计算可视化 —— 玩 4 种注意力头,看 attention pattern 在不同任务下的差异。

💡 读完这一篇你解锁了什么
  • 能看懂 BERT、GPT 论文里的 attention 部分
  • 能在 PyTorch 里手写一个 transformer block
  • 能对”你说 attention pattern” 等讨论应得上话

下一步:搭一个最小可运行的 Transformer。

下一篇:《CNN 卷积原理:从滤镜到 ResNet》 —— 视觉领域的”attention 之前的王者”。

📬

读到这里说明你认真在学 🎯

订阅每周精选 —— 下一篇新文章 / 新可视化第一时间送到邮箱。

💬

讨论区

· 用 GitHub 账号登录评论
⚠️ Giscus 评论未配置 —— 在 src/components/Comments.astro 顶部填入 仓库 ID 和分类 ID(见组件注释里的配置步骤)。