首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >梳理LLM中的设备长文本问题

梳理LLM中的设备长文本问题

原创
作者头像
曾高飞
发布2025-06-11 20:50:11
发布2025-06-11 20:50:11
3620
举报

近期,随着大模型技术的发展,长文本问题逐渐成为热门且关键的问题,不妨简单梳理一下近期出现的典型的长文本模型:

  • 10 月上旬,Moonshot AI 的 Kimi Chat 问世,这是首个支持 20 万汉字输入的智能助手产品;
  • 10 月下旬,百川智能发布 Baichuan2-192K 长窗口大模型,相当于一次处理约35 万个汉字;
  • 11 月上旬,OpenAI 发布支持 128K 上下文窗口的 GPT-4 Turbo 模型;
  • 11 月下旬,Anthropic 发布支持 200K 上下文窗口的 Claude 2.1 模型;
  • 12 月上旬,零一万物开源了长文本模型 Yi-6B-200K和 Yi-34B-200K。

实际上,随着文本长度的提高,模型能够处理问题的边界也大大提高,因此研究并解决长文本问题就显得非常必要。本文将从长文本问题的本质出发,逐步分析和研究长文本实现的问题及解决办法。

一、长文本的核心问题与解决方向

1.1 文本长度与显存及计算量之关系

要研究清楚长文本的问题,首先应该搞清楚文本长度在模型中的地位与影响。那么我们便以 Decoder-base 的模型为例来进行分析

1.1.1 模型参数量

Decoder-base 的模型主要包括 3 个部分:embedding, decoder-layer, head。

其中最主要部分是decoder-layer,其由 lll 个层组成,每个层又分为两部分:self-attention 和 MLP。

self-attention的模型参数有、、的权重矩阵 、、及bias,输出矩阵 及bias,4个权重矩阵的形状为 ( 表示 hidden_size),4个bias的形状为 。则 self- attention 的参数量为 。

MLP由2个线性层组成,一般地,第一个线性层是先将维度从 映射到 ,第二个线性层再将维度从映射到。第一个线性层的权重矩阵 的形状为 ,偏置的形状为 。第二个线性层权重矩阵 的形状为 ,偏置形状为 。则 MLP 的参数量为 。

self-attention 和MLP各有一个layer normalization,包含了2个可训练模型参数:缩放参数γ和平移参数β,形状都是。2个layer normalization的参数量为 。

由此,每个Decoder层的参数量为。

此外,embeddinghead 的参数量相同,与词表相关,为(如果是 Tied embedding,则二者共用同一个参数)。由于位置编码多样,且参数量小,故忽略此部分。

综上, 层模型的可训练模型参数量为 。当 较大时,可以忽略一次项,模型参数量近似为。

1.1.2 计算量估计

如果说参数量是模型的固有属性,那么计算量便是由模型和输入共同决定,下面分析这一过程。假设输入数据的形状为 ( 表示batch_size,表示sequence_length)。

先分析Decoder中self-attention的计算量,计算公式如下:

3da4cb6b197ed2fd6e5b990614db04c3.png
3da4cb6b197ed2fd6e5b990614db04c3.png
  1. 计算:矩阵乘法的输入和输出形状为。计算量为。
  2. 矩阵乘法的输入和输出形状为
7c302070cccc844acf6d8763c2391ac1.png
7c302070cccc844acf6d8763c2391ac1.png

计算量为。

  1. 计算在上的加权,矩阵乘法的输入和输出形状为
2c1927e069230efeb4675cb790521c28.png
2c1927e069230efeb4675cb790521c28.png

计算量为 。

  1. attention后的线性映射,矩阵乘法的输入和输出形状为.计算量为。

接下来分析MLP块的计算,计算公式如下:

f7d9576c25ccc88b678e8b316258996f.png
f7d9576c25ccc88b678e8b316258996f.png
  1. 第一个线性层,矩阵乘法的输入和输出形状为 。计算量为 。
  2. 第二个线性层,矩阵乘法的输入和输出形状为 。计算量为。

将上述计算量相加,得到每个Decoder层的计算量大约为。

此外,另一个计算量的大头是logits的计算,将隐藏向量映射为词表大小。矩阵乘法的输入和输出形状为,计算量为。

因此,对于一个 lll 层的模型,输入数据形状为的情况下,一次前向计算的计算量为。

1.1.3 文本长度与计算量、参数量、显存的关系

忽略低次项,一次输入的tokens数为bs, 则计算量与参数量的关系为 在实际中通常 ,因此该项可近似认为约等于2。即在一次前向传递中,对于每个token,每个模型参数,需要进行2次浮点数运算(一次乘法法运算和一次加法运算)。考虑到后向传递的计算量是前向传递的2倍。因此一次训练迭代中,对于每个 token,每个模型参数,需要进行 次浮点数运算。

