Skip to content

Commit 6aadcc3

Browse files
committed
feat: use rich sys prompt
1 parent 49c455e commit 6aadcc3

2 files changed

Lines changed: 65 additions & 11 deletions

File tree

src/strands/models/model.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,18 @@ def _count_content_block_tokens(block: ContentBlock, encoding: Any) -> int:
4949
total += len(encoding.encode(block["text"]))
5050

5151
if "toolUse" in block:
52+
tool_use = block["toolUse"]
53+
total += len(encoding.encode(tool_use.get("name", "")))
5254
try:
53-
total += len(encoding.encode(json.dumps(block["toolUse"])))
55+
total += len(encoding.encode(json.dumps(tool_use.get("input", {}))))
5456
except (TypeError, ValueError):
5557
pass
5658

5759
if "toolResult" in block:
58-
try:
59-
total += len(encoding.encode(json.dumps(block["toolResult"])))
60-
except (TypeError, ValueError):
61-
pass
60+
tool_result = block["toolResult"]
61+
for item in tool_result.get("content", []):
62+
if "text" in item:
63+
total += len(encoding.encode(item["text"]))
6264

6365
if "reasoningContent" in block:
6466
reasoning = block["reasoningContent"]
@@ -74,9 +76,10 @@ def _count_content_block_tokens(block: ContentBlock, encoding: Any) -> int:
7476

7577
if "citationsContent" in block:
7678
citations = block["citationsContent"]
77-
for item in citations.get("content", []):
78-
if "text" in item:
79-
total += len(encoding.encode(item["text"]))
79+
if "content" in citations:
80+
for citation_item in citations["content"]:
81+
if "text" in citation_item:
82+
total += len(encoding.encode(citation_item["text"]))
8083

8184
return total
8285

@@ -85,6 +88,7 @@ def _estimate_tokens_with_tiktoken(
8588
messages: Messages,
8689
tool_specs: list[ToolSpec] | None = None,
8790
system_prompt: str | None = None,
91+
system_prompt_content: list[SystemContentBlock] | None = None,
8892
) -> int:
8993
"""Estimate tokens by serializing messages/tools to text and counting with tiktoken.
9094
@@ -97,6 +101,11 @@ def _estimate_tokens_with_tiktoken(
97101
if system_prompt:
98102
total += len(encoding.encode(system_prompt))
99103

104+
if system_prompt_content:
105+
for block in system_prompt_content:
106+
if "text" in block:
107+
total += len(encoding.encode(block["text"]))
108+
100109
for message in messages:
101110
for block in message["content"]:
102111
total += _count_content_block_tokens(block, encoding)
@@ -224,6 +233,7 @@ def _estimate_tokens(
224233
messages: Messages,
225234
tool_specs: list[ToolSpec] | None = None,
226235
system_prompt: str | None = None,
236+
system_prompt_content: list[SystemContentBlock] | None = None,
227237
) -> int:
228238
"""Estimate token count for the given input before sending to the model.
229239
@@ -239,11 +249,12 @@ def _estimate_tokens(
239249
messages: List of message objects to estimate tokens for.
240250
tool_specs: List of tool specifications to include in the estimate.
241251
system_prompt: System prompt to include in the estimate.
252+
system_prompt_content: System prompt content blocks to include in the estimate.
242253
243254
Returns:
244255
Estimated total input tokens.
245256
"""
246-
return _estimate_tokens_with_tiktoken(messages, tool_specs, system_prompt)
257+
return _estimate_tokens_with_tiktoken(messages, tool_specs, system_prompt, system_prompt_content)
247258

248259

249260
class _ModelPlugin(Plugin):

tests/strands/models/test_model.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def test_estimate_tokens_skips_binary_content(model):
313313
assert model._estimate_tokens(messages=messages) == 0
314314

315315

316-
def test_estimate_tokens_tool_result_with_bytes(model):
316+
def test_estimate_tokens_tool_result_with_bytes_only(model):
317317
messages = [
318318
{
319319
"role": "user",
@@ -332,6 +332,28 @@ def test_estimate_tokens_tool_result_with_bytes(model):
332332
assert result == 0
333333

334334

335+
def test_estimate_tokens_tool_result_with_text_and_bytes(model):
336+
messages = [
337+
{
338+
"role": "user",
339+
"content": [
340+
{
341+
"toolResult": {
342+
"toolUseId": "123",
343+
"content": [
344+
{"text": "Here is the screenshot"},
345+
{"image": {"format": "png", "source": {"bytes": b"image data"}}},
346+
],
347+
"status": "success",
348+
}
349+
}
350+
],
351+
}
352+
]
353+
result = model._estimate_tokens(messages=messages)
354+
assert result > 0
355+
356+
335357
def test_estimate_tokens_guard_content_block(model):
336358
messages = [
337359
{
@@ -359,7 +381,8 @@ def test_estimate_tokens_tool_use_with_bytes(model):
359381
}
360382
]
361383
result = model._estimate_tokens(messages=messages)
362-
assert result == 0
384+
# Should still count the tool name even though input has non-serializable bytes
385+
assert result > 0
363386

364387

365388
def test_estimate_tokens_non_serializable_tool_spec(model, messages):
@@ -393,6 +416,25 @@ def test_estimate_tokens_citations_block(model):
393416
assert result > 0
394417

395418

419+
def test_estimate_tokens_system_prompt_content(model):
420+
result = model._estimate_tokens(
421+
messages=[],
422+
system_prompt_content=[{"text": "You are a helpful assistant."}],
423+
)
424+
assert result > 0
425+
426+
427+
def test_estimate_tokens_system_prompt_content_with_cache_point(model):
428+
result = model._estimate_tokens(
429+
messages=[],
430+
system_prompt_content=[
431+
{"text": "You are a helpful assistant."},
432+
{"cachePoint": {"type": "default"}},
433+
],
434+
)
435+
assert result > 0
436+
437+
396438
def test_estimate_tokens_all_inputs(model):
397439
messages = [
398440
{"role": "user", "content": [{"text": "hello world"}]},
@@ -402,6 +444,7 @@ def test_estimate_tokens_all_inputs(model):
402444
messages=messages,
403445
tool_specs=[{"name": "test", "description": "a test tool", "inputSchema": {"json": {}}}],
404446
system_prompt="Be helpful.",
447+
system_prompt_content=[{"text": "Additional system context."}],
405448
)
406449
assert result > 0
407450

0 commit comments

Comments
 (0)