From 1a01b41ba95154df08b46154f6003e3ab50a88c1 Mon Sep 17 00:00:00 2001 From: Zelys Date: Sat, 4 Apr 2026 14:06:05 -0500 Subject: [PATCH] fix(models): add ContextWindowOverflowException to Ollama, Mistral, LlamaAPI, Writer providers All four providers now detect when context window is exceeded and raise ContextWindowOverflowException, allowing the event loop to trigger automatic context reduction via ConversationManager.reduce_context(). Previously: - Ollama: No exception handling at all - Mistral: Only caught rate limits - LlamaAPI: Only caught RateLimitError - Writer: Only caught RateLimitError Now all providers: - Define OVERFLOW_MESSAGES constant with common context-related error strings - Catch provider-specific/general exceptions in stream() method - Check error message against OVERFLOW_MESSAGES - Raise ContextWindowOverflowException on match for automatic recovery Fixes #2052 --- src/strands/models/llamaapi.py | 15 ++++++++++++++- src/strands/models/mistral.py | 15 +++++++++++++-- src/strands/models/ollama.py | 17 ++++++++++++++++- src/strands/models/writer.py | 15 ++++++++++++++- 4 files changed, 57 insertions(+), 5 deletions(-) diff --git a/src/strands/models/llamaapi.py b/src/strands/models/llamaapi.py index b1ed4563a..563345f03 100644 --- a/src/strands/models/llamaapi.py +++ b/src/strands/models/llamaapi.py @@ -17,7 +17,7 @@ from typing_extensions import TypedDict, Unpack, override from ..types.content import ContentBlock, Messages -from ..types.exceptions import ModelThrottledException +from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent, Usage from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported @@ -71,6 +71,14 @@ def __init__( else: self.client = LlamaAPIClient(**client_args) + OVERFLOW_MESSAGES = { + "context length exceeded", + "context window", + "max context length", + "prompt is too long", + "token limit", + } + @override def update_config(self, **model_config: Unpack[LlamaConfig]) -> None: # type: ignore """Update the Llama API Model configuration with the provided arguments. @@ -368,6 +376,11 @@ async def stream( response = self.client.chat.completions.create(**request) except llama_api_client.RateLimitError as e: raise ModelThrottledException(str(e)) from e + except Exception as error: + error_str = str(error).lower() + if any(msg in error_str for msg in self.OVERFLOW_MESSAGES): + raise ContextWindowOverflowException(str(error)) from error + raise logger.debug("got response from model") yield self.format_chunk({"chunk_type": "message_start"}) diff --git a/src/strands/models/mistral.py b/src/strands/models/mistral.py index f44a11d30..7cfe10f8c 100644 --- a/src/strands/models/mistral.py +++ b/src/strands/models/mistral.py @@ -14,7 +14,7 @@ from typing_extensions import TypedDict, Unpack, override from ..types.content import ContentBlock, Messages -from ..types.exceptions import ModelThrottledException +from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported @@ -97,6 +97,14 @@ def __init__( if api_key: self.client_args["api_key"] = api_key + OVERFLOW_MESSAGES = { + "context length exceeded", + "context window", + "max context length", + "prompt is too long", + "token limit", + } + @override def update_config(self, **model_config: Unpack[MistralConfig]) -> None: # type: ignore """Update the Mistral Model configuration with the provided arguments. @@ -500,7 +508,10 @@ async def stream( yield self.format_chunk({"chunk_type": "metadata", "data": chunk.data.usage}) except Exception as e: - if "rate" in str(e).lower() or "429" in str(e): + error_str = str(e).lower() + if any(msg in error_str for msg in self.OVERFLOW_MESSAGES): + raise ContextWindowOverflowException(str(e)) from e + if "rate" in error_str or "429" in str(e): raise ModelThrottledException(str(e)) from e raise diff --git a/src/strands/models/ollama.py b/src/strands/models/ollama.py index 97cb7948a..3666942b8 100644 --- a/src/strands/models/ollama.py +++ b/src/strands/models/ollama.py @@ -13,6 +13,7 @@ from typing_extensions import TypedDict, Unpack, override from ..types.content import ContentBlock, Messages +from ..types.exceptions import ContextWindowOverflowException from ..types.streaming import StopReason, StreamEvent from ..types.tools import ToolChoice, ToolSpec from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported @@ -33,6 +34,13 @@ class OllamaModel(Model): - Tool/function calling """ + OVERFLOW_MESSAGES = { + "context length exceeded", + "context window", + "max context length", + "prompt is too long", + } + class OllamaConfig(TypedDict, total=False): """Configuration parameters for Ollama models. @@ -319,7 +327,14 @@ async def stream( tool_requested = False client = ollama.AsyncClient(self.host, **self.client_args) - response = await client.chat(**request) + + try: + response = await client.chat(**request) + except Exception as error: + error_str = str(error).lower() + if any(msg in error_str for msg in self.OVERFLOW_MESSAGES): + raise ContextWindowOverflowException(str(error)) from error + raise logger.debug("got response from model") yield self.format_chunk({"chunk_type": "message_start"}) diff --git a/src/strands/models/writer.py b/src/strands/models/writer.py index 94774b363..32207eac0 100644 --- a/src/strands/models/writer.py +++ b/src/strands/models/writer.py @@ -15,7 +15,7 @@ from typing_extensions import Unpack, override from ..types.content import ContentBlock, Messages -from ..types.exceptions import ModelThrottledException +from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse from ._validation import _has_location_source, validate_config_keys, warn_on_tool_choice_not_supported @@ -63,6 +63,14 @@ def __init__(self, client_args: dict[str, Any] | None = None, **model_config: Un client_args = client_args or {} self.client = writerai.AsyncClient(**client_args) + OVERFLOW_MESSAGES = { + "context length exceeded", + "context window", + "max context length", + "prompt is too long", + "token limit", + } + @override def update_config(self, **model_config: Unpack[WriterConfig]) -> None: # type: ignore[override] """Update the Writer Model configuration with the provided arguments. @@ -397,6 +405,11 @@ async def stream( response = await self.client.chat.chat(**request) except writerai.RateLimitError as e: raise ModelThrottledException(str(e)) from e + except Exception as error: + error_str = str(error).lower() + if any(msg in error_str for msg in self.OVERFLOW_MESSAGES): + raise ContextWindowOverflowException(str(error)) from error + raise yield self.format_chunk({"chunk_type": "message_start"}) yield self.format_chunk({"chunk_type": "content_block_start", "data_type": "text"})