通过以上分析,我们可以得到结论:计算量主要和模型参数和 token 数相关,文本长度并不会显著增加计算量。那么这就引出另一个问题:文本长度与显存的关系。

除了模型参数、梯度、优化器状态外,占用显存的大头就是前向传递过程中计算得到的中间激活值。这里的激活(activations)指的是:前向传递过程中计算得到的,并在后向传递过程中需要用到的所有张量。

先分析 Decoder layer 中 self-attention 的中间激活:

  1. 对于 ,需要保存它们共同的输入 ,这就是中间激活。输入 的形状为 ,元素个数为 ,占用显存大小为 。
  2. 对于 矩阵乘法,需要保存中间激活 ,两个张量的形状都是 ,占用显存大小合计为 。
  3. 对于 函数,需要保存函数的输入 ,占用显存大小为 ,这里的 表示注意力头数。
  4. 计算完 函数后,会进行dropout操作。需要保存一个mask矩阵,mask矩阵的形状与 相同,占用显存大小为 。
  5. 计算在 上的attention,即 ,需要保存 score ,大小为 ;以及 ,大小为 。二者占用显存大小合计为 。
  6. 计算输出映射以及一个dropout操作。输入映射需要保存其输入,大小为 ;dropout需要保存mask矩阵,大小为 。二者占用显存大小合计为 。

因此,将上述中间激活相加得到,self-attention的中间激活占用显存大小为 。接下来分析分析Decoder layer中MLP的中间激活:

  1. 第一个线性层需要保存其输入,占用显存大小为 。
  2. 激活函数需要保存其输入,占用显存大小为 。
  3. 第二个线性层需要保存其输入,占用显存大小为 。
  4. 最后有一个dropout操作,需要保存mask矩阵,占用显存大小为 。

对于MLP块,需要保存的中间激活值为 。

另外,self-attention块和MLP块分别对应了一个layer normalization。每个layer norm需要保存其输入,大小为 。2个layer norm需要保存的中间激活为 。

综上,每个层需要保存的中间激活占用显存大小为 。对于 层transformer模型,还有embedding层、最后的输出层。embedding层不需要中间激活。总的而言,当隐藏维度 比较大,层数 较深时,这部分的中间激活是很少的,可以忽略。因此,对于 层模型,中间激活占用的显存大小可以近似为 ,这个结果与文本长度关系密切。

下面以GPT3-175B为例,对比下文本长度对模型参数与中间激活的显存大小的影响。假设数据类型为 FP16 。

模型名

参数量

层数

隐藏维度

注意力头数

GPT3

175B

96

12288

96

GPT3的模型参数量为175B,占用的显存大小为 。GPT3 模型需要占用350GB的显存。

假设 GPT3 输入的 。对比不同的文本长度下占用的中间激活:

当 时,中间激活占用显存为

,大约是模型参数显存的0.79倍;

当 时,中间激活占用显存为

,大约是模型参数显存的2.68倍。

可以看到长度仅仅到 4K,显存占用就出现了剧烈增加,同时 GPU onchip 的 memory 就显得更加捉襟见肘(因此也就出现了 FlashAttention 这类算法)。因此如何解决长文本带来的巨量显存开销成为关键及核心问题。

1.2 长文本问题的解决思路

当前,为了实现更长长文本的支持,解决思路主要可以分为两个阶段:

  • 阶段一:在预训练阶段尽可能支持更长的文本长度 为实现这一阶段目标,通常采用并行化 (parallelism) 方法将显存占用分摊到多个 device,或者改造 attention 结构,避免显存占用与文本长度成二次关系。
  • 阶段二:在 SFT 或推理阶段尽可能外推到更大长度 为实现这一阶段目标,通常也是需要在两个方面进行考虑:
代码语言:actionscript
复制
    - 对位置编码进行外推
    - 优化 Attention 机制

本文接下来的部分将尽可能详细深入地进行这些问题的研究。为了便于理解和接受,下文将从易到难,先介绍第二阶段的技术,然后再介绍第一阶段(同时也是考虑到直接使用开源模型者,不需要第一阶段的情况)。

二、长文本与位置编码

在 Transformer 结构的模型中,Attention模块的值与顺序无关,因此需要加入位置编码以确定不同位置的 token。典型的位置编码方式有两类:

绝对位置编码:即将位置信息融入到输入中

相对位置编码:微调Attention结构,使其能够分辨不同位置的Token

随着文本长度的增加,位置编码也会发生相应的变化,因此处理好位置编码问题是解决长文本问题的重要环节。

2.1 绝对位置编码及其外推

一般来说,绝对位置编码会加到输入中:在输入的第 个输入向量 中加入位置向量 得到 ,其中 仅依赖于位置 。

