Skip to content

Commit 7a6c55d

Browse files
authored
fix: Avoid sending a single json message in multiple chunks. (#837)
Fixes #742
1 parent 7e6b3c2 commit 7a6c55d

6 files changed

Lines changed: 85 additions & 18 deletions

File tree

src/a2a/compat/v0_3/rest_adapter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import json
23
import logging
34

45
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
@@ -103,9 +104,9 @@ async def _handle_streaming_request(
103104

104105
async def event_generator(
105106
stream: AsyncIterable[Any],
106-
) -> AsyncIterator[dict[str, dict[str, Any]]]:
107+
) -> AsyncIterator[str]:
107108
async for item in stream:
108-
yield {'data': item}
109+
yield json.dumps(item)
109110

110111
return EventSourceResponse(
111112
event_generator(method(request, call_context))

src/a2a/compat/v0_3/rest_handler.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import logging
22

3-
from collections.abc import AsyncIterable, AsyncIterator
3+
from collections.abc import AsyncIterator
44
from typing import TYPE_CHECKING, Any
55

6-
from google.protobuf.json_format import MessageToDict, MessageToJson, Parse
6+
from google.protobuf.json_format import MessageToDict, Parse
77

88

99
if TYPE_CHECKING:
@@ -86,7 +86,7 @@ async def on_message_send_stream(
8686
self,
8787
request: Request,
8888
context: ServerCallContext,
89-
) -> AsyncIterator[str]:
89+
) -> AsyncIterator[dict[str, Any]]:
9090
"""Handles the 'message/stream' REST method.
9191
9292
Args:
@@ -108,7 +108,7 @@ async def on_message_send_stream(
108108
v03_pb_resp = proto_utils.ToProto.stream_response(
109109
v03_stream_resp.result
110110
)
111-
yield MessageToJson(v03_pb_resp)
111+
yield MessageToDict(v03_pb_resp)
112112

113113
async def on_cancel_task(
114114
self,
@@ -142,7 +142,7 @@ async def on_subscribe_to_task(
142142
self,
143143
request: Request,
144144
context: ServerCallContext,
145-
) -> AsyncIterable[str]:
145+
) -> AsyncIterator[dict[str, Any]]:
146146
"""Handles the 'tasks/{id}:subscribe' REST method.
147147
148148
Args:
@@ -164,7 +164,7 @@ async def on_subscribe_to_task(
164164
v03_pb_resp = proto_utils.ToProto.stream_response(
165165
v03_stream_resp.result
166166
)
167-
yield MessageToJson(v03_pb_resp)
167+
yield MessageToDict(v03_pb_resp)
168168

169169
async def get_push_notification(
170170
self,

src/a2a/server/apps/rest/rest_adapter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import json
23
import logging
34

45
from abc import ABC, abstractmethod
@@ -150,9 +151,9 @@ async def _handle_streaming_request(
150151

151152
async def event_generator(
152153
stream: AsyncIterable[Any],
153-
) -> AsyncIterator[dict[str, dict[str, Any]]]:
154+
) -> AsyncIterator[str]:
154155
async for item in stream:
155-
yield {'data': item}
156+
yield json.dumps(item)
156157

157158
return EventSourceResponse(
158159
event_generator(method(request, call_context))

src/a2a/server/request_handlers/rest_handler.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import logging
22

3-
from collections.abc import AsyncIterable, AsyncIterator
3+
from collections.abc import AsyncIterator
44
from typing import TYPE_CHECKING, Any
55

66
from google.protobuf.json_format import (
77
MessageToDict,
8-
MessageToJson,
98
Parse,
109
)
1110

@@ -96,7 +95,7 @@ async def on_message_send_stream(
9695
self,
9796
request: Request,
9897
context: ServerCallContext,
99-
) -> AsyncIterator[str]:
98+
) -> AsyncIterator[dict[str, Any]]:
10099
"""Handles the 'message/stream' REST method.
101100
102101
Yields response objects as they are produced by the underlying handler's stream.
@@ -116,7 +115,7 @@ async def on_message_send_stream(
116115
params, context
117116
):
118117
response = proto_utils.to_stream_response(event)
119-
yield MessageToJson(response)
118+
yield MessageToDict(response)
120119

121120
async def on_cancel_task(
122121
self,
@@ -148,7 +147,7 @@ async def on_subscribe_to_task(
148147
self,
149148
request: Request,
150149
context: ServerCallContext,
151-
) -> AsyncIterable[str]:
150+
) -> AsyncIterator[dict[str, Any]]:
152151
"""Handles the 'SubscribeToTask' REST method.
153152
154153
Yields response objects as they are produced by the underlying handler's stream.
@@ -164,7 +163,7 @@ async def on_subscribe_to_task(
164163
async for event in self.request_handler.on_subscribe_to_task(
165164
SubscribeToTaskRequest(id=task_id), context
166165
):
167-
yield MessageToJson(proto_utils.to_stream_response(event))
166+
yield MessageToDict(proto_utils.to_stream_response(event))
168167

169168
async def get_push_notification(
170169
self,

tests/compat/v0_3/test_rest_handler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ async def mock_stream(*args, **kwargs):
110110
)
111111

112112
results = [
113-
json.loads(chunk)
113+
chunk
114114
async for chunk in rest_handler.on_message_send_stream(
115115
mock_request, mock_context
116116
)
@@ -169,7 +169,7 @@ async def mock_stream(*args, **kwargs):
169169
)
170170

171171
results = [
172-
json.loads(chunk)
172+
chunk
173173
async for chunk in rest_handler.on_subscribe_to_task(
174174
mock_request, mock_context
175175
)

tests/server/apps/rest/test_rest_fastapi_app.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import json
23

34
from typing import Any
45
from unittest.mock import MagicMock
@@ -339,6 +340,71 @@ async def mock_stream_response():
339340
request_handler.on_message_send_stream.assert_called_once()
340341

341342

343+
@pytest.mark.anyio
344+
async def test_streaming_content_verification(
345+
streaming_client: AsyncClient, request_handler: MagicMock
346+
) -> None:
347+
"""Test that streaming endpoint returns correct SSE content."""
348+
349+
async def mock_stream_response():
350+
yield Message(
351+
message_id='stream_msg_1',
352+
role=Role.ROLE_AGENT,
353+
parts=[Part(text='First chunk')],
354+
)
355+
yield Message(
356+
message_id='stream_msg_2',
357+
role=Role.ROLE_AGENT,
358+
parts=[Part(text='Second chunk')],
359+
)
360+
361+
request_handler.on_message_send_stream.return_value = mock_stream_response()
362+
363+
request = a2a_pb2.SendMessageRequest(
364+
message=a2a_pb2.Message(
365+
message_id='test_stream_msg',
366+
role=a2a_pb2.ROLE_USER,
367+
parts=[a2a_pb2.Part(text='Test message')],
368+
),
369+
)
370+
371+
response = await streaming_client.post(
372+
'/message:stream',
373+
json=json_format.MessageToDict(request),
374+
headers={'Accept': 'text/event-stream'},
375+
)
376+
377+
response.raise_for_status()
378+
379+
# Read the response content
380+
lines = [line async for line in response.aiter_lines()]
381+
382+
# SSE format is "data: <json>\n\n"
383+
# httpx.aiter_lines() will give us each line.
384+
data_lines = [
385+
json.loads(line[6:]) for line in lines if line.startswith('data: ')
386+
]
387+
388+
expected_data_lines = [
389+
{
390+
'message': {
391+
'messageId': 'stream_msg_1',
392+
'role': 'ROLE_AGENT',
393+
'parts': [{'text': 'First chunk'}],
394+
}
395+
},
396+
{
397+
'message': {
398+
'messageId': 'stream_msg_2',
399+
'role': 'ROLE_AGENT',
400+
'parts': [{'text': 'Second chunk'}],
401+
}
402+
},
403+
]
404+
405+
assert data_lines == expected_data_lines
406+
407+
342408
@pytest.mark.anyio
343409
async def test_streaming_endpoint_with_invalid_content_type(
344410
streaming_client: AsyncClient, request_handler: MagicMock

0 commit comments

Comments
 (0)