|
1 | 1 | import logging |
| 2 | +import json |
2 | 3 |
|
3 | 4 | from typing import Any |
4 | 5 | from unittest.mock import MagicMock |
@@ -339,6 +340,71 @@ async def mock_stream_response(): |
339 | 340 | request_handler.on_message_send_stream.assert_called_once() |
340 | 341 |
|
341 | 342 |
|
| 343 | +@pytest.mark.anyio |
| 344 | +async def test_streaming_content_verification( |
| 345 | + streaming_client: AsyncClient, request_handler: MagicMock |
| 346 | +) -> None: |
| 347 | + """Test that streaming endpoint returns correct SSE content.""" |
| 348 | + |
| 349 | + async def mock_stream_response(): |
| 350 | + yield Message( |
| 351 | + message_id='stream_msg_1', |
| 352 | + role=Role.ROLE_AGENT, |
| 353 | + parts=[Part(text='First chunk')], |
| 354 | + ) |
| 355 | + yield Message( |
| 356 | + message_id='stream_msg_2', |
| 357 | + role=Role.ROLE_AGENT, |
| 358 | + parts=[Part(text='Second chunk')], |
| 359 | + ) |
| 360 | + |
| 361 | + request_handler.on_message_send_stream.return_value = mock_stream_response() |
| 362 | + |
| 363 | + request = a2a_pb2.SendMessageRequest( |
| 364 | + message=a2a_pb2.Message( |
| 365 | + message_id='test_stream_msg', |
| 366 | + role=a2a_pb2.ROLE_USER, |
| 367 | + parts=[a2a_pb2.Part(text='Test message')], |
| 368 | + ), |
| 369 | + ) |
| 370 | + |
| 371 | + response = await streaming_client.post( |
| 372 | + '/message:stream', |
| 373 | + json=json_format.MessageToDict(request), |
| 374 | + headers={'Accept': 'text/event-stream'}, |
| 375 | + ) |
| 376 | + |
| 377 | + response.raise_for_status() |
| 378 | + |
| 379 | + # Read the response content |
| 380 | + lines = [line async for line in response.aiter_lines()] |
| 381 | + |
| 382 | + # SSE format is "data: <json>\n\n" |
| 383 | + # httpx.aiter_lines() will give us each line. |
| 384 | + data_lines = [ |
| 385 | + json.loads(line[6:]) for line in lines if line.startswith('data: ') |
| 386 | + ] |
| 387 | + |
| 388 | + expected_data_lines = [ |
| 389 | + { |
| 390 | + 'message': { |
| 391 | + 'messageId': 'stream_msg_1', |
| 392 | + 'role': 'ROLE_AGENT', |
| 393 | + 'parts': [{'text': 'First chunk'}], |
| 394 | + } |
| 395 | + }, |
| 396 | + { |
| 397 | + 'message': { |
| 398 | + 'messageId': 'stream_msg_2', |
| 399 | + 'role': 'ROLE_AGENT', |
| 400 | + 'parts': [{'text': 'Second chunk'}], |
| 401 | + } |
| 402 | + }, |
| 403 | + ] |
| 404 | + |
| 405 | + assert data_lines == expected_data_lines |
| 406 | + |
| 407 | + |
342 | 408 | @pytest.mark.anyio |
343 | 409 | async def test_streaming_endpoint_with_invalid_content_type( |
344 | 410 | streaming_client: AsyncClient, request_handler: MagicMock |
|
0 commit comments