Skip to content

Commit e5efdb8

Browse files
committed
[feat] add guardrails with llama guard 3
1 parent 3cb6bac commit e5efdb8

File tree

7 files changed

+89
-10
lines changed

7 files changed

+89
-10
lines changed

src/api/controllers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .chat import new_message as chat_new_message
2+
from .guardrails import Guardrail
23

3-
__all__ = ["chat_new_message"]
4+
__all__ = ["chat_new_message", "Guardrail"]

src/api/controllers/guardrails.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from fastapi import HTTPException, Request
2+
3+
from src.api.models import APIRequest
4+
from src.services.llama_guard import LlamaGuard
5+
from src.infrastructure.database import (
6+
MongoDB,
7+
get_user_details,
8+
block_user
9+
)
10+
11+
12+
class Guardrail:
13+
def __call__(self, api_request: APIRequest, req: Request) -> None:
14+
if req.app.llama_guard:
15+
self.llama_guard_layer(
16+
api_request.message,
17+
req.app.llama_guard,
18+
api_request.user_id,
19+
req.app.database
20+
)
21+
self.check_user(api_request.user_id, req.app.database)
22+
23+
def llama_guard_layer(
24+
self,
25+
message: str,
26+
llama_guard: LlamaGuard,
27+
user_id: str,
28+
db: MongoDB
29+
) -> None:
30+
response = llama_guard(message)
31+
if not response:
32+
_ = block_user(user_id, db)
33+
34+
raise HTTPException(
35+
status_code=400,
36+
detail=f"""
37+
O conteúdo fornecido viola as políticas da plataforma.
38+
O seu usuário foi bloqueado.
39+
40+
Retorno do LLAMA GUARD: {response}
41+
"""
42+
)
43+
44+
def check_user(self, user_id: str, db: MongoDB):
45+
user_details = get_user_details(user_id, db)
46+
if user_details and user_details.get("blocked"):
47+
raise HTTPException(
48+
status_code=400,
49+
detail="""
50+
O usuário está bloqueado devido a violação de políticas.
51+
"""
52+
)
53+
54+
# More Guardrails could be added here.

src/api/routes/chat.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from fastapi import APIRouter, status, Request, Depends
22

33
from src.api.models import APIResponse, APIRequest
4-
from src.api.controllers import chat_new_message
4+
from src.api.controllers import chat_new_message, Guardrail
55

66

77
router = APIRouter(
88
prefix="/chat",
99
tags=["chat"],
10-
# dependencies=[Depends(validate_user)]
10+
dependencies=[
11+
Depends(Guardrail())
12+
]
1113
)
1214

1315

src/infrastructure/config/settings.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ class Settings(BaseSettings):
2525
MODEL_TEMPERATURE: float = 0.2
2626
MODEL_API_KEY: str = ''
2727

28+
# LlamaGuard
29+
LLAMA_GUARD_MODEL: str = "llama-guard3"
30+
2831
class Config:
2932
env_file = ".env"
3033
extra = "ignore"

src/main.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
11
from fastapi import FastAPI
22

33
from src.infrastructure.database import MongoDB
4-
from src.api.routes import chat_router
5-
from src.infrastructure.config import settings
64
from src.infrastructure.config.llm import LLM
5+
from src.services.llama_guard import LlamaGuard
6+
7+
from src.api.routes import chat_router
78

89

910
def create_app():
1011
app = FastAPI()
1112

1213
# defining API variables
13-
app.database = MongoDB(db_name=settings.MONGO_DB)
14-
app.llm = LLM(model_name=settings.MODEL)
15-
16-
# app.vector_store = ChromaDB()
17-
# app.llm = choose_model(model_name=settings.MODEL)
14+
app.database = MongoDB()
15+
app.llm = LLM()
16+
app.llama_guard = LlamaGuard()
1817

1918
# including routes
2019
app.include_router(chat_router)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .llama_guard import LlamaGuard
2+
3+
__all__ = ["LlamaGuard"]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from langchain_ollama.llms import OllamaLLM
2+
from src.infrastructure.config import settings
3+
4+
5+
class LlamaGuard:
6+
def __init__(self) -> None:
7+
try:
8+
self.llm = OllamaLLM(
9+
model=settings.LLAMA_GUARD_MODEL,
10+
base_url=settings.MODEL_URL,
11+
)
12+
except:
13+
return None
14+
15+
def __call__(self, message: str):
16+
response = self.llm.invoke(message)
17+
return True if response == "safe" else False

0 commit comments

Comments
 (0)