如下图所示,以二维向量为例来形象说明,图左中黑色剪头为输入向量 ,蓝色箭头为位置向量 (不同方法的长度与角度不同),其相加的结果为绿色箭头。在 Attention 结构中,

即相当于同时对输入向量 和位置向量 进行线性变换,那么 attention 值则是 和 的点积,如下图右中 和 的注意力即黄色夹角区域。

c789cd239f476ff311fa6937b8acbf04.png
c789cd239f476ff311fa6937b8acbf04.png

由于绝对位置编码由两部分组成,且两部分相互独立,因此无法计算相对距离。下面介绍几种典型的绝对位置编码:

2.1.1 训练式编码

这种方式最为简单直接,即把位置当做词表一样,训练一个 位置向量矩阵。这种训练式的绝对位置编码,一般的认为它没有外推性,但是苏剑林大神提出过一个层次分解的拓展方法。

f8ece75007bc6421c223dc5bbd5a8669.png
f8ece75007bc6421c223dc5bbd5a8669.png

假设已经训练好的绝对位置编码向量为 ,希望能在此基础上构造一套新的编码向量 ,其中 。为此,设

021736c968f29572e29856457f5d2f55.png
021736c968f29572e29856457f5d2f55.png

其中超参 , 是该套位置编码的“基底”。为了保障 ,这样就能反推出各个 :

12009a33200f375db2413da3d7a321c8.png
12009a33200f375db2413da3d7a321c8.png

这样就最大可以表示出 个位置的编码,并且前 个位置编码跟原来模型是相容的。下图反映了经过finetune其准确率在延长的位置编码在MLM任务上是行之有效的。

a064b3dbbf7445c8ae1aa011da2dd3b0.png
a064b3dbbf7445c8ae1aa011da2dd3b0.png

需要说明的是,这种矩阵式的位置编码方式在当前的大模型中已经比较少采用了,仅有 GPT2 等早期模型中采用了这种方式。

2.1.2 Sinusoidal 位置编码

这种方案也是Attention Is All You Need 中提出的方法

58d1e4b3c59c6e10d72ed9d8f34f1d12.png
58d1e4b3c59c6e10d72ed9d8f34f1d12.png

其中 分别是位置 的编码向量的第 个分量, 是位置向量的维度。根据以上定义,我们可以非常简单计算得到Sinusoidal位置编码的值,并绘制图像研究其规律。计算及绘图代码如下所示:

代码语言:javascript
复制
import numpy as npimport matplotlib.pyplot as plt def getPositionEncoding(seq_len, d, n=10000):    P = np.zeros((seq_len, d))    for k in range(seq_len):        for i in np.arange(int(d/2)):            denominator = np.power(n, 2*i/d)            P[k, 2*i] = np.sin(k/denominator)            P[k, 2*i+1] = np.cos(k/denominator)    return P  P = getPositionEncoding(seq_len=100, d=512, n=10000)cax = plt.matshow(P)plt.title('Sinusoidal Positional Embeddings')plt.xlabel('Dimension')plt.ylabel('Position')plt.gcf().colorbar(cax)AI写代码go运行

整体位置编码如下图所示:

93ed26cfa830a046de69a9357ea6439d.png
93ed26cfa830a046de69a9357ea6439d.png

首先研究 Sinusoidal 位置编码与位置之间的关系,绘制不同位置下,函数值与 sin 维度的关系

代码语言:javascript
复制
def plotSinusoid(k, d=512, n=10000):    x = np.arange(0, 256, 1)    denominator = np.power(n, 2*x/d)    y = np.sin(k/denominator)    plt.plot(x, y)    plt.title('k = ' + str(k))    plt.xlabel('Dimension') fig = plt.figure(figsize=(15, 4))    for i in range(4):    plt.subplot(141 + i)    plotSinusoid(i*4)AI写代码go运行

其曲线如下图所示, 可以从图中得到几点结论:

  • 位置越远,频率越大
  • 随着维度增大,函数逐渐收敛到0(cos函数收敛到1)
fe8e2686d5611f9edde7b7b78f629b12.png
fe8e2686d5611f9edde7b7b78f629b12.png

研究 Sinusoidal 位置编码与维度分量之间的关系,绘制不同维度分量 i 下,函数值与位置的的关系

代码语言:javascript
复制
def plotSinusoid(x, d=512, n=10000):    k = np.arange(0, 100, 1)    denominator = np.power(n, 2*x/d)    y = np.sin(k/denominator)    plt.plot(k, y)    plt.title('i = ' + str(x))    plt.xlabel('Position') fig = plt.figure(figsize=(15, 4))    for i in range(4):    plt.subplot(141 + i)    plotSinusoid(i*10)AI写代码go运行

