From e75ae0cb0dbcb12bc4aeb1e308320efc749290bc Mon Sep 17 00:00:00 2001 From: Di-Is Date: Fri, 20 Mar 2026 11:01:55 +0900 Subject: [PATCH] fix: handle BaseException in trace spans to prevent span leaks on KeyboardInterrupt Trace spans were not properly closed when BaseException (e.g. KeyboardInterrupt, asyncio.CancelledError) was raised. Add explicit BaseException handlers to close spans and aclose() calls to ensure async generators are cleaned up. --- src/strands/agent/agent.py | 4 +- src/strands/event_loop/event_loop.py | 30 ++++- src/strands/telemetry/tracer.py | 10 +- tests/strands/agent/test_agent.py | 21 ++++ tests/strands/event_loop/test_event_loop.py | 115 ++++++++++++++++++++ tests/strands/telemetry/test_tracer.py | 36 ++++++ 6 files changed, 204 insertions(+), 12 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 3a23133de..9cdf071fd 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -841,7 +841,7 @@ async def stream_async( self._end_agent_trace_span(response=result) - except Exception as e: + except BaseException as e: self._end_agent_trace_span(error=e) raise @@ -1044,7 +1044,7 @@ def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span: def _end_agent_trace_span( self, response: AgentResult | None = None, - error: Exception | None = None, + error: BaseException | None = None, ) -> None: """Ends a trace span for the agent. diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index b4af16058..835122b4e 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -138,6 +138,7 @@ async def event_loop_cycle( custom_trace_attributes=agent.trace_attributes, ) invocation_state["event_loop_cycle_span"] = cycle_span + model_events: AsyncGenerator[TypedEvent, None] | None = None with trace_api.use_span(cycle_span, end_on_exit=False): try: @@ -153,15 +154,21 @@ async def event_loop_cycle( model_events = _handle_model_execution( agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context ) - async for model_event in model_events: - if not isinstance(model_event, ModelStopReason): - yield model_event + try: + async for model_event in model_events: + if not isinstance(model_event, ModelStopReason): + yield model_event + finally: + await model_events.aclose() stop_reason, message, *_ = model_event["stop"] yield ModelMessageEvent(message=message) except Exception as e: tracer.end_span_with_error(cycle_span, str(e), e) raise + except BaseException as e: + tracer.end_span_with_error(cycle_span, str(e), e) + raise try: if stop_reason == "max_tokens": @@ -238,6 +245,9 @@ async def event_loop_cycle( yield ForceStopEvent(reason=e) logger.exception("cycle failed") raise EventLoopException(e, invocation_state["request_state"]) from e + except BaseException as e: + tracer.end_span_with_error(cycle_span, str(e), e) + raise async def recurse_event_loop( @@ -323,6 +333,7 @@ async def _handle_model_execution( system_prompt=agent.system_prompt, system_prompt_content=agent._system_prompt_content, ) + streamed_events: AsyncGenerator[TypedEvent, None] | None = None with trace_api.use_span(model_invoke_span, end_on_exit=False): try: await agent.hooks.invoke_callbacks_async( @@ -338,7 +349,7 @@ async def _handle_model_execution( else: tool_specs = agent.tool_registry.get_all_tool_specs() - async for event in stream_messages( + streamed_events = stream_messages( agent.model, agent.system_prompt, agent.messages, @@ -348,8 +359,12 @@ async def _handle_model_execution( invocation_state=invocation_state, model_state=agent._model_state, cancel_signal=agent._cancel_signal, - ): - yield event + ) + try: + async for event in streamed_events: + yield event + finally: + await streamed_events.aclose() stop_reason, message, usage, metrics = event["stop"] invocation_state.setdefault("request_state", {}) @@ -410,6 +425,9 @@ async def _handle_model_execution( # No retry requested, raise the exception yield ForceStopEvent(reason=e) raise e + except BaseException as e: + tracer.end_span_with_error(model_invoke_span, str(e), e) + raise try: # Add message in trace and mark the end of the stream messages trace diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 19a163f5c..1ff968558 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -184,7 +184,7 @@ def _end_span( self, span: Span, attributes: dict[str, AttributeValue] | None = None, - error: Exception | None = None, + error: BaseException | None = None, error_message: str | None = None, ) -> None: """Generic helper method to end a span. @@ -224,7 +224,7 @@ def _end_span( except Exception as e: logger.warning("error=<%s> | failed to force flush tracer provider", e) - def end_span_with_error(self, span: Span, error_message: str, exception: Exception | None = None) -> None: + def end_span_with_error(self, span: Span, error_message: str, exception: BaseException | None = None) -> None: """End a span with error status. Args: @@ -450,7 +450,9 @@ def start_tool_call_span( return span - def end_tool_call_span(self, span: Span, tool_result: ToolResult | None, error: Exception | None = None) -> None: + def end_tool_call_span( + self, span: Span, tool_result: ToolResult | None, error: BaseException | None = None + ) -> None: """End a tool call span with results. Args: @@ -650,7 +652,7 @@ def end_agent_span( self, span: Span, response: AgentResult | None = None, - error: Exception | None = None, + error: BaseException | None = None, ) -> None: """End an agent span with results and metrics. diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 5a3cce11c..337b269af 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1423,6 +1423,27 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) +@pytest.mark.asyncio +@unittest.mock.patch("strands.agent.agent.get_tracer") +async def test_agent_stream_async_creates_and_ends_span_on_base_exception(mock_get_tracer, mock_model, alist): + """Test that stream_async ends the agent span when a BaseException occurs.""" + mock_tracer = unittest.mock.MagicMock() + mock_span = unittest.mock.MagicMock() + mock_tracer.start_agent_span.return_value = mock_span + mock_get_tracer.return_value = mock_tracer + + test_exception = KeyboardInterrupt("stop now") + mock_model.mock_stream.side_effect = test_exception + + agent = Agent(model=mock_model) + + with pytest.raises(KeyboardInterrupt, match="stop now"): + stream = agent.stream_async("test prompt") + await alist(stream) + + mock_tracer.end_agent_span.assert_called_once_with(span=mock_span, error=test_exception) + + def test_agent_init_with_state_object(): agent = Agent(state=AgentState({"foo": "bar"})) assert agent.state.get("foo") == "bar" diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index f91f7c2af..5903651f4 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -685,6 +685,121 @@ async def test_event_loop_tracing_with_tool_execution( assert mock_tracer.end_model_invoke_span.call_count == 2 +@patch("strands.event_loop.event_loop.get_tracer") +@pytest.mark.asyncio +async def test_event_loop_cycle_closes_spans_on_stream_aclose( + mock_get_tracer, + agent, + model, + mock_tracer, +): + mock_get_tracer.return_value = mock_tracer + cycle_span = MagicMock() + mock_tracer.start_event_loop_cycle_span.return_value = cycle_span + model_span = MagicMock() + mock_tracer.start_model_invoke_span.return_value = model_span + + async def interrupted_stream(): + yield {"contentBlockDelta": {"delta": {"text": "test text"}}} + await asyncio.sleep(10) + yield {"contentBlockStop": {}} + + model.stream.return_value = interrupted_stream() + + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + await anext(stream) + await anext(stream) + await anext(stream) + await stream.aclose() + + assert [call.args[0] for call in mock_tracer.end_span_with_error.call_args_list] == [model_span, cycle_span] + assert [call.args[1] for call in mock_tracer.end_span_with_error.call_args_list] == [ + "", + "", + ] + + +@patch("strands.event_loop.event_loop.get_tracer") +@pytest.mark.asyncio +async def test_event_loop_cycle_closes_spans_on_task_cancellation( + mock_get_tracer, + agent, + model, + mock_tracer, +): + mock_get_tracer.return_value = mock_tracer + cycle_span = MagicMock() + mock_tracer.start_event_loop_cycle_span.return_value = cycle_span + model_span = MagicMock() + mock_tracer.start_model_invoke_span.return_value = model_span + + blocked_on_stream = asyncio.Event() + release_stream = asyncio.Event() + + async def interrupted_stream(): + yield {"contentBlockDelta": {"delta": {"text": "test text"}}} + blocked_on_stream.set() + await release_stream.wait() + yield {"contentBlockStop": {}} + + model.stream.return_value = interrupted_stream() + + async def consume() -> None: + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + async for _ in stream: + pass + + task = asyncio.create_task(consume()) + await blocked_on_stream.wait() + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + assert [call.args[0] for call in mock_tracer.end_span_with_error.call_args_list] == [model_span, cycle_span] + assert [call.args[1] for call in mock_tracer.end_span_with_error.call_args_list] == [ + "", + "", + ] + + +@patch("strands.event_loop.event_loop.get_tracer") +@pytest.mark.asyncio +async def test_event_loop_cycle_closes_spans_on_keyboard_interrupt( + mock_get_tracer, + agent, + model, + mock_tracer, + alist, +): + mock_get_tracer.return_value = mock_tracer + cycle_span = MagicMock() + mock_tracer.start_event_loop_cycle_span.return_value = cycle_span + model_span = MagicMock() + mock_tracer.start_model_invoke_span.return_value = model_span + + test_exception = KeyboardInterrupt("stop now") + model.stream.side_effect = test_exception + + with pytest.raises(KeyboardInterrupt, match="stop now"): + stream = strands.event_loop.event_loop.event_loop_cycle( + agent=agent, + invocation_state={}, + ) + await alist(stream) + + assert mock_tracer.end_span_with_error.call_args_list == [ + call(model_span, "stop now", test_exception), + call(cycle_span, "stop now", test_exception), + ] + + @pytest.mark.asyncio async def test_event_loop_cycle_closes_cycle_span_before_recursive_cycle( agent, diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index bcd42b610..f1f26b835 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -140,6 +140,18 @@ def test_end_span_with_empty_exception_message_uses_exception_name(mock_span): mock_span.end.assert_called_once() +def test_end_span_with_empty_base_exception_message_uses_exception_name(mock_span): + """Test that empty BaseException messages fall back to the exception type name.""" + tracer = Tracer() + error = KeyboardInterrupt() + + tracer.end_span_with_error(mock_span, "", error) + + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "KeyboardInterrupt") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + def test_end_span_with_error_prefers_explicit_message(mock_span): """Test that an explicit error message takes precedence over the exception text.""" tracer = Tracer() @@ -1162,6 +1174,30 @@ def test_force_flush_with_error(mock_span, mock_get_tracer_provider): mock_tracer_provider.force_flush.assert_called_once() +def test_end_agent_span_with_empty_error_message_uses_exception_name(mock_span): + """Test that agent spans fall back to the exception type name for empty errors.""" + tracer = Tracer() + error = Exception() + + tracer.end_agent_span(mock_span, error=error) + + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Exception") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + +def test_end_tool_call_span_with_empty_error_message_uses_exception_name(mock_span): + """Test that tool call spans fall back to the exception type name for empty errors.""" + tracer = Tracer() + error = Exception() + + tracer.end_tool_call_span(mock_span, None, error=error) + + mock_span.set_status.assert_called_once_with(StatusCode.ERROR, "Exception") + mock_span.record_exception.assert_called_once_with(error) + mock_span.end.assert_called_once() + + def test_end_tool_call_span_with_none(mock_span): """Test ending a tool call span with None result.""" tracer = Tracer()