LangChain を使って自然言語で SQL データベースを操作する【GPT-3.5-turbo】

FacebooktwitterredditpinterestlinkedinmailFacebooktwitterredditpinterestlinkedinmail

今回は、 LangChain を使って SQLite を直接操作する方法を試してみます。

GPT-4 をはじめ、大規模言語モデル(LLM)は非常の強力なツールですが、事前に学習されたデータを元に回答を作成します。自社ドキュメントを Embedding したりGoogle 検索を利用したり、あるいはこれらを使い分けたりして学習されていないデータを元にした回答を作成する方法もあります。いわゆる ChatGPT を自社ビジネスに利用したいとなると真っ先に検討される内容と言えるのではないでしょうか。

一方で、生成系としての使い方には別の面もあります。出力される文字列を最終的なアウトプットとして利用するのではなく、別のツールの input として利用する考え方です。 ChatGPT にプログラムを書かせたり、SQL を書いてもらったりというのがこの方式であると言えます。

LangChain から SQLite を操作する場合、(ChatGPTのような汎用的なインターフェースを使う場合と比較し)大きく分けてメリットが3つあります。

  1. SQL 文を SQLite に入れるという作業がないので、SQL を全く知らないユーザーでも使える
  2. データベースの構造などのメタ情報を LangChain がよろしくやってくれるので、Prompt を深く考えなくていい
  3. 使用ツールの判断まで LLM 任せにできる

3 は、今回行う Chain での実行方式ではそこまでメリットを感じないかもしれませんが、 1, 2の時点で大分強力です。もちろん、 LLM に SQL を書かせることのメリット(作業負担の軽減、学習負担の軽減など)もあります。

一方で、デメリットというか危険性も理解しておく必要があります。

  1. prompt injection や予期せぬ命令によってデータベースが破壊される可能性がある
  2. LLM API にデータベースの内容が送信される

1 については、広く一般にデータベースにアクセスできる LLM を開放することはないと思うので、 prompt injection はそれほど心配しなくてもいいかもしれません。しかし、ユーザーが誤って変なpromptを入力してしまい、データベースに悪影響を与える可能性は考慮すべきでしょう。

2 については、特に秘匿性の高いデータベース(個人情報など)を扱う場合には重要です。SQLDatabaseChain に return_direct=True とスイッチを与えることで、 SQLデータベースからの応答は LLM に送信されなくなります。しかし、それでもデータベースのschema (table や key)は送信されてしまうのでR&Dなどで table 名や key名すら秘匿する必要がある場合は使用しないでください。

参考:https://langchain.readthedocs.io/en/latest/modules/chains/examples/sqlite.html

LangChainの SQLChain を使って SQLite を操作する

テーブルの内容

参考ページにサンプルデータベースの紹介もあったようですが、今回は ChatGPT に作って貰いました。

テーブルの内容

Students:
(1, ‘Taro’, 20)
(2, ‘Hanako’, 19)
(3, ‘Jiro’, 21)

Courses:
(1, ‘Math’)
(2, ‘English’)
(3, ‘Physics’)

Grades:
(‘Taro’, ‘Math’, 85.0)
(‘Taro’, ‘English’, 90.0)
(‘Taro’, ‘Physics’, 88.0)
(‘Hanako’, ‘Math’, 75.0)
(‘Hanako’, ‘English’, 80.0)
(‘Hanako’, ‘Physics’, 82.0)
(‘Jiro’, ‘Math’, 92.0)
(‘Jiro’, ‘English’, 95.0)
(‘Jiro’, ‘Physics’, 89.0)

※後でteachers も追加

動作させてみる

# SQL
from langchain import SQLDatabase, SQLDatabaseChain
# GPT-3.5-turbo
from langchain.chat_models import ChatOpenAI

# env に読み込ませるAPIキーの類
import key
# 環境変数にAPIキーを設定
import os
os.environ["OPENAI_API_KEY"] = key.OPEN_API_KEY

