From 4d737df07d10e563b853cee351af812e950b998b Mon Sep 17 00:00:00 2001 From: TenzDelek Date: Mon, 16 Mar 2026 11:36:11 +0530 Subject: [PATCH] changes --- api/ai/ai_response_model.py | 10 +++++-- api/ai/ai_service.py | 52 ++++++++++++++++++++++++------------- api/ai/ai_view.py | 6 ++--- 3 files changed, 45 insertions(+), 23 deletions(-) diff --git a/api/ai/ai_response_model.py b/api/ai/ai_response_model.py index 72d303d..86cd13a 100644 --- a/api/ai/ai_response_model.py +++ b/api/ai/ai_response_model.py @@ -30,7 +30,8 @@ class StreamRequest(BaseModel): assistant_id: UUID target_language: Optional[str] = None prompt: List[str] = Field(min_length=1) - model: str = Field(min_length=1) + model: List[str] = Field(min_length=1) + class WorkflowResult(BaseModel): output_text: str @@ -44,6 +45,11 @@ class ResponseMetadata(BaseModel): class StreamResponse(BaseModel): + model: str results: List[WorkflowResult] metadata: ResponseMetadata - errors: List[Any] \ No newline at end of file + errors: List[Any] + + +class MultiModelResponse(BaseModel): + responses: List[StreamResponse] \ No newline at end of file diff --git a/api/ai/ai_service.py b/api/ai/ai_service.py index fbccd53..abec304 100644 --- a/api/ai/ai_service.py +++ b/api/ai/ai_service.py @@ -1,4 +1,5 @@ -from typing import AsyncGenerator +import asyncio +from typing import AsyncGenerator, List from api.Assistant.assistant_repository import get_assistant_by_id_repository from api.Assistant.assistant_response_model import ContextRequest @@ -10,7 +11,8 @@ WorkflowResult, ResponseMetadata, AvailableModelsResponse, - ModelInfo + ModelInfo, + MultiModelResponse, ) from api.db.pg_database import SessionLocal from api.llm.router import get_model_router @@ -37,31 +39,30 @@ def build_workflow_request(db_session, assistant_id, target_language, prompt, mo ) -def validate_model(model: str) -> None: +def validate_models(models: List[str]) -> None: model_router = get_model_router() - if not model_router.validate_model_availability(model): + invalid_models = [m for m in models if not model_router.validate_model_availability(m)] + if invalid_models: available_models = model_router.available_models() raise HTTPException( - status_code=400, - detail=f"Model not available: Select from this list: {list(available_models.keys())}" + status_code=400, + detail=f"Models not available: {invalid_models}. Select from: {list(available_models.keys())}" ) -async def run_workflow_service(assistant_id, target_language, prompt, model): - validate_model(model) - +async def _run_single_workflow(assistant_id, target_language, prompt, model: str) -> StreamResponse: with SessionLocal() as db_session: workflow_request = build_workflow_request( db_session, assistant_id, target_language, prompt, model ) - + workflow_response = await run_workflow(workflow_request) - + results = [ WorkflowResult(output_text=result.output_text) for result in workflow_response.get("final_results", []) ] - + workflow_metadata = workflow_response.get("metadata", {}) response_metadata = ResponseMetadata( initialized_at=workflow_metadata.get("initialized_at"), @@ -69,23 +70,37 @@ async def run_workflow_service(assistant_id, target_language, prompt, model): completed_at=workflow_metadata.get("completed_at"), total_processing_time=workflow_metadata.get("total_processing_time") ) - - response = StreamResponse( + + return StreamResponse( + model=model, results=results, metadata=response_metadata, errors=workflow_response.get("errors", []) ) - - return response + + +async def run_workflow_service( + assistant_id, target_language, prompt, models: List[str] +) -> MultiModelResponse: + validate_models(models) + + responses = await asyncio.gather( + *[ + _run_single_workflow(assistant_id, target_language, prompt, model) + for model in models + ] + ) + + return MultiModelResponse(responses=list(responses)) async def stream_workflow_service( assistant_id, target_language, prompt, - model + model: str ) -> AsyncGenerator[str, None]: - validate_model(model) + validate_models([model]) with SessionLocal() as db_session: workflow_request = build_workflow_request( @@ -99,6 +114,7 @@ async def stream_workflow_service( ): yield event + def get_available_models_service() -> AvailableModelsResponse: model_router = get_model_router() available_models_dict = model_router.available_models() diff --git a/api/ai/ai_view.py b/api/ai/ai_view.py index 085ca64..6f07b71 100644 --- a/api/ai/ai_view.py +++ b/api/ai/ai_view.py @@ -1,4 +1,4 @@ -from api.ai.ai_response_model import StreamRequest, AvailableModelsResponse +from api.ai.ai_response_model import StreamRequest, AvailableModelsResponse, MultiModelResponse from api.ai.ai_service import run_workflow_service, stream_workflow_service, get_available_models_service from fastapi import APIRouter, HTTPException, Depends from fastapi.responses import StreamingResponse @@ -26,7 +26,7 @@ async def get_available_models(): return get_available_models_service() -@ai_router.post("", status_code=status.HTTP_200_OK) +@ai_router.post("", status_code=status.HTTP_200_OK, response_model=MultiModelResponse) async def run_workflow( payload: StreamRequest, authentication_credential: Annotated[ @@ -48,7 +48,7 @@ async def run_workflow( assistant_id=payload.assistant_id, target_language=payload.target_language, prompt=payload.prompt, - model=payload.model, + models=payload.model, )