AI 推理加速利器:提示缓存技术解析
- Authors
- @SLIPPERTOPIA
Claude 最近也支持提示缓存(prompt caching)了,为什么说“也”,因为前几日 DeepSeek 就率先在 API 模型上支持了硬盘缓存,把本来就白菜价的 API 价格又压了一个数量级。在缓存命中的情况下,百万 tokens 仅 0.1 元,四舍五入等于不要钱。DeepSeek 已经成了众多开发者的默认测试 API,不是 GPT-4o 用不起,而是 DeepSeek 更有性价比。估计后面会有更多平台跟进,不得不说,"AI 界拼多多"在产品体验这块儿跑的还挺快。
当然,提示缓存不是新概念,23 年就由 《Prompt Cache: Modular Attention Reuse for Low-Latency Inference》 提到过,跟 KV cache 的想法比较类似,本质上都是空间换时间的思路。Prompt cache 不仅能缓存上下文、提示模板、系统消息,也能缓存文档,在不损失推理效果的前提下,做到尽可能的提高推理速度、降低延迟。
Claude 这次主要缓存的是上下文,适用于长对话、编程辅助,及使用同一个长提示作为模板进行批量任务等场景。举个栗子,假如每次调用 API 时都使用一个 5000 字的长 PROMPT,告诉模型它是谁、遵守哪些规则、以什么形式返回结果、参考哪些示例等,如果短时间内(官方数据是 5 分钟内)有一万次 API 调用,那这个 PROMPT 就只计费一次,这样费用上立省一台 4090,延迟也能大幅降低。
原理
那 prompt caching 是怎么实现的呢?
概括的说,即是复用多条输入重叠部分的注意力状态(attention states),减少计算量。也即是在推理阶段,以文本片段(text segment)为单位,把频繁出现的文本段对应的注意力状态存储下来,在下次推理时,如果遇到相同的文本段,就直接使用上次存储的注意力状态。
序列内缓存——KV cache
稍微回顾一下 KV cache 是怎么工作的。众所周知,在推理阶段,LLM 会把生成的 token 跟输入拼接起来,继续生成下一个 token,直到生成结束。假如初始序列 ,然后逐渐生成 个 token 。
在自回归模型最初始版本中,每生成一个新的 token 时都需要重复计算注意力状态 ,既然如此,那能不能把这些注意力状态存储起来,在生成下一个 token 时,直接复用上次存储的注意力状态呢?
当然可以,这就是 KV cache 的想法。给定输入 ,先将注意力层计算出来的键值对存储起来,后面推理时直接使用(如下图 (b))。
跨序列的缓存——Prompt cache
KV cache 的局限性在于它只能在同一条序列中使用,而 prompt cache 更进一步,把 attention states 的复用从单条序列中解放出来,并用到其他序列中。
从上图可以看到,对比 KV cache,Prompt cache 多了一个组件,不难想到,这个组件得负责处理两件事:
- 复用:在推理过程中遇到一个新的序列时,如何从 cache 中直接得到 attention states?
- 缓存:另外,cache 本身又是怎么来的?
Cache 的诞生
如果不参考已有方案,让我们自己去设计,跨序列的缓存应该怎么实现?
当遇到一条新的序列时,我们首先要“拆解”它才能考虑复用,也就是先把它按特定的“结构”进行拆分,找出重复的部分,然后复用这些“结构”对应的注意力状态。那作为 LLM 的输入,一条序列有哪些重复的部分?
以日常的对话为例,LLM API 输入通常包含系统提示、元模板、背景文档、对话历史等元素,以及用户自定义的部分。
这其中,系统提示、元模板、背景文档等都是通用的元素,同一类型的任务仅在用户自定义的部分不同,因此一个合理的想法就当把这些通用的元素看作上面提及的特定“结构”。接下来要做的事情就是把这个结构抽象出来,并把它形式化。
编码系统
要形式化这个系统,一种方法就是使用一套编码系统将 prompt 的结构识别出来并进行编码。论文中定义了提示标记语言(Prompt Markup Language, PML),将一类 prompt(通常对应一类任务)定义成一个模式(schema),这个模式将一个“结构”定义成一个模块(module)。即形如:
每一条序列都可以通过这个 schema 重新编码成一个新的 prompt。比如我们关于一个任务的 prompt 是这样的:
其中,系统提示、元模板、背景文档、示例等是预定义的模块,多个任务之前可以共用。UserPrompt
是用户自定义的,每次任务不同。我们可以为这类任务定义一个可能的如下的 schema:
<schema name="TaskPrompt">
<module name="SystemPrompt">...</module>
<module name="MetaTemplate">...</module>
<module name="Context">...</module>
<module name="Examples">...</module>
</schema>
这个 schema 定义了若干个可供用户直接调用的模块,比如一个 SystemPrompt 模块编码的内容可能为:
You are a helpful assistant.
假设我们有一条包含 SystemPrompt、Context、Examples 的输入序列:
PROMPT = """
SystemPrompt
Context
Examples
Query: q.
Answer: ?
"""
通过以上 schema 定义,我们可以将这条序列编码成如下形式的 prompt:
<prompt schema="TaskPrompt">
<SystemPrompt/>
<Context/>
<Examples/>
Query: q. Answer: ?
</prompt>
在推理时,prompt cache 处理流程:
- 检索缓存的注意力状态: 从缓存中获取 SystemPrompt, Context, Examples 的注意力状态。
- 处理新文本: 用户输入的文本都未缓存,需要重新计算注意力状态。上例中的
Query: q. Answer: ?
就是未缓存的部分。 - 合并注意力状态: 按顺序拼接 SystemPrompt + Context + Examples + UserPrompt 的注意力状态。
- 生成响应: 使用合并后的注意力状态来生成 LLM 的响应。
编码方式
在模型支持 prompt caching 的情况下,主要有两种方式。一种是手动改写:用户根据 schema 的定义,手动将输入序列改写成上面形式的 prompt。
另一种方式是自动生成,从提示程序自动生成(Python-to-PML 编译): 论文提到可以自动将 Python 函数转换为相应的 PML schema。
小结
本文探讨了 prompt cache 的基本原理及实现,一种方式是通过 PML 形式化输入序列,定义并可复用部分。在推理时,只计算未缓存的文本。