如您有工作需要分享,欢迎联系:aigc_to_future
作者:Chenlu Zhan、Wen Li等
解读:AI生成未来
论文链接:https://arxiv.org/pdf/2509.01085
亮点直击
BSA——一种可训练的双向动态稀疏注意力框架,该框架首次在视频扩散训练中对全注意力机制中的查询(Query)及键值对(Key-Value)进行正交稀疏化处理以加速训练过程。 为查询块和键值块设计了不同的动态稀疏化策略,有效捕捉训练过程中的注意力变化特性,实现超越固定模式的自适应标记选择。 在Wan2.1-1.3B数据集表明:BSA可实现最高20倍的浮点运算量降低、17.7倍的训练加速以及6倍的推理加速,同时保持或超越全注意力机制的生成质量。
由于 DiT 模型采用Full Attention机制,计算量随序列长度增加而呈二次方增长,计算复杂度为(其中 L 为 token 序列长度)。这直接导致在训练与推理过程中的计算成本急剧攀升,严重制约了 DiT 模型在高分辨率长视频生成任务中的实用性与效率,因此亟待针对性的优化方案来解决这一核心限制。为了解决上述问题,提出了一种可训练的双向动态稀疏注意力加速框架,首次对3D Full Attention中的Query和Key-Value 对分别进行动态稀疏化计算,同时设计了不同的动态稀疏化策略来提升训练、推理效率。
双向Query-Key稀疏注意力:对于Query稀疏,通过对比token之间语义相似度来高效的选取Query内部关键的query token,动态优化query的稀疏性。对于Key-Value稀疏,只计算选取的关键KVBlock中的token。 动态稀疏注意力训练策略:分别针对KV block和Query block的动态稀疏性均设计了不同的动态策略。对于KV block稀疏,对不同的Query动态选择对应最关键的KV token,根据每一个训练step输入的block之间的注意力分数,动态选择关键 token 直至累积分数达到目标阈值p。 对于Query动态稀疏策略,分别针对时间、空间动态稀疏来选择不同的block稀疏度。
大量实验表明,该方法显著加速了视频扩散模型在不同长序列上的端到端训练速度,获得了最大20倍的FLOPs减少和17.7倍的注意力训练加速,同时获得了与Full Attention相当甚至更好的生成质量,除此之外,也可以在不降低推理质量的情况下加速推理速度,在H100上将端到端的推理延迟从31s降低到5.2s ( 6.2x )。
问题与发现
解决的问题
视频 DiT 在训练全分辨率、长序列数据时,大部分计算资源都耗费在注意力上,它可以消耗高达95 %的处理时间,且训练后的 DiT 在推理阶段仍速度缓慢,这使得注意力计算成为视频 DiT 缩放的首要瓶颈。为了改善这一状况,近期很多工作提出了多种稀疏注意力机制。它们的核心思路是让每个查询Query仅与KV键值对的部分子集进行交互,以此来降低计算的复杂程度。它们只关注KV键值对中的部分冗余子集,却忽略了Query查询序列中同样存在大量的冗余信息,这会导致大量的重复计算。除此之外,绝大多数稀疏注意力机制大多被设计成无需训练的形式。这些未经过训练的方法通过直接截取部分KV子集来进行注意力计算,在实际训练中往往只能得出欠佳的结果。
关键发现
为了设计高效的注意力训练框架,对当前Full Attention的训练延迟进行了特异性分析,并揭示了以下两个关键发现:
(1)Full Attention中的查询Query和Key-Value序列均具有较大稀疏性而导致过多的计算浪费。
对于查询Query来说,视频帧之间及帧内 token 存在大量重复语义(如静态背景、连续动作的相似帧)导致 token 数庞大。如图2所示,Frame3和Frame12中的Full Attention中的query热力图呈现高度相似的表法,说明这些token提供相同的语义特征,对所有Query的序列token进行注意力查询计算会导致严重的计算浪费。 对于KV键值对来说,token序列计算得到的注意力分数具有长尾效应,只有部分关键KV子集于每个查询Query具有强相关性,这一小部分计算显著影响最终的输出。因此,只需要计算小部分关键令牌就可以在不影响生成质量的情况下显著降低的计算成本。
(2)DiT中的注意力计算呈现动态稀疏性。动态稀疏性分别体现在Query和KV的时间、空间动态稀疏性。
空间动态稀疏:不同的 Query 所对应的关键 KV 对子集本应是动态变化的,如果采用固定的稀疏化策略,则无法适应时空的动态稀疏,过选会造成计算冗余,漏选则会产生精度损失,因此需要设计动态稀疏策略来适配DiT中本身的动态稀疏性。 时间动态稀疏:随着训练training step 的推进,稀疏度是随时间变化的,前期注意力会获取主要的全局信息,而后期注意力查询则只关注于更高语义层次的局部特征,稀疏度随着训练逐渐增大。
为了解决上述挑战,提出了一种可训练的双向动态稀疏注意力(BSA,Bidirectional Sparse Attention for Faster Video Diffusion Training)加速框架,首次对3D Full Attention中的Query和Key-Value 对分别进行动态稀疏化,同时设计了不同的动态稀疏化策略来提升训练、推理效率。
方法
1. Sparse Attention 回顾
现代视频扩散 Transformer(DiT)使用 3D Full Attention来捕捉整个视频体积内的依赖关系,在Full Attention中,Q、K、V中的所有序列令牌都参与交互和计算。而Sparse Attention通过从KV对中选择关键子集和来减少总体计算量,旨在提高效率。注意力输出O计算如下:
2. 方法架构
2.1 整体框架
如图3 所示,方法框架主要分成三部分: (a)为注意力序列立方体划分,将视频 latent 划分为时空立方体(Block),通过均值池化生成块级表示来有效地筛选关键信息。 (b)提出的Query-Sparse方法,分别基于Query的语义冗余特征来高效的选取最优query token,并根据时间空间动态稀疏性设计动态稀疏策略。 (c)提出的动态KV-sparse方法,对不同的Q选择对应最关键的KV token,动态选择关键 token 直至累积分数达到目标阈值p,无需预设固定稀疏模式,适应不同输入内容的稀疏需求。
2.2 立方体划分
给定一个形状为的视频,为了可以高效地以较低的计算成本来选择关键token子集,采用将多个token组合成一个较大的立方体block的形式来进行初步的选择。对于输入查询 、键、值 ,将视频 latent划分为大小为的立方体,每个立方体对应 GPU 上的一个块(block),块大小。然后对每个立方体的 tokens 进行均值池化,得到块级查询 、键、值 。视频中的每个立方体映射为GPU SM上的单个瓦片来协同设计稀疏注意力算法及其核心实现。
3. Query-Sparse
视频数据本身具有多帧的时间相关性和每帧帧内的空间相关性,因此存在时空信息冗余。实验测试显示在视频扩散模型中,约 4% 的空间邻近 token 贡献了 80% 的注意力分数,可以去除冗余token的情况下实现无损性能。因此考虑到每个query查询序列中也会存在很大的信息冗余(如静态背景、连续动作的相似帧),主要的语义(如物体类别、动作趋势)由少量关键 token 主导,丢弃相似语义的冗余 token 不会破坏整体语义结构。
基于此发现,提出了基于特征冗余的query token稀疏化方法。详细地说,对于查询Query设查询分成 个块 ,块 的token集合为 ,对应中心token为。 发现基于分块后的同一block内的 token(如空间邻近的像素块)通常包含很多语义高度相似的特征,中心 token 在时间空间维度上可作为该区域的语义代表,可以计算块内其他token与中心语义代表token之间的特征相似度,使用余弦相似度或点积衡量中心 token 与周围 token 的语义相似性,避免平均池化的 "一刀切" 信息损失,对于每个block之内的token进行局部时空窗口内计算相似性,然后对每个block内保留部分不冗余的tokens,这些token便可以贡献关键的注意力分数,而去除的冗余token由于所代表的特征信息与其他token重复,因此即便去除了也可以实现无损性能,不会破坏语义结构。对每个块分别按剪枝率 保留部分token,最后将所有block内的保留下来的关键token进行拼接,构成新的无冗余的查询Query ,具体生成方式如下所示:
其中, 表示在块b 内根据从大到小排序后的排名,是块b中的 token 数量,是保留比例。
4. KV-Sparse
基于立方体划分后的块级表示,可以让每个查询Query仅与KV键值对的部分子集进行交互,以此来大量降低计算的复杂程度。但是如何确定每个查询Query对应的关键KV键值对子集是一个非常重要的问题。在实验中发现,稀疏性在注意力块之间和同一块内之间存在显著差异,并且对于每一个query查询对应的关键kv对也是动态变化的,不应该采用固定的top-k选择方式来统一固定对每个query进行关键kv的选择。
因此提出了基于统计阈值的动态KV-Sparse稀疏方法,分别针对每个Query选取动态的关键KV对,并通过输入注意力分数的统计特性来计算得到动态的稀疏阈值来选取关键KV对,无需预设固定稀疏模式,适应不同输入内容的稀疏需求。
首先先对每个立方体的 tokens 进行均值池化,得到块级查询 、键、值 ,然后进行块选择Key Block Select ,计算块间注意力得分 ,通过动态统计阈值 选择关键块(保留高注意力值的块)。然后再将稀疏化的每个查询Query block 分别与选取到的关键KV对仅在关键块内进行 token 级注意力计算。动态稀疏分别体现在两方面:
获取基于统计的动态阈值p。对于每一次计算得到的块间注意力得分 ,可以通过计算query与KV对每次得到的注意力得分 中所有分数的均值和标准差,计算出一个可以选出k个关键样本的动态阈值p,也就是说根据输入注意力分数的统计特征去选出根据统计分布的关键KV对,而不是人为的截取对应关键kV对。
通过动态阈值选取不同Query的关键KV对。针对选取的关键block(假设K个),针对每一个query block分别和KV做计算,并且分别根据超过统计的动态阈值来进行动态索引选择:对每个query block i,选择最小索引集,确保所选注意力分数之和不低于阈值。
最终的稀疏注意力:设稀疏化后的查询矩阵为,(其中为稀疏化后的查询 token 数量),筛选出的关键键矩阵为 、关键值矩阵为 。其中对应所有 query block 选出的关键 KV 对键集合;对应相应的值集合;稀疏掩码矩阵为 ,稀疏掩码矩阵,保证只计算选中选中的 query 与 KV 交互对应的注意力。稀疏注意力输出可以表示为:
实验
基于Wan2.1-1.3B模型架构进行T2V任务的模型训练,重新初始化进行training from scatch,所有的模型训练均训练至完全收敛,以保证公平比较。
Loss比较
如图4所示,Sparse Attention与Full Attention基线的预训练损失曲线相重合,均表现出稳定且平滑的下降趋势,并且大部分优于Full Attention 模型。
Efficiency和Quality比较
如表1 所示,在2个不同的分辨率上对Sparse Attention 和Full Attention 进行from strach训练,分别为61 × 448 × 832,23K令牌)的原始分辨率,和扩展的更长token长度( 157x768x1280 , 153K令牌)。进行Sparse Attention和Full Attention在效率和生成质量上的对比。
在原始序列长度(23k tokens)下,Sparse Attention比Full Attention的获得了12.85倍的加速比,并且实现了93%的稀疏度,FLOPs为Full Attention的7%。除此之外,在加速的同时,BSA体现出了强大的生成质量,它在Vbench的4个一致性测量指标上优于Full Attention,尤其是在背景一致性上。这说明了Sparse Attention 可以在较短序列长度上也可以实现较大的加速训练,同时也可以达到更好的生成效果。 在更长的序列长度(153k tokens)下,Sparse Attention在加速比和生成质量的优势上更加明显。具体来说,BSA与Full Attention模型训练相比,获得了17.79倍的加速比,稀疏度可以达到95%,FLOPs计算也可以达到Full Attention的5%。并且它在生成质量上相对于Full Attention的提升幅度也更大,尤其是文本一致性和背景一致性。这种改进主要是源于对于更长的序列长度,那么模型训练时Attention计算的占比也更多,由此可以达到的稀疏度和加速比都会随之增大。
Training on Longer Sequences 在不同序列长度上的对比
为了评估BSA在不同序列长度上的训练加速效果,分别在5种不同序列长度上进行训练加速比测试。所有的模型训练设置均保持一致来保证训练的公平性,结果如图6所示。详细地说,分别测试了23k、44k、59k、117k、153k序列长度,加速比随着序列长度的增加逐渐增大。当序列长度为最小的23k的时候,加速比也可以达到12.85x,当序列长度增加为其2倍的44k的时候,加速比可以增加至14.72x。对于当前测试的最长的序列长度153k时,最大加速比可以达到17.79倍,由此说明对于更长的序列长度,Sparse Attention可以更有效地缩短模型训练的时间。
Sparse Adaptation 稀疏度讨论
为了探究稀疏度与训练Loss和计算量之间的关系,还测试了不同稀疏度下的验证损失Validation Loss和计算量FLOPs的实验,如图7所示。模型的稀疏度与Query-sparse中的保留token比例r和KV-sparse中的动态阈值p(动态阈值通过每一次计算得到的注意力分数来选取的k个关键值得到)相关,并且也存在trade-off的权衡。当sparsity为0时,代表的是Full Attention的训练结果。从图7中可以发现,当Sparse Attention的稀疏度在0-0.93时,validation loss与Full Attention的Validation loss几乎没有区别,并且FLOPs随着稀疏度的增加而下降。但是当Sparse Attention的稀疏度超过0.95,虽然计算量FLOPs仍在减少,但是validation loss却变得很大,这说明在这个稀疏度下无法实现无损的生成质量。而当稀疏度为0.93附近时,是一个最优的结果,即既可以实现无损甚至更好的生成效果,还可以减少13x的计算量FLOPs。
Qualitative Results 定性实验结果
如图5所示,展示了4个分别在不同序列长度上的生成视频不同帧下的T2V生成结果,分别包括不同帧数下较低分辨率(448✖️832)和高分辨率(782✖️1280)。如图中4个不同的例子展示所示,所提出的Sparse Attention生成的视频与Full attention相比可以达到无损的效果。
Comparison with Other Sparse Attentions 与其他SparseAttention方法对比
如表2所示,与最相关的基于训练的稀疏注意方法(如MoBA和VSA)进行了详细的比较。BSA在加速比方面比MoBA和VSA都有明显的优势,对于23k序列长度,可以达到12.85x的attention加速,但是目前training-based最优的VSA仅可以实现4.5x的attention加速比。并且与这些稀疏注意力方法相比,也提供了更好的生成质量。
Ablation Study
为了探究Query-sparse和KV-sparse对加速效果和生成质量的影响,分别对其进行了详尽的消融实验,如表3所示。采取Full Attention为基线在表2的第5行,总体的方法展示在最后一行,并且分别在第1-4行来计算Query-sparse及其window窗口、KV-sparse及其统计动态阈值对加速效果和生成质量的影响。
Query-Sparse
Original Query-sparse:在没有进行KV-sparse的基础上,通过表2的第1行可以发现,当保持prune rate为0.5时,可以达到无损的验证结果,在effciency方面,并且可以实现1.96x的加速比,减少50%的计算量。除此之外,在Vbench上的测试结果也都优于Full Attention。 Query-sparse with window size selection:还测试了采用window size来根据多个center token来选取有效token的方法。这说明了with window block selection可以更好地选取包含有效语义的tokens,而不会被冗余token干扰。
KV-Sparse
Original KV-sparse:在没有Query-sparse的基础上,基于阈值的KV-sparse可以实现0.86的稀疏度和6.05x倍的训练加速,还节省了将近8.6倍的计算量。除此之外,总体生成效果与Full Attention相比还是可以达到无损的结果。 KV-sparse with stastic dynamic threshold:还测试了加上动态统计阈值的KV-sparse。从表2中的结果可以验证,这种基于统计信息的动态阈值可以在相同validation loss的基础上实现更高的稀疏度,并且在生成质量相当的情况下实现更高的训练加速比和更少的计算量FLOPs。
Query-Sparse + KV-Sparse
如表2的最后一行显示,结合了Query-Sparse 和KV-Sparse的方法在相当的validation loss和生成质量的情况下实现了最大的稀疏度0.93和最大的加速比12.85倍。这得益于Query-Sparse 和KV-Sparse是可以正交实现的,两者达到的稀疏效果可以进行叠加,达到最优的加速效果,并且不会损害生成质量,验证了稀疏注意力的有效性。并且需要强调的是,稀疏方法所增加的计算量很小,几乎可以忽略不计,这也显示了Sparse Attention方法的高效性。
结论
视频扩散Transformer(DiT)模型在生成质量方面表现优异,但在生成高分辨率长视频时遇到了主要的计算瓶颈。Full Attention的二次复杂度会增加训练/推理成本。 为了克服这一限制,提出了一个双向稀疏注意(BSA)框架,用于更快的视频DiT训练,这是第一个提出双向Query-KV动态稀疏化的框架,从而提高了训练和推理效率。完全关注效率低下源于两个关键挑战:由于查询和键值对固有的稀疏性而导致的过度计算,以及由于固定的稀疏模式无法利用DiT的动态关注而导致的冗余计算 。BSA通过两个关键组件来解决这些问题,查询稀疏性通过语义相似度和动态时空训练策略选择信息量最大的查询令牌来优化,而KV稀疏性通过计算统计动态阈值并仅保留关键KV块进行计算来实现。 大量实验表明,BSA显著加速了长序列的DiT训练,将FLOPs降低了20倍,实现了17.79倍的注意力训练速度,同时保持甚至超过了完Full Attention的生成质量。
参考文献
[1] Bidirectional Sparse Attention for Faster Video Diffusion Training
致谢
如果您觉得这篇文章对你有帮助或启发,请不吝点赞、在看、转发,让更多人受益。同时,欢迎给个星标⭐,以便第一时间收到我的最新推送。每一个互动都是对我最大的鼓励。让我们携手并进,共同探索未知,见证一个充满希望和伟大的未来!
技术交流
加入「AI生成未来社区」群聊,一起交流讨论,涉及 图像生成、视频生成、3D生成、具身智能等多个不同方向,备注不同方向邀请入群!可添加小助手备注方向加群!
没有评论:
发表评论