|
| 1 | +import argparse # noqa: I001 |
| 2 | +import asyncio |
| 3 | +import base64 |
| 4 | +import logging |
| 5 | +import uuid |
| 6 | + |
| 7 | +import grpc |
| 8 | +import httpx |
| 9 | +import uvicorn |
| 10 | + |
| 11 | +from fastapi import FastAPI |
| 12 | + |
| 13 | +from pyproto import instruction_pb2 |
| 14 | + |
| 15 | +from a2a.client import ClientConfig, ClientFactory |
| 16 | +from a2a.compat.v0_3 import a2a_v0_3_pb2_grpc |
| 17 | +from a2a.compat.v0_3.grpc_handler import CompatGrpcHandler |
| 18 | +from a2a.server.agent_execution import AgentExecutor, RequestContext |
| 19 | +from a2a.server.apps import A2AFastAPIApplication, A2ARESTFastAPIApplication |
| 20 | +from a2a.server.events import EventQueue |
| 21 | +from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager |
| 22 | +from a2a.server.request_handlers import DefaultRequestHandler, GrpcHandler |
| 23 | +from a2a.server.tasks import TaskUpdater |
| 24 | +from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore |
| 25 | +from a2a.types import a2a_pb2_grpc |
| 26 | +from a2a.types.a2a_pb2 import ( |
| 27 | + AgentCapabilities, |
| 28 | + AgentCard, |
| 29 | + AgentInterface, |
| 30 | + Message, |
| 31 | + Part, |
| 32 | + SendMessageRequest, |
| 33 | + TaskState, |
| 34 | +) |
| 35 | +from a2a.utils import TransportProtocol |
| 36 | + |
| 37 | + |
| 38 | +logging.basicConfig(level=logging.INFO) |
| 39 | +logger = logging.getLogger(__name__) |
| 40 | + |
| 41 | + |
| 42 | +def extract_instruction( |
| 43 | + message: Message | None, |
| 44 | +) -> instruction_pb2.Instruction | None: |
| 45 | + """Extracts an Instruction proto from an A2A Message.""" |
| 46 | + if not message or not message.parts: |
| 47 | + return None |
| 48 | + |
| 49 | + for part in message.parts: |
| 50 | + # 1. Handle binary protobuf part (media_type or filename) |
| 51 | + if ( |
| 52 | + part.media_type == 'application/x-protobuf' |
| 53 | + or part.filename == 'instruction.bin' |
| 54 | + ): |
| 55 | + try: |
| 56 | + inst = instruction_pb2.Instruction() |
| 57 | + if part.raw: |
| 58 | + inst.ParseFromString(part.raw) |
| 59 | + elif part.text: |
| 60 | + # Some clients might send it as base64 in text part |
| 61 | + raw = base64.b64decode(part.text) |
| 62 | + inst.ParseFromString(raw) |
| 63 | + except Exception: # noqa: BLE001 |
| 64 | + logger.debug( |
| 65 | + 'Failed to parse instruction from binary part', |
| 66 | + exc_info=True, |
| 67 | + ) |
| 68 | + continue |
| 69 | + else: |
| 70 | + return inst |
| 71 | + |
| 72 | + # 2. Handle base64 encoded instruction in any text part |
| 73 | + if part.text: |
| 74 | + try: |
| 75 | + raw = base64.b64decode(part.text) |
| 76 | + inst = instruction_pb2.Instruction() |
| 77 | + inst.ParseFromString(raw) |
| 78 | + except Exception: # noqa: BLE001 |
| 79 | + logger.debug( |
| 80 | + 'Failed to parse instruction from text part', exc_info=True |
| 81 | + ) |
| 82 | + continue |
| 83 | + else: |
| 84 | + return inst |
| 85 | + return None |
| 86 | + |
| 87 | + |
| 88 | +def wrap_instruction_to_request(inst: instruction_pb2.Instruction) -> Message: |
| 89 | + """Wraps an Instruction proto into an A2A Message.""" |
| 90 | + inst_bytes = inst.SerializeToString() |
| 91 | + return Message( |
| 92 | + role='ROLE_USER', |
| 93 | + message_id=str(uuid.uuid4()), |
| 94 | + parts=[ |
| 95 | + Part( |
| 96 | + raw=inst_bytes, |
| 97 | + media_type='application/x-protobuf', |
| 98 | + filename='instruction.bin', |
| 99 | + ) |
| 100 | + ], |
| 101 | + ) |
| 102 | + |
| 103 | + |
| 104 | +async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]: |
| 105 | + """Handles the CallAgent instruction by invoking another agent.""" |
| 106 | + logger.info('Calling agent %s via %s', call.agent_card_uri, call.transport) |
| 107 | + |
| 108 | + # Mapping transport string to TransportProtocol enum |
| 109 | + transport_map = { |
| 110 | + 'JSONRPC': TransportProtocol.JSONRPC, |
| 111 | + 'HTTP+JSON': TransportProtocol.HTTP_JSON, |
| 112 | + 'HTTP_JSON': TransportProtocol.HTTP_JSON, |
| 113 | + 'REST': TransportProtocol.HTTP_JSON, |
| 114 | + 'GRPC': TransportProtocol.GRPC, |
| 115 | + } |
| 116 | + |
| 117 | + selected_transport = transport_map.get( |
| 118 | + call.transport.upper(), TransportProtocol.JSONRPC |
| 119 | + ) |
| 120 | + if selected_transport is None: |
| 121 | + raise ValueError(f'Unsupported transport: {call.transport}') |
| 122 | + |
| 123 | + config = ClientConfig() |
| 124 | + config.httpx_client = httpx.AsyncClient(timeout=30.0) |
| 125 | + config.grpc_channel_factory = grpc.aio.insecure_channel |
| 126 | + config.supported_protocol_bindings = [selected_transport] |
| 127 | + config.streaming = call.streaming or ( |
| 128 | + selected_transport == TransportProtocol.GRPC |
| 129 | + ) |
| 130 | + |
| 131 | + try: |
| 132 | + client = await ClientFactory.connect( |
| 133 | + call.agent_card_uri, |
| 134 | + client_config=config, |
| 135 | + ) |
| 136 | + |
| 137 | + # Wrap nested instruction |
| 138 | + nested_msg = wrap_instruction_to_request(call.instruction) |
| 139 | + request = SendMessageRequest(message=nested_msg) |
| 140 | + |
| 141 | + results = [] |
| 142 | + async for event in client.send_message(request): |
| 143 | + # Event is streaming response and task |
| 144 | + logger.info('Event: %s', event) |
| 145 | + stream_resp, task = event |
| 146 | + |
| 147 | + message = None |
| 148 | + if stream_resp.HasField('message'): |
| 149 | + message = stream_resp.message |
| 150 | + elif task and task.status.HasField('message'): |
| 151 | + message = task.status.message |
| 152 | + elif stream_resp.HasField( |
| 153 | + 'status_update' |
| 154 | + ) and stream_resp.status_update.status.HasField('message'): |
| 155 | + message = stream_resp.status_update.status.message |
| 156 | + |
| 157 | + if message: |
| 158 | + results.extend(part.text for part in message.parts if part.text) |
| 159 | + |
| 160 | + except Exception as e: |
| 161 | + logger.exception('Failed to call outbound agent') |
| 162 | + raise RuntimeError( |
| 163 | + f'Outbound call to {call.agent_card_uri} failed: {e!s}' |
| 164 | + ) from e |
| 165 | + else: |
| 166 | + return results |
| 167 | + |
| 168 | + |
| 169 | +async def handle_instruction(inst: instruction_pb2.Instruction) -> list[str]: |
| 170 | + """Recursively handles instructions.""" |
| 171 | + if inst.HasField('call_agent'): |
| 172 | + return await handle_call_agent(inst.call_agent) |
| 173 | + if inst.HasField('return_response'): |
| 174 | + return [inst.return_response.response] |
| 175 | + if inst.HasField('steps'): |
| 176 | + all_results = [] |
| 177 | + for step in inst.steps.instructions: |
| 178 | + results = await handle_instruction(step) |
| 179 | + all_results.extend(results) |
| 180 | + return all_results |
| 181 | + raise ValueError('Unknown instruction type') |
| 182 | + |
| 183 | + |
| 184 | +class V10AgentExecutor(AgentExecutor): |
| 185 | + """Executor for ITK v10 agent tasks.""" |
| 186 | + |
| 187 | + async def execute( |
| 188 | + self, context: RequestContext, event_queue: EventQueue |
| 189 | + ) -> None: |
| 190 | + """Executes a task instruction.""" |
| 191 | + logger.info('Executing task %s', context.task_id) |
| 192 | + task_updater = TaskUpdater( |
| 193 | + event_queue, |
| 194 | + context.task_id, |
| 195 | + context.context_id, |
| 196 | + ) |
| 197 | + |
| 198 | + await task_updater.update_status(TaskState.TASK_STATE_SUBMITTED) |
| 199 | + await task_updater.update_status(TaskState.TASK_STATE_WORKING) |
| 200 | + |
| 201 | + instruction = extract_instruction(context.message) |
| 202 | + if not instruction: |
| 203 | + error_msg = 'No valid instruction found in request' |
| 204 | + logger.error(error_msg) |
| 205 | + await task_updater.update_status( |
| 206 | + TaskState.TASK_STATE_FAILED, |
| 207 | + message=task_updater.new_agent_message([Part(text=error_msg)]), |
| 208 | + ) |
| 209 | + return |
| 210 | + |
| 211 | + try: |
| 212 | + logger.info('Instruction: %s', instruction) |
| 213 | + results = await handle_instruction(instruction) |
| 214 | + response_text = '\n'.join(results) |
| 215 | + logger.info('Response: %s', response_text) |
| 216 | + await task_updater.update_status( |
| 217 | + TaskState.TASK_STATE_COMPLETED, |
| 218 | + message=task_updater.new_agent_message( |
| 219 | + [Part(text=response_text)] |
| 220 | + ), |
| 221 | + ) |
| 222 | + logger.info('Task %s completed', context.task_id) |
| 223 | + except Exception as e: |
| 224 | + logger.exception('Error during instruction handling') |
| 225 | + await task_updater.update_status( |
| 226 | + TaskState.TASK_STATE_FAILED, |
| 227 | + message=task_updater.new_agent_message([Part(text=str(e))]), |
| 228 | + ) |
| 229 | + |
| 230 | + async def cancel( |
| 231 | + self, context: RequestContext, event_queue: EventQueue |
| 232 | + ) -> None: |
| 233 | + """Cancels a task.""" |
| 234 | + logger.info('Cancel requested for task %s', context.task_id) |
| 235 | + task_updater = TaskUpdater( |
| 236 | + event_queue, |
| 237 | + context.task_id, |
| 238 | + context.context_id, |
| 239 | + ) |
| 240 | + await task_updater.update_status(TaskState.TASK_STATE_CANCELED) |
| 241 | + |
| 242 | + |
| 243 | +async def main_async(http_port: int, grpc_port: int) -> None: |
| 244 | + """Starts the Agent with HTTP and gRPC interfaces.""" |
| 245 | + interfaces = [ |
| 246 | + AgentInterface( |
| 247 | + protocol_binding=TransportProtocol.GRPC, |
| 248 | + url=f'127.0.0.1:{grpc_port}', |
| 249 | + protocol_version='1.0', |
| 250 | + ), |
| 251 | + AgentInterface( |
| 252 | + protocol_binding=TransportProtocol.GRPC, |
| 253 | + url=f'127.0.0.1:{grpc_port}', |
| 254 | + protocol_version='0.3', |
| 255 | + ), |
| 256 | + ] |
| 257 | + |
| 258 | + interfaces.append( |
| 259 | + AgentInterface( |
| 260 | + protocol_binding=TransportProtocol.JSONRPC, |
| 261 | + url=f'http://127.0.0.1:{http_port}/jsonrpc/', |
| 262 | + ) |
| 263 | + ) |
| 264 | + interfaces.append( |
| 265 | + AgentInterface( |
| 266 | + protocol_binding=TransportProtocol.HTTP_JSON, |
| 267 | + url=f'http://127.0.0.1:{http_port}/rest/', |
| 268 | + protocol_version='1.0', |
| 269 | + ) |
| 270 | + ) |
| 271 | + interfaces.append( |
| 272 | + AgentInterface( |
| 273 | + protocol_binding=TransportProtocol.HTTP_JSON, |
| 274 | + url=f'http://127.0.0.1:{http_port}/rest/', |
| 275 | + protocol_version='0.3', |
| 276 | + ) |
| 277 | + ) |
| 278 | + |
| 279 | + agent_card = AgentCard( |
| 280 | + name='ITK v10 Agent', |
| 281 | + description='Python agent using SDK 1.0.', |
| 282 | + version='1.0.0', |
| 283 | + capabilities=AgentCapabilities(streaming=True), |
| 284 | + default_input_modes=['text/plain'], |
| 285 | + default_output_modes=['text/plain'], |
| 286 | + supported_interfaces=interfaces, |
| 287 | + ) |
| 288 | + |
| 289 | + task_store = InMemoryTaskStore() |
| 290 | + handler = DefaultRequestHandler( |
| 291 | + agent_executor=V10AgentExecutor(), |
| 292 | + task_store=task_store, |
| 293 | + queue_manager=InMemoryQueueManager(), |
| 294 | + ) |
| 295 | + |
| 296 | + app = FastAPI() |
| 297 | + |
| 298 | + json_rpc_app = A2AFastAPIApplication( |
| 299 | + agent_card, handler, enable_v0_3_compat=True |
| 300 | + ).build() |
| 301 | + app.mount('/jsonrpc', json_rpc_app) |
| 302 | + rest_app = A2ARESTFastAPIApplication( |
| 303 | + http_handler=handler, agent_card=agent_card, enable_v0_3_compat=True |
| 304 | + ).build() |
| 305 | + app.mount('/rest', rest_app) |
| 306 | + |
| 307 | + server = grpc.aio.server() |
| 308 | + |
| 309 | + compat_servicer = CompatGrpcHandler(agent_card, handler) |
| 310 | + a2a_v0_3_pb2_grpc.add_A2AServiceServicer_to_server(compat_servicer, server) |
| 311 | + servicer = GrpcHandler(agent_card, handler) |
| 312 | + a2a_pb2_grpc.add_A2AServiceServicer_to_server(servicer, server) |
| 313 | + |
| 314 | + server.add_insecure_port(f'127.0.0.1:{grpc_port}') |
| 315 | + await server.start() |
| 316 | + |
| 317 | + logger.info( |
| 318 | + 'Starting ITK v10 Agent on HTTP port %s and gRPC port %s', |
| 319 | + http_port, |
| 320 | + grpc_port, |
| 321 | + ) |
| 322 | + |
| 323 | + config = uvicorn.Config( |
| 324 | + app, host='127.0.0.1', port=http_port, log_level='info' |
| 325 | + ) |
| 326 | + uvicorn_server = uvicorn.Server(config) |
| 327 | + |
| 328 | + await uvicorn_server.serve() |
| 329 | + |
| 330 | + |
| 331 | +def str2bool(v: str | bool) -> bool: |
| 332 | + """Converts a string to a boolean value.""" |
| 333 | + if isinstance(v, bool): |
| 334 | + return v |
| 335 | + if v.lower() in ('yes', 'true', 't', 'y', '1'): |
| 336 | + return True |
| 337 | + if v.lower() in ('no', 'false', 'f', 'n', '0'): |
| 338 | + return False |
| 339 | + raise argparse.ArgumentTypeError('Boolean value expected.') |
| 340 | + |
| 341 | + |
| 342 | +def main() -> None: |
| 343 | + """Main entry point for the agent.""" |
| 344 | + parser = argparse.ArgumentParser() |
| 345 | + parser.add_argument('--httpPort', type=int, default=10102) |
| 346 | + parser.add_argument('--grpcPort', type=int, default=11002) |
| 347 | + args = parser.parse_args() |
| 348 | + |
| 349 | + asyncio.run(main_async(args.httpPort, args.grpcPort)) |
| 350 | + |
| 351 | + |
| 352 | +if __name__ == '__main__': |
| 353 | + main() |
0 commit comments