分布式训练
第7章:其他显存优化技术
掌握 TP+PP+DP 的 3D 并行拓扑设计,以及混合精度、梯度累积、Activation Checkpointing、MoE 并行、长序列并行等关键训练优化技术
3D并行 混合精度 Activation Checkpointing MoE 长序列训练 Context Parallel
📖 本章概述
实际大模型训练需要多种并行策略和显存优化技术的组合。本章先讲清 3D 并行的拓扑设计原则,再逐一介绍配套的训练优化技术。
📑 章节结构
1. 3D 并行拓扑设计
- 通信域划分原则:
- TP 组:放在机内 NVLink 互联的 GPU 上(高带宽、低延迟)
- PP 组:跨机 InfiniBand 互联(只传激活值,通信量相对小)
- DP 组:跨节点(每步一次 AllReduce/ReduceScatter)
- 拓扑映射公式:
- 设计实例:64 卡集群(8 节点 × 8 GPU/节点)
- 方案 A:TP=8, PP=4, DP=2(TP 占满单机)
- 方案 B:TP=4, PP=8, DP=2(TP 半机,PP 跨更多机)
- 如何选择:取决于模型层数、单层大小、网络带宽
- 通信组构建:Megatron-LM 中
mpu.initialize_model_parallel()的分组逻辑
2. 混合精度训练
-
数据类型对比:
类型 位宽 指数位 尾数位 数值范围 适用场景 FP32 32 8 23 大 Master weight、优化器 FP16 16 5 10 小,易溢出 旧方案 BF16 16 8 7 与 FP32 相同 大模型训练首选 FP8 (E4M3) 8 4 3 很小 H100+ 的前向计算 -
BF16 混合精度流程:前向/反向用 BF16 计算,参数更新用 FP32 Master Weight
-
FP8 训练:H100 Transformer Engine,前向用 FP8 + 动态缩放(per-tensor scaling),反向梯度用 BF16
-
Loss Scaling:FP16 训练必须的梯度缩放技巧(BF16 因范围大通常不需要)
3. 梯度累积(Gradient Accumulation)
- 目的:在有限显存下模拟更大的 Effective Batch Size
- 机制:连续 步前向+反向累积梯度,每 步做一次参数更新
- Effective Batch Size = 单卡 batch × world_size × accumulation_steps
- 与 PP 的配合:PP 的 micro-batch 本身就实现了类似效果;梯度累积进一步放大
- 注意事项:Loss 需要除以 (或等价地梯度除以 )
4. Activation Checkpointing(激活重计算)
- 问题:前向保存的激活值随层数和 batch size 线性增长,大模型训练中激活显存可能超过参数显存
- 核心思想:只保存部分层的激活值(checkpoint),其余层反向时重新前向计算
- Full Checkpointing:只保存每个 Transformer Block 的输入,反向时重算整个 Block
- 显存节省:激活显存从 降至 (每 层存一个 checkpoint)
- 计算代价:约增加 33% 的前向计算量
- Selective Checkpointing:只重算计算量小但激活大的操作(如 Attention 的 softmax 输出),保留计算量大的操作的激活
- 配合 TP/PP 使用:PP 天然每个 Stage 只需存自己的激活;TP 下 SP 已减少激活冗余
5. MoE 并行(Expert Parallelism)
- MoE 模型结构:Router 选择 Top-K Expert,非密集计算
- Expert Parallelism(EP):不同 Expert 放在不同 GPU 上
- All-to-All 通信:token 根据 Router 选择发送到对应 Expert 所在 GPU,计算完成后再 All-to-All 发回
- 组合策略:EP + DP + TP + PP 的四维并行
- EP 组与 DP 组正交:通常在 DP 维度内进一步划分 EP
- 通信量分析:All-to-All 通信量取决于 Expert 数量和 Top-K
- 负载均衡:Auxiliary Loss 引导 Router 均匀分发 token
6. 长序列训练
- 问题:Attention 的计算量和激活显存与序列长度 呈 关系,超长序列(128K+)单卡放不下
- Ring Attention:将序列切分到多卡,每卡只计算部分 QK 交互,通过 Ring 传递 KV 分片
- Ulysses(DeepSpeed):沿 Head 维度切分,All-to-All 重排为序列维度切分
- Context Parallel(Megatron-LM):沿序列维度切分,配合 FlashAttention 的分块计算
- 与 TP/SP 的区别:SP 切分的是非 Attention 区域的激活;CP/Ring Attention 切分的是 Attention 计算本身
🎯 本章学习目标
- 能为给定规模的集群设计 3D 并行方案并说明通信组划分理由
- 能解释 BF16 混合精度训练的完整流程(含 Master Weight)
- 能分析 Activation Checkpointing 的显存节省量和计算代价
- 能描述 MoE Expert Parallelism 的 All-to-All 通信模式
- 能区分 SP、CP、Ring Attention 的切分维度和适用场景