WheatField
WheatField

浅聊 Mixtral-8x7B 的一些亮点

January 11, 20242144 words, 11 min read
Authors

Mistral.ai 在 arXiv 上放出来论文 Mixtral of Experts,整体结构上跟 Mistral 7B 一样,唯一的区别在于每个层都由 8 个前馈块(即专家)组成。

Mistral 7B

首先聊一下 Mistral 7B(论文链接: https://arxiv.org/pdf/2310.06825.pdf, arXiv 23.10),这篇工作在当时的亮点有:

  • 性能:- 如下图所示, Mistral 7B 在多项基准测试中都优于当时最佳开源模型 Llama 2 13B,在推理、数学和代码生成方面优于 Llama 1 34B - 提供了一个经过指令微调的模型,Mistral 7B-Instruct,在 Chatbot Arena Leaderboard 上得分 1031 分(论文发布时),超过了 Llama 2 13B 的 1012 分。 排行榜上的得分是通过人类对模型生成的对话打分来计算的,分数越高,模型生成的对话越接近人类的对话。
  • 结构特点: 利用分组查询注意力(GQA)来实现更快的推理,并结合滑动窗口注意力(SWA)来有效地处理任意长度的序列,同时降低推理成本。
Mistral 7B 基准测试结果

做为对比,截止到 2024/01/20, Mistral 8x7b instruct 在 Chatbot Arena Leaderboard 仍然排在前十的位置,前一段时间的排名更高,居第4名,仅次于 GPT-4。Mistral 的另一个模型 Mistral medium 效果更好,在 Arena ELO 上排到第 5 名,在 MT-bench 上得分 8.61 分,除了 GPT 4 之外,超过了所有其他模型,但这个模型目前还没有开源。 有意思的是,Claud 的两个新版本 2.0, 2.1 反而一代不如一代。

注:MT-bench, multi-turn question set, 是一个用于评估模型多轮对话及指令遵循能力的数据集。目前在这个 set 上,能力最强的模型是 GPT-4-Turbo,得分 9.32 分。

Chatbot arena ELO ranking (2024/01/20)

当然,强如 mistral medium,也还是没有通过“鲁迅为什么要打周树人”的幻觉测试。

Mistral medium 幻觉
鲁迅 AI

Sliding Window Attention (SWA)

在 vanilla attention 中,每个 token 都会与所有其他 token 进行交互,这样的话,计算复杂度就是 O(n2)O(n^2),其中 nn 是序列长度。这样的计算复杂度在序列很长时会很高,一种更高效的方法是 Sliding Window Attention (SWA)。

SWA 最早在 LongFormer中提出,属于 sparse attention 的一种,其目标是让 LLM 更有效的处理长序列。思想是把每个 token 的注意力跨度限制到它周围的固定窗口中,LongFormer中的做法是定义一个宽度为WW的窗口,使得 query 节点只能注意到对等的 key 节点以及 key nodes 左右 W/2W/2 个节点。这样的话,计算复杂度就是 O(Wcdotn)O(W cdot n),其中 nn 是序列长度。当 W<<nW<<n,计算复杂度就会大大降低。

这样,其他 key nodes 的信息是不是丢失了?也不是,多个层堆叠在一起时,在高层上,query 节点会间接的注意到远处的 key 节点。假设 ll 层,每层的窗口大小为 WW,那么在第 ll 层,query 节点可以间接的注意到 WlW*l 范围内的 key 节点。

Mistral 7B 中的 window size 是 4096.

Grouped Query Attention (GQA)

分组查询注意力(GQA,论文 :https://arxiv.org/pdf/2305.13245.pdf )是在多查询注意力(Multi-Query Attention,MQA)和多头注意力(Multi-Head Attention,MHA)之间找到一个平衡点,其目标是在保持MQA速度的同时实现MHA的质量。

GQA, MQA, MHA 三者的关联

MQA 使用单一的键值头,这能加快解码器的推理速度,但可能导致解码质量下降。具体来说,MQA减少了模型在执行注意力操作时可利用的信息量 从而可能导致在处理复杂任务或长序列生成时的性能下降。例如,在文本摘要任务中,MHA允许模型在计算注意力时考虑多个不同的特征组合 从而捕捉更丰富的语言特征。而MQA由于头的限制,可能无法捕获这样的复杂特征,导致生成的摘要可能不够准确。

GQA 将查询头分为 G 组,每个组共享一个键头和值头。GQA 通过减少解码器推理期间所需的内存带宽来提高 LLMs 的效率。

优势

相对 MHA 来说,GQA 显著降低 LLM的计算复杂性,从而提高推理速度;对内存的占用也更少,适用于限制内存大小的LLMs。相对 MQA 来说,GQA 提高了解码质量,但速度稍慢。还有一点,MQA 支持多GPU并行,有效地利用计算资源。

Byte-fallback BPE algorithm

BPE算法大家应该都比较熟悉,一开始 BPE是做为压缩文本的算法而开发的,后来由 OpenAI 在预训练 GPT 模型时用于标记化。这是一种基于统计的分词算法,它将词汇表中的每个单词拆分成子单词,然后将子单词组合成新的单词。这样的话,就可 处理未知词汇,提高模型的泛化能力。很多Transformer 模型中都有应用,比如 GPT、GPT-2、RoBERTa、BART 和 DeBERTa。

那这里的 byte-fallback 又是什么?我们知道分词器在遇到未知词汇时,会将其标记为 Unk,但这样的话,模型就无法学习到未知词汇的信息。Byte-fallback BPE 算法就是为了解决这个问题而提出的,它把所有 256 个 UTF-8 字节码单元添加到词汇表中,任何未知的 unicode 字符可以被分解为字节码单元,这种方法为单词提供了一种独特且意义更加明确的表达形式。

Mistral 中使用了 byte-fallback BPE 算法,词汇表大小为 32K,其中 256 个字节码单元占了 1K,剩下的 31K 是由 BPE 算法生成的。

Sparse Mixtures of Experts

SMoE 是 mixtral-8x7B 的重心之一。

给定输入 xx,MoE 模块的输出 y=i=0n1Gi(x)Ei(x)y=\sum_{i=0}^{n-1} G_i(x) \cdot E_i(x),其中 nn 是专家网络的个数,G(x)G(x) 是第 ii 个专家的权重,Ei(x)E_i(x) 是第 ii 个专家的输出。G(x)G(x) 的每个元素都是一个概率值,且满足 i=0n1G(x)i=1\sum_{i=0}^{n-1} G(x)_i = 1

那 SMoE 中的 S(sparse) 是什么意思呢?其实就是指只有少数专家参与决策,即 G(x)G(x) 中的大部分元素都是 0,只有少部分非零。这样的话,就可以减少计算量。既然只有少数专家参与决策,那就需要一个策略决定哪些专家参与决策,方法有很多种 一种简单高效的方法是对线性层的前 kk 个 logits 应用 softmax 函数:

G(x)=Softmax(TopK(xWg))G(x)=\text{Softmax}(\text{TopK}(x \cdot W_g))

kk 做为一个超参数,可以通过平衡效果与计算量来调整。Mistral 8x7B 中使用的是 k=2k=2,即只有两个专家参与决策。 关于 MoE,huggingface 上有一篇很好的文章 Mixture of Experts Explained,可以参考。

📝 一些与决策相关的工作

Mistral 的 MoE 层结构如下图所示:

Mistral MoE layer 结构图

直接偏好优化

直接偏好优化(Direct Perference Optimization,DPO)是一种 LM 偏好对齐算法,最初 Rafailov 等人在 《Direct Preference Optimization: Your Language Model is Secretly a Reward Model》 中提出。

在此之前,让LM对人类偏好对齐常用的算法是RLHF,思路是先根据人类偏好拟合一个奖励模型,再用强化学习的方法去微调一个LM,使得LM的 出尽可能奖励最大化。但RLHF复杂,且不稳定。对比基于 Proximal Policy Optimization(PPO) 的 RLHF的方案,DPO优点是稳定、效果好、计算量小。

Q & A 环节

Router network 是怎么工作的?怎么选择,怎么组合?

从上面的结构图里可以知道每一个 router 是一个 gating network,它的输入是 xx,输出是 G(x)G(x),即每个专家的权重。

这个 gating network 做为整个模型的一部分,在训练时参数 WgW_g 也是通过反向传播来更新的。一旦训练完成,推理过程中直接使用 $G(x 的值即可。

为什么要用 router network?

理论上nn个专家共同决策,效果更稳定一些(泛化更好),但要平衡计算量和效果,所以只选择了部分专家。Router network 就是这样的一个策略,它决定哪些专家参与决策。

参考资料