跳到主要内容
AIInfra前置基础

3.3 Self-Attention机制深入理解

Self-Attention 是 Transformer 的心脏,也是当代大模型中计算量最集中、优化手段最丰富的模块

Transformer Self-Attention Multi-Head Attention FlashAttention GQA

Self-Attention 是 Transformer 的心脏,也是当代大模型中计算量最集中、优化手段最丰富的模块。无论你是想理解 FlashAttention 背后的 IO 优化思想,还是想搞清楚 GQA、MLA 这些 Attention 变种为什么能减少推理开销,都绕不开对 Self-Attention 机制的深入理解。本文将从 Attention 的历史起源讲起,逐步拆解 Scaled Dot-Product Attention 的每一步数学原理,手写 PyTorch 实现,分析计算瓶颈,最后延伸到 FlashAttention 和各种 Attention 变种,力求让读者建立从直觉到公式再到工程实现的完整认知链条。

📑 目录


1. Attention 的历史演进

在 2017 年 Transformer 横空出世之前,Attention 机制已经在序列建模领域酝酿了好几年。理解它的演进脉络,有助于我们把握 Self-Attention 到底解决了什么问题,以及为什么它的设计长成今天这个样子。

1.1 Bahdanau Attention (2014):开山之作

传统的 Encoder-Decoder 架构(比如用 RNN 做机器翻译)有一个致命缺陷:Encoder 把整个输入序列压缩成一个固定长度的向量,再交给 Decoder 去生成输出。当输入序列很长时,这个固定向量根本装不下所有信息,翻译质量会急剧下降。

Bahdanau 等人在 2014 年提出了一个关键改进:别硬压缩了,让 Decoder 在生成每个词的时候,自己回头”看”Encoder 的所有隐状态,按需取用。具体做法是,Decoder 当前时刻的隐状态作为”查询”,跟 Encoder 每个时刻的隐状态做”匹配”(通过一个小型前馈网络计算匹配分数),然后用匹配分数对 Encoder 的隐状态加权求和,得到一个”上下文向量”,作为当前解码步的额外输入。

这就是 Attention 的雏形:让模型学会”注意”输入序列的不同部分。不过 Bahdanau Attention 的匹配函数是一个额外的前馈网络(additive attention),计算效率不高。

1.2 Luong Attention (2015):简化计算

Luong 在 2015 年对 Bahdanau Attention 做了两个关键简化:一是提出了更高效的匹配函数(直接用点积代替前馈网络),二是探索了多种对齐方式(全局 vs 局部)。其中,点积 Attention 的计算方式——两个向量直接做内积来衡量相似度——后来成为了 Transformer 的基础。

点积的计算效率远高于前馈网络:它可以被表示为矩阵乘法,天然适合 GPU 并行加速。这个看似简单的改进,为后来 Self-Attention 的大规模应用铺平了道路。

1.3 Self-Attention (2017):Attention Is All You Need

2017 年的 Transformer 论文做了一个大胆的决定:完全抛弃 RNN,只用 Attention 来建模序列

之前的 Attention 都是”跨序列”的——Decoder 去关注 Encoder。而 Self-Attention 是”自关注”——序列中的每个位置去关注同一个序列中的所有位置(包括自己)。这样做有两大优势:

  • 并行性:RNN 必须逐步计算(第 tt 步依赖第 t1t-1 步的结果),而 Self-Attention 中所有位置可以同时计算,天然适合 GPU 并行。
  • 长距离依赖:RNN 中距离较远的两个词需要通过多步传递才能交互信息,信号会逐步衰减。Self-Attention 中任意两个位置之间只需要一步就能直接交互,理论上不存在距离衰减问题。

代价是什么?Self-Attention 的计算复杂度是 O(N2)O(N^2)——序列中每对位置都需要计算匹配分数,而 RNN 的复杂度是 O(N)O(N)。这个平方复杂度成为后续无数优化工作的出发点。


2. 从信息检索角度理解 QKV

Attention 中最核心的三个概念是 Query(Q)、Key(K)和 Value(V)。这三个词直接借鉴了信息检索领域的术语,用一个日常场景来类比可以帮助建立直觉。

2.1 图书馆检索的类比

想象你走进一座图书馆,你脑子里有一个模糊的需求:“我想了解关于并行计算的内容”。这个需求就是你的 Query

图书馆里每一本书的书脊上都贴着标签——“操作系统”、“计算机网络”、“并行计算导论”、“数据结构”等等。这些标签就是每本书的 Key

你拿着自己的 Query,去和每本书的 Key 做比对。“并行计算导论”这个 Key 和你的 Query 高度匹配,匹配分数最高;“操作系统”可能有一些相关性,分数中等;“数据结构”和你的需求关系不大,分数很低。

确定了匹配分数之后,你不是把书脊标签(Key)拿走,而是把每本书的实际内容(Value)按照匹配分数加权汇总。“并行计算导论”的内容占大比重,“操作系统”的内容占一点,“数据结构”几乎不占——最终你得到的就是一份以并行计算为主、兼顾一点操作系统知识的综合信息。

2.2 三元组的形式化定义

回到 Self-Attention 的语境。给定输入序列 XX(形状为 (N,dmodel)(N, d_{model})),三个线性变换将 XX 映射到不同的空间:

  • Q=XWQQ = X \cdot W_Q:每个 token 生成自己的”查询向量”——“我需要什么信息”
  • K=XWKK = X \cdot W_K:每个 token 生成自己的”索引向量”——“我能提供什么信息的线索”
  • V=XWVV = X \cdot W_V:每个 token 生成自己的”内容向量”——“我实际携带的信息”

