ConversationalRetrievalChain で LangChain の QA にチャット履歴実装

FacebooktwitterredditpinterestlinkedinmailFacebooktwitterredditpinterestlinkedinmail

LangChain では、 EmbeddingAPI を使って vector search とその結果を LLM に与えて QA Bot を構築したり、あるいは ChatGPT のような記憶・履歴(Memory)を実装して、自然な対話を行う便利なモジュールが揃っています。しかし、この Memory のオブジェクトを QA チェーンに入れても、それだけでは上手く動作しません。

そこで今回は、まず QA を一つの Chain オブジェクトで実行する方法と、さらに Memory を実装するについてです。

参考:https://python.langchain.com/en/latest/modules/chains/index_examples/vector_db_qa.html

参考:https://python.langchain.com/en/latest/modules/chains/index_examples/chat_vector_db.html

検索と回答をまとめる RetrievalQA

RetrievalQA は、 vector index の内容を元に QA の応答を行う Chain です。 load_qa_chain との違いがなかなか分かりにくいですが、load_qa_chain の方は回答の元となるドキュメントを渡す必要があるのに対し、 RetrivalQA は Query や vector_index を直接渡します。

極論すると、 load_qa_chain は SQL でも web検索でも、あるいは(様々な制約が許せば)百科事典全編を参考ドキュメントとして渡すことが許されますが、 RetrriebalQA は vector search ができなければならないという違いがあります。 RetrievalQA は融通が利かないですが、 Agent や Sequantial な Chain に入れる場合には重宝すると思います。ソースコードも短くて済みますし。

import と vector index の準備

# Embedding用
from langchain.embeddings.openai import OpenAIEmbeddings
# Vector 格納 / FAISS
from langchain.vectorstores import FAISS
# Q&A用Chain
from langchain.chains import RetrievalQA
# ChatOpenAI GPT 3.5
from langchain.chat_models import ChatOpenAI
from langchain.prompts import PromptTemplate
# env に読み込ませるAPIキーの類
import key

# 環境変数にAPIキーを設定
import os
os.environ["OPENAI_API_KEY"] = key.OPEN_API_KEY

embeddings = OpenAIEmbeddings()

db = FAISS.load_local("faiss_index", embeddings)

DB は ローカル環境に保存されたものを使っています。 vector index の作り方は以前の記事を確認してください。

Embedding, LLM の他は、Chains から RetrievalQA を import すれば大丈夫です。さほど複雑ではないですね。

RetrievalQA を動かしてみる

qa = RetrievalQA.from_chain_type(llm=ChatOpenAI(), chain_type="stuff", retriever=db.as_retriever())
query = "A社の主要な作物は何?"
print(qa.run(query))
# A社の主要な作物は、サツマイモ、レタス、トマト、苺、トウモロコシなどです。

from_chain_type メソッドで QA Chain を作成します。引数として、 LLM, chain_type(詳細は過去記事)、それから retriever が必要です。retriever は vector index のオブジェクトのメソッド、as_retriever で取得します。ここでは、 FAISS で load_local した(中小企業診断士の試験問題)ドキュメントの vector index の retriever を渡しています。

後は、query として質問を渡せば動作します。similarity_search と load_qa_chain を組み合わせるより簡単ですね。

カスタムプロンプトを読み込ませる

LangChain で困るのが無理矢理 LLMChain に色々入れて、Promptを変えようとすると、Parse Error が出てしまうことです。issue として出されてもいますが、適切なモジュールを使っていない場合というのもありそうです。

QA の場合は、 RetrievalQA に渡す引数を工夫することで実現できます。

prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.

{context}

