diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index d2e1d61032..a4e32fb473 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -480,22 +480,12 @@ async def _run_on_tool_error_callbacks( invocation_context, function_call, tool_confirmation ) + tool_not_found_error = None try: tool = _get_tool(function_call, tools_dict) except ValueError as tool_error: tool = BaseTool(name=function_call.name, description='Tool not found') - error_response = await _run_on_tool_error_callbacks( - tool=tool, - tool_args=function_args, - tool_context=tool_context, - error=tool_error, - ) - if error_response is not None: - return __build_response_event( - tool, error_response, tool_context, invocation_context - ) - else: - raise tool_error + tool_not_found_error = tool_error async def _run_with_trace(): nonlocal function_args @@ -520,6 +510,21 @@ async def _run_with_trace(): if function_response: break + # Step 2.5: If the tool was not found (hallucinated tool), handle the + # error after before_tool_callback has run to maintain balanced + # lifecycle callbacks (push/pop invariant for trace spans). + if tool_not_found_error is not None: + error_response = await _run_on_tool_error_callbacks( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + error=tool_not_found_error, + ) + if error_response is not None: + function_response = error_response + else: + raise tool_not_found_error + # Step 3: Otherwise, proceed calling the tool normally. if function_response is None: try: @@ -711,21 +716,12 @@ async def _run_on_tool_error_callbacks( tool_context = _create_tool_context(invocation_context, function_call) + tool_not_found_error = None try: tool = _get_tool(function_call, tools_dict) except ValueError as tool_error: tool = BaseTool(name=function_call.name, description='Tool not found') - error_response = await _run_on_tool_error_callbacks( - tool=tool, - tool_args=function_args, - tool_context=tool_context, - error=tool_error, - ) - if error_response is not None: - return __build_response_event( - tool, error_response, tool_context, invocation_context - ) - raise tool_error + tool_not_found_error = tool_error async def _run_with_trace(): nonlocal function_args @@ -743,6 +739,21 @@ async def _run_with_trace(): ) ) + # If the tool was not found (hallucinated tool), handle the error after + # before_tool_callback has run to maintain balanced lifecycle callbacks + # (push/pop invariant for trace spans). + if tool_not_found_error is not None: + error_response = await _run_on_tool_error_callbacks( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + error=tool_not_found_error, + ) + if error_response is not None: + function_response = error_response + else: + raise tool_not_found_error + # Step 2: If no overrides are provided from the plugins, further run the # canonical callback. if function_response is None: diff --git a/tests/unittests/flows/llm_flows/test_tool_callbacks.py b/tests/unittests/flows/llm_flows/test_tool_callbacks.py index 695cef192f..c0a09a4609 100644 --- a/tests/unittests/flows/llm_flows/test_tool_callbacks.py +++ b/tests/unittests/flows/llm_flows/test_tool_callbacks.py @@ -432,3 +432,59 @@ async def async_on_tool_error_callback( ), ('root_agent', 'response1'), ] + + +def test_hallucinated_tool_calls_before_tool_callback(): + """Test that before_tool_callback is called even for hallucinated (non-existent) tools. + + This ensures balanced lifecycle callbacks (before/error) so that plugins + relying on push/pop semantics (e.g., TraceManager span stack) are not + corrupted. See https://github.com/google/adk-python/issues/4775. + """ + callback_order = [] + + def tracking_before_tool_callback( + tool: BaseTool, + args: dict[str, Any], + tool_context: ToolContext, + ): + callback_order.append(('before_tool', tool.name)) + return None + + def tracking_on_tool_error_callback( + tool: BaseTool, + args: dict[str, Any], + tool_context: ToolContext, + error: Exception, + ): + callback_order.append(('on_tool_error', tool.name)) + return {'error': str(error)} + + responses = [ + types.Part.from_function_call( + name='hallucinated_tool', + args={'input_str': 'test'}, + ), + 'response1', + ] + mock_model = testing_utils.MockModel.create(responses=responses) + agent = Agent( + name='root_agent', + model=mock_model, + before_tool_callback=tracking_before_tool_callback, + on_tool_error_callback=tracking_on_tool_error_callback, + tools=[simple_function], + ) + + runner = testing_utils.InMemoryRunner(agent) + events = testing_utils.simplify_events(runner.run('test')) + + # Verify the callback order: before_tool must be called before on_tool_error + assert len(callback_order) == 2 + assert callback_order[0][0] == 'before_tool' + assert callback_order[0][1] == 'hallucinated_tool' + assert callback_order[1][0] == 'on_tool_error' + assert callback_order[1][1] == 'hallucinated_tool' + + # Verify the response event is still produced + assert events[-1] == ('root_agent', 'response1')