Q、K、V 之所以要从同一个 XX 做三次不同的线性变换,而不是直接用 XX 本身,原因在于解耦不同的角色。一个 token “需要什么”和它”能提供什么”往往是不同的。比如在一个句子里,动词可能需要关注它的主语和宾语(Query 的方向),但它作为被别人关注的对象时,提供的是动作语义(Key/Value 的方向)。三个独立的投影矩阵让模型有自由度去学习这些不同的映射。


3. Scaled Dot-Product Attention 详解

有了 Q、K、V 之后,Attention 的计算过程可以分解为清晰的五步。

3.1 第一步:线性投影

输入 XX 的形状为 (N,dmodel)(N, d_{model})。三个权重矩阵 WQW_QWKW_KWVW_V 的形状都是 (dmodel,dmodel)(d_{model}, d_{model})(以单头为例):

Q=XWQ(N,dmodel)×(dmodel,dmodel)=(N,dmodel)\begin{aligned} Q &= X \cdot W_Q \quad & (N, d_{model}) \times (d_{model}, d_{model}) = (N, d_{model}) \\ \end{aligned} K=XWK(N,dmodel)×(dmodel,dmodel)=(N,dmodel)\begin{aligned} K &= X \cdot W_K \quad & (N, d_{model}) \times (d_{model}, d_{model}) = (N, d_{model}) \\ \end{aligned} V=XWV(N,dmodel)×(dmodel,dmodel)=(N,dmodel)\begin{aligned} V &= X \cdot W_V \quad & (N, d_{model}) \times (d_{model}, d_{model}) = (N, d_{model}) \end{aligned}

这三次矩阵乘法是三次独立的 GEMM(General Matrix Multiply)操作。在实际实现中,为了提高 GPU 利用率,通常会把 WQW_QWKW_KWVW_V 合并成一个大矩阵 WQKVW_{QKV}(形状 (dmodel,3×dmodel)(d_{model}, 3 \times d_{model})),做一次 GEMM 然后 split,这样能更好地利用 GPU 的算力。

3.2 第二步:计算原始注意力分数

QQKK 的转置做矩阵乘法,衡量每对 token 之间的匹配程度:

S=QK(N,dmodel)×(dmodel,N)=(N,N)S = Q K^\top \quad (N, d_{model}) \times (d_{model}, N) = (N, N)

结果矩阵 SS 的每个元素 S[i][j]S[i][j] 是第 ii 个 token 的查询向量和第 jj 个 token 的索引向量的点积,代表 token ii 对 token jj 的原始关注程度。

这一步产生了一个 N×NN \times N 的矩阵,这就是 Self-Attention 平方复杂度的根源。

3.3 第三步:缩放

将原始分数除以 dk\sqrt{d_k}dkd_k 是 Key 向量的维度):

Sscaled=SdkS_{\text{scaled}} = \frac{S}{\sqrt{d_k}}

为什么需要缩放?简单来说,当维度 dkd_k 较大时,点积的结果会变得很大,导致后续 Softmax 的梯度趋近于零。详细的数学推导见第 6 节。

3.4 第四步:Softmax 归一化

对缩放后的分数矩阵按行做 Softmax,将原始分数转换为概率分布:

A=softmax(Sscaled)RN×N, 每行和为 1A = \text{softmax}(S_{\text{scaled}}) \quad \in \mathbb{R}^{N \times N}, \text{ 每行和为 1}

Softmax 的作用是双重的:一方面把任意实数映射到 (0,1)(0, 1) 区间,使其可以作为权重;另一方面保证每行的权重之和为 1,形成一个合法的概率分布。

3.5 第五步:加权求和

用注意力权重对 Value 矩阵做加权求和:

Output=AV(N,N)×(N,dmodel)=(N,dmodel)\text{Output} = A \cdot V \quad (N, N) \times (N, d_{model}) = (N, d_{model})

最终每个 token 位置得到一个 dmodeld_{model} 维的向量,里面融合了它所”关注”的所有 token 的信息,关注程度由 AA 的权重决定。

3.6 完整公式

将以上步骤合并,Self-Attention 的计算可以用一行公式概括:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V

最后还要再过一个输出投影矩阵 WOW_O(形状 (dmodel,dmodel)(d_{model}, d_{model})),将结果映射回模型的隐藏空间。这个投影在 Multi-Head Attention 中尤其重要——它负责将多个头拼接后的表示重新混合。


4. PyTorch 手动实现

理论讲得再多,不如亲手写一遍代码。下面分别实现 Single-Head 和 Multi-Head Self-Attention。

