WheatField
WheatField

AI 推理加速利器:提示缓存技术解析

August 15, 20241756 words, 9 min read
Authors

Claude 最近也支持提示缓存(prompt caching)了,为什么说“也”,因为前几日 DeepSeek 就率先在 API 模型上支持了硬盘缓存,把本来就白菜价的 API 价格又压了一个数量级。在缓存命中的情况下,百万 tokens 仅 0.1 元,四舍五入等于不要钱。DeepSeek 已经成了众多开发者的默认测试 API,不是 GPT-4o 用不起,而是 DeepSeek 更有性价比。估计后面会有更多平台跟进,不得不说,"AI 界拼多多"在产品体验这块儿跑的还挺快。

当然,提示缓存不是新概念,23 年就由 《Prompt Cache: Modular Attention Reuse for Low-Latency Inference》 提到过,跟 KV cache 的想法比较类似,本质上都是空间换时间的思路。Prompt cache 不仅能缓存上下文、提示模板、系统消息,也能缓存文档,在不损失推理效果的前提下,做到尽可能的提高推理速度、降低延迟

Claude 这次主要缓存的是上下文,适用于长对话、编程辅助,及使用同一个长提示作为模板进行批量任务等场景。举个栗子,假如每次调用 API 时都使用一个 5000 字的长 PROMPT,告诉模型它是谁、遵守哪些规则、以什么形式返回结果、参考哪些示例等,如果短时间内(官方数据是 5 分钟内)有一万次 API 调用,那这个 PROMPT 就只计费一次,这样费用上立省一台 4090,延迟也能大幅降低。

原理

那 prompt caching 是怎么实现的呢?

概括的说,即是复用多条输入重叠部分的注意力状态(attention states),减少计算量。也即是在推理阶段,以文本片段(text segment)为单位,把频繁出现的文本段对应的注意力状态存储下来,在下次推理时,如果遇到相同的文本段,就直接使用上次存储的注意力状态。

序列内缓存——KV cache

稍微回顾一下 KV cache 是怎么工作的。众所周知,在推理阶段,LLM 会把生成的 token 跟输入拼接起来,继续生成下一个 token,直到生成结束。假如初始序列 s={s1,s2,...,sn}s=\{s_1, s_2, ..., s_n\},然后逐渐生成 kk 个 token {sn+1,sn+2,...,sn+k}\{s_{n+1}, s_{n+2}, ..., s_{n+k}\}

在自回归模型最初始版本中,每生成一个新的 token 时都需要重复计算注意力状态 {(k1,v1),(k2,v2),...,(kn,vn)}\{(k_1, v_1), (k_2, v_2), ..., (k_n, v_n)\},既然如此,那能不能把这些注意力状态存储起来,在生成下一个 token 时,直接复用上次存储的注意力状态呢?

当然可以,这就是 KV cache 的想法。给定输入 s={s1,s2,...,sn}s=\{s_1, s_2, ..., s_n\},先将注意力层计算出来的键值对存储起来,后面推理时直接使用(如下图 (b))。

跨序列的缓存——Prompt cache

KV cache 的局限性在于它只能在同一条序列中使用,而 prompt cache 更进一步,把 attention states 的复用从单条序列中解放出来,并用到其他序列中。

从上图可以看到,对比 KV cache,Prompt cache 多了一个组件,不难想到,这个组件得负责处理两件事:

  1. 复用:在推理过程中遇到一个新的序列时,如何从 cache 中直接得到 attention states?
  2. 缓存:另外,cache 本身又是怎么来的?

Cache 的诞生

如果不参考已有方案,让我们自己去设计,跨序列的缓存应该怎么实现?

当遇到一条新的序列时,我们首先要“拆解”它才能考虑复用,也就是先把它按特定的“结构”进行拆分,找出重复的部分,然后复用这些“结构”对应的注意力状态。那作为 LLM 的输入,一条序列有哪些重复的部分?

以日常的对话为例,LLM API 输入通常包含系统提示、元模板、背景文档、对话历史等元素,以及用户自定义的部分。

(SystemPrompt | MetaTemplate | Context | ConversationHistory | ...)+UserPrompt(\text{SystemPrompt | MetaTemplate | Context | ConversationHistory | ...})^* + \text{UserPrompt}

这其中,系统提示、元模板、背景文档等都是通用的元素,同一类型的任务仅在用户自定义的部分不同,因此一个合理的想法就当把这些通用的元素看作上面提及的特定“结构”。接下来要做的事情就是把这个结构抽象出来,并把它形式化。

编码系统

要形式化这个系统,一种方法就是使用一套编码系统将 prompt 的结构识别出来并进行编码。论文中定义了提示标记语言(Prompt Markup Language, PML),将一类 prompt(通常对应一类任务)定义成一个模式(schema),这个模式将一个“结构”定义成一个模块(module)。即形如:

schema=module\text{schema} = \text{module}^*

每一条序列都可以通过这个 schema 重新编码成一个新的 prompt。比如我们关于一个任务的 prompt 是这样的:

(SystemPrompt | MetaTemplate | Context | Examples | Others)+UserPrompt(\text{SystemPrompt | MetaTemplate | Context | Examples | Others})^* + \text{UserPrompt}

其中,系统提示、元模板、背景文档、示例等是预定义的模块,多个任务之前可以共用。UserPrompt 是用户自定义的,每次任务不同。我们可以为这类任务定义一个可能的如下的 schema:

<schema name="TaskPrompt">
  <module name="SystemPrompt">...</module>
  <module name="MetaTemplate">...</module>
  <module name="Context">...</module>
  <module name="Examples">...</module>
</schema>

这个 schema 定义了若干个可供用户直接调用的模块,比如一个 SystemPrompt 模块编码的内容可能为:

You are a helpful assistant.

假设我们有一条包含 SystemPrompt、Context、Examples 的输入序列:

PROMPT = """
SystemPrompt
Context
Examples

Query: q.
Answer: ?
"""

通过以上 schema 定义,我们可以将这条序列编码成如下形式的 prompt:

<prompt schema="TaskPrompt">
  <SystemPrompt/>
  <Context/>
  <Examples/>
  Query: q. Answer: ?
</prompt>

在推理时,prompt cache 处理流程:

  • 检索缓存的注意力状态: 从缓存中获取 SystemPrompt, Context, Examples 的注意力状态。
  • 处理新文本: 用户输入的文本都未缓存,需要重新计算注意力状态。上例中的 Query: q. Answer: ? 就是未缓存的部分。
  • 合并注意力状态: 按顺序拼接 SystemPrompt + Context + Examples + UserPrompt 的注意力状态。
  • 生成响应: 使用合并后的注意力状态来生成 LLM 的响应。

编码方式

在模型支持 prompt caching 的情况下,主要有两种方式。一种是手动改写:用户根据 schema 的定义,手动将输入序列改写成上面形式的 prompt。

另一种方式是自动生成,从提示程序自动生成(Python-to-PML 编译): 论文提到可以自动将 Python 函数转换为相应的 PML schema。

小结

本文探讨了 prompt cache 的基本原理及实现,一种方式是通过 PML 形式化输入序列,定义并可复用部分。在推理时,只计算未缓存的文本。