LangChain での LLM の Stream の基本とカスタマイズ

FacebooktwitterredditpinterestlinkedinmailFacebooktwitterredditpinterestlinkedinmail

LLM の Stream って?

ChatGPTの、1文字ずつ(1単語ずつ)出力されるアレ。あれは別に、時間をかけてユーザーに出力を提供することで負荷分散を図っているのではなく(多分)、 もともと LLM 自体が token 単位で文字を出力するため、それを少しずつユーザーに対して出力することによる UX の向上を図っているのだと思います。混雑時に token 出力にディレイをかけるくらいは、今はやっているかもしれませんが……。

で、通常の方法で LangChain から OpenAI の API を叩くと、API からの出力の完了を待って(同期的に)結果の出力を行います。これを、ChatGPT と同様に Token ごとに出力しようというのが、 LLM API の(と LangChain の) Stream です。

LangChain を使って ChatGPT 風に出力

参考:https://python.langchain.com/en/latest/modules/models/llms/examples/streaming_llm.html

注意点

2023/04/19 現在、 LangChain の Stream は OpenAI API(OpenAI(), ChatOpenAI())と、Anthropic のみの対応です。

他も Loadmap に含まれているようですが、使いたい API が対応しているかは事前に確認しましょう。

基本

基本的な使い方は、  LLM オブジェクトを生成する際に、 streaming のスイッチを True に、callback_manager に適切な CallbackManager を設定します。

from langchain.chat_models import ChatOpenAI
from langchain.callbacks.base import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.schema import HumanMessage

# key に OPEN_API_KEY を入れている
import key, os

os.environ["OPENAI_API_KEY"] = key.OPEN_API_KEY

llm = ChatOpenAI(streaming=True, callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]), verbose=True, temperature=0)
llm(messages=[HumanMessage(content="Write me a song about sparkling water.")])

出力例:

Verse 1:
Bubbles rising to the top
A refreshing drink that never stops
Clear and crisp, it's oh so pure
Sparkling water, I can't ignore

Chorus:
Sparkling water, oh how you shine
A taste so clean, it's simply divine
You quench my thirst, you make me feel alive
Sparkling water, you're my favorite vibe

Verse 2:
No sugar, no calories, just H2O
A drink that's good for me, don't you know
With lemon or lime, it's even better
Sparkling water, you're my forever

Chorus:
Sparkling water, oh how you shine
A taste so clean, it's simply divine
You quench my thirst, you make me feel alive
Sparkling water, you're my favorite vibe

Bridge:
I'll never get tired of your effervescence
You're the perfect drink for any occasion
Whether I'm at home or out on the town
Sparkling water, you never let me down

###### 後略 ########

英語の場合は日本語と異なり、1文字ずつではなく通常、1単語ずつ出力されているのが分かります。

contentを、「炭酸水についての歌詞を書いてください。」に変更して実行すると、

炭酸水、炭酸水
冷たくて爽やかな炭酸水
喉を潤してくれる
夏の暑さも吹き飛ばす

炭酸水、炭酸水
泡立ちが心地よい炭酸水
食事のお供にもぴったり
美味しい時間を過ごせる

#### 後略 ####

(なんだこれ)上記のような謎のポエムが1文字ずつ出力されます。

ここでは、StreamingStdOutCallbackHandler() を CallbackManager として利用しているので、 標準出力に Streaming の内容が表示されます(コードに printなどが入っていないことを確認してください)。

仮に llm の実行を以下のようにprint で囲むと、

print(llm(messages=[HumanMessage(content="炭酸水についての歌詞を書いてください。")]))

streaming が一通り出力された後、

content='炭酸水、炭酸水\n冷たくて爽やかな炭酸水\n #####中略##### いつまでも一緒にいたい\n私たちの大切な炭酸水' additional_kwargs={}

LLM からの出力がまとめて表示されます。ChatGPT と比較して API の実行が遅いのでは? というのは、全て出力されるまで待ってから次の動作に進むからというのが分かると思います。

カスタム CallbackHandler

参考:https://python.langchain.com/en/latest/modules/callbacks/getting_started.html

上記の例では、 LLM からの応答を標準出力に出力していました。組み込みの StreamingStdoutCallbackHandler を利用したからですね。ただ、実際にアプリケーションやサービスとして構築する場合に、標準出力を利用する場合は多くないでしょう。

様々な API や出力先で Callback を行いたいときに利用するのが、 CallbackManager と、 CallbackHandler になります。

いずれも Streaming に限らず、ロギング、モニタリング、その他のタスクに利用可能なコールバックを取り扱います。実際に Callback を処理する CallbackHandler と、CallbackHandler を束ねて管理する CallbackManager に分かれています。

今回の目的では、 CallbackManager をカスタマイズすることで、 streaming の出力先を変更することになります。例えば、 Flask に実装して ChatGPT のような UI の実装に使えるでしょう。

カスタム CallbackHandler の定義

独自の Callback 処理を実装するには、 BaseCallbackHandler クラスを継承したクラスを作成する必要があります。

from typing import Any, Dict, List, Optional, Union

from langchain.agents import initialize_agent, load_tools
from langchain.agents import AgentType
from langchain.callbacks.base import CallbackManager, BaseCallbackHandler
from langchain.llms import OpenAI
from langchain.schema import AgentAction, AgentFinish, LLMResult

