Skip to content

Commit 27a48ab

Browse files
author
Krzysztof Dziedzic
committed
Add ITK test suite
1 parent 24f5f1e commit 27a48ab

File tree

5 files changed

+517
-1
lines changed

5 files changed

+517
-1
lines changed

.github/actions/spelling/allow.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,10 @@ initdb
6262
inmemory
6363
INR
6464
isready
65+
itk
66+
ITK
6567
jcs
68+
jit
6669
jku
6770
JOSE
6871
JPY
@@ -106,11 +109,13 @@ protoc
106109
pydantic
107110
pyi
108111
pypistats
112+
pyproto
109113
pyupgrade
110114
pyversions
111115
redef
112116
respx
113117
resub
118+
rmi
114119
RS256
115120
RUF
116121
SECP256R1
@@ -127,7 +132,7 @@ taskupdate
127132
testuuid
128133
Tful
129134
tiangolo
135+
TResponse
130136
typ
131137
typeerror
132138
vulnz
133-
TResponse

itk/__init__.py

Whitespace-only changes.

itk/main.py

Lines changed: 353 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,353 @@
1+
import argparse # noqa: I001
2+
import asyncio
3+
import base64
4+
import logging
5+
import uuid
6+
7+
import grpc
8+
import httpx
9+
import uvicorn
10+
11+
from fastapi import FastAPI
12+
13+
from pyproto import instruction_pb2
14+
15+
from a2a.client import ClientConfig, ClientFactory
16+
from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc
17+
from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler
18+
from a2a.server.agent_execution import AgentExecutor, RequestContext
19+
from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication
20+
from a2a.server.events import EventQueue
21+
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
22+
from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler
23+
from a2a.server.tasks import TaskUpdater
24+
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
25+
from a2a.types import a2a_pb2_grpc
26+
from a2a.types.a2a_pb2 import (
27+
AgentCapabilities,
28+
AgentCard,
29+
AgentInterface,
30+
Message,
31+
Part,
32+
SendMessageRequest,
33+
TaskState,
34+
)
35+
from a2a.utils import TransportProtocol
36+
37+
38+
logging.basicConfig(level=logging.INFO)
39+
logger = logging.getLogger(__name__)
40+
41+
42+
def extract_instruction(
43+
message: Message | None,
44+
) -> instruction_pb2.Instruction | None:
45+
"""Extracts an Instruction proto from an A2A Message."""
46+
if not message or not message.parts:
47+
return None
48+
49+
for part in message.parts:
50+
# 1. Handle binary protobuf part (media_type or filename)
51+
if (
52+
part.media_type == 'application/x-protobuf'
53+
or part.filename == 'instruction.bin'
54+
):
55+
try:
56+
inst = instruction_pb2.Instruction()
57+
if part.raw:
58+
inst.ParseFromString(part.raw)
59+
elif part.text:
60+
# Some clients might send it as base64 in text part
61+
raw = base64.b64decode(part.text)
62+
inst.ParseFromString(raw)
63+
except Exception: # noqa: BLE001
64+
logger.debug(
65+
'Failed to parse instruction from binary part',
66+
exc_info=True,
67+
)
68+
continue
69+
else:
70+
return inst
71+
72+
# 2. Handle base64 encoded instruction in any text part
73+
if part.text:
74+
try:
75+
raw = base64.b64decode(part.text)
76+
inst = instruction_pb2.Instruction()
77+
inst.ParseFromString(raw)
78+
except Exception: # noqa: BLE001
79+
logger.debug(
80+
'Failed to parse instruction from text part', exc_info=True
81+
)
82+
continue
83+
else:
84+
return inst
85+
return None
86+
87+
88+
def wrap_instruction_to_request(inst: instruction_pb2.Instruction) -> Message:
89+
"""Wraps an Instruction proto into an A2A Message."""
90+
inst_bytes = inst.SerializeToString()
91+
return Message(
92+
role='ROLE_USER',
93+
message_id=str(uuid.uuid4()),
94+
parts=[
95+
Part(
96+
raw=inst_bytes,
97+
media_type='application/x-protobuf',
98+
filename='instruction.bin',
99+
)
100+
],
101+
)
102+
103+
104+
async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]:
105+
"""Handles the CallAgent instruction by invoking another agent."""
106+
logger.info('Calling agent %s via %s', call.agent_card_uri, call.transport)
107+
108+
# Mapping transport string to TransportProtocol enum
109+
transport_map = {
110+
'JSONRPC': TransportProtocol.JSONRPC,
111+
'HTTP+JSON': TransportProtocol.HTTP_JSON,
112+
'HTTP_JSON': TransportProtocol.HTTP_JSON,
113+
'REST': TransportProtocol.HTTP_JSON,
114+
'GRPC': TransportProtocol.GRPC,
115+
}
116+
117+
selected_transport = transport_map.get(
118+
call.transport.upper(), TransportProtocol.JSONRPC
119+
)
120+
if selected_transport is None:
121+
raise ValueError(f'Unsupported transport: {call.transport}')
122+
123+
config = ClientConfig()
124+
config.httpx_client = httpx.AsyncClient(timeout=30.0)
125+
config.grpc_channel_factory = grpc.aio.insecure_channel
126+
config.supported_protocol_bindings = [selected_transport]
127+
config.streaming = call.streaming or (
128+
selected_transport == TransportProtocol.GRPC
129+
)
130+
131+
try:
132+
client = await ClientFactory.connect(
133+
call.agent_card_uri,
134+
client_config=config,
135+
)
136+
137+
# Wrap nested instruction
138+
nested_msg = wrap_instruction_to_request(call.instruction)
139+
request = SendMessageRequest(message=nested_msg)
140+
141+
results = []
142+
async for event in client.send_message(request):
143+
# Event is streaming response and task
144+
logger.info('Event: %s', event)
145+
stream_resp, task = event
146+
147+
message = None
148+
if stream_resp.HasField('message'):
149+
message = stream_resp.message
150+
elif task and task.status.HasField('message'):
151+
message = task.status.message
152+
elif stream_resp.HasField(
153+
'status_update'
154+
) and stream_resp.status_update.status.HasField('message'):
155+
message = stream_resp.status_update.status.message
156+
157+
if message:
158+
results.extend(part.text for part in message.parts if part.text)
159+
160+
except Exception as e:
161+
logger.exception('Failed to call outbound agent')
162+
raise RuntimeError(
163+
f'Outbound call to {call.agent_card_uri} failed: {e!s}'
164+
) from e
165+
else:
166+
return results
167+
168+
169+
async def handle_instruction(inst: instruction_pb2.Instruction) -> list[str]:
170+
"""Recursively handles instructions."""
171+
if inst.HasField('call_agent'):
172+
return await handle_call_agent(inst.call_agent)
173+
if inst.HasField('return_response'):
174+
return [inst.return_response.response]
175+
if inst.HasField('steps'):
176+
all_results = []
177+
for step in inst.steps.instructions:
178+
results = await handle_instruction(step)
179+
all_results.extend(results)
180+
return all_results
181+
raise ValueError('Unknown instruction type')
182+
183+
184+
class V10AgentExecutor(AgentExecutor):
185+
"""Executor for ITK v10 agent tasks."""
186+
187+
async def execute(
188+
self, context: RequestContext, event_queue: EventQueue
189+
) -> None:
190+
"""Executes a task instruction."""
191+
logger.info('Executing task %s', context.task_id)
192+
task_updater = TaskUpdater(
193+
event_queue,
194+
context.task_id,
195+
context.context_id,
196+
)
197+
198+
await task_updater.update_status(TaskState.TASK_STATE_SUBMITTED)
199+
await task_updater.update_status(TaskState.TASK_STATE_WORKING)
200+
201+
instruction = extract_instruction(context.message)
202+
if not instruction:
203+
error_msg = 'No valid instruction found in request'
204+
logger.error(error_msg)
205+
await task_updater.update_status(
206+
TaskState.TASK_STATE_FAILED,
207+
message=task_updater.new_agent_message([Part(text=error_msg)]),
208+
)
209+
return
210+
211+
try:
212+
logger.info('Instruction: %s', instruction)
213+
results = await handle_instruction(instruction)
214+
response_text = '\n'.join(results)
215+
logger.info('Response: %s', response_text)
216+
await task_updater.update_status(
217+
TaskState.TASK_STATE_COMPLETED,
218+
message=task_updater.new_agent_message(
219+
[Part(text=response_text)]
220+
),
221+
)
222+
logger.info('Task %s completed', context.task_id)
223+
except Exception as e:
224+
logger.exception('Error during instruction handling')
225+
await task_updater.update_status(
226+
TaskState.TASK_STATE_FAILED,
227+
message=task_updater.new_agent_message([Part(text=str(e))]),
228+
)
229+
230+
async def cancel(
231+
self, context: RequestContext, event_queue: EventQueue
232+
) -> None:
233+
"""Cancels a task."""
234+
logger.info('Cancel requested for task %s', context.task_id)
235+
task_updater = TaskUpdater(
236+
event_queue,
237+
context.task_id,
238+
context.context_id,
239+
)
240+
await task_updater.update_status(TaskState.TASK_STATE_CANCELED)
241+
242+
243+
async def main_async(http_port: int, grpc_port: int) -> None:
244+
"""Starts the Agent with HTTP and gRPC interfaces."""
245+
interfaces = [
246+
AgentInterface(
247+
protocol_binding=TransportProtocol.GRPC,
248+
url=f'127.0.0.1:{grpc_port}',
249+
protocol_version='1.0',
250+
),
251+
AgentInterface(
252+
protocol_binding=TransportProtocol.GRPC,
253+
url=f'127.0.0.1:{grpc_port}',
254+
protocol_version='0.3',
255+
),
256+
]
257+
258+
interfaces.append(
259+
AgentInterface(
260+
protocol_binding=TransportProtocol.JSONRPC,
261+
url=f'http://127.0.0.1:{http_port}/jsonrpc/',
262+
)
263+
)
264+
interfaces.append(
265+
AgentInterface(
266+
protocol_binding=TransportProtocol.HTTP_JSON,
267+
url=f'http://127.0.0.1:{http_port}/rest/',
268+
protocol_version='1.0',
269+
)
270+
)
271+
interfaces.append(
272+
AgentInterface(
273+
protocol_binding=TransportProtocol.HTTP_JSON,
274+
url=f'http://127.0.0.1:{http_port}/rest/',
275+
protocol_version='0.3',
276+
)
277+
)
278+
279+
agent_card = AgentCard(
280+
name='ITK v10 Agent',
281+
description='Python agent using SDK 1.0.',
282+
version='1.0.0',
283+
capabilities=AgentCapabilities(streaming=True),
284+
default_input_modes=['text/plain'],
285+
default_output_modes=['text/plain'],
286+
supported_interfaces=interfaces,
287+
)
288+
289+
task_store = InMemoryTaskStore()
290+
handler = DefaultRequestHandler(
291+
agent_executor=V10AgentExecutor(),
292+
task_store=task_store,
293+
queue_manager=InMemoryQueueManager(),
294+
)
295+
296+
app = FastAPI()
297+
298+
json_rpc_app = A2AFastAPIApplication(
299+
agent_card, handler, enable_v0_3_compat=True
300+
).build()
301+
app.mount('/jsonrpc', json_rpc_app)
302+
rest_app = A2ARESTFastAPIApplication(
303+
http_handler=handler, agent_card=agent_card, enable_v0_3_compat=True
304+
).build()
305+
app.mount('/rest', rest_app)
306+
307+
server = grpc.aio.server()
308+
309+
compat_servicer = CompatGrpcHandler(agent_card, handler)
310+
a2a_v0_3_pb2_grpc.add_A2AServiceServicer_to_server(compat_servicer, server)
311+
servicer = GrpcHandler(agent_card, handler)
312+
a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server)
313+
314+
server.add_insecure_port(f'127.0.0.1:{grpc_port}')
315+
await server.start()
316+
317+
logger.info(
318+
'Starting ITK v10 Agent on HTTP port %s and gRPC port %s',
319+
http_port,
320+
grpc_port,
321+
)
322+
323+
config = uvicorn.Config(
324+
app, host='127.0.0.1', port=http_port, log_level='info'
325+
)
326+
uvicorn_server = uvicorn.Server(config)
327+
328+
await uvicorn_server.serve()
329+
330+
331+
def str2bool(v: str | bool) -> bool:
332+
"""Converts a string to a boolean value."""
333+
if isinstance(v, bool):
334+
return v
335+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
336+
return True
337+
if v.lower() in ('no', 'false', 'f', 'n', '0'):
338+
return False
339+
raise argparse.ArgumentTypeError('Boolean value expected.')
340+
341+
342+
def main() -> None:
343+
"""Main entry point for the agent."""
344+
parser = argparse.ArgumentParser()
345+
parser.add_argument('--httpPort', type=int, default=10102)
346+
parser.add_argument('--grpcPort', type=int, default=11002)
347+
args = parser.parse_args()
348+
349+
asyncio.run(main_async(args.httpPort, args.grpcPort))
350+
351+
352+
if __name__ == '__main__':
353+
main()

0 commit comments

Comments
 (0)