Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions api/ai/ai_response_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class StreamRequest(BaseModel):
assistant_id: UUID
target_language: Optional[str] = None
prompt: List[str] = Field(min_length=1)
model: str = Field(min_length=1)
model: List[str] = Field(min_length=1)


class WorkflowResult(BaseModel):
output_text: str
Expand All @@ -44,6 +45,11 @@ class ResponseMetadata(BaseModel):


class StreamResponse(BaseModel):
model: str
results: List[WorkflowResult]
metadata: ResponseMetadata
errors: List[Any]
errors: List[Any]


class MultiModelResponse(BaseModel):
responses: List[StreamResponse]
52 changes: 34 additions & 18 deletions api/ai/ai_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import AsyncGenerator
import asyncio
from typing import AsyncGenerator, List

from api.Assistant.assistant_repository import get_assistant_by_id_repository
from api.Assistant.assistant_response_model import ContextRequest
Expand All @@ -10,7 +11,8 @@
WorkflowResult,
ResponseMetadata,
AvailableModelsResponse,
ModelInfo
ModelInfo,
MultiModelResponse,
)
from api.db.pg_database import SessionLocal
from api.llm.router import get_model_router
Expand All @@ -37,55 +39,68 @@ def build_workflow_request(db_session, assistant_id, target_language, prompt, mo
)


def validate_model(model: str) -> None:
def validate_models(models: List[str]) -> None:
model_router = get_model_router()
if not model_router.validate_model_availability(model):
invalid_models = [m for m in models if not model_router.validate_model_availability(m)]
if invalid_models:
available_models = model_router.available_models()
raise HTTPException(
status_code=400,
detail=f"Model not available: Select from this list: {list(available_models.keys())}"
status_code=400,
detail=f"Models not available: {invalid_models}. Select from: {list(available_models.keys())}"
)


async def run_workflow_service(assistant_id, target_language, prompt, model):
validate_model(model)

async def _run_single_workflow(assistant_id, target_language, prompt, model: str) -> StreamResponse:
with SessionLocal() as db_session:
workflow_request = build_workflow_request(
db_session, assistant_id, target_language, prompt, model
)

workflow_response = await run_workflow(workflow_request)

results = [
WorkflowResult(output_text=result.output_text)
for result in workflow_response.get("final_results", [])
]

workflow_metadata = workflow_response.get("metadata", {})
response_metadata = ResponseMetadata(
initialized_at=workflow_metadata.get("initialized_at"),
total_batches=workflow_metadata.get("total_batches"),
completed_at=workflow_metadata.get("completed_at"),
total_processing_time=workflow_metadata.get("total_processing_time")
)

response = StreamResponse(

return StreamResponse(
model=model,
results=results,
metadata=response_metadata,
errors=workflow_response.get("errors", [])
)

return response


async def run_workflow_service(
assistant_id, target_language, prompt, models: List[str]
) -> MultiModelResponse:
validate_models(models)

responses = await asyncio.gather(
*[
_run_single_workflow(assistant_id, target_language, prompt, model)
for model in models
]
)

return MultiModelResponse(responses=list(responses))


async def stream_workflow_service(
assistant_id,
target_language,
prompt,
model
model: str
) -> AsyncGenerator[str, None]:
validate_model(model)
validate_models([model])

with SessionLocal() as db_session:
workflow_request = build_workflow_request(
Expand All @@ -99,6 +114,7 @@ async def stream_workflow_service(
):
yield event


def get_available_models_service() -> AvailableModelsResponse:
model_router = get_model_router()
available_models_dict = model_router.available_models()
Expand Down
6 changes: 3 additions & 3 deletions api/ai/ai_view.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from api.ai.ai_response_model import StreamRequest, AvailableModelsResponse
from api.ai.ai_response_model import StreamRequest, AvailableModelsResponse, MultiModelResponse
from api.ai.ai_service import run_workflow_service, stream_workflow_service, get_available_models_service
from fastapi import APIRouter, HTTPException, Depends
from fastapi.responses import StreamingResponse
Expand Down Expand Up @@ -26,7 +26,7 @@
async def get_available_models():
return get_available_models_service()

@ai_router.post("", status_code=status.HTTP_200_OK)
@ai_router.post("", status_code=status.HTTP_200_OK, response_model=MultiModelResponse)
async def run_workflow(
payload: StreamRequest,
authentication_credential: Annotated[
Expand All @@ -48,7 +48,7 @@ async def run_workflow(
assistant_id=payload.assistant_id,
target_language=payload.target_language,
prompt=payload.prompt,
model=payload.model,
models=payload.model,
)


Expand Down
Loading