|
50 | 50 | TaskStatus, |
51 | 51 | TaskStatusUpdateEvent, |
52 | 52 | ) |
53 | | -from a2a.utils.constants import TransportProtocol |
| 53 | +from a2a.utils.constants import ( |
| 54 | + TransportProtocol, |
| 55 | +) |
| 56 | +from a2a.utils.errors import ( |
| 57 | + ExtendedAgentCardNotConfiguredError, |
| 58 | + ContentTypeNotSupportedError, |
| 59 | + ExtensionSupportRequiredError, |
| 60 | + InternalError, |
| 61 | + InvalidAgentResponseError, |
| 62 | + InvalidParamsError, |
| 63 | + InvalidRequestError, |
| 64 | + MethodNotFoundError, |
| 65 | + PushNotificationNotSupportedError, |
| 66 | + TaskNotCancelableError, |
| 67 | + TaskNotFoundError, |
| 68 | + UnsupportedOperationError, |
| 69 | + VersionNotSupportedError, |
| 70 | +) |
54 | 71 | from a2a.utils.signing import ( |
55 | 72 | create_agent_card_signer, |
56 | 73 | create_signature_verifier, |
@@ -788,6 +805,43 @@ async def test_client_get_signed_base_and_extended_cards( |
788 | 805 | await client.close() |
789 | 806 |
|
790 | 807 |
|
| 808 | +@pytest.mark.asyncio |
| 809 | +@pytest.mark.parametrize( |
| 810 | + 'error_cls', |
| 811 | + [ |
| 812 | + TaskNotFoundError, |
| 813 | + TaskNotCancelableError, |
| 814 | + PushNotificationNotSupportedError, |
| 815 | + UnsupportedOperationError, |
| 816 | + ContentTypeNotSupportedError, |
| 817 | + InvalidAgentResponseError, |
| 818 | + ExtendedAgentCardNotConfiguredError, |
| 819 | + ExtensionSupportRequiredError, |
| 820 | + VersionNotSupportedError, |
| 821 | + ], |
| 822 | +) |
| 823 | +async def test_client_handles_a2a_errors(transport_setups, error_cls) -> None: |
| 824 | + """Integration test to verify error propagation from handler to client.""" |
| 825 | + client = transport_setups.client |
| 826 | + handler = transport_setups.handler |
| 827 | + |
| 828 | + # Mock the handler to raise the error |
| 829 | + handler.on_get_task.side_effect = error_cls('Test error message') |
| 830 | + |
| 831 | + params = GetTaskRequest(id='some-id') |
| 832 | + |
| 833 | + # We expect the client to raise the same error_cls. |
| 834 | + with pytest.raises(error_cls) as exc_info: |
| 835 | + await client.get_task(request=params) |
| 836 | + |
| 837 | + assert 'Test error message' in str(exc_info.value) |
| 838 | + |
| 839 | + # Reset side_effect for other tests |
| 840 | + handler.on_get_task.side_effect = None |
| 841 | + |
| 842 | + await client.close() |
| 843 | + |
| 844 | + |
791 | 845 | @pytest.mark.asyncio |
792 | 846 | @pytest.mark.parametrize( |
793 | 847 | 'request_kwargs, expected_error_code', |
|
0 commit comments