Sinusoidal位置编码与维度分量的关系如下图所示,可以发现结论如下:

  • 每个分量都具有周期性,是正弦或余弦函数
  • 越靠后的分量(i 越大),波长越长,频率越低
8cc03629291e086e2ab5d3f248c3ee82.png
8cc03629291e086e2ab5d3f248c3ee82.png

了解了这些基本的特性后,接下来就需要讨论更加深层次的问题:

问题一:为什么用包含各频率的正弦和余弦对?

位置编码存储的是一个包含各频率的正弦和余弦对,这样做有两个好处:

  • 可以使得不同位置的编码向量之间有一定的规律性,比如相邻位置之间的差异较小,而距离较远的位置之间的差异较大。这是由正弦和余弦函数的连续性和单调性保证的,即对于任意两个相邻的位置,它们对应的编码向量在每一个维度上都只有微小的变化,而对于任意两个距离较远的位置,它们对应的编码向量在每一个维度上都有较大的差异。
  • 可以使得编码向量在任意维度上都能保持唯一性,即不同位置在同一个维度上不会有相同的值。这是由正弦和余弦函数的周期性和相位差保证的,即对于任意两个不同的位置,它们对应的编码向量在每一个维度上都不相等。

问题二:底数对结果的影响是什么?

底数越大,位置向量能表示的序列就越长,这是大底数的好处。但是,底数大,意味着在-1到+1的范围内向量的取值越密集,造成两个位置的向量距离越近,这对后续的Self-Attention模块来说是不利的,因为它需要经历更多的训练次数才能准确地找到每个位置的信息,或者说,才能准确地区分不同的位置。长序列需要长编码。但这样又会增加计算量,特别是长编码会影响模型的训练时间。所以,那个底数并非是越大越好。

问题三:Sinusoidal 位置编码如何外推

三角函数式位置编码的特点是有显式的生成规律,因此可以期望于它有一定的外推性。另外一个使用它的理由是:由于

8faeeb20b501bd6fd8df7628697854bb.png
8faeeb20b501bd6fd8df7628697854bb.png

这表明位置 的向量可以表示成位置 和位置 的向量组合,这提供了位置拓展的可能性。

2.1.3 其他的绝对位置编码

如递归式(如 FLOATER)和相乘式(如PENG Bo:中文语言模型研究:(1) 乘性位置编码),因使用较少,在此不予赘述。

2.2 相对位置编码及其外推

相对位置并没有完整建模每个输入的位置信息,而是在算Attention的时候考虑当前位置与被Attention的位置的相对距离,由于自然语言一般更依赖于相对位置,所以相对位置编码通常也有着更好的表现,灵活性也更大。

2.2.1 旋转位置编码 RoPE

实际上 RoPE 的诸多思想来源于 Sinusoidal 位置编码,区别在于 Sinusoidal 位置编码采用和 word embedding 相加的形式,RoPE 则采用了矩阵相乘的形式。

在正式介绍之前,我们需要回顾一下经典的欧拉公式

5150d6e79834c36874cb619bc9ece6e5.png
5150d6e79834c36874cb619bc9ece6e5.png

其矩阵形式为 即旋转矩阵,这三种表现形式表达了同样的信息,即将二维向量逆时针旋转角度 。

66ec4acaa920e85969064655e8707477.png
66ec4acaa920e85969064655e8707477.png

接下来我们直接看 RoPE 的表达式,对于位置为 m 的 q 向量,其表达式为

ce973636148dbdeed44920a36d590fb6.png
ce973636148dbdeed44920a36d590fb6.png

即逆时针旋转了度,如上图右所示。同理,位置为的 k 向量的表达式为

99cc837253bff67220c9f6a678b8afb8.png
99cc837253bff67220c9f6a678b8afb8.png

那么便可以通过点积,计算二者的 attention 值

190854373c5b1a8339e8242f7a33e858.png
190854373c5b1a8339e8242f7a33e858.png

即证明了相对位置关系,即旋转前的 attention 值与旋转后的 attention 值的差值仅与相对位置有关。这一点也可以从上图右中看出来,即旋转前的夹角(橙色区域) 与旋转后的夹角(黄色区域) 相同,即内积也相同。

这时我们就可以写出位置为的q的完整的变换矩阵,即

4fe2fb3ab93923d524396e32c4ac90fa.png
4fe2fb3ab93923d524396e32c4ac90fa.png

从改变换矩阵也能看出,随着维度增加,旋转角度也在指数级减小,如下图所示。RoPE 的这一功能使模型可以通过从低维度到更高维度,将嵌入中编码的信息类型从低频(close)转变为高频(far)。

2.2.2 远程衰减问题

