Skip to content

Commit 9f2af36

Browse files
committed
fix: ensure balanced tool lifecycle callbacks for hallucinated tools
When the LLM hallucinates a non-existent tool name, the ValueError handler was calling on_tool_error_callback directly without first invoking before_tool_callback or entering the tracer span context. This broke the push/pop invariant that plugins (e.g., BigQueryAgentAnalyticsPlugin) rely on for TraceManager span stack management, causing stack corruption. Move the hallucinated-tool error handling inside the traced lifecycle path (_run_with_trace) so that: 1. The tracer span context (start_as_current_span) is entered first 2. before_tool_callback runs before on_tool_error_callback 3. after_tool_callback / trace_tool_call run in the finally block Applied to both _execute_single_function_call_async and _execute_single_function_call_live. Fixes #4775
1 parent b8e7647 commit 9f2af36

2 files changed

Lines changed: 90 additions & 23 deletions

File tree

src/google/adk/flows/llm_flows/functions.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -480,22 +480,12 @@ async def _run_on_tool_error_callbacks(
480480
invocation_context, function_call, tool_confirmation
481481
)
482482

483+
tool_not_found_error = None
483484
try:
484485
tool = _get_tool(function_call, tools_dict)
485486
except ValueError as tool_error:
486487
tool = BaseTool(name=function_call.name, description='Tool not found')
487-
error_response = await _run_on_tool_error_callbacks(
488-
tool=tool,
489-
tool_args=function_args,
490-
tool_context=tool_context,
491-
error=tool_error,
492-
)
493-
if error_response is not None:
494-
return __build_response_event(
495-
tool, error_response, tool_context, invocation_context
496-
)
497-
else:
498-
raise tool_error
488+
tool_not_found_error = tool_error
499489

500490
async def _run_with_trace():
501491
nonlocal function_args
@@ -520,6 +510,21 @@ async def _run_with_trace():
520510
if function_response:
521511
break
522512

513+
# Step 2.5: If the tool was not found (hallucinated tool), handle the
514+
# error after before_tool_callback has run to maintain balanced
515+
# lifecycle callbacks (push/pop invariant for trace spans).
516+
if tool_not_found_error is not None:
517+
error_response = await _run_on_tool_error_callbacks(
518+
tool=tool,
519+
tool_args=function_args,
520+
tool_context=tool_context,
521+
error=tool_not_found_error,
522+
)
523+
if error_response is not None:
524+
function_response = error_response
525+
else:
526+
raise tool_not_found_error
527+
523528
# Step 3: Otherwise, proceed calling the tool normally.
524529
if function_response is None:
525530
try:
@@ -711,21 +716,12 @@ async def _run_on_tool_error_callbacks(
711716

712717
tool_context = _create_tool_context(invocation_context, function_call)
713718

719+
tool_not_found_error = None
714720
try:
715721
tool = _get_tool(function_call, tools_dict)
716722
except ValueError as tool_error:
717723
tool = BaseTool(name=function_call.name, description='Tool not found')
718-
error_response = await _run_on_tool_error_callbacks(
719-
tool=tool,
720-
tool_args=function_args,
721-
tool_context=tool_context,
722-
error=tool_error,
723-
)
724-
if error_response is not None:
725-
return __build_response_event(
726-
tool, error_response, tool_context, invocation_context
727-
)
728-
raise tool_error
724+
tool_not_found_error = tool_error
729725

730726
async def _run_with_trace():
731727
nonlocal function_args
@@ -743,6 +739,21 @@ async def _run_with_trace():
743739
)
744740
)
745741

742+
# If the tool was not found (hallucinated tool), handle the error after
743+
# before_tool_callback has run to maintain balanced lifecycle callbacks
744+
# (push/pop invariant for trace spans).
745+
if tool_not_found_error is not None:
746+
error_response = await _run_on_tool_error_callbacks(
747+
tool=tool,
748+
tool_args=function_args,
749+
tool_context=tool_context,
750+
error=tool_not_found_error,
751+
)
752+
if error_response is not None:
753+
function_response = error_response
754+
else:
755+
raise tool_not_found_error
756+
746757
# Step 2: If no overrides are provided from the plugins, further run the
747758
# canonical callback.
748759
if function_response is None:

tests/unittests/flows/llm_flows/test_tool_callbacks.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -432,3 +432,59 @@ async def async_on_tool_error_callback(
432432
),
433433
('root_agent', 'response1'),
434434
]
435+
436+
437+
def test_hallucinated_tool_calls_before_tool_callback():
438+
"""Test that before_tool_callback is called even for hallucinated (non-existent) tools.
439+
440+
This ensures balanced lifecycle callbacks (before/error) so that plugins
441+
relying on push/pop semantics (e.g., TraceManager span stack) are not
442+
corrupted. See https://github.com/google/adk-python/issues/4775.
443+
"""
444+
callback_order = []
445+
446+
def tracking_before_tool_callback(
447+
tool: BaseTool,
448+
args: dict[str, Any],
449+
tool_context: ToolContext,
450+
):
451+
callback_order.append(('before_tool', tool.name))
452+
return None
453+
454+
def tracking_on_tool_error_callback(
455+
tool: BaseTool,
456+
args: dict[str, Any],
457+
tool_context: ToolContext,
458+
error: Exception,
459+
):
460+
callback_order.append(('on_tool_error', tool.name))
461+
return {'error': str(error)}
462+
463+
responses = [
464+
types.Part.from_function_call(
465+
name='hallucinated_tool',
466+
args={'input_str': 'test'},
467+
),
468+
'response1',
469+
]
470+
mock_model = testing_utils.MockModel.create(responses=responses)
471+
agent = Agent(
472+
name='root_agent',
473+
model=mock_model,
474+
before_tool_callback=tracking_before_tool_callback,
475+
on_tool_error_callback=tracking_on_tool_error_callback,
476+
tools=[simple_function],
477+
)
478+
479+
runner = testing_utils.InMemoryRunner(agent)
480+
events = testing_utils.simplify_events(runner.run('test'))
481+
482+
# Verify the callback order: before_tool must be called before on_tool_error
483+
assert len(callback_order) == 2
484+
assert callback_order[0][0] == 'before_tool'
485+
assert callback_order[0][1] == 'hallucinated_tool'
486+
assert callback_order[1][0] == 'on_tool_error'
487+
assert callback_order[1][1] == 'hallucinated_tool'
488+
489+
# Verify the response event is still produced
490+
assert events[-1] == ('root_agent', 'response1')

0 commit comments

Comments
 (0)