Skip to content

Commit 49e2e2b

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: update sdk to support a2a 1.0
PiperOrigin-RevId: 890388363
1 parent c12aedc commit 49e2e2b

4 files changed

Lines changed: 105 additions & 36 deletions

File tree

vertexai/_genai/_agent_engines_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -632,9 +632,9 @@ def _generate_class_methods_spec_or_raise(
632632
class_method = _to_proto(schema_dict)
633633
class_method[_MODE_KEY_IN_SCHEMA] = mode
634634
if hasattr(agent, "agent_card"):
635-
class_method[_A2A_AGENT_CARD] = getattr(
636-
agent, "agent_card"
637-
).model_dump_json()
635+
class_method[_A2A_AGENT_CARD] = json_format.MessageToJson(
636+
getattr(agent, "agent_card")
637+
)
638638
class_methods_spec.append(class_method)
639639

640640
return class_methods_spec

vertexai/_genai/agent_engines.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1840,10 +1840,12 @@ def _create_config(
18401840
agent_card = getattr(agent, "agent_card")
18411841
if agent_card:
18421842
try:
1843-
agent_engine_spec["agent_card"] = agent_card.model_dump(
1844-
exclude_none=True
1843+
from google.protobuf import json_format
1844+
import json
1845+
agent_engine_spec["agent_card"] = json.loads(
1846+
json_format.MessageToJson(agent_card)
18451847
)
1846-
except TypeError as e:
1848+
except Exception as e:
18471849
raise ValueError(
18481850
f"Failed to convert agent card to dict (serialization error): {e}"
18491851
) from e

vertexai/agent_engines/_agent_engines.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -119,23 +119,28 @@
119119
try:
120120
from a2a.types import (
121121
AgentCard,
122-
TransportProtocol,
122+
AgentInterface,
123123
Message,
124124
TaskIdParams,
125125
TaskQueryParams,
126126
)
127+
from a2a.utils.constants import TransportProtocol, PROTOCOL_VERSION_CURRENT
127128
from a2a.client import ClientConfig, ClientFactory
128129

129130
AgentCard = AgentCard
131+
AgentInterface = AgentInterface
130132
TransportProtocol = TransportProtocol
133+
PROTOCOL_VERSION_CURRENT = PROTOCOL_VERSION_CURRENT
131134
Message = Message
132135
ClientConfig = ClientConfig
133136
ClientFactory = ClientFactory
134137
TaskIdParams = TaskIdParams
135138
TaskQueryParams = TaskQueryParams
136139
except (ImportError, AttributeError):
137140
AgentCard = None
141+
AgentInterface = None
138142
TransportProtocol = None
143+
PROTOCOL_VERSION_CURRENT = None
139144
Message = None
140145
ClientConfig = None
141146
ClientFactory = None
@@ -1735,17 +1740,20 @@ async def _method(self, **kwargs) -> Any:
17351740
a2a_agent_card = AgentCard(**json.loads(agent_card))
17361741

17371742
# A2A + AE integration currently only supports Rest API.
1738-
if (
1739-
a2a_agent_card.preferred_transport
1740-
and a2a_agent_card.preferred_transport != TransportProtocol.http_json
1741-
):
1743+
if a2a_agent_card.supported_interfaces and a2a_agent_card.supported_interfaces[0].protocol_binding != TransportProtocol.HTTP_JSON:
17421744
raise ValueError(
1743-
"Only HTTP+JSON is supported for preferred transport on agent card "
1745+
"Only HTTP+JSON is supported for primary interface on agent card "
17441746
)
17451747

1746-
# Set preferred transport to HTTP+JSON if not set.
1747-
if not hasattr(a2a_agent_card, "preferred_transport"):
1748-
a2a_agent_card.preferred_transport = TransportProtocol.http_json
1748+
# Set primary interface to HTTP+JSON if not set.
1749+
if not a2a_agent_card.supported_interfaces:
1750+
a2a_agent_card.supported_interfaces = []
1751+
a2a_agent_card.supported_interfaces.append(
1752+
AgentInterface(
1753+
protocol_binding=TransportProtocol.HTTP_JSON,
1754+
protocol_version=PROTOCOL_VERSION_CURRENT,
1755+
)
1756+
)
17491757

17501758
# AE cannot support streaming yet. Turn off streaming for now.
17511759
if a2a_agent_card.capabilities and a2a_agent_card.capabilities.streaming:
@@ -1759,12 +1767,13 @@ async def _method(self, **kwargs) -> Any:
17591767

17601768
# agent_card is set on the class_methods before set_up is invoked.
17611769
# Ensure that the agent_card url is set correctly before the client is created.
1762-
a2a_agent_card.url = f"https://{initializer.global_config.api_endpoint}/v1beta1/{self.resource_name}/a2a"
1770+
url = f"https://{initializer.global_config.api_endpoint}/v1beta1/{self.resource_name}/a2a"
1771+
a2a_agent_card.supported_interfaces[0].url = url
17631772

17641773
# Using a2a client, inject the auth token from the global config.
17651774
config = ClientConfig(
17661775
supported_transports=[
1767-
TransportProtocol.http_json,
1776+
TransportProtocol.HTTP_JSON,
17681777
],
17691778
use_client_preference=True,
17701779
httpx_client=httpx.AsyncClient(
@@ -1977,9 +1986,10 @@ def _generate_class_methods_spec_or_raise(
19771986
class_method[_MODE_KEY_IN_SCHEMA] = mode
19781987
# A2A agent card is a special case, when running in A2A mode,
19791988
if hasattr(agent_engine, "agent_card"):
1980-
class_method[_A2A_AGENT_CARD] = getattr(
1981-
agent_engine, "agent_card"
1982-
).model_dump_json()
1989+
from google.protobuf import json_format
1990+
class_method[_A2A_AGENT_CARD] = json_format.MessageToJson(
1991+
getattr(agent_engine, "agent_card")
1992+
)
19831993
class_methods_spec.append(class_method)
19841994

19851995
return class_methods_spec

vertexai/preview/reasoning_engines/templates/a2a.py

Lines changed: 73 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ def create_agent_card(
8787
provided.
8888
"""
8989
# pylint: disable=g-import-not-at-top
90-
from a2a.types import AgentCard, AgentCapabilities, TransportProtocol
90+
from a2a.types import AgentCard, AgentCapabilities, AgentInterface
91+
from a2a.utils.constants import TransportProtocol, PROTOCOL_VERSION_CURRENT
9192

9293
# Check if a dictionary was provided.
9394
if agent_card:
@@ -98,14 +99,18 @@ def create_agent_card(
9899
return AgentCard(
99100
name=agent_name,
100101
description=description,
101-
url="http://localhost:9999/",
102102
version="1.0.0",
103103
default_input_modes=default_input_modes or ["text/plain"],
104104
default_output_modes=default_output_modes or ["application/json"],
105-
capabilities=AgentCapabilities(streaming=streaming),
105+
capabilities=AgentCapabilities(streaming=streaming, extended_agent_card=True),
106106
skills=skills,
107-
preferred_transport=TransportProtocol.http_json, # Http Only.
108-
supports_authenticated_extended_card=True,
107+
supported_interfaces=[
108+
AgentInterface(
109+
url="http://localhost:9999/",
110+
protocol_binding=TransportProtocol.HTTP_JSON,
111+
protocol_version=PROTOCOL_VERSION_CURRENT,
112+
)
113+
],
109114
)
110115

111116
# Raise an error if insufficient data is provided.
@@ -181,14 +186,11 @@ def __init__(
181186
"""Initializes the A2A agent."""
182187
# pylint: disable=g-import-not-at-top
183188
from google.cloud.aiplatform import initializer
184-
from a2a.types import TransportProtocol
189+
from a2a.utils.constants import TransportProtocol
185190

186-
if (
187-
agent_card.preferred_transport
188-
and agent_card.preferred_transport != TransportProtocol.http_json
189-
):
191+
if agent_card.supported_interfaces and agent_card.supported_interfaces[0].interface.protocol_binding != TransportProtocol.HTTP_JSON:
190192
raise ValueError(
191-
"Only HTTP+JSON is supported for preferred transport on agent card "
193+
"Only HTTP+JSON is supported for the primary interface on agent card "
192194
)
193195

194196
self._tmpl_attrs: dict[str, Any] = {
@@ -244,7 +246,21 @@ def set_up(self):
244246
agent_engine_id = os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_ID", "test-agent-engine")
245247
version = "v1beta1"
246248

247-
self.agent_card.url = f"https://{location}-aiplatform.googleapis.com/{version}/projects/{project}/locations/{location}/reasoningEngines/{agent_engine_id}/a2a"
249+
new_url = f"https://{location}-aiplatform.googleapis.com/{version}/projects/{project}/locations/{location}/reasoningEngines/{agent_engine_id}/a2a"
250+
if not self.agent_card.supported_interfaces:
251+
from a2a.types import AgentInterface
252+
from a2a.utils.constants import TransportProtocol, PROTOCOL_VERSION_CURRENT
253+
254+
self.agent_card.supported_interfaces.append(
255+
AgentInterface(
256+
url=new_url,
257+
protocol_binding=TransportProtocol.HTTP_JSON,
258+
protocol_version=PROTOCOL_VERSION_CURRENT,
259+
)
260+
)
261+
else:
262+
# primary interface must be HTTP+JSON
263+
self.agent_card.supported_interfaces[0].url = new_url
248264
self._tmpl_attrs["agent_card"] = self.agent_card
249265

250266
# Create the agent executor if a builder is provided.
@@ -339,8 +355,8 @@ def register_operations(self) -> Dict[str, List[str]]:
339355
}
340356
if self.agent_card.capabilities and self.agent_card.capabilities.streaming:
341357
routes["a2a_extension"].append("on_message_send_stream")
342-
routes["a2a_extension"].append("on_resubscribe_to_task")
343-
if self.agent_card.supports_authenticated_extended_card:
358+
routes["a2a_extension"].append("on_subscribe_to_task")
359+
if self.agent_card.capabilities and self.agent_card.capabilities.extended_agent_card:
344360
routes["a2a_extension"].append("handle_authenticated_agent_card")
345361
return routes
346362

@@ -353,11 +369,52 @@ async def on_message_send_stream(
353369
async for chunk in self.rest_handler.on_message_send_stream(request, context):
354370
yield chunk
355371

356-
async def on_resubscribe_to_task(
372+
async def on_subscribe_to_task(
357373
self,
358374
request: "Request",
359375
context: "ServerCallContext",
360376
) -> AsyncIterator[str]:
361377
"""Handles A2A task resubscription requests via SSE."""
362-
async for chunk in self.rest_handler.on_resubscribe_to_task(request, context):
378+
async for chunk in self.rest_handler.on_subscribe_to_task(request, context):
363379
yield chunk
380+
381+
def __getstate__(self):
382+
"""Serializes AgentCard proto to a dictionary."""
383+
from google.protobuf import json_format
384+
import json
385+
386+
state = self.__dict__.copy()
387+
388+
def _to_dict_if_proto(obj):
389+
if hasattr(obj, "DESCRIPTOR"):
390+
return {"__protobuf_AgentCard__": json.loads(json_format.MessageToJson(obj))}
391+
return obj
392+
393+
state["agent_card"] = _to_dict_if_proto(state.get("agent_card"))
394+
if "_tmpl_attrs" in state:
395+
tmpl_attrs = state["_tmpl_attrs"].copy()
396+
tmpl_attrs["agent_card"] = _to_dict_if_proto(tmpl_attrs.get("agent_card"))
397+
tmpl_attrs["extended_agent_card"] = _to_dict_if_proto(tmpl_attrs.get("extended_agent_card"))
398+
state["_tmpl_attrs"] = tmpl_attrs
399+
400+
return state
401+
402+
def __setstate__(self, state):
403+
"""Deserializes AgentCard proto from a dictionary."""
404+
from google.protobuf import json_format
405+
from a2a.types import AgentCard
406+
407+
def _from_dict_if_proto(obj):
408+
if isinstance(obj, dict) and "__protobuf_AgentCard__" in obj:
409+
agent_card = AgentCard()
410+
json_format.ParseDict(obj["__protobuf_AgentCard__"], agent_card)
411+
return agent_card
412+
return obj
413+
414+
state["agent_card"] = _from_dict_if_proto(state.get("agent_card"))
415+
if "_tmpl_attrs" in state:
416+
state["_tmpl_attrs"]["agent_card"] = _from_dict_if_proto(state["_tmpl_attrs"].get("agent_card"))
417+
state["_tmpl_attrs"]["extended_agent_card"] = _from_dict_if_proto(state["_tmpl_attrs"].get("extended_agent_card"))
418+
419+
self.__dict__.update(state)
420+

0 commit comments

Comments
 (0)