由于 RoPE 中的 attention 值除了身外,仅和因子相关,那么下面考察因子的特点

5844fc5a0347519fcdd59f438548e5af.png
5844fc5a0347519fcdd59f438548e5af.png

那么问题就变成了积分的渐进估计问题,通过一下函数计算积分值与位置距离的关系,并分析不同 base 值的影响。

代码语言:javascript
复制
from scipy.integrate import quadimport numpy as npimport matplotlib.pyplot as plt def integrand(t, dis, base=10000):    return np.exp(1j * dis * base**(-t)) def plot():    x = np.arange(0, 100, 0.1)    base_list = [1, 100, 1000, 10000, 100000]    y = np.zeros((len(base_list),len(x)))    for b in range(len(base_list)):        n = base_list[b]        for i in range(len(x)):            res, err=quad(integrand, 0, 1, args=(x[i],n))            y[b][i]=res                plt.plot(x, y[0], 'g', label='base='+str(base_list[0]))    plt.plot(x, y[1], 'r', label='base='+str(base_list[1]))    plt.plot(x, y[2], 'b', label='base='+str(base_list[2]))    plt.plot(x, y[3], 'k', label='base='+str(base_list[3]))    plt.plot(x, y[4], 'c', label='base='+str(base_list[4]))        plt.xlabel('Distance')    plt.ylabel('Value')    plt.legend()    plt.show() plot()AI写代码go运行

下图展示了不同距离尺度上不同 base 值的积分结果,可以得到以下结论:

  • 除了 base=1 外,均有明显的远程衰减特性
  • base 越小,衰减得越快且幅度也更大
  • base 越大,衰减得越慢且幅度也越小
2.2.3 RoPE 长度的内插与外推

长度外推性是一个训练和预测的长度不一致的问题。提现有两点:

  • 预测的时候用到了没训练过的位置编码(不管绝对还是相对);
  • 预测的时候注意力机制所处理的token数量远超训练时的数量。 一旦我们在模型中有效地整合了相对位置信息,增加 LLM 上下文窗口的最直接方法就是通过位置插值 (position interpolation,PI) 进行微调。

这种方法实现很简单,如果希望将预训练阶段的位置向量范围外推到,只需要将对应位置缩放到原先支持的区间()内:计算公式如下,L为原先支持的长度(如2048),为需要扩展的长度(如4096):

fb57bd4c4da01bbfdaf099a33e1ee9c2.png
fb57bd4c4da01bbfdaf099a33e1ee9c2.png

其过程如下图所示:

3a2521535f48766aabf489cdb1d8ca36.png
3a2521535f48766aabf489cdb1d8ca36.png
bacfa89f50dc51d17fe0e37269cf618e.png
bacfa89f50dc51d17fe0e37269cf618e.png

下面分析一下以上操作的本质,经过这种放缩操作后,位置为 的维度为的旋转角变为,即线性减小了旋转弧度,如下图第一列的上图所示(横轴为位置编码,纵轴为旋转弧度)。通过这种方式插值后,向量旋转速度变慢,周期变大,频率变慢。 除了上述的这种差值方式外,还有以下改进方式可以实现外推:

  • NTK-aware (Neural Tangent Kernel)

这种方式把旋转角修改为 ,其中表示 basebasebase的缩放因子,在codellama中取值为100 。其修改的方式如下图第二列下图所示(横轴为维度,纵轴为旋转角),在不同维度上修改的程度不同。这种方式保留了高频信息,即高频分量旋转速度降幅低,低频分量旋转速度降幅高;在高频部分进行外推,低频部分进行内插。这是因为靠前的维度,在训练中见过非常多完整的旋转周期,位置信息得到了充分的训练,所以具有较强的外推能力。靠后的维度,在训练中无法见到完整的旋转周期,或者见到的旋转周期非常少,训练不够充分,外推性能弱,需要进行位置插值。

  • NTK-by-parts

该方法是基于 NTK-Aware 的优化,其核心思想是:不改变高频部分,仅缩小低频部分的旋转弧度。即不改变小维度的旋转弧度,仅减小大维度的旋转弧度,这就是by-patrs的含义。

第个维度的旋转周期为:

其在训练长度内旋转的周期个数如下:

引入超参数,表示旋转周期个数的约束条件,

当,旋转周期数量足够多,则认为该维度为高频部分,无需改变。

当,旋转周期数量少,则为低频分组,进行Position Interpolation。

  • Dynamic NTK

这是是一种动态插值的方法:当推理长度小于等于训练长度时,不进行插值;推理长度大于训练长度时,每一步都通过NTK-Aware Interpolation动态放大base。

表示当前的序列长度,表示模型训练长度,

当时,不调整旋转角

当时,旋转角调整为,其中需要说明的是下图最后一列下图的粗线表示一个范围,未体现出与长度的动态联动。

