Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 40 additions & 28 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from google.adk.platform import time as platform_time
from google.genai import types
from opentelemetry import trace
from websockets.exceptions import ConnectionClosed
from websockets.exceptions import ConnectionClosedOK

Expand Down Expand Up @@ -1102,28 +1103,47 @@ async def _call_llm_async(
llm_request: LlmRequest,
model_response_event: Event,
) -> AsyncGenerator[LlmResponse, None]:
# Runs before_model_callback if it exists.
if response := await self._handle_before_model_callback(
invocation_context, llm_request, model_response_event
):
yield response
return

llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.labels = llm_request.config.labels or {}
async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
with tracer.start_as_current_span('call_llm') as span:
# Runs before_model_callback if it exists.
# This must be inside the call_llm span so that before_model_callback
# and after_model_callback/on_model_error_callback all share the same
# span context (fixes issue #4851).
if response := await self._handle_before_model_callback(
invocation_context, llm_request, model_response_event
):
yield response
return

llm_request.config = llm_request.config or types.GenerateContentConfig()
llm_request.config.labels = llm_request.config.labels or {}

# Add agent name as a label to the llm_request. This will help with
# slicing the billing reports on a per-agent basis.
if _ADK_AGENT_NAME_LABEL_KEY not in llm_request.config.labels:
llm_request.config.labels[_ADK_AGENT_NAME_LABEL_KEY] = (
invocation_context.agent.name
)

# Add agent name as a label to the llm_request. This will help with slicing
# the billing reports on a per-agent basis.
if _ADK_AGENT_NAME_LABEL_KEY not in llm_request.config.labels:
llm_request.config.labels[_ADK_AGENT_NAME_LABEL_KEY] = (
invocation_context.agent.name
)
# Calls the LLM.
llm = self.__get_llm(invocation_context)

# Calls the LLM.
llm = self.__get_llm(invocation_context)
async def _apply_after_model_callback(
response: LlmResponse,
) -> LlmResponse:
"""Applies after_model_callback within the call_llm span context.

Re-activates the call_llm span so after_model_callback sees the
same span_id as before_model_callback (issue #4851).
"""
with trace.use_span(span):
if altered := await self._handle_after_model_callback(
invocation_context, response, model_response_event
):
return altered
return response

async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
with tracer.start_as_current_span('call_llm') as span:
if invocation_context.run_config.support_cfc:
invocation_context.live_request_queue = LiveRequestQueue()
responses_generator = self.run_live(invocation_context)
Expand All @@ -1136,11 +1156,7 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
)
) as agen:
async for llm_response in agen:
# Runs after_model_callback if it exists.
if altered_llm_response := await self._handle_after_model_callback(
invocation_context, llm_response, model_response_event
):
llm_response = altered_llm_response
llm_response = await _apply_after_model_callback(llm_response)
# only yield partial response in SSE streaming mode
if (
invocation_context.run_config.streaming_mode
Expand Down Expand Up @@ -1176,11 +1192,7 @@ async def _call_llm_with_tracing() -> AsyncGenerator[LlmResponse, None]:
llm_response,
span,
)
# Runs after_model_callback if it exists.
if altered_llm_response := await self._handle_after_model_callback(
invocation_context, llm_response, model_response_event
):
llm_response = altered_llm_response
llm_response = await _apply_after_model_callback(llm_response)

yield llm_response

Expand Down
232 changes: 232 additions & 0 deletions tests/unittests/flows/llm_flows/test_llm_callback_span_consistency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests that LLM callbacks share the same OTel span context (issue #4851).

When OpenTelemetry tracing is enabled, before_model_callback,
after_model_callback, and on_model_error_callback must all execute within
the same call_llm span so that plugins (e.g. BigQueryAgentAnalyticsPlugin)
see a consistent span_id for LLM_REQUEST and LLM_RESPONSE events.
"""

from typing import Optional
from unittest import mock

from google.adk.agents.callback_context import CallbackContext
from google.adk.agents.llm_agent import Agent
from google.adk.flows.llm_flows import base_llm_flow
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.adk.plugins.base_plugin import BasePlugin
from google.adk.telemetry import tracing as adk_tracing
from google.genai import types
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
import pytest

from ... import testing_utils


def _make_real_tracer():
"""Create a real tracer that produces valid span IDs."""
provider = TracerProvider()
return provider.get_tracer('test_tracer')


class SpanCapturingPlugin(BasePlugin):
"""Plugin that captures the current span ID in each model callback."""

def __init__(self):
super().__init__(name='span_capturing_plugin')
self.before_model_span_id: Optional[int] = None
self.after_model_span_id: Optional[int] = None
self.on_model_error_span_id: Optional[int] = None

async def before_model_callback(
self,
*,
callback_context: CallbackContext,
llm_request: LlmRequest,
) -> Optional[LlmResponse]:
span = trace.get_current_span()
ctx = span.get_span_context()
if ctx and ctx.span_id:
self.before_model_span_id = ctx.span_id
return None

async def after_model_callback(
self,
*,
callback_context: CallbackContext,
llm_response: LlmResponse,
) -> Optional[LlmResponse]:
span = trace.get_current_span()
ctx = span.get_span_context()
if ctx and ctx.span_id:
self.after_model_span_id = ctx.span_id
return None

async def on_model_error_callback(
self,
*,
callback_context: CallbackContext,
llm_request: LlmRequest,
error: Exception,
) -> Optional[LlmResponse]:
span = trace.get_current_span()
ctx = span.get_span_context()
if ctx and ctx.span_id:
self.on_model_error_span_id = ctx.span_id
return LlmResponse(
content=testing_utils.ModelContent(
[types.Part.from_text(text='error handled')]
)
)


@pytest.mark.asyncio
async def test_before_and_after_model_callbacks_share_span_id():
"""Verify before_model_callback and after_model_callback share the same span.

This is the core regression test for issue #4851. Before the fix,
before_model_callback ran outside the call_llm span, causing a span_id
mismatch between LLM_REQUEST and LLM_RESPONSE events.
"""
plugin = SpanCapturingPlugin()
real_tracer = _make_real_tracer()

mock_model = testing_utils.MockModel.create(responses=['model_response'])
agent = Agent(
name='test_agent',
model=mock_model,
)

with (
mock.patch.object(base_llm_flow, 'tracer', real_tracer),
mock.patch.object(adk_tracing, 'tracer', real_tracer),
):
runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin])
events = await runner.run_async_with_new_session('test')

# Both callbacks should have captured a span ID
assert (
plugin.before_model_span_id is not None
), 'before_model_callback did not capture a span ID'
assert (
plugin.after_model_span_id is not None
), 'after_model_callback did not capture a span ID'

# The span IDs must match — this is the core assertion for issue #4851
assert plugin.before_model_span_id == plugin.after_model_span_id, (
'Span ID mismatch: before_model_callback span_id='
f'{plugin.before_model_span_id:#018x}, '
f'after_model_callback span_id={plugin.after_model_span_id:#018x}. '
'Both callbacks must run inside the same call_llm span.'
)


@pytest.mark.asyncio
async def test_before_and_on_error_model_callbacks_share_span_id():
"""Verify before_model_callback and on_model_error_callback share span.

When the model raises an error, on_model_error_callback should see the
same span as before_model_callback.
"""
plugin = SpanCapturingPlugin()
real_tracer = _make_real_tracer()

mock_model = testing_utils.MockModel.create(
responses=[], error=SystemError('model error')
)
agent = Agent(
name='test_agent',
model=mock_model,
)

with (
mock.patch.object(base_llm_flow, 'tracer', real_tracer),
mock.patch.object(adk_tracing, 'tracer', real_tracer),
):
runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin])
events = await runner.run_async_with_new_session('test')

# Both callbacks should have captured a span ID
assert (
plugin.before_model_span_id is not None
), 'before_model_callback did not capture a span ID'
assert (
plugin.on_model_error_span_id is not None
), 'on_model_error_callback did not capture a span ID'

# The span IDs must match
assert plugin.before_model_span_id == plugin.on_model_error_span_id, (
'Span ID mismatch: before_model_callback span_id='
f'{plugin.before_model_span_id:#018x}, '
'on_model_error_callback span_id='
f'{plugin.on_model_error_span_id:#018x}. '
'Both callbacks must run inside the same call_llm span.'
)


@pytest.mark.asyncio
async def test_before_model_callback_short_circuit_has_span():
"""Verify before_model_callback has a valid span when short-circuiting."""

class ShortCircuitPlugin(BasePlugin):

def __init__(self):
super().__init__(name='short_circuit_plugin')
self.span_id: Optional[int] = None

async def before_model_callback(
self,
*,
callback_context: CallbackContext,
llm_request: LlmRequest,
) -> Optional[LlmResponse]:
span = trace.get_current_span()
ctx = span.get_span_context()
if ctx and ctx.span_id:
self.span_id = ctx.span_id
return LlmResponse(
content=testing_utils.ModelContent(
[types.Part.from_text(text='short-circuited')]
)
)

plugin = ShortCircuitPlugin()
real_tracer = _make_real_tracer()

mock_model = testing_utils.MockModel.create(responses=['model_response'])
agent = Agent(
name='test_agent',
model=mock_model,
)

with (
mock.patch.object(base_llm_flow, 'tracer', real_tracer),
mock.patch.object(adk_tracing, 'tracer', real_tracer),
):
runner = testing_utils.TestInMemoryRunner(agent, plugins=[plugin])
events = await runner.run_async_with_new_session('test')

# The callback should have a valid (non-zero) span ID from the call_llm span
assert plugin.span_id is not None and plugin.span_id != 0, (
'before_model_callback should have a valid span ID even when '
'short-circuiting the LLM call'
)

# Verify the short-circuit response was received
simplified = testing_utils.simplify_events(events)
assert any('short-circuited' in str(e) for e in simplified)