diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index b5f6fcf91..5bac0a05d 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -409,8 +409,12 @@ async def stream( if event.type in AnthropicModel.EVENT_TYPES: yield self.format_chunk(event.model_dump()) - usage = event.message.usage # type: ignore - yield self.format_chunk({"type": "metadata", "usage": usage.model_dump()}) + try: + message_snapshot = await stream.get_final_message() + except AssertionError as e: + logger.warning("error=<%s> | failed to retrieve message snapshot, usage metadata unavailable", e) + else: + yield self.format_chunk({"type": "metadata", "usage": message_snapshot.usage.model_dump()}) except anthropic.RateLimitError as error: raise ModelThrottledException(str(error)) from error diff --git a/tests/strands/models/test_anthropic.py b/tests/strands/models/test_anthropic.py index c5aff8062..ec11b16eb 100644 --- a/tests/strands/models/test_anthropic.py +++ b/tests/strands/models/test_anthropic.py @@ -52,6 +52,24 @@ class TestOutputModel(pydantic.BaseModel): return TestOutputModel +def generate_mock_stream_context(events, final_message=None): + mock_stream = unittest.mock.AsyncMock() + + async def mock_aiter(self): + for event in events: + yield event + + mock_stream.__aiter__ = mock_aiter + if isinstance(final_message, Exception): + mock_stream.get_final_message.side_effect = final_message + elif final_message: + mock_stream.get_final_message.return_value = final_message + + mock_context = unittest.mock.AsyncMock() + mock_context.__aenter__.return_value = mock_stream + return mock_context + + def test__init__model_configs(anthropic_client, model_id, max_tokens): _ = anthropic_client @@ -692,7 +710,7 @@ def test_format_chunk_unknown(model): @pytest.mark.asyncio -async def test_stream(anthropic_client, model, agenerator, alist): +async def test_stream(anthropic_client, model, alist): mock_event_1 = unittest.mock.Mock( type="message_start", dict=lambda: {"type": "message_start"}, @@ -713,9 +731,14 @@ async def test_stream(anthropic_client, model, agenerator, alist): ), ) - mock_context = unittest.mock.AsyncMock() - mock_context.__aenter__.return_value = agenerator([mock_event_1, mock_event_2, mock_event_3]) - anthropic_client.messages.stream.return_value = mock_context + anthropic_client.messages.stream.return_value = generate_mock_stream_context( + [mock_event_1, mock_event_2, mock_event_3], + final_message=unittest.mock.Mock( + usage=unittest.mock.Mock( + model_dump=lambda: {"input_tokens": 1, "output_tokens": 2}, + ) + ), + ) messages = [{"role": "user", "content": [{"text": "hello"}]}] response = model.stream(messages, None, None) @@ -738,6 +761,42 @@ async def test_stream(anthropic_client, model, agenerator, alist): anthropic_client.messages.stream.assert_called_once_with(**expected_request) +@pytest.mark.asyncio +async def test_stream_early_termination(anthropic_client, model, alist, caplog): + caplog.set_level(logging.WARNING, logger="strands.models.anthropic") + mock_event = unittest.mock.Mock( + type="message_start", + model_dump=lambda: {"type": "message_start"}, + ) + + anthropic_client.messages.stream.return_value = generate_mock_stream_context( + [mock_event], + final_message=AssertionError("message snapshot is not available"), + ) + + messages = [{"role": "user", "content": [{"text": "hello"}]}] + tru_events = await alist(model.stream(messages, None, None)) + + assert len(tru_events) == 1 + assert "messageStart" in tru_events[0] + assert "failed to retrieve message snapshot, usage metadata unavailable" in caplog.text + + +@pytest.mark.asyncio +async def test_stream_empty(anthropic_client, model, alist, caplog): + caplog.set_level(logging.WARNING, logger="strands.models.anthropic") + anthropic_client.messages.stream.return_value = generate_mock_stream_context( + [], + final_message=AssertionError("message snapshot is not available"), + ) + + messages = [{"role": "user", "content": [{"text": "hello"}]}] + tru_events = await alist(model.stream(messages, None, None)) + + assert tru_events == [] + assert "failed to retrieve message snapshot, usage metadata unavailable" in caplog.text + + @pytest.mark.asyncio async def test_stream_rate_limit_error(anthropic_client, model, alist): anthropic_client.messages.stream.side_effect = anthropic.RateLimitError( @@ -780,7 +839,7 @@ async def test_stream_bad_request_error(anthropic_client, model): @pytest.mark.asyncio -async def test_structured_output(anthropic_client, model, test_output_model_cls, agenerator, alist): +async def test_structured_output(anthropic_client, model, test_output_model_cls, alist): messages = [{"role": "user", "content": [{"text": "Generate a person"}]}] events = [ @@ -815,18 +874,16 @@ async def test_structured_output(anthropic_client, model, test_output_model_cls, return_value={"type": "message_stop", "message": {"stop_reason": "tool_use"}} ), ), - unittest.mock.Mock( - message=unittest.mock.Mock( - usage=unittest.mock.Mock( - model_dump=unittest.mock.Mock(return_value={"input_tokens": 0, "output_tokens": 0}) - ), - ), - ), ] - mock_context = unittest.mock.AsyncMock() - mock_context.__aenter__.return_value = agenerator(events) - anthropic_client.messages.stream.return_value = mock_context + anthropic_client.messages.stream.return_value = generate_mock_stream_context( + events, + final_message=unittest.mock.Mock( + usage=unittest.mock.Mock( + model_dump=unittest.mock.Mock(return_value={"input_tokens": 0, "output_tokens": 0}) + ), + ), + ) stream = model.structured_output(test_output_model_cls, messages) events = await alist(stream)