1138044b887f55c1f2813a013e02b6e9.png
1138044b887f55c1f2813a013e02b6e9.png

需要说明的是,论文Scaling Laws of RoPE-based Extrapolation中深入研究了 RoPE 位置编码的特性,其结论就是:RoPE 中 base 的放大和缩小都能获得很好的外推效果(base=10K 效果最差)。原因在于:

  • 当 base 较小时(如 500),RoPE 的三角函数周期变短,训练时就可以见过完整的值域;
  • 当 base 较大时(如 1000000),RoPE 的三角函数周期变长,训练时虽然不能见过完整的值域,但是外推时仍处于单调区间。
2.2.4 其他形式的编码方式及其外推

在苏神的文章Transformer升级之路:12、无限外推的ReRoPE中指出:RoPE 形式上是一种绝对位置编码,但实际上给 Attention 带来的是相对位置信息,即如下的Toeplitz矩阵

82952fac915fac52f79e9560eba8e3d4.png
82952fac915fac52f79e9560eba8e3d4.png

这么这种形式的 bias 似乎有种似曾相识的感觉,没错,就是 ALiBi 编码。严格来说,ALiBi并不算位置编码,因为它并没有作用在 embedding 上,而是直接作用在了 Attention 上,通过这种构造方式既实现了远程衰减,又实现了位置的相对关系。

b7af443da6672c2aab9bf3e6b71c1b6e.png
b7af443da6672c2aab9bf3e6b71c1b6e.png

对于外推特性,ALiBi 与前文所述的方法也是不同的,体现在:

  • 事后修改,比如NTK-RoPE、YaRN、ReRoPE等,这类方法的特点是直接修改推理模型,无需微调就能达到一定的长度外推效果,但缺点是它们都无法保持模型在训练长度内的恒等性
  • 事前修改,如ALIBI、KERPLE、XPOS以及HWFA等,它们可以不加改动地实现一定的长度外推,但相应的改动需要在训练之前就引入,因此无法不微调地用于现成模型

三、长文本与 Attention 机制

Attention 机制也是制约长文本实现的重要因素,以下是几种典型的 Attention 的 方式:

15821db5e80e7c464b48c0308a1a70a6.png
15821db5e80e7c464b48c0308a1a70a6.png

关于 Attention 机制改进的更多类型和细节,笔者在之前的文章中已经有所讨论,可看历史文章。

在此主要想介绍一个方案 —— LongLora。

回顾第一节中研究的结论,长文本影响最大的就是 self-attention 中的,随长度二次变化的显存占用和计算复杂度。为解决这个问题,LongLora 的原则是,虽然在推理过程中需要密集的全局注意力,但通过稀疏的局部注意力可以有效且高效地微调模型。

LongLora 在微调期间延长上下文长度,同时使用 Lora 方法保持高性能和低复杂性。其中最关键的是提出了转移短注意力(S2-Attn)方案。下面简要介绍这一方案:

S2-Attn 在微调阶段使用局部注意力而不是全局注意力。即将输入文档分解为几个不同的组,并在每个组中分别应用注意力机制(Pattern 1)。尽管这种方式能够在资源占用不多的情况下拓展长度,由于不同组之间缺乏信息交换,随着上下文长度的增加,会导致混乱增加。

为了解决上述问题,S2-Attn 引入了组大小一半的移位操作,确保相邻组之间顺利的信息交换(Pattern 2)。这种做法有助于模型在文本开头和结尾之间顺利交换信息,从而提高模型稳定性。

89d9a94a45b0b3c9bb22464a1d0932af.png
89d9a94a45b0b3c9bb22464a1d0932af.png

而本文提出的 shift short attention 有一半的 head 会被做 shift,如下图所示,然后每个 group 内作 self-attention,从而使信息可以在不同 group 间传递。这种做法实际上将 Pattern 1 和 Pattern 2 结合起来,而没有引入额外的计算开销,使其非常适合高效处理长序列文本。

ef4e78a9ff416d464baf25f5b504c1cb.png
ef4e78a9ff416d464baf25f5b504c1cb.png

此外,LongLoRA相比于Lora还可以微调embedding层和normalization层。尽管这两项内容占的参数量很小(以Llama 2-7B为例,embedding层只占1.94%,normalization层更是不到十万分之四),对结果也起到了重要作用。

四、长文本的预训练方法

上两节主要介绍了如何在位置编码和 attention 机制方面进行文本长度的有效拓展,这两个方面都是“经济适用性”的,即只需要简单微调或者直接外推即可,接下来将是最困难,也是成本最高的部分,即讨论如何在预训练阶段提高文本长度。

