Skip to content

Commit 73e5445

Browse files
committed
WIP
1 parent ad1ff53 commit 73e5445

3 files changed

Lines changed: 18 additions & 9 deletions

File tree

src/a2a/client/transports/grpc.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
TaskPushNotificationConfig,
4848
)
4949
from a2a.utils.constants import PROTOCOL_VERSION_CURRENT, VERSION_HEADER
50-
from a2a.utils.errors import A2A_REASON_TO_ERROR
50+
from a2a.utils.errors import InvalidParamsError
5151
from a2a.utils.telemetry import SpanKind, trace_class
5252

5353

@@ -74,8 +74,7 @@ def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn:
7474
for v in bad_request.field_violations
7575
]
7676
data = {'errors': errors}
77-
# Infer InvalidParamsError from BadRequest details
78-
exception_cls = A2A_REASON_TO_ERROR.get('INVALID_PARAMS')
77+
exception_cls = InvalidParamsError
7978
elif detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR):
8079
error_info = error_details_pb2.ErrorInfo()
8180
detail.Unpack(error_info)

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -441,8 +441,7 @@ async def abort_context(
441441
# Create standard Status
442442
status = status_pb2.Status(code=status_code, message=error_msg)
443443

444-
# Exclusive details based on error type:
445-
if error.data and error.data.get('errors'):
444+
if isinstance(error, types.InvalidParamsError) and error.data and error.data.get('errors'):
446445
bad_request = error_details_pb2.BadRequest()
447446
for err_dict in error.data['errors']:
448447
violation = bad_request.field_violations.add()

tests/integration/test_end_to_end.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -570,11 +570,22 @@ async def test_end_to_end_validation_errors(
570570
) -> None:
571571
client = transport_setups.client
572572

573-
with pytest.raises(InvalidParamsError) as exc_info:
573+
try:
574574
async for _ in client.send_message(request=empty_request):
575575
pass
576-
577-
errors = exc_info.value.data.get('errors', [])
578-
assert {e['field'] for e in errors} == set(expected_fields)
576+
except Exception as e:
577+
# ASGITransport propagates server-side generator crashes as ExceptionGroups
578+
exc = e
579+
if hasattr(e, 'exceptions') and len(e.exceptions) == 1:
580+
exc = e.exceptions[0]
581+
582+
if not isinstance(exc, InvalidParamsError):
583+
raise e
584+
585+
errors = exc.data.get('errors', []) if exc.data else []
586+
assert {e['field'] for e in errors} == set(expected_fields)
587+
return
588+
589+
pytest.fail('InvalidParamsError was not raised')
579590

580591
await client.close()

0 commit comments

Comments
 (0)