From c882376df78ca7348b2ac91904b00ebd795c33b4 Mon Sep 17 00:00:00 2001 From: Cyrus Bugwadia Date: Sat, 31 Jan 2026 15:37:59 -0800 Subject: [PATCH 1/2] make card modifier and extended card modifier async --- src/a2a/server/apps/jsonrpc/fastapi_app.py | 7 +- src/a2a/server/apps/jsonrpc/jsonrpc_app.py | 16 +++- src/a2a/server/apps/jsonrpc/starlette_app.py | 7 +- src/a2a/server/apps/rest/fastapi_app.py | 7 +- src/a2a/server/apps/rest/rest_adapter.py | 19 +++- .../server/request_handlers/grpc_handler.py | 15 ++- .../request_handlers/jsonrpc_handler.py | 17 ++-- src/a2a/utils/helpers.py | 19 +++- .../request_handlers/test_grpc_handler.py | 28 ++++++ .../request_handlers/test_jsonrpc_handler.py | 51 ++++++++++ tests/server/test_integration.py | 93 +++++++++++++++++++ tests/utils/test_signing.py | 19 ++-- 12 files changed, 256 insertions(+), 42 deletions(-) diff --git a/src/a2a/server/apps/jsonrpc/fastapi_app.py b/src/a2a/server/apps/jsonrpc/fastapi_app.py index ace2c6ae3..dfd92d87c 100644 --- a/src/a2a/server/apps/jsonrpc/fastapi_app.py +++ b/src/a2a/server/apps/jsonrpc/fastapi_app.py @@ -1,6 +1,6 @@ import logging -from collections.abc import Callable +from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any @@ -72,9 +72,10 @@ def __init__( # noqa: PLR0913 http_handler: RequestHandler, extended_agent_card: AgentCard | None = None, context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], AgentCard] | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], AgentCard + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard ] | None = None, max_content_length: int | None = 10 * 1024 * 1024, # 10MB diff --git a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py index 3e7c2854b..f4bffc320 100644 --- a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py +++ b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py @@ -4,7 +4,7 @@ import traceback from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable from typing import TYPE_CHECKING, Any from pydantic import ValidationError @@ -51,6 +51,7 @@ PREV_AGENT_CARD_WELL_KNOWN_PATH, ) from a2a.utils.errors import MethodNotImplementedError +from a2a.utils.helpers import apply_optional_awaitable logger = logging.getLogger(__name__) @@ -178,9 +179,10 @@ def __init__( # noqa: PLR0913 http_handler: RequestHandler, extended_agent_card: AgentCard | None = None, context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], AgentCard] | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], AgentCard + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard ] | None = None, max_content_length: int | None = 10 * 1024 * 1024, # 10MB @@ -576,7 +578,9 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse: card_to_serve = self.agent_card if self.card_modifier: - card_to_serve = self.card_modifier(card_to_serve) + card_to_serve = await apply_optional_awaitable( + self.card_modifier, card_to_serve + ) return JSONResponse( card_to_serve.model_dump( @@ -605,7 +609,9 @@ async def _handle_get_authenticated_extended_agent_card( context = self._context_builder.build(request) # If no base extended card is provided, pass the public card to the modifier base_card = card_to_serve if card_to_serve else self.agent_card - card_to_serve = self.extended_card_modifier(base_card, context) + card_to_serve = await apply_optional_awaitable( + self.extended_card_modifier, base_card, context + ) if card_to_serve: return JSONResponse( diff --git a/src/a2a/server/apps/jsonrpc/starlette_app.py b/src/a2a/server/apps/jsonrpc/starlette_app.py index 1effa9d51..ceaf5ced1 100644 --- a/src/a2a/server/apps/jsonrpc/starlette_app.py +++ b/src/a2a/server/apps/jsonrpc/starlette_app.py @@ -1,6 +1,6 @@ import logging -from collections.abc import Callable +from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any @@ -54,9 +54,10 @@ def __init__( # noqa: PLR0913 http_handler: RequestHandler, extended_agent_card: AgentCard | None = None, context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], AgentCard] | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], AgentCard + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard ] | None = None, max_content_length: int | None = 10 * 1024 * 1024, # 10MB diff --git a/src/a2a/server/apps/rest/fastapi_app.py b/src/a2a/server/apps/rest/fastapi_app.py index 3ae5ad6fe..12a03de84 100644 --- a/src/a2a/server/apps/rest/fastapi_app.py +++ b/src/a2a/server/apps/rest/fastapi_app.py @@ -1,6 +1,6 @@ import logging -from collections.abc import Callable +from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any @@ -49,9 +49,10 @@ def __init__( # noqa: PLR0913 http_handler: RequestHandler, extended_agent_card: AgentCard | None = None, context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], AgentCard] | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], AgentCard + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard ] | None = None, ): diff --git a/src/a2a/server/apps/rest/rest_adapter.py b/src/a2a/server/apps/rest/rest_adapter.py index cdf86ab14..26011fd89 100644 --- a/src/a2a/server/apps/rest/rest_adapter.py +++ b/src/a2a/server/apps/rest/rest_adapter.py @@ -4,6 +4,8 @@ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable from typing import TYPE_CHECKING, Any +from a2a.utils.helpers import apply_optional_awaitable + if TYPE_CHECKING: from sse_starlette.sse import EventSourceResponse @@ -58,9 +60,10 @@ def __init__( # noqa: PLR0913 http_handler: RequestHandler, extended_agent_card: AgentCard | None = None, context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], AgentCard] | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], AgentCard + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard ] | None = None, ): @@ -150,7 +153,9 @@ async def handle_get_agent_card( """ card_to_serve = self.agent_card if self.card_modifier: - card_to_serve = self.card_modifier(card_to_serve) + card_to_serve = await apply_optional_awaitable( + self.card_modifier, card_to_serve + ) return card_to_serve.model_dump(mode='json', exclude_none=True) @@ -182,9 +187,13 @@ async def handle_authenticated_agent_card( if self.extended_card_modifier: context = self._context_builder.build(request) - card_to_serve = self.extended_card_modifier(card_to_serve, context) + card_to_serve = await apply_optional_awaitable( + self.extended_card_modifier, card_to_serve, context + ) elif self.card_modifier: - card_to_serve = self.card_modifier(card_to_serve) + card_to_serve = await apply_optional_awaitable( + self.card_modifier, card_to_serve + ) return card_to_serve.model_dump(mode='json', exclude_none=True) diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index e2ec69a15..409cf7f1e 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -3,7 +3,7 @@ import logging from abc import ABC, abstractmethod -from collections.abc import AsyncIterable, Sequence +from collections.abc import AsyncIterable, Awaitable, Sequence try: @@ -34,7 +34,11 @@ from a2a.types import AgentCard, TaskNotFoundError from a2a.utils import proto_utils from a2a.utils.errors import ServerError -from a2a.utils.helpers import validate, validate_async_generator +from a2a.utils.helpers import ( + apply_optional_awaitable, + validate, + validate_async_generator, +) logger = logging.getLogger(__name__) @@ -89,7 +93,8 @@ def __init__( agent_card: AgentCard, request_handler: RequestHandler, context_builder: CallContextBuilder | None = None, - card_modifier: Callable[[AgentCard], AgentCard] | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, ): """Initializes the GrpcHandler. @@ -339,7 +344,9 @@ async def GetAgentCard( """Get the agent card for the agent served.""" card_to_serve = self.agent_card if self.card_modifier: - card_to_serve = self.card_modifier(card_to_serve) + card_to_serve = await apply_optional_awaitable( + self.card_modifier, card_to_serve + ) return proto_utils.ToProto.agent_card(card_to_serve) async def abort_context( diff --git a/src/a2a/server/request_handlers/jsonrpc_handler.py b/src/a2a/server/request_handlers/jsonrpc_handler.py index 567c61484..f1c6dab81 100644 --- a/src/a2a/server/request_handlers/jsonrpc_handler.py +++ b/src/a2a/server/request_handlers/jsonrpc_handler.py @@ -1,6 +1,6 @@ import logging -from collections.abc import AsyncIterable, Callable +from collections.abc import AsyncIterable, Awaitable, Callable from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler @@ -46,7 +46,7 @@ TaskStatusUpdateEvent, ) from a2a.utils.errors import ServerError -from a2a.utils.helpers import validate +from a2a.utils.helpers import apply_optional_awaitable, validate from a2a.utils.telemetry import SpanKind, trace_class @@ -63,10 +63,11 @@ def __init__( request_handler: RequestHandler, extended_agent_card: AgentCard | None = None, extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], AgentCard + [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard ] | None = None, - card_modifier: Callable[[AgentCard], AgentCard] | None = None, + card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] + | None = None, ): """Initializes the JSONRPCHandler. @@ -450,9 +451,13 @@ async def get_authenticated_extended_card( card_to_serve = base_card if self.extended_card_modifier and context: - card_to_serve = self.extended_card_modifier(base_card, context) + card_to_serve = await apply_optional_awaitable( + self.extended_card_modifier, base_card, context + ) elif self.card_modifier: - card_to_serve = self.card_modifier(base_card) + card_to_serve = await apply_optional_awaitable( + self.card_modifier, base_card + ) return GetAuthenticatedExtendedCardResponse( root=GetAuthenticatedExtendedCardSuccessResponse( diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index 96acdc1e6..d6d3a4ad4 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -5,8 +5,9 @@ import json import logging -from collections.abc import Callable -from typing import Any +from collections.abc import Awaitable, Callable +from inspect import isawaitable +from typing import Any, ParamSpec, TypeVar from uuid import uuid4 from a2a.types import ( @@ -24,6 +25,10 @@ from a2a.utils.telemetry import trace_function +T = TypeVar('T') +P = ParamSpec('P') + + logger = logging.getLogger(__name__) @@ -368,3 +373,13 @@ def canonicalize_agent_card(agent_card: AgentCard) -> str: # Recursively remove empty values cleaned_dict = _clean_empty(card_dict) return json.dumps(cleaned_dict, separators=(',', ':'), sort_keys=True) + + +async def apply_optional_awaitable( + func: Callable[P, Awaitable[T] | T], *args: P.args, **kwargs: P.kwargs +) -> T: + """Applies a function that may be sync or async and returns the result.""" + result = func(*args, **kwargs) + if isawaitable(result): + return await result + return result diff --git a/tests/server/request_handlers/test_grpc_handler.py b/tests/server/request_handlers/test_grpc_handler.py index 26f923c14..647d9e86f 100644 --- a/tests/server/request_handlers/test_grpc_handler.py +++ b/tests/server/request_handlers/test_grpc_handler.py @@ -209,6 +209,34 @@ async def test_get_agent_card_with_modifier( ) -> None: """Test GetAgentCard call with a card_modifier.""" + async def modifier(card: types.AgentCard) -> types.AgentCard: + modified_card = card.model_copy(deep=True) + modified_card.name = 'Modified gRPC Agent' + return modified_card + + grpc_handler_modified = GrpcHandler( + agent_card=sample_agent_card, + request_handler=mock_request_handler, + card_modifier=modifier, + ) + + request_proto = a2a_pb2.GetAgentCardRequest() + response = await grpc_handler_modified.GetAgentCard( + request_proto, mock_grpc_context + ) + + assert response.name == 'Modified gRPC Agent' + assert response.version == sample_agent_card.version + + +@pytest.mark.asyncio +async def test_get_agent_card_with_modifier_sync( + mock_request_handler: AsyncMock, + sample_agent_card: types.AgentCard, + mock_grpc_context: AsyncMock, +) -> None: + """Test GetAgentCard call with a synchronous card_modifier.""" + def modifier(card: types.AgentCard) -> types.AgentCard: modified_card = card.model_copy(deep=True) modified_card.name = 'Modified gRPC Agent' diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index d1ead0211..4ed6e7025 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -1295,6 +1295,57 @@ async def test_get_authenticated_extended_card_with_modifier(self) -> None: skills=[], ) + async def modifier( + card: AgentCard, context: ServerCallContext + ) -> AgentCard: + modified_card = card.model_copy(deep=True) + modified_card.name = 'Modified Card' + modified_card.description = ( + f'Modified for context: {context.state.get("foo")}' + ) + return modified_card + + handler = JSONRPCHandler( + self.mock_agent_card, + mock_request_handler, + extended_agent_card=mock_base_card, + extended_card_modifier=modifier, + ) + request = GetAuthenticatedExtendedCardRequest(id='ext-card-req-mod') + call_context = ServerCallContext(state={'foo': 'bar'}) + + # Act + response: GetAuthenticatedExtendedCardResponse = ( + await handler.get_authenticated_extended_card(request, call_context) + ) + + # Assert + self.assertIsInstance( + response.root, GetAuthenticatedExtendedCardSuccessResponse + ) + self.assertEqual(response.root.id, 'ext-card-req-mod') + modified_card = response.root.result + self.assertEqual(modified_card.name, 'Modified Card') + self.assertEqual(modified_card.description, 'Modified for context: bar') + self.assertEqual(modified_card.version, '1.0') + + async def test_get_authenticated_extended_card_with_modifier_sync( + self, + ) -> None: + """Test successful retrieval of a synchronously dynamically modified extended agent card.""" + # Arrange + mock_request_handler = AsyncMock(spec=DefaultRequestHandler) + mock_base_card = AgentCard( + name='Base Card', + description='Base details', + url='http://agent.example.com/api', + version='1.0', + capabilities=AgentCapabilities(), + default_input_modes=['text/plain'], + default_output_modes=['application/json'], + skills=[], + ) + def modifier(card: AgentCard, context: ServerCallContext) -> AgentCard: modified_card = card.model_copy(deep=True) modified_card.name = 'Modified Card' diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index d65657dea..8080136c1 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -858,6 +858,30 @@ def test_dynamic_agent_card_modifier( ): """Test that the card_modifier dynamically alters the public agent card.""" + async def modifier(card: AgentCard) -> AgentCard: + modified_card = card.model_copy(deep=True) + modified_card.name = 'Dynamically Modified Agent' + return modified_card + + app_instance = A2AStarletteApplication( + agent_card, handler, card_modifier=modifier + ) + client = TestClient(app_instance.build()) + + response = client.get(AGENT_CARD_WELL_KNOWN_PATH) + assert response.status_code == 200 + data = response.json() + assert data['name'] == 'Dynamically Modified Agent' + assert ( + data['version'] == agent_card.version + ) # Ensure other fields are intact + + +def test_dynamic_agent_card_modifier_sync( + agent_card: AgentCard, handler: mock.AsyncMock +): + """Test that a synchronous card_modifier dynamically alters the public agent card.""" + def modifier(card: AgentCard) -> AgentCard: modified_card = card.model_copy(deep=True) modified_card.name = 'Dynamically Modified Agent' @@ -885,6 +909,54 @@ def test_dynamic_extended_agent_card_modifier( """Test that the extended_card_modifier dynamically alters the extended agent card.""" agent_card.supports_authenticated_extended_card = True + async def modifier( + card: AgentCard, context: ServerCallContext + ) -> AgentCard: + modified_card = card.model_copy(deep=True) + modified_card.description = 'Dynamically Modified Extended Description' + return modified_card + + # Test with a base extended card + app_instance = A2AStarletteApplication( + agent_card, + handler, + extended_agent_card=extended_agent_card_fixture, + extended_card_modifier=modifier, + ) + client = TestClient(app_instance.build()) + + response = client.get(EXTENDED_AGENT_CARD_PATH) + assert response.status_code == 200 + data = response.json() + assert data['name'] == extended_agent_card_fixture.name + assert data['description'] == 'Dynamically Modified Extended Description' + + # Test without a base extended card (modifier should receive public card) + app_instance_no_base = A2AStarletteApplication( + agent_card, + handler, + extended_agent_card=None, + extended_card_modifier=modifier, + ) + client_no_base = TestClient(app_instance_no_base.build()) + response_no_base = client_no_base.get(EXTENDED_AGENT_CARD_PATH) + assert response_no_base.status_code == 200 + data_no_base = response_no_base.json() + assert data_no_base['name'] == agent_card.name + assert ( + data_no_base['description'] + == 'Dynamically Modified Extended Description' + ) + + +def test_dynamic_extended_agent_card_modifier_sync( + agent_card: AgentCard, + extended_agent_card_fixture: AgentCard, + handler: mock.AsyncMock, +): + """Test that a synchronous extended_card_modifier dynamically alters the extended agent card.""" + agent_card.supports_authenticated_extended_card = True + def modifier(card: AgentCard, context: ServerCallContext) -> AgentCard: modified_card = card.model_copy(deep=True) modified_card.description = 'Dynamically Modified Extended Description' @@ -928,6 +1000,27 @@ def test_fastapi_dynamic_agent_card_modifier( ): """Test that the card_modifier dynamically alters the public agent card for FastAPI.""" + async def modifier(card: AgentCard) -> AgentCard: + modified_card = card.model_copy(deep=True) + modified_card.name = 'Dynamically Modified Agent' + return modified_card + + app_instance = A2AFastAPIApplication( + agent_card, handler, card_modifier=modifier + ) + client = TestClient(app_instance.build()) + + response = client.get(AGENT_CARD_WELL_KNOWN_PATH) + assert response.status_code == 200 + data = response.json() + assert data['name'] == 'Dynamically Modified Agent' + + +def test_fastapi_dynamic_agent_card_modifier_sync( + agent_card: AgentCard, handler: mock.AsyncMock +): + """Test that a synchronous card_modifier dynamically alters the public agent card for FastAPI.""" + def modifier(card: AgentCard) -> AgentCard: modified_card = card.model_copy(deep=True) modified_card.name = 'Dynamically Modified Agent' diff --git a/tests/utils/test_signing.py b/tests/utils/test_signing.py index 9a843d340..f055d9cac 100644 --- a/tests/utils/test_signing.py +++ b/tests/utils/test_signing.py @@ -1,20 +1,17 @@ +from typing import Any + +import pytest + +from cryptography.hazmat.primitives import asymmetric +from jwt.utils import base64url_encode + from a2a.types import ( - AgentCard, AgentCapabilities, - AgentSkill, -) -from a2a.types import ( AgentCard, - AgentCapabilities, - AgentSkill, AgentCardSignature, + AgentSkill, ) from a2a.utils import signing -from typing import Any -from jwt.utils import base64url_encode - -import pytest -from cryptography.hazmat.primitives import asymmetric def create_key_provider(verification_key: str | bytes | dict[str, Any]): From f050cf982e983728716f0b2bfa36b4caa1e63079 Mon Sep 17 00:00:00 2001 From: Cyrus Bugwadia Date: Mon, 2 Feb 2026 10:04:50 -0800 Subject: [PATCH 2/2] address feedback --- src/a2a/server/apps/jsonrpc/jsonrpc_app.py | 10 ++++------ src/a2a/server/apps/rest/rest_adapter.py | 14 +++++--------- .../server/request_handlers/grpc_handler.py | 10 ++-------- .../request_handlers/jsonrpc_handler.py | 10 ++++------ src/a2a/utils/helpers.py | 17 ++++++----------- tests/utils/test_signing.py | 19 +++++++++++-------- 6 files changed, 32 insertions(+), 48 deletions(-) diff --git a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py index f4bffc320..27839cd35 100644 --- a/src/a2a/server/apps/jsonrpc/jsonrpc_app.py +++ b/src/a2a/server/apps/jsonrpc/jsonrpc_app.py @@ -51,7 +51,7 @@ PREV_AGENT_CARD_WELL_KNOWN_PATH, ) from a2a.utils.errors import MethodNotImplementedError -from a2a.utils.helpers import apply_optional_awaitable +from a2a.utils.helpers import maybe_await logger = logging.getLogger(__name__) @@ -578,9 +578,7 @@ async def _handle_get_agent_card(self, request: Request) -> JSONResponse: card_to_serve = self.agent_card if self.card_modifier: - card_to_serve = await apply_optional_awaitable( - self.card_modifier, card_to_serve - ) + card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) return JSONResponse( card_to_serve.model_dump( @@ -609,8 +607,8 @@ async def _handle_get_authenticated_extended_agent_card( context = self._context_builder.build(request) # If no base extended card is provided, pass the public card to the modifier base_card = card_to_serve if card_to_serve else self.agent_card - card_to_serve = await apply_optional_awaitable( - self.extended_card_modifier, base_card, context + card_to_serve = await maybe_await( + self.extended_card_modifier(base_card, context) ) if card_to_serve: diff --git a/src/a2a/server/apps/rest/rest_adapter.py b/src/a2a/server/apps/rest/rest_adapter.py index 26011fd89..719085604 100644 --- a/src/a2a/server/apps/rest/rest_adapter.py +++ b/src/a2a/server/apps/rest/rest_adapter.py @@ -4,7 +4,7 @@ from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable from typing import TYPE_CHECKING, Any -from a2a.utils.helpers import apply_optional_awaitable +from a2a.utils.helpers import maybe_await if TYPE_CHECKING: @@ -153,9 +153,7 @@ async def handle_get_agent_card( """ card_to_serve = self.agent_card if self.card_modifier: - card_to_serve = await apply_optional_awaitable( - self.card_modifier, card_to_serve - ) + card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) return card_to_serve.model_dump(mode='json', exclude_none=True) @@ -187,13 +185,11 @@ async def handle_authenticated_agent_card( if self.extended_card_modifier: context = self._context_builder.build(request) - card_to_serve = await apply_optional_awaitable( - self.extended_card_modifier, card_to_serve, context + card_to_serve = await maybe_await( + self.extended_card_modifier(card_to_serve, context) ) elif self.card_modifier: - card_to_serve = await apply_optional_awaitable( - self.card_modifier, card_to_serve - ) + card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) return card_to_serve.model_dump(mode='json', exclude_none=True) diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 409cf7f1e..105b99471 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -34,11 +34,7 @@ from a2a.types import AgentCard, TaskNotFoundError from a2a.utils import proto_utils from a2a.utils.errors import ServerError -from a2a.utils.helpers import ( - apply_optional_awaitable, - validate, - validate_async_generator, -) +from a2a.utils.helpers import maybe_await, validate, validate_async_generator logger = logging.getLogger(__name__) @@ -344,9 +340,7 @@ async def GetAgentCard( """Get the agent card for the agent served.""" card_to_serve = self.agent_card if self.card_modifier: - card_to_serve = await apply_optional_awaitable( - self.card_modifier, card_to_serve - ) + card_to_serve = await maybe_await(self.card_modifier(card_to_serve)) return proto_utils.ToProto.agent_card(card_to_serve) async def abort_context( diff --git a/src/a2a/server/request_handlers/jsonrpc_handler.py b/src/a2a/server/request_handlers/jsonrpc_handler.py index f1c6dab81..6df872fca 100644 --- a/src/a2a/server/request_handlers/jsonrpc_handler.py +++ b/src/a2a/server/request_handlers/jsonrpc_handler.py @@ -46,7 +46,7 @@ TaskStatusUpdateEvent, ) from a2a.utils.errors import ServerError -from a2a.utils.helpers import apply_optional_awaitable, validate +from a2a.utils.helpers import maybe_await, validate from a2a.utils.telemetry import SpanKind, trace_class @@ -451,13 +451,11 @@ async def get_authenticated_extended_card( card_to_serve = base_card if self.extended_card_modifier and context: - card_to_serve = await apply_optional_awaitable( - self.extended_card_modifier, base_card, context + card_to_serve = await maybe_await( + self.extended_card_modifier(base_card, context) ) elif self.card_modifier: - card_to_serve = await apply_optional_awaitable( - self.card_modifier, base_card - ) + card_to_serve = await maybe_await(self.card_modifier(base_card)) return GetAuthenticatedExtendedCardResponse( root=GetAuthenticatedExtendedCardSuccessResponse( diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index d6d3a4ad4..8164674e5 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -6,8 +6,7 @@ import logging from collections.abc import Awaitable, Callable -from inspect import isawaitable -from typing import Any, ParamSpec, TypeVar +from typing import Any, TypeVar from uuid import uuid4 from a2a.types import ( @@ -26,7 +25,6 @@ T = TypeVar('T') -P = ParamSpec('P') logger = logging.getLogger(__name__) @@ -375,11 +373,8 @@ def canonicalize_agent_card(agent_card: AgentCard) -> str: return json.dumps(cleaned_dict, separators=(',', ':'), sort_keys=True) -async def apply_optional_awaitable( - func: Callable[P, Awaitable[T] | T], *args: P.args, **kwargs: P.kwargs -) -> T: - """Applies a function that may be sync or async and returns the result.""" - result = func(*args, **kwargs) - if isawaitable(result): - return await result - return result +async def maybe_await(value: T | Awaitable[T]) -> T: + """Awaits a value if it's awaitable, otherwise simply provides it back.""" + if inspect.isawaitable(value): + return await value + return value diff --git a/tests/utils/test_signing.py b/tests/utils/test_signing.py index f055d9cac..9a843d340 100644 --- a/tests/utils/test_signing.py +++ b/tests/utils/test_signing.py @@ -1,17 +1,20 @@ -from typing import Any - -import pytest - -from cryptography.hazmat.primitives import asymmetric -from jwt.utils import base64url_encode - from a2a.types import ( + AgentCard, AgentCapabilities, + AgentSkill, +) +from a2a.types import ( AgentCard, - AgentCardSignature, + AgentCapabilities, AgentSkill, + AgentCardSignature, ) from a2a.utils import signing +from typing import Any +from jwt.utils import base64url_encode + +import pytest +from cryptography.hazmat.primitives import asymmetric def create_key_provider(verification_key: str | bytes | dict[str, Any]):