为解决预训练过程中的长文本问题,思路主要有以下几个方面:

  • 并行化计算,典型方法如 sequence parallelism
  • 优化 attention 机制,典型方法如 Transformer-XL, Longformer
  • 引入 memory 机制,典型方法如 Focused Transformer, Memorizing Transformer
  • 引入采样机制,典型方法如 Hierarchical transformers, Dynamic-Pooling Transformer

由于笔者精力有限,下面仅选取其中部分方法加以介绍。

4.1 序列并行(sequence parallel)

在并行化算法大行其道的今天,使用改思想来解决长文本问题变得自然而言,实际上 SP 已逐渐成为 3D(DP, PP, TP)并行之外的第 4 个维度了。简单来说,SP 就是将一段完整的文本拆分到多个设备上进行计算,设备在适当的时候进行通信和信息交互,如下图(c) 所示。

37369f9bbecc8bd146c3b06d388eaa89.png
37369f9bbecc8bd146c3b06d388eaa89.png

在实现层面,借鉴了 Ring-Allreduce 的思想,将输入序列分割成多个块,并将每个块输入到其相应的二手设备中。

为了计算注意力输出,将环状通信与自注意力计算相结合,实现了环自注意力(RSA),如下图所示。

e704a06277d140cb36d30540e41dbbda.png
e704a06277d140cb36d30540e41dbbda.png
3df16bb831f32abc3e91f69a88a8b630.png
3df16bb831f32abc3e91f69a88a8b630.png

下面我们来深度理解一下这个过程,论文中的符号表示为

  • B: batch size
  • L: sequence length
  • H: hidden size of linear layers
  • A: attention head size
  • Z: number of attention heads
  • N: number of GPUs

对于:

be358c4ce9024e8267dcf4436af3fb8d.png
be358c4ce9024e8267dcf4436af3fb8d.png

,切分后的每一个小块,在和矩阵乘后得到需要在切分维度做补全才可以得到每一个小块的完整结果,在后续进行操作。故需要做 ring操作达到一个concat的操作。

对于:

8f341c892834b4185a9bef667fd4dea4.png
8f341c892834b4185a9bef667fd4dea4.png

需要完整对结果进行value查询,需要对根据序列并行度进行分块,用对应的块在对应的value上查询并求和,故也需要做ring操作。

843bdaf1be63aefa05644461c7c47ea4.png
843bdaf1be63aefa05644461c7c47ea4.png

MLP 部分的计算就更简单了,如下所示:

9e816ff11e3bbe72fafaf29e92695587.png
9e816ff11e3bbe72fafaf29e92695587.png
4.2 LongLLaMA (Focused Transformer)

LongLLaMA 通过引入Focused Transformer(FOT)方法,在保持性能的同时,将 LLaMA 的上下文长度扩展到100k!在长文本的情况下,除了第一节所研究的显存和计算量的问题外,这篇论文还提出了一个分心问题(Distraction Issue),即随着文本长度的增加,其中相关的 tokens 对不相关 tokens 的比例会减少,从而导致与不相关 value 相关的 key 和与相关 value相关的 key 发生重叠,致使模型需要额外区分不同语义的 key 。

为此文章提出了Focused Transformer(FOT)解决方案,其中主要使用了 Memory Attention Layers 以及 CrossBatch 技术,在 Inference 的过程中,绿色的 Memory Attention Layers 使用 kNN 对外部的 Memory 进行查询,从而有效延长了上下文长度,而 Memory Attention Layers 则主要使用 CrossBatch 进行训练。

27b9bde5a8d4ecc5ef9e4082fc2b831d.png
27b9bde5a8d4ecc5ef9e4082fc2b831d.png

具体而言,Memory Attention Layers 中的每个 query 在 却符合中会关注局部的上下文以及 Memory 中使用 kNN 计算出的最匹配的 个key,而整个 Memory 则根据 之前处理的 key,value 进行填充。而 CrossBatch 则期望使得 Memory Attention Layers 更加关注长文本之中的“相关 value 的 key” ,CrossBatch 的处理借鉴了对比学习的思想,以相关文档之中的 d-1 个上下文作为正样本,以不相关文档之中的 d-1 个上下文作为负样本,通过对比学习的方式使得 Memory Attention Layers 可以更好的分辨相关与无关的 key-value。

与标准的 Transformer 相比,一般的 Transformer 的训练过程中,相关与不相关文档没有被得到有效区分(正负样本分散均匀),当文档数量扩展时,注意力变得越来越分散,而 Focused Transformer 则通过 CrossBatch 的训练目标使得模型有效的关注与区分的长文本下的相关与无关的 key-value 的空间结构,从而解决了分心的问题。

五、长文本的效果评估

5.1 PPL

