Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions src/a2a/client/transports/http_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from a2a.client.errors import A2AClientError, A2AClientTimeoutError


def _default_sse_error_handler(sse_data: str) -> NoReturn:
raise A2AClientError(f'SSE stream error event received: {sse_data}')


@contextmanager
def handle_http_exceptions(
status_error_handler: Callable[[httpx.HTTPStatusError], NoReturn]
Expand Down Expand Up @@ -69,11 +73,23 @@ async def send_http_stream_request(
httpx_client: httpx.AsyncClient,
method: str,
url: str,
status_error_handler: Callable[[httpx.HTTPStatusError], NoReturn]
| None = None,
status_error_handler: Callable[[httpx.HTTPStatusError], NoReturn],
sse_error_handler: Callable[[str], NoReturn] = _default_sse_error_handler,
**kwargs: Any,
) -> AsyncGenerator[str]:
"""Sends a streaming HTTP request, yielding SSE data strings and handling exceptions."""
"""Sends a streaming HTTP request, yielding SSE data strings and handling exceptions.

Args:
httpx_client: The async HTTP client.
method: The HTTP method (e.g. 'POST', 'GET').
url: The URL to send the request to.
status_error_handler: Handler for HTTP status errors. Should raise an
appropriate domain-specific exception.
sse_error_handler: Handler for SSE error events. Called with the
raw SSE data string when an ``event: error`` SSE event is received.
Should raise an appropriate domain-specific exception.
**kwargs: Additional keyword arguments forwarded to ``aconnect_sse``.
"""
with handle_http_exceptions(status_error_handler):
async with aconnect_sse(
httpx_client, method, url, **kwargs
Expand All @@ -97,4 +113,6 @@ async def send_http_stream_request(
async for sse in event_source.aiter_sse():
if not sse.data:
continue
if sse.event == 'error':
sse_error_handler(sse.data)
yield sse.data
10 changes: 9 additions & 1 deletion src/a2a/client/transports/jsonrpc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging

from collections.abc import AsyncGenerator
from typing import Any
from typing import Any, NoReturn
from uuid import uuid4

import httpx
Expand Down Expand Up @@ -349,6 +349,7 @@ async def _send_stream_request(
'POST',
self.url,
None,
self._handle_sse_error,
Comment on lines 351 to +352
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The call to send_http_stream_request passes None for the status_error_handler argument. However, the signature for send_http_stream_request in src/a2a/client/transports/http_helpers.py was changed in this PR to make this argument mandatory, no longer allowing None. This will result in a TypeError at runtime.

To fix this, you should provide a valid error handler. RestTransport defines a _handle_http_error method for this purpose. A similar method should be added to JsonRpcTransport and passed here instead of None.

Suggested change
None,
self._handle_sse_error,
self._handle_http_error,
self._handle_sse_error,

json=rpc_request_payload,
**http_kwargs,
):
Expand All @@ -359,3 +360,10 @@ async def _send_stream_request(
json_rpc_response.result, StreamResponse()
)
yield response

def _handle_sse_error(self, sse_data: str) -> NoReturn:
"""Handles SSE error events by parsing JSON-RPC error payload and raising the appropriate domain error."""
json_rpc_response = JSONRPC20Response.from_json(sse_data)
if json_rpc_response.error:
raise self._create_jsonrpc_error(json_rpc_response.error)
raise A2AClientError(f'SSE stream error: {sse_data}')
83 changes: 53 additions & 30 deletions src/a2a/client/transports/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,47 @@
logger = logging.getLogger(__name__)


def _parse_rest_error(
error_payload: dict[str, Any],
fallback_message: str,
) -> Exception | None:
"""Parses a REST error payload and returns the appropriate A2AError.

Args:
error_payload: The parsed JSON error payload.
fallback_message: Message to use if the payload has no ``message``.

Returns:
The mapped A2AError if a known reason was found, otherwise ``None``.
"""
error_data = error_payload.get('error', {})
message = error_data.get('message', fallback_message)
details = error_data.get('details', [])
if not isinstance(details, list):
return None

# The `details` array can contain multiple different error objects.
# We extract the first `ErrorInfo` object because it contains the
# specific `reason` code needed to map this back to a Python A2AError.
for d in details:
if (
isinstance(d, dict)
and d.get('@type') == 'type.googleapis.com/google.rpc.ErrorInfo'
):
reason = d.get('reason')
metadata = d.get('metadata') or {}
if isinstance(reason, str):
exception_cls = A2A_REASON_TO_ERROR.get(reason)
if exception_cls:
exc = exception_cls(message)
if metadata:
exc.data = metadata
return exc
break

return None


@trace_class(kind=SpanKind.CLIENT)
class RestTransport(ClientTransport):
"""A REST transport for the A2A client."""
Expand Down Expand Up @@ -294,39 +335,12 @@ def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn:
"""Handles HTTP status errors and raises the appropriate A2AError."""
try:
error_payload = e.response.json()
error_data = error_payload.get('error', {})

message = error_data.get('message', str(e))
details = error_data.get('details', [])
if not isinstance(details, list):
details = []

# The `details` array can contain multiple different error objects.
# We extract the first `ErrorInfo` object because it contains the
# specific `reason` code needed to map this back to a Python A2AError.
error_info = {}
for d in details:
if (
isinstance(d, dict)
and d.get('@type')
== 'type.googleapis.com/google.rpc.ErrorInfo'
):
error_info = d
break
reason = error_info.get('reason')
metadata = error_info.get('metadata') or {}

if isinstance(reason, str):
exception_cls = A2A_REASON_TO_ERROR.get(reason)
if exception_cls:
exc = exception_cls(message)
if metadata:
exc.data = metadata
raise exc from e
mapped = _parse_rest_error(error_payload, str(e))
if mapped:
raise mapped from e
except (json.JSONDecodeError, ValueError):
pass

# Fallback mappings for status codes if 'type' is missing or unknown
status_code = e.response.status_code
if status_code == httpx.codes.NOT_FOUND:
raise MethodNotFoundError(
Expand All @@ -335,6 +349,14 @@ def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn:

raise A2AClientError(f'HTTP Error {status_code}: {e}') from e

def _handle_sse_error(self, sse_data: str) -> NoReturn:
"""Handles SSE error events by parsing the REST error payload and raising the appropriate A2AError."""
error_payload = json.loads(sse_data)
mapped = _parse_rest_error(error_payload, sse_data)
if mapped:
raise mapped
raise A2AClientError(sse_data)

async def _send_stream_request(
self,
method: str,
Expand All @@ -352,6 +374,7 @@ async def _send_stream_request(
method,
f'{self.url}{path}',
self._handle_http_error,
self._handle_sse_error,
json=json,
**http_kwargs,
):
Expand Down
17 changes: 14 additions & 3 deletions src/a2a/server/apps/rest/rest_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@


if TYPE_CHECKING:
from sse_starlette.event import ServerSentEvent
from sse_starlette.sse import EventSourceResponse
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
Expand All @@ -20,6 +21,7 @@

else:
try:
from sse_starlette.event import ServerSentEvent
from sse_starlette.sse import EventSourceResponse
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
Expand All @@ -30,6 +32,7 @@
Request = Any
JSONResponse = Any
Response = Any
ServerSentEvent = Any

_package_starlette_installed = False

Expand All @@ -42,6 +45,7 @@
from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder
from a2a.types.a2a_pb2 import AgentCard
from a2a.utils.error_handlers import (
build_rest_error_payload,
rest_error_handler,
rest_stream_error_handler,
)
Expand Down Expand Up @@ -163,10 +167,17 @@ async def _handle_streaming_request(
except StopAsyncIteration:
return EventSourceResponse(iter([]))

async def event_generator() -> AsyncIterator[str]:
async def event_generator() -> AsyncIterator[str | ServerSentEvent]:
yield json.dumps(first_item)
async for item in stream:
yield json.dumps(item)
try:
async for item in stream:
yield json.dumps(item)
except Exception as e:
logger.exception('Error during REST SSE stream')
yield ServerSentEvent(
data=json.dumps(build_rest_error_payload(e)),
event='error',
)

return EventSourceResponse(event_generator())

Expand Down
26 changes: 24 additions & 2 deletions src/a2a/server/routes/jsonrpc_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,8 +559,30 @@ def _create_response(
async def event_generator(
stream: AsyncGenerator[dict[str, Any]],
) -> AsyncGenerator[dict[str, str]]:
async for item in stream:
yield {'data': json.dumps(item)}
try:
async for item in stream:
event: dict[str, str] = {
'data': json.dumps(item),
}
if 'error' in item:
event['event'] = 'error'
yield event
except Exception as e:
logger.exception(
'Unhandled error during JSON-RPC SSE stream'
)
rpc_error: A2AError | JSONRPCError = (
e
if isinstance(e, A2AError | JSONRPCError)
else InternalError(message=str(e))
)
error_response = build_error_response(
context.state.get('request_id'), rpc_error
)
yield {
'event': 'error',
'data': json.dumps(error_response),
}

return EventSourceResponse(
event_generator(handler_result), headers=headers
Expand Down
82 changes: 42 additions & 40 deletions src/a2a/utils/error_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,64 +54,66 @@
return {'error': payload}


def _create_error_response(error: Exception) -> Response:
"""Helper function to create a JSONResponse for an error."""
def build_rest_error_payload(error: Exception) -> dict[str, Any]:
"""Build a REST error payload dict from an exception.

Returns:
A dict with the error payload in the standard REST error format.
"""
if isinstance(error, A2AError):
mapping = A2A_REST_ERROR_MAPPING.get(
type(error), RestErrorMap(500, 'INTERNAL', 'INTERNAL_ERROR')
)
http_code = mapping.http_code
grpc_status = mapping.grpc_status
reason = mapping.reason
# SECURITY WARNING: Data attached to A2AError.data is serialized unaltered and exposed publicly to the client in the REST API response.
metadata = getattr(error, 'data', None) or {}
return _build_error_payload(
code=mapping.http_code,
status=mapping.grpc_status,
message=getattr(error, 'message', str(error)),
reason=mapping.reason,
metadata=metadata,
)
if isinstance(error, ParseError):
return _build_error_payload(
code=400,
status='INVALID_ARGUMENT',
message=str(error),
reason='INVALID_REQUEST',
metadata={},
)
return _build_error_payload(
code=500,
status='INTERNAL',
message='unknown exception',
)


def _create_error_response(error: Exception) -> Response:
"""Helper function to create a JSONResponse for an error."""
if isinstance(error, A2AError):
log_level = (
logging.ERROR
if isinstance(error, InternalError)
else logging.WARNING
)
logger.log(
log_level,
"Request error: Code=%s, Message='%s'%s",
getattr(error, 'code', 'N/A'),
getattr(error, 'message', str(error)),
f', Data={error.data}' if error.data else '',
)

# SECURITY WARNING: Data attached to A2AError.data is serialized unaltered and exposed publicly to the client in the REST API response.
metadata = getattr(error, 'data', None) or {}

return JSONResponse(
content=_build_error_payload(
code=http_code,
status=grpc_status,
message=getattr(error, 'message', str(error)),
reason=reason,
metadata=metadata,
),
status_code=http_code,
media_type='application/json',
)
if isinstance(error, ParseError):
elif isinstance(error, ParseError):

Check notice on line 106 in src/a2a/utils/error_handlers.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/utils/error_handlers.py (142-155)
logger.warning('Parse error: %s', str(error))
return JSONResponse(
content=_build_error_payload(
code=400,
status='INVALID_ARGUMENT',
message=str(error),
reason='INVALID_REQUEST',
metadata={},
),
status_code=400,
media_type='application/json',
)
logger.exception('Unknown error occurred')
else:
logger.exception('Unknown error occurred')

payload = build_rest_error_payload(error)
# Extract HTTP status code from the payload
http_code = payload.get('error', {}).get('code', 500)
return JSONResponse(
content=_build_error_payload(
code=500,
status='INTERNAL',
message='unknown exception',
),
status_code=500,
content=payload,
status_code=http_code,
media_type='application/json',
)

Expand Down Expand Up @@ -171,7 +173,7 @@
try:
async for item in original_iterator:
yield item
except Exception as stream_error:
except Exception as stream_error: # noqa: BLE001

Check failure on line 176 in src/a2a/utils/error_handlers.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

ruff (RUF100)

src/a2a/utils/error_handlers.py:176:56: RUF100 Unused `noqa` directive (unused: `BLE001`) help: Remove unused `noqa` directive
_log_error(stream_error)
raise stream_error

Expand Down
Loading
Loading