首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >ZeRo零冗余优化器参数更新过程和通信量计算

ZeRo零冗余优化器参数更新过程和通信量计算

作者头像
AI老马
发布2026-01-13 14:50:03
发布2026-01-13 14:50:03
740
举报
文章被收录于专栏:AI前沿技术AI前沿技术
1,ZeRo-2 参数更新和通信量计算

策略:将模型的梯度和优化器状态进行分片。每个GPU上,保存一份完整的参数副本,以及分片后的梯度和优化器状态。

内存:相对于数据并行,内存最大能减少到原来的1/8倍。

通信量:通信量与数据并行一致,为2M (M为模型参数量)。

整体更新流程:

  • • 数据分片和梯度划分。模型梯度自动分片,batch数据分配到不同GPU上。
  • • 前向计算。每个GPU上有模型参数副本,计算模型损失。因为有完整参数副本,前向计算无需参数通信,。
  • • 反向传递损失,计算梯度。假设:模型有N层,3个GPU设备,模型梯度均分成3份, 层,由GPU0负责,层,由GPU1负责,层,由GPU2负责。更新过程:3个GPU设备前向计算完成后,同时反向传播,当3个GPU设备计算完层的梯度时,开始使用集合通信算法 Reduce-Scatter 通信,GPU0获得每个设备层的平均梯度,并保存下来,其他GPU设备上的层梯度就被丢弃,这部分内存被节省下来。同样过程,GPU1获得层的梯度平均,GPU2获得层的梯度平均。整个过程每个GPU的梯度占用内存,只相当于原来的1/3。优化器状态:对应的优化器状态也仅仅保留每个GPU负责梯度对应的部分即可。
  • • 完整模型参数更新 每个GPU获得负责层的平均梯度后,即可对参数更新,参数更新后,通过 All-gather算法,使得每个GPU上获得一致的参数。如图所示。
  • • 下一轮前向计算开始,如此往复。

以上流程中总的通信量: 前向计算 0 + 反向传播 M + 参数更新同步 M = 2M

2,ZeRo-3 参数更新和通信量计算

策略:将模型的参数,梯度和优化器状态进行分片。每个GPU上保存分片后的参数,梯度和优化器状态。

内存:相对于数据并行,内存最大能减少到原来的1/N倍(N为GPU数量)。

通信量:通信量是数据并行的1.5倍,为3M (M为模型参数量)。

整体更新流程:

  • • 参数和梯度划分。不同GPU保存对应分片后的参数和梯度。数据分片,batch数据分配到不同GPU上。
  • • 前向计算时,进行参数通信获得需要参数。 比如,需要进行 层的前向计算,此部分参数仅保存在GPU2上,需要GPU2进行broadcast通信,将 层的参数广播到GPU0和GPU1,GPU0和GPU1用完后,即可丢弃此部分参数。同样,进行 层时,GPU1进行广播参数,进行 层时,GPU0进行广播参数。每个GPU获得各自的前向损失。总的通信量为M。
  • • 反向传递损失,计算梯度。 与ZeRo-2一样,3个GPU设备前向计算完成后,同时反向传播,对梯度进行 Reduce-Scatter通信,每个GPU设备获得负责层的梯度平均,通信量为M。注意!此时每个GPU设备上并没有完整的参数,需要再次的进行参数通信,1)参与计算梯度和2)重新计算前向的激活值,参数的通信量为 M。 反向时总的通信量为2M
  • • 参数更新 每个设备获得的根据平均梯度,更新对应负责参数和优化器状态。
  • • 下一轮前向计算开始,如此往复。

以上流程中总的通信量: 前向计算 M + 反向传播 2M + 参数更新同步 0 = 3M

3,ZeRo-2&3策略总结

ZeRo通过分片策略实现内存与通信的权衡,其核心思想是 “以通信换内存”

维度

ZeRo-2

ZeRo-3

分片策略

仅分片梯度(Gradients)和优化器状态(Optimizer States)

分片参数(Parameters)、梯度和优化器状态

内存优化倍数

最大减少至数据并行的 1/8

最大减少至数据并行的 1/N(N为GPU数量)

总通信量

2M (反向传播M + 参数同步M)

3M (前向M + 反向2M)

前向计算

无需参数通信(每GPU保留完整参数副本)

需动态广播分片参数(如逐层广播,通信量M)

反向传播

使用 Reduce-Scatter 分片梯度(通信量M)

分两次通信:1)梯度Reduce-Scatter(M)2)参数计算梯度和激活值(M)

参数更新

各GPU更新完整参数后,通过 All-Gather 同步(通信量M)

仅更新本地分片参数(无需同步)

4,利用ZeRo优化内存-代码实现

使用pytorch 官方集成的零冗余优化器[2],默认是对梯度和优化器状态进行分片存储,如果也需要对参数进行分片,需要结合FSDP进行。

代码语言:javascript
复制
if use_zero:
  optimizer = ZeroRedundancyOptimizer(
    ddp_model.parameters(), optimizer_class=torch.optim.Adam, lr=0.001
  )
  else:
    optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.001)

从结果上看,使用优化器后,模型训练占用的显存从2003M 降到了 1248M。

代码语言:javascript
复制
max memory allocated after creating local model: 335.0MB
max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 996.0MB
Max memory allocated after optimizer step(): 1248.0MB
params sum is: 80040000
------- not using Zero ---------
max memory allocated after creating local model: 335.0MB
max memory allocated after creating DDP: 656.0MB
Max memory allocated before optimizer step(): 996.0MB
Max memory allocated after optimizer step(): 2003.0MB
params sum is: 80040000

参考:

代码语言:javascript
复制
[1] Samyam Rajbhandari, Jeff Rasley, Olatunji Ruwase, Yuxiong He. "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models".arXiv:1910.02054 
[2] 张奇、桂韬、郑锐、黄萱菁 《大规模语言模型从理论到实践》 中国工信出版社。2024.1
本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2025-03-14,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI老马啊 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1,ZeRo-2 参数更新和通信量计算
  • 2,ZeRo-3 参数更新和通信量计算
  • 3,ZeRo-2&3策略总结
  • 4,利用ZeRo优化内存-代码实现
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档