4.1 Single-Head Self-Attention

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SingleHeadSelfAttention(nn.Module):
    """手动实现单头 Self-Attention,不依赖 PyTorch 内置的 MHA 模块"""
    def __init__(self, d_model: int):
        super().__init__()
        self.d_model = d_model
        # 三个线性投影:输入维度 d_model,输出维度 d_model
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        # 输出投影
        self.W_O = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        # x: (batch, seq_len, d_model)
        # Step 1: 线性投影得到 Q, K, V
        Q = self.W_Q(x)  # (batch, seq_len, d_model)
        K = self.W_K(x)  # (batch, seq_len, d_model)
        V = self.W_V(x)  # (batch, seq_len, d_model)

        # Step 2: 计算注意力分数 = Q @ K^T
        scores = torch.matmul(Q, K.transpose(-2, -1))  # (batch, seq_len, seq_len)

        # Step 3: 缩放,防止点积值过大导致 softmax 梯度消失
        scores = scores / math.sqrt(self.d_model)

        # Step 4: 如果提供了 mask,将被遮蔽位置的分数设为负无穷
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # Step 5: Softmax 归一化,得到注意力权重
        attn_weights = F.softmax(scores, dim=-1)  # (batch, seq_len, seq_len)

        # Step 6: 用权重对 V 加权求和
        context = torch.matmul(attn_weights, V)  # (batch, seq_len, d_model)

        # Step 7: 输出投影
        output = self.W_O(context)  # (batch, seq_len, d_model)
        return output

4.2 Multi-Head Self-Attention

class MultiHeadSelfAttention(nn.Module):
    """手动实现多头 Self-Attention:将 d_model 切分为 num_heads 个子空间"""
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads  # 每个头的维度

        # QKV 投影:输入 d_model,输出 d_model(内部包含所有头的投影)
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        batch_size, seq_len, _ = x.shape

        # Step 1: 线性投影
        Q = self.W_Q(x)  # (batch, seq_len, d_model)
        K = self.W_K(x)
        V = self.W_V(x)

        # Step 2: 重塑为多头形状,并转置使 head 维度在 seq_len 前面
        # (batch, seq_len, d_model) -> (batch, seq_len, num_heads, head_dim)
        # -> (batch, num_heads, seq_len, head_dim)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Step 3: 每个头独立计算注意力分数
        # Q @ K^T: (batch, num_heads, seq_len, head_dim)
        #        @ (batch, num_heads, head_dim, seq_len)
        #        = (batch, num_heads, seq_len, seq_len)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)

        # Step 4: 可选的因果遮蔽
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # Step 5: Softmax + 加权求和
        attn_weights = F.softmax(scores, dim=-1)
        # (batch, num_heads, seq_len, seq_len) @ (batch, num_heads, seq_len, head_dim)
        # = (batch, num_heads, seq_len, head_dim)
        context = torch.matmul(attn_weights, V)

        # Step 6: 多头拼接——先转回 (batch, seq_len, num_heads, head_dim),再合并最后两维
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)

        # Step 7: 输出投影,将拼接后的多头表示混合
        output = self.W_O(context)  # (batch, seq_len, d_model)
        return output

上面的代码完整展示了从输入到输出的每一步。值得注意的是,.transpose(1, 2) 这一步将 (batch, seq_len, num_heads, head_dim) 变成 (batch, num_heads, seq_len, head_dim),这样后续的矩阵乘法就能在每个头内部独立进行,利用 batch 维度的并行性一次算完所有头。


5. 详细维度推导

抽象的维度符号容易让人迷糊。下面用一组具体数值,完整跟踪 Multi-Head Self-Attention 中每一步的张量形状。

5.1 参数设定

seq_len   = 6        # 序列长度(6 个 token)
d_model   = 512      # 模型隐藏维度
num_heads = 8        # 注意力头数
head_dim  = 512 / 8 = 64   # 每个头的维度
batch     = 1        # 简化为 batch=1

5.2 逐步跟踪

输入

X: (1, 6, 512)    # 1 个样本,6 个 token,每个 token 512 维

线性投影

W_Q: (512, 512)    # nn.Linear 的权重矩阵
Q = X @ W_Q^T:  (1, 6, 512) @ (512, 512) = (1, 6, 512)
K = X @ W_K^T:  (1, 6, 512) @ (512, 512) = (1, 6, 512)
V = X @ W_V^T:  (1, 6, 512) @ (512, 512) = (1, 6, 512)

切分多头 + 转置

Q.view(1, 6, 8, 64):         (1, 6, 8, 64)
Q.transpose(1, 2):           (1, 8, 6, 64)    # 8 个头,每个头看到 6 个 token 的 64 维向量

K reshape 同理:              (1, 8, 6, 64)
V reshape 同理:              (1, 8, 6, 64)

计算注意力分数

K.transpose(-2, -1):         (1, 8, 64, 6)    # 转置最后两维
Q @ K^T:                     (1, 8, 6, 64) @ (1, 8, 64, 6) = (1, 8, 6, 6)
                             # 每个头产生一个 6x6 的注意力分数矩阵

缩放

scores / sqrt(64):           (1, 8, 6, 6)    # sqrt(64) = 8,每个分数除以 8

Softmax

softmax(scores, dim=-1):     (1, 8, 6, 6)    # 每行(最后一维)归一化为概率分布
                             # 每个 6 维行的元素之和为 1

加权求和

attn_weights @ V:            (1, 8, 6, 6) @ (1, 8, 6, 64) = (1, 8, 6, 64)
                             # 每个头输出 6 个 token 的 64 维表示

多头拼接

transpose(1, 2):             (1, 6, 8, 64)    # 把 head 维度挪回去
view(1, 6, 512):             (1, 6, 512)      # 8 * 64 = 512,拼接回 d_model

输出投影

W_O: (512, 512)
context @ W_O^T:             (1, 6, 512) @ (512, 512) = (1, 6, 512)

最终输出

Output: (1, 6, 512)    # 与输入 X 形状完全一致

从头到尾,输入和输出的形状都是 (batch,seq_len,dmodel)(batch, seq\_len, d_{model}),这保证了多个 Transformer Block 可以像积木一样层层堆叠。