Question: {question}
Answer in Japanese:"""
PROMPT = PromptTemplate(
    template=prompt_template, input_variables=["context", "question"]
)

chain_type_kwargs = {"prompt": PROMPT}
qa = RetrievalQA.from_chain_type(llm=ChatOpenAI(), chain_type="stuff", retriever=db.as_retriever(), chain_type_kwargs=chain_type_kwargs)

query = "A社の工場はどこにある?"
print(qa.run(query))
# 自社工場はどこにあるか記載されていないため、わかりません。 

RetrievalQA に渡す prompt には、ユーザーからの question の他、参考ドキュメントを入れる context が必要になります。

そして、 prompt については、引数に直接 prompt= と入れるのではなく、いったん chain_thype_kwargs という dict の中に prompt というキーで格納します。必要なキーワード引数(keywords arguments, kwargs)が入った dict を改めて from_chain_type メソッドに渡してインスタンスを取得します。恐らく、内部的に load_qa_chain などに引数をそのまま引き渡すためにこのような構造になっているのだと思います。

質問を行うと、prompt で指示された通り(If you don’t know the answer, just say that you don’t know, don’t try to make up an answer.)、回答を無理矢理作成するのではなく不明と返答してくれていることが分かります。

会話として動作させるにはConversationalRetrievalChain

RetrievalQA には Memory を処理する仕組みがありません。そのため、ChatGPT のように記憶を処理させるには別のモジュール、 ConversationalRetrievalChain を利用する必要があります。ただ、この ConversationalRetrievalChain も、他の LangChain に用意されている Memory モジュールをそのまま利用するものとは違うため、取り扱いに少し注意が必要です。

とりあえず import とか

# Embedding用
from langchain.embeddings.openai import OpenAIEmbeddings
# Vector 格納 / FAISS
from langchain.vectorstores import FAISS
# Q&A用Chain
from langchain.chains import ConversationalRetrievalChain
# ChatOpenAI GPT 3.5
from langchain.chat_models import ChatOpenAI
# env に読み込ませるAPIキーの類
import key

# 環境変数にAPIキーを設定
import os
os.environ["OPENAI_API_KEY"] = key.OPEN_API_KEY

embeddings = OpenAIEmbeddings()

db = FAISS.load_local("faiss_index", embeddings)

QA 用の Chain が RetrievalQA から変更になっただけです。 LangChain の Memory 系モジュールは不要(使えない)点に注意してください。

動かしてみる

qa = ConversationalRetrievalChain.from_llm(ChatOpenAI(temperature=0), db.as_retriever())
# 空のchat履歴
chat_history = []
query = "A社の主要な作物は何?"
# 複数の引数をdictで渡す
result = qa({"question": query, "chat_history": chat_history})
print(result["answer"])
# A社の主要な作物は、サツマイモ、レタス、トマト、イチゴ、トウモロコシなどです。特に、イチゴの栽培で知られ、高い糖度と大粒で形状や色合いが良く人気があります。

chat_history = [(query, result["answer"])]
query = "それらの作物を加工して出来た商品は何?"
result = qa({"question": query, "chat_history": chat_history})

print(result["answer"])
# A社は、自社工場で外部取引先からパン生地を調達し、自社栽培の新鮮で旬の野菜(トマトやレタスなど)やフルーツを使ったサンドイッチや総菜商品などを製造し、既存の大手中食業者を含めた複数の業者に卸しているということです。つまりA社の作物を加工して作られた商品は、サンドイッチや総菜商品などです。

使い方は、ほぼ RetrievalQA と一緒です。ここでは chain_type を指定していない(厳密には、直接指定できない)ので、インスタンスを取得するメソッドが from_llm となっており、必要な引数が LLM と retriever のみの順序引数となっています。

続いて、空の chat_history を 空のリスト型として作成しています。 ConversationalRetrievalChain では chat_history をシンプルなリストとしてしか受け取りません。なので、 import でも特別なライブラリやモジュールが追加されていませんでした。

QA の実行には、 qa chain に直接、dict 型の引数を渡します。ここではユーザーからの query と、過去の履歴(空)となっています。

続いて、ユーザーからのクエリと、得られた回答(result[“answer”])をペアとした tuple だけを含めたリストを、chat_history に格納しています。実際のコーディングでは、リストに対して insert していくといった処理になるかと思います(ただし、context の順序的に最初から末尾に追加した方がいい場合もある)。

QA を実行すると、先の回答の文脈を使って、「中食業者に卸している製品」の回答が得られています。

なんで Memory 使わないの?

筆者も最初、 Memory 関連にサンプルコードや API があるだろうと探していたのですが見つからず、やっと見つかったと思ったら、リスト内にタプルが格納されただけのシンプルな構造で混乱しました。

が、恐らく、 LangChain の Memory には ConversationalMemoryBuffer のような 1:1 の対話型以外の記録形式も内包しているため、質問に対して応答という形式を重視する RetrievalQA 型の実装にはあわなかったからだと推測されます。 Agent 内に格納し、 Memory を共有する場合などは、後述する get_chat_history 関数などを利用する必要があると思います。

回答の信頼性の閾値を設定したりソースを取得したり

# Source も返す
qa = ConversationalRetrievalChain.from_llm(ChatOpenAI(temperature=0), db.as_retriever(), return_source_documents=True)
# vector store が探していれば、 search distance に閾値を設定してフィルタがかけられる
vectordbkwargs = {"search_distance": 0.9}
# 空のchat履歴
chat_history = []
query = "A社の主要な作物は何?"
# 複数の引数をdictで渡す, search_distance はさらに vectordbkwargs というdict で渡す
result = qa({"question": query, "chat_history": chat_history, "vectordbkwargs": vectordbkwargs})
print(result["answer"])
# result['source_documents'][0] ソースドキュメントを出力する

vector index での検索は、多次元空間内のベクトル間の距離で行われます。そのため、キーワード自体に一致が見られなくても、言葉の意味する内容が近かったり、文脈が一致する場合などに検索結果として出力することが可能になっています。ただ、あくまでも距離が近いものが出力されるため、思ったより品質が低い内容が参考されてしまう場合もあります。特に、似たような機能が多いソフトウェアのマニュアルでは顕著でしょう(クラウド会計ソフトのマニュアルを Google 検索した場合、微妙に違う内容ばかりひっかかって、2ページ目以降の個人サイトの方が役に立つというのはよくある話です)。

そこで、ベクトル間の距離を閾値としたフィルターを設定し、関連度がより強いものしか参照しないようにできます。ここでは、 vectordbkwargs 内のdictに、 search_distance というキー名で格納します。このdict は、RetrievalQA に Prompt を設定したときとほぼ同じで、引数の名前が異なります(今度は、 vector index を格納する DB に直接渡す引数であるから vectordbkwargs ですね)。

また、ConversationalRetievalChain で、回答の元になったドキュメントを取得するには、最初のインスタンス化の際に return_source_documents のスイッチを True にします。すると、 QA の結果として返ってくるdict 内に、 source_documents というキーで格納されます。

Prompt を設定したり chain_type を設定したりする。

ConversationalRetrievalChain は、 load_qa_chain をシンプルにラップしているだけのようで、 chain_type の設定なども from_llm によるインスタンス取得からでは直接、行えません(各chain_type について)。Prompt についても同様ですが、 Prompt についてはもうちょっと複雑です。

その代わり、ConbersationalRetrievalChain 内で使われる qa_chain そのものを設定 / 変更することでこれらの各種機能にアクセスしたりできます。

import とか

# Embedding用
from langchain.embeddings.openai import OpenAIEmbeddings
# Vector 格納 / FAISS
from langchain.vectorstores import FAISS
# Q&A用Chain
from langchain.chains import ConversationalRetrievalChain
from langchain.chains import LLMChain
from langchain.chains.question_answering import load_qa_chain
# ChatOpenAI GPT 3.5
from langchain.chat_models import ChatOpenAI

# Prompt 0.0.128現在、CONDENSE_QUESTION_PROMPT の import が公式ドキュメントだと間違っているので注意
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
from langchain.prompts import PromptTemplate
# env に読み込ませるAPIキーの類
import key

# 環境変数にAPIキーを設定
import os
os.environ["OPENAI_API_KEY"] = key.OPEN_API_KEY

embeddings = OpenAIEmbeddings()

db = FAISS.load_local("faiss_index", embeddings)

import するモジュールとしては、直接使用する必要があるため、 LLMChain, load_qa_chain を追加しています。

また、Promptを作成するための PromptTemplate と、標準で用意されている CONDENSE_QUESTION_PROMPT を import しています。

CONDENSE_QUESTION_PROMPT は、condense (凝縮する)の名の通り、質問やチャットの履歴を圧縮する prompt になっています(ConversationSummarizeBufferMemory が Summarize する代わりに condense する)。

prompt の作成

llm = ChatOpenAI(temperature=0) # 最近ChatOpenAIにもtemperature が実装されたらしい
# QuestionGenerator として LLMChainに llm と prompt を渡す
question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
# print(CONDENSE_QUESTION_PROMPT)
# input_variables=['chat_history', 'question'] output_parser=None partial_variables={} 
# template='Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.\n\nChat History:\n{chat_history}\nFollow Up Input: {question}\nStandalone question:'
# template_format='f-string' validate_template=True
#
#次の会話に対しフォローアップの質問があるので、フォローアップの質問を独立した質問に言い換えなさい。
#
#チャット履歴
#{chat_history}
#
#フォローアップの入力: {query}
#独立した質問:


# QA の最終質問のprompt(stuff用)
prompt_template = """Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.

