Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion api/Assistant/assistant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ class Context(Base):

id = Column(UUID(as_uuid=True), primary_key=True, default=uuid4)
content=Column(Text,nullable=True)
file_url=Column(Text,nullable=True)
pecha_title = Column(String(255), nullable=True)
pecha_text_id = Column(String(255), nullable=True)
assistant_id = Column(UUID(as_uuid=True),ForeignKey("assistant.id", ondelete="CASCADE"),nullable=False)
Expand Down
2 changes: 0 additions & 2 deletions api/Assistant/assistant_response_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,12 @@

class ContextRequest(BaseModel):
content: Optional[str] = None
file_url: Optional[str] = None
pecha_title: Optional[str] = None
pecha_text_id: Optional[str] = None

class ContextResponse(BaseModel):
id: UUID
content: Optional[str] = None
file_url: Optional[str] = None
pecha_title: Optional[str] = None
pecha_text_id: Optional[str] = None

Expand Down
63 changes: 31 additions & 32 deletions api/Assistant/assistant_service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import logging
from api.Users.user_service import validate_and_extract_user_email
from api.db.pg_database import SessionLocal
from api.Assistant.assistant_repository import get_all_assistants, get_assistant_by_id_repository, delete_assistant_repository, update_assistant_repository
Expand All @@ -8,29 +7,22 @@
from api.Assistant.assistant_model import Assistant, Context
from uuid import UUID
from datetime import datetime, timezone
from fastapi import HTTPException, status
from fastapi import HTTPException, status, UploadFile
from api.error_constant import ErrorConstants
from api.upload.S3_utils import generate_presigned_access_url, delete_file
from api.config import get
from api.cache.cache_enums import CacheType
from api.Assistant.assistant_cache_service import (
get_assistant_detail_cache,
set_assistant_detail_cache,
delete_assistant_detail_cache,
)
from api.utils import Utils


def _build_context_responses(contexts) -> List[ContextResponse]:
return [
ContextResponse(
id=context.id,
content=context.content,
file_url=(
generate_presigned_access_url(
bucket_name=get("AWS_BUCKET_NAME"),
s3_key=context.file_url
) if context.file_url else None
),
pecha_title=context.pecha_title,
pecha_text_id=context.pecha_text_id
) for context in contexts
Expand Down Expand Up @@ -79,21 +71,35 @@ def get_assistants(skip: 0, limit: 20) -> AssistantResponse:
return assistant_response


def create_assistant_service(token: str, assistant_request: AssistantRequest):
current_user_email=validate_and_extract_user_email(token=token)
async def create_assistant_service(token: str, assistant_request: AssistantRequest, files: List[UploadFile] = None):
current_user_email = validate_and_extract_user_email(token=token)
contexts_list = []
for ctx in assistant_request.contexts:
contexts_list.append(
Context(content=ctx.content, pecha_title=ctx.pecha_title, pecha_text_id=ctx.pecha_text_id)
)

if files:
for file in files:
if file.filename:
file_bytes = await file.read()
try:
Utils.validate_file(file.filename, len(file_bytes))
extracted_content = Utils.extract_content_from_file(file_bytes, file.filename)
contexts_list.append(Context(content=extracted_content))
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))

with SessionLocal() as db_session:
assistant = Assistant(
name=assistant_request.name,
source_type=assistant_request.source_type,
description=assistant_request.description,
system_prompt=assistant_request.system_prompt,
system_assistance=assistant_request.system_assistance,
created_by=current_user_email,
contexts=[
Context(content=ctx.content, file_url=ctx.file_url, pecha_title=ctx.pecha_title, pecha_text_id=ctx.pecha_text_id)
for ctx in assistant_request.contexts
]
)
name=assistant_request.name,
source_type=assistant_request.source_type,
description=assistant_request.description,
system_prompt=assistant_request.system_prompt,
system_assistance=assistant_request.system_assistance,
created_by=current_user_email,
contexts=contexts_list
)
create_assistant_repository(db=db_session, assistant=assistant)

async def get_assistant_by_id_service(assistant_id: UUID) -> AssistantInfoResponse:
Expand All @@ -118,21 +124,14 @@ async def get_assistant_by_id_service(assistant_id: UUID) -> AssistantInfoRespon
return assistant_info