6. 为什么要除以 dk\sqrt{d_k}:方差稳定性的数学推导

除以 dk\sqrt{d_k} 这一步在直觉上容易被当作”调参技巧”一笔带过,但它背后有严格的数学原因。

6.1 问题提出

假设 QQKK 的每个分量都是均值为 0、方差为 1 的独立随机变量。我们来计算点积 qkq \cdot k 的方差。

q=(q1,q2,,qdk)q = (q_1, q_2, \ldots, q_{d_k})k=(k1,k2,,kdk)k = (k_1, k_2, \ldots, k_{d_k}),其中每个 qiq_ikjk_j 独立同分布,满足 E[qi]=0E[q_i] = 0Var(qi)=1\text{Var}(q_i) = 1

点积的定义是:

qk=i=1dkqikiq \cdot k = \sum_{i=1}^{d_k} q_i \cdot k_i

6.2 推导过程

首先,单个分量乘积的期望和方差:

E[qiki]=E[qi]E[ki]=0×0=0E[q_i \cdot k_i] = E[q_i] \cdot E[k_i] = 0 \times 0 = 0 Var(qiki)=E[qi2ki2](E[qiki])2\text{Var}(q_i \cdot k_i) = E[q_i^2 \cdot k_i^2] - (E[q_i \cdot k_i])^2

由于 qiq_ikik_i 独立:

E[qi2ki2]=E[qi2]E[ki2]=Var(qi)Var(ki)=1×1=1E[q_i^2 \cdot k_i^2] = E[q_i^2] \cdot E[k_i^2] = \text{Var}(q_i) \cdot \text{Var}(k_i) = 1 \times 1 = 1

所以:

Var(qiki)=10=1\text{Var}(q_i \cdot k_i) = 1 - 0 = 1

点积是 dkd_k 个独立随机变量之和,根据方差的可加性:

Var(qk)=i=1dkVar(qiki)=dk\text{Var}(q \cdot k) = \sum_{i=1}^{d_k} \text{Var}(q_i \cdot k_i) = d_k

6.3 结论

点积 qkq \cdot k 的方差是 dkd_k。这意味着当 dk=64d_k = 64 时,点积值的标准差大约是 8;当 dk=128d_k = 128 时,标准差大约是 11.3。维度越大,点积值的绝对值越大,分布越分散。

Softmax 函数 softmax(zi)=exp(zi)/exp(zj)\text{softmax}(z_i) = \exp(z_i) / \sum \exp(z_j) 对输入值的量级非常敏感。当输入值之间的差距很大时(比如一个值是 50,另一个是 -10),Softmax 的输出会极度集中在最大值上,接近 one-hot 分布。此时梯度几乎为零,参数无法更新,训练陷入停滞。

除以 dk\sqrt{d_k} 之后:

Var(qkdk)=Var(qk)dk=dkdk=1\text{Var}\left(\frac{q \cdot k}{\sqrt{d_k}}\right) = \frac{\text{Var}(q \cdot k)}{d_k} = \frac{d_k}{d_k} = 1

方差被拉回到 1,无论 dkd_k 取什么值,Softmax 的输入始终在一个合理的范围内波动,梯度保持健康。这就是 Scaled Dot-Product Attention 名字中 “Scaled” 一词的由来。


7. Softmax 的数值稳定性

Softmax 看起来是一个简单的公式,但在实际的 GPU 实现中,它隐藏着一个容易导致数值溢出的陷阱。

7.1 溢出问题

标准 Softmax 公式:

softmax(zi)=ezij=1Nezj\text{softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^{N} e^{z_j}}

问题在于 exp\exp 函数增长极快。当 ziz_i 值较大时(比如 zi=1000z_i = 1000),exp(1000)\exp(1000) 直接超出 float32 甚至 float16 的表示范围,得到 inf。即使分子分母同时溢出,inf / inf 会得到 NaN,整个计算就废了。

7.2 减最大值技巧

解决方案是利用 Softmax 的平移不变性:对所有输入减去最大值,不改变结果。

softmax(zi)=ezimax(z)j=1Nezjmax(z)\text{softmax}(z_i) = \frac{e^{z_i - \max(z)}}{\sum_{j=1}^{N} e^{z_j - \max(z)}}

数学证明很简单——分子分母同乘以 exp(max(z))\exp(-\max(z)),等价于分子分母各自除以 exp(max(z))\exp(\max(z)),值不变。减去最大值后,最大的指数输入是 0,所以 exp\exp 的结果最大为 1,不会溢出。

但这带来了一个效率问题:标准实现需要对数据做三遍扫描

  1. 第一遍:遍历所有元素,找最大值 mm
  2. 第二遍:遍历所有元素,计算 exp(zim)\exp(z_i - m) 和它们的总和 sum\text{sum}
  3. 第三遍:遍历所有元素,每个 exp(zim)\exp(z_i - m) 除以 sum\text{sum}

每遍扫描都要从 HBM(高带宽显存)读取数据,三遍就是三次 HBM 读取。当 NN 很大时(比如 128K 的序列长度),这个 IO 开销非常大。

7.3 Online Softmax

Milakov 和 Gimelshein 在 2018 年提出了 Online Softmax 算法,将三遍扫描合并为一遍扫描,核心思想是在遍历数据的同时动态维护最大值和归一化分母。

