第2章:优化器
理解主流优化器的演进逻辑与内部状态组成,掌握优化器显存开销分析方法,为 ZeRO 显存优化和混合精度训练打下基础
📖 本章概述
优化器是训练的”舵手”——模型算出梯度之后,怎么更新参数全靠它。本章不讲优化器的数学推导细节,而是聚焦于 AI Infra 工程师最需要关心的两个问题:优化器内部到底存了什么(直接决定显存占用)、分布式场景下优化器有什么特殊要求(大 Batch 收敛问题)。这两个问题是理解后续 ZeRO 显存优化和混合精度训练的直接前提。
📑 章节结构
0. 什么是优化器
神经网络训练的本质是最小化损失函数——模型对训练数据的预测结果与真实标签之间的差距。反向传播(backward)算出的梯度告诉我们损失函数在当前参数位置的”下坡方向”,但往哪走只是问题的一半,怎么走、走多远才是决定训练成败的关键。优化器就是回答后半个问题的组件。
优化器在训练循环中的位置
for batch in dataloader:
loss = model(batch) # 前向:算预测 → 算损失
loss.backward() # 反向:算梯度
optimizer.step() # 优化器:用梯度 + 内部状态更新参数
optimizer.zero_grad() # 清零梯度,准备下一轮
optimizer.step() 这一行看似简单,背后做了三件事:(1)读取每个参数的梯度;(2)结合优化器内部状态(动量、二阶矩等)计算更新量;(3)更新参数。优化器的内部状态在 step() 之间持久保存,这就是为什么它会持续占用显存。
1. 优化器演进:从 SGD 到 AdamW
- SGD:最朴素的优化器,参数更新规则 ,无额外状态,每个参数只需存一份梯度。好比沿着当前最陡的方向迈一步,简单直接但容易在山谷里来回震荡
- Momentum SGD:引入动量 ,每个参数多存一个动量缓冲区(一阶动量)。好比给下山的球加了惯性,不容易被小坑绊住,但需要额外存储动量状态
- Adam:同时维护一阶动量(梯度的指数移动平均 )和二阶动量(梯度平方的指数移动平均 ),相当于同时记住”往哪走”和”路有多颠”。自适应学习率让不同参数有不同的更新步长
- AdamW:修正了 Adam 中权重衰减(Weight Decay)的实现方式——Adam 将 Weight Decay 混入梯度中,与自适应学习率耦合导致正则化效果不对;AdamW 将其解耦为独立的参数缩放步骤。这是目前大模型训练的事实标准
2. 优化器状态的显存开销(本章重点)
理解优化器的显存开销是理解 ZeRO 的关键。以下分析以混合精度训练为前提(前向/反向用 FP16/BF16,优化器内部用 FP32 保持精度):
各优化器的每参数显存开销
| 优化器 | 额外状态 | 每参数显存 | 7B 模型总开销 |
|---|---|---|---|
| SGD | 无 | 4B(FP32 参数副本) | ~28 GB |
| Momentum SGD | 一阶动量 | 8B(FP32 参数副本 + FP32 动量) | ~56 GB |
| Adam / AdamW | 一阶动量 + 二阶动量 | 12B(FP32 参数副本 + FP32 + FP32 ) | ~84 GB |
显存占比分析
以 AdamW + BF16 混合精度训练 7B 模型为例:
- BF16 参数: = 14 GB
- BF16 梯度: = 14 GB
- 优化器状态: = 84 GB(FP32 参数副本 28GB + 一阶动量 28GB + 二阶动量 28GB)
- 总计:~112 GB,其中优化器状态占 75%
这就解释了为什么 ZeRO-1 首先拿优化器状态开刀——它是显存占比最大的部分,切分它性价比最高。
为什么优化器必须用 FP32
混合精度训练中,前向和反向用 BF16 提速,但优化器状态必须保持 FP32:
- 参数更新量通常极小(学习率 × 梯度 ≈ 1e-4 × 梯度),BF16 的精度不足以表示这些微小变化
- 多次累加微小更新后,FP16/BF16 的舍入误差会积累导致训练发散
- 因此优化器需要保存一份 FP32 的”主权重”(master weights),每步用 FP32 精度更新后再转回 BF16 供下一步前向使用
3. 大 Batch 优化器:LAMB 与 LARS
分布式数据并行会线性放大有效 Batch Size( 卡 = 倍 Batch),这引入了一个新问题:
大 Batch 训练的困境
- Batch Size 从几百增长到几千甚至几万时,梯度的方差下降、信噪比上升,理论上可以用更大的学习率加速收敛
- 但实际操作中,简单地线性放大学习率往往导致训练不稳定甚至发散
- 不同层的参数尺度差异巨大(如 Embedding 层 vs. LayerNorm 层),全局统一的学习率放大策略很难兼顾所有层
LARS(Layer-wise Adaptive Rate Scaling)
- 对每一层独立计算学习率缩放因子:(参数范数 / 梯度范数)
- 参数范数大、梯度小的层步子迈大一些,反之迈小一些
- 最初用于 ResNet 大 Batch 训练
LAMB(Layer-wise Adaptive Moments)
- 在 Adam 基础上引入 LARS 的逐层缩放思想
- 每层的更新量先由 Adam 计算(含自适应学习率),再乘以信赖域缩放因子
- 使 BERT 在 Batch Size 65536 下仍能稳定收敛,训练时间从 3 天压缩到 76 分钟
- 适用场景:数据并行卡数多(64+)、有效 Batch Size 极大时
4. 优化器与分布式训练的交叉点
本节梳理优化器知识如何衔接后续章节:
- → 第3章 数据并行:DDP 中每卡各存一份完整优化器状态,是显存冗余的最大来源
- → 第4章 ZeRO 系列:ZeRO-1 的核心就是切分优化器状态,将 75% 的显存冗余降为
- → 第7章 其他显存优化技术:混合精度训练中,FP32 主权重存在优化器内,理解优化器状态才能算清完整的显存账本
- → 大 Batch 训练:当数据并行度增大,可能需要从 AdamW 切换到 LAMB 以维持收敛性
🎯 本章学习目标
- 能说清 SGD → Momentum → Adam → AdamW 每一步演进解决了什么问题、引入了什么额外状态
- 能手算任意规模模型在 AdamW + 混合精度下的优化器状态显存开销,并解释为什么它占训练总显存的 75%
- 能解释混合精度训练中优化器为什么必须用 FP32(master weights 的必要性)
- 能回答”Batch Size 放大 10 倍,学习率怎么调”这类问题,知道 LAMB/LARS 在什么场景下有必要