分布式训练
第5章:张量并行TP与序列并行SP
掌握 Megatron-LM 的张量并行方案(Column/Row Parallel Linear)、通信插入位置推导,以及序列并行对激活显存的优化
张量并行 序列并行 Megatron-LM TP SP NVLink
📖 本章概述
当单个 Transformer 层的参数或激活值就超出单卡显存时,需要将层内的矩阵运算切分到多卡——这就是张量并行(Tensor Parallelism, TP)。序列并行(SP)则进一步解决 TP 未覆盖区域的激活值冗余。
📑 章节结构
1. 为什么需要张量并行
- 数据并行的局限:DP/FSDP 只切分训练状态的冗余副本,但如果单层的前向计算本身就超出单卡算力/显存(如 FFN 的中间激活),DP 无能为力
- TP 的定位:将单个矩阵乘法 拆分到多卡并行计算,每卡只需存 的权重和对应的部分激活
- 带宽约束:TP 每层都需要通信,对延迟极其敏感 → 通常限制在 NVLink 互联的单机内(8卡以内)
2. Column Parallel Linear(按列切分)
- 切分方式:权重 按列切分为 ,每卡持有
- 计算过程:(每卡用完整输入 乘自己的列分片)
- 输出:各卡得到 的不同列分片
- 通信:输入 需要通过 Broadcast/复制到各卡(或本身各卡已有完整输入)
- 应用位置:FFN 的第一个线性层、Attention 的 QKV 投影
3. Row Parallel Linear(按行切分)
- 切分方式:权重 按行切分,每卡持有 (行方向的分片)
- 输入要求:输入 也需按列分片为 (与 Column Parallel 的输出对接)
- 计算过程:(每卡计算部分结果)
- 通信:各卡的 通过 AllReduce(求和)得到完整输出
- 应用位置:FFN 的第二个线性层、Attention 的输出投影
4. Transformer Block 的 TP 切分方案
- Attention 层:QKV 按 Column Parallel(多头天然可切分,每卡负责部分 Head),输出投影按 Row Parallel
- FFN 层:第一个线性层 Column Parallel(切分中间维度),第二个线性层 Row Parallel(合并输出)
- 通信插入点:每个 Transformer Block 的前向过程中需要 2 次 AllReduce(Attention 输出 + FFN 输出)
- 反向传播:前向的 AllReduce 对应反向的 AllReduce,通信量对称
- 每层通信量:前向 (2 次 AllReduce,每次传输 数据)
5. 序列并行(Sequence Parallelism, SP)
- 问题:LayerNorm 和 Dropout 不参与 TP 切分,它们的输入/输出在每卡上是完整的 → 激活值冗余
- 解决方案:在非 TP 区域(LayerNorm、Dropout),沿序列维度切分激活值
- 通信变化:TP 区域的 AllReduce 变为 ReduceScatter(进入非 TP 区域时)+ AllGather(回到 TP 区域时)
- 收益:Non-TP 区域的激活显存从完整 降至 ,总激活显存接近线性缩放
- 通信量变化:总通信量不变(AllReduce = ReduceScatter + AllGather),但激活显存大幅降低
6. GQA/MQA 下的 TP 切分
- 问题:GQA 的 KV Head 数量可能 < TP 度(如 KV Head = 4, TP = 8)
- 解决策略:
- 复制 KV Head(每个 TP rank 复制一份完整 KV Head)
- 部分 rank 共享 KV Head(分组策略)
- 影响分析:KV Head 复制导致 TP 的参数切分不完全均匀,但计算仍可并行
7. 动手实验
- 画出一个 Transformer Block 在 TP=4 下的切分图(标注每一步的通信操作)
- 阅读 Megatron-LM
ColumnParallelLinear和RowParallelLinear源码
🎯 本章学习目标
- 能推导 Column Parallel 和 Row Parallel 的输入输出形状与通信位置
- 能画出 Transformer Block(Attention + FFN)在 TP 下的完整切分图
- 能解释序列并行如何将 AllReduce 拆为 ReduceScatter + AllGather 以节省激活显存
- 能分析 GQA 模型在 TP 切分时 KV Head 不足的处理策略