3.3 Self-Attention机制深入理解
Self-Attention 是 Transformer 的心脏,也是当代大模型中计算量最集中、优化手段最丰富的模块
Self-Attention 是 Transformer 的心脏,也是当代大模型中计算量最集中、优化手段最丰富的模块。无论你是想理解 FlashAttention 背后的 IO 优化思想,还是想搞清楚 GQA、MLA 这些 Attention 变种为什么能减少推理开销,都绕不开对 Self-Attention 机制的深入理解。本文将从 Attention 的历史起源讲起,逐步拆解 Scaled Dot-Product Attention 的每一步数学原理,手写 PyTorch 实现,分析计算瓶颈,最后延伸到 FlashAttention 和各种 Attention 变种,力求让读者建立从直觉到公式再到工程实现的完整认知链条。
📑 目录
- 1. Attention 的历史演进
- 2. 从信息检索角度理解 QKV
- 3. Scaled Dot-Product Attention 详解
- 4. PyTorch 手动实现
- 5. 详细维度推导
- 6. 为什么要除以 :方差稳定性的数学推导
- 7. Softmax 的数值稳定性
- 8. Multi-Head Attention
- 9. MHA vs MQA vs GQA vs MLA
- 10. Causal Mask 的作用与实现
- 11. Self-Attention 的计算瓶颈分析
- 12. FlashAttention 的核心思想
- 自我检验清单
- 参考资料
1. Attention 的历史演进
在 2017 年 Transformer 横空出世之前,Attention 机制已经在序列建模领域酝酿了好几年。理解它的演进脉络,有助于我们把握 Self-Attention 到底解决了什么问题,以及为什么它的设计长成今天这个样子。
1.1 Bahdanau Attention (2014):开山之作
传统的 Encoder-Decoder 架构(比如用 RNN 做机器翻译)有一个致命缺陷:Encoder 把整个输入序列压缩成一个固定长度的向量,再交给 Decoder 去生成输出。当输入序列很长时,这个固定向量根本装不下所有信息,翻译质量会急剧下降。
Bahdanau 等人在 2014 年提出了一个关键改进:别硬压缩了,让 Decoder 在生成每个词的时候,自己回头”看”Encoder 的所有隐状态,按需取用。具体做法是,Decoder 当前时刻的隐状态作为”查询”,跟 Encoder 每个时刻的隐状态做”匹配”(通过一个小型前馈网络计算匹配分数),然后用匹配分数对 Encoder 的隐状态加权求和,得到一个”上下文向量”,作为当前解码步的额外输入。
这就是 Attention 的雏形:让模型学会”注意”输入序列的不同部分。不过 Bahdanau Attention 的匹配函数是一个额外的前馈网络(additive attention),计算效率不高。
1.2 Luong Attention (2015):简化计算
Luong 在 2015 年对 Bahdanau Attention 做了两个关键简化:一是提出了更高效的匹配函数(直接用点积代替前馈网络),二是探索了多种对齐方式(全局 vs 局部)。其中,点积 Attention 的计算方式——两个向量直接做内积来衡量相似度——后来成为了 Transformer 的基础。
点积的计算效率远高于前馈网络:它可以被表示为矩阵乘法,天然适合 GPU 并行加速。这个看似简单的改进,为后来 Self-Attention 的大规模应用铺平了道路。
1.3 Self-Attention (2017):Attention Is All You Need
2017 年的 Transformer 论文做了一个大胆的决定:完全抛弃 RNN,只用 Attention 来建模序列。
之前的 Attention 都是”跨序列”的——Decoder 去关注 Encoder。而 Self-Attention 是”自关注”——序列中的每个位置去关注同一个序列中的所有位置(包括自己)。这样做有两大优势:
- 并行性:RNN 必须逐步计算(第 步依赖第 步的结果),而 Self-Attention 中所有位置可以同时计算,天然适合 GPU 并行。
- 长距离依赖:RNN 中距离较远的两个词需要通过多步传递才能交互信息,信号会逐步衰减。Self-Attention 中任意两个位置之间只需要一步就能直接交互,理论上不存在距离衰减问题。
代价是什么?Self-Attention 的计算复杂度是 ——序列中每对位置都需要计算匹配分数,而 RNN 的复杂度是 。这个平方复杂度成为后续无数优化工作的出发点。
2. 从信息检索角度理解 QKV
Attention 中最核心的三个概念是 Query(Q)、Key(K)和 Value(V)。这三个词直接借鉴了信息检索领域的术语,用一个日常场景来类比可以帮助建立直觉。
2.1 图书馆检索的类比
想象你走进一座图书馆,你脑子里有一个模糊的需求:“我想了解关于并行计算的内容”。这个需求就是你的 Query。
图书馆里每一本书的书脊上都贴着标签——“操作系统”、“计算机网络”、“并行计算导论”、“数据结构”等等。这些标签就是每本书的 Key。
你拿着自己的 Query,去和每本书的 Key 做比对。“并行计算导论”这个 Key 和你的 Query 高度匹配,匹配分数最高;“操作系统”可能有一些相关性,分数中等;“数据结构”和你的需求关系不大,分数很低。
确定了匹配分数之后,你不是把书脊标签(Key)拿走,而是把每本书的实际内容(Value)按照匹配分数加权汇总。“并行计算导论”的内容占大比重,“操作系统”的内容占一点,“数据结构”几乎不占——最终你得到的就是一份以并行计算为主、兼顾一点操作系统知识的综合信息。
2.2 三元组的形式化定义
回到 Self-Attention 的语境。给定输入序列 (形状为 ),三个线性变换将 映射到不同的空间:
- :每个 token 生成自己的”查询向量”——“我需要什么信息”
- :每个 token 生成自己的”索引向量”——“我能提供什么信息的线索”
- :每个 token 生成自己的”内容向量”——“我实际携带的信息”
Q、K、V 之所以要从同一个 做三次不同的线性变换,而不是直接用 本身,原因在于解耦不同的角色。一个 token “需要什么”和它”能提供什么”往往是不同的。比如在一个句子里,动词可能需要关注它的主语和宾语(Query 的方向),但它作为被别人关注的对象时,提供的是动作语义(Key/Value 的方向)。三个独立的投影矩阵让模型有自由度去学习这些不同的映射。
3. Scaled Dot-Product Attention 详解
有了 Q、K、V 之后,Attention 的计算过程可以分解为清晰的五步。
3.1 第一步:线性投影
输入 的形状为 。三个权重矩阵 、、 的形状都是 (以单头为例):
这三次矩阵乘法是三次独立的 GEMM(General Matrix Multiply)操作。在实际实现中,为了提高 GPU 利用率,通常会把 、、 合并成一个大矩阵 (形状 ),做一次 GEMM 然后 split,这样能更好地利用 GPU 的算力。
3.2 第二步:计算原始注意力分数
用 和 的转置做矩阵乘法,衡量每对 token 之间的匹配程度:
结果矩阵 的每个元素 是第 个 token 的查询向量和第 个 token 的索引向量的点积,代表 token 对 token 的原始关注程度。
这一步产生了一个 的矩阵,这就是 Self-Attention 平方复杂度的根源。
3.3 第三步:缩放
将原始分数除以 ( 是 Key 向量的维度):
为什么需要缩放?简单来说,当维度 较大时,点积的结果会变得很大,导致后续 Softmax 的梯度趋近于零。详细的数学推导见第 6 节。
3.4 第四步:Softmax 归一化
对缩放后的分数矩阵按行做 Softmax,将原始分数转换为概率分布:
Softmax 的作用是双重的:一方面把任意实数映射到 区间,使其可以作为权重;另一方面保证每行的权重之和为 1,形成一个合法的概率分布。
3.5 第五步:加权求和
用注意力权重对 Value 矩阵做加权求和:
最终每个 token 位置得到一个 维的向量,里面融合了它所”关注”的所有 token 的信息,关注程度由 的权重决定。
3.6 完整公式
将以上步骤合并,Self-Attention 的计算可以用一行公式概括:
最后还要再过一个输出投影矩阵 (形状 ),将结果映射回模型的隐藏空间。这个投影在 Multi-Head Attention 中尤其重要——它负责将多个头拼接后的表示重新混合。
4. PyTorch 手动实现
理论讲得再多,不如亲手写一遍代码。下面分别实现 Single-Head 和 Multi-Head Self-Attention。
4.1 Single-Head Self-Attention
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class SingleHeadSelfAttention(nn.Module):
"""手动实现单头 Self-Attention,不依赖 PyTorch 内置的 MHA 模块"""
def __init__(self, d_model: int):
super().__init__()
self.d_model = d_model
# 三个线性投影:输入维度 d_model,输出维度 d_model
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
# 输出投影
self.W_O = nn.Linear(d_model, d_model, bias=False)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
# x: (batch, seq_len, d_model)
# Step 1: 线性投影得到 Q, K, V
Q = self.W_Q(x) # (batch, seq_len, d_model)
K = self.W_K(x) # (batch, seq_len, d_model)
V = self.W_V(x) # (batch, seq_len, d_model)
# Step 2: 计算注意力分数 = Q @ K^T
scores = torch.matmul(Q, K.transpose(-2, -1)) # (batch, seq_len, seq_len)
# Step 3: 缩放,防止点积值过大导致 softmax 梯度消失
scores = scores / math.sqrt(self.d_model)
# Step 4: 如果提供了 mask,将被遮蔽位置的分数设为负无穷
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Step 5: Softmax 归一化,得到注意力权重
attn_weights = F.softmax(scores, dim=-1) # (batch, seq_len, seq_len)
# Step 6: 用权重对 V 加权求和
context = torch.matmul(attn_weights, V) # (batch, seq_len, d_model)
# Step 7: 输出投影
output = self.W_O(context) # (batch, seq_len, d_model)
return output
4.2 Multi-Head Self-Attention
class MultiHeadSelfAttention(nn.Module):
"""手动实现多头 Self-Attention:将 d_model 切分为 num_heads 个子空间"""
def __init__(self, d_model: int, num_heads: int):
super().__init__()
assert d_model % num_heads == 0, "d_model 必须能被 num_heads 整除"
self.d_model = d_model
self.num_heads = num_heads
self.head_dim = d_model // num_heads # 每个头的维度
# QKV 投影:输入 d_model,输出 d_model(内部包含所有头的投影)
self.W_Q = nn.Linear(d_model, d_model, bias=False)
self.W_K = nn.Linear(d_model, d_model, bias=False)
self.W_V = nn.Linear(d_model, d_model, bias=False)
self.W_O = nn.Linear(d_model, d_model, bias=False)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
batch_size, seq_len, _ = x.shape
# Step 1: 线性投影
Q = self.W_Q(x) # (batch, seq_len, d_model)
K = self.W_K(x)
V = self.W_V(x)
# Step 2: 重塑为多头形状,并转置使 head 维度在 seq_len 前面
# (batch, seq_len, d_model) -> (batch, seq_len, num_heads, head_dim)
# -> (batch, num_heads, seq_len, head_dim)
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Step 3: 每个头独立计算注意力分数
# Q @ K^T: (batch, num_heads, seq_len, head_dim)
# @ (batch, num_heads, head_dim, seq_len)
# = (batch, num_heads, seq_len, seq_len)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
# Step 4: 可选的因果遮蔽
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Step 5: Softmax + 加权求和
attn_weights = F.softmax(scores, dim=-1)
# (batch, num_heads, seq_len, seq_len) @ (batch, num_heads, seq_len, head_dim)
# = (batch, num_heads, seq_len, head_dim)
context = torch.matmul(attn_weights, V)
# Step 6: 多头拼接——先转回 (batch, seq_len, num_heads, head_dim),再合并最后两维
context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
# Step 7: 输出投影,将拼接后的多头表示混合
output = self.W_O(context) # (batch, seq_len, d_model)
return output
上面的代码完整展示了从输入到输出的每一步。值得注意的是,.transpose(1, 2) 这一步将 (batch, seq_len, num_heads, head_dim) 变成 (batch, num_heads, seq_len, head_dim),这样后续的矩阵乘法就能在每个头内部独立进行,利用 batch 维度的并行性一次算完所有头。
5. 详细维度推导
抽象的维度符号容易让人迷糊。下面用一组具体数值,完整跟踪 Multi-Head Self-Attention 中每一步的张量形状。
5.1 参数设定
seq_len = 6 # 序列长度(6 个 token)
d_model = 512 # 模型隐藏维度
num_heads = 8 # 注意力头数
head_dim = 512 / 8 = 64 # 每个头的维度
batch = 1 # 简化为 batch=1
5.2 逐步跟踪
输入
X: (1, 6, 512) # 1 个样本,6 个 token,每个 token 512 维
线性投影
W_Q: (512, 512) # nn.Linear 的权重矩阵
Q = X @ W_Q^T: (1, 6, 512) @ (512, 512) = (1, 6, 512)
K = X @ W_K^T: (1, 6, 512) @ (512, 512) = (1, 6, 512)
V = X @ W_V^T: (1, 6, 512) @ (512, 512) = (1, 6, 512)
切分多头 + 转置
Q.view(1, 6, 8, 64): (1, 6, 8, 64)
Q.transpose(1, 2): (1, 8, 6, 64) # 8 个头,每个头看到 6 个 token 的 64 维向量
K reshape 同理: (1, 8, 6, 64)
V reshape 同理: (1, 8, 6, 64)
计算注意力分数
K.transpose(-2, -1): (1, 8, 64, 6) # 转置最后两维
Q @ K^T: (1, 8, 6, 64) @ (1, 8, 64, 6) = (1, 8, 6, 6)
# 每个头产生一个 6x6 的注意力分数矩阵
缩放
scores / sqrt(64): (1, 8, 6, 6) # sqrt(64) = 8,每个分数除以 8
Softmax
softmax(scores, dim=-1): (1, 8, 6, 6) # 每行(最后一维)归一化为概率分布
# 每个 6 维行的元素之和为 1
加权求和
attn_weights @ V: (1, 8, 6, 6) @ (1, 8, 6, 64) = (1, 8, 6, 64)
# 每个头输出 6 个 token 的 64 维表示
多头拼接
transpose(1, 2): (1, 6, 8, 64) # 把 head 维度挪回去
view(1, 6, 512): (1, 6, 512) # 8 * 64 = 512,拼接回 d_model
输出投影
W_O: (512, 512)
context @ W_O^T: (1, 6, 512) @ (512, 512) = (1, 6, 512)
最终输出
Output: (1, 6, 512) # 与输入 X 形状完全一致
从头到尾,输入和输出的形状都是 ,这保证了多个 Transformer Block 可以像积木一样层层堆叠。
6. 为什么要除以 :方差稳定性的数学推导
除以 这一步在直觉上容易被当作”调参技巧”一笔带过,但它背后有严格的数学原因。
6.1 问题提出
假设 和 的每个分量都是均值为 0、方差为 1 的独立随机变量。我们来计算点积 的方差。
设 ,,其中每个 和 独立同分布,满足 ,。
点积的定义是:
6.2 推导过程
首先,单个分量乘积的期望和方差:
由于 和 独立:
所以:
点积是 个独立随机变量之和,根据方差的可加性:
6.3 结论
点积 的方差是 。这意味着当 时,点积值的标准差大约是 8;当 时,标准差大约是 11.3。维度越大,点积值的绝对值越大,分布越分散。
Softmax 函数 对输入值的量级非常敏感。当输入值之间的差距很大时(比如一个值是 50,另一个是 -10),Softmax 的输出会极度集中在最大值上,接近 one-hot 分布。此时梯度几乎为零,参数无法更新,训练陷入停滞。
除以 之后:
方差被拉回到 1,无论 取什么值,Softmax 的输入始终在一个合理的范围内波动,梯度保持健康。这就是 Scaled Dot-Product Attention 名字中 “Scaled” 一词的由来。
7. Softmax 的数值稳定性
Softmax 看起来是一个简单的公式,但在实际的 GPU 实现中,它隐藏着一个容易导致数值溢出的陷阱。
7.1 溢出问题
标准 Softmax 公式:
问题在于 函数增长极快。当 值较大时(比如 ), 直接超出 float32 甚至 float16 的表示范围,得到 inf。即使分子分母同时溢出,inf / inf 会得到 NaN,整个计算就废了。
7.2 减最大值技巧
解决方案是利用 Softmax 的平移不变性:对所有输入减去最大值,不改变结果。
数学证明很简单——分子分母同乘以 ,等价于分子分母各自除以 ,值不变。减去最大值后,最大的指数输入是 0,所以 的结果最大为 1,不会溢出。
但这带来了一个效率问题:标准实现需要对数据做三遍扫描:
- 第一遍:遍历所有元素,找最大值
- 第二遍:遍历所有元素,计算 和它们的总和
- 第三遍:遍历所有元素,每个 除以
每遍扫描都要从 HBM(高带宽显存)读取数据,三遍就是三次 HBM 读取。当 很大时(比如 128K 的序列长度),这个 IO 开销非常大。
7.3 Online Softmax
Milakov 和 Gimelshein 在 2018 年提出了 Online Softmax 算法,将三遍扫描合并为一遍扫描,核心思想是在遍历数据的同时动态维护最大值和归一化分母。
算法流程如下。在遍历到第 个元素时,维护两个变量:
- :前 个元素的最大值
- :前 个元素的指数和(以 为基准)
递推公式:
关键技巧在第二行:当最大值从 更新到 时,之前累积的指数和 需要乘以一个修正因子 来”换基”。这样只需一遍扫描就能同时得到最大值和指数和,然后再做一遍扫描完成归一化——总共两遍,比标准实现少一遍。
这个思想对 FlashAttention 至关重要:FlashAttention 将 Attention 矩阵分成小块(tile)逐块处理,每处理一块就需要更新 Softmax 的中间结果。如果不能在线更新 Softmax,就必须把整个 矩阵写入 HBM 后再做全局 Softmax,那就失去了分块计算的意义。Online Softmax 让分块计算和精确 Softmax 得以兼容。
8. Multi-Head Attention
8.1 为什么需要多头
单头 Attention 只有一组 QKV 投影,意味着模型只能学习一种”关注模式”。但语言中 token 之间的关系是多维度的——同一个词和其他词之间可能同时存在句法关系(主谓一致)、语义关系(同义替换)、位置关系(相邻词的局部模式)等。
打个比方:在一次项目评审会议上,只派一个评审员去审阅整个项目,他只能从自己擅长的角度提出意见。如果派出一个评审团——一位看技术架构,一位看代码质量,一位看测试覆盖率,一位看文档完整性——每个人独立给出评分和建议,最后汇总成一份综合评审报告,覆盖面就远比单人评审要全面得多。
Multi-Head Attention 就是这个”评审团”机制。每个头有自己独立的 、、 投影参数,在不同的子空间中捕捉不同类型的关系。
8.2 数学原理
假设 ,,则 。
Multi-Head Attention 的完整公式:
其中每个头的投影矩阵 、、 的形状是 。但实际实现中,并不会真的维护 组小矩阵——而是用一个大矩阵 (形状 )做一次投影,然后 reshape 成 来切分。这样做是等价的:大矩阵可以看作 8 个小矩阵纵向拼接。
8.3 参数量分析
以 为例:
| 参数矩阵 | 形状 | 参数量 |
|---|---|---|
| 262,144 | ||
| 262,144 | ||
| 262,144 | ||
| 262,144 | ||
| 合计 |
通用公式:MHA 的参数量为 4 \times d_{model}^2\(不算 bias 的情况)。不管有多少个头,总参数量不变——头数只影响切分方式,不影响总参数。这是因为每增加一个头,每个头的维度相应减小,两者乘积(即总投影维度)始终等于 。
8.4 多头的工程意义
多头结构天然适合并行化。8 个头的计算完全独立,可以:
- GPU 内并行:利用 batch 维度,在一次 CUDA kernel 启动中同时处理所有头
- 多卡张量并行(Tensor Parallelism):将不同头分配到不同 GPU 上,每张 GPU 只计算自己负责的若干头。比如 8 个头分到 4 张 GPU,每张处理 2 个头。最后通过一次 AllReduce 通信汇总输出投影的结果
9. MHA vs MQA vs GQA vs MLA
随着大模型进入推理效率至上的时代,Attention 头的结构也在不断演化。核心驱动力是一个问题:KV Cache 太大了。
9.1 MHA (Multi-Head Attention)
标准的 MHA 中,每个注意力头都有独立的 Q、K、V。如果有 个头,那就有 组 K 和 组 V 需要缓存。
- 个 Q 头, 个 K 头, 个 V 头
- 每组 Q/K/V 独立,互不共享
9.2 MQA (Multi-Query Attention)
Shazeer 在 2019 年提出:所有 Q 头共享同一组 K 和 V。也就是说,只有 1 个 K 头和 1 个 V 头,但仍然有 个 Q 头。
- 个 Q 头,1 个 K 头,1 个 V 头
- 所有 Q 头共享同一份 K 和 V
好处是 KV Cache 缩小为原来的 ,推理速度大幅提升。代价是模型表达能力下降——所有头被迫从同一组 KV 中提取信息,多样性受限。
9.3 GQA (Grouped-Query Attention)
GQA 是 MHA 和 MQA 的折中方案(Ainslie et al., 2023)。将 个 Q 头分成 个组(每组 个 Q 头),每组共享一组 K 和 V。
- 个 Q 头, 个 K 头, 个 V 头
- 每 个 Q 头共享一组 K 和 V
- 当 时退化为 MHA,当 时退化为 MQA
GQA 在模型质量和推理效率之间取得了更好的平衡。LLaMA-2-70B、Mistral-7B 等主流模型都采用了 GQA。
9.4 对比表格
以 ,num_heads = 32,head_dim = 128 为例,假设序列长度 ,FP16 精度:
| 指标 | MHA () | GQA () | MQA () |
|---|---|---|---|
| Q 头数 | 32 | 32 | 32 |
| KV 头数 | 32 | 8 | 1 |
| 参数量 | |||
| 参数量 | |||
| 单 token KV Cache | |||
| 4096 token KV Cache / 层 | |||
| 模型质量 | 最好 | 接近 MHA | 有下降 |
注:Cache size = 2 × B × L × num_heads × head_dim × d_type_bytes (其中 2 是 K 和 V,B 是 batch_size,L 是已生成序列长度) 可以看到,从 MHA 到 GQA(),KV Cache 缩小为 1/4,参数量减少,但模型质量几乎不受影响。这就是 GQA 成为主流选择的原因。
9.5 MLA (Multi-Latent Attention)
DeepSeek-V2 提出了 MLA(Multi-Latent Attention),代表了另一种压缩 KV Cache 的思路。
MLA 的核心思想是:不直接缓存完整的 K 和 V 向量,而是将它们压缩到一个低维潜在空间(latent space),只缓存这个低维表示。推理时再将低维表示解压回完整的 K 和 V。
- 标准 MHA:(缓存 ,维度 = num_kv_heads head_dim)
- MLA:(缓存 ,维度远小于 的维度),推理时 按需解压
MLA 通过低秩压缩将 KV Cache 大幅缩小(DeepSeek-V2 报告了约 93.3% 的压缩率),同时通过精心设计的上投影矩阵保持了模型质量。它与 GQA 的思路不同——GQA 是减少 KV 头的数量,MLA 是降低每个表示的维度——但目标一致:让推理时的 KV Cache 尽可能小。
10. Causal Mask 的作用与实现
10.1 为什么需要遮蔽
在自回归语言模型(如 GPT 系列、LLaMA 系列)中,模型的训练目标是”根据前文预测下一个 token”。这要求在计算 Attention 时,每个位置的 token 只能看到自己和它之前的 token,不能”偷看”未来的信息。否则,模型在训练时就能看到答案,学不到任何有意义的预测能力。
10.2 实现方式
Causal Mask(因果遮蔽)是一个上三角矩阵,覆盖在注意力分数矩阵上,将未来位置的分数设为负无穷:
# 构造因果遮蔽矩阵
def create_causal_mask(seq_len: int) -> torch.Tensor:
"""
生成一个下三角矩阵(含对角线),上三角部分为 False。
为 True 的位置保留,为 False 的位置在 Attention 分数中被设为 -inf。
"""
mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool))
return mask
# 使用示例:seq_len = 4
# mask:
# [[True, False, False, False],
# [True, True, False, False],
# [True, True, True, False],
# [True, True, True, True ]]
#
# 被遮蔽后的 Attention 分数矩阵(上三角为 -inf):
# [[s00, -inf, -inf, -inf],
# [s10, s11, -inf, -inf],
# [s20, s21, s22, -inf],
# [s30, s31, s32, s33 ]]
#
# Softmax 之后,-inf 位置的权重变为 0,未来信息被完全屏蔽
Causal Mask 的另一个工程意义在于:它使得 Attention 矩阵只有下三角部分有效,理论上可以跳过上三角部分的计算。FlashAttention-2 正是利用了这一点——在处理上三角区域的 tile 时直接跳过,减少了接近一半的计算量。
11. Self-Attention 的计算瓶颈分析
理解 Self-Attention 的性能瓶颈,需要区分两种不同的场景。
11.1 计算量分析
Self-Attention 中有四次主要的矩阵乘法:
| 运算 | 形状 | FLOPs |
|---|---|---|
总计:。
当 较小(比如 )时, 项主导——瓶颈在 QKV 投影的 GEMM 运算。当 较大(比如 )时, 项主导——瓶颈在 Attention 矩阵的计算。
11.2 Compute Bound vs Memory Bound
GPU 的性能受两个指标限制:
- 算力(Compute):每秒能做多少次浮点运算(FLOPS)
- 带宽(Memory Bandwidth):每秒能从 HBM 读写多少数据(GB/s)
一个运算是 Compute Bound 还是 Memory Bound,取决于它的算术强度(Arithmetic Intensity)= FLOPs / Bytes。
- 如果算术强度 > GPU 的算力/带宽比值,运算是 Memory Bound——GPU 算力用不满,大部分时间在等数据搬运
- 如果算术强度 < GPU 的算力/带宽比值,运算是 Compute Bound——GPU 带宽够用,大部分时间在做计算
以 A100 为例:FP16 算力约 312 TFLOPS,HBM 带宽约 2 TB/s,算力/带宽比值约 156 FLOP/Byte。
Prefill 阶段(处理完整 prompt):
QKV 投影是大矩阵乘法,batch 维度(seq_len)很大,算术强度高,通常是 Compute Bound。Attention 的 也是大矩阵乘,同样是 Compute Bound。
Decode 阶段(逐 token 生成):
每步只有 1 个 token,QKV 投影退化为矩阵-向量乘法(GEMV),算术强度极低,是 Memory Bound。Attention 退化为一个向量和整个 KV Cache 的运算,同样 Memory Bound。大部分时间花在从 HBM 搬运模型权重和 KV Cache 上。
11.3 的显存瓶颈
标准 Attention 实现需要显式地计算并存储完整的 注意力矩阵。以 seq_len = 128、num_heads = 32 为例:
这显然远超任何单张 GPU 的显存容量。即使序列长度”只有”8K,注意力矩阵也需要 ,已经是一个不可忽视的开销。
这个 的显存瓶颈,正是 FlashAttention 要解决的核心问题。
12. FlashAttention 的核心思想
FlashAttention(Dao et al., 2022)是过去几年 Attention 优化领域最具影响力的工作。它不改变 Attention 的计算结果(是精确计算,不是近似),但通过重新编排计算顺序,大幅减少了对 HBM 的访问量。
12.1 GPU 存储层次回顾
要理解 FlashAttention,需要先了解 GPU 的存储层次:
- HBM(High Bandwidth Memory):GPU 的主显存,容量大(如 A100 的 80 GB)但访问速度相对慢(2 TB/s)
- SRAM(片上缓存):每个 SM(Streaming Multiprocessor)上的共享内存和寄存器,容量很小(如 A100 每个 SM 约 192 KB 共享内存,全部 SM 合计约 20 MB)但访问速度极快(约 19 TB/s)
标准 Attention 的流程是:在 HBM 中完成 ,把完整的 注意力矩阵写回 HBM;然后从 HBM 读取注意力矩阵做 Softmax,结果写回 HBM;最后从 HBM 读取 Softmax 结果和 做矩阵乘法。每一步都涉及对 大矩阵的 HBM 读写——这是巨大的带宽浪费。
12.2 Tiling:分块计算
FlashAttention 的第一个关键技术是 Tiling(分块):不一次性计算完整的 注意力矩阵,而是将 Q、K、V 分成若干小块(tile),每次只加载一小块 Q 和一小块 K、V 到 SRAM 中,在 SRAM 内完成该块的 Attention 计算,然后将结果写回 HBM。
标准实现:
Q, K 全部加载到 HBM → 计算完整 N x N 矩阵 → 写回 HBM → 读取做 Softmax → 写回 HBM → ...
FlashAttention:
将 Q 分成 T_r 块,K/V 分成 T_c 块
For 每一块 Q_i:
For 每一块 K_j, V_j:
从 HBM 加载 Q_i, K_j, V_j 到 SRAM(每块很小,能装下)
在 SRAM 内计算 Q_i @ K_j^T → 局部 Softmax → 乘以 V_j
累积到该块 Q_i 对应的输出中
将 Q_i 块的最终输出写回 HBM
每次只处理一个 tile,注意力矩阵的这一小块始终驻留在 SRAM 中,永远不需要把完整的 矩阵写入 HBM。
12.3 Online Softmax 的关键作用
Tiling 带来一个棘手的问题:Softmax 需要对每一行的所有元素做归一化(需要全局最大值和全局求和),但分块计算时每次只看到一行的一部分——怎么在只看到局部数据的情况下计算全局 Softmax?
这就是 Online Softmax 派上用场的地方。回忆第 7.3 节的 Online Softmax 递推公式:当新数据块到来时,可以动态更新全局最大值和归一化分母,并修正之前块的计算结果。
具体来说,处理 对 块的 Attention 后,得到一个局部输出 和对应的 Softmax 统计量(局部最大值 和局部分母 )。当继续处理 对 块时,会得到新的局部统计量 和 。通过 Online Softmax 的修正:
不断累积,直到所有 块处理完毕,最终结果与标准 Attention 的全局 Softmax 在数学上完全一致。
12.4 为什么能减少 HBM 访问
标准 Attention 的 HBM 访问量:
- 写入 (): 次写
- 读取 做 Softmax: 次读
- 写入 : 次写
- 读取 做 : 次读
- 总 HBM 访问量:O(N^2)\(加上 Q, K, V 本身的 读取)
FlashAttention 的 HBM 访问量:
- 读取 Q, K, V:
- 写入最终输出 O:
- 中间的注意力矩阵始终在 SRAM 中,不写入 HBM
- 总 HBM 访问量:
从 降到 ——当 远大于 时(长序列场景),这是一个数量级的改进。注意,计算量(FLOPs)没有变——仍然是 ——改变的只是 IO 模式。FlashAttention 的加速本质上来自于减少了对慢速 HBM 的访问,让计算尽可能在快速 SRAM 中完成。
12.5 FlashAttention-2 和 FlashAttention-3
FlashAttention-2(Dao, 2023)在 FlashAttention 的基础上进一步优化了并行策略和 warp 调度:
- 调整了内外循环的顺序(外循环遍历 Q 块,内循环遍历 K/V 块),减少了共享内存的读写
- 更好的 warp 级工作分配,提升了 GPU 的占用率
- 利用 Causal Mask 跳过全零的 tile,进一步减少计算量
FlashAttention-3(Dao et al., 2024)则针对 Hopper 架构(H100)进行了优化,利用了 TMA(Tensor Memory Accelerator)和 wgmma(Warp Group MMA)指令,进一步提升了硬件利用率。
目前 FlashAttention 已经是所有主流推理和训练框架的标配——PyTorch 2.0+ 内置了 torch.nn.functional.scaled_dot_product_attention,底层默认调用 FlashAttention kernel。
🎯 自我检验清单
完成本文学习后,用以下问题检验自己的理解深度:
- 能从 Bahdanau Attention 到 Self-Attention 梳理出 Attention 机制的演进脉络,说清每一步解决了什么问题
- 能默写 Scaled Dot-Product Attention 的完整公式,并解释 Q、K、V 三元组各自的作用
- 能从零推导 “除以 ” 的数学原因:假设 独立标准正态,推出点积方差为
- 能手写 Single-Head 和 Multi-Head Self-Attention 的 PyTorch 实现,能说清每一步的张量维度变化
- 能解释 Softmax 为什么要减最大值,以及 Online Softmax 的核心递推思想
- 给定 seq_len、、num_heads 的具体数值,能完整跟踪每一步的张量形状
- 能画出 MHA / MQA / GQA 的结构差异,并估算各自的 KV Cache 大小
- 能说清 Causal Mask 的作用和实现方式(上三角设为负无穷,Softmax 后归零)
- 能区分 Compute Bound 和 Memory Bound,并说清 Prefill 和 Decode 阶段各自的瓶颈类型
- 能说清 FlashAttention 的核心思想:Tiling + Online Softmax 如何将 HBM 访问从 降到
📚 参考资料
论文
- Neural Machine Translation by Jointly Learning to Align and Translate (Bahdanau et al., 2014):https://arxiv.org/abs/1409.0473 — Bahdanau Attention,注意力机制的开山之作
- Effective Approaches to Attention-based Neural Machine Translation (Luong et al., 2015):https://arxiv.org/abs/1508.04025 — Luong Attention,引入点积注意力
- Attention Is All You Need (Vaswani et al., 2017):https://arxiv.org/abs/1706.03762 — Transformer 原始论文
- Fast Transformer Decoding: One Write-Head is All You Need (Shazeer, 2019):https://arxiv.org/abs/1911.02150 — Multi-Query Attention
- GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints (Ainslie et al., 2023):https://arxiv.org/abs/2305.13245 — Grouped-Query Attention
- DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model (DeepSeek-AI, 2024):https://arxiv.org/abs/2405.04434 — Multi-Latent Attention
- FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness (Dao et al., 2022):https://arxiv.org/abs/2205.14135 — FlashAttention v1
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning (Dao, 2023):https://arxiv.org/abs/2307.08691 — FlashAttention v2
- Online Normalizer Calculation for Softmax (Milakov & Gimelshein, 2018):https://arxiv.org/abs/1805.02867 — Online Softmax 算法
教程与博客
- The Illustrated Transformer (Jay Alammar):https://jalammar.github.io/illustrated-transformer/ — 图文并茂的 Transformer 入门
- The Annotated Transformer (Harvard NLP):https://nlp.seas.harvard.edu/annotated-transformer/ — 论文逐行对应 PyTorch 实现
- Andrej Karpathy: Let’s build GPT from scratch:https://www.youtube.com/watch?v=kCc8FmEb1nY — 从零手写 GPT
- ELI5: FlashAttention (Aleksa Gordic):https://gordicaleksa.medium.com/eli5-flash-attention-5c44017022ad — FlashAttention 通俗讲解