Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 34 additions & 23 deletions src/google/adk/flows/llm_flows/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,22 +480,12 @@ async def _run_on_tool_error_callbacks(
invocation_context, function_call, tool_confirmation
)

tool_not_found_error = None
try:
tool = _get_tool(function_call, tools_dict)
except ValueError as tool_error:
tool = BaseTool(name=function_call.name, description='Tool not found')
error_response = await _run_on_tool_error_callbacks(
tool=tool,
tool_args=function_args,
tool_context=tool_context,
error=tool_error,
)
if error_response is not None:
return __build_response_event(
tool, error_response, tool_context, invocation_context
)
else:
raise tool_error
tool_not_found_error = tool_error

async def _run_with_trace():
nonlocal function_args
Expand All @@ -520,6 +510,21 @@ async def _run_with_trace():
if function_response:
break

# Step 2.5: If the tool was not found (hallucinated tool), handle the
# error after before_tool_callback has run to maintain balanced
# lifecycle callbacks (push/pop invariant for trace spans).
if tool_not_found_error is not None:
error_response = await _run_on_tool_error_callbacks(
tool=tool,
tool_args=function_args,
tool_context=tool_context,
error=tool_not_found_error,
)
if error_response is not None:
function_response = error_response
else:
raise tool_not_found_error

# Step 3: Otherwise, proceed calling the tool normally.
if function_response is None:
try:
Expand Down Expand Up @@ -711,21 +716,12 @@ async def _run_on_tool_error_callbacks(

tool_context = _create_tool_context(invocation_context, function_call)

tool_not_found_error = None
try:
tool = _get_tool(function_call, tools_dict)
except ValueError as tool_error:
tool = BaseTool(name=function_call.name, description='Tool not found')
error_response = await _run_on_tool_error_callbacks(
tool=tool,
tool_args=function_args,
tool_context=tool_context,
error=tool_error,
)
if error_response is not None:
return __build_response_event(
tool, error_response, tool_context, invocation_context
)
raise tool_error
tool_not_found_error = tool_error

async def _run_with_trace():
nonlocal function_args
Expand All @@ -743,6 +739,21 @@ async def _run_with_trace():
)
)

# If the tool was not found (hallucinated tool), handle the error after
# before_tool_callback has run to maintain balanced lifecycle callbacks
# (push/pop invariant for trace spans).
if tool_not_found_error is not None:
error_response = await _run_on_tool_error_callbacks(
tool=tool,
tool_args=function_args,
tool_context=tool_context,
error=tool_not_found_error,
)
if error_response is not None:
function_response = error_response
else:
raise tool_not_found_error

# Step 2: If no overrides are provided from the plugins, further run the
# canonical callback.
if function_response is None:
Expand Down
56 changes: 56 additions & 0 deletions tests/unittests/flows/llm_flows/test_tool_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,3 +432,59 @@ async def async_on_tool_error_callback(
),
('root_agent', 'response1'),
]


def test_hallucinated_tool_calls_before_tool_callback():
"""Test that before_tool_callback is called even for hallucinated (non-existent) tools.

This ensures balanced lifecycle callbacks (before/error) so that plugins
relying on push/pop semantics (e.g., TraceManager span stack) are not
corrupted. See https://github.com/google/adk-python/issues/4775.
"""
callback_order = []

def tracking_before_tool_callback(
tool: BaseTool,
args: dict[str, Any],
tool_context: ToolContext,
):
callback_order.append(('before_tool', tool.name))
return None

def tracking_on_tool_error_callback(
tool: BaseTool,
args: dict[str, Any],
tool_context: ToolContext,
error: Exception,
):
callback_order.append(('on_tool_error', tool.name))
return {'error': str(error)}

responses = [
types.Part.from_function_call(
name='hallucinated_tool',
args={'input_str': 'test'},
),
'response1',
]
mock_model = testing_utils.MockModel.create(responses=responses)
agent = Agent(
name='root_agent',
model=mock_model,
before_tool_callback=tracking_before_tool_callback,
on_tool_error_callback=tracking_on_tool_error_callback,
tools=[simple_function],
)

runner = testing_utils.InMemoryRunner(agent)
events = testing_utils.simplify_events(runner.run('test'))

# Verify the callback order: before_tool must be called before on_tool_error
assert len(callback_order) == 2
assert callback_order[0][0] == 'before_tool'
assert callback_order[0][1] == 'hallucinated_tool'
assert callback_order[1][0] == 'on_tool_error'
assert callback_order[1][1] == 'hallucinated_tool'

# Verify the response event is still produced
assert events[-1] == ('root_agent', 'response1')