
论文链接:https://arxiv.org/pdf/2509.01085
亮点直击

结果
由于 DiT 模型采用Full Attention机制,计算量随序列长度增加而呈二次方增长,计算复杂度为)(其中 L 为 token 序列长度)。这直接导致在训练与推理过程中的计算成本急剧攀升,严重制约了 DiT 模型在高分辨率长视频生成任务中的实用性与效率,因此亟待针对性的优化方案来解决这一核心限制。为了解决上述问题,提出了一种可训练的双向动态稀疏注意力加速框架,首次对3D Full Attention中的Query和Key-Value 对分别进行动态稀疏化计算,同时设计了不同的动态稀疏化策略来提升训练、推理效率。
大量实验表明,该方法显著加速了视频扩散模型在不同长序列上的端到端训练速度,获得了最大20倍的FLOPs减少和17.7倍的注意力训练加速,同时获得了与Full Attention相当甚至更好的生成质量,除此之外,也可以在不降低推理质量的情况下加速推理速度,在H100上将端到端的推理延迟从31s降低到5.2s ( 6.2x )。
视频 DiT 在训练全分辨率、长序列数据时,大部分计算资源都耗费在注意力上,它可以消耗高达95 %的处理时间,且训练后的 DiT 在推理阶段仍速度缓慢,这使得注意力计算成为视频 DiT 缩放的首要瓶颈。为了改善这一状况,近期很多工作提出了多种稀疏注意力机制。它们的核心思路是让每个查询Query仅与KV键值对的部分子集进行交互,以此来降低计算的复杂程度。它们只关注KV键值对中的部分冗余子集,却忽略了Query查询序列中同样存在大量的冗余信息,这会导致大量的重复计算。除此之外,绝大多数稀疏注意力机制大多被设计成无需训练的形式。这些未经过训练的方法通过直接截取部分KV子集来进行注意力计算,在实际训练中往往只能得出欠佳的结果。

发现
为了设计高效的注意力训练框架,对当前Full Attention的训练延迟进行了特异性分析,并揭示了以下两个关键发现:
(1)Full Attention中的查询Query和Key-Value序列均具有较大稀疏性而导致过多的计算浪费。
(2)DiT中的注意力计算呈现动态稀疏性。动态稀疏性分别体现在Query和KV的时间、空间动态稀疏性。
为了解决上述挑战,提出了一种可训练的双向动态稀疏注意力(BSA,Bidirectional Sparse Attention for Faster Video Diffusion Training)加速框架,首次对3D Full Attention中的Query和Key-Value 对分别进行动态稀疏化,同时设计了不同的动态稀疏化策略来提升训练、推理效率。
现代视频扩散 Transformer(DiT)使用 3D Full Attention来捕捉整个视频体积内的依赖关系,在Full Attention中,Q、K、V中的所有序列令牌都参与交互和计算。而Sparse Attention通过从KV对中选择关键子集和来减少总体计算量,旨在提高效率。注意力输出O计算如下:

主图
如图3 所示,方法框架主要分成三部分: (a)为注意力序列立方体划分,将视频 latent 划分为时空立方体(Block),通过均值池化生成块级表示来有效地筛选关键信息。 (b)提出的Query-Sparse方法,分别基于Query的语义冗余特征来高效的选取最优query token,并根据时间空间动态稀疏性设计动态稀疏策略。 (c)提出的动态KV-sparse方法,对不同的Q选择对应最关键的KV token,动态选择关键 token 直至累积分数达到目标阈值p,无需预设固定稀疏模式,适应不同输入内容的稀疏需求。
给定一个形状为()的视频,为了可以高效地以较低的计算成本来选择关键token子集,采用将多个token组合成一个较大的立方体block的形式来进行初步的选择。对于输入查询 、键、值 ,将视频 latent()划分为大小为的立方体,每个立方体对应 GPU 上的一个块(block),块大小。然后对每个立方体的 tokens 进行均值池化,得到块级查询 、键、值 。视频中的每个立方体映射为GPU SM上的单个瓦片来协同设计稀疏注意力算法及其核心实现。
视频数据本身具有多帧的时间相关性和每帧帧内的空间相关性,因此存在时空信息冗余。实验测试显示在视频扩散模型中,约 4% 的空间邻近 token 贡献了 80% 的注意力分数,可以去除冗余token的情况下实现无损性能。因此考虑到每个query查询序列中也会存在很大的信息冗余(如静态背景、连续动作的相似帧),主要的语义(如物体类别、动作趋势)由少量关键 token 主导,丢弃相似语义的冗余 token 不会破坏整体语义结构。
基于此发现,提出了基于特征冗余的query token稀疏化方法。详细地说,对于查询Query设查询分成 个块 ,块 的token集合为 ,对应中心token为。 发现基于分块后的同一block内的 token(如空间邻近的像素块)通常包含很多语义高度相似的特征,中心 token 在时间空间维度上可作为该区域的语义代表,可以计算块内其他token与中心语义代表token之间的特征相似度,使用余弦相似度或点积衡量中心 token 与周围 token 的语义相似性,避免平均池化的 “一刀切” 信息损失,对于每个block之内的token进行局部时空窗口内计算相似性,然后对每个block内保留部分不冗余的tokens,这些token便可以贡献关键的注意力分数,而去除的冗余token由于所代表的特征信息与其他token重复,因此即便去除了也可以实现无损性能,不会破坏语义结构。对每个块分别按剪枝率 保留部分token,最后将所有block内的保留下来的关键token进行拼接,构成新的无冗余的查询Query ,具体生成方式如下所示:
其中, 表示在块b 内根据从大到小排序后的排名,是块b中的 token 数量,是保留比例。
基于立方体划分后的块级表示,可以让每个查询Query仅与KV键值对的部分子集进行交互,以此来大量降低计算的复杂程度。但是如何确定每个查询Query对应的关键KV键值对子集是一个非常重要的问题。在实验中发现,稀疏性在注意力块之间和同一块内之间存在显著差异,并且对于每一个query查询对应的关键kv对也是动态变化的,不应该采用固定的top-k选择方式来统一固定对每个query进行关键kv的选择。
因此提出了基于统计阈值的动态KV-Sparse稀疏方法,分别针对每个Query选取动态的关键KV对,并通过输入注意力分数的统计特性来计算得到动态的稀疏阈值来选取关键KV对,无需预设固定稀疏模式,适应不同输入内容的稀疏需求。
首先先对每个立方体的 tokens 进行均值池化,得到块级查询 、键、值 ,然后进行块选择Key Block Select ,计算块间注意力得分 ,通过动态统计阈值 选择关键块(保留高注意力值的块)。然后再将稀疏化的每个查询Query block 分别与选取到的关键KV对仅在关键块内进行 token 级注意力计算。动态稀疏分别体现在两方面:
最终的稀疏注意力:设稀疏化后的查询矩阵为,(其中为稀疏化后的查询 token 数量),筛选出的关键键矩阵为 、关键值矩阵为 。其中对应所有 query block 选出的关键 KV 对键集合;对应相应的值集合;稀疏掩码矩阵为 ,稀疏掩码矩阵,保证只计算选中选中的 query 与 KV 交互对应的注意力。稀疏注意力输出可以表示为:
其中,为稀疏化后的注意力分数矩阵(维度 , 为关键 KV token 数量),是缩放因子,为最终的稀疏注意力输出,维度和输入保持一致。
基于Wan2.1-1.3B模型架构进行T2V任务的模型训练,重新初始化进行training from scatch,所有的模型训练均训练至完全收敛,以保证公平比较。
如图4所示,Sparse Attention与Full Attention基线的预训练损失曲线相重合,均表现出稳定且平滑的下降趋势,并且大部分优于Full Attention 模型。

loss
如表1 所示,在2个不同的分辨率上对Sparse Attention 和Full Attention 进行from strach训练,分别为61 × 448 × 832,23K令牌)的原始分辨率,和扩展的更长token长度( 157x768x1280 , 153K令牌)。进行Sparse Attention和Full Attention在效率和生成质量上的对比。

为了评估BSA在不同序列长度上的训练加速效果,分别在5种不同序列长度上进行训练加速比测试。所有的模型训练设置均保持一致来保证训练的公平性,结果如图6所示。详细地说,分别测试了23k、44k、59k、117k、153k序列长度,加速比随着序列长度的增加逐渐增大。当序列长度为最小的23k的时候,加速比也可以达到12.85x,当序列长度增加为其2倍的44k的时候,加速比可以增加至14.72x。对于当前测试的最长的序列长度153k时,最大加速比可以达到17.79倍,由此说明对于更长的序列长度,Sparse Attention可以更有效地缩短模型训练的时间。

