混合精度训练过程中显存占用主要来自参数、梯度、优化器和中间激活值。仅参数、梯度和优化器占用内存为参数量的16倍,假如全参训练一个7.5B的模型,至少要120G的显存,传统的训练策略,有极大的内存优化空间。
1)三种零冗余优化器的分层优化方法。 2)不同的内存优化方案,梯度和参数是如何同步,如何更新? 3)每种优化策略的优势和适用场景介绍。
内存占用:采用混合精度训练,参数,梯度和优化器占用显存和总参数量M 的关系为:
大模型训练内存占用参考[模型训练占用显存分析]
核心思想:即然参数、梯度和优化器GPU显存开销大,那就分层划片成更小维度后,将它们放在不同 的GPU设备上,用到时候再进行读取。
对显存的进一步优化也就从这三方面下手,即零冗余优化,分为三个层次:
优化后的内存占比为:,当N比较大时, 显存占用相当于原来的 。
优化后的内存占比为:,当N比较大时, 显存占用相当于原来的 。
优化后的内存占比为:,当N比较大时, 显存占用非常的小。

策略:将模型的梯度和优化器状态进行分片。每个GPU上,保存一份完整的参数副本,以及分片后的梯度和优化器状态。
内存:相对于数据并行,内存最大能减少到原来的1/8倍。
通信量:通信量与数据并行一致,为2M (M为模型参数量)。
更新流程:
以上流程中总的通信量: 前向计算 0 + 反向传播 M + 参数更新同步 M = 2M

策略:将模型的参数,梯度和优化器状态进行分片。每个GPU上保存分片后的参数,梯度和优化器状态。
内存:相对于数据并行,内存最大能减少到原来的1/N倍(N为GPU数量)。
通信量:通信量是数据并行的1.5倍,为3M (M为模型参数量)。
更新流程:
以上流程中总的通信量: 前向计算 M + 反向传播 2M + 参数更新同步 0 = 3M

ZeRo通过分片策略实现内存与通信的权衡,其核心思想是 “以通信换内存”
维度 | ZeRo-2 | ZeRo-3 |
|---|---|---|
分片策略 | 仅分片梯度(Gradients)和优化器状态(Optimizer States) | 分片参数(Parameters)、梯度和优化器状态 |
内存优化倍数 | 最大减少至数据并行的 1/8 | 最大减少至数据并行的 1/N(N为GPU数量) |
总通信量 | 2M (反向传播M + 参数同步M) | 3M (前向M + 反向2M) |
前向计算 | 无需参数通信(每GPU保留完整参数副本) | 需动态广播分片参数(如逐层广播,通信量M) |
反向传播 | 使用 Reduce-Scatter 分片梯度(通信量M) | 分两次通信:1)梯度Reduce-Scatter(M)2)参数计算梯度和激活值(M) |
参数更新 | 各GPU更新完整参数后,通过 All-Gather 同步(通信量M) | 仅更新本地分片参数(无需同步) |
适用场景 | 中等规模模型(内存压力主要来自梯度和优化器状态) | 超大规模模型(内存压力来自参数本身,如GPT-3等) |
使用pytorch 官方集成的零冗余优化器,默认是对梯度和优化器状态进行分片存储,如果也需要对参数进行分片,需要结合FSDP进行。
if use_zero:
optimizer = ZeroRedundancyOptimizer(
ddp_model.parameters(), optimizer_class=torch.optim.Adam, lr=0.001
)
else:
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.001)从结果上看,使用优化器后,模型训练占用的显存从2003M 降到了 1248M。
max memory allocated after creating local model: 335.0MB
max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 996.0MB
Max memory allocated after optimizer step(): 1248.0MB
params sum is: 80040000
------- not using Zero ---------
max memory allocated after creating local model: 335.0MB
max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 996.0MB
Max memory allocated after optimizer step(): 2003.0MB
params sum is: 80040000参考: [1] arXiv:1910.02054 [2] arXiv:1910.02054 [3] arXiv:2104.07857