Skip to content

Commit 8c65e84

Browse files
authored
feat: InMemoryTaskStore creates a copy of Task by default to make it consistent with database task stores (#887)
Sharing the Task object instance in InMemoryTaskStore leads to unexpected behaviour (from differences of in-place update Task in AgentExecutor to non-trivial concurrency reporting issues on task state reporting). Fixes #869
1 parent 405be3f commit 8c65e84

5 files changed

Lines changed: 469 additions & 4 deletions

File tree

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from __future__ import annotations
2+
3+
import logging
4+
5+
from typing import TYPE_CHECKING
6+
7+
8+
if TYPE_CHECKING:
9+
from a2a.server.context import ServerCallContext
10+
from a2a.server.tasks.task_store import TaskStore
11+
from a2a.types.a2a_pb2 import ListTasksRequest, ListTasksResponse, Task
12+
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
class CopyingTaskStoreAdapter(TaskStore):
18+
"""An adapter that ensures deep copies of tasks are passed to and returned from the underlying TaskStore.
19+
20+
This prevents accidental shared mutable state bugs where code modifies a Task object
21+
retrieved from the store without explicitly saving it, which hides missing save calls.
22+
"""
23+
24+
def __init__(self, underlying_store: TaskStore):
25+
self._store = underlying_store
26+
27+
async def save(
28+
self, task: Task, context: ServerCallContext | None = None
29+
) -> None:
30+
"""Saves a copy of the task to the underlying store."""
31+
task_copy = Task()
32+
task_copy.CopyFrom(task)
33+
await self._store.save(task_copy, context)
34+
35+
async def get(
36+
self, task_id: str, context: ServerCallContext | None = None
37+
) -> Task | None:
38+
"""Retrieves a task from the underlying store and returns a copy."""
39+
task = await self._store.get(task_id, context)
40+
if task is None:
41+
return None
42+
task_copy = Task()
43+
task_copy.CopyFrom(task)
44+
return task_copy
45+
46+
async def list(
47+
self,
48+
params: ListTasksRequest,
49+
context: ServerCallContext | None = None,
50+
) -> ListTasksResponse:
51+
"""Retrieves a list of tasks from the underlying store and returns a copy."""
52+
response = await self._store.list(params, context)
53+
response_copy = ListTasksResponse()
54+
response_copy.CopyFrom(response)
55+
return response_copy
56+
57+
async def delete(
58+
self, task_id: str, context: ServerCallContext | None = None
59+
) -> None:
60+
"""Deletes a task from the underlying store."""
61+
await self._store.delete(task_id, context)

src/a2a/server/tasks/inmemory_task_store.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from a2a.server.context import ServerCallContext
55
from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope
6+
from a2a.server.tasks.copying_task_store import CopyingTaskStoreAdapter
67
from a2a.server.tasks.task_store import TaskStore
78
from a2a.types import a2a_pb2
89
from a2a.types.a2a_pb2 import Task
@@ -14,8 +15,8 @@
1415
logger = logging.getLogger(__name__)
1516

1617

17-
class InMemoryTaskStore(TaskStore):
18-
"""In-memory implementation of TaskStore.
18+
class _InMemoryTaskStoreImpl(TaskStore):
19+
"""Internal In-memory implementation of TaskStore.
1920
2021
Stores task objects in a nested dictionary in memory, keyed by owner then task_id.
2122
Task data is lost when the server process stops.
@@ -25,8 +26,8 @@ def __init__(
2526
self,
2627
owner_resolver: OwnerResolver = resolve_user_scope,
2728
) -> None:
28-
"""Initializes the InMemoryTaskStore."""
29-
logger.debug('Initializing InMemoryTaskStore')
29+
"""Initializes the internal _InMemoryTaskStoreImpl."""
30+
logger.debug('Initializing _InMemoryTaskStoreImpl')
3031
self.tasks: dict[str, dict[str, Task]] = {}
3132
self.lock = asyncio.Lock()
3233
self.owner_resolver = owner_resolver
@@ -183,3 +184,55 @@ async def delete(
183184
if not owner_tasks:
184185
del self.tasks[owner]
185186
logger.debug('Removed empty owner %s from store.', owner)
187+
188+
189+
class InMemoryTaskStore(TaskStore):
190+
"""In-memory implementation of TaskStore.
191+
192+
Can optionally use CopyingTaskStoreAdapter to wrap the internal dictionary-based
193+
implementation, preventing shared mutable state issues by always returning and
194+
storing deep copies.
195+
"""
196+
197+
def __init__(
198+
self,
199+
owner_resolver: OwnerResolver = resolve_user_scope,
200+
use_copying: bool = True,
201+
) -> None:
202+
"""Initializes the InMemoryTaskStore.
203+
204+
Args:
205+
owner_resolver: Resolver for task owners.
206+
use_copying: If True, the store will return and save deep copies of tasks.
207+
Copying behavior is consistent with database task stores.
208+
"""
209+
self._impl = _InMemoryTaskStoreImpl(owner_resolver=owner_resolver)
210+
self._store: TaskStore = (
211+
CopyingTaskStoreAdapter(self._impl) if use_copying else self._impl
212+
)
213+
214+
async def save(
215+
self, task: Task, context: ServerCallContext | None = None
216+
) -> None:
217+
"""Saves or updates a task in the store."""
218+
await self._store.save(task, context)
219+
220+
async def get(
221+
self, task_id: str, context: ServerCallContext | None = None
222+
) -> Task | None:
223+
"""Retrieves a task from the store by ID."""
224+
return await self._store.get(task_id, context)
225+
226+
async def list(
227+
self,
228+
params: a2a_pb2.ListTasksRequest,
229+
context: ServerCallContext | None = None,
230+
) -> a2a_pb2.ListTasksResponse:
231+
"""Retrieves a list of tasks from the store."""
232+
return await self._store.list(params, context)
233+
234+
async def delete(
235+
self, task_id: str, context: ServerCallContext | None = None
236+
) -> None:
237+
"""Deletes a task from the store by ID."""
238+
await self._store.delete(task_id, context)
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import httpx
2+
import pytest
3+
from typing import NamedTuple
4+
5+
from starlette.applications import Starlette
6+
7+
from a2a.client.client import Client, ClientConfig
8+
from a2a.client.client_factory import ClientFactory
9+
from a2a.server.agent_execution import AgentExecutor, RequestContext
10+
from a2a.server.routes import create_agent_card_routes, create_jsonrpc_routes
11+
from a2a.server.events import EventQueue
12+
from a2a.server.events.in_memory_queue_manager import InMemoryQueueManager
13+
from a2a.server.request_handlers import DefaultRequestHandler
14+
from a2a.server.tasks import TaskUpdater
15+
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
16+
from a2a.types import (
17+
AgentCapabilities,
18+
AgentCard,
19+
AgentInterface,
20+
Artifact,
21+
GetTaskRequest,
22+
Message,
23+
Part,
24+
Role,
25+
SendMessageRequest,
26+
TaskState,
27+
)
28+
from a2a.utils import TransportProtocol
29+
30+
31+
class MockMutatingAgentExecutor(AgentExecutor):
32+
async def execute(self, context: RequestContext, event_queue: EventQueue):
33+
assert context.task_id is not None
34+
assert context.context_id is not None
35+
task_updater = TaskUpdater(
36+
event_queue,
37+
context.task_id,
38+
context.context_id,
39+
)
40+
41+
user_input = context.get_user_input()
42+
43+
if user_input == 'Init task':
44+
# Explicitly save status change to ensure task exists with some state
45+
await task_updater.update_status(
46+
TaskState.TASK_STATE_WORKING,
47+
message=task_updater.new_agent_message(
48+
[Part(text='task working')]
49+
),
50+
)
51+
else:
52+
# Mutate the task WITHOUT saving it properly
53+
assert context.current_task is not None
54+
context.current_task.artifacts.append(
55+
Artifact(
56+
name='leaked-artifact',
57+
parts=[Part(text='leaked artifact')],
58+
)
59+
)
60+
61+
async def cancel(self, context: RequestContext, event_queue: EventQueue):
62+
raise NotImplementedError('Cancellation is not supported')
63+
64+
65+
@pytest.fixture
66+
def agent_card() -> AgentCard:
67+
return AgentCard(
68+
name='Mutating Agent',
69+
description='Real in-memory integration testing.',
70+
version='1.0.0',
71+
capabilities=AgentCapabilities(
72+
streaming=True, push_notifications=False
73+
),
74+
skills=[],
75+
default_input_modes=['text/plain'],
76+
default_output_modes=['text/plain'],
77+
supported_interfaces=[
78+
AgentInterface(
79+
protocol_binding=TransportProtocol.JSONRPC,
80+
url='http://testserver',
81+
),
82+
],
83+
)
84+
85+
86+
class ClientSetup(NamedTuple):
87+
client: Client
88+
task_store: InMemoryTaskStore
89+
use_copying: bool
90+
91+
92+
def setup_client(agent_card: AgentCard, use_copying: bool) -> ClientSetup:
93+
task_store = InMemoryTaskStore(use_copying=use_copying)
94+
handler = DefaultRequestHandler(
95+
agent_executor=MockMutatingAgentExecutor(),
96+
task_store=task_store,
97+
queue_manager=InMemoryQueueManager(),
98+
)
99+
agent_card_routes = create_agent_card_routes(
100+
agent_card=agent_card, card_url='/'
101+
)
102+
jsonrpc_routes = create_jsonrpc_routes(
103+
agent_card=agent_card,
104+
request_handler=handler,
105+
extended_agent_card=agent_card,
106+
rpc_url='/',
107+
)
108+
app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes])
109+
httpx_client = httpx.AsyncClient(
110+
transport=httpx.ASGITransport(app=app), base_url='http://testserver'
111+
)
112+
factory = ClientFactory(
113+
config=ClientConfig(
114+
httpx_client=httpx_client,
115+
supported_protocol_bindings=[TransportProtocol.JSONRPC],
116+
)
117+
)
118+
client = factory.create(agent_card)
119+
return ClientSetup(
120+
client=client,
121+
task_store=task_store,
122+
use_copying=use_copying,
123+
)
124+
125+
126+
@pytest.mark.asyncio
127+
@pytest.mark.parametrize('use_copying', [True, False])
128+
async def test_mutation_observability(agent_card: AgentCard, use_copying: bool):
129+
"""Tests that task mutations are observable when copying is disabled.
130+
131+
When copying is disabled, the agent mutates the task in-place and the
132+
changes are observable by the client. When copying is enabled, the agent
133+
mutates a copy of the task and the changes are not observable by the client.
134+
135+
It is ok to remove the `use_copying` parameter from the system in the future
136+
to make InMemoryTaskStore consistent with other task stores.
137+
"""
138+
client_setup = setup_client(agent_card, use_copying)
139+
client = client_setup.client
140+
141+
# 1. First message to create the task
142+
message_to_send = Message(
143+
role=Role.ROLE_USER,
144+
message_id='msg-mut-init',
145+
parts=[Part(text='Init task')],
146+
)
147+
148+
events = [
149+
event
150+
async for event in client.send_message(
151+
request=SendMessageRequest(message=message_to_send)
152+
)
153+
]
154+
155+
task = events[-1][1]
156+
assert task is not None
157+
task_id = task.id
158+
159+
# 2. Second message to mutate it
160+
message_to_send_2 = Message(
161+
role=Role.ROLE_USER,
162+
message_id='msg-mut-do',
163+
task_id=task_id,
164+
parts=[Part(text='Update task without saving it')],
165+
)
166+
167+
_ = [
168+
event
169+
async for event in client.send_message(
170+
request=SendMessageRequest(message=message_to_send_2)
171+
)
172+
]
173+
174+
# 3. Get task via client
175+
retrieved_task = await client.get_task(request=GetTaskRequest(id=task_id))
176+
177+
# 4. Assert behavior based on `use_copying`
178+
if use_copying:
179+
# The un-saved artifact IS NOT leaked to the client
180+
assert len(retrieved_task.artifacts) == 0
181+
else:
182+
# The un-saved artifact IS leaked to the client
183+
assert len(retrieved_task.artifacts) == 1
184+
assert retrieved_task.artifacts[0].name == 'leaked-artifact'

0 commit comments

Comments
 (0)