diff --git a/vertexai/_genai/_agent_engines_utils.py b/vertexai/_genai/_agent_engines_utils.py index ba388f4d38..9881efa2dd 100644 --- a/vertexai/_genai/_agent_engines_utils.py +++ b/vertexai/_genai/_agent_engines_utils.py @@ -632,9 +632,9 @@ def _generate_class_methods_spec_or_raise( class_method = _to_proto(schema_dict) class_method[_MODE_KEY_IN_SCHEMA] = mode if hasattr(agent, "agent_card"): - class_method[_A2A_AGENT_CARD] = getattr( - agent, "agent_card" - ).model_dump_json() + class_method[_A2A_AGENT_CARD] = json_format.MessageToJson( + getattr(agent, "agent_card") + ) class_methods_spec.append(class_method) return class_methods_spec @@ -1233,9 +1233,16 @@ def _upload_agent_engine( cloudpickle.dump(agent, f) except Exception as e: url = "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/custom#deployment-considerations" - raise TypeError( - f"Failed to serialize agent engine. Visit {url} for details." - ) from e + error_msg = f"Failed to serialize agent engine. Visit {url} for details." + if "google._upb._message" in str(e) or "Descriptor" in str(e): + error_msg += ( + " This is often caused by protobuf objects (like Part, AgentCard) " + "being imported at the global module level. Please move these " + "imports inside the functions or methods where they are used. " + "Alternatively, you can import the entire module: " + "`from a2a import types`." + ) + raise TypeError(error_msg) from e with blob.open("rb") as f: try: _ = cloudpickle.load(f) diff --git a/vertexai/_genai/agent_engines.py b/vertexai/_genai/agent_engines.py index cf6e96c943..fa5d1a25e4 100644 --- a/vertexai/_genai/agent_engines.py +++ b/vertexai/_genai/agent_engines.py @@ -1834,10 +1834,13 @@ def _create_config( agent_card = getattr(agent, "agent_card") if agent_card: try: - agent_engine_spec["agent_card"] = agent_card.model_dump( - exclude_none=True + from google.protobuf import json_format + import json + + agent_engine_spec["agent_card"] = json.loads( + json_format.MessageToJson(agent_card) ) - except TypeError as e: + except Exception as e: raise ValueError( f"Failed to convert agent card to dict (serialization error): {e}" ) from e diff --git a/vertexai/agent_engines/_agent_engines.py b/vertexai/agent_engines/_agent_engines.py index dd4e35269d..1a90f4e3dd 100644 --- a/vertexai/agent_engines/_agent_engines.py +++ b/vertexai/agent_engines/_agent_engines.py @@ -119,15 +119,18 @@ try: from a2a.types import ( AgentCard, - TransportProtocol, + AgentInterface, Message, TaskIdParams, TaskQueryParams, ) + from a2a.utils.constants import TransportProtocol, PROTOCOL_VERSION_CURRENT from a2a.client import ClientConfig, ClientFactory AgentCard = AgentCard + AgentInterface = AgentInterface TransportProtocol = TransportProtocol + PROTOCOL_VERSION_CURRENT = PROTOCOL_VERSION_CURRENT Message = Message ClientConfig = ClientConfig ClientFactory = ClientFactory @@ -135,7 +138,9 @@ TaskQueryParams = TaskQueryParams except (ImportError, AttributeError): AgentCard = None + AgentInterface = None TransportProtocol = None + PROTOCOL_VERSION_CURRENT = None Message = None ClientConfig = None ClientFactory = None @@ -1216,9 +1221,16 @@ def _upload_agent_engine( cloudpickle.dump(agent_engine, f) except Exception as e: url = "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/custom#deployment-considerations" - raise TypeError( - f"Failed to serialize agent engine. Visit {url} for details." - ) from e + error_msg = f"Failed to serialize agent engine. Visit {url} for details." + if "google._upb._message" in str(e) or "Descriptor" in str(e): + error_msg += ( + " This is often caused by protobuf objects (like Part, AgentCard) " + "being imported at the global module level. Please move these " + "imports inside the functions or methods where they are used. " + "Alternatively, you can import the entire module: " + "`from a2a import types as a2a_types`." + ) + raise TypeError(error_msg) from e with blob.open("rb") as f: try: _ = cloudpickle.load(f) @@ -1736,16 +1748,23 @@ async def _method(self, **kwargs) -> Any: # A2A + AE integration currently only supports Rest API. if ( - a2a_agent_card.preferred_transport - and a2a_agent_card.preferred_transport != TransportProtocol.http_json + a2a_agent_card.supported_interfaces + and a2a_agent_card.supported_interfaces[0].protocol_binding + != TransportProtocol.HTTP_JSON ): raise ValueError( - "Only HTTP+JSON is supported for preferred transport on agent card " + "Only HTTP+JSON is supported for primary interface on agent card " ) - # Set preferred transport to HTTP+JSON if not set. - if not hasattr(a2a_agent_card, "preferred_transport"): - a2a_agent_card.preferred_transport = TransportProtocol.http_json + # Set primary interface to HTTP+JSON if not set. + if not a2a_agent_card.supported_interfaces: + a2a_agent_card.supported_interfaces = [] + a2a_agent_card.supported_interfaces.append( + AgentInterface( + protocol_binding=TransportProtocol.HTTP_JSON, + protocol_version=PROTOCOL_VERSION_CURRENT, + ) + ) # AE cannot support streaming yet. Turn off streaming for now. if a2a_agent_card.capabilities and a2a_agent_card.capabilities.streaming: @@ -1759,12 +1778,13 @@ async def _method(self, **kwargs) -> Any: # agent_card is set on the class_methods before set_up is invoked. # Ensure that the agent_card url is set correctly before the client is created. - a2a_agent_card.url = f"https://{initializer.global_config.api_endpoint}/v1beta1/{self.resource_name}/a2a" + url = f"https://{initializer.global_config.api_endpoint}/v1beta1/{self.resource_name}/a2a" + a2a_agent_card.supported_interfaces[0].url = url # Using a2a client, inject the auth token from the global config. config = ClientConfig( supported_transports=[ - TransportProtocol.http_json, + TransportProtocol.HTTP_JSON, ], use_client_preference=True, httpx_client=httpx.AsyncClient( @@ -1977,9 +1997,11 @@ def _generate_class_methods_spec_or_raise( class_method[_MODE_KEY_IN_SCHEMA] = mode # A2A agent card is a special case, when running in A2A mode, if hasattr(agent_engine, "agent_card"): - class_method[_A2A_AGENT_CARD] = getattr( - agent_engine, "agent_card" - ).model_dump_json() + from google.protobuf import json_format + + class_method[_A2A_AGENT_CARD] = json_format.MessageToJson( + getattr(agent_engine, "agent_card") + ) class_methods_spec.append(class_method) return class_methods_spec diff --git a/vertexai/preview/reasoning_engines/templates/a2a.py b/vertexai/preview/reasoning_engines/templates/a2a.py index 724e2af41e..e4f9171308 100644 --- a/vertexai/preview/reasoning_engines/templates/a2a.py +++ b/vertexai/preview/reasoning_engines/templates/a2a.py @@ -87,7 +87,8 @@ def create_agent_card( provided. """ # pylint: disable=g-import-not-at-top - from a2a.types import AgentCard, AgentCapabilities, TransportProtocol + from a2a.types import AgentCard, AgentCapabilities, AgentInterface + from a2a.utils.constants import TransportProtocol, PROTOCOL_VERSION_CURRENT # Check if a dictionary was provided. if agent_card: @@ -98,14 +99,20 @@ def create_agent_card( return AgentCard( name=agent_name, description=description, - url="http://localhost:9999/", version="1.0.0", default_input_modes=default_input_modes or ["text/plain"], default_output_modes=default_output_modes or ["application/json"], - capabilities=AgentCapabilities(streaming=streaming), + capabilities=AgentCapabilities( + streaming=streaming, extended_agent_card=True + ), skills=skills, - preferred_transport=TransportProtocol.http_json, # Http Only. - supports_authenticated_extended_card=True, + supported_interfaces=[ + AgentInterface( + url="http://localhost:9999/", + protocol_binding=TransportProtocol.HTTP_JSON, + protocol_version=PROTOCOL_VERSION_CURRENT, + ) + ], ) # Raise an error if insufficient data is provided. @@ -162,6 +169,22 @@ async def cancel( ) +def _is_version_enabled(agent_card: "AgentCard", version: str) -> bool: + """Checks if a specific version compatibility should be enabled for the A2aAgent.""" + # pylint: disable=g-import-not-at-top + from a2a.utils.constants import TransportProtocol + + if not getattr(agent_card, "supported_interfaces", None): + return False + for interface in agent_card.supported_interfaces: + if ( + interface.protocol_version == version + and interface.protocol_binding == TransportProtocol.HTTP_JSON + ): + return True + return False + + class A2aAgent: """A class to initialize and set up an Agent-to-Agent application.""" @@ -181,14 +204,15 @@ def __init__( """Initializes the A2A agent.""" # pylint: disable=g-import-not-at-top from google.cloud.aiplatform import initializer - from a2a.types import TransportProtocol + from a2a.utils.constants import TransportProtocol if ( - agent_card.preferred_transport - and agent_card.preferred_transport != TransportProtocol.http_json + agent_card.supported_interfaces + and agent_card.supported_interfaces[0].interface.protocol_binding + != TransportProtocol.HTTP_JSON ): raise ValueError( - "Only HTTP+JSON is supported for preferred transport on agent card " + "Only HTTP+JSON is supported for the primary interface on agent card " ) self._tmpl_attrs: dict[str, Any] = { @@ -207,6 +231,7 @@ def __init__( "extended_agent_card": extended_agent_card, } self.agent_card = agent_card + self.a2a_rest_adapters = [] self.a2a_rest_adapter = None self.request_handler = None self.rest_handler = None @@ -232,7 +257,6 @@ def set_up(self): """Sets up the A2A application.""" # pylint: disable=g-import-not-at-top from a2a.server.apps.rest.rest_adapter import RESTAdapter - from a2a.server.request_handlers.rest_handler import RESTHandler from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import InMemoryTaskStore @@ -244,7 +268,21 @@ def set_up(self): agent_engine_id = os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_ID", "test-agent-engine") version = "v1beta1" - self.agent_card.url = f"https://{location}-aiplatform.googleapis.com/{version}/projects/{project}/locations/{location}/reasoningEngines/{agent_engine_id}/a2a" + new_url = f"https://{location}-aiplatform.googleapis.com/{version}/projects/{project}/locations/{location}/reasoningEngines/{agent_engine_id}/a2a" + if not self.agent_card.supported_interfaces: + from a2a.types import AgentInterface + from a2a.utils.constants import TransportProtocol, PROTOCOL_VERSION_CURRENT + + self.agent_card.supported_interfaces.append( + AgentInterface( + url=new_url, + protocol_binding=TransportProtocol.HTTP_JSON, + protocol_version=PROTOCOL_VERSION_CURRENT, + ) + ) + else: + # primary interface must be HTTP+JSON + self.agent_card.supported_interfaces[0].url = new_url self._tmpl_attrs["agent_card"] = self.agent_card # Create the agent executor if a builder is provided. @@ -286,45 +324,71 @@ def set_up(self): # a2a_rest_adapter is used to register the A2A API routes in the # Reasoning Engine API router. - self.a2a_rest_adapter = RESTAdapter( - agent_card=self.agent_card, - http_handler=self._tmpl_attrs.get("request_handler"), - extended_agent_card=self._tmpl_attrs.get("extended_agent_card"), - ) + if _is_version_enabled(self.agent_card, "1.0"): + self.a2a_rest_adapter = RESTAdapter( + agent_card=self.agent_card, + http_handler=self._tmpl_attrs.get("request_handler"), + extended_agent_card=self._tmpl_attrs.get("extended_agent_card"), + ) + self.a2a_rest_adapters.append(self.a2a_rest_adapter) - # rest_handler is used to handle the A2A API requests. - self.rest_handler = RESTHandler( - agent_card=self.agent_card, - request_handler=self._tmpl_attrs.get("request_handler"), - ) + # rest_handler is used to handle the A2A API requests. + self.rest_handler = self.a2a_rest_adapter.handler + + # v0.3 handlers will be deprecated in the future. + if _is_version_enabled(self.agent_card, "0.3"): + from a2a.compat.v0_3.rest_adapter import REST03Adapter + + adapter_03 = REST03Adapter( + agent_card=self.agent_card, + http_handler=self._tmpl_attrs.get("request_handler"), + extended_agent_card=self._tmpl_attrs.get("extended_agent_card"), + ) + self.a2a_rest_adapters.append(adapter_03) + + def _get_handler(self, kwargs: dict[str, Any]) -> Any: + handler = kwargs.get("rest_handler", getattr(self, "rest_handler", None)) + if not handler: + raise NotImplementedError("rest_handler not available.") + return handler + + def _get_adapter(self, kwargs: dict[str, Any]) -> Any: + adapter = kwargs.get("rest_adapter", getattr(self, "a2a_rest_adapter", None)) + if not adapter: + raise NotImplementedError("rest_adapter not available.") + return adapter async def on_message_send( self, request: "Request", context: "ServerCallContext", + **kwargs, ) -> dict[str, Any]: - return await self.rest_handler.on_message_send(request, context) + return await self._get_handler(kwargs).on_message_send(request, context) async def on_cancel_task( self, request: "Request", context: "ServerCallContext", + **kwargs, ) -> dict[str, Any]: - return await self.rest_handler.on_cancel_task(request, context) + return await self._get_handler(kwargs).on_cancel_task(request, context) async def on_get_task( self, request: "Request", context: "ServerCallContext", + **kwargs, ) -> dict[str, Any]: - return await self.rest_handler.on_get_task(request, context) + return await self._get_handler(kwargs).on_get_task(request, context) async def handle_authenticated_agent_card( self, request: "Request", context: "ServerCallContext", + **kwargs, ) -> dict[str, Any]: - return await self.a2a_rest_adapter.handle_authenticated_agent_card( + return await self._get_adapter(kwargs).handle_authenticated_agent_card( request, context ) @@ -339,8 +403,11 @@ def register_operations(self) -> Dict[str, List[str]]: } if self.agent_card.capabilities and self.agent_card.capabilities.streaming: routes["a2a_extension"].append("on_message_send_stream") - routes["a2a_extension"].append("on_resubscribe_to_task") - if self.agent_card.supports_authenticated_extended_card: + routes["a2a_extension"].append("on_subscribe_to_task") + if ( + self.agent_card.capabilities + and self.agent_card.capabilities.extended_agent_card + ): routes["a2a_extension"].append("handle_authenticated_agent_card") return routes @@ -348,16 +415,70 @@ async def on_message_send_stream( self, request: "Request", context: "ServerCallContext", + **kwargs, ) -> AsyncIterator[str]: """Handles A2A streaming requests via SSE.""" - async for chunk in self.rest_handler.on_message_send_stream(request, context): + async for chunk in self._get_handler(kwargs).on_message_send_stream( + request, context + ): yield chunk - async def on_resubscribe_to_task( + async def on_subscribe_to_task( self, request: "Request", context: "ServerCallContext", + **kwargs, ) -> AsyncIterator[str]: """Handles A2A task resubscription requests via SSE.""" - async for chunk in self.rest_handler.on_resubscribe_to_task(request, context): + async for chunk in self._get_handler(kwargs).on_subscribe_to_task( + request, context + ): yield chunk + + def __getstate__(self): + """Serializes AgentCard proto to a dictionary.""" + from google.protobuf import json_format + import json + + state = self.__dict__.copy() + + def _to_dict_if_proto(obj): + if hasattr(obj, "DESCRIPTOR"): + return { + "__protobuf_AgentCard__": json.loads(json_format.MessageToJson(obj)) + } + return obj + + state["agent_card"] = _to_dict_if_proto(state.get("agent_card")) + if "_tmpl_attrs" in state: + tmpl_attrs = state["_tmpl_attrs"].copy() + tmpl_attrs["agent_card"] = _to_dict_if_proto(tmpl_attrs.get("agent_card")) + tmpl_attrs["extended_agent_card"] = _to_dict_if_proto( + tmpl_attrs.get("extended_agent_card") + ) + state["_tmpl_attrs"] = tmpl_attrs + + return state + + def __setstate__(self, state): + """Deserializes AgentCard proto from a dictionary.""" + from google.protobuf import json_format + from a2a.types import AgentCard + + def _from_dict_if_proto(obj): + if isinstance(obj, dict) and "__protobuf_AgentCard__" in obj: + agent_card = AgentCard() + json_format.ParseDict(obj["__protobuf_AgentCard__"], agent_card) + return agent_card + return obj + + state["agent_card"] = _from_dict_if_proto(state.get("agent_card")) + if "_tmpl_attrs" in state: + state["_tmpl_attrs"]["agent_card"] = _from_dict_if_proto( + state["_tmpl_attrs"].get("agent_card") + ) + state["_tmpl_attrs"]["extended_agent_card"] = _from_dict_if_proto( + state["_tmpl_attrs"].get("extended_agent_card") + ) + + self.__dict__.update(state)