# データベース読み込み
db = SQLDatabase.from_uri("sqlite:///school_database.db")
# LLM
llm = ChatOpenAI()

# データベース Chain
db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True)

print(db_chain.run("生徒は何人いる?"))

やることは、 SQLDatabase モジュールで、まずはデータベース自体への参照を Python 上に持ちます。ただし、これは Python に標準でインストールされている SQLite3 ライブラリではなく、LangChain で拡張されたものであることに注意してください。後々出てきますが、 LLM がデータベースを操作できるように内部のメタデータなどを含んだオブジェクトになっています。

データベースへの参照と LLM への参照が得られたら、 SQLDatabaseChain を作成します。今回は動作を分かりやすくするため、 verbose スイッチに True を指定します。

結果

> Entering new SQLDatabaseChain chain…
生徒は何人いる?
SQLQuery:SELECT COUNT(*) FROM students
SQLResult: [(3,)]
Answer:生徒は3人います。
> Finished chain.
生徒は3人います。

ChatGPT であれば、どんなテーブルがあるか、といった情報を提供しなければならなかったですが、LangChain で SQLDatabaseChain を利用した場合「生徒は何人いる?」という端的な質問だけで SQL を発行してくれていることが分かります。

また、テーブル名は「students」と英語にも関わらず、日本語の指示から正しいテーブルを選択しています。実用もコーディングも楽ですね。

カスタムテンプレートを使って見る

実際には標準のテンプレートでの動作でも充分だと思いますが、チャット型への拡張なども考えると、テンプレートの変更も出来た方がいいでしょう。

ここではデフォルトのテンプレートを、 GPT-3.5-turbo 用に System用とHuman用に分割して再構築してみます(これだけで prompt injection を防げるものではないですが、間違ってテーブルを削除されないような文言をSystem Roleで書いておくと気休めくらいにはなるでしょう)。

# カスタムテンプレート用(Chat特化)
from langchain.prompts.prompt import PromptTemplate
from langchain.prompts.chat import (
    # メッセージテンプレート
    ChatPromptTemplate,
    # System メッセージテンプレート
    SystemMessagePromptTemplate,
    # assistant メッセージテンプレート
    AIMessagePromptTemplate,
    # user メッセージテンプレート
    HumanMessagePromptTemplate,
)
from langchain.schema import (
    # それぞれ GPT-3.5-turbo API の assistant, user, system role に対応
    AIMessage,
    HumanMessage,
    SystemMessage
)
_DEFAULT_TEMPLATE = """Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
Use the following format:

Question: "Question here"
SQLQuery: "SQL Query to run"
SQLResult: "Result of the SQLQuery"
Answer: "Final answer here"

Only use the following tables:

{table_info}

If someone asks for the table hogehoge, they really mean the student table.
"""
user_template = "Question: {input}"

system_message_prompt = SystemMessagePromptTemplate.from_template(_DEFAULT_TEMPLATE)
human_message_prompt = HumanMessagePromptTemplate.from_template(user_template)

chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
chat_prompt.input_variables = ["input", "table_info", "dialect"]

db_chain = SQLDatabaseChain(llm=llm, database=db, prompt=chat_prompt, verbose=True)
print(db_chain.run("hogehogeテーブルに生徒は何人いる?"))

標準のprompt に加えて、students テーブルの別名として hogehoge という名前を指導しています。

> Entering new SQLDatabaseChain chain…
hogehogeテーブルに生徒は何人いる?
SQLQuery:SELECT COUNT(*) FROM students
SQLResult: [(3,)]
Answer:3
> Finished chain.
3

問題なく動作しますが、なんだか答えがぶっきらぼうになってしまいました(特にreturn_direct スイッチなどは使っていないのですが…)。

また、prompt には、 input, table_info, dialect の3引数は必須となっています。

検索数を制限する

# 上位3つの結果
db_chain = SQLDatabaseChain(llm=llm, database=db, verbose=True, top_k=3)
print(db_chain.run("上位の点数を見せて"))