算法流程如下。在遍历到第 ii 个元素时,维护两个变量:

  • mim_i:前 ii 个元素的最大值
  • did_i:前 ii 个元素的指数和(以 mim_i 为基准)

递推公式:

mi=max(mi1,zi)\begin{aligned} m_i &= \max(m_{i-1}, z_i) \end{aligned} di=di1emi1mi+ezimi\begin{aligned} d_i &= d_{i-1} \cdot e^{m_{i-1} - m_i} + e^{z_i - m_i} \end{aligned}

关键技巧在第二行:当最大值从 mi1m_{i-1} 更新到 mim_i 时,之前累积的指数和 di1d_{i-1} 需要乘以一个修正因子 exp(mi1mi)\exp(m_{i-1} - m_i) 来”换基”。这样只需一遍扫描就能同时得到最大值和指数和,然后再做一遍扫描完成归一化——总共两遍,比标准实现少一遍。

这个思想对 FlashAttention 至关重要:FlashAttention 将 Attention 矩阵分成小块(tile)逐块处理,每处理一块就需要更新 Softmax 的中间结果。如果不能在线更新 Softmax,就必须把整个 N×NN \times N 矩阵写入 HBM 后再做全局 Softmax,那就失去了分块计算的意义。Online Softmax 让分块计算和精确 Softmax 得以兼容。


8. Multi-Head Attention

8.1 为什么需要多头

单头 Attention 只有一组 QKV 投影,意味着模型只能学习一种”关注模式”。但语言中 token 之间的关系是多维度的——同一个词和其他词之间可能同时存在句法关系(主谓一致)、语义关系(同义替换)、位置关系(相邻词的局部模式)等。

打个比方:在一次项目评审会议上,只派一个评审员去审阅整个项目,他只能从自己擅长的角度提出意见。如果派出一个评审团——一位看技术架构,一位看代码质量,一位看测试覆盖率,一位看文档完整性——每个人独立给出评分和建议,最后汇总成一份综合评审报告,覆盖面就远比单人评审要全面得多。

Multi-Head Attention 就是这个”评审团”机制。每个头有自己独立的 WQW_QWKW_KWVW_V 投影参数,在不同的子空间中捕捉不同类型的关系。

8.2 数学原理

假设 dmodel=512d_{model} = 512num_heads=8num\_heads = 8,则 head_dim=64head\_dim = 64

Multi-Head Attention 的完整公式:

MultiHead(Q,K,V)=Concat(head1,,headh)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) \cdot W_O headi=Attention(QWQi,KWKi,VWVi)\text{head}_i = \text{Attention}(Q \cdot W_Q^i, K \cdot W_K^i, V \cdot W_V^i)

其中每个头的投影矩阵 WQiW_Q^iWKiW_K^iWViW_V^i 的形状是 (dmodel,head_dim)=(512,64)(d_{model}, head\_dim) = (512, 64)。但实际实现中,并不会真的维护 hh 组小矩阵——而是用一个大矩阵 WQW_Q(形状 (512,512)(512, 512))做一次投影,然后 reshape 成 (seq_len,8,64)(seq\_len, 8, 64) 来切分。这样做是等价的:大矩阵可以看作 8 个小矩阵纵向拼接。

8.3 参数量分析

dmodel=512d_{model} = 512 为例:

参数矩阵形状参数量
WQW_Q(512,512)(512, 512)262,144
WKW_K(512,512)(512, 512)262,144
WVW_V(512,512)(512, 512)262,144
WOW_O(512,512)(512, 512)262,144
合计4×5122=1,048,5764 \times 512^2 = 1{,}048{,}576

通用公式:MHA 的参数量为 4 \times d_{model}^2\(不算 bias 的情况)。不管有多少个头,总参数量不变——头数只影响切分方式,不影响总参数。这是因为每增加一个头,每个头的维度相应减小,两者乘积(即总投影维度)始终等于 dmodeld_{model}

8.4 多头的工程意义

多头结构天然适合并行化。8 个头的计算完全独立,可以:

  • GPU 内并行:利用 batch 维度,在一次 CUDA kernel 启动中同时处理所有头
  • 多卡张量并行(Tensor Parallelism):将不同头分配到不同 GPU 上,每张 GPU 只计算自己负责的若干头。比如 8 个头分到 4 张 GPU,每张处理 2 个头。最后通过一次 AllReduce 通信汇总输出投影的结果

9. MHA vs MQA vs GQA vs MLA

随着大模型进入推理效率至上的时代,Attention 头的结构也在不断演化。核心驱动力是一个问题:KV Cache 太大了

9.1 MHA (Multi-Head Attention)

标准的 MHA 中,每个注意力头都有独立的 Q、K、V。如果有 hh 个头,那就有 hh 组 K 和 hh 组 V 需要缓存。

  • hh 个 Q 头,hh 个 K 头,hh 个 V 头
  • 每组 Q/K/V 独立,互不共享

9.2 MQA (Multi-Query Attention)

Shazeer 在 2019 年提出:所有 Q 头共享同一组 K 和 V。也就是说,只有 1 个 K 头和 1 个 V 头,但仍然有 hh 个 Q 头。

  • hh 个 Q 头,1 个 K 头,1 个 V 头
  • 所有 Q 头共享同一份 K 和 V

好处是 KV Cache 缩小为原来的 1/h1/h,推理速度大幅提升。代价是模型表达能力下降——所有头被迫从同一组 KV 中提取信息,多样性受限。

9.3 GQA (Grouped-Query Attention)