对于 LLM 来说,其效果通常是通过生成连贯且上下文相关的文本的能力来衡量的。为了量化和衡量这一指标,困惑度 (Perplexity, PPL) 便成了最常见的指标。

PPL 是一种衡量标准,反映模型根据前面的上下文预测下一个单词的能力。PPL 分数越低,模型准确预测下一个单词的能力就越好。

PPL 是使用平均交叉熵计算的,而平均交叉熵又是使用数据集中的单词数量和根据前面的上下文预测的单词(目标单词)的概率来计算的。前面的上下文通常由目标单词之前的固定长度单词序列表示。其公式如下:其中H是平均交叉熵,

PPL 作为一种客观的评估指标被广泛用来进行 LLM 的评估。但是其也存在一些问题和不足:

  • 模型词汇量可能会不公平地影响PPL:PPL 在很大程度上依赖于模型的词汇量及其概括未见过的单词的能力。如果模型遇到训练数据中不存在的单词或短语,即使生成的文本有意义,其 PPL 分数也较高。
  • 缺乏主观性考虑:PPL 是一种客观指标,不考虑主观因素,例如风格、创造力或特定环境下的适当性。
  • 上下文理解:PPL 主要关注于根据前面的上下文预测下一个单词。然而,它可能无法捕捉模型对更广泛背景的整体理解。
  • 语言歧义和创造力:PPL 并不能体现模型处理语言歧义或生成创造性和新颖输出的能力。
  • 领域特异性:PPL 对训练数据的领域和分布很敏感。在特定领域训练的模型可能会在其领域内实现较低的复杂性,但可能需要帮助在其训练环境之外生成文本。
  • 过度拟合和泛化:PPL可能会受到过度拟合的影响,其中模型在训练数据上表现得非常好,但很难泛化到看不见的或现实世界的例子。

实际上,StreamingLLM 就很好地证明了 PPL 的局限性,因为尽管 StreamingLLM 的 PPL 值较低,但是由于其损失了大量中间信息,因此无法在“大海捞针”等测试方法中有较好的表现。

5.2 “大海捞针”

“大海捞针” 由 Greg Kamradt 提出的大模型长文本性能测试方法,其做法是在文本语料中藏入一个与文本语料不相关的句子,然后看大模型能不能通过自然语言提问的方式(Prompt)把这句话准确地提取出来。Greg Kamradt 的“大海捞针”实验简述:

“大海”:Paul Graham 的文章合集作为语料 “针”:“The best thing to do in San Francisco is eat a sandwich and sit in Dolores Park on a sunny day.” 提问:"What is the most fun thing to do in San Francisco based on my context? Don't give information outside the document" 期待模型输出的正确答案: The best thing to do in San Francisco is eat a sandwich and sit in Dolores Park on a sunny day.

Greg Kamradt 公布了他对 GPT-4 Turbo(128K)和 Claude 2.1 的测试结果:

  • GPT-4 Turbo(128K)在语料长度超过 72K 且句子(“针”)藏在文本头部的时候,准确率不佳。
da2f63891301d8a9e53fb02321868f3c.png
da2f63891301d8a9e53fb02321868f3c.png
  • Claude 2.1似乎在语料长度超过20K之后就开始准确率不佳,而且句子(“针”)藏在语料靠前的位置时,准确率尤其差。
  • 进一步的,Anthropic发现可以通过简单的prompt提示就可以提高模型不愿意回答不相关内容的效果,即让模型回答问题之前,加上一句“Here is the most relevant sentence in the context:”即可大幅提升模型回答效果,改进模型不愿意回答不相关内容的水平。
  • 此外国内的 Moonshot AI 的长文本模型 Kimi Chat 也在“大海捞针”实验中发挥了令人惊艳的表现,原始报道见这里 。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 一、长文本的核心问题与解决方向
    • 1.1 文本长度与显存及计算量之关系
      • 1.1.1 模型参数量
      • 1.1.2 计算量估计
      • 1.1.3 文本长度与计算量、参数量、显存的关系
    • 1.2 长文本问题的解决思路
  • 二、长文本与位置编码
    • 2.1 绝对位置编码及其外推
      • 2.1.1 训练式编码
      • 2.1.2 Sinusoidal 位置编码
      • 2.1.3 其他的绝对位置编码
    • 2.2 相对位置编码及其外推
      • 2.2.1 旋转位置编码 RoPE
      • 2.2.2 远程衰减问题
      • 2.2.3 RoPE 长度的内插与外推
      • 2.2.4 其他形式的编码方式及其外推
  • 三、长文本与 Attention 机制
  • 四、长文本的预训练方法
    • 4.1 序列并行(sequence parallel)
    • 4.2 LongLLaMA (Focused Transformer)
  • 五、长文本的效果评估
    • 5.1 PPL
    • 5.2 “大海捞针”
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档