diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index e11aa668da..da4e23dd09 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -35,10 +35,12 @@ AgentResponseUpdate, AgentSession, BaseAgent, + BaseHistoryProvider, Content, ContinuationToken, Message, ResponseStream, + SessionContext, normalize_messages, prepend_agent_framework_to_user_agent, ) @@ -284,22 +286,115 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride] When stream=True: A ResponseStream of AgentResponseUpdate items. """ del function_invocation_kwargs, client_kwargs, kwargs - if continuation_token is not None: - a2a_stream: AsyncIterable[A2AStreamItem] = self.client.resubscribe( - TaskIdParams(id=continuation_token["task_id"]) + normalized_messages = normalize_messages(messages) if continuation_token is None else None + + if not stream: + + async def _run_non_streaming() -> AgentResponse[Any]: + active_session: AgentSession | None = None + session_context: SessionContext | None = None + if self.context_providers: + active_session, session_context = await self._run_before_providers( + session=session, + input_messages=normalized_messages, + ) + if continuation_token is not None: + a2a_stream: AsyncIterable[A2AStreamItem] = self.client.resubscribe( + TaskIdParams(id=continuation_token["task_id"]) + ) + else: + a2a_message = self._prepare_message_for_a2a(normalized_messages[-1]) # type: ignore[index] + a2a_stream = self.client.send_message(a2a_message) + + response_stream = ResponseStream( + self._map_a2a_stream(a2a_stream, background=background), + finalizer=AgentResponse.from_updates, + ) + result = await response_stream.get_final_response() + if self.context_providers and session_context is not None: + session_context._response = result # type: ignore[assignment] # pyright: ignore[reportPrivateUsage] + await self._run_after_providers(session=active_session, context=session_context) + return result + + return _run_non_streaming() + + # Streaming path + active_session_holder: dict[str, AgentSession | None] = {"session": None} + context_holder: dict[str, SessionContext | None] = {"ctx": None} + + async def _post_hook(response: AgentResponse) -> None: + if not self.context_providers: + return + session_context = context_holder["ctx"] + if session_context is None: + return + session_context._response = response # type: ignore[assignment] # pyright: ignore[reportPrivateUsage] + await self._run_after_providers(session=active_session_holder["session"], context=session_context) + + async def _get_stream() -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: + if self.context_providers: + active_session, session_context = await self._run_before_providers( + session=session, + input_messages=normalized_messages, + ) + active_session_holder["session"] = active_session + context_holder["ctx"] = session_context + + if continuation_token is not None: + a2a_stream: AsyncIterable[A2AStreamItem] = self.client.resubscribe( + TaskIdParams(id=continuation_token["task_id"]) + ) + else: + a2a_message = self._prepare_message_for_a2a(normalized_messages[-1]) # type: ignore[index] + a2a_stream = self.client.send_message(a2a_message) + + return ResponseStream( + self._map_a2a_stream(a2a_stream, background=background), + finalizer=AgentResponse.from_updates, ) - else: - normalized_messages = normalize_messages(messages) - a2a_message = self._prepare_message_for_a2a(normalized_messages[-1]) - a2a_stream = self.client.send_message(a2a_message) - - response = ResponseStream( - self._map_a2a_stream(a2a_stream, background=background), - finalizer=AgentResponse.from_updates, + + return ( + ResponseStream.from_awaitable(_get_stream()).with_result_hook(_post_hook) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] ) - if stream: - return response - return response.get_final_response() + + async def _run_before_providers( + self, + *, + session: AgentSession | None, + input_messages: list[Message] | None, + ) -> tuple[AgentSession | None, SessionContext]: + """Run before_run on all context providers and return the active session and context. + + Keyword Args: + session: The conversation session (None for stateless invocation). + input_messages: Messages to process. + + Returns: + A tuple of (active_session, session_context). + """ + active_session = session + if active_session is None and self.context_providers: + active_session = AgentSession() + + session_context = SessionContext( + session_id=active_session.session_id if active_session else None, + service_session_id=active_session.service_session_id if active_session else None, + input_messages=input_messages or [], + ) + + for provider in self.context_providers: + if isinstance(provider, BaseHistoryProvider) and not provider.load_messages: + continue + if active_session is None: + raise RuntimeError("Provider session must be available when context providers are configured.") + await provider.before_run( + agent=self, + session=active_session, + context=session_context, + state=active_session.state.setdefault(provider.source_id, {}), + ) + + return active_session, session_context async def _map_a2a_stream( self, diff --git a/python/packages/a2a/tests/test_a2a_agent.py b/python/packages/a2a/tests/test_a2a_agent.py index a426c27a7f..d2fcfba2c8 100644 --- a/python/packages/a2a/tests/test_a2a_agent.py +++ b/python/packages/a2a/tests/test_a2a_agent.py @@ -23,6 +23,9 @@ from agent_framework import ( AgentResponse, AgentResponseUpdate, + AgentSession, + BaseContextProvider, + BaseHistoryProvider, Content, Message, ) @@ -850,4 +853,280 @@ async def test_poll_task_completed(a2a_agent: A2AAgent, mock_a2a_client: MockA2A assert response.messages[0].text == "Poll result" +# region context_providers + + +class TrackingContextProvider(BaseContextProvider): + """Context provider that tracks before_run/after_run calls.""" + + def __init__(self, source_id: str = "tracking") -> None: + super().__init__(source_id=source_id) + self.before_run_called = False + self.after_run_called = False + self.before_run_session: AgentSession | None = None + self.after_run_session: AgentSession | None = None + self.after_run_response: AgentResponse | None = None + + async def before_run(self, *, agent, session, context, state) -> None: + self.before_run_called = True + self.before_run_session = session + + async def after_run(self, *, agent, session, context, state) -> None: + self.after_run_called = True + self.after_run_session = session + self.after_run_response = context.response + + +async def test_run_invokes_context_providers(mock_a2a_client: MockA2AClient) -> None: + """Test that run() calls before_run and after_run on context providers.""" + provider = TrackingContextProvider() + agent = A2AAgent(name="Test Agent", client=mock_a2a_client, http_client=None, context_providers=[provider]) + + mock_a2a_client.add_message_response("msg-ctx", "Hello!", "agent") + + response = await agent.run("Hi") + + assert provider.before_run_called + assert provider.after_run_called + assert provider.after_run_response is not None + assert isinstance(response, AgentResponse) + assert response.messages[0].text == "Hello!" + + +async def test_run_invokes_context_providers_with_session(mock_a2a_client: MockA2AClient) -> None: + """Test that context providers receive the provided session.""" + provider = TrackingContextProvider() + agent = A2AAgent(name="Test Agent", client=mock_a2a_client, http_client=None, context_providers=[provider]) + session = AgentSession(session_id="test-session") + + mock_a2a_client.add_message_response("msg-sess", "With session", "agent") + + await agent.run("Hi", session=session) + + assert provider.before_run_session is session + assert provider.after_run_session is session + + +async def test_run_creates_session_for_providers_when_none(mock_a2a_client: MockA2AClient) -> None: + """Test that a session is auto-created when context_providers are set but no session is passed.""" + provider = TrackingContextProvider() + agent = A2AAgent(name="Test Agent", client=mock_a2a_client, http_client=None, context_providers=[provider]) + + mock_a2a_client.add_message_response("msg-auto", "Auto session", "agent") + + await agent.run("Hi") + + assert provider.before_run_session is not None + assert provider.after_run_session is not None + + +async def test_streaming_invokes_context_providers(mock_a2a_client: MockA2AClient) -> None: + """Test that streaming run() calls before_run and after_run on context providers.""" + provider = TrackingContextProvider() + agent = A2AAgent(name="Test Agent", client=mock_a2a_client, http_client=None, context_providers=[provider]) + + mock_a2a_client.add_message_response("msg-stream-ctx", "Streamed!", "agent") + + response = await agent.run("Hi", stream=True).get_final_response() + + assert provider.before_run_called + assert provider.after_run_called + assert provider.after_run_response is not None + assert response.messages[0].text == "Streamed!" + + +async def test_run_without_providers_still_works(mock_a2a_client: MockA2AClient) -> None: + """Test that run() without context_providers still works correctly.""" + agent = A2AAgent(name="Test Agent", client=mock_a2a_client, http_client=None) + + mock_a2a_client.add_message_response("msg-no-ctx", "No providers", "agent") + + response = await agent.run("Hi") + + assert isinstance(response, AgentResponse) + assert response.messages[0].text == "No providers" + + +async def test_multiple_providers_invoked_in_order(mock_a2a_client: MockA2AClient) -> None: + """Test that multiple context providers are called in forward/reverse order.""" + call_order: list[str] = [] + + class OrderTrackingProvider(BaseContextProvider): + async def before_run(self, *, agent, session, context, state) -> None: + call_order.append(f"before:{self.source_id}") + + async def after_run(self, *, agent, session, context, state) -> None: + call_order.append(f"after:{self.source_id}") + + provider_a = OrderTrackingProvider(source_id="a") + provider_b = OrderTrackingProvider(source_id="b") + agent = A2AAgent( + name="Test Agent", client=mock_a2a_client, http_client=None, context_providers=[provider_a, provider_b] + ) + + mock_a2a_client.add_message_response("msg-order", "Ordered", "agent") + + await agent.run("Hi") + + assert call_order == ["before:a", "before:b", "after:b", "after:a"] + + +class TrackingHistoryProvider(BaseHistoryProvider): + """History provider that tracks before_run/after_run calls.""" + + def __init__(self, source_id: str = "history", *, load_messages: bool = True) -> None: + super().__init__(source_id=source_id, load_messages=load_messages) + self.before_run_called = False + self.after_run_called = False + + async def before_run(self, *, agent, session, context, state) -> None: + self.before_run_called = True + + async def after_run(self, *, agent, session, context, state) -> None: + self.after_run_called = True + + async def get_messages(self, session_id, **kwargs) -> list[Message]: + return [] + + async def save_messages(self, session_id, messages, **kwargs) -> None: + pass + + +async def test_history_provider_load_messages_false_skips_before_run(mock_a2a_client: MockA2AClient) -> None: + """Test that BaseHistoryProvider with load_messages=False has before_run skipped.""" + provider = TrackingHistoryProvider(load_messages=False) + agent = A2AAgent(name="Test Agent", client=mock_a2a_client, http_client=None, context_providers=[provider]) + + mock_a2a_client.add_message_response("msg-hist", "Hello!", "agent") + + await agent.run("Hi") + + assert not provider.before_run_called + assert provider.after_run_called + + +async def test_history_provider_load_messages_false_raises_if_before_run_called( + mock_a2a_client: MockA2AClient, +) -> None: + """Test with a stub whose before_run raises, proving it is never invoked.""" + + class FailingHistoryProvider(BaseHistoryProvider): + def __init__(self) -> None: + super().__init__(source_id="fail-hist", load_messages=False) + self.after_run_called = False + + async def before_run(self, *, agent, session, context, state) -> None: + raise AssertionError("before_run should not be called when load_messages=False") + + async def after_run(self, *, agent, session, context, state) -> None: + self.after_run_called = True + + async def get_messages(self, session_id, **kwargs) -> list[Message]: + return [] + + async def save_messages(self, session_id, messages, **kwargs) -> None: + pass + + provider = FailingHistoryProvider() + agent = A2AAgent(name="Test Agent", client=mock_a2a_client, http_client=None, context_providers=[provider]) + + mock_a2a_client.add_message_response("msg-fail", "OK", "agent") + + # Should not raise — before_run is skipped + await agent.run("Hi") + assert provider.after_run_called + + +async def test_history_provider_load_messages_true_calls_before_run(mock_a2a_client: MockA2AClient) -> None: + """Test that BaseHistoryProvider with load_messages=True (default) has before_run called.""" + provider = TrackingHistoryProvider(load_messages=True) + agent = A2AAgent(name="Test Agent", client=mock_a2a_client, http_client=None, context_providers=[provider]) + + mock_a2a_client.add_message_response("msg-hist-true", "Hello!", "agent") + + await agent.run("Hi") + + assert provider.before_run_called + assert provider.after_run_called + + +async def test_history_provider_load_messages_false_streaming(mock_a2a_client: MockA2AClient) -> None: + """Test that streaming skips before_run for BaseHistoryProvider with load_messages=False.""" + provider = TrackingHistoryProvider(load_messages=False) + agent = A2AAgent(name="Test Agent", client=mock_a2a_client, http_client=None, context_providers=[provider]) + + mock_a2a_client.add_message_response("msg-hist-stream", "Streamed!", "agent") + + await agent.run("Hi", stream=True).get_final_response() + + assert not provider.before_run_called + assert provider.after_run_called + + +async def test_mixed_providers_with_history_load_messages_false(mock_a2a_client: MockA2AClient) -> None: + """Test that a regular provider's before_run is called while history provider's is skipped.""" + context_provider = TrackingContextProvider(source_id="ctx") + history_provider = TrackingHistoryProvider(source_id="hist", load_messages=False) + agent = A2AAgent( + name="Test Agent", + client=mock_a2a_client, + http_client=None, + context_providers=[context_provider, history_provider], + ) + + mock_a2a_client.add_message_response("msg-mixed", "Mixed!", "agent") + + await agent.run("Hi") + + assert context_provider.before_run_called + assert not history_provider.before_run_called + assert context_provider.after_run_called + assert history_provider.after_run_called + + +async def test_resume_via_continuation_token_with_context_providers(mock_a2a_client: MockA2AClient) -> None: + """Test that non-streaming run() with continuation_token correctly invokes context providers.""" + provider = TrackingContextProvider() + agent = A2AAgent(name="Test Agent", client=mock_a2a_client, http_client=None, context_providers=[provider]) + + status = TaskStatus(state=TaskState.completed, message=None) + artifact = Artifact( + artifact_id="art-ctx-resume", + name="result", + parts=[Part(root=TextPart(text="Resumed with providers"))], + ) + task = Task(id="task-ctx-resume", context_id="ctx-cr", status=status, artifacts=[artifact]) + mock_a2a_client.resubscribe_responses.append((task, None)) + + token = A2AContinuationToken(task_id="task-ctx-resume", context_id="ctx-cr") + response = await agent.run(continuation_token=token) + + assert isinstance(response, AgentResponse) + assert response.messages[0].text == "Resumed with providers" + assert provider.before_run_called + assert provider.after_run_called + assert provider.after_run_response is not None + + +async def test_resume_via_continuation_token_no_context_providers(mock_a2a_client: MockA2AClient) -> None: + """Test that run() with continuation_token and no context_providers works without crash.""" + agent = A2AAgent(name="Test Agent", client=mock_a2a_client, http_client=None) + + status = TaskStatus(state=TaskState.completed, message=None) + artifact = Artifact( + artifact_id="art-no-ctx", + name="result", + parts=[Part(root=TextPart(text="Resumed no providers"))], + ) + task = Task(id="task-no-ctx", context_id="ctx-nc", status=status, artifacts=[artifact]) + mock_a2a_client.resubscribe_responses.append((task, None)) + + token = A2AContinuationToken(task_id="task-no-ctx", context_id="ctx-nc") + response = await agent.run(continuation_token=token) + + assert isinstance(response, AgentResponse) + assert response.messages[0].text == "Resumed no providers" + assert response.continuation_token is None + + # endregion