GQA 是 MHA 和 MQA 的折中方案(Ainslie et al., 2023)。将 hh 个 Q 头分成 gg 个组(每组 h/gh/g 个 Q 头),每组共享一组 K 和 V。

  • hh 个 Q 头,gg 个 K 头,gg 个 V 头
  • h/gh/g 个 Q 头共享一组 K 和 V
  • g=hg = h 时退化为 MHA,当 g=1g = 1 时退化为 MQA

GQA 在模型质量和推理效率之间取得了更好的平衡。LLaMA-2-70B、Mistral-7B 等主流模型都采用了 GQA。

9.4 对比表格

dmodel=4096d_{model} = 4096,num_heads = 32,head_dim = 128 为例,假设序列长度 N=4096N = 4096,FP16 精度:

指标MHA (g=32g=32)GQA (g=8g=8)MQA (g=1g=1)
Q 头数323232
KV 头数3281
WKW_K 参数量32×128×4096=16M32 \times 128 \times 4096 = 16\text{M}8×128×4096=4M8 \times 128 \times 4096 = 4\text{M}1×128×4096=0.5M1 \times 128 \times 4096 = 0.5\text{M}
WVW_V 参数量16M16\text{M}4M4\text{M}0.5M0.5\text{M}
单 token KV Cache2×32×128×2B=16KB2 \times 32 \times 128 \times 2\text{B} = 16\text{KB}2×8×128×2B=4KB2 \times 8 \times 128 \times 2\text{B} = 4\text{KB}2×1×128×2B=0.5KB2 \times 1 \times 128 \times 2\text{B} = 0.5\text{KB}
4096 token KV Cache / 层64MB64\text{MB}16MB16\text{MB}2MB2\text{MB}
模型质量最好接近 MHA有下降

注:Cache size = 2 × B × L × num_heads × head_dim × d_type_bytes (其中 2 是 K 和 V,B 是 batch_size,L 是已生成序列长度) 可以看到,从 MHA 到 GQA(g=8g=8),KV Cache 缩小为 1/4,参数量减少,但模型质量几乎不受影响。这就是 GQA 成为主流选择的原因。

9.5 MLA (Multi-Latent Attention)

DeepSeek-V2 提出了 MLA(Multi-Latent Attention),代表了另一种压缩 KV Cache 的思路。

MLA 的核心思想是:不直接缓存完整的 K 和 V 向量,而是将它们压缩到一个低维潜在空间(latent space),只缓存这个低维表示。推理时再将低维表示解压回完整的 K 和 V。

  • 标准 MHAXWKKX \to W_K \to K(缓存 KK,维度 = num_kv_heads ×\times head_dim)
  • MLAXWDKVcX \to W_{DKV} \to c(缓存 cc,维度远小于 KK 的维度),推理时 cWUKKc \to W_{UK} \to K 按需解压

MLA 通过低秩压缩将 KV Cache 大幅缩小(DeepSeek-V2 报告了约 93.3% 的压缩率),同时通过精心设计的上投影矩阵保持了模型质量。它与 GQA 的思路不同——GQA 是减少 KV 头的数量,MLA 是降低每个表示的维度——但目标一致:让推理时的 KV Cache 尽可能小。


10. Causal Mask 的作用与实现

10.1 为什么需要遮蔽

在自回归语言模型(如 GPT 系列、LLaMA 系列)中,模型的训练目标是”根据前文预测下一个 token”。这要求在计算 Attention 时,每个位置的 token 只能看到自己和它之前的 token,不能”偷看”未来的信息。否则,模型在训练时就能看到答案,学不到任何有意义的预测能力。

10.2 实现方式

Causal Mask(因果遮蔽)是一个上三角矩阵,覆盖在注意力分数矩阵上,将未来位置的分数设为负无穷:

# 构造因果遮蔽矩阵
def create_causal_mask(seq_len: int) -> torch.Tensor:
    """
    生成一个下三角矩阵(含对角线),上三角部分为 False。
    为 True 的位置保留,为 False 的位置在 Attention 分数中被设为 -inf。
    """
    mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))
    return mask

# 使用示例:seq_len = 4
# mask:
# [[True,  False, False, False],
#  [True,  True,  False, False],
#  [True,  True,  True,  False],
#  [True,  True,  True,  True ]]
#
# 被遮蔽后的 Attention 分数矩阵(上三角为 -inf):
# [[s00,  -inf, -inf, -inf],
#  [s10,  s11,  -inf, -inf],
#  [s20,  s21,  s22,  -inf],
#  [s30,  s31,  s32,  s33 ]]
#
# Softmax 之后,-inf 位置的权重变为 0,未来信息被完全屏蔽

Causal Mask 的另一个工程意义在于:它使得 Attention 矩阵只有下三角部分有效,理论上可以跳过上三角部分的计算。FlashAttention-2 正是利用了这一点——在处理上三角区域的 tile 时直接跳过,减少了接近一半的计算量。


11. Self-Attention 的计算瓶颈分析

理解 Self-Attention 的性能瓶颈,需要区分两种不同的场景。

11.1 计算量分析

Self-Attention 中有四次主要的矩阵乘法:

