From 0ed898af3e1440ec3b83c9e745c2fa3e5212aab6 Mon Sep 17 00:00:00 2001 From: brucearctor <5032356+brucearctor@users.noreply.github.com> Date: Mon, 16 Mar 2026 18:31:18 -0700 Subject: [PATCH 1/4] fix: ensure LLM callbacks share the same OTel span context (#4851) Move before_model_callback inside the call_llm span and wrap after_model_callback with trace.use_span(span) to re-activate the call_llm span context. This ensures before_model_callback, after_model_callback, and on_model_error_callback all see the same span_id, fixing the mismatch that broke the BigQuery Analytics Plugin. The root cause was twofold: 1. before_model_callback ran outside the call_llm span 2. after_model_callback ran inside a child generate_content span (created by _run_and_handle_error via use_inference_span) Fixes #4851 --- .../adk/flows/llm_flows/base_llm_flow.py | 65 +++-- .../test_llm_callback_span_consistency.py | 226 ++++++++++++++++++ 2 files changed, 265 insertions(+), 26 deletions(-) create mode 100644 tests/unittests/flows/llm_flows/test_llm_callback_span_consistency.py diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index bd0037bdcb..47ec9bfc61 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -46,6 +46,7 @@ from ...telemetry.tracing import trace_call_llm from ...telemetry.tracing import trace_send_data from ...telemetry.tracing import tracer +from opentelemetry import trace from ...tools.base_toolset import BaseToolset from ...tools.tool_context import ToolContext from ...utils.context_utils import Aclosing @@ -1102,28 +1103,34 @@ async def _call_llm_async( llm_request: LlmRequest, model_response_event: Event, ) -> AsyncGenerator[LlmResponse, None]: - # Runs before_model_callback if it exists. - if response := await self._handle_before_model_callback( - invocation_context, llm_request, model_response_event - ): - yield response - return - llm_request.config = llm_request.config or types.GenerateContentConfig() - llm_request.config.labels = llm_request.config.labels or {} + async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]: + with tracer.start_as_current_span('call_llm') as span: + # Runs before_model_callback if it exists. + # This must be inside the call_llm span so that before_model_callback + # and after_model_callback/on_model_error_callback all share the same + # span context (fixes issue #4851). + if response := await self._handle_before_model_callback( + invocation_context, llm_request, model_response_event + ): + yield response + return - # Add agent name as a label to the llm_request. This will help with slicing - # the billing reports on a per-agent basis. - if _ADK_AGENT_NAME_LABEL_KEY not in llm_request.config.labels: - llm_request.config.labels[_ADK_AGENT_NAME_LABEL_KEY] = ( - invocation_context.agent.name - ) + llm_request.config = ( + llm_request.config or types.GenerateContentConfig() + ) + llm_request.config.labels = llm_request.config.labels or {} - # Calls the LLM. - llm = self.__get_llm(invocation_context) + # Add agent name as a label to the llm_request. This will help with + # slicing the billing reports on a per-agent basis. + if _ADK_AGENT_NAME_LABEL_KEY not in llm_request.config.labels: + llm_request.config.labels[_ADK_AGENT_NAME_LABEL_KEY] = ( + invocation_context.agent.name + ) + + # Calls the LLM. + llm = self.__get_llm(invocation_context) - async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]: - with tracer.start_as_current_span('call_llm') as span: if invocation_context.run_config.support_cfc: invocation_context.live_request_queue = LiveRequestQueue() responses_generator = self.run_live(invocation_context) @@ -1137,10 +1144,13 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]: ) as agen: async for llm_response in agen: # Runs after_model_callback if it exists. - if altered_llm_response := await self._handle_after_model_callback( - invocation_context, llm_response, model_response_event - ): - llm_response = altered_llm_response + # Re-activate the call_llm span so after_model_callback sees + # the same span_id as before_model_callback (issue #4851). + with trace.use_span(span): + if altered_llm_response := await self._handle_after_model_callback( + invocation_context, llm_response, model_response_event + ): + llm_response = altered_llm_response # only yield partial response in SSE streaming mode if ( invocation_context.run_config.streaming_mode @@ -1177,10 +1187,13 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]: span, ) # Runs after_model_callback if it exists. - if altered_llm_response := await self._handle_after_model_callback( - invocation_context, llm_response, model_response_event - ): - llm_response = altered_llm_response + # Re-activate the call_llm span so after_model_callback sees + # the same span_id as before_model_callback (issue #4851). + with trace.use_span(span): + if altered_llm_response := await self._handle_after_model_callback( + invocation_context, llm_response, model_response_event + ): + llm_response = altered_llm_response yield llm_response diff --git a/tests/unittests/flows/llm_flows/test_llm_callback_span_consistency.py b/tests/unittests/flows/llm_flows/test_llm_callback_span_consistency.py new file mode 100644 index 0000000000..d4109dfed0 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_llm_callback_span_consistency.py @@ -0,0 +1,226 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests that LLM callbacks share the same OTel span context (issue #4851). + +When OpenTelemetry tracing is enabled, before_model_callback, +after_model_callback, and on_model_error_callback must all execute within +the same call_llm span so that plugins (e.g. BigQueryAgentAnalyticsPlugin) +see a consistent span_id for LLM_REQUEST and LLM_RESPONSE events. +""" + +from typing import Optional +from unittest import mock + +from google.adk.agents.callback_context import CallbackContext +from google.adk.agents.llm_agent import Agent +from google.adk.flows.llm_flows import base_llm_flow +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.plugins.base_plugin import BasePlugin +from google.adk.telemetry import tracing as adk_tracing +from google.genai import types +from opentelemetry import trace +from opentelemetry.sdk.trace import TracerProvider +import pytest + +from ... import testing_utils + + +def _make_real_tracer(): + """Create a real tracer that produces valid span IDs.""" + provider = TracerProvider() + return provider.get_tracer('test_tracer') + + +class SpanCapturingPlugin(BasePlugin): + """Plugin that captures the current span ID in each model callback.""" + + def __init__(self): + super().__init__(name='span_capturing_plugin') + self.before_model_span_id: Optional[int] = None + self.after_model_span_id: Optional[int] = None + self.on_model_error_span_id: Optional[int] = None + + async def before_model_callback( + self, + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + ) -> Optional[LlmResponse]: + span = trace.get_current_span() + ctx = span.get_span_context() + if ctx and ctx.span_id: + self.before_model_span_id = ctx.span_id + return None + + async def after_model_callback( + self, + *, + callback_context: CallbackContext, + llm_response: LlmResponse, + ) -> Optional[LlmResponse]: + span = trace.get_current_span() + ctx = span.get_span_context() + if ctx and ctx.span_id: + self.after_model_span_id = ctx.span_id + return None + + async def on_model_error_callback( + self, + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + error: Exception, + ) -> Optional[LlmResponse]: + span = trace.get_current_span() + ctx = span.get_span_context() + if ctx and ctx.span_id: + self.on_model_error_span_id = ctx.span_id + return LlmResponse( + content=testing_utils.ModelContent( + [types.Part.from_text(text='error handled')] + ) + ) + + +@pytest.mark.asyncio +async def test_before_and_after_model_callbacks_share_span_id(): + """Verify before_model_callback and after_model_callback share the same span. + + This is the core regression test for issue #4851. Before the fix, + before_model_callback ran outside the call_llm span, causing a span_id + mismatch between LLM_REQUEST and LLM_RESPONSE events. + """ + plugin = SpanCapturingPlugin() + real_tracer = _make_real_tracer() + + mock_model = testing_utils.MockModel.create(responses=['model_response']) + agent = Agent( + name='test_agent', + model=mock_model, + ) + + with mock.patch.object(base_llm_flow, 'tracer', real_tracer), \ + mock.patch.object(adk_tracing, 'tracer', real_tracer): + runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin]) + events = await runner.run_async_with_new_session('test') + + # Both callbacks should have captured a span ID + assert plugin.before_model_span_id is not None, ( + 'before_model_callback did not capture a span ID' + ) + assert plugin.after_model_span_id is not None, ( + 'after_model_callback did not capture a span ID' + ) + + # The span IDs must match — this is the core assertion for issue #4851 + assert plugin.before_model_span_id == plugin.after_model_span_id, ( + f'Span ID mismatch: before_model_callback span_id=' + f'{plugin.before_model_span_id:#018x}, ' + f'after_model_callback span_id={plugin.after_model_span_id:#018x}. ' + f'Both callbacks must run inside the same call_llm span.' + ) + + +@pytest.mark.asyncio +async def test_before_and_on_error_model_callbacks_share_span_id(): + """Verify before_model_callback and on_model_error_callback share span. + + When the model raises an error, on_model_error_callback should see the + same span as before_model_callback. + """ + plugin = SpanCapturingPlugin() + real_tracer = _make_real_tracer() + + mock_model = testing_utils.MockModel.create( + responses=[], error=SystemError('model error') + ) + agent = Agent( + name='test_agent', + model=mock_model, + ) + + with mock.patch.object(base_llm_flow, 'tracer', real_tracer), \ + mock.patch.object(adk_tracing, 'tracer', real_tracer): + runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin]) + events = await runner.run_async_with_new_session('test') + + # Both callbacks should have captured a span ID + assert plugin.before_model_span_id is not None, ( + 'before_model_callback did not capture a span ID' + ) + assert plugin.on_model_error_span_id is not None, ( + 'on_model_error_callback did not capture a span ID' + ) + + # The span IDs must match + assert plugin.before_model_span_id == plugin.on_model_error_span_id, ( + f'Span ID mismatch: before_model_callback span_id=' + f'{plugin.before_model_span_id:#018x}, ' + f'on_model_error_callback span_id=' + f'{plugin.on_model_error_span_id:#018x}. ' + f'Both callbacks must run inside the same call_llm span.' + ) + + +@pytest.mark.asyncio +async def test_before_model_callback_short_circuit_has_span(): + """Verify before_model_callback has a valid span when short-circuiting.""" + + class ShortCircuitPlugin(BasePlugin): + + def __init__(self): + super().__init__(name='short_circuit_plugin') + self.span_id: Optional[int] = None + + async def before_model_callback( + self, + *, + callback_context: CallbackContext, + llm_request: LlmRequest, + ) -> Optional[LlmResponse]: + span = trace.get_current_span() + ctx = span.get_span_context() + if ctx and ctx.span_id: + self.span_id = ctx.span_id + return LlmResponse( + content=testing_utils.ModelContent( + [types.Part.from_text(text='short-circuited')] + ) + ) + + plugin = ShortCircuitPlugin() + real_tracer = _make_real_tracer() + + mock_model = testing_utils.MockModel.create(responses=['model_response']) + agent = Agent( + name='test_agent', + model=mock_model, + ) + + with mock.patch.object(base_llm_flow, 'tracer', real_tracer), \ + mock.patch.object(adk_tracing, 'tracer', real_tracer): + runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin]) + events = await runner.run_async_with_new_session('test') + + # The callback should have a valid (non-zero) span ID from the call_llm span + assert plugin.span_id is not None and plugin.span_id != 0, ( + 'before_model_callback should have a valid span ID even when ' + 'short-circuiting the LLM call' + ) + + # Verify the short-circuit response was received + simplified = testing_utils.simplify_events(events) + assert any('short-circuited' in str(e) for e in simplified) From 1c84fbea4281dde0320fc4c37f85826b4de2db19 Mon Sep 17 00:00:00 2001 From: brucearctor <5032356+brucearctor@users.noreply.github.com> Date: Tue, 17 Mar 2026 11:41:28 -0700 Subject: [PATCH 2/4] refactor: extract _apply_after_model_callback helper (review feedback) Extract the duplicated after_model_callback + trace.use_span(span) logic into a local _apply_after_model_callback coroutine for DRY. --- .../adk/flows/llm_flows/base_llm_flow.py | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 47ec9bfc61..b54f9a6eeb 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -1131,6 +1131,21 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]: # Calls the LLM. llm = self.__get_llm(invocation_context) + async def _apply_after_model_callback( + response: LlmResponse, + ) -> LlmResponse: + """Applies after_model_callback within the call_llm span context. + + Re-activates the call_llm span so after_model_callback sees the + same span_id as before_model_callback (issue #4851). + """ + with trace.use_span(span): + if altered := await self._handle_after_model_callback( + invocation_context, response, model_response_event + ): + return altered + return response + if invocation_context.run_config.support_cfc: invocation_context.live_request_queue = LiveRequestQueue() responses_generator = self.run_live(invocation_context) @@ -1143,14 +1158,7 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]: ) ) as agen: async for llm_response in agen: - # Runs after_model_callback if it exists. - # Re-activate the call_llm span so after_model_callback sees - # the same span_id as before_model_callback (issue #4851). - with trace.use_span(span): - if altered_llm_response := await self._handle_after_model_callback( - invocation_context, llm_response, model_response_event - ): - llm_response = altered_llm_response + llm_response = await _apply_after_model_callback(llm_response) # only yield partial response in SSE streaming mode if ( invocation_context.run_config.streaming_mode @@ -1186,14 +1194,7 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]: llm_response, span, ) - # Runs after_model_callback if it exists. - # Re-activate the call_llm span so after_model_callback sees - # the same span_id as before_model_callback (issue #4851). - with trace.use_span(span): - if altered_llm_response := await self._handle_after_model_callback( - invocation_context, llm_response, model_response_event - ): - llm_response = altered_llm_response + llm_response = await _apply_after_model_callback(llm_response) yield llm_response From 3a009a2c44f210c6a89fbb78f7f0845e3ce93831 Mon Sep 17 00:00:00 2001 From: brucearctor <5032356+brucearctor@users.noreply.github.com> Date: Tue, 17 Mar 2026 12:15:27 -0700 Subject: [PATCH 3/4] style: move opentelemetry import to third-party section (PEP 8) Move 'from opentelemetry import trace' to the third-party imports group per PEP 8 import ordering convention. --- src/google/adk/flows/llm_flows/base_llm_flow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index b54f9a6eeb..55b019f47a 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -24,6 +24,7 @@ from google.adk.platform import time as platform_time from google.genai import types +from opentelemetry import trace from websockets.exceptions import ConnectionClosed from websockets.exceptions import ConnectionClosedOK @@ -46,7 +47,6 @@ from ...telemetry.tracing import trace_call_llm from ...telemetry.tracing import trace_send_data from ...telemetry.tracing import tracer -from opentelemetry import trace from ...tools.base_toolset import BaseToolset from ...tools.tool_context import ToolContext from ...utils.context_utils import Aclosing From 1b11d62418007d27dbdb031296d1d3b0393c0e04 Mon Sep 17 00:00:00 2001 From: brucearctor <5032356+brucearctor@users.noreply.github.com> Date: Tue, 17 Mar 2026 16:33:46 -0700 Subject: [PATCH 4/4] style: run autoformat.sh (pyink + isort) --- .../adk/flows/llm_flows/base_llm_flow.py | 4 +- .../test_llm_callback_span_consistency.py | 52 +++++++++++-------- 2 files changed, 30 insertions(+), 26 deletions(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 55b019f47a..be64ccabec 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -1116,9 +1116,7 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]: yield response return - llm_request.config = ( - llm_request.config or types.GenerateContentConfig() - ) + llm_request.config = llm_request.config or types.GenerateContentConfig() llm_request.config.labels = llm_request.config.labels or {} # Add agent name as a label to the llm_request. This will help with diff --git a/tests/unittests/flows/llm_flows/test_llm_callback_span_consistency.py b/tests/unittests/flows/llm_flows/test_llm_callback_span_consistency.py index d4109dfed0..bcead2df3d 100644 --- a/tests/unittests/flows/llm_flows/test_llm_callback_span_consistency.py +++ b/tests/unittests/flows/llm_flows/test_llm_callback_span_consistency.py @@ -112,25 +112,27 @@ async def test_before_and_after_model_callbacks_share_span_id(): model=mock_model, ) - with mock.patch.object(base_llm_flow, 'tracer', real_tracer), \ - mock.patch.object(adk_tracing, 'tracer', real_tracer): + with ( + mock.patch.object(base_llm_flow, 'tracer', real_tracer), + mock.patch.object(adk_tracing, 'tracer', real_tracer), + ): runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin]) events = await runner.run_async_with_new_session('test') # Both callbacks should have captured a span ID - assert plugin.before_model_span_id is not None, ( - 'before_model_callback did not capture a span ID' - ) - assert plugin.after_model_span_id is not None, ( - 'after_model_callback did not capture a span ID' - ) + assert ( + plugin.before_model_span_id is not None + ), 'before_model_callback did not capture a span ID' + assert ( + plugin.after_model_span_id is not None + ), 'after_model_callback did not capture a span ID' # The span IDs must match — this is the core assertion for issue #4851 assert plugin.before_model_span_id == plugin.after_model_span_id, ( - f'Span ID mismatch: before_model_callback span_id=' + 'Span ID mismatch: before_model_callback span_id=' f'{plugin.before_model_span_id:#018x}, ' f'after_model_callback span_id={plugin.after_model_span_id:#018x}. ' - f'Both callbacks must run inside the same call_llm span.' + 'Both callbacks must run inside the same call_llm span.' ) @@ -152,26 +154,28 @@ async def test_before_and_on_error_model_callbacks_share_span_id(): model=mock_model, ) - with mock.patch.object(base_llm_flow, 'tracer', real_tracer), \ - mock.patch.object(adk_tracing, 'tracer', real_tracer): + with ( + mock.patch.object(base_llm_flow, 'tracer', real_tracer), + mock.patch.object(adk_tracing, 'tracer', real_tracer), + ): runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin]) events = await runner.run_async_with_new_session('test') # Both callbacks should have captured a span ID - assert plugin.before_model_span_id is not None, ( - 'before_model_callback did not capture a span ID' - ) - assert plugin.on_model_error_span_id is not None, ( - 'on_model_error_callback did not capture a span ID' - ) + assert ( + plugin.before_model_span_id is not None + ), 'before_model_callback did not capture a span ID' + assert ( + plugin.on_model_error_span_id is not None + ), 'on_model_error_callback did not capture a span ID' # The span IDs must match assert plugin.before_model_span_id == plugin.on_model_error_span_id, ( - f'Span ID mismatch: before_model_callback span_id=' + 'Span ID mismatch: before_model_callback span_id=' f'{plugin.before_model_span_id:#018x}, ' - f'on_model_error_callback span_id=' + 'on_model_error_callback span_id=' f'{plugin.on_model_error_span_id:#018x}. ' - f'Both callbacks must run inside the same call_llm span.' + 'Both callbacks must run inside the same call_llm span.' ) @@ -210,8 +214,10 @@ async def before_model_callback( model=mock_model, ) - with mock.patch.object(base_llm_flow, 'tracer', real_tracer), \ - mock.patch.object(adk_tracing, 'tracer', real_tracer): + with ( + mock.patch.object(base_llm_flow, 'tracer', real_tracer), + mock.patch.object(adk_tracing, 'tracer', real_tracer), + ): runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin]) events = await runner.run_async_with_new_session('test')