上位の点数を見せて
SQLQuery:SELECT students.name, courses.name, grades.grade FROM grades
JOIN students ON grades.student_id=students.id
JOIN courses ON grades.course_id=courses.id
ORDER BY grades.grade DESC LIMIT 3
SQLResult: [(‘Jiro’, ‘English’, 95.0), (‘Jiro’, ‘Math’, 92.0), (‘Taro’, ‘English’, 90.0)]
Answer:Here are the top 3 grades:
Jiro got 95 in English, Jiro got 92 in Math, and Taro got 90 in English.
> Finished chain.
Here are the top 3 grades:

top_k=に制限数を入れれば、無制限に SQLの結果が返され、 LLM の Token を大幅に使用してしまうことを防止できます。

LLM に渡すデータベースの情報を制限する

# 標準のデータベース読み込み
db = SQLDatabase.from_uri("sqlite:///school_database.db")
print(db.table_info)
print("----------------------------------------------")
# サンプルを指定してデータベース読み込み
db = SQLDatabase.from_uri("sqlite:///school_database.db",
                          include_tables=['grades'], # 成績テーブルだけサンプルにいれて token の節約
                          sample_rows_in_table_info=2)
print(db.table_info)

CREATE TABLE courses (
id INTEGER,
name TEXT NOT NULL,
PRIMARY KEY (id)
)
/*
3 rows from courses table:
id name
1 Math
2 English
3 Physics
*/

CREATE TABLE students (
id INTEGER,
name TEXT NOT NULL,
age INTEGER NOT NULL,
PRIMARY KEY (id)
)
/*
3 rows from students table:
id name age
1 Taro 20
2 Hanako 19
3 Jiro 21
*/

CREATE TABLE grades (
id INTEGER,
student_id INTEGER NOT NULL,
course_id INTEGER NOT NULL,
grade REAL NOT NULL,
PRIMARY KEY (id),
FOREIGN KEY(student_id) REFERENCES students (id),
FOREIGN KEY(course_id) REFERENCES courses (id)
)
/*
3 rows from grades table:
id student_id course_id grade
1 1 1 85.0
2 1 2 90.0
3 1 3 88.0
*/
———————————————-

CREATE TABLE grades (
id INTEGER,
student_id INTEGER NOT NULL,
course_id INTEGER NOT NULL,
grade REAL NOT NULL,
PRIMARY KEY (id),
FOREIGN KEY(student_id) REFERENCES students (id),
FOREIGN KEY(course_id) REFERENCES courses (id)
)
/*
2 rows from grades table:
id student_id course_id grade
1 1 1 85.0
2 1 2 90.0
*/

データベースのテーブル数が多すぎる場合、また、 LLM に操作・APIで送信してほしくないテーブルが存在する場合には、 include_tables に、LLM に情報を与えたいテーブル名を列挙します。これで、 token の節約や操作されては困るテーブルの情報を除外することが可能になります(もちろん、テーブル名を直接指定されてしまった場合などはこの限りではないですが)。

同様に、samples_row_in_table_infoに数値を入れることで、 LLM に提供するサンプルデータの数を制限することもできます。数値しかない、key名から(恐らく)自明である場合などは数値を小さくすると token を節約できます。

自分でtable 情報を構築する

custom_table_info = {
    "grades": """CREATE TABLE grades (
	"id" INTEGER,
    "student_id" INTEGER NOT NULL,
    "course_id" INTEGER NOT NULL,
    "grade" REAL NOT NULL, 
	PRIMARY KEY ("id"),
    FOREIGN KEY("student_id") REFERENCES students ("id"),
    FOREIGN KEY("course_id") REFERENCES courses ("id")
)
/*
3 rows from grades table:
id	studend_id	course_id   grade
3   1    3  88.0
4   2    1  75.0
7   3    1  92.0
*/"""
}

# カスタマイズされたテーブル情報を含めてデータベース読み込み
db = SQLDatabase.from_uri("sqlite:///school_database.db",
                         include_tables=['grades', 'students'],
                         sample_rows_in_table_info=2,
                         custom_table_info=custom_table_info)
