上篇文章 flash-linear-attention中的Chunkwise并行算法的理解 根据GLA Transformer Paper(https://arxiv.org/pdf/2312.06635 作者是这位大佬 @sonta)通过对Linear Attention的完全并行和RNN以及Chunkwise形式的介绍理解了Linear Attention的Chunkwise并行算法的原理。但是paper还没有读完,后续在paper里面提出了Gated Linear Attention Transformer,它正是基于Chunkwise Linear Attention的思想来做的,不过仍有很多的工程细节需要明了。这篇文章就来继续阅读一下paper剩下的部分,把握下GLA的计算流程以及PyTorch实现。下面对Paper的第三节和第四节进行理解,由于个人感觉Paper公式有点多,所以并没有对paper进行大量直接翻译,更多的是读了一些部分之后直接大白话一点写一下我对各个部分的理解和总结。这样可能会忽略一些细节,建议读者结合原Paper阅读。
这里需要说明的是,在上篇文章里面介绍到的Chunk并行算法实际上不是GLA这篇Paper首次提出的idea。GLA这篇paper是在工程上极大改进了Chunk并行算法,使得它的效率更高。改进的细节正是paper的第三节和第四节介绍的核心内容。不过我在 https://github.com/sustcsonglin/flash-linear-attention 官方仓库以及Paper给出的GLA算法伪代码中都看到只有一次分块,不太清楚原因。此外,Paper的实验中也没有把GLA Transformer Scale Up到更大的规模,这个可能是受限于算力之类的原因,不过最近看到 https://arxiv.org/abs/2405.18428 和 https://arxiv.org/abs/2405.18425 2篇比较新的Paper都是用GLA重做了一些经典的大模型架构,所以相信它是未来可期的。
paper描述了一种名为FLASHLINEARATTENTION的算法,这是一种面向输入/输出且硬件高效的线性注意力算法,它和与FLASHATTENTION相似。这一节讨论在实际高效的实现中需要考虑的硬件方面的问题。
一个高效的算法应考虑现代硬件上的计算模型、内存层次结构和专用计算单元。
我自己总结下,这一节主要是对递归形式,并行形式,以及Chunkwise并行形式进行了再次说明,paper中提到对于递归形式来说虽然flops较低但是由于要在时间步上频繁访问HBM并且无法使用Tensor Core导致实际效率很低。而对于并行形式来说,它的效率可以做到和FLASHATTENTION一致,但是当序列长度很长时,训练成本会快速增加。最后,对于Chunk形式的并行,它可以利用上Tensor Core,但是之前提出的一些实现效率较低,比如在2k-4k序列长度下是比FLASHATTENTION更慢的。
FLASHLINEARATTENTION的Forward Pass伪代码,materialize表示是否对隐藏状态S进行重计算
FLASHLINEARATTENTION的Backward Pass伪代码,materialize表示是否对隐藏状态S进行重计算
这一节直接读paper还不是很好懂,其实讲的就是说FLASHLINEARATTENTION算法有一个materialize参数来控制是否要重计算S,然后在计算过程中无论是否要重计算S都会遵循分块加载Q,K,V到共享内存中,然后我们就可以重用共享内存上的块状Tensor来避免多次加载HBM I/O。例如,对于Algorithm1中的materialize为True的情况,当 被加载到SRAM时, 和 可以在芯片上计算,这样可以避免再次加载 (从而节省HBM I/O)。
对于materialize为False的情况(非重计算版本),算法首先在HBM中把块间递归的结果存下来(对应Paper里的方程2),然后将所有 (对所有 )都并行计算在HBM中。该方法有更好的并行性,但略微增加了内存占用。非重计算版本顺序计算 (对所有 ),并使用SRAM暂时存储 。这种策略在内存上更高效,但缺乏序列级别的并行性。然而,在后向Pass过程中重计算隐藏状态 会引入大约30%的多余FLOPs。因此,非重计算版本通常比重计算版本速度更慢,但节省了更多GPU内存。
图1展示了这两种方法。
在这里插入图片描述
这张图画得挺好的,我们可以清楚的看到对于materialize为False的情况下,Q,K,V都是从HBM中加载到SRAM,每次都会计算出一个新的隐藏状态S出来,注意这个S无需保存所以它一直存在于SRAM上,整体的计算过程是一个Sequential的。而对于materialize为True的情况,首先通过KV计算出S并将S保存到HBM中,这部分也是Sequence的。计算完S之后就可以Chunk并行的计算出。这里的箭头表示每个操作需要的操作数,和上文的公式是完全对得上的。
图2展示了FLASHLINEARATTENTION实现的速度和内存占用情况。两种版本的FLASHLINEARATTENTION都比FlashAttention-2(Dao, 2023)和纯PyTorch(即不I/O感知)实现的chunkwise线性注意力快得多,展示了I/O感知的好处。所有方法都具有线性空间复杂度。非重计算版本具有最小的内存占用,而重计算版本的内存占用略高于FlashAttention-2。
在这里插入图片描述
方程1
方程1中的线性递归没有衰减项或遗忘门,而这在RNN中已被证明是至关重要的。缺少衰减项使得模型难以“忘记”信息,这被假设为部分导致线性注意力在长上下文任务中不稳定的原因。最近的研究通过在线性注意力中加入一个全局的、与数据无关的衰减因子 获得了更好的性能:。使用单一的 旨在保持注意力样式的并行形式,以实现高效训练。在paper中,作者考虑了一种与数据相关的门控机制用于线性注意力。我们展示了尽管有一个更具表达力的门控因子,所得到的门控线性注意力(GLA)层仍然可以采用硬件高效的chunkwise方式进行高效训练。
递归形式。GLA 有一个二维遗忘门 :
方程3
其中我们使用外积来获得 以实现参数效率,其中 。在初步实验中,我们发现简单地设置 是足够的,因此我们采用了以下简化的 GLA 递归形式:
其中 是通过 sigmoid 应用于 后由低秩线性层获得的(参见paper的§4.4)。
并行形式。上述递归形式有一个等效的并行形式。通过展开方程 3 我们有
设 ,我们可以将上述公式重写为
在这里插入图片描述
其中除法是按元素进行的。设 为通过堆叠 的转置获得的矩阵,则并行形式为:
在这里插入图片描述
但是,这种形式在数值上是不稳定的,因为 是在 中累积的gate值,并且当 很大时, 的值可能非常小。为了解决这个问题,我们可以以对数形式计算 :
公式4
上面推导了与线性注意力中chunkwise形式类似的GLA chunkwise形式。对于块内的仍然是完全并行的方式,而对于块间有:
在这里插入图片描述
直观地说, 编码了从一个块的开始处的累积衰减,这将用于传播来自前一个块 的隐藏状态,而 编码了到块结束处的衰减,这将用于累积信息以添加到下一个隐藏状态 。
有了Chunkwise形式之后,我们可以将paper里面第三节提出的Forward/Backward Pass应用于适应gate的情况。这个应用还依赖下面两种关键的技术,paper这里给出更直觉的解释,具体的算法推导再附录C。
次级级别Chunk化 与普通线性注意力不同,GLA中的块内计算无法利用半精度矩阵乘法(因此无法使用Tensor Core),因为涉及对数空间计算(公式4)。为了更好地利用Tensor Core,我们采用次级级别Chunk化方案,即一个块进一步划分为子块(即,另一层次的分块)。然后以块状方式计算类似注意力的矩阵 ,如图3所示。
图3:注意力风格的图示,用于说明GLA中的块状计算。块间依赖(灰色部分)并未在块状形式中直接计算(仅在并行形式中计算)。块内依赖通过次级Chunking/Tiling建模,其中块内子块部分(橙色部分)通过半精度矩阵乘法计算,而块内子块部分(粉红色部分)在对数空间中以全精度计算。其中 表示特征索引。然而,与普通线性注意力不同,公式4不能通过标准矩阵乘法表示,并且无法在张量核心上使用半精度矩阵乘法。我们将在第4.3节展示次级级别块化机制如何在保持数值稳定性的同时,使大部分计算可以使用张量核心上的半精度矩阵乘法。
具体而言,子块之间的交互是通过半精度矩阵乘法计算的:
这对应于图3中的橙色线条。对于块内子块部分(图3中的粉红色块),我们必须使用公式4并以全精度执行矩阵乘法以确保稳定性。通过这种两级块化策略,非半精度矩阵乘法FLOPs的总量大大减少。paper在附录C的图7中提供了PyTorch风格的伪代码。
内存高效的 计算 过去的工作声称GLA类模型必须将大小为 的矩阵值隐藏状态存储在HBM中,以计算所有梯度 ,因为 。这排除了使用Katharopoulos等的重新计算技术,因为重新计算需要从头构建(即,从 开始)。我们提供以下公式的封闭形式:
在这里插入图片描述
可以通过将其对公式4取导数容易地得到(参见附录C中的全导数)。并且和可以如算法2中所编写的那样计算。
在附录C中有一段gated_linear_attention
的代码,对应了上述GLA工程实现的所有技巧。将其OCR之后得到可编辑的代码,然后找一下每行代码在上面的对应位置:
def gated_linear_attention(Q, K, V, B, C, c):
'''
Q/K/V: query/key/value
B: cumprod of gates
C/c: chunk size, subchunk size
'''
# 这里不考虑batch以及attention的头的个数,只有seq和head_dim维度
seq_len, head_dim = Q.shape
# 隐藏层S的维度为(head_dim, head_dim)
S = torch.zeros(head_dim, head_dim)
# 输出的维度,也是(seq_len, head_dim)
O = torch.empty_like(V)
# 在seq_len维度上第一次分块
for i in range(0, seq_len // C):
# 当前块的下标范围
r = range(i*C, (i+1)*C)
# (C, head_dim) chunking
# 获取当前块的Q, K, V, B,其中B是gate的cumsum
bq, bk, bv, bb = Q[r], K[r], V[r], B[r]
# b1对应GLA的Chunkwise形式中的b_{iC}
b1 = B[i*C-1] if i > 0 else 1
# b2对应GLA的Chunkwise形式中的b_{(i+1)C}
b2 = bb[-1,None]
# inter-chunk w/ matmul
# q对应了GLA的Chunkwise形式中$Q_{i} \odot Λ_{[iC+j]}=b_{iC+j}/b_{iC}$
# k对应了GLA的Chunkwise形式中$K_{i} \odot \frac{b_{(i+1)C}}{b_{iC+j}}$
# g对应了GLA的Chunkwise形式中$\gamma_{i}=\frac{b_{(i+1)C}}{b_{iC}}$
q, k, g = bq*bb/b1, bk*b2/bb, b2/b1
# 对应了GLA的Chunkwise形式中计算块内的$O_{intra}=q @ S$
o = q @ S
# hidden state update
# 对应了GLA的Chunkwise形式中的隐藏层更新
S = g.t() * S + k.t() @ bv
# intra-chunk (secondary chunking)
# 计算第一次分块块内部输出的时候进行第二次分块
for j in range(0, C // c):
# 第二次分块中当前子块的下标范围
t = range(j*c, (j+1)*c)
#(c, head_dim) subchunking
# 获取当前子块的q, k, v, b
q, k, v, b = bq[t], bk[t], bv[t], bb[t]
# 计算当前子块的注意力矩阵p
p = torch.zeros(c, c)
# intra-subchunk w/o matmul.
# 子块内部的注意力矩阵p计算,无法使用矩阵乘法
for m in range(c):
for n in range(m+1):
p[m,n] = torch.sum(
q[m]*k[n]*(b[m]/b[n]))
o[t] += p @ v
# inter-subchunk w/ matmul
# 子块间的注意力矩阵p计算,可以用矩阵乘法
z = b[0, None]
q = q * b / z
for u in range(0, j):
y = range(u*c, (u+1)*c)
p = q @ (bk[y]*z/bb[y]).t()
o[t] += p @ bv[y]
O[r] = o
return O
需要对其中子块代码进行说明,下面这段代码对应了GLA递归形式中的这个公式:
在这里插入图片描述
for m in range(c):
for n in range(m+1):
p[m,n] = torch.sum(
q[m]*k[n]*(b[m]/b[n]))
可以看到这里是直接计算P的,没有考虑数值稳定性而使用公式(4),这和paper的描述似乎是不想符的。
子块之间的交互是通过半精度矩阵乘法计算的,公式如下:
代码对应:
z = b[0, None] # 相当于$b_{iC}$
# 对应了上面公式中的$Q_{i} \odot Λ_{i}=b_{iC+j}/b_{iC}$
q = q * b / z
# 遍历截止到当前子块之前的所有子块
for u in range(0, j):
# 取出当前子块之前所有子块的索引
y = range(u*c, (u+1)*c)
# 对应了上面公式的$K[j] \odot \Gamma[j] \odot \frac{b_{iC}}{b_{(j+1)C}} $,这里有代数化简
p = q @ (bk[y]*z/bb[y]).t()
o[t] += p @ bv[y]
我们需要把展开并和它化简之后才能得到p的计算代码,因为抵消了一个。
这里个人有个疑问就是附录里面的GLA伪代码算法描述是不包含二次分块的:
在这里插入图片描述
在官方代码实现中似乎也没有见到二级分块,是二级分块在工程实现中发现效果一般么?
paper在4.4节对GLA Transformer的一层的详细结构进行了介绍,paper中通过标准神经网络模块将GLA层推广到多头。给定个头,对于每一个头有如下的公式,其中。
在这里插入图片描述
在这里插入图片描述
这里不仅仅是以单个注意力头来描述公式,也忽略了Batch和Seq维度,实际训练的时候是有这两个维度的。
后面实验部分就是一些常规的东西了,说明GLA Transformer在训练上高效并且可以达到较好的性能,这里就不做冗余介绍了。
这篇文章主要是对GLA Transformer这篇Paper进行了阅读,进一步学习Chunkwise Linear Attention的思想以及GLA特殊的2级分块Chunkwise并行。不过我在 https://github.com/sustcsonglin/flash-linear-attention 官方仓库以及Paper给出的GLA算法伪代码中都看到只有一次分块,不太清楚原因。此外,Paper的实验中也没有把GLA Transformer Scale Up到更大的规模,这个可能是受限于算力之类的原因,不过最近看到 https://arxiv.org/abs/2405.18428 和 https://arxiv.org/abs/2405.18425 2篇比较新的Paper都是用GLA重做了一些经典的大模型架构,所以相信它是未来可期的。
本文分享自 GiantPandaCV 微信公众号,前往查看
如有侵权,请联系 cloudcommunity@tencent.com 删除。
本文参与 腾讯云自媒体同步曝光计划 ,欢迎热爱写作的你一起参与!