运算形状FLOPs
Q=XWQQ = X \cdot W_Q(N,d)×(d,d)(N, d) \times (d, d)2Nd22Nd^2
K=XWKK = X \cdot W_K(N,d)×(d,d)(N, d) \times (d, d)2Nd22Nd^2
V=XWVV = X \cdot W_V(N,d)×(d,d)(N, d) \times (d, d)2Nd22Nd^2
S=QKTS = Q \cdot K^T(N,d)×(d,N)(N, d) \times (d, N)2N2d2N^2d
O=SVO = S \cdot V(N,N)×(N,d)(N, N) \times (N, d)2N2d2N^2d
Final=OWOFinal = O \cdot W_O(N,d)×(d,d)(N, d) \times (d, d)2Nd22Nd^2

总计:8Nd2+4N2d8Nd^2 + 4N^2d

NN 较小(比如 N<dN < d)时,8Nd28Nd^2 项主导——瓶颈在 QKV 投影的 GEMM 运算。当 NN 较大(比如 N>dN > d)时,4N2d4N^2d 项主导——瓶颈在 Attention 矩阵的计算。

11.2 Compute Bound vs Memory Bound

GPU 的性能受两个指标限制:

  • 算力(Compute):每秒能做多少次浮点运算(FLOPS)
  • 带宽(Memory Bandwidth):每秒能从 HBM 读写多少数据(GB/s)

一个运算是 Compute Bound 还是 Memory Bound,取决于它的算术强度(Arithmetic Intensity)= FLOPs / Bytes。

  • 如果算术强度 > GPU 的算力/带宽比值,运算是 Memory Bound——GPU 算力用不满,大部分时间在等数据搬运
  • 如果算术强度 < GPU 的算力/带宽比值,运算是 Compute Bound——GPU 带宽够用,大部分时间在做计算

以 A100 为例:FP16 算力约 312 TFLOPS,HBM 带宽约 2 TB/s,算力/带宽比值约 156 FLOP/Byte。

Prefill 阶段(处理完整 prompt):

QKV 投影是大矩阵乘法,batch 维度(seq_len)很大,算术强度高,通常是 Compute Bound。Attention 的 QKTQ \cdot K^T 也是大矩阵乘,同样是 Compute Bound。

Decode 阶段(逐 token 生成):

每步只有 1 个 token,QKV 投影退化为矩阵-向量乘法(GEMV),算术强度极低,是 Memory Bound。Attention 退化为一个向量和整个 KV Cache 的运算,同样 Memory Bound。大部分时间花在从 HBM 搬运模型权重和 KV Cache 上。

11.3 O(N2)O(N^2) 的显存瓶颈

标准 Attention 实现需要显式地计算并存储完整的 N×NN \times N 注意力矩阵。以 seq_len = 128K\text{K}、num_heads = 32 为例:

32×128K×128K×2B=32×16,384M×2=1,048,576 MB=1 TB32 \times 128\text{K} \times 128\text{K} \times 2\text{B} = 32 \times 16{,}384\text{M} \times 2 = 1{,}048{,}576\text{ MB} = 1\text{ TB}

这显然远超任何单张 GPU 的显存容量。即使序列长度”只有”8K,注意力矩阵也需要 32×8K×8K×2B=4 GB32 \times 8\text{K} \times 8\text{K} \times 2\text{B} = 4\text{ GB},已经是一个不可忽视的开销。

这个 O(N2)O(N^2) 的显存瓶颈,正是 FlashAttention 要解决的核心问题。


12. FlashAttention 的核心思想

FlashAttention(Dao et al., 2022)是过去几年 Attention 优化领域最具影响力的工作。它不改变 Attention 的计算结果(是精确计算,不是近似),但通过重新编排计算顺序,大幅减少了对 HBM 的访问量。

12.1 GPU 存储层次回顾

要理解 FlashAttention,需要先了解 GPU 的存储层次:

  • HBM(High Bandwidth Memory):GPU 的主显存,容量大(如 A100 的 80 GB)但访问速度相对慢(2 TB/s)
  • SRAM(片上缓存):每个 SM(Streaming Multiprocessor)上的共享内存和寄存器,容量很小(如 A100 每个 SM 约 192 KB 共享内存,全部 SM 合计约 20 MB)但访问速度极快(约 19 TB/s)

标准 Attention 的流程是:在 HBM 中完成 QKTQ \cdot K^T,把完整的 N×NN \times N 注意力矩阵写回 HBM;然后从 HBM 读取注意力矩阵做 Softmax,结果写回 HBM;最后从 HBM 读取 Softmax 结果和 VV 做矩阵乘法。每一步都涉及对 N×NN \times N 大矩阵的 HBM 读写——这是巨大的带宽浪费。

12.2 Tiling:分块计算

FlashAttention 的第一个关键技术是 Tiling(分块):不一次性计算完整的 N×NN \times N 注意力矩阵,而是将 Q、K、V 分成若干小块(tile),每次只加载一小块 Q 和一小块 K、V 到 SRAM 中,在 SRAM 内完成该块的 Attention 计算,然后将结果写回 HBM。

标准实现:
  Q, K 全部加载到 HBM → 计算完整 N x N 矩阵 → 写回 HBM → 读取做 Softmax → 写回 HBM → ...

FlashAttention:
  将 Q 分成 T_r 块,K/V 分成 T_c 块
  For 每一块 Q_i:
    For 每一块 K_j, V_j:
      从 HBM 加载 Q_i, K_j, V_j 到 SRAM(每块很小,能装下)
      在 SRAM 内计算 Q_i @ K_j^T → 局部 Softmax → 乘以 V_j
      累积到该块 Q_i 对应的输出中
    将 Q_i 块的最终输出写回 HBM

