Skip to content

Commit f4a0bcd

Browse files
bartek-wishymko
andauthored
feat!: Raise errors on invalid AgentExecutor behavior. (#979)
Fixes #869 🦕 --------- Co-authored-by: Ivan Shymko <ishymko@google.com>
1 parent b8df210 commit f4a0bcd

7 files changed

Lines changed: 566 additions & 332 deletions

File tree

src/a2a/server/agent_execution/active_task.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
TaskStatusUpdateEvent,
3737
)
3838
from a2a.utils.errors import (
39+
InvalidAgentResponseError,
3940
InvalidParamsError,
4041
TaskNotFoundError,
4142
)
@@ -370,13 +371,12 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
370371
elif isinstance(event, Message):
371372
if task_mode is not None:
372373
if task_mode:
373-
logger.error(
374-
'Received Message() object in task mode.'
375-
)
376-
else:
377-
logger.error(
378-
'Multiple Message() objects received.'
374+
raise InvalidAgentResponseError(
375+
'Received Message object in task mode. Use TaskStatusUpdateEvent or TaskArtifactUpdateEvent instead.'
379376
)
377+
raise InvalidAgentResponseError(
378+
'Multiple Message objects received.'
379+
)
380380
task_mode = False
381381
logger.debug(
382382
'Consumer[%s]: Setting result to Message: %s',
@@ -385,9 +385,8 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
385385
)
386386
else:
387387
if task_mode is False:
388-
logger.error(
389-
'Received %s in message mode.',
390-
type(event).__name__,
388+
raise InvalidAgentResponseError(
389+
f'Received {type(event).__name__} in message mode. Use Task with TaskStatusUpdateEvent and TaskArtifactUpdateEvent instead.'
391390
)
392391

393392
if isinstance(event, Task):
@@ -408,6 +407,18 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
408407
# Initial task should already contain the message.
409408
message_to_save = None
410409
else:
410+
if (
411+
isinstance(event, TaskStatusUpdateEvent)
412+
and not self._task_created.is_set()
413+
):
414+
task = (
415+
await self._task_manager.get_task()
416+
)
417+
if task is None:
418+
raise InvalidAgentResponseError(
419+
f'Agent should enqueue Task before {type(event).__name__} event'
420+
)
421+
411422
new_task = (
412423
await self._task_manager.ensure_task_id(
413424
self._task_id,
@@ -434,8 +445,6 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
434445
if not isinstance(event, Task):
435446
await self._task_manager.process(event)
436447

437-
self._task_created.set()
438-
439448
# Check for AUTH_REQUIRED or INPUT_REQUIRED or TERMINAL states
440449
new_task = await self._task_manager.get_task()
441450
if new_task is None:
@@ -496,6 +505,9 @@ async def _run_consumer(self) -> None: # noqa: PLR0915, PLR0912
496505
await self._push_sender.send_notification(
497506
self._task_id, event
498507
)
508+
509+
self._task_created.set()
510+
499511
finally:
500512
if new_task is not None:
501513
new_task_copy = Task()

tests/integration/cross_version/client_server/server_0_3.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from starlette.requests import Request
3939
from starlette.concurrency import iterate_in_threadpool
4040
import time
41-
41+
from a2a.utils.task import new_task
4242
from server_common import CustomLoggingMiddleware
4343

4444

@@ -48,12 +48,18 @@ def __init__(self):
4848

4949
async def execute(self, context: RequestContext, event_queue: EventQueue):
5050
print(f'SERVER: execute called for task {context.task_id}')
51+
52+
task = new_task(context.message)
53+
task.id = context.task_id
54+
task.context_id = context.context_id
55+
task.status.state = TaskState.working
56+
await event_queue.enqueue_event(task)
57+
5158
task_updater = TaskUpdater(
5259
event_queue,
5360
context.task_id,
5461
context.context_id,
5562
)
56-
await task_updater.update_status(TaskState.submitted)
5763
await task_updater.update_status(TaskState.working)
5864

5965
text = ''

tests/integration/cross_version/client_server/server_1_0.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from a2a.utils import TransportProtocol
2929
from server_common import CustomLoggingMiddleware
3030
from google.protobuf.struct_pb2 import Struct, Value
31+
from a2a.helpers.proto_helpers import new_task_from_user_message
3132

3233

3334
class MockAgentExecutor(AgentExecutor):
@@ -36,12 +37,17 @@ def __init__(self):
3637

3738
async def execute(self, context: RequestContext, event_queue: EventQueue):
3839
print(f'SERVER: execute called for task {context.task_id}')
40+
task = new_task_from_user_message(context.message)
41+
task.id = context.task_id
42+
task.context_id = context.context_id
43+
task.status.state = TaskState.TASK_STATE_WORKING
44+
await event_queue.enqueue_event(task)
45+
3946
task_updater = TaskUpdater(
4047
event_queue,
4148
context.task_id,
4249
context.context_id,
4350
)
44-
await task_updater.update_status(TaskState.TASK_STATE_SUBMITTED)
4551
await task_updater.update_status(TaskState.TASK_STATE_WORKING)
4652

4753
text = ''

tests/integration/test_copying_observability.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
SendMessageRequest,
2626
TaskState,
2727
)
28+
from a2a.helpers.proto_helpers import new_task_from_user_message
2829
from a2a.utils import TransportProtocol
2930

3031

@@ -42,6 +43,12 @@ async def execute(self, context: RequestContext, event_queue: EventQueue):
4243

4344
if user_input == 'Init task':
4445
# Explicitly save status change to ensure task exists with some state
46+
task = new_task_from_user_message(context.message)
47+
task.id = context.task_id
48+
task.context_id = context.context_id
49+
task.status.state = TaskState.TASK_STATE_WORKING
50+
await event_queue.enqueue_event(task)
51+
4552
await task_updater.update_status(
4653
TaskState.TASK_STATE_WORKING,
4754
message=task_updater.new_agent_message(
@@ -153,6 +160,7 @@ async def test_mutation_observability(agent_card: AgentCard, use_copying: bool):
153160
]
154161

155162
event = events[-1]
163+
assert event.HasField('status_update')
156164
task_id = event.status_update.task_id
157165

158166
# 2. Second message to mutate it
@@ -162,7 +170,6 @@ async def test_mutation_observability(agent_card: AgentCard, use_copying: bool):
162170
task_id=task_id,
163171
parts=[Part(text='Update task without saving it')],
164172
)
165-
166173
_ = [
167174
event
168175
async for event in client.send_message(

0 commit comments

Comments
 (0)