1414 TransportProtocol ,
1515 TaskQueryParams ,
1616 TaskIdParams ,
17+ TaskState ,
1718 TaskPushNotificationConfig ,
1819 PushNotificationConfig ,
20+ FilePart ,
21+ FileWithUri ,
22+ FileWithBytes ,
23+ DataPart ,
1924)
2025from a2a .client .errors import A2AClientJSONRPCError , A2AClientHTTPError
2126import sys
27+ import traceback
2228
2329
2430async def test_send_message_stream (client ):
2531 print ('Testing send_message (streaming)...' )
32+
2633 msg = Message (
2734 role = Role .user ,
2835 message_id = f'stream-{ uuid4 ()} ' ,
29- parts = [Part (root = TextPart (text = 'stream' ))],
30- metadata = {'test_key' : 'test_value' },
36+ parts = [
37+ Part (root = TextPart (text = 'stream' )),
38+ Part (
39+ root = FilePart (
40+ file = FileWithUri (
41+ uri = 'https://example.com/file.txt' ,
42+ mime_type = 'text/plain' ,
43+ )
44+ )
45+ ),
46+ Part (
47+ root = FilePart (
48+ file = FileWithBytes (
49+ bytes = b'aGVsbG8=' , mime_type = 'application/octet-stream'
50+ )
51+ )
52+ ),
53+ Part (root = DataPart (data = {'key' : 'value' })),
54+ ],
55+ metadata = {'test_key' : 'full_message' },
3156 )
3257 events = []
3358
@@ -62,38 +87,43 @@ async def test_send_message_sync(url, protocol_enum):
6287 role = Role .user ,
6388 message_id = f'sync-{ uuid4 ()} ' ,
6489 parts = [Part (root = TextPart (text = 'sync' ))],
65- metadata = {'test_key' : 'test_value ' },
90+ metadata = {'test_key' : 'simple_message ' },
6691 )
6792
68- # In v0.3 SDK, send_message ALWAYS returns an async generator
6993 async for event in client .send_message (request = msg ):
7094 assert event is not None
7195 event_obj = event [0 ] if isinstance (event , tuple ) else event
72- if (
73- getattr (event_obj , 'status' , None )
74- and getattr (event_obj .status , 'state' , None )
75- == 'TASK_STATE_COMPLETED'
76- ):
77- assert (
78- getattr (event_obj .status .message , 'metadata' , {}).get (
79- 'response_key'
80- )
81- == 'response_value'
82- ), (
83- f'Missing response metadata: { getattr (event_obj .status .message , "metadata" , {})} '
96+
97+ status = getattr (event_obj , 'status' , None )
98+ if status and str (getattr (status , 'state' , '' )).endswith ('completed' ):
99+ # In 0.3 SDK, the message on the status might be exposed as 'message' or 'update'
100+ status_msg = getattr (
101+ status , 'message' , getattr (status , 'update' , None )
84102 )
85- elif getattr (event_obj , 'status' , None ) and str (
86- getattr (event_obj .status , 'state' , None )
87- ).endswith ('completed' ):
88- assert (
89- getattr (event_obj .status .message , 'metadata' , {}).get (
90- 'response_key'
91- )
92- == 'response_value'
93- ), (
94- f'Missing response metadata: { getattr (event_obj .status .message , "metadata" , {})} '
103+ assert status_msg is not None , (
104+ 'TaskStatus message/update is missing'
95105 )
96- break
106+
107+ metadata = getattr (status_msg , 'metadata' , {})
108+ assert metadata .get ('response_key' ) == 'response_value' , (
109+ f'Missing response metadata: { metadata } '
110+ )
111+
112+ # Check Part translation (root text part in 0.3)
113+ parts = getattr (
114+ status_msg , 'parts' , getattr (status_msg , 'content' , [])
115+ )
116+ assert len (parts ) > 0 , 'No parts found in TaskStatus message'
117+ first_part = parts [0 ]
118+ text = getattr (first_part , 'text' , '' )
119+ if (
120+ not text
121+ and hasattr (first_part , 'root' )
122+ and hasattr (first_part .root , 'text' )
123+ ):
124+ text = first_part .root .text
125+ assert text == 'done' , f"Expected 'done' text in Part, got '{ text } '"
126+ break
97127
98128 print (f'Success: send_message (synchronous) passed.' )
99129
@@ -102,20 +132,73 @@ async def test_get_task(client, task_id):
102132 print (f'Testing get_task ({ task_id } )...' )
103133 task = await client .get_task (request = TaskQueryParams (id = task_id ))
104134 assert task .id == task_id
135+
136+ user_msgs = [
137+ m for m in task .history if getattr (m , 'role' , None ) == Role .user
138+ ]
139+ assert user_msgs , 'Expected at least one ROLE_USER message in task history'
140+
141+ client_msg = user_msgs [0 ]
142+
143+ parts = client_msg .parts
144+ assert len (parts ) == 4 , f'Expected 4 parts, got { len (parts )} '
145+
146+ # 1. text part
147+ text = getattr (parts [0 ].root , 'text' , '' )
148+ assert text == 'stream' , f"Expected 'stream', got { text } "
149+
150+ # 2. uri part
151+ file_uri = getattr (parts [1 ].root , 'file' , None )
152+ assert (
153+ file_uri is not None
154+ and getattr (file_uri , 'uri' , None ) == 'https://example.com/file.txt'
155+ )
156+
157+ # 3. bytes part
158+ file_bytes = getattr (parts [2 ].root , 'file' , None )
159+ actual_bytes = getattr (file_bytes , 'bytes' , None )
160+ assert actual_bytes == 'aGVsbG8=' , (
161+ f"Expected base64 'hello', got { actual_bytes } "
162+ )
163+
164+ # 4. data part
165+ data_val = getattr (parts [3 ].root , 'data' , None )
166+ assert data_val is not None
167+ assert data_val == {'key' : 'value' }
168+
105169 print ('Success: get_task passed.' )
106170
107171
108172async def test_cancel_task (client , task_id ):
109173 print (f'Testing cancel_task ({ task_id } )...' )
110174 await client .cancel_task (request = TaskIdParams (id = task_id ))
175+ task = await client .get_task (request = TaskQueryParams (id = task_id ))
176+ assert task .status .state == TaskState .canceled , (
177+ f'Expected a canceled state, got { task .status .state } '
178+ )
111179 print ('Success: cancel_task passed.' )
112180
113181
114182async def test_subscribe (client , task_id ):
115183 print (f'Testing subscribe ({ task_id } )...' )
184+ has_artifact = False
116185 async for event in client .resubscribe (request = TaskIdParams (id = task_id )):
117- print (f'Received event: { event } ' )
118- break
186+ # event is tuple (Task, UpdateEvent)
187+ task , update = event
188+ if update and hasattr (update , 'artifact' ):
189+ has_artifact = True
190+ artifact = update .artifact
191+ assert artifact .name == 'test-artifact'
192+ assert artifact .metadata .get ('artifact_key' ) == 'artifact_value'
193+ # part check
194+ assert len (artifact .parts ) > 0
195+ p = artifact .parts [0 ]
196+ text = getattr (p .root , 'text' , '' )
197+ assert text == 'artifact-chunk'
198+ print ('Success: received artifact update.' )
199+
200+ if has_artifact :
201+ break
119202 print ('Success: subscribe passed.' )
120203
121204
@@ -124,7 +207,27 @@ async def test_get_extended_agent_card(client):
124207 # In v0.3, extended card is fetched via get_card() on the client
125208 card = await client .get_card ()
126209 assert card is not None
127- # the MockAgentExecutor might not have a name or has one, just assert card exists
210+ assert card .name in ('Server 0.3' , 'Server 1.0' )
211+ assert card .version == '1.0.0'
212+ assert 'Server running on a2a v' in card .description
213+
214+ assert card .capabilities is not None
215+ assert card .capabilities .streaming is True
216+ assert card .capabilities .push_notifications is True
217+
218+ if card .name == 'Server 0.3' :
219+ assert card .url is not None
220+ assert card .preferred_transport == TransportProtocol .jsonrpc
221+ assert len (card .additional_interfaces ) == 2
222+ assert card .supports_authenticated_extended_card is False
223+ else :
224+ assert card .url is not None
225+ assert card .preferred_transport is not None
226+ print (
227+ f'card.supports_authenticated_extended_card is: { card .supports_authenticated_extended_card } '
228+ )
229+ assert card .supports_authenticated_extended_card in (False , None )
230+
128231 print (f'Success: get_extended_agent_card passed.' )
129232
130233
@@ -177,8 +280,6 @@ def main():
177280 try :
178281 asyncio .run (run_client (args .url , protocol ))
179282 except Exception as e :
180- import traceback
181-
182283 traceback .print_exc ()
183284 print (f'FAILED protocol { protocol } : { e } ' )
184285 failed = True
0 commit comments