Skip to content
123 changes: 109 additions & 14 deletions python/packages/a2a/agent_framework_a2a/_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@
AgentResponseUpdate,
AgentSession,
BaseAgent,
BaseHistoryProvider,
Content,
ContinuationToken,
Message,
ResponseStream,
SessionContext,
normalize_messages,
prepend_agent_framework_to_user_agent,
)
Expand Down Expand Up @@ -284,22 +286,115 @@ def run( # pyright: ignore[reportIncompatibleMethodOverride]
When stream=True: A ResponseStream of AgentResponseUpdate items.
"""
del function_invocation_kwargs, client_kwargs, kwargs
if continuation_token is not None:
a2a_stream: AsyncIterable[A2AStreamItem] = self.client.resubscribe(
TaskIdParams(id=continuation_token["task_id"])
normalized_messages = normalize_messages(messages) if continuation_token is None else None

if not stream:

async def _run_non_streaming() -> AgentResponse[Any]:
active_session: AgentSession | None = None
session_context: SessionContext | None = None
if self.context_providers:
active_session, session_context = await self._run_before_providers(
session=session,
input_messages=normalized_messages,
)
if continuation_token is not None:
a2a_stream: AsyncIterable[A2AStreamItem] = self.client.resubscribe(
TaskIdParams(id=continuation_token["task_id"])
)
else:
a2a_message = self._prepare_message_for_a2a(normalized_messages[-1]) # type: ignore[index]
a2a_stream = self.client.send_message(a2a_message)

response_stream = ResponseStream(
self._map_a2a_stream(a2a_stream, background=background),
finalizer=AgentResponse.from_updates,
)
result = await response_stream.get_final_response()
if self.context_providers and session_context is not None:
session_context._response = result # type: ignore[assignment] # pyright: ignore[reportPrivateUsage]
await self._run_after_providers(session=active_session, context=session_context)
return result

return _run_non_streaming()

# Streaming path
active_session_holder: dict[str, AgentSession | None] = {"session": None}
context_holder: dict[str, SessionContext | None] = {"ctx": None}

async def _post_hook(response: AgentResponse) -> None:
if not self.context_providers:
return
session_context = context_holder["ctx"]
if session_context is None:
return
session_context._response = response # type: ignore[assignment] # pyright: ignore[reportPrivateUsage]
await self._run_after_providers(session=active_session_holder["session"], context=session_context)

async def _get_stream() -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
if self.context_providers:
active_session, session_context = await self._run_before_providers(
session=session,
input_messages=normalized_messages,
)
active_session_holder["session"] = active_session
context_holder["ctx"] = session_context

if continuation_token is not None:
a2a_stream: AsyncIterable[A2AStreamItem] = self.client.resubscribe(
TaskIdParams(id=continuation_token["task_id"])
)
else:
a2a_message = self._prepare_message_for_a2a(normalized_messages[-1]) # type: ignore[index]
a2a_stream = self.client.send_message(a2a_message)

return ResponseStream(
self._map_a2a_stream(a2a_stream, background=background),
finalizer=AgentResponse.from_updates,
)
else:
normalized_messages = normalize_messages(messages)
a2a_message = self._prepare_message_for_a2a(normalized_messages[-1])
a2a_stream = self.client.send_message(a2a_message)

response = ResponseStream(
self._map_a2a_stream(a2a_stream, background=background),
finalizer=AgentResponse.from_updates,

return (
ResponseStream.from_awaitable(_get_stream()).with_result_hook(_post_hook) # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
)
if stream:
return response
return response.get_final_response()

async def _run_before_providers(
self,
*,
session: AgentSession | None,
input_messages: list[Message] | None,
) -> tuple[AgentSession | None, SessionContext]:
"""Run before_run on all context providers and return the active session and context.

Keyword Args:
session: The conversation session (None for stateless invocation).
input_messages: Messages to process.

Returns:
A tuple of (active_session, session_context).
"""
active_session = session
if active_session is None and self.context_providers:
active_session = AgentSession()

session_context = SessionContext(
session_id=active_session.session_id if active_session else None,
service_session_id=active_session.service_session_id if active_session else None,
input_messages=input_messages or [],
)

for provider in self.context_providers:
if isinstance(provider, BaseHistoryProvider) and not provider.load_messages:
continue
if active_session is None:
raise RuntimeError("Provider session must be available when context providers are configured.")
await provider.before_run(
agent=self,
session=active_session,
context=session_context,
state=active_session.state.setdefault(provider.source_id, {}),
)

return active_session, session_context

async def _map_a2a_stream(
self,
Expand Down
Loading
Loading