Skip to content

Commit 4823b5d

Browse files
committed
feat(compat): Tests and unified REST url.
1 parent a102d31 commit 4823b5d

9 files changed

Lines changed: 476 additions & 98 deletions

File tree

src/a2a/server/apps/rest/fastapi_app.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,15 @@ def build(
121121
A configured FastAPI application instance.
122122
"""
123123
app = FastAPI(**kwargs)
124+
if self.enable_v0_3_compat and self._v03_adapter:
125+
v03_adapter = self._v03_adapter
126+
v03_router = APIRouter()
127+
for route, callback in v03_adapter.routes().items():
128+
v03_router.add_api_route(
129+
f'{rpc_url}{route[0]}', callback, methods=[route[1]]
130+
)
131+
app.include_router(v03_router)
132+
124133
router = APIRouter()
125134
for route, callback in self._adapter.routes().items():
126135
router.add_api_route(
@@ -134,19 +143,4 @@ async def get_agent_card(request: Request) -> Response:
134143

135144
app.include_router(router)
136145

137-
if self.enable_v0_3_compat and self._v03_adapter:
138-
v03_adapter = self._v03_adapter
139-
v03_router = APIRouter()
140-
for route, callback in v03_adapter.routes().items():
141-
v03_router.add_api_route(
142-
f'{rpc_url}/v0.3{route[0]}', callback, methods=[route[1]]
143-
)
144-
145-
@v03_router.get(f'{rpc_url}/v0.3{agent_card_url}')
146-
async def get_v03_agent_card(request: Request) -> Response:
147-
card = await v03_adapter.handle_get_agent_card(request)
148-
return JSONResponse(card)
149-
150-
app.include_router(v03_router)
151-
152146
return app

tests/compat/v0_3/test_rest_fastapi_app_compat.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ async def test_send_message_success_message_v03(
9292
)
9393

9494
response = await client.post(
95-
'/v0.3/v1/message:send', json=json_format.MessageToDict(request)
95+
'/v1/message:send', json=json_format.MessageToDict(request)
9696
)
9797
response.raise_for_status()
9898

@@ -127,7 +127,7 @@ async def test_send_message_success_task_v03(
127127
)
128128

129129
response = await client.post(
130-
'/v0.3/v1/message:send', json=json_format.MessageToDict(request)
130+
'/v1/message:send', json=json_format.MessageToDict(request)
131131
)
132132
response.raise_for_status()
133133

@@ -155,7 +155,7 @@ async def test_get_task_v03(
155155
),
156156
)
157157

158-
response = await client.get('/v0.3/v1/tasks/test_task_id')
158+
response = await client.get('/v1/tasks/test_task_id')
159159
response.raise_for_status()
160160

161161
actual_response = a2a_v0_3_pb2.Task()
@@ -182,7 +182,7 @@ async def test_cancel_task_v03(
182182
),
183183
)
184184

185-
response = await client.post('/v0.3/v1/tasks/test_task_id:cancel')
185+
response = await client.post('/v1/tasks/test_task_id:cancel')
186186
response.raise_for_status()
187187

188188
actual_response = a2a_v0_3_pb2.Task()

tests/integration/cross_version/client_server/client_0_3.py

Lines changed: 133 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,45 @@
1414
TransportProtocol,
1515
TaskQueryParams,
1616
TaskIdParams,
17+
TaskState,
1718
TaskPushNotificationConfig,
1819
PushNotificationConfig,
20+
FilePart,
21+
FileWithUri,
22+
FileWithBytes,
23+
DataPart,
1924
)
2025
from a2a.client.errors import A2AClientJSONRPCError, A2AClientHTTPError
2126
import sys
27+
import traceback
2228

2329

2430
async def test_send_message_stream(client):
2531
print('Testing send_message (streaming)...')
32+
2633
msg = Message(
2734
role=Role.user,
2835
message_id=f'stream-{uuid4()}',
29-
parts=[Part(root=TextPart(text='stream'))],
30-
metadata={'test_key': 'test_value'},
36+
parts=[
37+
Part(root=TextPart(text='stream')),
38+
Part(
39+
root=FilePart(
40+
file=FileWithUri(
41+
uri='https://example.com/file.txt',
42+
mime_type='text/plain',
43+
)
44+
)
45+
),
46+
Part(
47+
root=FilePart(
48+
file=FileWithBytes(
49+
bytes=b'aGVsbG8=', mime_type='application/octet-stream'
50+
)
51+
)
52+
),
53+
Part(root=DataPart(data={'key': 'value'})),
54+
],
55+
metadata={'test_key': 'full_message'},
3156
)
3257
events = []
3358

@@ -62,38 +87,43 @@ async def test_send_message_sync(url, protocol_enum):
6287
role=Role.user,
6388
message_id=f'sync-{uuid4()}',
6489
parts=[Part(root=TextPart(text='sync'))],
65-
metadata={'test_key': 'test_value'},
90+
metadata={'test_key': 'simple_message'},
6691
)
6792

68-
# In v0.3 SDK, send_message ALWAYS returns an async generator
6993
async for event in client.send_message(request=msg):
7094
assert event is not None
7195
event_obj = event[0] if isinstance(event, tuple) else event
72-
if (
73-
getattr(event_obj, 'status', None)
74-
and getattr(event_obj.status, 'state', None)
75-
== 'TASK_STATE_COMPLETED'
76-
):
77-
assert (
78-
getattr(event_obj.status.message, 'metadata', {}).get(
79-
'response_key'
80-
)
81-
== 'response_value'
82-
), (
83-
f'Missing response metadata: {getattr(event_obj.status.message, "metadata", {})}'
96+
97+
status = getattr(event_obj, 'status', None)
98+
if status and str(getattr(status, 'state', '')).endswith('completed'):
99+
# In 0.3 SDK, the message on the status might be exposed as 'message' or 'update'
100+
status_msg = getattr(
101+
status, 'message', getattr(status, 'update', None)
84102
)
85-
elif getattr(event_obj, 'status', None) and str(
86-
getattr(event_obj.status, 'state', None)
87-
).endswith('completed'):
88-
assert (
89-
getattr(event_obj.status.message, 'metadata', {}).get(
90-
'response_key'
91-
)
92-
== 'response_value'
93-
), (
94-
f'Missing response metadata: {getattr(event_obj.status.message, "metadata", {})}'
103+
assert status_msg is not None, (
104+
'TaskStatus message/update is missing'
95105
)
96-
break
106+
107+
metadata = getattr(status_msg, 'metadata', {})
108+
assert metadata.get('response_key') == 'response_value', (
109+
f'Missing response metadata: {metadata}'
110+
)
111+
112+
# Check Part translation (root text part in 0.3)
113+
parts = getattr(
114+
status_msg, 'parts', getattr(status_msg, 'content', [])
115+
)
116+
assert len(parts) > 0, 'No parts found in TaskStatus message'
117+
first_part = parts[0]
118+
text = getattr(first_part, 'text', '')
119+
if (
120+
not text
121+
and hasattr(first_part, 'root')
122+
and hasattr(first_part.root, 'text')
123+
):
124+
text = first_part.root.text
125+
assert text == 'done', f"Expected 'done' text in Part, got '{text}'"
126+
break
97127

98128
print(f'Success: send_message (synchronous) passed.')
99129

@@ -102,20 +132,73 @@ async def test_get_task(client, task_id):
102132
print(f'Testing get_task ({task_id})...')
103133
task = await client.get_task(request=TaskQueryParams(id=task_id))
104134
assert task.id == task_id
135+
136+
user_msgs = [
137+
m for m in task.history if getattr(m, 'role', None) == Role.user
138+
]
139+
assert user_msgs, 'Expected at least one ROLE_USER message in task history'
140+
141+
client_msg = user_msgs[0]
142+
143+
parts = client_msg.parts
144+
assert len(parts) == 4, f'Expected 4 parts, got {len(parts)}'
145+
146+
# 1. text part
147+
text = getattr(parts[0].root, 'text', '')
148+
assert text == 'stream', f"Expected 'stream', got {text}"
149+
150+
# 2. uri part
151+
file_uri = getattr(parts[1].root, 'file', None)
152+
assert (
153+
file_uri is not None
154+
and getattr(file_uri, 'uri', None) == 'https://example.com/file.txt'
155+
)
156+
157+
# 3. bytes part
158+
file_bytes = getattr(parts[2].root, 'file', None)
159+
actual_bytes = getattr(file_bytes, 'bytes', None)
160+
assert actual_bytes == 'aGVsbG8=', (
161+
f"Expected base64 'hello', got {actual_bytes}"
162+
)
163+
164+
# 4. data part
165+
data_val = getattr(parts[3].root, 'data', None)
166+
assert data_val is not None
167+
assert data_val == {'key': 'value'}
168+
105169
print('Success: get_task passed.')
106170

107171

108172
async def test_cancel_task(client, task_id):
109173
print(f'Testing cancel_task ({task_id})...')
110174
await client.cancel_task(request=TaskIdParams(id=task_id))
175+
task = await client.get_task(request=TaskQueryParams(id=task_id))
176+
assert task.status.state == TaskState.canceled, (
177+
f'Expected a canceled state, got {task.status.state}'
178+
)
111179
print('Success: cancel_task passed.')
112180

113181

114182
async def test_subscribe(client, task_id):
115183
print(f'Testing subscribe ({task_id})...')
184+
has_artifact = False
116185
async for event in client.resubscribe(request=TaskIdParams(id=task_id)):
117-
print(f'Received event: {event}')
118-
break
186+
# event is tuple (Task, UpdateEvent)
187+
task, update = event
188+
if update and hasattr(update, 'artifact'):
189+
has_artifact = True
190+
artifact = update.artifact
191+
assert artifact.name == 'test-artifact'
192+
assert artifact.metadata.get('artifact_key') == 'artifact_value'
193+
# part check
194+
assert len(artifact.parts) > 0
195+
p = artifact.parts[0]
196+
text = getattr(p.root, 'text', '')
197+
assert text == 'artifact-chunk'
198+
print('Success: received artifact update.')
199+
200+
if has_artifact:
201+
break
119202
print('Success: subscribe passed.')
120203

121204

@@ -124,7 +207,27 @@ async def test_get_extended_agent_card(client):
124207
# In v0.3, extended card is fetched via get_card() on the client
125208
card = await client.get_card()
126209
assert card is not None
127-
# the MockAgentExecutor might not have a name or has one, just assert card exists
210+
assert card.name in ('Server 0.3', 'Server 1.0')
211+
assert card.version == '1.0.0'
212+
assert 'Server running on a2a v' in card.description
213+
214+
assert card.capabilities is not None
215+
assert card.capabilities.streaming is True
216+
assert card.capabilities.push_notifications is True
217+
218+
if card.name == 'Server 0.3':
219+
assert card.url is not None
220+
assert card.preferred_transport == TransportProtocol.jsonrpc
221+
assert len(card.additional_interfaces) == 2
222+
assert card.supports_authenticated_extended_card is False
223+
else:
224+
assert card.url is not None
225+
assert card.preferred_transport is not None
226+
print(
227+
f'card.supports_authenticated_extended_card is: {card.supports_authenticated_extended_card}'
228+
)
229+
assert card.supports_authenticated_extended_card in (False, None)
230+
128231
print(f'Success: get_extended_agent_card passed.')
129232

130233

@@ -177,8 +280,6 @@ def main():
177280
try:
178281
asyncio.run(run_client(args.url, protocol))
179282
except Exception as e:
180-
import traceback
181-
182283
traceback.print_exc()
183284
print(f'FAILED protocol {protocol}: {e}')
184285
failed = True

0 commit comments

Comments
 (0)