跳到主要内容
分布式训练

第7章:其他显存优化技术

掌握 TP+PP+DP 的 3D 并行拓扑设计,以及混合精度、梯度累积、Activation Checkpointing、MoE 并行、长序列并行等关键训练优化技术

3D并行 混合精度 Activation Checkpointing MoE 长序列训练 Context Parallel

📖 本章概述

实际大模型训练需要多种并行策略和显存优化技术的组合。本章先讲清 3D 并行的拓扑设计原则,再逐一介绍配套的训练优化技术。


📑 章节结构

1. 3D 并行拓扑设计

  • 通信域划分原则
    • TP 组:放在机内 NVLink 互联的 GPU 上(高带宽、低延迟)
    • PP 组:跨机 InfiniBand 互联(只传激活值,通信量相对小)
    • DP 组:跨节点(每步一次 AllReduce/ReduceScatter)
  • 拓扑映射公式world_size=TP×PP×DP\text{world\_size} = TP \times PP \times DP
  • 设计实例:64 卡集群(8 节点 × 8 GPU/节点)
    • 方案 A:TP=8, PP=4, DP=2(TP 占满单机)
    • 方案 B:TP=4, PP=8, DP=2(TP 半机,PP 跨更多机)
    • 如何选择:取决于模型层数、单层大小、网络带宽
  • 通信组构建:Megatron-LM 中 mpu.initialize_model_parallel() 的分组逻辑

2. 混合精度训练

  • 数据类型对比

    类型位宽指数位尾数位数值范围适用场景
    FP3232823Master weight、优化器
    FP1616510小,易溢出旧方案
    BF161687与 FP32 相同大模型训练首选
    FP8 (E4M3)843很小H100+ 的前向计算
  • BF16 混合精度流程:前向/反向用 BF16 计算,参数更新用 FP32 Master Weight

  • FP8 训练:H100 Transformer Engine,前向用 FP8 + 动态缩放(per-tensor scaling),反向梯度用 BF16

  • Loss Scaling:FP16 训练必须的梯度缩放技巧(BF16 因范围大通常不需要)

3. 梯度累积(Gradient Accumulation)

  • 目的:在有限显存下模拟更大的 Effective Batch Size
  • 机制:连续 KK 步前向+反向累积梯度,每 KK 步做一次参数更新
  • Effective Batch Size = 单卡 batch × world_size × accumulation_steps
  • 与 PP 的配合:PP 的 micro-batch 本身就实现了类似效果;梯度累积进一步放大
  • 注意事项:Loss 需要除以 KK(或等价地梯度除以 KK

4. Activation Checkpointing(激活重计算)

  • 问题:前向保存的激活值随层数和 batch size 线性增长,大模型训练中激活显存可能超过参数显存
  • 核心思想:只保存部分层的激活值(checkpoint),其余层反向时重新前向计算
  • Full Checkpointing:只保存每个 Transformer Block 的输入,反向时重算整个 Block
    • 显存节省:激活显存从 O(L)O(L) 降至 O(L)O(\sqrt{L})(每 L\sqrt{L} 层存一个 checkpoint)
    • 计算代价:约增加 33% 的前向计算量
  • Selective Checkpointing:只重算计算量小但激活大的操作(如 Attention 的 softmax 输出),保留计算量大的操作的激活
  • 配合 TP/PP 使用:PP 天然每个 Stage 只需存自己的激活;TP 下 SP 已减少激活冗余

5. MoE 并行(Expert Parallelism)

  • MoE 模型结构:Router 选择 Top-K Expert,非密集计算
  • Expert Parallelism(EP):不同 Expert 放在不同 GPU 上
  • All-to-All 通信:token 根据 Router 选择发送到对应 Expert 所在 GPU,计算完成后再 All-to-All 发回
  • 组合策略:EP + DP + TP + PP 的四维并行
    • EP 组与 DP 组正交:通常在 DP 维度内进一步划分 EP
    • 通信量分析:All-to-All 通信量取决于 Expert 数量和 Top-K
  • 负载均衡:Auxiliary Loss 引导 Router 均匀分发 token

6. 长序列训练

  • 问题:Attention 的计算量和激活显存与序列长度 ssO(s2)O(s^2) 关系,超长序列(128K+)单卡放不下
  • Ring Attention:将序列切分到多卡,每卡只计算部分 QK 交互,通过 Ring 传递 KV 分片
  • Ulysses(DeepSpeed):沿 Head 维度切分,All-to-All 重排为序列维度切分
  • Context Parallel(Megatron-LM):沿序列维度切分,配合 FlashAttention 的分块计算
  • 与 TP/SP 的区别:SP 切分的是非 Attention 区域的激活;CP/Ring Attention 切分的是 Attention 计算本身

🎯 本章学习目标

  • 能为给定规模的集群设计 3D 并行方案并说明通信组划分理由
  • 能解释 BF16 混合精度训练的完整流程(含 Master Weight)
  • 能分析 Activation Checkpointing 的显存节省量和计算代价
  • 能描述 MoE Expert Parallelism 的 All-to-All 通信模式
  • 能区分 SP、CP、Ring Attention 的切分维度和适用场景