async def delete_assistant_service(assistant_id: UUID, token: str):
current_user_email=validate_and_extract_user_email(token=token)
current_user_email = validate_and_extract_user_email(token=token)
with SessionLocal() as db_session:
assistant = get_assistant_by_id_repository(db=db_session, assistant_id=assistant_id)
if current_user_email != assistant.created_by:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ErrorConstants.UNAUTHORIZED_ERROR_MESSAGE)
if assistant.system_assistance:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=ErrorConstants.FORBIDDEN_ERROR_MESSAGE)

for context in assistant.contexts:
if context.file_url:
try:
delete_file(context.file_url)
except Exception as e:
logging.error(f"Failed to delete S3 file {context.file_url}: {str(e)}")

delete_assistant_repository(db=db_session, assistant_id=assistant_id)

await delete_assistant_detail_cache(
Expand Down Expand Up @@ -161,7 +160,7 @@ async def update_assistant_service(assistant_id: UUID, update_request: UpdateAss
for context in assistant.contexts:
db_session.delete(context)
assistant.contexts = [
Context(content=ctx.content, file_url=ctx.file_url, pecha_title=ctx.pecha_title, pecha_text_id=ctx.pecha_text_id)
Context(content=ctx.content, pecha_title=ctx.pecha_title, pecha_text_id=ctx.pecha_text_id)
for ctx in update_request.contexts
]

Expand Down
31 changes: 27 additions & 4 deletions api/Assistant/assistant_view.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from fastapi import APIRouter
from fastapi import APIRouter, UploadFile, File, Form
from starlette import status
from api.Assistant.assistant_response_model import AssistantResponse, AssistantRequest, AssistantInfoResponse, UpdateAssistantRequest
from fastapi import Query, Depends
from api.Assistant.assistant_service import create_assistant_service, get_assistant_by_id_service, get_assistants, delete_assistant_service, update_assistant_service
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from typing import Annotated
from typing import Annotated, Optional, List
from uuid import UUID
from api.constant import Constant
import json
oauth2_scheme = HTTPBearer()

assistant_router=APIRouter(
Expand All @@ -21,8 +22,30 @@ async def get_all_assistants(
return get_assistants(skip=skip, limit=limit)

@assistant_router.post("", status_code=status.HTTP_201_CREATED)
async def create_assistant(assistant_request: AssistantRequest, authentication_credential: Annotated[HTTPAuthorizationCredentials, Depends(oauth2_scheme)]):
create_assistant_service(token=authentication_credential.credentials, assistant_request=assistant_request)
async def create_assistant(
authentication_credential: Annotated[HTTPAuthorizationCredentials, Depends(oauth2_scheme)],
name: str = Form(...),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as we are sending file to backend. pydantic dont have this validation

system_prompt: str = Form(...),
source_type: Optional[str] = Form(None),
description: Optional[str] = Form(None),
system_assistance: bool = Form(False),
contexts: Optional[str] = Form(None),
files: List[UploadFile] = File(default=[])
):
contexts_data = json.loads(contexts) if contexts else []
assistant_request = AssistantRequest(
name=name,
source_type=source_type,
description=description,
system_prompt=system_prompt,
contexts=contexts_data,
system_assistance=system_assistance
)
await create_assistant_service(
token=authentication_credential.credentials,
assistant_request=assistant_request,
files=files
)
return {"message": Constant.CREATED_ASSISTANT_MESSAGE}

@assistant_router.get("/{assistant_id}", status_code=status.HTTP_200_OK)
Expand Down
95 changes: 0 additions & 95 deletions api/langgraph/context_processor.py

This file was deleted.

4 changes: 2 additions & 2 deletions api/langgraph/nodes/node_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime

from api.langgraph.workflow_type import WorkflowState,Batch
from api.langgraph.context_processor import process_contexts
from api.utils import Utils
from api import config

DEFAULT_MAX_BATCH_SIZE = 2
Expand All @@ -17,7 +17,7 @@ def initialize_workflow(state: WorkflowState) -> WorkflowState:
batch_size = init_size


processed_contexts = process_contexts(request.contexts) if request.contexts else None
processed_contexts = Utils.process_contexts(request.contexts) if request.contexts else None

for i in range(0, len(texts), batch_size):
batch_texts = texts[i : i + batch_size]
Expand Down
Loading