From fd785cdaa4ba00458237474b8552c611b48fa73e Mon Sep 17 00:00:00 2001 From: Florent Lacroute Date: Thu, 2 Apr 2026 16:21:28 +0200 Subject: [PATCH 1/2] feat: add token-aware conversation manager with proactive compaction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Token-based context management that uses actual inputTokens from model responses to decide when to compact, instead of counting messages. Four-pass compaction strategy: 1. Sanitize — strip ANSI escape codes, collapse repeated lines 2. Truncate — replace oversized tool results with placeholders 3. Summarize — use model.stream() to summarize older messages 4. Trim — remove oldest messages as last resort The first user message is always preserved so the agent never loses sight of its original task. Summarization calls model.stream() directly, avoiding re-entrant agent invocation and deadlocks on _invocation_lock. --- .../agent/conversation_manager/__init__.py | 4 + .../token_aware_conversation_manager.py | 420 ++++++++++++++++++ 2 files changed, 424 insertions(+) create mode 100644 src/strands/agent/conversation_manager/token_aware_conversation_manager.py diff --git a/src/strands/agent/conversation_manager/__init__.py b/src/strands/agent/conversation_manager/__init__.py index c59623215..6fc23b453 100644 --- a/src/strands/agent/conversation_manager/__init__.py +++ b/src/strands/agent/conversation_manager/__init__.py @@ -8,6 +8,8 @@ size while preserving conversation coherence - SummarizingConversationManager: An implementation that summarizes older context instead of simply trimming it +- TokenAwareConversationManager: An implementation that uses actual input token counts to + proactively compact context with a four-pass strategy (sanitize, truncate, summarize, trim) Conversation managers help control memory usage and context length while maintaining relevant conversation state, which is critical for effective agent interactions. @@ -17,10 +19,12 @@ from .null_conversation_manager import NullConversationManager from .sliding_window_conversation_manager import SlidingWindowConversationManager from .summarizing_conversation_manager import SummarizingConversationManager +from .token_aware_conversation_manager import TokenAwareConversationManager __all__ = [ "ConversationManager", "NullConversationManager", "SlidingWindowConversationManager", "SummarizingConversationManager", + "TokenAwareConversationManager", ] diff --git a/src/strands/agent/conversation_manager/token_aware_conversation_manager.py b/src/strands/agent/conversation_manager/token_aware_conversation_manager.py new file mode 100644 index 000000000..dd75bab27 --- /dev/null +++ b/src/strands/agent/conversation_manager/token_aware_conversation_manager.py @@ -0,0 +1,420 @@ +"""Token-aware conversation manager with LLM summarization. + +Designed for autonomous agent workloads with long tool-call cycles. Uses actual input token count (from model +responses) to decide when to compact, and summarizes older context instead of just truncating. + +Four-pass compaction strategy: + 1. Sanitize — strip ANSI escape codes, collapse repeated lines + 2. Truncate — replace oversized tool result content with a placeholder + 3. Summarize — use the LLM to summarize older messages (preserves context) + 4. Trim — remove oldest messages as last resort (loses context) +""" + +import logging +import re +from typing import TYPE_CHECKING, Any + +from typing_extensions import override + +from ..._async import run_async +from ...event_loop.streaming import process_stream +from ...hooks import BeforeModelCallEvent, HookRegistry +from ...types.content import Message +from ...types.exceptions import ContextWindowOverflowException +from .conversation_manager import ConversationManager + +if TYPE_CHECKING: + from ..agent import Agent + +logger = logging.getLogger(__name__) + +# ANSI escape sequences: CSI codes, OSC sequences, charset designators, carriage returns +_ANSI_RE = re.compile(r"\x1b\[[0-9;]*[A-Za-z]|\x1b\].*?\x07|\x1b\([A-Z]|\r") +_TOOL_RESULT_TRUNCATED = "The tool result was too large!" + +SUMMARIZATION_PROMPT = ( + "You are a conversation summarizer for an autonomous AI agent. " + "Create a concise summary preserving:\n" + "- Current task/goal and progress\n" + "- Key decisions made and reasoning\n" + "- Important file paths, code changes, and tool results\n" + "- Errors encountered and how they were resolved\n" + "- Pending work items\n\n" + "Format as bullet points. Be concise but don't lose critical context." +) + + +class TokenAwareConversationManager(ConversationManager): + """Manages conversation based on token count with LLM summarization. + + Uses actual ``inputTokens`` from model responses to decide when to compact. Unlike + ``SlidingWindowConversationManager`` which counts messages, this manager reacts to the real context size the model + processes. + + The first user message (index 0) is always preserved across all compaction passes so the agent never loses sight of + its original task. + """ + + def __init__( + self, + compact_threshold: int = 150_000, + preserve_recent: int = 6, + should_truncate_results: bool = True, + ): + """Initialize the token-aware conversation manager. + + Args: + compact_threshold: Trigger compaction when inputTokens exceeds this value. Default 150 000 leaves ~50K + headroom on a 200K context window. + preserve_recent: Minimum number of recent messages to always keep. + should_truncate_results: Replace oversized tool result content with a placeholder as a first reduction + strategy. + """ + super().__init__() + self.compact_threshold = compact_threshold + self.preserve_recent = preserve_recent + self.should_truncate_results = should_truncate_results + self._last_input_tokens: int = 0 + self._model_call_count: int = 0 + self._summary_message: Message | None = None + + # ------------------------------------------------------------------ + # Hook registration + # ------------------------------------------------------------------ + + @override + def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: + """Register hooks to track token usage and apply proactive management. + + Args: + registry: The hook registry to register callbacks with. + **kwargs: Additional keyword arguments for future extensibility. + """ + super().register_hooks(registry, **kwargs) + registry.add_callback(BeforeModelCallEvent, self._on_before_model_call) + + def _on_before_model_call(self, event: BeforeModelCallEvent) -> None: + """Proactive management: read token usage from the previous cycle and check budget. + + By the time this hook fires, ``start_cycle`` has already appended a new empty cycle to + the invocation. The *previous* cycle (``cycles[-2]``) holds the most recent completed + token counts. Reading ``cycles[-1]`` would always yield zero. + + Args: + event: The before model call event. + """ + self._model_call_count += 1 + + # Read token count from the most recent *completed* cycle (the one before the current empty one) + agent = event.agent + invocation = agent.event_loop_metrics.latest_agent_invocation + if invocation and len(invocation.cycles) >= 2: + self._last_input_tokens = invocation.cycles[-2].usage.get("inputTokens", 0) + + if self._last_input_tokens > 0: + self.apply_management(agent) + + # ------------------------------------------------------------------ + # State persistence + # ------------------------------------------------------------------ + + @override + def get_state(self) -> dict[str, Any]: + """Return serialisable state for session persistence. + + Returns: + Dictionary containing the manager's state. + """ + state = super().get_state() + state["last_input_tokens"] = self._last_input_tokens + state["model_call_count"] = self._model_call_count + state["summary_message"] = self._summary_message + return state + + @override + def restore_from_session(self, state: dict[str, Any]) -> list[Message] | None: + """Restore manager state from a previous session. + + Args: + state: Previous state of the conversation manager. + + Returns: + Optionally returns the previous conversation summary if it exists. + """ + result = super().restore_from_session(state) + self._last_input_tokens = state.get("last_input_tokens", 0) + self._model_call_count = state.get("model_call_count", 0) + self._summary_message = state.get("summary_message") + return [self._summary_message] if self._summary_message else result + + # ------------------------------------------------------------------ + # Core management interface + # ------------------------------------------------------------------ + + @override + def apply_management(self, agent: "Agent", **kwargs: Any) -> None: + """Proactively compact when token usage exceeds the threshold. + + Args: + agent: The agent whose conversation history will be managed. + **kwargs: Additional keyword arguments for future extensibility. + """ + if self._last_input_tokens <= self.compact_threshold: + return + + logger.info( + "input_tokens=<%d>, threshold=<%d>, message_count=<%d> | compacting conversation", + self._last_input_tokens, + self.compact_threshold, + len(agent.messages), + ) + self._compact(agent) + + @override + def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: + """Reactive reduction when a ``ContextWindowOverflowException`` is caught. + + Args: + agent: The agent whose conversation history will be reduced. + e: The exception that triggered the context reduction, if any. + **kwargs: Additional keyword arguments for future extensibility. + """ + logger.warning("overflow= | reduce_context triggered") + self._compact(agent) + + # ------------------------------------------------------------------ + # Internal compaction logic + # ------------------------------------------------------------------ + + def _compact(self, agent: "Agent") -> None: + """Run the four-pass compaction strategy. + + 1. Sanitize all tool results (ANSI strip + dedup) + 2. Truncate oversized tool results (oldest first, skip first user message) + 3. Summarize older messages via LLM (preserve first user message) + 4. Hard-trim oldest messages as last resort (preserve first user message) + + The first user message (index 0) is always preserved — it contains the original task/prompt and must survive + compaction so the agent never loses sight of what it was asked to do. + + Args: + agent: The agent whose conversation history will be compacted. + + Raises: + ContextWindowOverflowException: If the context cannot be reduced further. + """ + messages = agent.messages + if len(messages) <= self.preserve_recent: + raise ContextWindowOverflowException("Cannot reduce: at minimum message count") + + # The first message is the original user prompt — never touch it. + protect_start = 1 + + # Pass 1: sanitize all tool results + self._sanitize_all_tool_results(messages) + + # Pass 2: truncate tool results (oldest first, skip protected + recent) + if self.should_truncate_results: + truncatable_end = len(messages) - self.preserve_recent + truncated_count = 0 + for idx in range(protect_start, truncatable_end): + if self._truncate_tool_results_in_message(messages, idx): + truncated_count += 1 + if truncated_count > 0: + logger.info("truncated_count=<%d> | truncated tool results", truncated_count) + return # re-try with truncated results first + + # Pass 3: summarize older messages using the LLM + summarize_end = len(messages) - self.preserve_recent + messages_to_summarize_count = summarize_end - protect_start + if messages_to_summarize_count > 0: + split = self._adjust_split_for_tool_pairs(messages, summarize_end) + if split > protect_start: + try: + first_message = messages[0] + old_messages = messages[protect_start:split] + remaining = messages[split:] + summary = self._generate_summary(old_messages, agent) + self.removed_message_count += len(old_messages) + if self._summary_message: + self.removed_message_count -= 1 + self._summary_message = summary + messages[:] = [first_message, summary] + remaining + logger.info( + "summarized_count=<%d>, remaining=<%d> | summarized older messages", + len(old_messages), + len(messages), + ) + return + except Exception as exc: + logger.warning("error=<%s> | summarization failed, falling back to trim", exc) + + # Pass 4: hard-trim as last resort (preserve first message) + trim_target = max(self.preserve_recent, len(messages) // 2) + trim_index = len(messages) - trim_target + trim_index = max(trim_index, protect_start) + trim_index = self._adjust_split_for_tool_pairs(messages, trim_index) + if trim_index <= protect_start: + raise ContextWindowOverflowException("Unable to trim conversation context!") + + first_message = messages[0] + trimmed_count = trim_index - protect_start + self.removed_message_count += trimmed_count + messages[:] = [first_message] + messages[trim_index:] + logger.info( + "trimmed_count=<%d>, remaining=<%d> | trimmed oldest messages", + trimmed_count, + len(messages), + ) + + # ------------------------------------------------------------------ + # LLM summarization + # ------------------------------------------------------------------ + + @staticmethod + def _generate_summary(old_messages: list[Message], agent: "Agent") -> Message: + """Summarize older messages by calling the agent's model directly. + + Bypasses the full agent pipeline (lock, metrics, traces, tool loop) and simply asks the underlying model to + summarize the conversation. + + Args: + old_messages: The messages to summarize. + agent: The parent agent whose model is used. + + Returns: + A message containing the conversation summary with role ``assistant``. + + Raises: + RuntimeError: If no response is received from the model. + """ + summarization_messages: list[Message] = list(old_messages) + [ + {"role": "user", "content": [{"text": "Summarize this conversation concisely."}]} + ] + + async def _call_model() -> Message: + chunks = agent.model.stream( + summarization_messages, + tool_specs=None, + system_prompt=SUMMARIZATION_PROMPT, + ) + + result_message: Message | None = None + async for event in process_stream(chunks): + if "stop" in event: + _, result_message, _, _ = event["stop"] + + if result_message is None: + raise RuntimeError("Failed to generate summary: no response from model") + return result_message + + message = run_async(_call_model) + # Keep role as assistant — the summary sits between the preserved first user message + # and the remaining conversation, maintaining proper user/assistant alternation. + return message + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @staticmethod + def _sanitize_all_tool_results(messages: list[Message]) -> None: + """Strip ANSI codes and collapse repeated lines in all tool results. + + Args: + messages: The full list of messages to sanitize in-place. + """ + for msg in messages: + for content in msg.get("content", []): + if isinstance(content, dict) and "toolResult" in content: + for item in content["toolResult"].get("content", []): + text = item.get("text") + if text and ("\x1b" in text or "\r" in text): + item["text"] = _sanitize_text(text) + + @staticmethod + def _truncate_tool_results_in_message(messages: list[Message], idx: int) -> bool: + """Replace tool result content in a specific message with a placeholder. + + Args: + messages: The full list of messages. + idx: Index of the message to truncate. + + Returns: + True if any tool results were truncated. + """ + msg = messages[idx] + changed = False + for content in msg.get("content", []): + if isinstance(content, dict) and "toolResult" in content: + tr = content["toolResult"] + for item in tr.get("content", []): + text = item.get("text", "") + if text and text != _TOOL_RESULT_TRUNCATED: + tr["status"] = "error" + tr["content"] = [{"text": _TOOL_RESULT_TRUNCATED}] + changed = True + break + return changed + + @staticmethod + def _adjust_split_for_tool_pairs(messages: list[Message], split: int) -> int: + """Adjust split forward so it doesn't break toolUse/toolResult pairs. + + Args: + messages: The full list of messages. + split: The initially calculated split point. + + Returns: + The adjusted split point. + + Raises: + ContextWindowOverflowException: If no valid split point can be found. + """ + while split < len(messages): + if any("toolResult" in c for c in messages[split]["content"]) or ( + any("toolUse" in c for c in messages[split]["content"]) + and split + 1 < len(messages) + and not any("toolResult" in c for c in messages[split + 1]["content"]) + ): + split += 1 + else: + break + else: + raise ContextWindowOverflowException("Unable to trim conversation context!") + + return split + + +# ------------------------------------------------------------------ +# Module-level helpers +# ------------------------------------------------------------------ + + +def _sanitize_text(text: str) -> str: + """Strip ANSI escape codes and collapse repeated consecutive lines. + + Args: + text: Raw text potentially containing ANSI codes and repeated lines. + + Returns: + Cleaned text with ANSI stripped and consecutive duplicate lines collapsed. + """ + text = _ANSI_RE.sub("", text) + lines = text.split("\n") + result: list[str] = [] + prev: str | None = None + repeat = 0 + for line in lines: + stripped = line.strip() + if stripped == prev and stripped: + repeat += 1 + else: + if repeat > 0: + result.append(f" [repeated {repeat} more time(s)]") + result.append(line) + prev = stripped + repeat = 0 + if repeat > 0: + result.append(f" [repeated {repeat} more time(s)]") + return "\n".join(result) From ffa9bd4215fe87f68f7d067a0a59ab7e4bf01727 Mon Sep 17 00:00:00 2001 From: Florent Lacroute Date: Thu, 2 Apr 2026 16:21:47 +0200 Subject: [PATCH 2/2] test: add unit tests for TokenAwareConversationManager 35 tests covering all four compaction passes, hook callbacks, state persistence, role alternation after summarization, and edge cases (too few messages, summarization failure fallback). --- .../test_token_aware_conversation_manager.py | 648 ++++++++++++++++++ 1 file changed, 648 insertions(+) create mode 100644 tests/strands/agent/test_token_aware_conversation_manager.py diff --git a/tests/strands/agent/test_token_aware_conversation_manager.py b/tests/strands/agent/test_token_aware_conversation_manager.py new file mode 100644 index 000000000..2e9aecd70 --- /dev/null +++ b/tests/strands/agent/test_token_aware_conversation_manager.py @@ -0,0 +1,648 @@ +"""Tests for TokenAwareConversationManager.""" + +from typing import TYPE_CHECKING, cast +from unittest.mock import Mock, patch + +import pytest + +if TYPE_CHECKING: + from strands.agent.agent import Agent + +from strands.agent.conversation_manager.token_aware_conversation_manager import ( + SUMMARIZATION_PROMPT, + TokenAwareConversationManager, + _sanitize_text, +) +from strands.hooks.events import BeforeModelCallEvent +from strands.hooks.registry import HookRegistry +from strands.types.content import Message, Messages +from strands.types.exceptions import ContextWindowOverflowException + +# --------------------------------------------------------------------------- +# Async mock helpers (same pattern as test_summarizing_conversation_manager) +# --------------------------------------------------------------------------- + + +async def _mock_model_stream(response_text: str): + """Create an async generator that yields stream events for a text response.""" + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockStart": {"start": {}}} + yield {"contentBlockDelta": {"delta": {"text": response_text}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + +async def _mock_model_stream_error(error: Exception): + """Async generator that raises an exception, simulating a model failure.""" + raise error + yield # pragma: no cover – makes this a generator + + +class MockAgent: + """Mock agent for testing token-aware conversation manager.""" + + def __init__(self, summary_response: str = "Summary of conversation."): + self.summary_response = summary_response + self.system_prompt = "You are helpful." + self.messages: Messages = [] + self.model = Mock() + self.model.stream = Mock(side_effect=lambda *a, **kw: _mock_model_stream(self.summary_response)) + self.tool_registry = Mock() + self.tool_names: list[str] = [] + self.event_loop_metrics = Mock() + invocation = Mock() + cycle = Mock() + cycle.usage = {"inputTokens": 200_000} + invocation.cycles = [cycle] + self.event_loop_metrics.latest_agent_invocation = invocation + + +def _create_mock_agent(summary_response: str = "Summary of conversation.") -> "MockAgent": + return MockAgent(summary_response) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def manager(): + """Token-aware manager with low threshold for testing.""" + return TokenAwareConversationManager(compact_threshold=100, preserve_recent=2) + + +# --------------------------------------------------------------------------- +# Init tests +# --------------------------------------------------------------------------- + + +def test_init_defaults(): + """Test initialization with default values.""" + m = TokenAwareConversationManager() + assert m.compact_threshold == 150_000 + assert m.preserve_recent == 6 + assert m.should_truncate_results is True + assert m._last_input_tokens == 0 + assert m._model_call_count == 0 + assert m._summary_message is None + + +def test_init_custom(): + """Test initialization with custom values.""" + m = TokenAwareConversationManager(compact_threshold=50_000, preserve_recent=4, should_truncate_results=False) + assert m.compact_threshold == 50_000 + assert m.preserve_recent == 4 + assert m.should_truncate_results is False + + +# --------------------------------------------------------------------------- +# Hook registration +# --------------------------------------------------------------------------- + + +def test_register_hooks(): + """Test that hooks are registered with the registry.""" + m = TokenAwareConversationManager() + registry = HookRegistry() + m.register_hooks(registry) + assert registry.has_callbacks() + + +# --------------------------------------------------------------------------- +# ANSI sanitization +# --------------------------------------------------------------------------- + + +def test_sanitize_strips_ansi_from_tool_results(manager): + """Test ANSI escape codes are stripped from tool result content.""" + messages: Messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "1", + "content": [{"text": "\x1b[31mred\x1b[0m normal"}], + "status": "success", + } + } + ], + } + ] + manager._sanitize_all_tool_results(messages) + assert messages[0]["content"][0]["toolResult"]["content"][0]["text"] == "red normal" + + +def test_sanitize_collapses_repeated_lines(manager): + """Test that repeated consecutive lines are collapsed when ANSI triggers sanitization.""" + repeated = "\x1b[0mline\nline\nline\nline\nother" + messages: Messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "1", + "content": [{"text": repeated}], + "status": "success", + } + } + ], + } + ] + manager._sanitize_all_tool_results(messages) + result = messages[0]["content"][0]["toolResult"]["content"][0]["text"] + assert "[repeated" in result + assert result.count("line") < 4 + + +def test_sanitize_skips_clean_text(manager): + """Test that messages without ANSI or carriage returns are not modified.""" + messages: Messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "1", + "content": [{"text": "clean text"}], + "status": "success", + } + } + ], + } + ] + manager._sanitize_all_tool_results(messages) + assert messages[0]["content"][0]["toolResult"]["content"][0]["text"] == "clean text" + + +def test_sanitize_text_function(): + """Test the module-level _sanitize_text helper.""" + text = "\x1b[31mhello\x1b[0m\ndup\ndup\ndup\nend" + result = _sanitize_text(text) + assert "\x1b" not in result + assert "[repeated 2 more time(s)]" in result + + +# --------------------------------------------------------------------------- +# Tool result truncation +# --------------------------------------------------------------------------- + + +def test_truncate_tool_results_replaces_content(manager): + """Test that tool result content is replaced with placeholder.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "keep me"}]}, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "1", + "content": [{"text": "big result data"}], + "status": "success", + } + } + ], + }, + {"role": "user", "content": [{"text": "recent"}]}, + ] + changed = manager._truncate_tool_results_in_message(messages, 1) + assert changed + assert messages[1]["content"][0]["toolResult"]["content"][0]["text"] == "The tool result was too large!" + assert messages[1]["content"][0]["toolResult"]["status"] == "error" + + +def test_truncate_skips_already_truncated(manager): + """Test that already-truncated tool results are not modified again.""" + messages: Messages = [ + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "1", + "content": [{"text": "The tool result was too large!"}], + "status": "error", + } + } + ], + }, + ] + changed = manager._truncate_tool_results_in_message(messages, 0) + assert not changed + + +def test_truncate_skips_non_tool_messages(manager): + """Test that non-tool messages are not modified.""" + messages: Messages = [{"role": "user", "content": [{"text": "hello"}]}] + changed = manager._truncate_tool_results_in_message(messages, 0) + assert not changed + + +# --------------------------------------------------------------------------- +# Tool pair adjustment +# --------------------------------------------------------------------------- + + +def test_adjust_split_skips_tool_result(manager): + """Test that split point moves past orphaned toolResult.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "msg"}]}, + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "name": "t", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "1", "content": [{"text": "r"}], "status": "success"}}], + }, + {"role": "assistant", "content": [{"text": "done"}]}, + ] + assert manager._adjust_split_for_tool_pairs(messages, 2) == 3 + + +def test_adjust_split_valid_position_unchanged(manager): + """Test that a valid split point is returned unchanged.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "msg"}]}, + {"role": "assistant", "content": [{"text": "response"}]}, + {"role": "user", "content": [{"text": "msg2"}]}, + ] + assert manager._adjust_split_for_tool_pairs(messages, 1) == 1 + + +def test_adjust_split_raises_on_all_tool_pairs(): + """Test that exception is raised when no valid split point exists.""" + m = TokenAwareConversationManager() + # toolResult without preceding toolUse at every position — no valid split + messages: Messages = [ + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "1", "content": [{"text": "r"}], "status": "success"}}], + }, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "2", "content": [{"text": "r"}], "status": "success"}}], + }, + ] + with pytest.raises(ContextWindowOverflowException): + m._adjust_split_for_tool_pairs(messages, 0) + + +def test_adjust_split_tooluse_with_following_result(manager): + """Test that toolUse followed by toolResult is a valid split point.""" + messages: Messages = [ + {"role": "assistant", "content": [{"toolUse": {"toolUseId": "1", "name": "t", "input": {}}}]}, + { + "role": "user", + "content": [{"toolResult": {"toolUseId": "1", "content": [{"text": "r"}], "status": "success"}}], + }, + {"role": "assistant", "content": [{"text": "done"}]}, + ] + # Split at 0: toolUse with toolResult at 1 is valid + assert manager._adjust_split_for_tool_pairs(messages, 0) == 0 + + +# --------------------------------------------------------------------------- +# Compact — too few messages +# --------------------------------------------------------------------------- + + +def test_compact_raises_when_too_few_messages(manager): + """Test that compaction raises when message count <= preserve_recent.""" + agent = cast("Agent", _create_mock_agent()) + agent.messages = [{"role": "user", "content": [{"text": "only one"}]}] + with pytest.raises(ContextWindowOverflowException, match="Cannot reduce"): + manager._compact(agent) + + +# --------------------------------------------------------------------------- +# Compact — truncation pass +# --------------------------------------------------------------------------- + + +def test_compact_truncates_before_summarizing(manager): + """Test that truncation pass runs before summarization.""" + agent = cast("Agent", _create_mock_agent()) + agent.messages = [ + {"role": "user", "content": [{"text": "task"}]}, + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "1", + "content": [{"text": "big result"}], + "status": "success", + } + } + ], + }, + {"role": "assistant", "content": [{"text": "R1"}]}, + {"role": "user", "content": [{"text": "recent1"}]}, + {"role": "assistant", "content": [{"text": "recent2"}]}, + ] + manager._compact(agent) + # Should have truncated tool result and returned early + assert agent.messages[1]["content"][0]["toolResult"]["content"][0]["text"] == "The tool result was too large!" + + +# --------------------------------------------------------------------------- +# Compact — summarization pass +# --------------------------------------------------------------------------- + + +def test_compact_summarizes_old_messages(manager): + """Test that older messages are summarized preserving first message and proper role alternation.""" + agent = cast("Agent", _create_mock_agent()) + agent.messages = [ + {"role": "user", "content": [{"text": "Original task"}]}, + {"role": "assistant", "content": [{"text": "R1"}]}, + {"role": "user", "content": [{"text": "M2"}]}, + {"role": "assistant", "content": [{"text": "R2"}]}, + {"role": "user", "content": [{"text": "M3"}]}, + {"role": "assistant", "content": [{"text": "R3"}]}, + ] + manager.should_truncate_results = False + manager._compact(agent) + # First message preserved as original user prompt + assert agent.messages[0]["content"][0]["text"] == "Original task" + assert agent.messages[0]["role"] == "user" + # Summary is assistant role (natural alternation: user → assistant → user) + assert agent.messages[1]["role"] == "assistant" + assert "Summary" in agent.messages[1]["content"][0]["text"] + # Recent messages kept + assert len(agent.messages) <= 5 + + +def test_compact_falls_back_to_trim_on_summarization_failure(manager): + """Test that compaction falls back to trimming when summarization fails.""" + agent = cast("Agent", _create_mock_agent()) + agent.model.stream = Mock(side_effect=lambda *a, **kw: _mock_model_stream_error(Exception("fail"))) + agent.messages = [ + {"role": "user", "content": [{"text": "task"}]}, + {"role": "assistant", "content": [{"text": "R1"}]}, + {"role": "user", "content": [{"text": "M2"}]}, + {"role": "assistant", "content": [{"text": "R2"}]}, + {"role": "user", "content": [{"text": "M3"}]}, + {"role": "assistant", "content": [{"text": "R3"}]}, + ] + manager.should_truncate_results = False + manager._compact(agent) + # First message preserved (trim doesn't merge — no summary generated) + assert agent.messages[0]["content"][0]["text"] == "task" + assert len(agent.messages) < 6 + + +def test_compact_tracks_removed_message_count(manager): + """Test that removed_message_count is properly tracked across summarizations.""" + agent = cast("Agent", _create_mock_agent()) + agent.messages = [ + {"role": "user", "content": [{"text": "task"}]}, + {"role": "assistant", "content": [{"text": "R1"}]}, + {"role": "user", "content": [{"text": "M2"}]}, + {"role": "assistant", "content": [{"text": "R2"}]}, + {"role": "user", "content": [{"text": "M3"}]}, + {"role": "assistant", "content": [{"text": "R3"}]}, + ] + manager.should_truncate_results = False + manager._compact(agent) + assert manager.removed_message_count > 0 + + +# --------------------------------------------------------------------------- +# LLM summarization +# --------------------------------------------------------------------------- + + +def test_generate_summary_calls_model_stream(manager): + """Test that _generate_summary calls model.stream() and returns assistant role.""" + agent = cast("Agent", _create_mock_agent()) + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi"}]}, + ] + summary = manager._generate_summary(messages, agent) + assert summary["role"] == "assistant" + assert "Summary" in summary["content"][0]["text"] + agent.model.stream.assert_called_once() + + +def test_generate_summary_uses_summarization_prompt(manager): + """Test that model.stream() is called with the summarization system prompt.""" + agent = cast("Agent", _create_mock_agent()) + messages: Messages = [{"role": "user", "content": [{"text": "test"}]}] + manager._generate_summary(messages, agent) + call_kwargs = agent.model.stream.call_args + assert call_kwargs.kwargs["system_prompt"] == SUMMARIZATION_PROMPT + + +def test_generate_summary_does_not_modify_agent_state(manager): + """Test that agent state is untouched after summarization.""" + agent = _create_mock_agent() + original_prompt = agent.system_prompt + original_messages = agent.messages.copy() + + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello"}]}, + {"role": "assistant", "content": [{"text": "Hi"}]}, + ] + manager._generate_summary(messages, cast("Agent", agent)) + + assert agent.system_prompt == original_prompt + assert agent.messages == original_messages + + +def test_generate_summary_raises_on_model_failure(manager): + """Test that _generate_summary raises when model.stream() fails.""" + agent = cast("Agent", _create_mock_agent()) + agent.model.stream = Mock(side_effect=lambda *a, **kw: _mock_model_stream_error(Exception("Model failed"))) + messages: Messages = [{"role": "user", "content": [{"text": "test"}]}] + + with pytest.raises(Exception, match="Model failed"): + manager._generate_summary(messages, agent) + + +# --------------------------------------------------------------------------- +# Hook callbacks +# --------------------------------------------------------------------------- + + +def test_before_model_call_captures_tokens_from_previous_cycle(manager): + """Test that _on_before_model_call reads tokens from cycles[-2] (the completed cycle).""" + agent = Mock() + # Simulate 2 cycles: one completed (with tokens) and one just started (empty) + completed_cycle = Mock() + completed_cycle.usage = {"inputTokens": 200_000} + current_cycle = Mock() + current_cycle.usage = {"inputTokens": 0} + agent.event_loop_metrics.latest_agent_invocation.cycles = [completed_cycle, current_cycle] + event = BeforeModelCallEvent(agent=agent) + with patch.object(manager, "apply_management"): + manager._on_before_model_call(event) + assert manager._model_call_count == 1 + assert manager._last_input_tokens == 200_000 + + +def test_before_model_call_skips_when_only_one_cycle(manager): + """Test that tokens are not read on the first model call (only 1 cycle, no previous).""" + agent = Mock() + current_cycle = Mock() + current_cycle.usage = {"inputTokens": 0} + agent.event_loop_metrics.latest_agent_invocation.cycles = [current_cycle] + event = BeforeModelCallEvent(agent=agent) + manager._on_before_model_call(event) + assert manager._last_input_tokens == 0 + assert manager._model_call_count == 1 + + +def test_before_model_call_skips_management_when_no_invocation(manager): + """Test that management is skipped when no invocation exists.""" + agent = Mock() + agent.event_loop_metrics.latest_agent_invocation = None + event = BeforeModelCallEvent(agent=agent) + with patch.object(manager, "apply_management") as mock_apply: + manager._on_before_model_call(event) + mock_apply.assert_not_called() + + +def test_before_model_call_triggers_management_above_threshold(manager): + """Test that management is triggered when previous cycle tokens exceed threshold.""" + agent = Mock() + completed_cycle = Mock() + completed_cycle.usage = {"inputTokens": 200_000} + current_cycle = Mock() + current_cycle.usage = {"inputTokens": 0} + agent.event_loop_metrics.latest_agent_invocation.cycles = [completed_cycle, current_cycle] + event = BeforeModelCallEvent(agent=agent) + with patch.object(manager, "apply_management") as mock_apply: + manager._on_before_model_call(event) + mock_apply.assert_called_once_with(agent) + + +# --------------------------------------------------------------------------- +# apply_management / reduce_context +# --------------------------------------------------------------------------- + + +def test_apply_management_skips_below_threshold(manager): + """Test that apply_management does nothing when below threshold.""" + manager._last_input_tokens = 50 # below 100 threshold + agent = Mock() + agent.messages = [] + with patch.object(manager, "_compact") as mock_compact: + manager.apply_management(agent) + mock_compact.assert_not_called() + + +def test_apply_management_triggers_above_threshold(manager): + """Test that apply_management triggers compaction when above threshold.""" + manager._last_input_tokens = 200 + agent = Mock() + agent.messages = [{"role": "user", "content": [{"text": f"msg{i}"}]} for i in range(10)] + with patch.object(manager, "_compact") as mock_compact: + manager.apply_management(agent) + mock_compact.assert_called_once_with(agent) + + +def test_reduce_context_calls_compact(manager): + """Test that reduce_context delegates to _compact.""" + agent = Mock() + with patch.object(manager, "_compact") as mock_compact: + manager.reduce_context(agent) + mock_compact.assert_called_once_with(agent) + + +# --------------------------------------------------------------------------- +# State persistence +# --------------------------------------------------------------------------- + + +def test_get_state(manager): + """Test that get_state returns complete manager state.""" + manager._last_input_tokens = 1000 + manager._model_call_count = 5 + state = manager.get_state() + assert state["last_input_tokens"] == 1000 + assert state["model_call_count"] == 5 + assert state["__name__"] == "TokenAwareConversationManager" + assert state["removed_message_count"] == 0 + assert state["summary_message"] is None + + +def test_restore_from_session_with_summary(manager): + """Test that restore_from_session restores all state including summary.""" + summary: Message = {"role": "user", "content": [{"text": "prev summary"}]} + state = { + "__name__": "TokenAwareConversationManager", + "removed_message_count": 3, + "last_input_tokens": 500, + "model_call_count": 10, + "summary_message": summary, + } + result = manager.restore_from_session(state) + assert manager._last_input_tokens == 500 + assert manager._model_call_count == 10 + assert manager._summary_message == summary + assert manager.removed_message_count == 3 + assert result == [summary] + + +def test_restore_from_session_without_summary(manager): + """Test that restore_from_session returns None when no summary exists.""" + state = { + "__name__": "TokenAwareConversationManager", + "removed_message_count": 0, + "last_input_tokens": 0, + "model_call_count": 0, + "summary_message": None, + } + result = manager.restore_from_session(state) + assert result is None + + +def test_restore_from_session_wrong_name_raises(manager): + """Test that restore raises with mismatched manager name.""" + state = { + "__name__": "SlidingWindowConversationManager", + "removed_message_count": 0, + } + with pytest.raises(ValueError, match="Invalid conversation manager state"): + manager.restore_from_session(state) + + +# --------------------------------------------------------------------------- +# Second summarization properly accounts for previous summary +# --------------------------------------------------------------------------- + + +def test_second_compact_does_not_double_count_summary(manager): + """Test that removed_message_count subtracts previous summary on re-summarization.""" + agent = cast("Agent", _create_mock_agent()) + agent.messages = [ + {"role": "user", "content": [{"text": "task"}]}, + {"role": "assistant", "content": [{"text": "R1"}]}, + {"role": "user", "content": [{"text": "M2"}]}, + {"role": "assistant", "content": [{"text": "R2"}]}, + {"role": "user", "content": [{"text": "M3"}]}, + {"role": "assistant", "content": [{"text": "R3"}]}, + ] + manager.should_truncate_results = False + + # First compact + manager._compact(agent) + first_removed = manager.removed_message_count + + # Add more messages + agent.messages.extend( + [ + {"role": "user", "content": [{"text": "M4"}]}, + {"role": "assistant", "content": [{"text": "R4"}]}, + {"role": "user", "content": [{"text": "M5"}]}, + {"role": "assistant", "content": [{"text": "R5"}]}, + ] + ) + + # Second compact + manager._compact(agent) + # Should have subtracted 1 for the previous summary message + assert manager.removed_message_count >= first_removed