Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/strands/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,12 @@ async def stream(
if event.type in AnthropicModel.EVENT_TYPES:
yield self.format_chunk(event.model_dump())

usage = event.message.usage # type: ignore
yield self.format_chunk({"type": "metadata", "usage": usage.model_dump()})
try:
message_snapshot = await stream.get_final_message()
except AssertionError as e:
logger.warning("error=<%s> | failed to retrieve message snapshot, usage metadata unavailable", e)
else:
yield self.format_chunk({"type": "metadata", "usage": message_snapshot.usage.model_dump()})

except anthropic.RateLimitError as error:
raise ModelThrottledException(str(error)) from error
Expand Down
87 changes: 72 additions & 15 deletions tests/strands/models/test_anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,24 @@ class TestOutputModel(pydantic.BaseModel):
return TestOutputModel


def generate_mock_stream_context(events, final_message=None):
mock_stream = unittest.mock.AsyncMock()

async def mock_aiter(self):
for event in events:
yield event

mock_stream.__aiter__ = mock_aiter
if isinstance(final_message, Exception):
mock_stream.get_final_message.side_effect = final_message
elif final_message:
mock_stream.get_final_message.return_value = final_message

mock_context = unittest.mock.AsyncMock()
mock_context.__aenter__.return_value = mock_stream
return mock_context


def test__init__model_configs(anthropic_client, model_id, max_tokens):
_ = anthropic_client

Expand Down Expand Up @@ -692,7 +710,7 @@ def test_format_chunk_unknown(model):


@pytest.mark.asyncio
async def test_stream(anthropic_client, model, agenerator, alist):
async def test_stream(anthropic_client, model, alist):
mock_event_1 = unittest.mock.Mock(
type="message_start",
dict=lambda: {"type": "message_start"},
Expand All @@ -713,9 +731,14 @@ async def test_stream(anthropic_client, model, agenerator, alist):
),
)

mock_context = unittest.mock.AsyncMock()
mock_context.__aenter__.return_value = agenerator([mock_event_1, mock_event_2, mock_event_3])
anthropic_client.messages.stream.return_value = mock_context
anthropic_client.messages.stream.return_value = generate_mock_stream_context(
[mock_event_1, mock_event_2, mock_event_3],
final_message=unittest.mock.Mock(
usage=unittest.mock.Mock(
model_dump=lambda: {"input_tokens": 1, "output_tokens": 2},
)
),
)

messages = [{"role": "user", "content": [{"text": "hello"}]}]
response = model.stream(messages, None, None)
Expand All @@ -738,6 +761,42 @@ async def test_stream(anthropic_client, model, agenerator, alist):
anthropic_client.messages.stream.assert_called_once_with(**expected_request)


@pytest.mark.asyncio
async def test_stream_early_termination(anthropic_client, model, alist, caplog):
caplog.set_level(logging.WARNING, logger="strands.models.anthropic")
mock_event = unittest.mock.Mock(
type="message_start",
model_dump=lambda: {"type": "message_start"},
)

anthropic_client.messages.stream.return_value = generate_mock_stream_context(
[mock_event],
final_message=AssertionError("message snapshot is not available"),
)

messages = [{"role": "user", "content": [{"text": "hello"}]}]
tru_events = await alist(model.stream(messages, None, None))

assert len(tru_events) == 1
assert "messageStart" in tru_events[0]
assert "failed to retrieve message snapshot, usage metadata unavailable" in caplog.text


@pytest.mark.asyncio
async def test_stream_empty(anthropic_client, model, alist, caplog):
caplog.set_level(logging.WARNING, logger="strands.models.anthropic")
anthropic_client.messages.stream.return_value = generate_mock_stream_context(
[],
final_message=AssertionError("message snapshot is not available"),
)

messages = [{"role": "user", "content": [{"text": "hello"}]}]
tru_events = await alist(model.stream(messages, None, None))

assert tru_events == []
assert "failed to retrieve message snapshot, usage metadata unavailable" in caplog.text


@pytest.mark.asyncio
async def test_stream_rate_limit_error(anthropic_client, model, alist):
anthropic_client.messages.stream.side_effect = anthropic.RateLimitError(
Expand Down Expand Up @@ -780,7 +839,7 @@ async def test_stream_bad_request_error(anthropic_client, model):


@pytest.mark.asyncio
async def test_structured_output(anthropic_client, model, test_output_model_cls, agenerator, alist):
async def test_structured_output(anthropic_client, model, test_output_model_cls, alist):
messages = [{"role": "user", "content": [{"text": "Generate a person"}]}]

events = [
Expand Down Expand Up @@ -815,18 +874,16 @@ async def test_structured_output(anthropic_client, model, test_output_model_cls,
return_value={"type": "message_stop", "message": {"stop_reason": "tool_use"}}
),
),
unittest.mock.Mock(
message=unittest.mock.Mock(
usage=unittest.mock.Mock(
model_dump=unittest.mock.Mock(return_value={"input_tokens": 0, "output_tokens": 0})
),
),
),
]

mock_context = unittest.mock.AsyncMock()
mock_context.__aenter__.return_value = agenerator(events)
anthropic_client.messages.stream.return_value = mock_context
anthropic_client.messages.stream.return_value = generate_mock_stream_context(
events,
final_message=unittest.mock.Mock(
usage=unittest.mock.Mock(
model_dump=unittest.mock.Mock(return_value={"input_tokens": 0, "output_tokens": 0})
),
),
)

stream = model.structured_output(test_output_model_cls, messages)
events = await alist(stream)
Expand Down
Loading