class MyCustomCallbackHandler(BaseCallbackHandler):
    """Custom CallbackHandler."""

    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> None:
        """LLM の処理開始。prompt の内容を出力"""
        pass

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        """LLM の処理終了。何もしない"""
        pass

    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        """LLM から新しい Token が出力。いわゆる Streaming の部分"""
        pass

    def on_llm_error(
        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
    ) -> None:
        """LLM の処理中にエラーが発生"""
        pass

    def on_chain_start(
        self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
    ) -> None:
        """Chain の処理がスタート"""
        pass

    def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
        """Chain の処理が終了"""
        pass

    def on_chain_error(
        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
    ) -> None:
        """Chain の実行でエラーが発生"""
        pass

    def on_tool_start(
        self,
        serialized: Dict[str, Any],
        input_str: str,
        **kwargs: Any,
    ) -> None:
        """Tool の実行が開始"""
        pass

    def on_agent_action(
        self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
    ) -> Any:
        """Agent がアクションを実施。Agent の Streaming は大体ここ"""
        print(action)

    def on_tool_end(
        self,
        output: str,
        color: Optional[str] = None,
        observation_prefix: Optional[str] = None,
        llm_prefix: Optional[str] = None,
        **kwargs: Any,
    ) -> None:
        """Tool の使用が終了。Final Answer でなければ[Observation]が出力"""
        print(output)

    def on_tool_error(
        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
    ) -> None:
        """Tool の使用でエラーが発生"""
        pass

    def on_text(
        self,
        text: str,
        color: Optional[str] = None,
        end: str = "",
        **kwargs: Optional[str],
    ) -> None:
        """Agent の終了時に呼び出される。完全に終了したとき(?)。結果の出力"""
        print(text)

    def on_agent_finish(
        self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
    ) -> None:
        """Agent が終了した時に呼び出される。ログの出力"""
        print(finish.log)

LLM だけでなく、 Tool, Agent の Callback も全て網羅したクラスになります。

今回は、炭酸水をイメージして(?)標準出力に色を付けて Streaming と Callback を試してみます。

from typing import Any, Dict, List, Optional, Union

from langchain.callbacks.base import CallbackManager, BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult

from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage

class MyCustomCallbackHandler(BaseCallbackHandler):
    """Custom CallbackHandler."""

    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> None:
        pass

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        pass

    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        '''新しいtokenが来たらprintする'''
        print('\033[36m' + token + '\033[0m')

    def on_llm_error(
        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
    ) -> None:
        """Do nothing."""
        pass

    def on_chain_start(
        self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
    ) -> None:
        """Print out that we are entering a chain."""
        class_name = serialized["name"]
        print(f"\n\n\033[1m> Entering new {class_name} chain...\033[0m")

    def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
        """Print out that we finished a chain."""
        print("\n\033[1m> Finished chain.\033[0m")

    def on_chain_error(
        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
    ) -> None:
        """Do nothing."""
        pass

    def on_tool_start(
        self,
        serialized: Dict[str, Any],
        input_str: str,
        **kwargs: Any,
    ) -> None:
        """Do nothing."""
        pass

    def on_agent_action(
        self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
    ) -> Any:
        """Run on agent action."""
        print(action)

    def on_tool_end(
        self,
        output: str,
        color: Optional[str] = None,
        observation_prefix: Optional[str] = None,
        llm_prefix: Optional[str] = None,
        **kwargs: Any,
    ) -> None:
        """If not the final action, print out observation."""
        print(output)

    def on_tool_error(
        self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
    ) -> None:
        """Do nothing."""
        pass

    def on_text(
        self,
        text: str,
        color: Optional[str] = None,
        end: str = "",
        **kwargs: Optional[str],
    ) -> None:
        """Run when agent ends."""
        print(text)

    def on_agent_finish(
        self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
    ) -> None:
        """Run on agent end."""
        print(finish.log)

import key, os

os.environ["OPENAI_API_KEY"] = key.OPEN_API_KEY

llm = ChatOpenAI(streaming=True, callback_manager=CallbackManager([MyCustomCallbackHandler()]), verbose=True, temperature=0)
print(llm(messages=[HumanMessage(content="炭酸水についての歌詞を書いてください。")]))

出力例:

凄い鬱陶しい感じに出力されてしまいましたが、 Streaming の出力が変わっていることが分かりやすいのでこれはこれでいいと思います(鬱陶しいのが嫌な方は、print の end=” を利用してください)。

BasecallBackHandler の各メソッドは abstract なので、純粋な Streaming の実験としては on_llm_new_token だけ実装できればいいですが、 MyCustomCallbackHandler では全てのメソッドの実装が必要です。

まとめ

  • stream を行うには LLM の streaming スイッチを True にし、かつ、 CallbackManager に適切な CallbackHandler のインスタンスを渡す必要がある
  • LangChain でStreaming を行う、最も基本的な CallbackHandler は標準出力を利用する StreamingStdOutCallbackHandler
  • 標準出力以外に Streaming を行いたい場合は、 BaseCallbackHandler を継承した独自の CallbackHandler を実装する必要がある。
  • LLM の streaming には on_llm_new_token メソッドを利用する。その他のメソッドで、 LLM だけでなく Tool, Agent などの動作の Streaming も可能になる
FacebooktwitterredditpinterestlinkedinmailFacebooktwitterredditpinterestlinkedinmail

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

このサイトはスパムを低減するために Akismet を使っています。コメントデータの処理方法の詳細はこちらをご覧ください

最新の記事