Skip to content

Commit ccfd37f

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: enable a2a streaming for agents deployed to Agent Engine.
PiperOrigin-RevId: 884414893
1 parent 15501c8 commit ccfd37f

File tree

1 file changed

+33
-9
lines changed
  • vertexai/preview/reasoning_engines/templates

1 file changed

+33
-9
lines changed

vertexai/preview/reasoning_engines/templates/a2a.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
#
1616

17+
from collections.abc import AsyncIterator
1718
import os
1819
from typing import Any, Callable, Dict, List, Mapping, Optional, TYPE_CHECKING
1920

@@ -55,11 +56,13 @@ def create_agent_card(
5556
agent_card: Optional[Dict[str, Any]] = None,
5657
default_input_modes: Optional[List[str]] = None,
5758
default_output_modes: Optional[List[str]] = None,
59+
streaming: bool = False,
5860
) -> "AgentCard":
5961
"""Creates an AgentCard object.
6062
6163
The function can be called in two ways:
62-
1. By providing the individual parameters: agent_name, description, and skills.
64+
1. By providing the individual parameters: agent_name, description, and
65+
skills.
6366
2. By providing a single dictionary containing all the data.
6467
6568
If a dictionary is provided, the other parameters are ignored.
@@ -69,16 +72,19 @@ def create_agent_card(
6972
description (Optional[str]): A description of the agent.
7073
skills (Optional[List[AgentSkill]]): A list of AgentSkills.
7174
agent_card (Optional[Dict[str, Any]]): Agent Card as a dictionary.
72-
default_input_modes (Optional[List[str]]): A list of input modes,
73-
default to ["text/plain"].
75+
default_input_modes (Optional[List[str]]): A list of input modes, default
76+
to ["text/plain"].
7477
default_output_modes (Optional[List[str]]): A list of output modes,
75-
default to ["application/json"].
78+
default to ["application/json"].
79+
streaming (bool): Whether to enable streaming for the agent. Defaults to
80+
False.
7681
7782
Returns:
7883
AgentCard: A fully constructed AgentCard object.
7984
8085
Raises:
81-
ValueError: If neither a dictionary nor the required parameters are provided.
86+
ValueError: If neither a dictionary nor the required parameters are
87+
provided.
8288
"""
8389
# pylint: disable=g-import-not-at-top
8490
from a2a.types import AgentCard, AgentCapabilities, TransportProtocol
@@ -96,8 +102,7 @@ def create_agent_card(
96102
version="1.0.0",
97103
default_input_modes=default_input_modes or ["text/plain"],
98104
default_output_modes=default_output_modes or ["application/json"],
99-
# Agent Engine does not support streaming yet
100-
capabilities=AgentCapabilities(streaming=False),
105+
capabilities=AgentCapabilities(streaming=streaming),
101106
skills=skills,
102107
preferred_transport=TransportProtocol.http_json, # Http Only.
103108
supports_authenticated_extended_card=True,
@@ -185,8 +190,6 @@ def __init__(
185190
raise ValueError(
186191
"Only HTTP+JSON is supported for preferred transport on agent card "
187192
)
188-
if agent_card.capabilities and agent_card.capabilities.streaming:
189-
raise ValueError("Streaming is not supported by Agent Engine")
190193

191194
self._tmpl_attrs: dict[str, Any] = {
192195
"project": initializer.global_config.project,
@@ -334,6 +337,27 @@ def register_operations(self) -> Dict[str, List[str]]:
334337
"on_cancel_task",
335338
]
336339
}
340+
if self.agent_card.capabilities and self.agent_card.capabilities.streaming:
341+
routes["a2a_extension"].append("on_message_send_stream")
342+
routes["a2a_extension"].append("on_resubscribe_to_task")
337343
if self.agent_card.supports_authenticated_extended_card:
338344
routes["a2a_extension"].append("handle_authenticated_agent_card")
339345
return routes
346+
347+
async def on_message_send_stream(
348+
self,
349+
request: "Request",
350+
context: "ServerCallContext",
351+
) -> AsyncIterator[str]:
352+
"""Handles A2A streaming requests via SSE."""
353+
async for chunk in self.rest_handler.on_message_send_stream(request, context):
354+
yield chunk
355+
356+
async def on_resubscribe_to_task(
357+
self,
358+
request: "Request",
359+
context: "ServerCallContext",
360+
) -> AsyncIterator[str]:
361+
"""Handles A2A task resubscription requests via SSE."""
362+
async for chunk in self.rest_handler.on_resubscribe_to_task(request, context):
363+
yield chunk

0 commit comments

Comments
 (0)