-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
174 lines (136 loc) · 6.02 KB
/
app.py
File metadata and controls
174 lines (136 loc) · 6.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
from fastapi.templating import Jinja2Templates
from langchain_pinecone import PineconeVectorStore
from src.config import Config
from src.helper import download_embeddings
from src.utility import QueryClassifier, StreamingHandler
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_classic.chains import create_retrieval_chain
from langchain_classic.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_community.chat_message_histories import ChatMessageHistory
from src.prompt import system_prompt
import uuid
Config.validate()
PINECONE_API_KEY = Config.PINECONE_API_KEY
GEMINI_API_KEY = Config.GEMINI_API_KEY
templates = Jinja2Templates(directory="templates")
# Intialize FastAPI app
app = FastAPI(title="Medical Chatbot", version="0.0.0")
# Store for session-based chat histories (resets on server restart)
chat_histories = {}
# Intialize embedding model
print("Loading the Embedding model...")
embeddings = download_embeddings()
# Connect to existing Pinecone index
index_name = Config.PINECONE_INDEX_NAME
print(f"Connecting to PineCone index: {index_name}")
docsearch = PineconeVectorStore.from_existing_index(
index_name=index_name, embedding=embeddings
)
# Creating retriever from vector store
retriever = docsearch.as_retriever(
search_type=Config.SEARCH_TYPE, search_kwargs={"k": Config.RETRIEVAL_K}
)
# Initialize Google Gemini chat model
print("Initializing Gemini model...")
llm = ChatGoogleGenerativeAI(
model=Config.GEMINI_MODEL,
google_api_key=GEMINI_API_KEY,
temperature=Config.LLM_TEMPERATURE,
convert_system_message_to_human=True,
)
# Create chat prompt template with memory
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
MessagesPlaceholder(variable_name="chat_history"),
("human", "{input}"),
]
)
# Create the question-answer chain
question_answer_chain = create_stuff_documents_chain(llm, prompt)
# Create the RAG chain
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
# Function to get chat history for a session
def get_chat_history(session_id: str) -> BaseChatMessageHistory:
if session_id not in chat_histories:
chat_histories[session_id] = ChatMessageHistory()
return chat_histories[session_id]
# Function to maintain conversation window buffer (keep last 5 messages)
def manage_memory_window(session_id: str, max_messages: int = 10):
"""Keep only the last max_messages (5 pairs = 10 messages)"""
if session_id in chat_histories:
history = chat_histories[session_id]
if len(history.messages) > max_messages:
# Keep only the last max_messages
history.messages = history.messages[-max_messages:]
print("Intialized Medical Chabot successfuly!")
print("Vector Store connected")
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
"""Render the chatbot interface"""
# Clear all old sessions to prevent memory overflow
chat_histories.clear()
# Generate a new session ID for each page load
session_id = str(uuid.uuid4())
return templates.TemplateResponse(
"index.html", {"request": request, "session_id": session_id}
)
@app.post("/get")
async def chat(msg: str = Form(...), session_id: str = Form(...)):
"""Handle chat messages and return streaming AI responses with conversation memory"""
# Get chat history for this session
history = get_chat_history(session_id)
# Classify query to determine if retrieval is needed
needs_retrieval, reason = QueryClassifier.needs_retrieval(msg)
async def generate_response():
"""Generator for streaming response"""
full_answer = ""
try:
if needs_retrieval:
# Stream RAG chain response for medical queries
print(f"✓ [RETRIEVAL STREAM] Reason: {reason} | Query: {msg[:50]}...")
async for chunk in StreamingHandler.stream_rag_response(
rag_chain, {"input": msg, "chat_history": history.messages}
):
yield chunk
# Extract full answer from the last chunk
if b'"done": true' in chunk.encode():
import json
data = json.loads(chunk.replace("data: ", "").strip())
if "full_answer" in data:
full_answer = data["full_answer"]
else:
# Stream simple response for greetings/acknowledgments
print(f"[NO RETRIEVAL STREAM] Reason: {reason} | Query: {msg[:50]}...")
simple_resp = QueryClassifier.get_simple_response(msg)
full_answer = simple_resp
async for chunk in StreamingHandler.stream_simple_response(simple_resp):
yield chunk
# Add the conversation to history after streaming completes
history.add_user_message(msg)
history.add_ai_message(full_answer)
# Manage memory window
manage_memory_window(session_id, max_messages=10)
except Exception as e:
print(f"Error during streaming: {str(e)}")
import json
yield f"data: {json.dumps({'error': 'An error occurred', 'done': True})}\n\n"
return StreamingResponse(
generate_response(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no"
}
)
if __name__ == "__main__":
import uvicorn
import os
# Use PORT from environment (7860 for HF Spaces, 8080 for Render)
port = int(os.getenv("PORT", 7860))
uvicorn.run(app, host="0.0.0.0", port=port)