Skip to content

Commit 3211393

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: update sdk to support a2a 1.0
PiperOrigin-RevId: 890388363
1 parent e164b19 commit 3211393

4 files changed

Lines changed: 121 additions & 40 deletions

File tree

vertexai/_genai/_agent_engines_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -632,9 +632,9 @@ def _generate_class_methods_spec_or_raise(
632632
class_method = _to_proto(schema_dict)
633633
class_method[_MODE_KEY_IN_SCHEMA] = mode
634634
if hasattr(agent, "agent_card"):
635-
class_method[_A2A_AGENT_CARD] = getattr(
636-
agent, "agent_card"
637-
).model_dump_json()
635+
class_method[_A2A_AGENT_CARD] = json_format.MessageToJson(
636+
getattr(agent, "agent_card")
637+
)
638638
class_methods_spec.append(class_method)
639639

640640
return class_methods_spec

vertexai/_genai/agent_engines.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1840,10 +1840,12 @@ def _create_config(
18401840
agent_card = getattr(agent, "agent_card")
18411841
if agent_card:
18421842
try:
1843-
agent_engine_spec["agent_card"] = agent_card.model_dump(
1844-
exclude_none=True
1843+
from google.protobuf import json_format
1844+
import json
1845+
agent_engine_spec["agent_card"] = json.loads(
1846+
json_format.MessageToJson(agent_card)
18451847
)
1846-
except TypeError as e:
1848+
except Exception as e:
18471849
raise ValueError(
18481850
f"Failed to convert agent card to dict (serialization error): {e}"
18491851
) from e

vertexai/agent_engines/_agent_engines.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -119,23 +119,28 @@
119119
try:
120120
from a2a.types import (
121121
AgentCard,
122-
TransportProtocol,
122+
AgentInterface,
123123
Message,
124124
TaskIdParams,
125125
TaskQueryParams,
126126
)
127+
from a2a.utils.constants import TransportProtocol, PROTOCOL_VERSION_CURRENT
127128
from a2a.client import ClientConfig, ClientFactory
128129

129130
AgentCard = AgentCard
131+
AgentInterface = AgentInterface
130132
TransportProtocol = TransportProtocol
133+
PROTOCOL_VERSION_CURRENT = PROTOCOL_VERSION_CURRENT
131134
Message = Message
132135
ClientConfig = ClientConfig
133136
ClientFactory = ClientFactory
134137
TaskIdParams = TaskIdParams
135138
TaskQueryParams = TaskQueryParams
136139
except (ImportError, AttributeError):
137140
AgentCard = None
141+
AgentInterface = None
138142
TransportProtocol = None
143+
PROTOCOL_VERSION_CURRENT = None
139144
Message = None
140145
ClientConfig = None
141146
ClientFactory = None
@@ -1735,17 +1740,27 @@ async def _method(self, **kwargs) -> Any:
17351740
a2a_agent_card = AgentCard(**json.loads(agent_card))
17361741

17371742
# A2A + AE integration currently only supports Rest API.
1738-
if (
1739-
a2a_agent_card.preferred_transport
1740-
and a2a_agent_card.preferred_transport != TransportProtocol.http_json
1741-
):
1742-
raise ValueError(
1743-
"Only HTTP+JSON is supported for preferred transport on agent card "
1744-
)
1743+
has_http_transport = False
1744+
if a2a_agent_card.supported_interfaces:
1745+
for interface in a2a_agent_card.supported_interfaces:
1746+
if interface.protocol_binding == TransportProtocol.HTTP_JSON:
1747+
has_http_transport = True
1748+
break
1749+
if not has_http_transport:
1750+
raise ValueError(
1751+
"Only HTTP+JSON is supported for preferred transport on agent card "
1752+
)
17451753

17461754
# Set preferred transport to HTTP+JSON if not set.
1747-
if not hasattr(a2a_agent_card, "preferred_transport"):
1748-
a2a_agent_card.preferred_transport = TransportProtocol.http_json
1755+
if not has_http_transport:
1756+
if a2a_agent_card.supported_interfaces is None:
1757+
a2a_agent_card.supported_interfaces = []
1758+
a2a_agent_card.supported_interfaces.append(
1759+
AgentInterface(
1760+
protocol_binding=TransportProtocol.HTTP_JSON,
1761+
protocol_version=PROTOCOL_VERSION_CURRENT,
1762+
)
1763+
)
17491764

17501765
# AE cannot support streaming yet. Turn off streaming for now.
17511766
if a2a_agent_card.capabilities and a2a_agent_card.capabilities.streaming:
@@ -1759,12 +1774,13 @@ async def _method(self, **kwargs) -> Any:
17591774

17601775
# agent_card is set on the class_methods before set_up is invoked.
17611776
# Ensure that the agent_card url is set correctly before the client is created.
1762-
a2a_agent_card.url = f"https://{initializer.global_config.api_endpoint}/v1beta1/{self.resource_name}/a2a"
1777+
url = f"https://{initializer.global_config.api_endpoint}/v1beta1/{self.resource_name}/a2a"
1778+
a2a_agent_card.supported_interfaces[0].url = url
17631779

17641780
# Using a2a client, inject the auth token from the global config.
17651781
config = ClientConfig(
17661782
supported_transports=[
1767-
TransportProtocol.http_json,
1783+
TransportProtocol.HTTP_JSON,
17681784
],
17691785
use_client_preference=True,
17701786
httpx_client=httpx.AsyncClient(
@@ -1977,9 +1993,10 @@ def _generate_class_methods_spec_or_raise(
19771993
class_method[_MODE_KEY_IN_SCHEMA] = mode
19781994
# A2A agent card is a special case, when running in A2A mode,
19791995
if hasattr(agent_engine, "agent_card"):
1980-
class_method[_A2A_AGENT_CARD] = getattr(
1981-
agent_engine, "agent_card"
1982-
).model_dump_json()
1996+
from google.protobuf import json_format
1997+
class_method[_A2A_AGENT_CARD] = json_format.MessageToJson(
1998+
getattr(agent_engine, "agent_card")
1999+
)
19832000
class_methods_spec.append(class_method)
19842001

19852002
return class_methods_spec

vertexai/preview/reasoning_engines/templates/a2a.py

Lines changed: 81 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ def create_agent_card(
8787
provided.
8888
"""
8989
# pylint: disable=g-import-not-at-top
90-
from a2a.types import AgentCard, AgentCapabilities, TransportProtocol
90+
from a2a.types import AgentCard, AgentCapabilities, AgentInterface
91+
from a2a.utils.constants import TransportProtocol, PROTOCOL_VERSION_CURRENT
9192

9293
# Check if a dictionary was provided.
9394
if agent_card:
@@ -98,14 +99,18 @@ def create_agent_card(
9899
return AgentCard(
99100
name=agent_name,
100101
description=description,
101-
url="http://localhost:9999/",
102102
version="1.0.0",
103103
default_input_modes=default_input_modes or ["text/plain"],
104104
default_output_modes=default_output_modes or ["application/json"],
105-
capabilities=AgentCapabilities(streaming=streaming),
105+
capabilities=AgentCapabilities(streaming=streaming, extended_agent_card=True),
106106
skills=skills,
107-
preferred_transport=TransportProtocol.http_json, # Http Only.
108-
supports_authenticated_extended_card=True,
107+
supported_interfaces=[
108+
AgentInterface(
109+
url="http://localhost:9999/",
110+
protocol_binding=TransportProtocol.HTTP_JSON,
111+
protocol_version=PROTOCOL_VERSION_CURRENT,
112+
)
113+
],
109114
)
110115

111116
# Raise an error if insufficient data is provided.
@@ -181,15 +186,18 @@ def __init__(
181186
"""Initializes the A2A agent."""
182187
# pylint: disable=g-import-not-at-top
183188
from google.cloud.aiplatform import initializer
184-
from a2a.types import TransportProtocol
185-
186-
if (
187-
agent_card.preferred_transport
188-
and agent_card.preferred_transport != TransportProtocol.http_json
189-
):
190-
raise ValueError(
191-
"Only HTTP+JSON is supported for preferred transport on agent card "
192-
)
189+
from a2a.utils.constants import TransportProtocol
190+
191+
if agent_card.supported_interfaces:
192+
has_http_transport = False
193+
for interface in agent_card.supported_interfaces:
194+
if interface.protocol_binding == TransportProtocol.HTTP_JSON:
195+
has_http_transport = True
196+
break
197+
if not has_http_transport:
198+
raise ValueError(
199+
"Only HTTP+JSON is supported for preferred transport on agent card "
200+
)
193201

194202
self._tmpl_attrs: dict[str, Any] = {
195203
"project": initializer.global_config.project,
@@ -244,7 +252,20 @@ def set_up(self):
244252
agent_engine_id = os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_ID", "test-agent-engine")
245253
version = "v1beta1"
246254

247-
self.agent_card.url = f"https://{location}-aiplatform.googleapis.com/{version}/projects/{project}/locations/{location}/reasoningEngines/{agent_engine_id}/a2a"
255+
new_url = f"https://{location}-aiplatform.googleapis.com/{version}/projects/{project}/locations/{location}/reasoningEngines/{agent_engine_id}/a2a"
256+
if not self.agent_card.supported_interfaces:
257+
from a2a.types import AgentInterface
258+
from a2a.utils.constants import TransportProtocol, PROTOCOL_VERSION_CURRENT
259+
260+
self.agent_card.supported_interfaces.append(
261+
AgentInterface(
262+
url=new_url,
263+
protocol_binding=TransportProtocol.HTTP_JSON,
264+
protocol_version=PROTOCOL_VERSION_CURRENT,
265+
)
266+
)
267+
else:
268+
self.agent_card.supported_interfaces[0].url = new_url
248269
self._tmpl_attrs["agent_card"] = self.agent_card
249270

250271
# Create the agent executor if a builder is provided.
@@ -339,8 +360,8 @@ def register_operations(self) -> Dict[str, List[str]]:
339360
}
340361
if self.agent_card.capabilities and self.agent_card.capabilities.streaming:
341362
routes["a2a_extension"].append("on_message_send_stream")
342-
routes["a2a_extension"].append("on_resubscribe_to_task")
343-
if self.agent_card.supports_authenticated_extended_card:
363+
routes["a2a_extension"].append("on_subscribe_to_task")
364+
if self.agent_card.capabilities and self.agent_card.capabilities.extended_agent_card:
344365
routes["a2a_extension"].append("handle_authenticated_agent_card")
345366
return routes
346367

@@ -353,11 +374,52 @@ async def on_message_send_stream(
353374
async for chunk in self.rest_handler.on_message_send_stream(request, context):
354375
yield chunk
355376

356-
async def on_resubscribe_to_task(
377+
async def on_subscribe_to_task(
357378
self,
358379
request: "Request",
359380
context: "ServerCallContext",
360381
) -> AsyncIterator[str]:
361382
"""Handles A2A task resubscription requests via SSE."""
362-
async for chunk in self.rest_handler.on_resubscribe_to_task(request, context):
383+
async for chunk in self.rest_handler.on_subscribe_to_task(request, context):
363384
yield chunk
385+
386+
def __getstate__(self):
387+
"""Serializes the A2A agent for pickling."""
388+
from google.protobuf import json_format
389+
import json
390+
391+
state = self.__dict__.copy()
392+
393+
def _to_dict_if_proto(obj):
394+
if hasattr(obj, "DESCRIPTOR"):
395+
return {"__protobuf_AgentCard__": json.loads(json_format.MessageToJson(obj))}
396+
return obj
397+
398+
state["agent_card"] = _to_dict_if_proto(state.get("agent_card"))
399+
if "_tmpl_attrs" in state:
400+
tmpl_attrs = state["_tmpl_attrs"].copy()
401+
tmpl_attrs["agent_card"] = _to_dict_if_proto(tmpl_attrs.get("agent_card"))
402+
tmpl_attrs["extended_agent_card"] = _to_dict_if_proto(tmpl_attrs.get("extended_agent_card"))
403+
state["_tmpl_attrs"] = tmpl_attrs
404+
405+
return state
406+
407+
def __setstate__(self, state):
408+
"""Deserializes the A2A agent for unpickling."""
409+
from google.protobuf import json_format
410+
from a2a.types import AgentCard
411+
412+
def _from_dict_if_proto(obj):
413+
if isinstance(obj, dict) and "__protobuf_AgentCard__" in obj:
414+
agent_card = AgentCard()
415+
json_format.ParseDict(obj["__protobuf_AgentCard__"], agent_card)
416+
return agent_card
417+
return obj
418+
419+
state["agent_card"] = _from_dict_if_proto(state.get("agent_card"))
420+
if "_tmpl_attrs" in state:
421+
state["_tmpl_attrs"]["agent_card"] = _from_dict_if_proto(state["_tmpl_attrs"].get("agent_card"))
422+
state["_tmpl_attrs"]["extended_agent_card"] = _from_dict_if_proto(state["_tmpl_attrs"].get("extended_agent_card"))
423+
424+
self.__dict__.update(state)
425+

0 commit comments

Comments
 (0)