🚀 增强版SwarmFormer:高效序列建模的智能新方法
本项目持续致力于优化高效序列建模架构,对 SwarmFormer 进行了一系列更新,显著提升了其性能、可扩展性和稳定性。新设计引入了 分层注意力、动态聚类和门控反馈机制,使模型能够更有效地处理长序列,同时降低计算开销。
🚀 快速开始
本项目主要聚焦于对SwarmFormer模型的改进,暂未提供具体的代码运行等快速开始的操作步骤。若有相关代码及使用需求,可根据后续补充的代码部分进行操作。
✨ 主要特性
为何改进SwarmFormer?
原始的SwarmFormer引入了 令牌 - 集群交互模型,其中令牌自组织成集群,在更高层次上交换信息,然后传播回精炼的表示。虽然这种方法有效地处理了长距离依赖关系,但它存在一些局限性:
- ❌ 固定的集群分配 导致令牌分组僵化。
- ❌ 用于局部注意力的滚动移位 并非捕捉细粒度依赖关系的最佳方式。
- ❌ 集群到令牌的更新缺乏门控,导致更新嘈杂。
- ❌ 注意力层中没有权重共享,增加了参数数量。
为了解决这些问题,我们引入了 一系列关键改进,在保持计算效率的同时提高了模型的表达能力。
新SwarmFormer架构的关键改进
- 使用局部窗口注意力取代滚动移位:我们用 局部窗口注意力(类似于滑动变换器和卷积滤波器)取代了滚动移位注意力。这允许更有效地提取局部特征,而无需冗余移位,从而改善局部建模。
- 对集群应用多头注意力:我们没有对集群使用单一的注意力机制,而是应用了 多头自注意力(MHA)。这使每个注意力头能够学习 不同的集群 - 令牌关系,从而改善上下文表示。
- 使用令牌到集群的门控取代均匀分块:以前,令牌是 均匀分配到集群 的,这限制了灵活性。我们现在使用基于注意力的动态路由机制,允许令牌 自适应地选择其集群。这提高了集群形成中的 语义连贯性。
- 引入门控反馈机制以实现稳定的令牌更新:我们不再直接从集群更新令牌嵌入,而是引入了 残差MLP门控机制。这过滤掉了 嘈杂的集群更新,确保只有 相关信息 被传播回令牌。
- 在每个MLP和注意力块之前进行层归一化:我们发现,在每个前馈和注意力层之前添加 层归一化(LayerNorm) 显著稳定了训练,改善了梯度流和收敛性。
- 在集群注意力中进行线性投影的权重绑定:为了 在不影响表达能力的情况下减小模型大小,我们现在在 GlobalClusterAttention 模块中的 查询、键和值投影 之间共享权重。这种优化减少了可训练参数的数量,同时保持了性能。
- 采用金字塔结构的分层聚类:我们不再在所有层使用 固定的集群大小,而是实现了 分层金字塔:
- ✅ 较低层 专注于 细粒度的局部交互(更多集群)。
- ✅ 较高层 处理 抽象的、粗粒度的表示(较少集群)。
这种 多尺度集群形成 允许模型在不丢失局部细节的情况下有效地 传播高层信息。
- 使用Gumbel - 软最大化进行可微聚类:为了提高 集群分配的可训练性,我们实现了 Gumbel - 软最大化采样。这使模型能够通过反向传播学习 集群分配,允许强化信号(如集群连贯性)指导优化。
🔧 技术细节
计算复杂度分析
原始SwarmFormer的计算复杂度
- 令牌到集群的注意力:在原始的SwarmFormer中,每个令牌关注所有集群,复杂度为 (O(NCd)),其中 (N) 是序列长度,(C) 是集群数量,(d) 是隐藏维度。
- 集群到令牌的广播:每个集群更新所有令牌,复杂度同样为 (O(NCd))。
- 总复杂度:原始SwarmFormer的总复杂度为 (O(2NCd))。
新SwarmFormer的计算复杂度
- 局部窗口注意力取代滚动移位注意力:每个令牌只关注大小为 (w)(通常 (w \ll N))的局部窗口,复杂度为 (O(Nwd)),取代了滚动移位操作,显著降低了成本。
- 多头集群注意力与权重共享:在原始版本中,查询、键和值投影有单独的权重。现在,我们在这些投影之间共享权重,减少了集群注意力层中的参数数量和浮点运算次数。注意力复杂度仍为 (O(NCd)),但矩阵乘法次数减少。
- 令牌到集群的门控:令牌不再统一分配到集群,而是根据学习到的路由选择性地更新集群。这将从所有令牌到所有集群的更新数量减少到只有一小部分 (p) 的令牌参与,复杂度为 (O(pNCd)),其中 (p < 1)。由于 (p) 通常为 0.5 或更低,这显著减少了计算量。
- 门控反馈机制(MLP过滤):我们不再直接从集群更新令牌嵌入,而是引入了残差MLP门控机制。MLP的复杂度为 (O(Nd^2)),但它过滤掉了嘈杂的集群更新,确保只有相关信息被传播回令牌,减少了后续层的有效计算量。
- 金字塔结构的分层聚类:我们不再在所有层使用固定的集群大小,而是实现了分层金字塔。较低层有 (C) 个集群,中间层有 (C/2) 个集群,顶层有 (C/4) 个集群。这导致聚类计算的有效减少,复杂度为 (O(NCd + NC/2d + NC/4d + …)),形成一个几何级数,降低了总计算成本。
最终复杂度比较
模型 |
复杂度 |
原始SwarmFormer |
(O(2NCd)) |
新SwarmFormer |
(O(Nwd + pNCd + Nd^2)) |
由于 (w \ll N)(窗口注意力降低了成本),(p < 1)(较少的集群更新),(d^2) 项仅在小的MLP中,而不是在完整的注意力层中,并且分层聚类减少了总的集群交互,我们得到 (O(NCd) > O(Nwd + pNCd + Nd^2)),这表明新架构在计算上更高效。
结论:新SwarmFormer更高效
- ✅ 更低的浮点运算次数:由于窗口注意力和分层聚类,新架构的浮点运算次数更低。
- ✅ 更少的冗余更新:门控反馈和令牌到集群的门控减少了冗余更新。
- ✅ 权重共享:进一步减少了参数数量。
总结:🚀 新的SwarmFormer架构在保持或提高性能的同时,实现了更快的训练和推理!
参考资料
@article{legg2025swarmformer,
title={SwarmFormer: Local-Global Hierarchical Attention via Swarming Token Representations},
author={Legg, Jordan and Sturmanis, Mikus and {Takara.ai}},
journal={Takara.ai Research},
year={2025},
url={https://takara.ai/papers/SwarmFormer-Local-Global-Hierarchical-Attention-via-Swarming-Token-Representations.pdf}
}
📄 许可证
本项目采用Apache - 2.0许可证。