Skip to content

Commit 636cf86

Browse files
Emaan Khanemaan-c
authored andcommitted
feat: add streaming to direct tool calls
1 parent ae19308 commit 636cf86

3 files changed

Lines changed: 622 additions & 84 deletions

File tree

src/strands/tools/_caller.py

Lines changed: 186 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88
"""
99

1010
import json
11+
import logging
1112
import random
1213
import weakref
13-
from collections.abc import Callable
14+
from collections.abc import AsyncIterator, Iterator
1415
from typing import TYPE_CHECKING, Any
1516

1617
from .._async import run_async
1718
from ..tools.executors._executor import ToolExecutor
18-
from ..types._events import ToolInterruptEvent
19+
from ..types._events import ToolInterruptEvent, TypedEvent
1920
from ..types.content import ContentBlock, Message
2021
from ..types.exceptions import ConcurrencyException
2122
from ..types.tools import ToolResult, ToolUse
@@ -24,19 +25,27 @@
2425
from ..agent import Agent
2526
from ..experimental.bidi.agent import BidiAgent
2627

28+
logger = logging.getLogger(__name__)
2729

28-
class _ToolCaller:
29-
"""Call tool as a function."""
3030

31-
def __init__(self, agent: "Agent | BidiAgent") -> None:
32-
"""Initialize instance.
31+
class _ToolExecutor:
32+
"""Callable wrapper for tools that provides streaming methods.
33+
34+
This class enables three execution modes for tools:
35+
1. Synchronous: result = executor(x=5)
36+
2. Sync streaming: for event in executor.stream(x=5)
37+
3. Async streaming: async for event in executor.stream_async(x=5)
38+
"""
39+
40+
def __init__(self, agent: "Agent | BidiAgent", tool_name: str) -> None:
41+
"""Initialize tool executor.
3342
3443
Args:
35-
agent: Agent reference that will accept tool results.
44+
agent: Agent reference that owns the tools.
45+
tool_name: Name of the tool to execute.
3646
"""
37-
# WARNING: Do not add any other member variables or methods as this could result in a name conflict with
38-
# agent tools and thus break their execution.
3947
self._agent_ref = weakref.ref(agent)
48+
self._tool_name = tool_name
4049

4150
@property
4251
def _agent(self) -> "Agent | BidiAgent":
@@ -46,104 +55,161 @@ def _agent(self) -> "Agent | BidiAgent":
4655
raise ReferenceError("Agent has been garbage collected")
4756
return agent
4857