每次只处理一个 tile,注意力矩阵的这一小块始终驻留在 SRAM 中,永远不需要把完整的 N×NN \times N 矩阵写入 HBM。

12.3 Online Softmax 的关键作用

Tiling 带来一个棘手的问题:Softmax 需要对每一行的所有元素做归一化(需要全局最大值和全局求和),但分块计算时每次只看到一行的一部分——怎么在只看到局部数据的情况下计算全局 Softmax?

这就是 Online Softmax 派上用场的地方。回忆第 7.3 节的 Online Softmax 递推公式:当新数据块到来时,可以动态更新全局最大值和归一化分母,并修正之前块的计算结果。

具体来说,处理 QiQ_iK1K_1 块的 Attention 后,得到一个局部输出 O1O_1 和对应的 Softmax 统计量(局部最大值 m1m_1 和局部分母 l1l_1)。当继续处理 QiQ_iK2K_2 块时,会得到新的局部统计量 m2m_2l2l_2。通过 Online Softmax 的修正:

mnew=max(m1,m2)\begin{aligned} m_{\text{new}} &= \max(m_1, m_2) \end{aligned} lnew=l1em1mnew+l2em2mnew\begin{aligned} l_{\text{new}} &= l_1 \cdot e^{m_1 - m_{\text{new}}} + l_2 \cdot e^{m_2 - m_{\text{new}}} \end{aligned} Onew=O1l1em1mnew+O2locall2em2mnewlnew\begin{aligned} O_{\text{new}} &= \frac{O_1 \cdot l_1 \cdot e^{m_1 - m_{\text{new}}} + O_2^{\text{local}} \cdot l_2 \cdot e^{m_2 - m_{\text{new}}}}{l_{\text{new}}} \end{aligned}

不断累积,直到所有 KK 块处理完毕,最终结果与标准 Attention 的全局 Softmax 在数学上完全一致。

12.4 为什么能减少 HBM 访问

标准 Attention 的 HBM 访问量:

  • 写入 SS (N×NN \times N):O(N2)O(N^2) 次写
  • 读取 SS 做 Softmax:O(N2)O(N^2) 次读
  • 写入 P=Softmax(S)P = \text{Softmax}(S)O(N2)O(N^2) 次写
  • 读取 PPPVP \cdot VO(N2)O(N^2) 次读
  • 总 HBM 访问量:O(N^2)\(加上 Q, K, V 本身的 O(Nd)O(Nd) 读取)

FlashAttention 的 HBM 访问量:

  • 读取 Q, K, V:O(Nd)O(Nd)
  • 写入最终输出 O:O(Nd)O(Nd)
  • 中间的注意力矩阵始终在 SRAM 中,不写入 HBM
  • 总 HBM 访问量:O(Nd)O(Nd)

O(N2)O(N^2) 降到 O(Nd)O(Nd)——当 NN 远大于 dd 时(长序列场景),这是一个数量级的改进。注意,计算量(FLOPs)没有变——仍然是 O(N2d)O(N^2d)——改变的只是 IO 模式。FlashAttention 的加速本质上来自于减少了对慢速 HBM 的访问,让计算尽可能在快速 SRAM 中完成

12.5 FlashAttention-2 和 FlashAttention-3

FlashAttention-2(Dao, 2023)在 FlashAttention 的基础上进一步优化了并行策略和 warp 调度:

  • 调整了内外循环的顺序(外循环遍历 Q 块,内循环遍历 K/V 块),减少了共享内存的读写
  • 更好的 warp 级工作分配,提升了 GPU 的占用率
  • 利用 Causal Mask 跳过全零的 tile,进一步减少计算量

FlashAttention-3(Dao et al., 2024)则针对 Hopper 架构(H100)进行了优化,利用了 TMA(Tensor Memory Accelerator)和 wgmma(Warp Group MMA)指令,进一步提升了硬件利用率。

目前 FlashAttention 已经是所有主流推理和训练框架的标配——PyTorch 2.0+ 内置了 torch.nn.functional.scaled_dot_product_attention,底层默认调用 FlashAttention kernel。


🎯 自我检验清单

完成本文学习后,用以下问题检验自己的理解深度:

  • 能从 Bahdanau Attention 到 Self-Attention 梳理出 Attention 机制的演进脉络,说清每一步解决了什么问题
  • 能默写 Scaled Dot-Product Attention 的完整公式,并解释 Q、K、V 三元组各自的作用
  • 能从零推导 “除以 dk\sqrt{d_k}” 的数学原因:假设 qi,kiq_i, k_i 独立标准正态,推出点积方差为 dkd_k
  • 能手写 Single-Head 和 Multi-Head Self-Attention 的 PyTorch 实现,能说清每一步的张量维度变化
  • 能解释 Softmax 为什么要减最大值,以及 Online Softmax 的核心递推思想
  • 给定 seq_len、dmodeld_{model}、num_heads 的具体数值,能完整跟踪每一步的张量形状
  • 能画出 MHA / MQA / GQA 的结构差异,并估算各自的 KV Cache 大小
  • 能说清 Causal Mask 的作用和实现方式(上三角设为负无穷,Softmax 后归零)
  • 能区分 Compute Bound 和 Memory Bound,并说清 Prefill 和 Decode 阶段各自的瓶颈类型
  • 能说清 FlashAttention 的核心思想:Tiling + Online Softmax 如何将 HBM 访问从 O(N2)O(N^2) 降到 O(Nd)O(Nd)

📚 参考资料

论文

教程与博客