浅聊 Mixtral-8x7B 的一些亮点
- Authors
- @SLIPPERTOPIA
Mistral.ai 在 arXiv 上放出来论文 Mixtral of Experts,整体结构上跟 Mistral 7B 一样,唯一的区别在于每个层都由 8 个前馈块(即专家)组成。
Mistral 7B
首先聊一下 Mistral 7B(论文链接: https://arxiv.org/pdf/2310.06825.pdf, arXiv 23.10),这篇工作在当时的亮点有:
- 性能:- 如下图所示, Mistral 7B 在多项基准测试中都优于当时最佳开源模型 Llama 2 13B,在推理、数学和代码生成方面优于 Llama 1 34B - 提供了一个经过指令微调的模型,Mistral 7B-Instruct,在 Chatbot Arena Leaderboard 上得分 1031 分(论文发布时),超过了 Llama 2 13B 的 1012 分。 排行榜上的得分是通过人类对模型生成的对话打分来计算的,分数越高,模型生成的对话越接近人类的对话。
- 结构特点: 利用分组查询注意力(GQA)来实现更快的推理,并结合滑动窗口注意力(SWA)来有效地处理任意长度的序列,同时降低推理成本。
做为对比,截止到 2024/01/20, Mistral 8x7b instruct 在 Chatbot Arena Leaderboard 仍然排在前十的位置,前一段时间的排名更高,居第4名,仅次于 GPT-4。Mistral 的另一个模型 Mistral medium 效果更好,在 Arena ELO 上排到第 5 名,在 MT-bench 上得分 8.61 分,除了 GPT 4 之外,超过了所有其他模型,但这个模型目前还没有开源。 有意思的是,Claud 的两个新版本 2.0, 2.1 反而一代不如一代。
注:MT-bench, multi-turn question set, 是一个用于评估模型多轮对话及指令遵循能力的数据集。目前在这个 set 上,能力最强的模型是 GPT-4-Turbo,得分 9.32 分。
当然,强如 mistral medium,也还是没有通过“鲁迅为什么要打周树人”的幻觉测试。
Sliding Window Attention (SWA)
在 vanilla attention 中,每个 token 都会与所有其他 token 进行交互,这样的话,计算复杂度就是 ,其中 是序列长度。这样的计算复杂度在序列很长时会很高,一种更高效的方法是 Sliding Window Attention (SWA)。
SWA 最早在 LongFormer中提出,属于 sparse attention 的一种,其目标是让 LLM 更有效的处理长序列。思想是把每个 token 的注意力跨度限制到它周围的固定窗口中,LongFormer中的做法是定义一个宽度为的窗口,使得 query 节点只能注意到对等的 key 节点以及 key nodes 左右 个节点。这样的话,计算复杂度就是 ,其中 是序列长度。当 ,计算复杂度就会大大降低。
这样,其他 key nodes 的信息是不是丢失了?也不是,多个层堆叠在一起时,在高层上,query 节点会间接的注意到远处的 key 节点。假设 层,每层的窗口大小为 ,那么在第 层,query 节点可以间接的注意到 范围内的 key 节点。
Mistral 7B 中的 window size 是 4096.
Grouped Query Attention (GQA)
分组查询注意力(GQA,论文 :https://arxiv.org/pdf/2305.13245.pdf )是在多查询注意力(Multi-Query Attention,MQA)和多头注意力(Multi-Head Attention,MHA)之间找到一个平衡点,其目标是在保持MQA速度的同时实现MHA的质量。
MQA 使用单一的键值头,这能加快解码器的推理速度,但可能导致解码质量下降。具体来说,MQA减少了模型在执行注意力操作时可利用的信息量 从而可能导致在处理复杂任务或长序列生成时的性能下降。例如,在文本摘要任务中,MHA允许模型在计算注意力时考虑多个不同的特征组合 从而捕捉更丰富的语言特征。而MQA由于头的限制,可能无法捕获这样的复杂特征,导致生成的摘要可能不够准确。
GQA 将查询头分为 G 组,每个组共享一个键头和值头。GQA 通过减少解码器推理期间所需的内存带宽来提高 LLMs 的效率。
优势
相对 MHA 来说,GQA 显著降低 LLM的计算复杂性,从而提高推理速度;对内存的占用也更少,适用于限制内存大小的LLMs。相对 MQA 来说,GQA 提高了解码质量,但速度稍慢。还有一点,MQA 支持多GPU并行,有效地利用计算资源。
Byte-fallback BPE algorithm
BPE算法大家应该都比较熟悉,一开始 BPE是做为压缩文本的算法而开发的,后来由 OpenAI 在预训练 GPT 模型时用于标记化。这是一种基于统计的分词算法,它将词汇表中的每个单词拆分成子单词,然后将子单词组合成新的单词。这样的话,就可 处理未知词汇,提高模型的泛化能力。很多Transformer 模型中都有应用,比如 GPT、GPT-2、RoBERTa、BART 和 DeBERTa。
那这里的 byte-fallback 又是什么?我们知道分词器在遇到未知词汇时,会将其标记为 Unk,但这样的话,模型就无法学习到未知词汇的信息。Byte-fallback BPE 算法就是为了解决这个问题而提出的,它把所有 256 个 UTF-8 字节码单元添加到词汇表中,任何未知的 unicode 字符可以被分解为字节码单元,这种方法为单词提供了一种独特且意义更加明确的表达形式。
Mistral 中使用了 byte-fallback BPE 算法,词汇表大小为 32K,其中 256 个字节码单元占了 1K,剩下的 31K 是由 BPE 算法生成的。
Sparse Mixtures of Experts
SMoE 是 mixtral-8x7B 的重心之一。
给定输入 ,MoE 模块的输出 ,其中 是专家网络的个数, 是第 个专家的权重, 是第 个专家的输出。 的每个元素都是一个概率值,且满足 。
那 SMoE 中的 S(sparse) 是什么意思呢?其实就是指只有少数专家参与决策,即 中的大部分元素都是 0,只有少部分非零。这样的话,就可以减少计算量。既然只有少数专家参与决策,那就需要一个策略决定哪些专家参与决策,方法有很多种 一种简单高效的方法是对线性层的前 个 logits 应用 softmax 函数:
做为一个超参数,可以通过平衡效果与计算量来调整。Mistral 8x7B 中使用的是 ,即只有两个专家参与决策。 关于 MoE,huggingface 上有一篇很好的文章 Mixture of Experts Explained,可以参考。
📝 一些与决策相关的工作
- Unified scaling laws for routed language models
- Dselect-k: Differentiable selection in the mixture of experts with applications to multi-task learning
- CommonsenseQA: A Question Answering Challenge Targeting Commonsense Knowledge
Mistral 的 MoE 层结构如下图所示:
直接偏好优化
直接偏好优化(Direct Perference Optimization,DPO)是一种 LM 偏好对齐算法,最初 Rafailov 等人在 《Direct Preference Optimization: Your Language Model is Secretly a Reward Model》 中提出。
在此之前,让LM对人类偏好对齐常用的算法是RLHF,思路是先根据人类偏好拟合一个奖励模型,再用强化学习的方法去微调一个LM,使得LM的 出尽可能奖励最大化。但RLHF复杂,且不稳定。对比基于 Proximal Policy Optimization(PPO) 的 RLHF的方案,DPO优点是稳定、效果好、计算量小。
Q & A 环节
Router network 是怎么工作的?怎么选择,怎么组合?
从上面的结构图里可以知道每一个 router 是一个 gating network,它的输入是 ,输出是 ,即每个专家的权重。
这个 gating network 做为整个模型的一部分,在训练时参数 也是通过反向传播来更新的。一旦训练完成,推理过程中直接使用 $G(x 的值即可。
为什么要用 router network?
理论上个专家共同决策,效果更稳定一些(泛化更好),但要平衡计算量和效果,所以只选择了部分专家。Router network 就是这样的一个策略,它决定哪些专家参与决策。