Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 70 additions & 62 deletions src/strands/event_loop/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,25 +139,29 @@ async def event_loop_cycle(
)
invocation_state["event_loop_cycle_span"] = cycle_span

with trace_api.use_span(cycle_span, end_on_exit=True):
# Skipping model invocation if in interrupt state as interrupts are currently only supported for tool calls.
if agent._interrupt_state.activated:
stop_reason: StopReason = "tool_use"
message = agent._interrupt_state.context["tool_use_message"]
# Skip model invocation if the latest message contains ToolUse
elif _has_tool_use_in_latest_message(agent.messages):
stop_reason = "tool_use"
message = agent.messages[-1]
else:
model_events = _handle_model_execution(
agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context
)
async for model_event in model_events:
if not isinstance(model_event, ModelStopReason):
yield model_event
with trace_api.use_span(cycle_span, end_on_exit=False):
try:
# Skipping model invocation if in interrupt state as interrupts are currently only supported for tool calls.
if agent._interrupt_state.activated:
stop_reason: StopReason = "tool_use"
message = agent._interrupt_state.context["tool_use_message"]
# Skip model invocation if the latest message contains ToolUse
elif _has_tool_use_in_latest_message(agent.messages):
stop_reason = "tool_use"
message = agent.messages[-1]
else:
model_events = _handle_model_execution(
agent, cycle_span, cycle_trace, invocation_state, tracer, structured_output_context
)
async for model_event in model_events:
if not isinstance(model_event, ModelStopReason):
yield model_event

stop_reason, message, *_ = model_event["stop"]
yield ModelMessageEvent(message=message)
stop_reason, message, *_ = model_event["stop"]
yield ModelMessageEvent(message=message)
except Exception as e:
tracer.end_span_with_error(cycle_span, str(e), e)
raise

try:
if stop_reason == "max_tokens":
Expand Down Expand Up @@ -196,42 +200,45 @@ async def event_loop_cycle(

# End the cycle and return results
agent.event_loop_metrics.end_cycle(cycle_start_time, cycle_trace, attributes)
# Set attributes before span auto-closes

# Force structured output tool call if LLM didn't use it automatically
if structured_output_context.is_enabled and stop_reason == "end_turn":
if structured_output_context.force_attempted:
raise StructuredOutputException(
"The model failed to invoke the structured output tool even after it was forced."
)
structured_output_context.set_forced_mode()
logger.debug("Forcing structured output tool")
await agent._append_messages(
{"role": "user", "content": [{"text": structured_output_context.structured_output_prompt}]}
)

tracer.end_event_loop_cycle_span(cycle_span, message)
events = recurse_event_loop(
agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context
)
async for typed_event in events:
yield typed_event
return

tracer.end_event_loop_cycle_span(cycle_span, message)
except EventLoopException:
# Don't yield or log the exception - we already did it when we
# raised the exception and we don't need that duplication.
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])
except (
StructuredOutputException,
EventLoopException,
ContextWindowOverflowException,
MaxTokensReachedException,
) as e:
# These exceptions should bubble up directly rather than get wrapped in an EventLoopException
tracer.end_span_with_error(cycle_span, str(e), e)
raise
except (ContextWindowOverflowException, MaxTokensReachedException) as e:
# Special cased exceptions which we want to bubble up rather than get wrapped in an EventLoopException
raise e
except Exception as e:
tracer.end_span_with_error(cycle_span, str(e), e)
# Handle any other exceptions
yield ForceStopEvent(reason=e)
logger.exception("cycle failed")
raise EventLoopException(e, invocation_state["request_state"]) from e

# Force structured output tool call if LLM didn't use it automatically
if structured_output_context.is_enabled and stop_reason == "end_turn":
if structured_output_context.force_attempted:
raise StructuredOutputException(
"The model failed to invoke the structured output tool even after it was forced."
)
structured_output_context.set_forced_mode()
logger.debug("Forcing structured output tool")
await agent._append_messages(
{"role": "user", "content": [{"text": structured_output_context.structured_output_prompt}]}
)

events = recurse_event_loop(
agent=agent, invocation_state=invocation_state, structured_output_context=structured_output_context
)
async for typed_event in events:
yield typed_event
return

yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])