speed
为了探究稀疏度与训练Loss和计算量之间的关系,还测试了不同稀疏度下的验证损失Validation Loss和计算量FLOPs的实验,如图7所示。模型的稀疏度与Query-sparse中的保留token比例r和KV-sparse中的动态阈值p(动态阈值通过每一次计算得到的注意力分数来选取的k个关键值得到)相关,并且也存在trade-off的权衡。当sparsity为0时,代表的是Full Attention的训练结果。从图7中可以发现,当Sparse Attention的稀疏度在0-0.93时,validation loss与Full Attention的Validation loss几乎没有区别,并且FLOPs随着稀疏度的增加而下降。但是当Sparse Attention的稀疏度超过0.95,虽然计算量FLOPs仍在减少,但是validation loss却变得很大,这说明在这个稀疏度下无法实现无损的生成质量。而当稀疏度为0.93附近时,是一个最优的结果,即既可以实现无损甚至更好的生成效果,还可以减少13x的计算量FLOPs。
如图5所示,展示了4个分别在不同序列长度上的生成视频不同帧下的T2V生成结果,分别包括不同帧数下较低分辨率(448✖️832)和高分辨率(782✖️1280)。如图中4个不同的例子展示所示,所提出的Sparse Attention生成的视频与Full attention相比可以达到无损的效果。

vis

sota
如表2所示,与最相关的基于训练的稀疏注意方法(如MoBA和VSA)进行了详细的比较。BSA在加速比方面比MoBA和VSA都有明显的优势,对于23k序列长度,可以达到12.85x的attention加速,但是目前training-based最优的VSA仅可以实现4.5x的attention加速比。并且与这些稀疏注意力方法相比,也提供了更好的生成质量。
为了探究Query-sparse和KV-sparse对加速效果和生成质量的影响,分别对其进行了详尽的消融实验,如表3所示。采取Full Attention为基线在表2的第5行,总体的方法展示在最后一行,并且分别在第1-4行来计算Query-sparse及其window窗口、KV-sparse及其统计动态阈值对加速效果和生成质量的影响。
如表2的最后一行显示,结合了Query-Sparse 和KV-Sparse的方法在相当的validation loss和生成质量的情况下实现了最大的稀疏度0.93和最大的加速比12.85倍。这得益于Query-Sparse 和KV-Sparse是可以正交实现的,两者达到的稀疏效果可以进行叠加,达到最优的加速效果,并且不会损害生成质量,验证了稀疏注意力的有效性。并且需要强调的是,稀疏方法所增加的计算量很小,几乎可以忽略不计,这也显示了Sparse Attention方法的高效性。
视频扩散Transformer(DiT)模型在生成质量方面表现优异,但在生成高分辨率长视频时遇到了主要的计算瓶颈。Full Attention的二次复杂度会增加训练/推理成本。 为了克服这一限制,提出了一个双向稀疏注意(BSA)框架,用于更快的视频DiT训练,这是第一个提出双向Query-KV动态稀疏化的框架,从而提高了训练和推理效率。完全关注效率低下源于两个关键挑战:由于查询和键值对固有的稀疏性而导致的过度计算,以及由于固定的稀疏模式无法利用DiT的动态关注而导致的冗余计算 。BSA通过两个关键组件来解决这些问题,查询稀疏性通过语义相似度和动态时空训练策略选择信息量最大的查询令牌来优化,而KV稀疏性通过计算统计动态阈值并仅保留关键KV块进行计算来实现。 大量实验表明,BSA显著加速了长序列的DiT训练,将FLOPs降低了20倍,实现了17.79倍的注意力训练速度,同时保持甚至超过了完Full Attention的生成质量。
[1] Bidirectional Sparse Attention for Faster Video Diffusion Training
如果您觉得这篇文章对你有帮助或启发,请不吝点赞、在看、转发,让更多人受益。同时,欢迎给个星标⭐,以便第一时间收到我的最新推送。每一个互动都是对我最大的鼓励。让我们携手并进,共同探索未知,见证一个充满希望和伟大的未来!