WheatField
WheatField

LangChain 小记:LLM 多任务并行处理

October 10, 20241289 words, 7 min read
Authors

在使用 LLM 处理任务时,有时候需要同时执行多个任务,比如同时进行摘要总结、关键词提取及情感分析等。 如果子任务互相独立,并行处理可以显著提高执行效率。

在 Python 中,一个方法是通过 threading 或者 asyncio 写个脚本同时执行也是可以的,不过 LangChain 提供了一种很方便的方式来实现这一点,即是直接使用 RunnableParallel

from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_openai import ChatOpenAI
from langchain_core.runnables import RunnableParallel

llm = ChatOpenAI()

summary_prompt = PromptTemplate(input_variables=["text"], template="Summarize the following text:\n{text}\nSummary:")
tags_prompt = PromptTemplate(input_variables=["text"], template="Extract key tags from the following text:\n{text}\nTags:")


summary_chain = summary_prompt | llm | StrOutputParser()
tags_chain = tags_prompt | llm | StrOutputParser()

parallel_chain = RunnableParallel(summary=summary_chain, tags=tags_chain)
response = parallel_chain.invoke({"text": text})

RunnableParallel 省去了自己写代码管理线程或者协程的麻烦,运行机制也很简单,就是同时 invoke 多个 chain,然后合并结果。 如果是同步代码(e.g., chain.invoke),后台运行时通过 ThreadPoolExecutor 来实现并发,如果是异步代码(e.g., chain.ainvoke),则是通过 asyncio 来实现并发。

以上示例完整代码见:parallel-prompt.py

Chain 的实现

初次使用 LangChain 链式调用时,笔者对 | 操作符甚是困惑,一开始以为是类似形式化方法中的 or 操作符:αβ\alpha | \beta,多个模型中选择一个进行任务处理。后面稍微研究之后,才发现是串行处理。Python 语法中没有 -> 操作符,不然笔者感觉用 prompt -> llm -> output_parser 这种写法更直观一些。

看文档说明,LangChain 最开始是通过 LLMChain 类来构建链式调用,在版本 0.1.17 之后开始推荐使用 RunnableSequence,进而有 prompt | llm | output_parser 这种写法,而 | 对应的链式行为实际上是重载 RunnableSequence 类的 __or__ 方法实现的。

__or__ 运算符重载

众所周知,在 Python 中,可以通过自定义 __or__ 方法来实现运算符重载,常见的几个使用场景:

  1. 集合操作、位操作:合并集合或者是进行位运算,这是最常见的两种用法。
  2. 组合对象:将两个对象组合成一个复合对象,比如字符串拼接。
  3. 函数式编程:串联多个函数以实现管道处理。
class Chainable:
    def __call__(self, x):
        raise NotImplementedError("__call__ method must be implemented in subclass")

    def __or__(self, other):
        return ChainedCallable(self, other)

class ChainedCallable(Chainable):
    def __init__(self, *callables):
        self.callables = callables

    def __call__(self, x):
        for callable in self.callables:
            x = callable(x)
        return x

    def __or__(self, other):
        if isinstance(other, ChainedCallable):
            return ChainedCallable(*self.callables, *other.callables)
        return ChainedCallable(*self.callables, other)

class ClassA(Chainable):
    def __call__(self, x):
        print("ClassA processing:", x)
        return x ** 2

class ClassB(Chainable):
    def __call__(self, x):
        print("ClassB processing:", x)
        return x + 5

class ClassC(Chainable):
    def __call__(self, x):
        print("ClassC processing:", x)
        return x * 2

a, b, c = ClassA(), ClassB(), ClassC()
pipeline = a | b | c
result = pipeline(4)
print("Final result:", result)
# Final result: 42

LangChain 属于第三种,通过巧妙的重载 Python 的运算符,将前一个单元的输出作为下一个单元的输入,从而串联多个单元进行链式处理。LangChain 中 RunnableSequence 类的实现如下:

class RunnableSequence(Runnable):
    def __init__(self, *components):
        self.components = components

    def invoke(self, input: Any) -> Any:
        for component in self.components:
            input = component.invoke(input)
        return input

    def __or__(
        self,
        other: Union[
            Runnable[Any, Other],
            Callable[[Any], Other],
            Callable[[Iterator[Any]], Iterator[Other]],
            Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
        ],
    ) -> RunnableSerializable[Input, Other]:
        if isinstance(other, RunnableSequence):
            return RunnableSequence(
                self.first,
                *self.middle,
                self.last,
                other.first,
                *other.middle,
                other.last,
                name=self.name or other.name,
            )
        else:
            return RunnableSequence(
                self.first,
                *self.middle,
                self.last,
                coerce_to_runnable(other),
                name=self.name,
            )

如上所述,执行逻辑就是依次调用每个组件,将前一个组件的输出作为下一个组件的输入。因此当执行 summary_prompt | llm | StrOutputParser() 时,Python 解释器处理流程:

  1. 首先调用 summary_prompt.invoke(input),获得结果 result-1
  2. 然后将 result-1 作为输入调用 llm.invoke(result-1),获得结果 result-2
  3. 最后将 result-2 作为输入调用 StrOutputParser.invoke(result-2),获得最终结果。

多 LLMs 责任链

为提高系统可用性,实际部署项目时会考虑接入多个 LLM API,以防止某个 API 因为网络问题或者速率限制失效。 多个 LLM API 的调用可以串联起来,形成一个责任链,这样当其中一个失效时,责任链中的下一个 LLM 可以自动顶上。 如果代码中使用了 LangChain,那就不用再手撸责任链了,直接使用 with_fallbacks 方法即可。

groq = ChatOpenAI(...)
deepseek = ChatOpenAI(...)
mistral = ChatOpenAI(...)

llm = groq.with_fallbacks([deepseek, mistral])

年少不知框架好,错把原生当宝刀

笔者之前其实是不太喜欢使用 LangChain、LlamaIndex 之类框架或者第三方 SDK 的,基础功能都是自己手写,比如模型的 API 调用、JSON 格式化输出等。主要担心有两点:

  1. 上手成本,新兴框架更新迭代快,需要不断跟踪学习,而且即使不用框架,一些基本功能自己手搓一个也不是什么难事;
  2. 扩展性,如果框架的封装太深,一些定制化的需求不好处理。

但个人项目毕竟体量小,随着项目复杂度增加,需要自己造轮子的地方越来越多,维护扩展也举步维艰,这时候框架的便利性就体现出来了,很多基础常用的功能都有现成的组件可用。一旦熟悉了框架,开发效率提升还是挺明显的。