async def recurse_event_loop(
agent: "Agent",
Expand Down Expand Up @@ -316,20 +323,21 @@ async def _handle_model_execution(
system_prompt=agent.system_prompt,
system_prompt_content=agent._system_prompt_content,
)
with trace_api.use_span(model_invoke_span, end_on_exit=True):
await agent.hooks.invoke_callbacks_async(
BeforeModelCallEvent(
agent=agent,
invocation_state=invocation_state,
with trace_api.use_span(model_invoke_span, end_on_exit=False):
try:
await agent.hooks.invoke_callbacks_async(
BeforeModelCallEvent(
agent=agent,
invocation_state=invocation_state,
)
)
)

if structured_output_context.forced_mode:
tool_spec = structured_output_context.get_tool_spec()
tool_specs = [tool_spec] if tool_spec else []
else:
tool_specs = agent.tool_registry.get_all_tool_specs()
try:
if structured_output_context.forced_mode:
tool_spec = structured_output_context.get_tool_spec()
tool_specs = [tool_spec] if tool_spec else []
else:
tool_specs = agent.tool_registry.get_all_tool_specs()

async for event in stream_messages(
agent.model,
agent.system_prompt,
Expand Down Expand Up @@ -363,17 +371,17 @@ async def _handle_model_execution(
"stop_reason=<%s>, retry_requested=<True> | hook requested model retry",
stop_reason,
)
tracer.end_model_invoke_span(model_invoke_span, message, usage, metrics, stop_reason)
continue # Retry the model call

if stop_reason == "max_tokens":
message = recover_message_on_max_tokens_reached(message)

# Set attributes before span auto-closes
tracer.end_model_invoke_span(model_invoke_span, message, usage, metrics, stop_reason)
break # Success! Break out of retry loop

except Exception as e:
# Exception is automatically recorded by use_span with end_on_exit=True
tracer.end_span_with_error(model_invoke_span, str(e), e)
after_model_call_event = AfterModelCallEvent(
agent=agent,
invocation_state=invocation_state,
Expand Down Expand Up @@ -541,7 +549,7 @@ async def _handle_tool_execution(
interrupts,
structured_output=structured_output_result,
)
# Set attributes before span auto-closes (span is managed by use_span in event_loop_cycle)
# End the cycle span before yielding the recursive cycle.
if cycle_span:
tracer.end_event_loop_cycle_span(span=cycle_span, message=message)

Expand All @@ -559,7 +567,7 @@ async def _handle_tool_execution(

yield ToolResultMessageEvent(message=tool_result_message)

# Set attributes before span auto-closes (span is managed by use_span in event_loop_cycle)
# End the cycle span before yielding the recursive cycle.
if cycle_span:
tracer.end_event_loop_cycle_span(span=cycle_span, message=message, tool_result_message=tool_result_message)

Expand Down
34 changes: 15 additions & 19 deletions src/strands/telemetry/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,15 +185,17 @@ def _end_span(
span: Span,
attributes: dict[str, AttributeValue] | None = None,
error: Exception | None = None,
error_message: str | None = None,
) -> None:
"""Generic helper method to end a span.

Args:
span: The span to end
attributes: Optional attributes to set before ending the span
error: Optional exception if an error occurred
error_message: Optional error message to set in the span status
"""
if not span:
if not span or not span.is_recording():
return

try:
Expand All @@ -206,7 +208,8 @@ def _end_span(

# Handle error if present
if error:
span.set_status(StatusCode.ERROR, str(error))
status_description = error_message or str(error) or type(error).__name__
span.set_status(StatusCode.ERROR, status_description)
span.record_exception(error)
else:
span.set_status(StatusCode.OK)
Expand All @@ -229,11 +232,11 @@ def end_span_with_error(self, span: Span, error_message: str, exception: Excepti
error_message: Error message to set in the span status.
exception: Optional exception to record in the span.
"""
if not span:
if not span or not span.is_recording():
return

error = exception or Exception(error_message)
self._end_span(span, error=error)
self._end_span(span, error=error, error_message=error_message)

def _add_event(
self, span: Span | None, event_name: str, event_attributes: Attributes, to_span_attributes: bool = False
Expand Down Expand Up @@ -330,18 +333,15 @@ def end_model_invoke_span(
) -> None:
"""End a model invocation span with results and metrics.

Note: The span is automatically closed and exceptions recorded. This method just sets the necessary attributes.
Status in the span is automatically set to UNSET (OK) on success or ERROR on exception.

Args:
span: The span to set attributes on.
span: The span to end.
message: The message response from the model.
usage: Token usage information from the model call.
metrics: Metrics from the model call.
stop_reason: The reason the model stopped generating.
"""
# Set end time attribute
span.set_attribute("gen_ai.event.end_time", datetime.now(timezone.utc).isoformat())
if not span or not span.is_recording():
return

attributes: dict[str, AttributeValue] = {
"gen_ai.usage.prompt_tokens": usage["inputTokens"],
Expand Down Expand Up @@ -378,7 +378,7 @@ def end_model_invoke_span(
event_attributes={"finish_reason": str(stop_reason), "message": serialize(message["content"])},
)

span.set_attributes(attributes)
self._end_span(span, attributes)

def start_tool_call_span(
self,
Expand Down Expand Up @@ -553,20 +553,14 @@ def end_event_loop_cycle_span(
) -> None:
"""End an event loop cycle span with results.

Note: The span is automatically closed and exceptions recorded. This method just sets the necessary attributes.
Status in the span is automatically set to UNSET (OK) on success or ERROR on exception.

Args:
span: The span to set attributes on.
span: The span to end.
message: The message response from this cycle.
tool_result_message: Optional tool result message if a tool was called.
"""
if not span:
if not span or not span.is_recording():
return

# Set end time attribute
span.set_attribute("gen_ai.event.end_time", datetime.now(timezone.utc).isoformat())

event_attributes: dict[str, AttributeValue] = {"message": serialize(message["content"])}

if tool_result_message:
Expand All @@ -591,6 +585,8 @@ def end_event_loop_cycle_span(
else:
self._add_event(span, "gen_ai.choice", event_attributes=event_attributes)

self._end_span(span)

def start_agent_span(
self,
messages: Messages,
Expand Down
60 changes: 60 additions & 0 deletions tests/strands/event_loop/test_event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from unittest.mock import ANY, AsyncMock, MagicMock, call, patch

import pytest
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

import strands
import strands.telemetry
Expand All @@ -19,6 +22,7 @@
)
from strands.interrupt import Interrupt, _InterruptState
from strands.telemetry.metrics import EventLoopMetrics
from strands.telemetry.tracer import Tracer
from strands.tools.executors import SequentialToolExecutor
from strands.tools.registry import ToolRegistry
from strands.types._events import EventLoopStopEvent
Expand Down Expand Up @@ -583,6 +587,14 @@ async def test_event_loop_tracing_with_model_error(
)
await alist(stream)

assert mock_tracer.end_span_with_error.call_count == 2
mock_tracer.end_span_with_error.assert_has_calls(
[
call(model_span, "Input too long", model.stream.side_effect),
call(cycle_span, "Input too long", model.stream.side_effect),
]
)


@pytest.mark.asyncio
async def test_event_loop_cycle_max_tokens_exception(
Expand Down Expand Up @@ -673,6 +685,53 @@ async def test_event_loop_tracing_with_tool_execution(
assert mock_tracer.end_model_invoke_span.call_count == 2


@pytest.mark.asyncio
async def test_event_loop_cycle_closes_cycle_span_before_recursive_cycle(
agent,
model,
tool_stream,
agenerator,
alist,
):
exporter = InMemorySpanExporter()
provider = TracerProvider()
provider.add_span_processor(SimpleSpanProcessor(exporter))

tracer = Tracer()
tracer.tracer_provider = provider
tracer.tracer = provider.get_tracer(tracer.service_name)

async def delayed_text_stream():
yield {"contentBlockDelta": {"delta": {"text": "test text"}}}
await asyncio.sleep(0.05)
yield {"contentBlockStop": {}}

agent.trace_span = None
agent._system_prompt_content = None
model.config = {"model_id": "test-model"}
model.stream.side_effect = [
agenerator(tool_stream),
delayed_text_stream(),
]

with patch("strands.event_loop.event_loop.get_tracer", return_value=tracer):
stream = strands.event_loop.event_loop.event_loop_cycle(
agent=agent,
invocation_state={},
)
await alist(stream)

provider.force_flush()
cycle_spans = sorted(
[span for span in exporter.get_finished_spans() if span.name == "execute_event_loop_cycle"],
key=lambda span: span.start_time,
)

assert len(cycle_spans) == 2
assert cycle_spans[0].end_time <= cycle_spans[1].start_time
assert cycle_spans[0].end_time < cycle_spans[1].end_time


@patch("strands.event_loop.event_loop.get_tracer")
@pytest.mark.asyncio
async def test_event_loop_tracing_with_throttling_exception(
Expand Down Expand Up @@ -709,6 +768,7 @@ async def test_event_loop_tracing_with_throttling_exception(
)
await alist(stream)

assert mock_tracer.end_span_with_error.call_count == 1
# Verify span was created for the successful retry
assert mock_tracer.start_model_invoke_span.call_count == 2
assert mock_tracer.end_model_invoke_span.call_count == 1
Expand Down
Loading
Loading