首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >内存优化黑科技|Flash attention 为什么那么快?

内存优化黑科技|Flash attention 为什么那么快?

作者头像
AI老马
发布2026-01-13 15:02:09
发布2026-01-13 15:02:09
2260
举报
文章被收录于专栏:AI前沿技术AI前沿技术

大模型训练和推理大多是内存访问密集型,尤其推理时动辄几K长度的输入,内存大小,读写频次成为难以逾越的性能瓶颈。如何突破内存墙限制?本文围绕FlashAttention通过分块计算和动态重计算技术,对注意力机制进行优化,在保证准确性的同时,如何大幅优化注意力机制的内存效率和计算速度。主要包括:

1,GPU的内存层级架构特点介绍

2,为什么标准的attention需要频繁的访问HBM内存,进行读写操作,

3, Flash-attention增量式更新softmax值的公式推导过程

1, GPU内存层级结构

训练数据最终需要从主机内存流转到GPU内存中,了解不同内存的大小,访问速度等,有助于理解在模型训练中,不同内存所起到的作用,训练瓶颈内存墙的存在,以及对应优化方案所解决的问题。

  • • GPU SRAM (Static Random-Access Memory)

位于 GPU芯片上的超高速静态内存,用于缓存频繁访问的数据和指令,减少对显存(HBM)的访问延迟。在FlashAttention中用于缓存分块后的Q、K、V矩阵,实现局部Softmax计算。

  • • GPU HBM (High Bandwidth Memory)

GPU的板载显存,用于存储模型参数、激活值、梯度等大规模数据,通过高带宽接口与计算核心直接交互。HBM存储完整的Q、K、V矩阵。

  • • Main Memory (CPU DRAM)

位于 CPU侧的系统内存,存储操作系统、应用程序和未加载到GPU的数据。GPU 通过PCle 或NVLink总线访问主内存,需显式拷贝数据到 HBM。在GPU计算中,受PCle带宽限制,数据从CPU DRAM拷贝到GPU HBM,成为异构训练的性能瓶颈。

总结:

内存类型

位置

容量

带宽

特点

GPU SRAM

片上 SM内

~17MB

~19TB/s

最快 最小

GPU HBM

GPU板载内存

40/80G

~1.5TB/s

次快 中等

CPU DRAM

主机内存

512G-2T

50-400GB/s

最慢 最大

2,标准注意力和稳定版 softmax

2.1 标准注意力计算流程

标准注意力计算的核心问题是:显存 (HBM) 与片上内存(SRAM) 的频繁数据交换,导致高延迟和低效率。

标准的注意力计算分为三步,以单头注意力为例。

  • • 符号对齐

输入和输出矩阵维度信息 ,其中N表述序列长度,d表示多头注意力单头维度。在进行标准的注意力计算时,需要引入两个变量 S 和 P。S 为注意力分值 Attention Scores,P 对 S 逐行的 softmax 得到,可理解为 Normalized Attention Scores。

  • • 计算相似度矩阵

公式:

操作:从HBM 分块加载Q、K到 SRAM,计算分块乘积后写回S到HBM。

内存交互:需将 Q、K从HBM 加载到 SRAM ,因为S太大无法全留在SRAM,生成的S需完整写回HBM,占用 空间。对Q和K各读一次,对S写一次,共三次交互。

  • • Softmax归一化

公式:

操作:从 HBM读取 S到 SRAM, 逐行计算softmax后,将 P值写回到HBM

内存交互:需加载 S的分块到 SRAM,逐行计算Softmax得到 P,生成的P需完整写回HBM,再次占用 空间。对S读一次,对P 写一次,共两个交互。

  • • 加权聚合输出

公式:

操作:从HBM 分块加载P、V到 SRAM,计算后写回O到HBM

内存交互:需加载P、V的分块到SRAM进行分块计算,生成的 O 写回 HBM,占用 空间。对P和V各读一次,对O写一次,共三次操作。

2.2 稳定版 softmax 计算

解决问题:为了防止指数最大值溢出,在稳定版的softmax计算中,每个元素减去最大值 ,然后再除以所有元素的和。

softmax 公式:

