Skip to content

Commit 5fff272

Browse files
committed
[feat] add custom chat
1 parent f3d4216 commit 5fff272

File tree

11 files changed

+185
-0
lines changed

11 files changed

+185
-0
lines changed

src/api/controllers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .chat import new_message as chat_new_message
2+
3+
__all__ = ["chat_new_message"]

src/api/controllers/chat.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from src.infrastructure.database import MongoDB, get_service_details
2+
from src.infrastructure.config import LLM
3+
from src.services import CustomChat
4+
5+
6+
async def new_message(
7+
db: MongoDB,
8+
model: LLM,
9+
message: str,
10+
service_name: str,
11+
) -> str:
12+
13+
if service_details := get_service_details(service_name, db):
14+
prompt = service_details["prompt"]
15+
else:
16+
prompt = "You are a helpful assistant. Be kind!"
17+
18+
chat = CustomChat(
19+
model=model,
20+
sys_prompt=prompt
21+
)
22+
response = await chat(message)
23+
return response

src/api/models/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .api import APIResponse, APIRequest
2+
3+
__all__ = ["APIResponse", "APIRequest"]

src/api/models/api.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from typing import Optional
2+
from pydantic import BaseModel
3+
4+
5+
class APIResponse(BaseModel):
6+
status_code: int
7+
status_message: Optional[str] = None
8+
response: Optional[dict] = None
9+
10+
11+
class APIRequest(BaseModel):
12+
message: str
13+
user_id: str
14+
service_name: Optional[str] = None

src/api/routes/chat.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from fastapi import APIRouter, status, Request, Depends
2+
3+
from src.api.models import APIResponse, APIRequest
4+
from src.api.controllers import chat_new_message
5+
6+
7+
router = APIRouter(
8+
prefix="/chat",
9+
tags=["chat"],
10+
# dependencies=[Depends(validate_user)]
11+
)
12+
13+
14+
@router.get("/", status_code=status.HTTP_200_OK)
15+
async def router_test() -> APIResponse:
16+
return APIResponse(
17+
status_code=200,
18+
status_message="-- CHAT ROUTER WORKING! --"
19+
)
20+
21+
22+
@router.post("/new_message", status_code=status.HTTP_200_OK)
23+
async def new_message(api_request: APIRequest, req: Request) -> APIResponse:
24+
try:
25+
response = await chat_new_message(
26+
req.app.database,
27+
req.app.llm,
28+
api_request.message,
29+
api_request.service_name
30+
)
31+
32+
return APIResponse(
33+
status_code=200,
34+
response={
35+
"user": api_request.message,
36+
"ai": response
37+
}
38+
)
39+
40+
except Exception as e:
41+
return APIResponse(
42+
status_code=500,
43+
status_message=f"Error: {e}"
44+
)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .settings import settings
2+
from .llm import LLM
3+
4+
__all__ = ["settings", "LLM"]

src/infrastructure/config/llm.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from .settings import settings
2+
3+
from langchain_ollama.llms import OllamaLLM
4+
from langchain_openai import ChatOpenAI
5+
6+
7+
class LLM:
8+
def __new__(cls, model_name: str):
9+
try:
10+
if model_name == "ollama":
11+
return OllamaLLM(
12+
model=settings.MODEL_NAME,
13+
base_url=settings.MODEL_URL,
14+
temperature=settings.MODEL_TEMPERATURE
15+
)
16+
17+
elif model_name == "openai":
18+
return ChatOpenAI(
19+
model=settings.MODEL_NAME,
20+
base_url=settings.MODEL_URL,
21+
temperature=settings.MODEL_TEMPERATURE,
22+
api_key=settings.MODEL_API_KEY
23+
)
24+
# More models can be added here
25+
26+
raise ValueError(f"Model {model_name} not supported")
27+
28+
except Exception as e:
29+
raise ValueError(
30+
f"Problem instantiating the model {model_name}: {e}"
31+
)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from pydantic_settings import BaseSettings
2+
3+
4+
class Settings(BaseSettings):
5+
6+
# MongoDB
7+
MONGO_USER: str
8+
MONGO_PASSWORD: str
9+
MONGO_HOST: str = "localhost"
10+
MONGO_PORT: str = "27017"
11+
MONGO_DB: str
12+
13+
# ChromaDB
14+
CHROMA_HOST: str = "localhost"
15+
CHROMA_PORT: str = "8000"
16+
CHROMA_DB: str
17+
18+
# General Settings
19+
TIMEZONE: str = "America/Sao_Paulo"
20+
21+
# LLM
22+
MODEL: str = "ollama"
23+
MODEL_NAME: str = "llama3"
24+
MODEL_URL: str = "http://localhost:11434"
25+
MODEL_TEMPERATURE: float = 0.2
26+
MODEL_API_KEY: str = ''
27+
28+
class Config:
29+
env_file = ".env"
30+
extra = "ignore"
31+
32+
33+
settings = Settings()

src/services/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .custom_chat.chat import CustomChat
2+
3+
__all__ = ["CustomChat"]

src/services/custom_chat/chat.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from src.infrastructure.config import LLM
2+
from langchain_core.prompts import ChatPromptTemplate
3+
from langchain_core.tools import BaseTool
4+
from typing import List
5+
6+
7+
class CustomChat:
8+
def __init__(self, model: LLM, sys_prompt: str):
9+
self.model = model
10+
self.sys_prompt = sys_prompt
11+
self._prompt_template = ChatPromptTemplate.from_messages([
12+
("system", self.sys_prompt),
13+
("user", "{input}")
14+
])
15+
16+
async def __call__(self, user_input: str):
17+
# TO DO: add bind_tools
18+
self.chain = self._prompt_template | self.model
19+
response = self.chain.invoke({"input": user_input})
20+
return response
21+
22+
def add_tools(self, tools: List[BaseTool]):
23+
pass

0 commit comments

Comments
 (0)