print(db.table_info)

CREATE TABLE students (
id INTEGER,
name TEXT NOT NULL,
age INTEGER NOT NULL,
PRIMARY KEY (id)
)
/*
2 rows from students table:
id name age
1 Taro 20
2 Hanako 19
*/

CREATE TABLE grades (
“id” INTEGER,
“student_id” INTEGER NOT NULL,
“course_id” INTEGER NOT NULL,
“grade” REAL NOT NULL,
PRIMARY KEY (“id”),
FOREIGN KEY(“student_id”) REFERENCES students (“id”),
FOREIGN KEY(“course_id”) REFERENCES courses (“id”)
)
/*
3 rows from grades table:
id studend_id course_id grade
3 1 3 88.0
4 2 1 75.0
7 3 1 92.0
*/

custom_table_info にテーブル名をキーとする dict を入れることで完全にコントロールされたテーブル情報を LLM に与えることができます。 dict 内にテーブル名が見つからなかったテーブルは標準機能で出力されます(include_tables と sample_rows_in_table_infoの制限内で)。とはいえ、大きくテーブルの schema などは書き換えない方がいいでしょう。

使い方としては、顧客の個人情報(氏名、メールアドレス、住所、暗号化されたパスワード、支払情報など)が含まれるテーブルを LLM 経由で扱いたい場合に、 sample_rows をダミーデータに変更して情報漏洩の防止に努めるといった使い方が考えられます。

SequentialChainで実行してみる

db = SQLDatabase.from_uri("sqlite:///school_database.db")
from langchain.chains import SQLDatabaseSequentialChain
db_chain = SQLDatabaseSequentialChain.from_llm(llm=llm, database=db, verbose=True)
print(db_chain.run("生徒と先生はあわせて何人いる?"))

> Entering new SQLDatabaseSequentialChain chain…
Table names to use:
[‘students’, ‘teachers’]

> Entering new SQLDatabaseChain chain…
生徒と先生はあわせて何人いる?
SQLQuery:SELECT COUNT(*) FROM students UNION SELECT COUNT(*) FROM teachers
SQLResult: [(3,), (5,)]
Answer:生徒と先生はあわせて8人いる。

(Note: The query calculates the total number of students and teachers separately using the COUNT function, and then unions the two results together to get a single result with the total number of people.)
> Finished chain.

> Finished chain.
生徒と先生はあわせて8人いる。

(Note: The query calculates the total number of students and teachers separately using the COUNT function, and then unions the two results together to get a single result with the total number of people.)

今回のサンプルデータベースのような小さなデータベースでは効果が薄い(全くない)ですが、シーケンシャルに意志決定を行ってデータベースを操作します。

このチェーンは、

  1. LLM への Query を元にして、使用するテーブルを決定
  2. 1.で決定されたテーブル情報に基づいて、通常の SQLデータベースチェーンを実行

となっています。単にデータベースが巨大というよりは、テーブル数が非常に多い場合(正規化され、リレーショナルが複雑だったり)に有効となっています。実ビジネスだと、取引に紐付く情報が顧客や担当者、請求書だけでなく、組み立て用のパーツだったり、あるいは工員だったりと複雑に絡み合う局面では有効かと思われます。

まとめ

  • SQLite を LangChain 経由で操作すると、プログラムを作る人も使う人もとても楽
  • SQLに直接アクセスするに等しいので、セキュリティやコンプライアンスに注意を払う必要がある
  • 様々な制限を加える場合には、promptによる制限だけではなく、ソースコードからの「ハードな」制約を設定した方がよい
  • 特に LangChain の SQLiteDatabaseChain には、 LLM に送信される情報を細かく制御できる仕組みが備わっている
FacebooktwitterredditpinterestlinkedinmailFacebooktwitterredditpinterestlinkedinmail

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

このサイトはスパムを低減するために Akismet を使っています。コメントデータの処理方法の詳細はこちらをご覧ください

最新の記事