From 50d4467b8df8c68033112d485c5754142d9cd82d Mon Sep 17 00:00:00 2001 From: Declan Brady Date: Tue, 24 Mar 2026 17:28:46 -0400 Subject: [PATCH 1/6] Update tasks PUT endpoint to accept status changes --- agentex/src/api/routes/tasks.py | 13 +- agentex/src/api/schemas/tasks.py | 16 +- .../src/domain/use_cases/tasks_use_case.py | 30 +- .../integration/api/tasks/test_tasks_api.py | 108 +++++++ .../unit/use_cases/test_tasks_use_case.py | 269 ++++++++++++++++++ 5 files changed, 422 insertions(+), 14 deletions(-) create mode 100644 agentex/tests/unit/use_cases/test_tasks_use_case.py diff --git a/agentex/src/api/routes/tasks.py b/agentex/src/api/routes/tasks.py index 0cccdc5e..eb69dc95 100644 --- a/agentex/src/api/routes/tasks.py +++ b/agentex/src/api/routes/tasks.py @@ -15,6 +15,7 @@ TaskResponse, UpdateTaskRequest, ) +from src.domain.entities.tasks import TaskStatus as AgentexTaskStatus from src.domain.services.authorization_service import DAuthorizationService from src.domain.use_cases.streams_use_case import DStreamsUseCase from src.domain.use_cases.tasks_use_case import DTaskUseCase @@ -144,8 +145,12 @@ async def update_task( task_id: DAuthorizedId(AgentexResourceType.task, AuthorizedOperationType.update), task_use_case: DTaskUseCase, ) -> Task: + domain_status = AgentexTaskStatus(request.status) if request.status else None updated_task_entity = await task_use_case.update_mutable_fields_on_task( - id=task_id, task_metadata=request.task_metadata + id=task_id, + task_metadata=request.task_metadata, + status=domain_status, + status_reason=request.status_reason, ) return Task.model_validate(updated_task_entity) @@ -163,8 +168,12 @@ async def update_task_by_name( ), task_use_case: DTaskUseCase, ) -> Task: + domain_status = AgentexTaskStatus(request.status) if request.status else None updated_task_entity = await task_use_case.update_mutable_fields_on_task( - name=task_name, task_metadata=request.task_metadata + name=task_name, + task_metadata=request.task_metadata, + status=domain_status, + status_reason=request.status_reason, ) return Task.model_validate(updated_task_entity) diff --git a/agentex/src/api/schemas/tasks.py b/agentex/src/api/schemas/tasks.py index df5ad42a..9893c4d7 100644 --- a/agentex/src/api/schemas/tasks.py +++ b/agentex/src/api/schemas/tasks.py @@ -1,6 +1,6 @@ from datetime import datetime from enum import Enum -from typing import Any +from typing import Any, Literal from pydantic import Field @@ -24,6 +24,12 @@ class TaskStatus(str, Enum): DELETED = "DELETED" +# Statuses that agents can transition a running task to via the update endpoint +TerminalTaskStatus = Literal[ + "COMPLETED", "FAILED", "CANCELED", "TERMINATED", "TIMED_OUT" +] + + class Task(BaseModel): id: str = Field( ..., @@ -73,3 +79,11 @@ class UpdateTaskRequest(BaseModel): None, title="If provided, replaces task_metadata with this value", ) + status: TerminalTaskStatus | None = Field( + None, + title="If provided, transitions the task to this status. Only RUNNING tasks can be transitioned.", + ) + status_reason: str | None = Field( + None, + title="Optional reason for the status change", + ) diff --git a/agentex/src/domain/use_cases/tasks_use_case.py b/agentex/src/domain/use_cases/tasks_use_case.py index 954f7cc3..5f39115d 100644 --- a/agentex/src/domain/use_cases/tasks_use_case.py +++ b/agentex/src/domain/use_cases/tasks_use_case.py @@ -90,29 +90,37 @@ async def update_mutable_fields_on_task( id: str | None = None, name: str | None = None, task_metadata: dict[str, Any] | None = None, + status: TaskStatus | None = None, + status_reason: str | None = None, ) -> TaskEntity: """Update mutable fields on a task entity. This is used by our API since not all fields should be mutable.""" if not id and not name: raise ClientError("Either id or name must be provided") - # todo: make this a transaction? task_entity = await self.task_service.get_task(id=id, name=name) if task_entity.status == TaskStatus.DELETED: - if id: - raise ItemDoesNotExist(f"Task {id} not found") - else: - raise ItemDoesNotExist(f"Task {name} not found") - - # if no mutations are provided, don't do anything - if task_metadata is None: - return task_entity + identifier = id or name + raise ItemDoesNotExist(f"Task {identifier} not found") + + # Handle status transition (valid target statuses are enforced by the API schema) + if status is not None: + if task_entity.status != TaskStatus.RUNNING: + raise ClientError( + f"Task {task_entity.id} is not running (current status: {task_entity.status}). " + f"Only running tasks can have their status updated." + ) + task_entity.status = status + task_entity.status_reason = status_reason or f"Task {status.value.lower()}" if task_metadata is not None: task_entity.task_metadata = task_metadata - updated_task_entity = await self.task_service.update_task(task=task_entity) - return updated_task_entity + # If no mutations were provided, don't write + if status is None and task_metadata is None: + return task_entity + + return await self.task_service.update_task(task=task_entity) DTaskUseCase = Annotated[TasksUseCase, Depends(TasksUseCase)] diff --git a/agentex/tests/integration/api/tasks/test_tasks_api.py b/agentex/tests/integration/api/tasks/test_tasks_api.py index 3fd4874c..75b10323 100644 --- a/agentex/tests/integration/api/tasks/test_tasks_api.py +++ b/agentex/tests/integration/api/tasks/test_tasks_api.py @@ -1381,3 +1381,111 @@ async def test_list_tasks_filters_work_with_views( assert "agents" in task_data assert len(task_data["agents"]) == 1 assert task_data["agents"][0]["name"] == "target-filter-agent" + + async def test_update_task_status_to_completed(self, isolated_client, test_task): + """Test transitioning a RUNNING task to COMPLETED via PUT endpoint""" + # When + response = await isolated_client.put( + f"/tasks/{test_task.id}", + json={"status": "COMPLETED", "status_reason": "Agent finished"}, + ) + + # Then + assert response.status_code == 200 + task_data = response.json() + assert task_data["status"] == "COMPLETED" + assert task_data["status_reason"] == "Agent finished" + + async def test_update_task_status_to_terminated(self, isolated_client, test_task): + """Test transitioning a RUNNING task to TERMINATED via PUT endpoint""" + # When + response = await isolated_client.put( + f"/tasks/{test_task.id}", + json={"status": "TERMINATED", "status_reason": "Workflow killed"}, + ) + + # Then + assert response.status_code == 200 + task_data = response.json() + assert task_data["status"] == "TERMINATED" + assert task_data["status_reason"] == "Workflow killed" + + async def test_update_task_status_to_timed_out(self, isolated_client, test_task): + """Test transitioning a RUNNING task to TIMED_OUT via PUT endpoint""" + # When + response = await isolated_client.put( + f"/tasks/{test_task.id}", + json={"status": "TIMED_OUT"}, + ) + + # Then + assert response.status_code == 200 + task_data = response.json() + assert task_data["status"] == "TIMED_OUT" + assert task_data["status_reason"] == "Task timed_out" + + async def test_update_task_status_by_name(self, isolated_client, test_task): + """Test transitioning a task to COMPLETED by name""" + # When + response = await isolated_client.put( + f"/tasks/name/{test_task.name}", + json={"status": "COMPLETED", "status_reason": "Done by name"}, + ) + + # Then + assert response.status_code == 200 + task_data = response.json() + assert task_data["status"] == "COMPLETED" + assert task_data["status_reason"] == "Done by name" + + async def test_cannot_transition_non_running_task(self, isolated_client, test_task): + """Test that a completed task cannot be transitioned again""" + # Given - Complete the task first + response = await isolated_client.put( + f"/tasks/{test_task.id}", + json={"status": "COMPLETED"}, + ) + assert response.status_code == 200 + + # When - Try to transition again + response = await isolated_client.put( + f"/tasks/{test_task.id}", + json={"status": "TERMINATED"}, + ) + + # Then - Should fail + assert response.status_code == 400 + + async def test_update_task_rejects_invalid_status(self, isolated_client, test_task): + """Test that RUNNING and DELETED are rejected as target statuses""" + # When - Try to set status to RUNNING + response = await isolated_client.put( + f"/tasks/{test_task.id}", + json={"status": "RUNNING"}, + ) + + # Then - Should be rejected by schema validation (422) + assert response.status_code == 422 + + # When - Try to set status to DELETED + response = await isolated_client.put( + f"/tasks/{test_task.id}", + json={"status": "DELETED"}, + ) + + # Then - Should be rejected by schema validation (422) + assert response.status_code == 422 + + async def test_update_metadata_still_works(self, isolated_client, test_task): + """Test that updating only metadata without status still works""" + # When + response = await isolated_client.put( + f"/tasks/{test_task.id}", + json={"task_metadata": {"key": "value"}}, + ) + + # Then + assert response.status_code == 200 + task_data = response.json() + assert task_data["status"] == "RUNNING" + assert task_data["task_metadata"] == {"key": "value"} diff --git a/agentex/tests/unit/use_cases/test_tasks_use_case.py b/agentex/tests/unit/use_cases/test_tasks_use_case.py new file mode 100644 index 00000000..4dabd15c --- /dev/null +++ b/agentex/tests/unit/use_cases/test_tasks_use_case.py @@ -0,0 +1,269 @@ +""" +Unit tests for TasksUseCase - specifically the status transition logic +in update_mutable_fields_on_task. +""" + +from uuid import uuid4 + +import pytest +from src.adapters.crud_store.exceptions import DuplicateItemError, ItemDoesNotExist +from src.domain.entities.agents import ACPType, AgentEntity, AgentStatus +from src.domain.entities.tasks import TaskStatus +from src.domain.exceptions import ClientError +from src.domain.repositories.agent_repository import AgentRepository +from src.domain.repositories.task_repository import TaskRepository +from src.domain.use_cases.tasks_use_case import TasksUseCase + + +async def create_or_get_agent(agent_repository, agent): + """Helper to create agent or get existing one if name already exists""" + try: + return await agent_repository.create(agent) + except DuplicateItemError: + existing_agent = await agent_repository.get(name=agent.name) + agent.id = existing_agent.id + return existing_agent + + +@pytest.fixture +def agent_repository(postgres_session_maker): + """Real AgentRepository using test PostgreSQL database""" + return AgentRepository(postgres_session_maker, postgres_session_maker) + + +@pytest.fixture +def task_repository(postgres_session_maker): + """Real TaskRepository using test PostgreSQL database""" + return TaskRepository(postgres_session_maker, postgres_session_maker) + + +@pytest.fixture +def tasks_use_case(task_service): + """TasksUseCase with real task_service""" + return TasksUseCase(task_service=task_service) + + +@pytest.fixture +def sample_agent(): + """Sample agent entity for testing""" + return AgentEntity( + id=str(uuid4()), + name="test-agent-use-case", + description="A test agent for use case testing", + status=AgentStatus.READY, + acp_type=ACPType.ASYNC, + acp_url="http://test-acp.example.com", + ) + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestTasksUseCaseStatusTransitions: + """Test suite for task status transitions via update_mutable_fields_on_task""" + + async def test_complete_running_task( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a RUNNING task can be transitioned to COMPLETED""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="complete-test" + ) + assert task.status == TaskStatus.RUNNING + + # When + updated = await tasks_use_case.update_mutable_fields_on_task( + id=task.id, status=TaskStatus.COMPLETED, status_reason="Agent finished" + ) + + # Then + assert updated.status == TaskStatus.COMPLETED + assert updated.status_reason == "Agent finished" + + async def test_terminate_running_task( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a RUNNING task can be transitioned to TERMINATED""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="terminate-test" + ) + + # When + updated = await tasks_use_case.update_mutable_fields_on_task( + id=task.id, status=TaskStatus.TERMINATED, status_reason="Workflow killed" + ) + + # Then + assert updated.status == TaskStatus.TERMINATED + assert updated.status_reason == "Workflow killed" + + async def test_timeout_running_task( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a RUNNING task can be transitioned to TIMED_OUT""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="timeout-test" + ) + + # When + updated = await tasks_use_case.update_mutable_fields_on_task( + id=task.id, status=TaskStatus.TIMED_OUT + ) + + # Then + assert updated.status == TaskStatus.TIMED_OUT + assert updated.status_reason == "Task timed_out" + + async def test_default_status_reason( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a default status_reason is set when none provided""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="default-reason-test" + ) + + # When + updated = await tasks_use_case.update_mutable_fields_on_task( + id=task.id, status=TaskStatus.COMPLETED + ) + + # Then + assert updated.status_reason == "Task completed" + + async def test_cannot_transition_completed_task( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a COMPLETED task cannot be transitioned again""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="double-complete-test" + ) + await tasks_use_case.update_mutable_fields_on_task( + id=task.id, status=TaskStatus.COMPLETED + ) + + # When / Then + with pytest.raises(ClientError, match="not running"): + await tasks_use_case.update_mutable_fields_on_task( + id=task.id, status=TaskStatus.TERMINATED + ) + + async def test_cannot_transition_canceled_task( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a CANCELED task cannot be transitioned""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="cancel-block-test" + ) + await tasks_use_case.update_mutable_fields_on_task( + id=task.id, status=TaskStatus.CANCELED + ) + + # When / Then + with pytest.raises(ClientError, match="not running"): + await tasks_use_case.update_mutable_fields_on_task( + id=task.id, status=TaskStatus.COMPLETED + ) + + async def test_cannot_transition_failed_task( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a FAILED task cannot be transitioned""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="fail-block-test" + ) + await tasks_use_case.update_mutable_fields_on_task( + id=task.id, status=TaskStatus.FAILED + ) + + # When / Then + with pytest.raises(ClientError, match="not running"): + await tasks_use_case.update_mutable_fields_on_task( + id=task.id, status=TaskStatus.COMPLETED + ) + + async def test_cannot_transition_deleted_task( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a DELETED task raises not found""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="delete-block-test" + ) + await tasks_use_case.delete_task(id=task.id) + + # When / Then + with pytest.raises(ItemDoesNotExist): + await tasks_use_case.update_mutable_fields_on_task( + id=task.id, status=TaskStatus.COMPLETED + ) + + async def test_update_metadata_without_status( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that updating only task_metadata does not change status""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="metadata-only-test" + ) + + # When + updated = await tasks_use_case.update_mutable_fields_on_task( + id=task.id, task_metadata={"key": "value"} + ) + + # Then + assert updated.status == TaskStatus.RUNNING + assert updated.task_metadata == {"key": "value"} + + async def test_update_status_and_metadata_together( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that status and metadata can be updated in a single call""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="both-update-test" + ) + + # When + updated = await tasks_use_case.update_mutable_fields_on_task( + id=task.id, + status=TaskStatus.COMPLETED, + status_reason="Done", + task_metadata={"result": "success"}, + ) + + # Then + assert updated.status == TaskStatus.COMPLETED + assert updated.status_reason == "Done" + assert updated.task_metadata == {"result": "success"} + + async def test_no_op_when_nothing_provided( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that providing neither status nor metadata is a no-op""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task(agent=sample_agent, task_name="noop-test") + + # When + result = await tasks_use_case.update_mutable_fields_on_task(id=task.id) + + # Then + assert result.status == TaskStatus.RUNNING + assert result.id == task.id From 9d4ec296daad0a6e6b7c2a38f16b4f416c92f5a1 Mon Sep 17 00:00:00 2001 From: Declan Brady Date: Tue, 24 Mar 2026 17:31:58 -0400 Subject: [PATCH 2/6] Only allow chats with running tasks --- .../primary-content/prompt-input.tsx | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/agentex-ui/components/primary-content/prompt-input.tsx b/agentex-ui/components/primary-content/prompt-input.tsx index 1ae7475f..b85deb90 100644 --- a/agentex-ui/components/primary-content/prompt-input.tsx +++ b/agentex-ui/components/primary-content/prompt-input.tsx @@ -20,6 +20,7 @@ import { useSafeSearchParams, } from '@/hooks/use-safe-search-params'; import { useSendMessage } from '@/hooks/use-task-messages'; +import { useTask } from '@/hooks/use-tasks'; type PromptInputProps = { prompt: string; @@ -52,10 +53,16 @@ export function PromptInput({ prompt, setPrompt }: PromptInputProps) { const createTaskMutation = useCreateTask({ agentexClient }); const sendMessageMutation = useSendMessage({ agentexClient }); + const { data: task } = useTask({ agentexClient, taskId: taskID ?? '' }); const textInputRef = useRef(null); const codeMirrorViewRef = useRef(null); + const isTaskTerminal = useMemo(() => { + if (!taskID || !task) return false; + return task.status != null && task.status !== 'RUNNING'; + }, [taskID, task]); + const handleSetJson = useCallback( (value: boolean) => { if (value && !prompt.trim()) { @@ -86,8 +93,8 @@ export function PromptInput({ prompt, setPrompt }: PromptInputProps) { }, [taskID, isClient, isSendingJSON]); const isDisabled = useMemo( - () => !agentName || !isClient, - [agentName, isClient] + () => !agentName || !isClient || isTaskTerminal, + [agentName, isClient, isTaskTerminal] ); const handleSendPrompt = useCallback(async () => { @@ -171,6 +178,8 @@ export function PromptInput({ prompt, setPrompt }: PromptInputProps) { prompt={prompt} setPrompt={setPrompt} isDisabled={isDisabled} + isTaskTerminal={isTaskTerminal} + taskStatus={task?.status} handleSendPrompt={handleSendPrompt} inputRef={textInputRef} /> @@ -205,12 +214,16 @@ const TextInput = ({ prompt, setPrompt, isDisabled, + isTaskTerminal, + taskStatus, handleSendPrompt, inputRef, }: { prompt: string; setPrompt: (prompt: string) => void; isDisabled: boolean; + isTaskTerminal: boolean; + taskStatus: string | null | undefined; handleSendPrompt: () => void; inputRef: React.RefObject; }) => { @@ -230,7 +243,11 @@ const TextInput = ({ }} disabled={isDisabled} placeholder={ - isDisabled ? 'Select an agent to start' : 'Enter your prompt' + isTaskTerminal + ? `Task ${taskStatus?.toLowerCase() ?? 'ended'}` + : isDisabled + ? 'Select an agent to start' + : 'Enter your prompt' } className="mr-2 flex-1 outline-none focus:ring-0 focus:outline-none" style={{ From 5808b577be07ec955fd0217906c75ee3bace5015 Mon Sep 17 00:00:00 2001 From: Declan Brady Date: Tue, 24 Mar 2026 17:32:26 -0400 Subject: [PATCH 3/6] Update UI to not require user message --- .../task-messages/task-messages.tsx | 38 +++++++++++-------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/agentex-ui/components/task-messages/task-messages.tsx b/agentex-ui/components/task-messages/task-messages.tsx index 67330839..5e658a7c 100644 --- a/agentex-ui/components/task-messages/task-messages.tsx +++ b/agentex-ui/components/task-messages/task-messages.tsx @@ -23,7 +23,7 @@ type TaskMessagesProps = { }; type MessagePair = { id: string; - userMessage: TaskMessage; + userMessage: TaskMessage | null; agentMessages: TaskMessage[]; }; @@ -58,36 +58,41 @@ function TaskMessagesImpl({ taskId, headerRef }: TaskMessagesProps) { const pairs: MessagePair[] = []; let currentUserMessage: TaskMessage | null = null; let currentAgentMessages: TaskMessage[] = []; + let pairStarted = false; for (const message of messages) { const isUserMessage = message.content.author === 'user'; if (isUserMessage) { - if (currentUserMessage) { + if (pairStarted) { pairs.push({ - id: currentUserMessage.id || `pair-${pairs.length}`, + id: + currentUserMessage?.id || + currentAgentMessages[0]?.id || + `pair-${pairs.length}`, userMessage: currentUserMessage, agentMessages: currentAgentMessages, }); } currentUserMessage = message; currentAgentMessages = []; + pairStarted = true; } else { - if (currentUserMessage) { - currentAgentMessages.push(message); - } else { - pairs.push({ - id: message.id || `pair-${pairs.length}`, - userMessage: message, - agentMessages: [], - }); + if (!pairStarted) { + currentUserMessage = null; + currentAgentMessages = []; + pairStarted = true; } + currentAgentMessages.push(message); } } - if (currentUserMessage) { + if (pairStarted) { pairs.push({ - id: currentUserMessage.id || `pair-${pairs.length}`, + id: + currentUserMessage?.id || + currentAgentMessages[0]?.id || + `pair-${pairs.length}`, userMessage: currentUserMessage, agentMessages: currentAgentMessages, }); @@ -101,10 +106,13 @@ function TaskMessagesImpl({ taskId, headerRef }: TaskMessagesProps) { const lastPair = messagePairs[messagePairs.length - 1]!; const hasNoAgentMessages = lastPair.agentMessages.length === 0; + const hasUserMessage = lastPair.userMessage !== null; const rpcStatus = queryData?.rpcStatus; return ( - hasNoAgentMessages && (rpcStatus === 'pending' || rpcStatus === 'success') + hasUserMessage && + hasNoAgentMessages && + (rpcStatus === 'pending' || rpcStatus === 'success') ); }, [messagePairs, queryData?.rpcStatus]); @@ -191,7 +199,7 @@ function TaskMessagesImpl({ taskId, headerRef }: TaskMessagesProps) { containerHeight={containerHeight} > - {renderMessage(pair.userMessage)} + {pair.userMessage && renderMessage(pair.userMessage)} {pair.agentMessages.map(agentMessage => ( {renderMessage(agentMessage)} From 4785eb24879bfcb2db1adbfd4d3ece99d138ee4e Mon Sep 17 00:00:00 2001 From: Declan Brady Date: Tue, 24 Mar 2026 17:48:22 -0400 Subject: [PATCH 4/6] Make task status updates atomic --- .../domain/repositories/task_repository.py | 34 ++++++++++++++++- agentex/src/domain/services/task_service.py | 38 +++++++++++++++++++ .../src/domain/use_cases/tasks_use_case.py | 31 ++++++++++----- 3 files changed, 92 insertions(+), 11 deletions(-) diff --git a/agentex/src/domain/repositories/task_repository.py b/agentex/src/domain/repositories/task_repository.py index 53898eb6..fe3e2eae 100644 --- a/agentex/src/domain/repositories/task_repository.py +++ b/agentex/src/domain/repositories/task_repository.py @@ -2,7 +2,7 @@ from typing import Annotated, Literal from fastapi import Depends -from sqlalchemy import select +from sqlalchemy import select, update from sqlalchemy.orm import selectinload from src.adapters.crud_store.adapter_postgres import ( ColumnPrimitiveValue, @@ -139,5 +139,37 @@ async def update(self, task: TaskEntity) -> TaskEntity: # Return with agents populated return TaskEntity.model_validate(modified_orm) + async def transition_status( + self, + task_id: str, + expected_status: TaskStatus, + new_status: TaskStatus, + status_reason: str, + task_metadata: dict | None = None, + ) -> TaskEntity | None: + """Atomically transition task status. Returns None if the expected status didn't match (i.e. lost the race).""" + + async with ( + self.start_async_db_session(True) as session, + async_sql_exception_handler(), + ): + values: dict = {"status": new_status, "status_reason": status_reason} + if task_metadata is not None: + values["task_metadata"] = task_metadata + + stmt = ( + update(TaskORM) + .where(TaskORM.id == task_id, TaskORM.status == expected_status) + .values(**values) + ) + result = await session.execute(stmt) + await session.commit() + + if result.rowcount == 0: + return None + + refreshed = await session.get(TaskORM, task_id) + return TaskEntity.model_validate(refreshed) + DTaskRepository = Annotated[TaskRepository, Depends(TaskRepository)] diff --git a/agentex/src/domain/services/task_service.py b/agentex/src/domain/services/task_service.py index c31b3561..83923ff9 100644 --- a/agentex/src/domain/services/task_service.py +++ b/agentex/src/domain/services/task_service.py @@ -144,6 +144,44 @@ async def get_task( id=id, name=name, relationships=relationships ) + async def transition_task_status( + self, + task_id: str, + expected_status: TaskStatus, + new_status: TaskStatus, + status_reason: str, + task_metadata: dict | None = None, + ) -> TaskEntity | None: + """ + Atomically transition task status. Returns None if the expected status didn't match. + Publishes a task_updated event on success. + """ + updated_task = await self.task_repository.transition_status( + task_id=task_id, + expected_status=expected_status, + new_status=new_status, + status_reason=status_reason, + task_metadata=task_metadata, + ) + if updated_task is None: + return None + + try: + topic = get_task_event_stream_topic(task_id=task_id) + await self.stream_repository.send_data( + topic, + TaskStreamTaskUpdatedEventEntity( + type="task_updated", task=updated_task + ).model_dump(mode="json"), + ) + logger.info(f"task_updated event published to topic: {topic}") + except Exception as e: + logger.error( + f"Error sending task_updated event to stream: {e}", exc_info=True + ) + + return updated_task + async def update_task(self, task: TaskEntity) -> TaskEntity: """ Update a task in the repository. diff --git a/agentex/src/domain/use_cases/tasks_use_case.py b/agentex/src/domain/use_cases/tasks_use_case.py index 5f39115d..417994fa 100644 --- a/agentex/src/domain/use_cases/tasks_use_case.py +++ b/agentex/src/domain/use_cases/tasks_use_case.py @@ -103,23 +103,34 @@ async def update_mutable_fields_on_task( identifier = id or name raise ItemDoesNotExist(f"Task {identifier} not found") - # Handle status transition (valid target statuses are enforced by the API schema) + # If no mutations were provided, don't write + if status is None and task_metadata is None: + return task_entity + + # Status transition uses an atomic conditional update to prevent race conditions if status is not None: if task_entity.status != TaskStatus.RUNNING: raise ClientError( f"Task {task_entity.id} is not running (current status: {task_entity.status}). " f"Only running tasks can have their status updated." ) - task_entity.status = status - task_entity.status_reason = status_reason or f"Task {status.value.lower()}" - - if task_metadata is not None: - task_entity.task_metadata = task_metadata - - # If no mutations were provided, don't write - if status is None and task_metadata is None: - return task_entity + reason = status_reason or f"Task {status.value.lower()}" + updated = await self.task_service.transition_task_status( + task_id=task_entity.id, + expected_status=TaskStatus.RUNNING, + new_status=status, + status_reason=reason, + task_metadata=task_metadata, + ) + if updated is None: + raise ClientError( + f"Task {task_entity.id} status was concurrently modified. " + f"Please retry the request." + ) + return updated + # Metadata-only update (no status change) + task_entity.task_metadata = task_metadata return await self.task_service.update_task(task=task_entity) From 58d2723224c25c2a2fe955571c899216627d6e71 Mon Sep 17 00:00:00 2001 From: Declan Brady Date: Tue, 24 Mar 2026 19:47:58 -0400 Subject: [PATCH 5/6] Move to definied endpoints and functions for easier accessibility --- agentex/src/api/routes/tasks.py | 99 +++++++++++++-- agentex/src/api/schemas/tasks.py | 17 +-- .../src/domain/use_cases/tasks_use_case.py | 115 +++++++++++++----- .../integration/api/tasks/test_tasks_api.py | 114 ++++++++--------- .../unit/use_cases/test_tasks_use_case.py | 88 +++----------- 5 files changed, 246 insertions(+), 187 deletions(-) diff --git a/agentex/src/api/routes/tasks.py b/agentex/src/api/routes/tasks.py index eb69dc95..9eef026b 100644 --- a/agentex/src/api/routes/tasks.py +++ b/agentex/src/api/routes/tasks.py @@ -13,9 +13,9 @@ Task, TaskRelationships, TaskResponse, + TaskStatusReasonRequest, UpdateTaskRequest, ) -from src.domain.entities.tasks import TaskStatus as AgentexTaskStatus from src.domain.services.authorization_service import DAuthorizationService from src.domain.use_cases.streams_use_case import DStreamsUseCase from src.domain.use_cases.tasks_use_case import DTaskUseCase @@ -145,12 +145,8 @@ async def update_task( task_id: DAuthorizedId(AgentexResourceType.task, AuthorizedOperationType.update), task_use_case: DTaskUseCase, ) -> Task: - domain_status = AgentexTaskStatus(request.status) if request.status else None updated_task_entity = await task_use_case.update_mutable_fields_on_task( - id=task_id, - task_metadata=request.task_metadata, - status=domain_status, - status_reason=request.status_reason, + id=task_id, task_metadata=request.task_metadata ) return Task.model_validate(updated_task_entity) @@ -168,16 +164,97 @@ async def update_task_by_name( ), task_use_case: DTaskUseCase, ) -> Task: - domain_status = AgentexTaskStatus(request.status) if request.status else None updated_task_entity = await task_use_case.update_mutable_fields_on_task( - name=task_name, - task_metadata=request.task_metadata, - status=domain_status, - status_reason=request.status_reason, + name=task_name, task_metadata=request.task_metadata ) return Task.model_validate(updated_task_entity) +@router.post( + "/{task_id}/complete", + response_model=Task, + summary="Complete Task", + description="Mark a running task as completed.", +) +async def complete_task( + task_id: DAuthorizedId(AgentexResourceType.task, AuthorizedOperationType.update), + task_use_case: DTaskUseCase, + request: TaskStatusReasonRequest | None = None, +) -> Task: + updated = await task_use_case.complete_task( + id=task_id, reason=request.reason if request else None + ) + return Task.model_validate(updated) + + +@router.post( + "/{task_id}/fail", + response_model=Task, + summary="Fail Task", + description="Mark a running task as failed.", +) +async def fail_task( + task_id: DAuthorizedId(AgentexResourceType.task, AuthorizedOperationType.update), + task_use_case: DTaskUseCase, + request: TaskStatusReasonRequest | None = None, +) -> Task: + updated = await task_use_case.fail_task( + id=task_id, reason=request.reason if request else None + ) + return Task.model_validate(updated) + + +@router.post( + "/{task_id}/cancel", + response_model=Task, + summary="Cancel Task", + description="Mark a running task as canceled.", +) +async def cancel_task( + task_id: DAuthorizedId(AgentexResourceType.task, AuthorizedOperationType.update), + task_use_case: DTaskUseCase, + request: TaskStatusReasonRequest | None = None, +) -> Task: + updated = await task_use_case.cancel_task( + id=task_id, reason=request.reason if request else None + ) + return Task.model_validate(updated) + + +@router.post( + "/{task_id}/terminate", + response_model=Task, + summary="Terminate Task", + description="Mark a running task as terminated.", +) +async def terminate_task( + task_id: DAuthorizedId(AgentexResourceType.task, AuthorizedOperationType.update), + task_use_case: DTaskUseCase, + request: TaskStatusReasonRequest | None = None, +) -> Task: + updated = await task_use_case.terminate_task( + id=task_id, reason=request.reason if request else None + ) + return Task.model_validate(updated) + + +@router.post( + "/{task_id}/timeout", + response_model=Task, + summary="Timeout Task", + description="Mark a running task as timed out.", +) +async def timeout_task( + task_id: DAuthorizedId(AgentexResourceType.task, AuthorizedOperationType.update), + task_use_case: DTaskUseCase, + request: TaskStatusReasonRequest | None = None, +) -> Task: + updated = await task_use_case.timeout_task( + id=task_id, reason=request.reason if request else None + ) + return Task.model_validate(updated) + + @router.get( "/{task_id}/stream", summary="Stream Task Events by ID", diff --git a/agentex/src/api/schemas/tasks.py b/agentex/src/api/schemas/tasks.py index 9893c4d7..4055cadc 100644 --- a/agentex/src/api/schemas/tasks.py +++ b/agentex/src/api/schemas/tasks.py @@ -1,6 +1,6 @@ from datetime import datetime from enum import Enum -from typing import Any, Literal +from typing import Any from pydantic import Field @@ -24,12 +24,6 @@ class TaskStatus(str, Enum): DELETED = "DELETED" -# Statuses that agents can transition a running task to via the update endpoint -TerminalTaskStatus = Literal[ - "COMPLETED", "FAILED", "CANCELED", "TERMINATED", "TIMED_OUT" -] - - class Task(BaseModel): id: str = Field( ..., @@ -79,11 +73,10 @@ class UpdateTaskRequest(BaseModel): None, title="If provided, replaces task_metadata with this value", ) - status: TerminalTaskStatus | None = Field( - None, - title="If provided, transitions the task to this status. Only RUNNING tasks can be transitioned.", - ) - status_reason: str | None = Field( + + +class TaskStatusReasonRequest(BaseModel): + reason: str | None = Field( None, title="Optional reason for the status change", ) diff --git a/agentex/src/domain/use_cases/tasks_use_case.py b/agentex/src/domain/use_cases/tasks_use_case.py index 417994fa..f358cf1b 100644 --- a/agentex/src/domain/use_cases/tasks_use_case.py +++ b/agentex/src/domain/use_cases/tasks_use_case.py @@ -90,48 +90,103 @@ async def update_mutable_fields_on_task( id: str | None = None, name: str | None = None, task_metadata: dict[str, Any] | None = None, - status: TaskStatus | None = None, - status_reason: str | None = None, ) -> TaskEntity: """Update mutable fields on a task entity. This is used by our API since not all fields should be mutable.""" if not id and not name: raise ClientError("Either id or name must be provided") + # todo: make this a transaction? task_entity = await self.task_service.get_task(id=id, name=name) if task_entity.status == TaskStatus.DELETED: - identifier = id or name - raise ItemDoesNotExist(f"Task {identifier} not found") + if id: + raise ItemDoesNotExist(f"Task {id} not found") + else: + raise ItemDoesNotExist(f"Task {name} not found") - # If no mutations were provided, don't write - if status is None and task_metadata is None: + # if no mutations are provided, don't do anything + if task_metadata is None: return task_entity - # Status transition uses an atomic conditional update to prevent race conditions - if status is not None: - if task_entity.status != TaskStatus.RUNNING: - raise ClientError( - f"Task {task_entity.id} is not running (current status: {task_entity.status}). " - f"Only running tasks can have their status updated." - ) - reason = status_reason or f"Task {status.value.lower()}" - updated = await self.task_service.transition_task_status( - task_id=task_entity.id, - expected_status=TaskStatus.RUNNING, - new_status=status, - status_reason=reason, - task_metadata=task_metadata, + if task_metadata is not None: + task_entity.task_metadata = task_metadata + + updated_task_entity = await self.task_service.update_task(task=task_entity) + return updated_task_entity + + async def _transition_to_terminal( + self, + target_status: TaskStatus, + id: str | None = None, + name: str | None = None, + reason: str | None = None, + ) -> TaskEntity: + """Atomically transition a running task to a terminal status.""" + if not id and not name: + raise ClientError("Either id or name must be provided") + + task_entity = await self.task_service.get_task(id=id, name=name) + if task_entity.status == TaskStatus.DELETED: + raise ItemDoesNotExist(f"Task {id or name} not found") + if task_entity.status != TaskStatus.RUNNING: + raise ClientError( + f"Task {task_entity.id} is not running (current status: {task_entity.status}). " + f"Only running tasks can have their status updated." ) - if updated is None: - raise ClientError( - f"Task {task_entity.id} status was concurrently modified. " - f"Please retry the request." - ) - return updated - - # Metadata-only update (no status change) - task_entity.task_metadata = task_metadata - return await self.task_service.update_task(task=task_entity) + + status_reason = reason or f"Task {target_status.value.lower()}" + updated = await self.task_service.transition_task_status( + task_id=task_entity.id, + expected_status=TaskStatus.RUNNING, + new_status=target_status, + status_reason=status_reason, + ) + if updated is None: + raise ClientError( + f"Task {task_entity.id} status was concurrently modified. " + f"Please retry the request." + ) + return updated + + async def complete_task( + self, id: str | None = None, name: str | None = None, reason: str | None = None + ) -> TaskEntity: + """Mark a running task as completed.""" + return await self._transition_to_terminal( + TaskStatus.COMPLETED, id=id, name=name, reason=reason + ) + + async def fail_task( + self, id: str | None = None, name: str | None = None, reason: str | None = None + ) -> TaskEntity: + """Mark a running task as failed.""" + return await self._transition_to_terminal( + TaskStatus.FAILED, id=id, name=name, reason=reason + ) + + async def cancel_task( + self, id: str | None = None, name: str | None = None, reason: str | None = None + ) -> TaskEntity: + """Mark a running task as canceled.""" + return await self._transition_to_terminal( + TaskStatus.CANCELED, id=id, name=name, reason=reason + ) + + async def terminate_task( + self, id: str | None = None, name: str | None = None, reason: str | None = None + ) -> TaskEntity: + """Mark a running task as terminated.""" + return await self._transition_to_terminal( + TaskStatus.TERMINATED, id=id, name=name, reason=reason + ) + + async def timeout_task( + self, id: str | None = None, name: str | None = None, reason: str | None = None + ) -> TaskEntity: + """Mark a running task as timed out.""" + return await self._transition_to_terminal( + TaskStatus.TIMED_OUT, id=id, name=name, reason=reason + ) DTaskUseCase = Annotated[TasksUseCase, Depends(TasksUseCase)] diff --git a/agentex/tests/integration/api/tasks/test_tasks_api.py b/agentex/tests/integration/api/tasks/test_tasks_api.py index 75b10323..934e079f 100644 --- a/agentex/tests/integration/api/tasks/test_tasks_api.py +++ b/agentex/tests/integration/api/tasks/test_tasks_api.py @@ -1382,12 +1382,12 @@ async def test_list_tasks_filters_work_with_views( assert len(task_data["agents"]) == 1 assert task_data["agents"][0]["name"] == "target-filter-agent" - async def test_update_task_status_to_completed(self, isolated_client, test_task): - """Test transitioning a RUNNING task to COMPLETED via PUT endpoint""" + async def test_complete_task_endpoint(self, isolated_client, test_task): + """Test POST /tasks/{task_id}/complete transitions RUNNING to COMPLETED""" # When - response = await isolated_client.put( - f"/tasks/{test_task.id}", - json={"status": "COMPLETED", "status_reason": "Agent finished"}, + response = await isolated_client.post( + f"/tasks/{test_task.id}/complete", + json={"reason": "Agent finished"}, ) # Then @@ -1396,12 +1396,40 @@ async def test_update_task_status_to_completed(self, isolated_client, test_task) assert task_data["status"] == "COMPLETED" assert task_data["status_reason"] == "Agent finished" - async def test_update_task_status_to_terminated(self, isolated_client, test_task): - """Test transitioning a RUNNING task to TERMINATED via PUT endpoint""" + async def test_fail_task_endpoint(self, isolated_client, test_task): + """Test POST /tasks/{task_id}/fail transitions RUNNING to FAILED""" # When - response = await isolated_client.put( - f"/tasks/{test_task.id}", - json={"status": "TERMINATED", "status_reason": "Workflow killed"}, + response = await isolated_client.post( + f"/tasks/{test_task.id}/fail", + json={"reason": "Something went wrong"}, + ) + + # Then + assert response.status_code == 200 + task_data = response.json() + assert task_data["status"] == "FAILED" + assert task_data["status_reason"] == "Something went wrong" + + async def test_cancel_task_endpoint(self, isolated_client, test_task): + """Test POST /tasks/{task_id}/cancel transitions RUNNING to CANCELED""" + # When + response = await isolated_client.post( + f"/tasks/{test_task.id}/cancel", + json={"reason": "User requested cancellation"}, + ) + + # Then + assert response.status_code == 200 + task_data = response.json() + assert task_data["status"] == "CANCELED" + assert task_data["status_reason"] == "User requested cancellation" + + async def test_terminate_task_endpoint(self, isolated_client, test_task): + """Test POST /tasks/{task_id}/terminate transitions RUNNING to TERMINATED""" + # When + response = await isolated_client.post( + f"/tasks/{test_task.id}/terminate", + json={"reason": "Workflow killed"}, ) # Then @@ -1410,12 +1438,11 @@ async def test_update_task_status_to_terminated(self, isolated_client, test_task assert task_data["status"] == "TERMINATED" assert task_data["status_reason"] == "Workflow killed" - async def test_update_task_status_to_timed_out(self, isolated_client, test_task): - """Test transitioning a RUNNING task to TIMED_OUT via PUT endpoint""" + async def test_timeout_task_endpoint(self, isolated_client, test_task): + """Test POST /tasks/{task_id}/timeout transitions RUNNING to TIMED_OUT""" # When - response = await isolated_client.put( - f"/tasks/{test_task.id}", - json={"status": "TIMED_OUT"}, + response = await isolated_client.post( + f"/tasks/{test_task.id}/timeout", ) # Then @@ -1424,68 +1451,31 @@ async def test_update_task_status_to_timed_out(self, isolated_client, test_task) assert task_data["status"] == "TIMED_OUT" assert task_data["status_reason"] == "Task timed_out" - async def test_update_task_status_by_name(self, isolated_client, test_task): - """Test transitioning a task to COMPLETED by name""" + async def test_complete_task_with_default_reason(self, isolated_client, test_task): + """Test POST /tasks/{task_id}/complete without a reason uses default""" # When - response = await isolated_client.put( - f"/tasks/name/{test_task.name}", - json={"status": "COMPLETED", "status_reason": "Done by name"}, + response = await isolated_client.post( + f"/tasks/{test_task.id}/complete", ) # Then assert response.status_code == 200 task_data = response.json() assert task_data["status"] == "COMPLETED" - assert task_data["status_reason"] == "Done by name" + assert task_data["status_reason"] == "Task completed" async def test_cannot_transition_non_running_task(self, isolated_client, test_task): """Test that a completed task cannot be transitioned again""" # Given - Complete the task first - response = await isolated_client.put( - f"/tasks/{test_task.id}", - json={"status": "COMPLETED"}, + response = await isolated_client.post( + f"/tasks/{test_task.id}/complete", ) assert response.status_code == 200 - # When - Try to transition again - response = await isolated_client.put( - f"/tasks/{test_task.id}", - json={"status": "TERMINATED"}, + # When - Try to terminate the already-completed task + response = await isolated_client.post( + f"/tasks/{test_task.id}/terminate", ) # Then - Should fail assert response.status_code == 400 - - async def test_update_task_rejects_invalid_status(self, isolated_client, test_task): - """Test that RUNNING and DELETED are rejected as target statuses""" - # When - Try to set status to RUNNING - response = await isolated_client.put( - f"/tasks/{test_task.id}", - json={"status": "RUNNING"}, - ) - - # Then - Should be rejected by schema validation (422) - assert response.status_code == 422 - - # When - Try to set status to DELETED - response = await isolated_client.put( - f"/tasks/{test_task.id}", - json={"status": "DELETED"}, - ) - - # Then - Should be rejected by schema validation (422) - assert response.status_code == 422 - - async def test_update_metadata_still_works(self, isolated_client, test_task): - """Test that updating only metadata without status still works""" - # When - response = await isolated_client.put( - f"/tasks/{test_task.id}", - json={"task_metadata": {"key": "value"}}, - ) - - # Then - assert response.status_code == 200 - task_data = response.json() - assert task_data["status"] == "RUNNING" - assert task_data["task_metadata"] == {"key": "value"} diff --git a/agentex/tests/unit/use_cases/test_tasks_use_case.py b/agentex/tests/unit/use_cases/test_tasks_use_case.py index 4dabd15c..6ce7ca7c 100644 --- a/agentex/tests/unit/use_cases/test_tasks_use_case.py +++ b/agentex/tests/unit/use_cases/test_tasks_use_case.py @@ -1,6 +1,6 @@ """ Unit tests for TasksUseCase - specifically the status transition logic -in update_mutable_fields_on_task. +via explicit status methods (complete_task, fail_task, etc.). """ from uuid import uuid4 @@ -59,7 +59,7 @@ def sample_agent(): @pytest.mark.unit @pytest.mark.asyncio class TestTasksUseCaseStatusTransitions: - """Test suite for task status transitions via update_mutable_fields_on_task""" + """Test suite for task status transitions via explicit status methods""" async def test_complete_running_task( self, tasks_use_case, task_service, agent_repository, sample_agent @@ -73,8 +73,8 @@ async def test_complete_running_task( assert task.status == TaskStatus.RUNNING # When - updated = await tasks_use_case.update_mutable_fields_on_task( - id=task.id, status=TaskStatus.COMPLETED, status_reason="Agent finished" + updated = await tasks_use_case.complete_task( + id=task.id, reason="Agent finished" ) # Then @@ -92,8 +92,8 @@ async def test_terminate_running_task( ) # When - updated = await tasks_use_case.update_mutable_fields_on_task( - id=task.id, status=TaskStatus.TERMINATED, status_reason="Workflow killed" + updated = await tasks_use_case.terminate_task( + id=task.id, reason="Workflow killed" ) # Then @@ -111,9 +111,7 @@ async def test_timeout_running_task( ) # When - updated = await tasks_use_case.update_mutable_fields_on_task( - id=task.id, status=TaskStatus.TIMED_OUT - ) + updated = await tasks_use_case.timeout_task(id=task.id) # Then assert updated.status == TaskStatus.TIMED_OUT @@ -130,9 +128,7 @@ async def test_default_status_reason( ) # When - updated = await tasks_use_case.update_mutable_fields_on_task( - id=task.id, status=TaskStatus.COMPLETED - ) + updated = await tasks_use_case.complete_task(id=task.id) # Then assert updated.status_reason == "Task completed" @@ -146,15 +142,11 @@ async def test_cannot_transition_completed_task( task = await task_service.create_task( agent=sample_agent, task_name="double-complete-test" ) - await tasks_use_case.update_mutable_fields_on_task( - id=task.id, status=TaskStatus.COMPLETED - ) + await tasks_use_case.complete_task(id=task.id) # When / Then with pytest.raises(ClientError, match="not running"): - await tasks_use_case.update_mutable_fields_on_task( - id=task.id, status=TaskStatus.TERMINATED - ) + await tasks_use_case.terminate_task(id=task.id) async def test_cannot_transition_canceled_task( self, tasks_use_case, task_service, agent_repository, sample_agent @@ -165,15 +157,11 @@ async def test_cannot_transition_canceled_task( task = await task_service.create_task( agent=sample_agent, task_name="cancel-block-test" ) - await tasks_use_case.update_mutable_fields_on_task( - id=task.id, status=TaskStatus.CANCELED - ) + await tasks_use_case.cancel_task(id=task.id) # When / Then with pytest.raises(ClientError, match="not running"): - await tasks_use_case.update_mutable_fields_on_task( - id=task.id, status=TaskStatus.COMPLETED - ) + await tasks_use_case.complete_task(id=task.id) async def test_cannot_transition_failed_task( self, tasks_use_case, task_service, agent_repository, sample_agent @@ -184,15 +172,11 @@ async def test_cannot_transition_failed_task( task = await task_service.create_task( agent=sample_agent, task_name="fail-block-test" ) - await tasks_use_case.update_mutable_fields_on_task( - id=task.id, status=TaskStatus.FAILED - ) + await tasks_use_case.fail_task(id=task.id) # When / Then with pytest.raises(ClientError, match="not running"): - await tasks_use_case.update_mutable_fields_on_task( - id=task.id, status=TaskStatus.COMPLETED - ) + await tasks_use_case.complete_task(id=task.id) async def test_cannot_transition_deleted_task( self, tasks_use_case, task_service, agent_repository, sample_agent @@ -207,9 +191,7 @@ async def test_cannot_transition_deleted_task( # When / Then with pytest.raises(ItemDoesNotExist): - await tasks_use_case.update_mutable_fields_on_task( - id=task.id, status=TaskStatus.COMPLETED - ) + await tasks_use_case.complete_task(id=task.id) async def test_update_metadata_without_status( self, tasks_use_case, task_service, agent_repository, sample_agent @@ -222,48 +204,10 @@ async def test_update_metadata_without_status( ) # When - updated = await tasks_use_case.update_mutable_fields_on_task( + updated = await tasks_use_case.update_task_metadata( id=task.id, task_metadata={"key": "value"} ) # Then assert updated.status == TaskStatus.RUNNING assert updated.task_metadata == {"key": "value"} - - async def test_update_status_and_metadata_together( - self, tasks_use_case, task_service, agent_repository, sample_agent - ): - """Test that status and metadata can be updated in a single call""" - # Given - await create_or_get_agent(agent_repository, sample_agent) - task = await task_service.create_task( - agent=sample_agent, task_name="both-update-test" - ) - - # When - updated = await tasks_use_case.update_mutable_fields_on_task( - id=task.id, - status=TaskStatus.COMPLETED, - status_reason="Done", - task_metadata={"result": "success"}, - ) - - # Then - assert updated.status == TaskStatus.COMPLETED - assert updated.status_reason == "Done" - assert updated.task_metadata == {"result": "success"} - - async def test_no_op_when_nothing_provided( - self, tasks_use_case, task_service, agent_repository, sample_agent - ): - """Test that providing neither status nor metadata is a no-op""" - # Given - await create_or_get_agent(agent_repository, sample_agent) - task = await task_service.create_task(agent=sample_agent, task_name="noop-test") - - # When - result = await tasks_use_case.update_mutable_fields_on_task(id=task.id) - - # Then - assert result.status == TaskStatus.RUNNING - assert result.id == task.id From af1db6846ae81ffe328fdd84fe289e2819625002 Mon Sep 17 00:00:00 2001 From: Declan Brady Date: Tue, 24 Mar 2026 20:05:25 -0400 Subject: [PATCH 6/6] add better tests for tasks --- .../unit/use_cases/test_tasks_use_case.py | 274 +++++++++++++++++- 1 file changed, 259 insertions(+), 15 deletions(-) diff --git a/agentex/tests/unit/use_cases/test_tasks_use_case.py b/agentex/tests/unit/use_cases/test_tasks_use_case.py index 6ce7ca7c..3de88eea 100644 --- a/agentex/tests/unit/use_cases/test_tasks_use_case.py +++ b/agentex/tests/unit/use_cases/test_tasks_use_case.py @@ -1,6 +1,6 @@ """ -Unit tests for TasksUseCase - specifically the status transition logic -via explicit status methods (complete_task, fail_task, etc.). +Unit tests for TasksUseCase - status transition logic via explicit status +methods (complete_task, fail_task, etc.) and metadata updates. """ from uuid import uuid4 @@ -61,6 +61,8 @@ def sample_agent(): class TestTasksUseCaseStatusTransitions: """Test suite for task status transitions via explicit status methods""" + # --- Happy-path transitions (RUNNING -> terminal) --- + async def test_complete_running_task( self, tasks_use_case, task_service, agent_repository, sample_agent ): @@ -81,6 +83,42 @@ async def test_complete_running_task( assert updated.status == TaskStatus.COMPLETED assert updated.status_reason == "Agent finished" + async def test_fail_running_task( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a RUNNING task can be transitioned to FAILED""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task(agent=sample_agent, task_name="fail-test") + + # When + updated = await tasks_use_case.fail_task( + id=task.id, reason="Something went wrong" + ) + + # Then + assert updated.status == TaskStatus.FAILED + assert updated.status_reason == "Something went wrong" + + async def test_cancel_running_task( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a RUNNING task can be transitioned to CANCELED""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="cancel-test" + ) + + # When + updated = await tasks_use_case.cancel_task( + id=task.id, reason="User requested cancellation" + ) + + # Then + assert updated.status == TaskStatus.CANCELED + assert updated.status_reason == "User requested cancellation" + async def test_terminate_running_task( self, tasks_use_case, task_service, agent_repository, sample_agent ): @@ -117,21 +155,62 @@ async def test_timeout_running_task( assert updated.status == TaskStatus.TIMED_OUT assert updated.status_reason == "Task timed_out" + # --- Default reason for each transition --- + + @pytest.mark.parametrize( + "method,expected_reason", + [ + ("complete_task", "Task completed"), + ("fail_task", "Task failed"), + ("cancel_task", "Task canceled"), + ("terminate_task", "Task terminated"), + ("timeout_task", "Task timed_out"), + ], + ) async def test_default_status_reason( + self, + tasks_use_case, + task_service, + agent_repository, + sample_agent, + method, + expected_reason, + ): + """Test that each transition method sets a default reason when none provided""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name=f"default-reason-{method}" + ) + + # When + updated = await getattr(tasks_use_case, method)(id=task.id) + + # Then + assert updated.status_reason == expected_reason + + # --- Transition by name --- + + async def test_complete_task_by_name( self, tasks_use_case, task_service, agent_repository, sample_agent ): - """Test that a default status_reason is set when none provided""" + """Test that a task can be transitioned using name instead of id""" # Given await create_or_get_agent(agent_repository, sample_agent) task = await task_service.create_task( - agent=sample_agent, task_name="default-reason-test" + agent=sample_agent, task_name="complete-by-name-test" ) # When - updated = await tasks_use_case.complete_task(id=task.id) + updated = await tasks_use_case.complete_task( + name=task.name, reason="Done by name" + ) # Then - assert updated.status_reason == "Task completed" + assert updated.status == TaskStatus.COMPLETED + assert updated.status_reason == "Done by name" + + # --- Blocked transitions from each terminal state --- async def test_cannot_transition_completed_task( self, tasks_use_case, task_service, agent_repository, sample_agent @@ -148,6 +227,21 @@ async def test_cannot_transition_completed_task( with pytest.raises(ClientError, match="not running"): await tasks_use_case.terminate_task(id=task.id) + async def test_cannot_transition_failed_task( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a FAILED task cannot be transitioned""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="fail-block-test" + ) + await tasks_use_case.fail_task(id=task.id) + + # When / Then + with pytest.raises(ClientError, match="not running"): + await tasks_use_case.complete_task(id=task.id) + async def test_cannot_transition_canceled_task( self, tasks_use_case, task_service, agent_repository, sample_agent ): @@ -163,16 +257,31 @@ async def test_cannot_transition_canceled_task( with pytest.raises(ClientError, match="not running"): await tasks_use_case.complete_task(id=task.id) - async def test_cannot_transition_failed_task( + async def test_cannot_transition_terminated_task( self, tasks_use_case, task_service, agent_repository, sample_agent ): - """Test that a FAILED task cannot be transitioned""" + """Test that a TERMINATED task cannot be transitioned""" # Given await create_or_get_agent(agent_repository, sample_agent) task = await task_service.create_task( - agent=sample_agent, task_name="fail-block-test" + agent=sample_agent, task_name="terminate-block-test" ) - await tasks_use_case.fail_task(id=task.id) + await tasks_use_case.terminate_task(id=task.id) + + # When / Then + with pytest.raises(ClientError, match="not running"): + await tasks_use_case.complete_task(id=task.id) + + async def test_cannot_transition_timed_out_task( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that a TIMED_OUT task cannot be transitioned""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="timeout-block-test" + ) + await tasks_use_case.timeout_task(id=task.id) # When / Then with pytest.raises(ClientError, match="not running"): @@ -193,21 +302,156 @@ async def test_cannot_transition_deleted_task( with pytest.raises(ItemDoesNotExist): await tasks_use_case.complete_task(id=task.id) - async def test_update_metadata_without_status( + # --- Validation --- + + async def test_transition_requires_id_or_name(self, tasks_use_case): + """Test that transitioning without id or name raises ClientError""" + with pytest.raises(ClientError, match="Either id or name must be provided"): + await tasks_use_case.complete_task() + + +@pytest.mark.unit +@pytest.mark.asyncio +class TestTasksUseCaseMetadataUpdate: + """Test suite for update_mutable_fields_on_task""" + + async def test_update_metadata( self, tasks_use_case, task_service, agent_repository, sample_agent ): - """Test that updating only task_metadata does not change status""" + """Test that task_metadata is replaced with the provided value""" # Given await create_or_get_agent(agent_repository, sample_agent) task = await task_service.create_task( - agent=sample_agent, task_name="metadata-only-test" + agent=sample_agent, task_name="metadata-update-test" ) # When - updated = await tasks_use_case.update_task_metadata( + updated = await tasks_use_case.update_mutable_fields_on_task( id=task.id, task_metadata={"key": "value"} ) # Then - assert updated.status == TaskStatus.RUNNING assert updated.task_metadata == {"key": "value"} + assert updated.status == TaskStatus.RUNNING + + async def test_update_metadata_does_not_change_status( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that updating metadata leaves status unchanged""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="metadata-status-test" + ) + + # When + updated = await tasks_use_case.update_mutable_fields_on_task( + id=task.id, task_metadata={"new": "data"} + ) + + # Then + assert updated.status == TaskStatus.RUNNING + + async def test_update_metadata_replaces_existing( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that metadata is fully replaced, not merged""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="metadata-replace-test" + ) + await tasks_use_case.update_mutable_fields_on_task( + id=task.id, task_metadata={"original": "data", "keep": "this"} + ) + + # When + updated = await tasks_use_case.update_mutable_fields_on_task( + id=task.id, task_metadata={"replaced": "entirely"} + ) + + # Then + assert updated.task_metadata == {"replaced": "entirely"} + assert "original" not in updated.task_metadata + + async def test_update_metadata_with_empty_dict( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that metadata can be set to an empty dict""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="metadata-empty-test" + ) + await tasks_use_case.update_mutable_fields_on_task( + id=task.id, task_metadata={"some": "data"} + ) + + # When + updated = await tasks_use_case.update_mutable_fields_on_task( + id=task.id, task_metadata={} + ) + + # Then + assert updated.task_metadata == {} + + async def test_update_metadata_noop_when_none( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that passing task_metadata=None is a no-op""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="metadata-noop-test" + ) + + # When + updated = await tasks_use_case.update_mutable_fields_on_task( + id=task.id, task_metadata=None + ) + + # Then + assert updated.id == task.id + assert updated.task_metadata == task.task_metadata + + async def test_update_metadata_by_name( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that metadata can be updated using task name""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="metadata-by-name-test" + ) + + # When + updated = await tasks_use_case.update_mutable_fields_on_task( + name=task.name, task_metadata={"via": "name"} + ) + + # Then + assert updated.task_metadata == {"via": "name"} + + async def test_update_metadata_on_deleted_task_raises( + self, tasks_use_case, task_service, agent_repository, sample_agent + ): + """Test that updating metadata on a deleted task raises not found""" + # Given + await create_or_get_agent(agent_repository, sample_agent) + task = await task_service.create_task( + agent=sample_agent, task_name="metadata-deleted-test" + ) + await tasks_use_case.delete_task(id=task.id) + + # When / Then + with pytest.raises(ItemDoesNotExist): + await tasks_use_case.update_mutable_fields_on_task( + id=task.id, task_metadata={"should": "fail"} + ) + + async def test_update_metadata_requires_id_or_name(self, tasks_use_case): + """Test that updating metadata without id or name raises ClientError""" + with pytest.raises(ClientError, match="Either id or name must be provided"): + await tasks_use_case.update_mutable_fields_on_task( + task_metadata={"key": "value"} + )