Skip to content

Commit 0259343

Browse files
committed
wip
1 parent d80ec70 commit 0259343

5 files changed

Lines changed: 31 additions & 20 deletions

File tree

src/a2a/client/transports/grpc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn:
7676
data = {'errors': errors}
7777
exception_cls = InvalidParamsError
7878
break
79-
elif detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR):
79+
if detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR):
8080
error_info = error_details_pb2.ErrorInfo()
8181
detail.Unpack(error_info)
8282
if error_info.domain == 'a2a-protocol.org':

src/a2a/server/apps/rest/rest_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ async def event_generator(
159159
yield json.dumps(item)
160160

161161
return EventSourceResponse(
162-
event_generator(method(request, call_context))
162+
event_generator(await method(request, call_context))
163163
)
164164

165165
async def handle_get_agent_card(

src/a2a/server/request_handlers/grpc_handler.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,11 @@ async def abort_context(
406406
# Create standard Status
407407
status = status_pb2.Status(code=status_code, message=error_msg)
408408

409-
if isinstance(error, types.InvalidParamsError) and error.data and error.data.get('errors'):
409+
if (
410+
isinstance(error, types.InvalidParamsError)
411+
and error.data
412+
and error.data.get('errors')
413+
):
410414
bad_request = error_details_pb2.BadRequest()
411415
for err_dict in error.data['errors']:
412416
violation = bad_request.field_violations.add()

src/a2a/server/request_handlers/request_handler.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -229,25 +229,24 @@ async def on_delete_task_push_notification_config(
229229

230230
def validate_request_params(method: Callable) -> Callable:
231231
"""Decorator for RequestHandler methods to validate required fields on incoming requests."""
232-
if inspect.isasyncgenfunction(method):
232+
if inspect.iscoroutinefunction(method):
233233

234234
@functools.wraps(method)
235-
async def async_generator_wrapper(
235+
async def async_wrapper(
236236
self: RequestHandler,
237237
params: ProtoMessage,
238238
context: ServerCallContext,
239239
*args: Any,
240240
**kwargs: Any,
241-
) -> AsyncGenerator:
241+
) -> Any:
242242
if params is not None:
243243
validate_proto_required_fields(params)
244-
async for item in method(self, params, context, *args, **kwargs):
245-
yield item
244+
return await method(self, params, context, *args, **kwargs)
246245

247-
return async_generator_wrapper
246+
return async_wrapper
248247

249248
@functools.wraps(method)
250-
async def async_wrapper(
249+
def sync_wrapper(
251250
self: RequestHandler,
252251
params: ProtoMessage,
253252
context: ServerCallContext,
@@ -256,6 +255,6 @@ async def async_wrapper(
256255
) -> Any:
257256
if params is not None:
258257
validate_proto_required_fields(params)
259-
return await method(self, params, context, *args, **kwargs)
258+
return method(self, params, context, *args, **kwargs)
260259

261-
return async_wrapper
260+
return sync_wrapper

src/a2a/server/request_handlers/rest_handler.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,14 @@ async def on_message_send_stream(
116116
body = await request.body()
117117
params = a2a_pb2.SendMessageRequest()
118118
Parse(body, params)
119-
async for event in self.request_handler.on_message_send_stream(
120-
params, context
121-
):
122-
response = proto_utils.to_stream_response(event)
123-
yield MessageToDict(response)
119+
stream = self.request_handler.on_message_send_stream(params, context)
120+
121+
async def _generator() -> AsyncIterator[dict[str, Any]]:
122+
async for event in stream:
123+
response = proto_utils.to_stream_response(event)
124+
yield MessageToDict(response)
125+
126+
return _generator()
124127

125128
@validate_version(constants.PROTOCOL_VERSION_1_0)
126129
async def on_cancel_task(
@@ -167,10 +170,15 @@ async def on_subscribe_to_task(
167170
JSON serialized objects containing streaming events
168171
"""
169172
task_id = request.path_params['id']
170-
async for event in self.request_handler.on_subscribe_to_task(
173+
stream = self.request_handler.on_subscribe_to_task(
171174
SubscribeToTaskRequest(id=task_id), context
172-
):
173-
yield MessageToDict(proto_utils.to_stream_response(event))
175+
)
176+
177+
async def _generator() -> AsyncIterator[dict[str, Any]]:
178+
async for event in stream:
179+
yield MessageToDict(proto_utils.to_stream_response(event))
180+
181+
return _generator()
174182

175183
@validate_version(constants.PROTOCOL_VERSION_1_0)
176184
async def get_push_notification(

0 commit comments

Comments
 (0)