Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/a2a/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,5 @@ class A2ABaseModel(BaseModel):
validate_by_alias=True,
serialize_by_alias=True,
alias_generator=to_camel_custom,
extra='forbid',
)
27 changes: 25 additions & 2 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ async def push_notification_callback() -> None:
)

except Exception:
logger.exception('Agent execution failed')
await self._handle_execution_failure(producer_task, queue)
raise
finally:
if interrupted_or_non_blocking:
Expand Down Expand Up @@ -392,6 +392,10 @@ async def on_message_send_stream(
bg_task.set_name(f'background_consume:{task_id}')
self._track_background_task(bg_task)
raise
except Exception:
# If the consumer fails (e.g. database error), we must cleanup.
await self._handle_execution_failure(producer_task, queue)
raise
finally:
cleanup_task = asyncio.create_task(
self._cleanup_producer(producer_task, task_id)
Expand Down Expand Up @@ -429,13 +433,32 @@ def _on_done(completed: asyncio.Task) -> None:

task.add_done_callback(_on_done)

async def _handle_execution_failure(
self, producer_task: asyncio.Task, queue: EventQueue
) -> None:
"""Cancels the producer and closes the queue immediately on failure."""
logger.exception('Agent execution failed')
# If the consumer fails, we must cancel the producer to prevent it from hanging
# on queue operations (e.g., waiting for the queue to drain).
producer_task.cancel()
# Force the queue to close immediately, discarding any pending events.
# This ensures that any producers waiting on the queue are unblocked.
await queue.close(immediate=True)

async def _cleanup_producer(
self,
producer_task: asyncio.Task,
task_id: str,
) -> None:
"""Cleans up the agent execution task and queue manager entry."""
await producer_task
try:
await producer_task
except asyncio.CancelledError:
logger.debug(
'Producer task %s was cancelled during cleanup', task_id
)
except Exception:
logger.exception('Producer task %s failed during cleanup', task_id)
await self._queue_manager.close(task_id)
async with self._running_agents_lock:
self._running_agents.pop(task_id, None)
Expand Down
39 changes: 39 additions & 0 deletions tck/sut_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
TaskStatus,
TaskStatusUpdateEvent,
TextPart,
FilePart,
DataPart,
InvalidParamsError,
)
from a2a.utils.errors import ServerError


JSONRPC_URL = '/a2a/jsonrpc'
Expand Down Expand Up @@ -67,6 +71,41 @@ async def execute(
task_id = context.task_id
context_id = context.context_id

# Validate message parts
if not user_message.parts:
# Empty parts array is invalid
raise ServerError(
error=InvalidParamsError(message='Message must contain at least one part')
)

for part in user_message.parts:
# Unwrap RootModel if present to get the actual part
actual_part = part
if hasattr(part, 'root'):
actual_part = part.root

# Check if it's a known part type
if not isinstance(actual_part, (TextPart, FilePart, DataPart)):
# If we received something that isn't a known part, treating it as unsupported.
# Enqueue a failed status event.
await event_queue.enqueue_event(TaskStatusUpdateEvent(
task_id=task_id,
context_id=context_id,
status=TaskStatus(
state=TaskState.failed,
message=Message(
role='agent',
message_id=str(uuid.uuid4()),
parts=[TextPart(text='Unsupported message part type')],
task_id=task_id,
context_id=context_id,
),
timestamp=datetime.now(timezone.utc).isoformat(),
),
final=True,
))
return

self.running_tasks.add(task_id)

logger.info(
Expand Down
168 changes: 168 additions & 0 deletions tests/server/request_handlers/test_default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2644,3 +2644,171 @@ async def test_on_message_send_stream_task_id_provided_but_task_not_found():
f'Task {task_id} was specified but does not exist'
in exc_info.value.error.message
)


@pytest.mark.asyncio
async def test_on_message_send_stream_consumer_error_cancels_producer_and_closes_queue():
"""Test that if the consumer (result aggregator) raises an exception, the producer is cancelled and queue is closed immediately."""
mock_task_store = AsyncMock(spec=TaskStore)
mock_queue_manager = AsyncMock(spec=QueueManager)
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)

task_id = 'error_cleanup_task'
context_id = 'error_cleanup_ctx'

mock_request_context = MagicMock(spec=RequestContext)
mock_request_context.task_id = task_id
mock_request_context.context_id = context_id
mock_request_context_builder.build.return_value = mock_request_context

mock_queue = AsyncMock(spec=EventQueue)
mock_queue_manager.create_or_tap.return_value = mock_queue

request_handler = DefaultRequestHandler(
agent_executor=mock_agent_executor,
task_store=mock_task_store,
queue_manager=mock_queue_manager,
request_context_builder=mock_request_context_builder,
)

params = MessageSendParams(
message=Message(
role=Role.user,
message_id='msg_error_cleanup',
parts=[],
# Do NOT provide task_id here to avoid "Task ... was specified but does not exist" error
)
)

# Mock ResultAggregator to raise exception
mock_result_aggregator_instance = MagicMock(spec=ResultAggregator)

async def raise_error_gen(_consumer):
# Raise an exception to simulate consumer failure
raise ValueError('Consumer failed!')
yield # unreachable

mock_result_aggregator_instance.consume_and_emit.side_effect = (
raise_error_gen
)

# Capture the producer task to verify cancellation
captured_producer_task = None
original_register = request_handler._register_producer

async def spy_register_producer(tid, task):
nonlocal captured_producer_task
captured_producer_task = task
# Wrap the cancel method to spy on it
task.cancel = MagicMock(wraps=task.cancel)
await original_register(tid, task)

with (
patch(
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
return_value=mock_result_aggregator_instance,
),
patch(
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
return_value=None,
),
patch.object(
request_handler,
'_register_producer',
side_effect=spy_register_producer,
),
):
# Act
with pytest.raises(ValueError, match='Consumer failed!'):
async for _ in request_handler.on_message_send_stream(
params, create_server_call_context()
):
pass

assert captured_producer_task is not None
# Verify producer was cancelled
captured_producer_task.cancel.assert_called()

# Verify queue closed immediately
mock_queue.close.assert_awaited_with(immediate=True)


@pytest.mark.asyncio
async def test_on_message_send_consumer_error_cancels_producer_and_closes_queue():
"""Test that if the consumer raises an exception during blocking wait, the producer is cancelled."""
mock_task_store = AsyncMock(spec=TaskStore)
mock_queue_manager = AsyncMock(spec=QueueManager)
mock_agent_executor = AsyncMock(spec=AgentExecutor)
mock_request_context_builder = AsyncMock(spec=RequestContextBuilder)

task_id = 'error_cleanup_blocking_task'
context_id = 'error_cleanup_blocking_ctx'

mock_request_context = MagicMock(spec=RequestContext)
mock_request_context.task_id = task_id
mock_request_context.context_id = context_id
mock_request_context_builder.build.return_value = mock_request_context

mock_queue = AsyncMock(spec=EventQueue)
mock_queue_manager.create_or_tap.return_value = mock_queue

request_handler = DefaultRequestHandler(
agent_executor=mock_agent_executor,
task_store=mock_task_store,
queue_manager=mock_queue_manager,
request_context_builder=mock_request_context_builder,
)

params = MessageSendParams(
message=Message(
role=Role.user,
message_id='msg_error_blocking',
parts=[],
)
)

# Mock ResultAggregator to raise exception
mock_result_aggregator_instance = MagicMock(spec=ResultAggregator)
mock_result_aggregator_instance.consume_and_break_on_interrupt.side_effect = ValueError(
'Consumer failed!'
)

# Capture the producer task to verify cancellation
captured_producer_task = None
original_register = request_handler._register_producer

async def spy_register_producer(tid, task):
nonlocal captured_producer_task
captured_producer_task = task
# Wrap the cancel method to spy on it
task.cancel = MagicMock(wraps=task.cancel)
await original_register(tid, task)

with (
patch(
'a2a.server.request_handlers.default_request_handler.ResultAggregator',
return_value=mock_result_aggregator_instance,
),
patch(
'a2a.server.request_handlers.default_request_handler.TaskManager.get_task',
return_value=None,
),
patch.object(
request_handler,
'_register_producer',
side_effect=spy_register_producer,
),
):
# Act
with pytest.raises(ValueError, match='Consumer failed!'):
await request_handler.on_message_send(
params, create_server_call_context()
)

assert captured_producer_task is not None
# Verify producer was cancelled
captured_producer_task.cancel.assert_called()

# Verify queue closed immediately
mock_queue.close.assert_awaited_with(immediate=True)
1 change: 0 additions & 1 deletion tests/server/request_handlers/test_jsonrpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,6 @@ async def streaming_coro():

self.assertIsInstance(response.root, JSONRPCErrorResponse)
assert response.root.error == UnsupportedOperationError() # type: ignore
mock_agent_executor.execute.assert_called_once()

@patch(
'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build'
Expand Down