最大值定义:

所有元素和:

其中:

公式(1),可改写为公式:

公式(5)中的分子为向量,分母为标量,计算过程为逐元素相除。

对于softmax不能分块计算,最大的障碍来自于公式(5)的分母,因为它需要所有元素的和

小结

标准注意力的性能瓶颈主要来自HBM的频繁读写。因为GPU片上SRAM内存容量有限,每个计算核心SM(streaming multiprocessor)更是小的可怜!无法缓存完整的中间值 S 和 P。每次计算需分块加载数据到SRAM,计算后写回HBM。

具体表现为,计算 S 和 P 需两次全量读写HBM即,读 QK -> 写S -> 读 S ->写 P -> 读 PV -> 写 O。总HBM访问量为 与序列长度平方成正比。

在计算attention时,最大程度的减少读写HBM内存,成为进一步提高模型训练和推理效率的关键。

3,flash attention 注意力计算

核心思想:通过分块计算 tiling,动态重计算和内核融合,将计算保留在SRAM中,大幅减少HBM的访问,从而突破内存墙限制。

假设输入的向量 , 将其切分为两个向量 , 两个分块向量依次计算。

  • • 分块 局部softmax 计算过程

最大值:

逐元素减去最大值:

求和:

局部softmax:

之所以称之为局部:因为公式(7)并没有减去全局的最大值,而是当前块的最大值,公式(9)的分母应该是全局的和,而不仅仅是当前块全部元素的和。

  • • 保存标量结果用于更新全局值

每个分块计算后的局最大值和局部元素和,两个标量也需要保存,即:。

定义两个全局标量,全局最大值和全局元素的和 。

此时只处理了分块1,全局标量和局部标量相同,即

  • • 分块 局部softmax 计算过程

最大值:

逐元素减去最大值:

求和:

局部softmax:

此时计算得到的softmax 也是局部的。

  • • 更新全局标量值

更新全局最大值:

更新全局元素和:

问题:为什么公式(15)中多乘上一个多项式 就更新到全局了,推导下:

此时公式(16)为全局。

分块的全局和已经减去了局部的最大值,要更新为全局的必须先加回来,然后再减去更新后的全局最大值。

以上可以当需要把局部的求和 l 更新为全局时,只需要多加一个多项式

此时通过公式(15)可以得到当前全局元素的和就不难理解了。

下面通过softmax 公式(13)直接推导,先看分子:

此时重新计算公式(13)

公式(18)中更新分块2的softmax时,不再需要分块2的原始向量元素,只需要保存下来的标量值 和当前需要更新的局部softmax值 。

同理将以上标号换成分块1,即可更新分块1的局部softmax到全局。

4,总结

标准注意力的计算受限于HBM与SRAM的频繁数据交换,尤其是中间矩阵S和P的的开销。Flash attention 通过分块计算,动态重计算和内核融合,将计算保留在SRAM中,大幅减少HBM的访问,从而突破内存墙限制。

步骤

标准注意力

Flash attention

中间矩阵存储

必须存储 到HBM

不存储S和P,只缓存部分统计量

HBM访问次数

读写S和P

不存储中间值

softmax计算

全局逐行计算,依赖HBM同步

增量式分块计算,全部在SRAM中完成

元素融合操作

多次独立读写HBM

单次分块融合 softmax,masking,Dropout

参考:

[1]Tri Dao et al. FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness [2]https://www.zhihu.com/question/611236756/answer/3132304304?utm_psn=1698839910906404864

更多精彩:

历史文章:

大模型推理-page attention 内存分页术

大模型推理-极致化的批处理策略介绍

大模型推理- PD分离部署,势在必行!

大模型推理-高效推理必备KV cache

大模型训练-混合专家系统MoE

大模型训练-Nvidia GPU 互联技术全景图

大模型训练-流水线并行PP

大模型训练-张量并行TP

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2025-05-26,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 AI老马啊 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 1, GPU内存层级结构
  • 2,标准注意力和稳定版 softmax
    • 2.1 标准注意力计算流程
    • 2.2 稳定版 softmax 计算
    • 3,flash attention 注意力计算
    • 4,总结
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档