-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathchat_with_sqlite.py
More file actions
124 lines (96 loc) · 3.86 KB
/
chat_with_sqlite.py
File metadata and controls
124 lines (96 loc) · 3.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# DEMO NLQ
# Simplest form, single shot, single prompt
# Replace with CHOP model and API key once available
# llm = OllamaLLM(model="llama3.2:latest")
# db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)
# db_chain.invoke("How many isolates are in the database?")
# db_chain.invoke("Can you tell me the names of the isolates with 0 aliquots?")
from IPython.display import Image, display
from langchain.chat_models import init_chat_model
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import ChatPromptTemplate
from langchain_experimental.sql import SQLDatabaseChain
from langchain_ollama.llms import OllamaLLM
from langgraph.graph import START, StateGraph
from langgraph.prebuilt import create_react_agent
from typing_extensions import Annotated, TypedDict
class QueryOutput(TypedDict):
"""Generated SQL query."""
query: Annotated[str, ..., "Syntactically valid SQL query."]
class State(TypedDict):
question: str
query: str
result: str
answer: str
db = SQLDatabase.from_uri("sqlite:////home/ctbus/Penn/marc_web/db.sqlite")
llm = init_chat_model("mistral-small", model_provider="ollama")
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
tools = toolkit.get_tools()
QUERY_SYSTEM_PROMPT = """
You are an expert SQL query builder for {dialect} databases.
Use the following database schema:
{table_info}
Return a syntactically valid SQL statement that answers the user's question.
Unless the user specifies a row limit, default to returning up to {top_k} results.
If an existing query is provided, treat it as a starting point and refine it if appropriate.
Existing query draft (may be empty):
{existing_query}
"""
query_prompt_template = ChatPromptTemplate.from_messages(
[
("system", QUERY_SYSTEM_PROMPT),
("human", "{input}"),
]
)
def write_query(state: State) -> State:
"""Generate SQL query to fetch information."""
prompt = query_prompt_template.invoke(
{
"dialect": "sqlite",
"top_k": 10,
"table_info": db.get_table_info(),
"existing_query": state.get("query", ""),
"input": state["question"],
}
)
structured_llm = llm.with_structured_output(QueryOutput)
result = structured_llm.invoke(prompt)
return {"query": result["query"]}
def execute_query(state: State):
"""Execute SQL query."""
execute_query_tool = QuerySQLDatabaseTool(db=db)
return {"result": execute_query_tool.invoke(state["query"])}
def generate_answer(state: State):
"""Answer question using retrieved information as context."""
prompt = (
"Given the following user question, corresponding SQL query, "
"and SQL result, answer the user question.\n\n"
f'Question: {state["question"]}\n'
f'SQL Query: {state["query"]}\n'
f'SQL Result: {state["result"]}'
)
response = llm.invoke(prompt)
return {"answer": response.content}
if True:
graph_builder = StateGraph(State).add_sequence(
[write_query, execute_query, generate_answer]
)
graph_builder.add_edge(START, "write_query")
graph = graph_builder.compile()
# display(Image(graph.get_graph().draw_mermaid_png()))
for step in graph.stream(
{"question": "What is the average number of aliquots per isolate?"},
stream_mode="updates",
):
print(step)
else:
system_message = query_prompt_template.format(dialect="SQLite", top_k=5)
agent_executor = create_react_agent(llm, tools, prompt=system_message)
question = "What is the average number of aliquots per isolate?"
for step in agent_executor.stream(
{"messages": [{"role": "user", "content": question}]},
stream_mode="values",
):
step["messages"][-1].pretty_print()