轻量级实现:ai-algorithms/DIFF_Transformer.ipynb at main · Jaykef/ai-algorithms · GitHub
前言
“Differential Transformer.pdf”介绍了一种名为DIFF Transformer的新型架构,该架构通过差分注意力机制减少无关上下文的干扰,在语言建模和多种下游任务中展现出优势。
1. 研究背景和动机
- Transformer在注意力分配上存在问题,倾向于过度关注无关上下文,导致在关键信息检索等方面面临挑战。
- 提出DIFF Transformer,旨在解决这些问题,提高模型性能和效率。
2. DIFF Transformer架构
2.1 差分注意力机制(Differential Attention)
- 计算方式
- 将输入投影到查询、键和值向量(Q1,Q2,K1,K2,V)。
- 通过两个softmax函数计算注意力得分,即$DiffAttn(X)=(softmax(\frac{Q_{1}K_{1}^{T}}{\sqrt{d}})-\lambda softmax(\frac{Q_{2}K_{2}^{T}}{\sqrt{d}}))V$,其中λ是可学习的标量,经过特定的重新参数化方式初始化。
- 多头部机制(Multi-Head Differential Attention)
- 使用多个注意力头,每个头采用相同的差分注意力计算方式,头之间共享λ。
- 对每个头的输出进行归一化和投影,得到最终结果。
2.2 整体架构(Overall Architecture)
- 由L层堆叠而成,每层包含一个多头部差分注意力模块和一个前馈网络模块。
- 采用预RMSNorm和SwiGLU等改进措施。
3. 实验
3.1 语言建模评估(Language Modeling Evaluation)
- 实验设置
- 训练3B大小的DIFF Transformer和Transformer语言模型,在隐藏大小、层数、头维度等方面设置相同,采用相同的优化器和训练数据等。
- 结果
- 在LM Eval Harness基准测试中,DIFF Transformer相比其他基于Transformer的语言模型取得了较好的性能。
3.2 与Transformer的可扩展性比较(Scalability Compared with Transformer)
- 模型大小扩展
- 训练不同参数数量的语言模型,结果显示DIFF Transformer在各种模型大小下都优于Transformer,例如6.8B大小的DIFF Transformer可达到与11B大小的Transformer相当的验证损失,但参数仅需其62.2%。
- 训练令牌数量扩展
- 对3B模型评估不同训练令牌数量下的性能,DIFF Transformer用160B令牌训练可达到与用251B令牌训练的Transformer相当的性能,仅消耗63.7%的训练令牌。
3.3 长上下文评估(Long-Context Evaluation)
- 将3B大小的模型扩展到64K上下文长度继续训练,结果表明DIFF Transformer能更有效地利用长上下文,在累积平均负对数似然(NLL)指标上优于Transformer。
3.4 关键信息检索(Key Information Retrieval)
- 实验设置
- 采用多针检索评估协议,在不同长度的上下文中插入“针”(包含关键信息的句子),并在不同深度和数量的情况下进行检索。
- 结果
- 在4K和64K上下文长度的检索实验中,DIFF Transformer在关键信息检索方面表现出优越的能力,随着“针”数量和查询城市数量增加,其准确性保持稳定,而Transformer性能显著下降。同时,DIFF Transformer分配给答案跨度的注意力得分更高,注意力噪声更低。
3.5 上下文学习(In-Context Learning)
- 多示例分类(Many-Shot In-Context Learning)
- 在支持64K输入长度的3B大小语言模型上进行实验,结果显示DIFF Transformer在不同数据集和不同数量的示例下,多示例分类的准确性均优于Transformer,平均准确率提高幅度从5.2%到21.6%。
- 上下文学习的鲁棒性(Robustness of In-Context Learning)
- 在TREC数据集上评估两种提示格式(示例随机排列和按类别交替排列)下上下文学习的鲁棒性,结果表明DIFF Transformer的性能方差远小于Transformer,对顺序排列更具鲁棒性。
3.6 上下文幻觉评估(Contextual Hallucination Evaluation)
- 在文本摘要和问答任务上评估3B大小模型的上下文幻觉,通过GPT - 4o进行判断,结果显示DIFF Transformer相比Transformer在文本摘要和问答任务上减少了上下文幻觉,这可能是因为DIFF Transformer更好地聚焦于任务所需的关键信息。
3.7 激活异常值分析(Activation Outliers Analysis)
- 激活值统计
- 分析Transformer和DIFF Transformer模型中注意力对数和隐藏状态的激活值,结果表明DIFF Transformer的顶级激活值远低于Transformer,产生的激活异常值更少。
- 注意力对数量化
- 对注意力对数进行量化实验,结果显示DIFF Transformer在降低比特宽度时仍能保持高性能,而Transformer在6位量化时准确性显著下降。
3.8 消融研究(Ablation Studies)
- 对1.4B大小的语言模型进行消融研究,结果表明DIFF Transformer的性能优势主要来自差分注意力机制,而非配置或归一化模块,并且模型对λ的初始化选择具有鲁棒性。
4. 结论
- DIFF Transformer在语言建模的多个方面优于Transformer,包括缩放特性、长上下文建模、关键信息检索、幻觉缓解、上下文学习和激活异常值减少等。
- 差分注意力机制可通过FlashAttention轻松实现,未来可基于激活异常值减少的特性开发高效的低比特注意力内核和压缩关键值缓存。