diff --git a/src/a2a/_base.py b/src/a2a/_base.py index 6c50734cd..07efc6747 100644 --- a/src/a2a/_base.py +++ b/src/a2a/_base.py @@ -35,4 +35,5 @@ class A2ABaseModel(BaseModel): validate_by_alias=True, serialize_by_alias=True, alias_generator=to_camel_custom, + extra='forbid', ) diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 30d1ee891..04ccccdaf 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -329,7 +329,7 @@ async def push_notification_callback() -> None: ) except Exception: - logger.exception('Agent execution failed') + await self._handle_execution_failure(producer_task, queue) raise finally: if interrupted_or_non_blocking: @@ -392,6 +392,10 @@ async def on_message_send_stream( bg_task.set_name(f'background_consume:{task_id}') self._track_background_task(bg_task) raise + except Exception: + # If the consumer fails (e.g. database error), we must cleanup. + await self._handle_execution_failure(producer_task, queue) + raise finally: cleanup_task = asyncio.create_task( self._cleanup_producer(producer_task, task_id) @@ -429,13 +433,32 @@ def _on_done(completed: asyncio.Task) -> None: task.add_done_callback(_on_done) + async def _handle_execution_failure( + self, producer_task: asyncio.Task, queue: EventQueue + ) -> None: + """Cancels the producer and closes the queue immediately on failure.""" + logger.exception('Agent execution failed') + # If the consumer fails, we must cancel the producer to prevent it from hanging + # on queue operations (e.g., waiting for the queue to drain). + producer_task.cancel() + # Force the queue to close immediately, discarding any pending events. + # This ensures that any producers waiting on the queue are unblocked. + await queue.close(immediate=True) + async def _cleanup_producer( self, producer_task: asyncio.Task, task_id: str, ) -> None: """Cleans up the agent execution task and queue manager entry.""" - await producer_task + try: + await producer_task + except asyncio.CancelledError: + logger.debug( + 'Producer task %s was cancelled during cleanup', task_id + ) + except Exception: + logger.exception('Producer task %s failed during cleanup', task_id) await self._queue_manager.close(task_id) async with self._running_agents_lock: self._running_agents.pop(task_id, None) diff --git a/tck/sut_agent.py b/tck/sut_agent.py index 525631ca0..1500d6988 100644 --- a/tck/sut_agent.py +++ b/tck/sut_agent.py @@ -24,7 +24,11 @@ TaskStatus, TaskStatusUpdateEvent, TextPart, + FilePart, + DataPart, + InvalidParamsError, ) +from a2a.utils.errors import ServerError JSONRPC_URL = '/a2a/jsonrpc' @@ -67,6 +71,41 @@ async def execute( task_id = context.task_id context_id = context.context_id + # Validate message parts + if not user_message.parts: + # Empty parts array is invalid + raise ServerError( + error=InvalidParamsError(message='Message must contain at least one part') + ) + + for part in user_message.parts: + # Unwrap RootModel if present to get the actual part + actual_part = part + if hasattr(part, 'root'): + actual_part = part.root + + # Check if it's a known part type + if not isinstance(actual_part, (TextPart, FilePart, DataPart)): + # If we received something that isn't a known part, treating it as unsupported. + # Enqueue a failed status event. + await event_queue.enqueue_event(TaskStatusUpdateEvent( + task_id=task_id, + context_id=context_id, + status=TaskStatus( + state=TaskState.failed, + message=Message( + role='agent', + message_id=str(uuid.uuid4()), + parts=[TextPart(text='Unsupported message part type')], + task_id=task_id, + context_id=context_id, + ), + timestamp=datetime.now(timezone.utc).isoformat(), + ), + final=True, + )) + return + self.running_tasks.add(task_id) logger.info( diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 88dd77ab4..f64ed04c4 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -2644,3 +2644,171 @@ async def test_on_message_send_stream_task_id_provided_but_task_not_found(): f'Task {task_id} was specified but does not exist' in exc_info.value.error.message ) + + +@pytest.mark.asyncio +async def test_on_message_send_stream_consumer_error_cancels_producer_and_closes_queue(): + """Test that if the consumer (result aggregator) raises an exception, the producer is cancelled and queue is closed immediately.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_queue_manager = AsyncMock(spec=QueueManager) + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_request_context_builder = AsyncMock(spec=RequestContextBuilder) + + task_id = 'error_cleanup_task' + context_id = 'error_cleanup_ctx' + + mock_request_context = MagicMock(spec=RequestContext) + mock_request_context.task_id = task_id + mock_request_context.context_id = context_id + mock_request_context_builder.build.return_value = mock_request_context + + mock_queue = AsyncMock(spec=EventQueue) + mock_queue_manager.create_or_tap.return_value = mock_queue + + request_handler = DefaultRequestHandler( + agent_executor=mock_agent_executor, + task_store=mock_task_store, + queue_manager=mock_queue_manager, + request_context_builder=mock_request_context_builder, + ) + + params = MessageSendParams( + message=Message( + role=Role.user, + message_id='msg_error_cleanup', + parts=[], + # Do NOT provide task_id here to avoid "Task ... was specified but does not exist" error + ) + ) + + # Mock ResultAggregator to raise exception + mock_result_aggregator_instance = MagicMock(spec=ResultAggregator) + + async def raise_error_gen(_consumer): + # Raise an exception to simulate consumer failure + raise ValueError('Consumer failed!') + yield # unreachable + + mock_result_aggregator_instance.consume_and_emit.side_effect = ( + raise_error_gen + ) + + # Capture the producer task to verify cancellation + captured_producer_task = None + original_register = request_handler._register_producer + + async def spy_register_producer(tid, task): + nonlocal captured_producer_task + captured_producer_task = task + # Wrap the cancel method to spy on it + task.cancel = MagicMock(wraps=task.cancel) + await original_register(tid, task) + + with ( + patch( + 'a2a.server.request_handlers.default_request_handler.ResultAggregator', + return_value=mock_result_aggregator_instance, + ), + patch( + 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task', + return_value=None, + ), + patch.object( + request_handler, + '_register_producer', + side_effect=spy_register_producer, + ), + ): + # Act + with pytest.raises(ValueError, match='Consumer failed!'): + async for _ in request_handler.on_message_send_stream( + params, create_server_call_context() + ): + pass + + assert captured_producer_task is not None + # Verify producer was cancelled + captured_producer_task.cancel.assert_called() + + # Verify queue closed immediately + mock_queue.close.assert_awaited_with(immediate=True) + + +@pytest.mark.asyncio +async def test_on_message_send_consumer_error_cancels_producer_and_closes_queue(): + """Test that if the consumer raises an exception during blocking wait, the producer is cancelled.""" + mock_task_store = AsyncMock(spec=TaskStore) + mock_queue_manager = AsyncMock(spec=QueueManager) + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_request_context_builder = AsyncMock(spec=RequestContextBuilder) + + task_id = 'error_cleanup_blocking_task' + context_id = 'error_cleanup_blocking_ctx' + + mock_request_context = MagicMock(spec=RequestContext) + mock_request_context.task_id = task_id + mock_request_context.context_id = context_id + mock_request_context_builder.build.return_value = mock_request_context + + mock_queue = AsyncMock(spec=EventQueue) + mock_queue_manager.create_or_tap.return_value = mock_queue + + request_handler = DefaultRequestHandler( + agent_executor=mock_agent_executor, + task_store=mock_task_store, + queue_manager=mock_queue_manager, + request_context_builder=mock_request_context_builder, + ) + + params = MessageSendParams( + message=Message( + role=Role.user, + message_id='msg_error_blocking', + parts=[], + ) + ) + + # Mock ResultAggregator to raise exception + mock_result_aggregator_instance = MagicMock(spec=ResultAggregator) + mock_result_aggregator_instance.consume_and_break_on_interrupt.side_effect = ValueError( + 'Consumer failed!' + ) + + # Capture the producer task to verify cancellation + captured_producer_task = None + original_register = request_handler._register_producer + + async def spy_register_producer(tid, task): + nonlocal captured_producer_task + captured_producer_task = task + # Wrap the cancel method to spy on it + task.cancel = MagicMock(wraps=task.cancel) + await original_register(tid, task) + + with ( + patch( + 'a2a.server.request_handlers.default_request_handler.ResultAggregator', + return_value=mock_result_aggregator_instance, + ), + patch( + 'a2a.server.request_handlers.default_request_handler.TaskManager.get_task', + return_value=None, + ), + patch.object( + request_handler, + '_register_producer', + side_effect=spy_register_producer, + ), + ): + # Act + with pytest.raises(ValueError, match='Consumer failed!'): + await request_handler.on_message_send( + params, create_server_call_context() + ) + + assert captured_producer_task is not None + # Verify producer was cancelled + captured_producer_task.cancel.assert_called() + + # Verify queue closed immediately + mock_queue.close.assert_awaited_with(immediate=True) diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index d1ead0211..d10d544ac 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -322,7 +322,6 @@ async def streaming_coro(): self.assertIsInstance(response.root, JSONRPCErrorResponse) assert response.root.error == UnsupportedOperationError() # type: ignore - mock_agent_executor.execute.assert_called_once() @patch( 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'