49-
def __getattr__(self, name: str) -> Callable[..., Any]:
50-
"""Call tool as a function.
58+
def __call__(
59+
self,
60+
user_message_override: str | None = None,
61+
record_direct_tool_call: bool | None = None,
62+
**kwargs: Any,
63+
) -> ToolResult:
64+
"""Synchronous tool execution (existing behavior - backward compatible).
5165
5266
This method enables the method-style interface (e.g., `agent.tool.tool_name(param="value")`).
5367
It matches underscore-separated names to hyphenated tool names (e.g., 'some_thing' matches 'some-thing').
5468
5569
Args:
56-
name: The name of the attribute (tool) being accessed.
70+
user_message_override: Optional custom message to record.
71+
record_direct_tool_call: Whether to record in message history.
72+
**kwargs: Tool parameters.
5773
5874
Returns:
59-
A function that when called will execute the named tool.
75+
ToolResult from execution.
6076
6177
Raises:
62-
AttributeError: If no tool with the given name exists or if multiple tools match the given name.
78+
AttributeError: If tool doesn't exist.
79+
RuntimeError: If called during interrupt.
80+
ConcurrencyException: If invocation lock cannot be acquired.
6381
"""
82+
if self._agent._interrupt_state.activated:
83+
raise RuntimeError("cannot directly call tool during interrupt")
6484

65-
def caller(
66-
user_message_override: str | None = None,
67-
record_direct_tool_call: bool | None = None,
68-
**kwargs: Any,
69-
) -> Any:
70-
"""Call a tool directly by name.
71-
72-
Args:
73-
user_message_override: Optional custom message to record instead of default
74-
record_direct_tool_call: Whether to record direct tool calls in message history. Overrides class
75-
attribute if provided.
76-
**kwargs: Keyword arguments to pass to the tool.
77-
78-
Returns:
79-
The result returned by the tool.
80-
81-
Raises:
82-
AttributeError: If the tool doesn't exist.
83-
"""
84-
if self._agent._interrupt_state.activated:
85-
raise RuntimeError("cannot directly call tool during interrupt")
86-
87-
if record_direct_tool_call is not None:
88-
should_record_direct_tool_call = record_direct_tool_call
89-
else:
90-
should_record_direct_tool_call = self._agent.record_direct_tool_call
91-
92-
should_lock = should_record_direct_tool_call
93-
94-
from ..agent import Agent # Locally imported to avoid circular reference
95-
96-
acquired_lock = (
97-
should_lock
98-
and isinstance(self._agent, Agent)
99-
and self._agent._invocation_lock.acquire_lock(blocking=False)
100-
)
101-
if should_lock and not acquired_lock:
102-
raise ConcurrencyException(
103-
"Direct tool call cannot be made while the agent is in the middle of an invocation. "
104-
"Set record_direct_tool_call=False to allow direct tool calls during agent invocation."
105-
)
85+
if record_direct_tool_call is not None:
86+
should_record_direct_tool_call = record_direct_tool_call
87+
else:
88+
should_record_direct_tool_call = self._agent.record_direct_tool_call
10689

107-
try:
108-
normalized_name = self._find_normalized_tool_name(name)
90+
should_lock = should_record_direct_tool_call
10991

110-
# Create unique tool ID and set up the tool request
111-
tool_id = f"tooluse_{name}_{random.randint(100000000, 999999999)}"
112-
tool_use: ToolUse = {
113-
"toolUseId": tool_id,
114-
"name": normalized_name,
115-
"input": kwargs.copy(),
116-
}
117-
tool_results: list[ToolResult] = []
118-
invocation_state = kwargs
92+
from ..agent import Agent # Locally imported to avoid circular reference
11993

120-
async def acall() -> ToolResult:
121-
async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state):
122-
if isinstance(event, ToolInterruptEvent):
123-
self._agent._interrupt_state.deactivate()
124-
raise RuntimeError("cannot raise interrupt in direct tool call")
94+
acquired_lock = (
95+
should_lock and isinstance(self._agent, Agent) and self._agent._invocation_lock.acquire_lock(blocking=False)
96+
)
97+
if should_lock and not acquired_lock:
98+
raise ConcurrencyException(
99+
"Direct tool call cannot be made while the agent is in the middle of an invocation. "
100+
"Set record_direct_tool_call=False to allow direct tool calls during agent invocation."
101+
)
125102

126-
tool_result = tool_results[0]
103+
try:
104+
normalized_name = self._find_normalized_tool_name(self._tool_name)
127105

128-
if should_record_direct_tool_call:
129-
# Create a record of this tool execution in the message history
130-
await self._record_tool_execution(tool_use, tool_result, user_message_override)
106+
# Create unique tool ID and set up the tool request
107+
tool_id = f"tooluse_{self._tool_name}_{random.randint(100000000, 999999999)}"
108+
tool_use: ToolUse = {
109+
"toolUseId": tool_id,
110+
"name": normalized_name,
111+
"input": kwargs.copy(),
112+
}
113+
tool_results: list[ToolResult] = []
114+
invocation_state = kwargs
131115

132-
return tool_result
116+
async def acall() -> ToolResult:
117+
async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state):
118+
if isinstance(event, ToolInterruptEvent):
119+
self._agent._interrupt_state.deactivate()
120+
raise RuntimeError("cannot raise interrupt in direct tool call")
133121

134-
tool_result = run_async(acall)
122+
tool_result = tool_results[0]
135123

136-
# TODO: https://github.com/strands-agents/sdk-python/issues/1311
137-
if isinstance(self._agent, Agent):
138-
self._agent.conversation_manager.apply_management(self._agent)
124+
if should_record_direct_tool_call:
125+
# Create a record of this tool execution in the message history
126+
await self._record_tool_execution(tool_use, tool_result, user_message_override)
139127

140128
return tool_result
141129

142-
finally:
143-
if acquired_lock and isinstance(self._agent, Agent):
144-
self._agent._invocation_lock.release()
130+
tool_result = run_async(acall)
131+
132+
# TODO: https://github.com/strands-agents/sdk-python/issues/1311
133+
if isinstance(self._agent, Agent):
134+
self._agent.conversation_manager.apply_management(self._agent)
135+
136+
return tool_result
137+
138+
finally:
139+
if acquired_lock and isinstance(self._agent, Agent):
140+
self._agent._invocation_lock.release()
141+
142+
def stream(self, **kwargs: Any) -> Iterator[TypedEvent]:
143+
"""Synchronous streaming via async-to-sync wrapper.
144+
145+
This method provides synchronous streaming by wrapping stream_async()
146+
with run_async(). Note that due to Python's async/sync boundary constraints,
147+
events are buffered before yielding. For true streaming, use stream_async().
148+
149+
Args:
150+
**kwargs: Tool parameters.
151+
152+
Yields:
153+
Tool execution events.
154+
155+
Raises:
156+
AttributeError: If tool doesn't exist.
157+
RuntimeError: If called during interrupt.
158+
"""
159+
160+
async def async_generator() -> AsyncIterator[TypedEvent]:
161+
async for event in self.stream_async(**kwargs):
162+
yield event
163+
164+
# Run async generator in sync context
165+
async def collect_events() -> list[TypedEvent]:
166+
events = []
167+
async for event in async_generator():
168+
events.append(event)
169+
return events
170+
171+
events = run_async(collect_events)
172+
yield from events
173+
174+
async def stream_async(self, **kwargs: Any) -> AsyncIterator[TypedEvent]:
175+
"""Asynchronous streaming from ToolExecutor._stream().
176+
177+
This method yields events directly from tool execution without recording
178+
to message history. Designed for observability and real-time progress.
179+
180+
Args:
181+
**kwargs: Tool parameters.
182+
183+
Yields:
184+
Tool execution events from ToolExecutor._stream().
185+
186+
Raises:
187+
AttributeError: If tool doesn't exist.
188+
RuntimeError: If called during interrupt.
189+
"""
190+
if self._agent._interrupt_state.activated:
191+
raise RuntimeError("cannot directly call tool during interrupt")
145192

146-
return caller
193+
normalized_name = self._find_normalized_tool_name(self._tool_name)
194+
195+
logger.debug("tool_name=<%s>, streaming=<True> | executing tool stream", normalized_name)
196+
197+
# Create unique tool ID and set up the tool request
198+
tool_id = f"tooluse_{self._tool_name}_{random.randint(100000000, 999999999)}"
199+
tool_use: ToolUse = {
200+
"toolUseId": tool_id,
201+
"name": normalized_name,
202+
"input": kwargs.copy(),
203+
}
204+
tool_results: list[ToolResult] = []
205+
invocation_state = kwargs
206+
207+
# Stream events directly without recording to message history
208+
async for event in ToolExecutor._stream(self._agent, tool_use, tool_results, invocation_state):
209+
if isinstance(event, ToolInterruptEvent):
210+
self._agent._interrupt_state.deactivate()
211+
raise RuntimeError("cannot raise interrupt in direct tool call")
212+
yield event
147213

148214
def _find_normalized_tool_name(self, name: str) -> str:
149215
"""Lookup the tool represented by name, replacing characters with underscores as necessary."""
@@ -246,3 +312,39 @@ def _filter_tool_parameters_for_recording(self, tool_name: str, input_params: di
246312

247313
properties = tool_spec["inputSchema"]["json"]["properties"]
248314
return {k: v for k, v in input_params.items() if k in properties}
315+
316+
317+
class _ToolCaller:
318+
"""Call tool as a function."""
319+
320+
def __init__(self, agent: "Agent | BidiAgent") -> None:
321+
"""Initialize instance.
322+
323+
Args:
324+
agent: Agent reference that will accept tool results.
325+
"""
326+
# WARNING: Do not add any other member variables or methods as this could result in a name conflict with
327+
# agent tools and thus break their execution.
328+
self._agent_ref = weakref.ref(agent)
329+
330+
@property
331+
def _agent(self) -> "Agent | BidiAgent":
332+
"""Return the agent, raising ReferenceError if it has been garbage collected."""
333+
agent = self._agent_ref()
334+
if agent is None:
335+
raise ReferenceError("Agent has been garbage collected")
336+
return agent
337+
338+
def __getattr__(self, name: str) -> _ToolExecutor:
339+
"""Return tool executor with streaming methods.
340+
341+
This method enables the tool calling interface by returning a callable
342+
object that provides both synchronous execution and streaming methods.
343+
344+
Args:
345+
name: Tool name.
346+
347+
Returns:
348+
Tool executor instance.
349+
"""
350+
return _ToolExecutor(self._agent, name)

0 commit comments

Comments
 (0)