
大模型训练和推理大多是内存访问密集型,尤其推理时动辄几K长度的输入,内存大小,读写频次成为难以逾越的性能瓶颈。如何突破内存墙限制?本文围绕FlashAttention通过分块计算和动态重计算技术,对注意力机制进行优化,在保证准确性的同时,如何大幅优化注意力机制的内存效率和计算速度。主要包括:
1,GPU的内存层级架构特点介绍
2,为什么标准的attention需要频繁的访问HBM内存,进行读写操作,
3, Flash-attention增量式更新softmax值的公式推导过程
训练数据最终需要从主机内存流转到GPU内存中,了解不同内存的大小,访问速度等,有助于理解在模型训练中,不同内存所起到的作用,训练瓶颈内存墙的存在,以及对应优化方案所解决的问题。
位于 GPU芯片上的超高速静态内存,用于缓存频繁访问的数据和指令,减少对显存(HBM)的访问延迟。在FlashAttention中用于缓存分块后的Q、K、V矩阵,实现局部Softmax计算。
GPU的板载显存,用于存储模型参数、激活值、梯度等大规模数据,通过高带宽接口与计算核心直接交互。HBM存储完整的Q、K、V矩阵。
位于 CPU侧的系统内存,存储操作系统、应用程序和未加载到GPU的数据。GPU 通过PCle 或NVLink总线访问主内存,需显式拷贝数据到 HBM。在GPU计算中,受PCle带宽限制,数据从CPU DRAM拷贝到GPU HBM,成为异构训练的性能瓶颈。
总结:
内存类型 | 位置 | 容量 | 带宽 | 特点 |
|---|---|---|---|---|
GPU SRAM | 片上 SM内 | ~17MB | ~19TB/s | 最快 最小 |
GPU HBM | GPU板载内存 | 40/80G | ~1.5TB/s | 次快 中等 |
CPU DRAM | 主机内存 | 512G-2T | 50-400GB/s | 最慢 最大 |
标准注意力计算的核心问题是:显存 (HBM) 与片上内存(SRAM) 的频繁数据交换,导致高延迟和低效率。
标准的注意力计算分为三步,以单头注意力为例。
输入和输出矩阵维度信息 ,其中N表述序列长度,d表示多头注意力单头维度。在进行标准的注意力计算时,需要引入两个变量 S 和 P。S 为注意力分值 Attention Scores,P 对 S 逐行的 softmax 得到,可理解为 Normalized Attention Scores。
公式:
操作:从HBM 分块加载Q、K到 SRAM,计算分块乘积后写回S到HBM。
内存交互:需将 Q、K从HBM 加载到 SRAM ,因为S太大无法全留在SRAM,生成的S需完整写回HBM,占用 空间。对Q和K各读一次,对S写一次,共三次交互。
公式:
操作:从 HBM读取 S到 SRAM, 逐行计算softmax后,将 P值写回到HBM
内存交互:需加载 S的分块到 SRAM,逐行计算Softmax得到 P,生成的P需完整写回HBM,再次占用 空间。对S读一次,对P 写一次,共两个交互。
公式:
操作:从HBM 分块加载P、V到 SRAM,计算后写回O到HBM
内存交互:需加载P、V的分块到SRAM进行分块计算,生成的 O 写回 HBM,占用 空间。对P和V各读一次,对O写一次,共三次操作。
解决问题:为了防止指数最大值溢出,在稳定版的softmax计算中,每个元素减去最大值 ,然后再除以所有元素的和。
softmax 公式:
最大值定义:
所有元素和:
其中:
公式(1),可改写为公式:
公式(5)中的分子为向量,分母为标量,计算过程为逐元素相除。
对于softmax不能分块计算,最大的障碍来自于公式(5)的分母,因为它需要所有元素的和
小结
标准注意力的性能瓶颈主要来自HBM的频繁读写。因为GPU片上SRAM内存容量有限,每个计算核心SM(streaming multiprocessor)更是小的可怜!无法缓存完整的中间值 S 和 P。每次计算需分块加载数据到SRAM,计算后写回HBM。
具体表现为,计算 S 和 P 需两次全量读写HBM即,读 QK -> 写S -> 读 S ->写 P -> 读 PV -> 写 O。总HBM访问量为 与序列长度平方成正比。
在计算attention时,最大程度的减少读写HBM内存,成为进一步提高模型训练和推理效率的关键。
核心思想:通过分块计算 tiling,动态重计算和内核融合,将计算保留在SRAM中,大幅减少HBM的访问,从而突破内存墙限制。
假设输入的向量 , 将其切分为两个向量 , 两个分块向量依次计算。
最大值:
逐元素减去最大值:
求和:
局部softmax:
之所以称之为局部:因为公式(7)并没有减去全局的最大值,而是当前块的最大值,公式(9)的分母应该是全局的和,而不仅仅是当前块全部元素的和。
每个分块计算后的局最大值和局部元素和,两个标量也需要保存,即:。
定义两个全局标量,全局最大值和全局元素的和 。
此时只处理了分块1,全局标量和局部标量相同,即
最大值:
逐元素减去最大值:
求和:
局部softmax:
此时计算得到的softmax 也是局部的。
更新全局最大值:
更新全局元素和:
问题:为什么公式(15)中多乘上一个多项式 就更新到全局了,推导下:
此时公式(16)为全局。
分块的全局和已经减去了局部的最大值,要更新为全局的必须先加回来,然后再减去更新后的全局最大值。
以上可以当需要把局部的求和 l 更新为全局时,只需要多加一个多项式
此时通过公式(15)可以得到当前全局元素的和就不难理解了。
下面通过softmax 公式(13)直接推导,先看分子:
此时重新计算公式(13)
公式(18)中更新分块2的softmax时,不再需要分块2的原始向量元素,只需要保存下来的标量值 和当前需要更新的局部softmax值 。
同理将以上标号换成分块1,即可更新分块1的局部softmax到全局。
标准注意力的计算受限于HBM与SRAM的频繁数据交换,尤其是中间矩阵S和P的的开销。Flash attention 通过分块计算,动态重计算和内核融合,将计算保留在SRAM中,大幅减少HBM的访问,从而突破内存墙限制。
步骤 | 标准注意力 | Flash attention |
|---|---|---|
中间矩阵存储 | 必须存储 到HBM | 不存储S和P,只缓存部分统计量 |
HBM访问次数 | 读写S和P | 不存储中间值 |
softmax计算 | 全局逐行计算,依赖HBM同步 | 增量式分块计算,全部在SRAM中完成 |
元素融合操作 | 多次独立读写HBM | 单次分块融合 softmax,masking,Dropout |
参考:
[1]Tri Dao et al. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness [2]https://www.zhihu.com/question/611236756/answer/3132304304?utm_psn=1698839910906404864
更多精彩:
历史文章: