Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion go/cli/internal/tui/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,11 @@ func (m *chatModel) handleMessageParts(msg protocol.Message, shouldDisplay bool)
continue
}

kagentType, ok := dp.Metadata["kagent_type"].(string)
typeVal, found := getMetadataValue(dp.Metadata, "type")
if !found {
continue
}
kagentType, ok := typeVal.(string)
if !ok {
continue
}
Expand Down Expand Up @@ -497,6 +501,23 @@ func (m *chatModel) updateStatus() {
}
}

// getMetadataValue looks up an unprefixed key in A2A metadata, checking
// "adk_<key>" first then falling back to "kagent_<key>". This allows
// interoperability with upstream ADK (adk_ prefix) while preserving
// backward-compatibility with kagent's own kagent_ prefix.
func getMetadataValue(metadata map[string]any, key string) (any, bool) {
if metadata == nil {
return nil, false
}
if v, ok := metadata["adk_"+key]; ok {
return v, true
}
if v, ok := metadata["kagent_"+key]; ok {
return v, true
}
return nil, false
}

// getString safely extracts a string value from a map
func getString(m map[string]any, key string) string {
if val, ok := m[key]; ok {
Expand Down
2 changes: 1 addition & 1 deletion python/packages/kagent-adk/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dependencies = [
"anthropic[vertex]>=0.49.0",
"fastapi>=0.115.1",
"litellm>=1.74.3",
"google-adk>=1.22.1",
"google-adk>=1.25.0",
"google-genai>=1.21.1",
"google-auth>=2.40.2",
"httpx>=0.25.0",
Expand Down
84 changes: 65 additions & 19 deletions python/packages/kagent-adk/src/kagent/adk/_agent_executor.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from __future__ import annotations

import asyncio
import inspect
import logging
import uuid
from datetime import datetime, timezone
from typing import Any, Awaitable, Callable, Optional

import asyncio
from a2a.server.agent_execution import AgentExecutor
from a2a.server.agent_execution.context import RequestContext
from a2a.server.events.event_queue import EventQueue
from a2a.types import (
Expand All @@ -21,10 +20,15 @@
TaskStatusUpdateEvent,
TextPart,
)
from google.adk.a2a.executor.a2a_agent_executor import (
A2aAgentExecutor as UpstreamA2aAgentExecutor,
)
from google.adk.a2a.executor.a2a_agent_executor import (
A2aAgentExecutorConfig as UpstreamA2aAgentExecutorConfig,
)
from google.adk.events import Event, EventActions
from google.adk.runners import Runner
from google.adk.utils.context_utils import Aclosing
from opentelemetry import trace
from pydantic import BaseModel
from typing_extensions import override

Expand All @@ -35,24 +39,56 @@
)

from .converters.event_converter import convert_event_to_a2a_events
from .converters.part_converter import convert_a2a_part_to_genai_part, convert_genai_part_to_a2a_part
from .converters.request_converter import convert_a2a_request_to_adk_run_args

logger = logging.getLogger("kagent_adk." + __name__)


class A2aAgentExecutorConfig(BaseModel):
"""Configuration for the A2aAgentExecutor."""
"""Configuration for the KAgent A2aAgentExecutor."""

stream: bool = False


# This class is a copy of the A2aAgentExecutor class in the ADK sdk,
# with the following changes:
# - The runner is ALWAYS a callable that returns a Runner instance
# - The runner is cleaned up at the end of the execution
class A2aAgentExecutor(AgentExecutor):
"""An AgentExecutor that runs an ADK Agent against an A2A request and
publishes updates to an event queue.
def _kagent_request_converter(request, _part_converter=None):
"""Adapter to match the upstream A2ARequestToAgentRunRequestConverter signature.

Upstream expects (RequestContext, A2APartToGenAIPartConverter) -> AgentRunRequest.
Kagent's converter has a different signature, so this wraps it to satisfy
the upstream config type while still using kagent's own conversion logic.
"""
from google.adk.a2a.converters.request_converter import AgentRunRequest

run_args = convert_a2a_request_to_adk_run_args(request, stream=False)
return AgentRunRequest(
user_id=run_args["user_id"],
session_id=run_args["session_id"],
new_message=run_args["new_message"],
run_config=run_args["run_config"],
)


def _kagent_event_converter(event, invocation_context, task_id=None, context_id=None, _part_converter=None):
"""Adapter to match the upstream AdkEventToA2AEventsConverter signature.

Upstream expects (Event, InvocationContext, task_id, context_id, GenAIPartToA2APartConverter).
Kagent's converter doesn't take a part_converter arg, so this wraps it.
"""
return convert_event_to_a2a_events(event, invocation_context, task_id, context_id)


class A2aAgentExecutor(UpstreamA2aAgentExecutor):
"""KAgent's A2A agent executor.

Extends the upstream google-adk A2aAgentExecutor with:
- Per-request runner lifecycle (created fresh and closed after each request)
- OpenTelemetry span attribute management
- Enhanced error handling (Ollama-specific JSON parse errors, CancelledError)
- Partial event filtering to avoid duplicate aggregation during streaming
- Session naming from first message text
- Request header forwarding to session state
- Invocation ID tracking in final event metadata
"""

def __init__(
Expand All @@ -61,23 +97,33 @@ def __init__(
runner: Callable[..., Runner | Awaitable[Runner]],
config: Optional[A2aAgentExecutorConfig] = None,
):
super().__init__()
self._runner = runner
self._config = config
# Build upstream config with kagent's custom converters
upstream_config = UpstreamA2aAgentExecutorConfig(
a2a_part_converter=convert_a2a_part_to_genai_part,
gen_ai_part_converter=convert_genai_part_to_a2a_part,
request_converter=_kagent_request_converter,
event_converter=_kagent_event_converter,
)
super().__init__(runner=runner, config=upstream_config)
self._kagent_config = config

@override
async def _resolve_runner(self) -> Runner:
"""Resolve the runner, handling cases where it's a callable that returns a Runner."""
"""Resolve the runner from the callable.

Unlike the upstream executor which caches a single Runner instance,
kagent always creates a fresh Runner per request. This is necessary
because MCP toolset connections are not shared between requests and
must be cleaned up after each execution.
"""
if callable(self._runner):
# Call the function to get the runner
result = self._runner()

# Handle async callables
if inspect.iscoroutine(result):
resolved_runner = await result
else:
resolved_runner = result

# Ensure we got a Runner instance
if not isinstance(resolved_runner, Runner):
raise TypeError(f"Callable must return a Runner instance, got {type(resolved_runner)}")

Expand Down Expand Up @@ -111,7 +157,7 @@ async def execute(
raise ValueError("A2A request must have a message")

# Convert the a2a request to ADK run args
stream = self._config.stream if self._config is not None else False
stream = self._kagent_config.stream if self._kagent_config is not None else False
run_args = convert_a2a_request_to_adk_run_args(context, stream=stream)

# Prepare span attributes.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""Verify kagent-core A2A constants stay in sync with upstream google-adk.

kagent-core defines these constants locally to avoid depending on google-adk.
This test ensures the values match the upstream definitions.
"""

import pytest
from google.adk.a2a.converters import part_converter as upstream

from kagent.core.a2a import _consts as local

# Each tuple is (constant_name, local_value, upstream_value).
_SYNCED_CONSTANTS = [
(
"A2A_DATA_PART_METADATA_TYPE_KEY",
local.A2A_DATA_PART_METADATA_TYPE_KEY,
upstream.A2A_DATA_PART_METADATA_TYPE_KEY,
),
(
"A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY",
local.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY,
upstream.A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY,
),
(
"A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL",
local.A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL,
upstream.A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL,
),
(
"A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE",
local.A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE,
upstream.A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE,
),
(
"A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT",
local.A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT,
upstream.A2A_DATA_PART_METADATA_TYPE_CODE_EXECUTION_RESULT,
),
(
"A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE",
local.A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE,
upstream.A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE,
),
]


@pytest.mark.parametrize("name,local_val,upstream_val", _SYNCED_CONSTANTS, ids=[t[0] for t in _SYNCED_CONSTANTS])
def test_constant_matches_upstream(name: str, local_val: str, upstream_val: str) -> None:
assert local_val == upstream_val, (
f"kagent-core constant {name} = {local_val!r} does not match "
f"upstream google-adk value {upstream_val!r}. "
f"Update the value in kagent-core/src/kagent/core/a2a/_consts.py."
)
4 changes: 4 additions & 0 deletions python/packages/kagent-core/src/kagent/core/a2a/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL,
A2A_DATA_PART_METADATA_TYPE_FUNCTION_RESPONSE,
A2A_DATA_PART_METADATA_TYPE_KEY,
ADK_METADATA_KEY_PREFIX,
KAGENT_HITL_DECISION_TYPE_APPROVE,
KAGENT_HITL_DECISION_TYPE_DENY,
KAGENT_HITL_DECISION_TYPE_KEY,
Expand All @@ -14,6 +15,7 @@
KAGENT_HITL_RESUME_KEYWORDS_APPROVE,
KAGENT_HITL_RESUME_KEYWORDS_DENY,
get_kagent_metadata_key,
read_metadata_value,
)
from ._hitl import (
DecisionType,
Expand All @@ -33,6 +35,8 @@
"KAgentRequestContextBuilder",
"KAgentTaskStore",
"get_kagent_metadata_key",
"read_metadata_value",
"ADK_METADATA_KEY_PREFIX",
"A2A_DATA_PART_METADATA_TYPE_KEY",
"A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY",
"A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL",
Expand Down
36 changes: 36 additions & 0 deletions python/packages/kagent-core/src/kagent/core/a2a/_consts.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# A2A DataPart metadata constants.
# These values MUST match the upstream google-adk definitions in
# google.adk.a2a.converters.part_converter. A sync-check test in
# kagent-adk verifies they stay in sync.
A2A_DATA_PART_METADATA_TYPE_KEY = "type"
A2A_DATA_PART_METADATA_IS_LONG_RUNNING_KEY = "is_long_running"
A2A_DATA_PART_METADATA_TYPE_FUNCTION_CALL = "function_call"
Expand All @@ -6,6 +10,7 @@
A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE = "executable_code"

KAGENT_METADATA_KEY_PREFIX = "kagent_"
ADK_METADATA_KEY_PREFIX = "adk_"


def get_kagent_metadata_key(key: str) -> str:
Expand All @@ -25,6 +30,37 @@ def get_kagent_metadata_key(key: str) -> str:
return f"{KAGENT_METADATA_KEY_PREFIX}{key}"


def read_metadata_value(metadata: dict | None, key: str, default=None):
"""Read a metadata value, checking ``adk_<key>`` first then ``kagent_<key>``.

This allows interoperability with upstream ADK (which uses the ``adk_``
prefix) while preserving backward-compatibility with kagent's own
``kagent_`` prefix.

Args:
metadata: The metadata dict to look up (may be ``None``).
key: The unprefixed key name (e.g. ``"type"``).
default: Value returned when the key is not found under either prefix.

Returns:
The value found under ``adk_<key>`` or ``kagent_<key>``, or *default*.

Raises:
ValueError: If *key* is empty or ``None``.
"""
if not key:
raise ValueError("Metadata key cannot be empty or None")
if not metadata:
return default
adk_key = f"{ADK_METADATA_KEY_PREFIX}{key}"
if adk_key in metadata:
return metadata[adk_key]
kagent_key = f"{KAGENT_METADATA_KEY_PREFIX}{key}"
if kagent_key in metadata:
return metadata[kagent_key]
return default


# Human-in-the-Loop (HITL) Constants
KAGENT_HITL_INTERRUPT_TYPE_TOOL_APPROVAL = "tool_approval"
KAGENT_HITL_DECISION_TYPE_KEY = "decision_type"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import BaseModel
from typing_extensions import override

from kagent.core.a2a import get_kagent_metadata_key
from kagent.core.a2a import read_metadata_value


class KAgentTaskResponse(BaseModel):
Expand Down Expand Up @@ -40,7 +40,7 @@ def __init__(self, client: httpx.AsyncClient):
def _is_partial_event(self, item: Message) -> bool:
"""Check if a history item is a partial ADK streaming event."""
metadata = item.metadata or {}
return metadata.get(get_kagent_metadata_key("adk_partial")) is True
return read_metadata_value(metadata, "adk_partial") is True

def _clean_partial_events(self, history: list[Message]) -> list[Message]:
"""Remove partial streaming events from history."""
Expand Down
53 changes: 53 additions & 0 deletions python/packages/kagent-core/tests/test_read_metadata_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest

from kagent.core.a2a import read_metadata_value


class TestReadMetadataValue:
"""Tests for the dual-prefix metadata reader."""

def test_reads_kagent_prefix(self):
metadata = {"kagent_type": "function_call"}
assert read_metadata_value(metadata, "type") == "function_call"

def test_reads_adk_prefix(self):
metadata = {"adk_type": "function_call"}
assert read_metadata_value(metadata, "type") == "function_call"

def test_adk_takes_priority_when_both_present(self):
metadata = {"adk_type": "adk_value", "kagent_type": "kagent_value"}
assert read_metadata_value(metadata, "type") == "adk_value"

def test_returns_default_for_missing_key(self):
metadata = {"unrelated_key": "val"}
assert read_metadata_value(metadata, "type") is None
assert read_metadata_value(metadata, "type", "fallback") == "fallback"

def test_returns_default_for_none_metadata(self):
assert read_metadata_value(None, "type") is None
assert read_metadata_value(None, "type", "default") == "default"

def test_returns_default_for_empty_metadata(self):
assert read_metadata_value({}, "type") is None
assert read_metadata_value({}, "type", 42) == 42

def test_raises_for_empty_key(self):
with pytest.raises(ValueError, match="empty"):
read_metadata_value({"a": 1}, "")

def test_raises_for_none_key(self):
with pytest.raises(ValueError, match="empty"):
read_metadata_value({"a": 1}, None) # type: ignore[arg-type]

def test_preserves_non_string_values(self):
metadata = {"kagent_usage": {"total": 100}}
result = read_metadata_value(metadata, "usage")
assert result == {"total": 100}

def test_returns_false_value_not_default(self):
"""Ensure falsy values (False, 0, '') are returned, not treated as missing."""
metadata = {"kagent_flag": False}
assert read_metadata_value(metadata, "flag") is False

metadata2 = {"adk_count": 0}
assert read_metadata_value(metadata2, "count") == 0
Loading
Loading