|os: 昨天文章标题写错,实在抱歉~
大型语言模型时代下,面对海量的文本数据,扩展序列长度已然成为一个关键问题。现有算法下,序列长度受限主要受模型表达能力、计算复杂度的影响。在此背景下,微软研究提出了一种Transformer变体:LONGNET,该架构将序列标记长度扩展到了10亿+,且并不会影响较短序列的性能。LONGNET的核心是扩展注意力,将计算复杂度从二次降低到线性。LONGNET可以用作分布式训练器,「跨多个GPU」设备并行训练序列。
Paper:https://arxiv.org/pdf/2307.02486.pdf
Code:https://github.com/microsoft/torchscale
纵观深度学习发展趋势,随着模型框架层数的增加,模型表达能力也逐步增强,由此产生许多强大的深度网络,例如:BMR+20, KMH+20, ZKHB22, CND+22, DDM+23等;然后,随着稀疏MoE模型和模型并行化方法的出现,模型隐藏维度得到了有效扩展。「而序列长度作为神经网络的最后一个原子维,我们希望是无限的」。
打破序列长度的限制能够带来显著的优势。首先,它为模型提供了大的记忆和接受场,这对它们与人类和世界的互动是实用的。其次,更长的上下文包含更复杂的因果关系和推理路径,模型可以在训练数据中利用这些信息。相反,较短的上下文中的例外会存在伪相关的信息,这不利于模型的泛化。最后,它能够探索上下文学习的局限性,这有可能成为many-shot学习的范式转变,因为较长的上下文有助于模型减轻灾难性遗忘。
「扩大序列长度的主要挑战是在计算复杂性和模型表达能力之间取得适当的平衡」。RNN风格的模型主要是为了增加长度。然而,它的顺序性质限制了训练过程中的并行化,而这在长序列建模中至关重要。最近,状态空间模型对序列建模很有吸引力。它可以在训练期间作为CNN运行,并在测试时转换为高效的RNN。虽然它们在远程基准测试中表现良好,但它们在常规长度上的性能不如Transformers,这主要受到模型表达能力的限制。
缩放序列长度的另一个原因是降低Transformer的复杂性,即自注意力的二次复杂度。在注意力上实现滑动窗口或卷积模块是使复杂性接近线性的直接方法。然而,这牺牲了回忆早期标记的能力,忘记了序列开头的提示。稀疏注意力通过稀疏注意力矩阵来减少计算量,保留回忆远距离信息的可能性。目前最新的一些方法的序列扩展长度都没有达到10亿+的水平。如下图所示:
基于以上背景,微软研究提出了一种新的Transformer变体:LONGNET,该架构将序列标记长度扩展到了10亿+,并不会影响较短序列的性能。它采用用一个名为扩展注意力的新颖组件取代了普通Transformers的注意力,其设计原则为:注意力分配随着Token之间距离的增加呈指数减少。这使得LONGNET可以获得线性计算复杂度和对数依赖性,从而解决了有限的注意力资源和每个标记的可访问性之间的矛盾。
扩展注意力由一系列用于建模短程和长程依赖关系的注意力模式组成,注意力模式的数量可以根据序列长度进行扩展。在每个注意力模式中,查询向量和键向量之间的点积被分解为多个子点积,每个子点积仅涉及到一小部分的键向量。这种分解方式可以减少计算复杂度,同时也可以使模型更好地处理长序列。具体如下图所示:
扩张注意力还引入了“多头”机制,可以在不同的头之间分别计算注意力。每个头都有自己的偏移量,这样就可以在不同的位置上计算注意力,从而更好地捕捉序列中的信息。通过这种方式,扩张注意力可以更好地处理长序列,同时保持较短序列的性能。具体如下图所示:
分布式训练方法,利用LONGNET的线性计算复杂度,将序列维度分布式地进行训练。具体而言,算法首先将输入序列沿着序列维度进行切分,每个序列片段被分配到不同的设备上进行计算。然后,每个设备将序列片段投影为查询、键和值,并使用本地计算得到局部的注意力权重。对于超出本地设备序列长度的部分,键和值将被发送到其他设备上进行计算。最后,所有设备将局部的注意力权重进行汇总,得到全局的注意力权重,并使用全局的注意力权重计算每个标记的表示。具体如下图所示:
该算法可以在任意数量的设备上进行扩展,并且可以通过并行计算来加速训练过程。由于LONGNET具有线性计算复杂度,因此该算法可以有效地处理超长序列,而不会牺牲训练速度和模型性能。此外,该算法还支持标准Transformer的优化技术,例如内核融合、量化和分布式训练,从而使得LONGNET可以无缝地与现有的深度学习框架进行集成。
LONGNET能够在几乎恒定的运行时间下有效地将序列长度扩展到1B个Token,如下图所示,而普通Transformer则面临着二次复杂度的问题。
将LONGNET与普通Transformer和稀疏Transformer进行比较。架构之间的差异在于注意力层,而其他部分保持不变。将这些模型的序列长度从2K扩展到32K,同时减小批次大小以保持每批次的Token数量恒定。如下图所示,实验结果表明:1)在训练期间增加序列长度通常会产生更好的语言模型;2)推理中序列长度的外推不适用于长度远大于模型支持的情况;3)LONGNET始终优于基线模型,证明了其在语言建模方面的有效性。