跳到主要内容
分布式训练

第5章:张量并行TP与序列并行SP

掌握 Megatron-LM 的张量并行方案(Column/Row Parallel Linear)、通信插入位置推导,以及序列并行对激活显存的优化

张量并行 序列并行 Megatron-LM TP SP NVLink

📖 本章概述

当单个 Transformer 层的参数或激活值就超出单卡显存时,需要将层内的矩阵运算切分到多卡——这就是张量并行(Tensor Parallelism, TP)。序列并行(SP)则进一步解决 TP 未覆盖区域的激活值冗余。


📑 章节结构

1. 为什么需要张量并行

  • 数据并行的局限:DP/FSDP 只切分训练状态的冗余副本,但如果单层的前向计算本身就超出单卡算力/显存(如 FFN 的中间激活),DP 无能为力
  • TP 的定位:将单个矩阵乘法 Y=XAY = XA 拆分到多卡并行计算,每卡只需存 1N\frac{1}{N} 的权重和对应的部分激活
  • 带宽约束:TP 每层都需要通信,对延迟极其敏感 → 通常限制在 NVLink 互联的单机内(8卡以内)

2. Column Parallel Linear(按列切分)

  • 切分方式:权重 AA 按列切分为 [A1,A2,...,AN][A_1, A_2, ..., A_N],每卡持有 AiA_i
  • 计算过程Yi=XAiY_i = X \cdot A_i(每卡用完整输入 XX 乘自己的列分片)
  • 输出:各卡得到 YY 的不同列分片 YiY_i
  • 通信:输入 XX 需要通过 Broadcast/复制到各卡(或本身各卡已有完整输入)
  • 应用位置:FFN 的第一个线性层、Attention 的 QKV 投影

3. Row Parallel Linear(按行切分)

  • 切分方式:权重 AA 按行切分,每卡持有 AiA_i(行方向的分片)
  • 输入要求:输入 XX 也需按列分片为 XiX_i(与 Column Parallel 的输出对接)
  • 计算过程Yi=XiAiY_i = X_i \cdot A_i(每卡计算部分结果)
  • 通信:各卡的 YiY_i 通过 AllReduce(求和)得到完整输出 YY
  • 应用位置:FFN 的第二个线性层、Attention 的输出投影

4. Transformer Block 的 TP 切分方案

  • Attention 层:QKV 按 Column Parallel(多头天然可切分,每卡负责部分 Head),输出投影按 Row Parallel
  • FFN 层:第一个线性层 Column Parallel(切分中间维度),第二个线性层 Row Parallel(合并输出)
  • 通信插入点:每个 Transformer Block 的前向过程中需要 2 次 AllReduce(Attention 输出 + FFN 输出)
  • 反向传播:前向的 AllReduce 对应反向的 AllReduce,通信量对称
  • 每层通信量:前向 2×2bsh2 \times 2bsh(2 次 AllReduce,每次传输 2bsh2bsh 数据)

5. 序列并行(Sequence Parallelism, SP)

  • 问题:LayerNorm 和 Dropout 不参与 TP 切分,它们的输入/输出在每卡上是完整的 → 激活值冗余
  • 解决方案:在非 TP 区域(LayerNorm、Dropout),沿序列维度切分激活值
  • 通信变化:TP 区域的 AllReduce 变为 ReduceScatter(进入非 TP 区域时)+ AllGather(回到 TP 区域时)
  • 收益:Non-TP 区域的激活显存从完整 bshbsh 降至 bshN\frac{bsh}{N},总激活显存接近线性缩放
  • 通信量变化:总通信量不变(AllReduce = ReduceScatter + AllGather),但激活显存大幅降低

6. GQA/MQA 下的 TP 切分

  • 问题:GQA 的 KV Head 数量可能 < TP 度(如 KV Head = 4, TP = 8)
  • 解决策略
    • 复制 KV Head(每个 TP rank 复制一份完整 KV Head)
    • 部分 rank 共享 KV Head(分组策略)
  • 影响分析:KV Head 复制导致 TP 的参数切分不完全均匀,但计算仍可并行

7. 动手实验

  • 画出一个 Transformer Block 在 TP=4 下的切分图(标注每一步的通信操作)
  • 阅读 Megatron-LM ColumnParallelLinearRowParallelLinear 源码

🎯 本章学习目标

  • 能推导 Column Parallel 和 Row Parallel 的输入输出形状与通信位置
  • 能画出 Transformer Block(Attention + FFN)在 TP 下的完整切分图
  • 能解释序列并行如何将 AllReduce 拆为 ReduceScatter + AllGather 以节省激活显存
  • 能分析 GQA 模型在 TP 切分时 KV Head 不足的处理策略