{context}

Question: {question}
Answer in Japanese:"""
PROMPT = PromptTemplate(
    template=prompt_template, input_variables=["context", "question"]
)

import した LLMChain ですが、これは question_generator という前処理に使用します。そして、この question_generator が利用する prompt が CONDENSE_QUESTION_PROMPT となっています。興味深かったので、 dump したものと和訳したものをコメントとして記述しています。

英語特有の表現(follow up, standalone)があってちょっと和訳が難しいのですが、「つまり、この人の質問の意図は何か、そのものズバリの質問を作成しなさい」ということです。

本記事の例でいえば、

History
A社の主要な作物は何?

A社の主要な作物は、サツマイモ、レタス、トマト、イチゴ、トウモロコシなどです。特に、イチゴの栽培で知られ、高い糖度と大粒で形状や色合いが良く人気があります。

それらの作物を加工して出来た商品は何?

つまり、サツマイモ、レタス、トマト、イチゴ、トウモロコシなどから製造されている商品は?

ということになります。

また、 chain_type が stuffing 用の Prompt も例示しています(map_reduce と併用はできません)。

カスタマイズしたインスタンスを作成

# ConversationalRetrievalChain で利用する load_qa_chain.
doc_chain = load_qa_chain(llm, chain_type="map_reduce")
# promptを利用する stuffing のQAチェイン
# doc_chain = load_qa_chain(llm, chain_type="stuff", prompt=PROMPT)
qa = ConversationalRetrievalChain(
    retriever=db.as_retriever(), # Source の vector store
    question_generator=question_generator, # 要するに、何? という質問を作る chain
    combine_docs_chain=doc_chain, # つまりユーザーの質問に対する答えは、という load_qa_chain
)

combine_docs_chain に、 chain_typeを map_reduce にした(あるいは、カスタマイズされた prompt を使用した stuffing の)load_qa_chain を与えています。これで、 ConversationalRetrievalChain に対し、map_reduce を使わせたり、 prompt を使わせたりが可能になります。

更に、 question_generator に先ほど作成した、 CONDENSE_QUESTION_PROMPT を利用した LLMChain を入れています。これが、 前処理として動作することで、 load_qa_chain に文脈を持たせつつ、一番新しい質問の意図をそこまで希釈させずに回答させることが期待されます。

出力

chat_history = []
query = "A社の主要な作物は何?"
# 複数の引数をdictで渡す, search_distance はさらに vectordbkwargs というdict で渡す
result = qa({"question": query, "chat_history": chat_history})
print(result["answer"])
# A社は最初の文書によれば、主要な作物としてサツマイモ、レタス、トマト、苺、トウモロコシを栽培しています。ただし、特に苺は施設園芸用ハウスで栽培され、人気が高く、売上高を拡大しています。

文脈が変化し、 chain_type が map_reduce になっていることがなんとなく想像できますね。

まとめ

  • vector index search をチャットに仕込むのは大変
  • similarity_search と load_qa_chain をひとまとめに処理するには、RetrievalQA Chain が便利。
  • チャットの履歴を使うには、 ConvaerstationalRetrievalChain を使用する
    • 履歴は、 Memory オブジェクトではなく、リストに格納された質問と応答がペアになったタプルを利用する
      • Memory オブジェクトは、色々な形式があるため
    • question_generator という前処理を担当する LLM によって、これまでの文脈を含めつつ質問の意図が希釈されないように変換が行われている
    • stuff 以外の chain_type を使いたい場合や prompt を変えたい場合は、 load_qa_chain を自力で作成する
  • ちゃんとやると思ったより面倒臭かった
FacebooktwitterredditpinterestlinkedinmailFacebooktwitterredditpinterestlinkedinmail

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

このサイトはスパムを低減するために Akismet を使っています。コメントデータの処理方法の詳細はこちらをご覧ください

最新の記事