内存瓶颈:
在训练过程中显存的用一般是四部分组成参数,梯度,优化器和中间激活值,前三项和参数量的的关系约为16倍。假如一个3B的模型,至少要48G的显存,加上中间激活着占用的显存,多张80G显存的A100,也力不从心!怎样在显存一定的情况下,克服内存墙是训练模型的关键之一。
优化方案:
即然参数,梯度和优化器GPU显存开销大,那就分级划片分成更小维度后,将它们放在不同 的GPU设备上,用到时候再进行读取。
Ps:如果内存紧张,中间激活值可采用激活重新计算的方式来节省内存,用时间换显存。
采用混合精度计算,参数,梯度和优化器占用显存和总参数量 的关系为:

对显存的进一步优化也就从这三方面下手,即零冗余优化,分为三个层次:
优化后的内存占比为:
当N比较大时, 显存占用相当于原来的
优化后的内存占比为:
当N比较大时, 显存占用相当于原来的
优化后的内存占比为:
当N比较大时, 显存占用非常的小。
通信量变化: 三种的内存优化措施,相应的通信量是如何? zero-1 和zero-2 相对于baseline 数据并行,通信量是没有变化的,zero-3变为1.5倍。 。
参考
[1] Samyam Rajbhandari , Jeff Rasley , Olatunji Ruwase, Yuxiong He. "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models" . arXiv:1910.02054
[2] Samyam Rajbhandari, Olatunji Ruwase, Jeff Rasley, Shaden Smith, Yuxiong He, 2021."ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning". arXiv:2104.07857