Skip to content

Commit dce25cf

Browse files
feat(server): add async context manager support to EventQueue
# Conflicts: # src/a2a/server/events/event_queue.py
1 parent fd0a1bd commit dce25cf

2 files changed

Lines changed: 39 additions & 1 deletion

File tree

src/a2a/server/events/event_queue.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
import logging
33
import sys
44

5-
from a2a.types.a2a_pb2 import (
5+
from types import TracebackType
6+
7+
from typing_extensions import Self
8+
9+
from a2a.types import (
610
Message,
711
Task,
812
TaskArtifactUpdateEvent,
@@ -43,6 +47,19 @@ def __init__(self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE) -> None:
4347
self._lock = asyncio.Lock()
4448
logger.debug('EventQueue initialized.')
4549

50+
async def __aenter__(self) -> Self:
51+
"""Enters the async context manager, returning the queue itself."""
52+
return self
53+
54+
async def __aexit__(
55+
self,
56+
exc_type: type[BaseException] | None,
57+
exc_val: BaseException | None,
58+
exc_tb: TracebackType | None,
59+
) -> None:
60+
"""Exits the async context manager, ensuring close() is called."""
61+
await self.close()
62+
4663
async def enqueue_event(self, event: Event) -> None:
4764
"""Enqueues an event to this queue and all its children.
4865

tests/server/events/test_event_queue.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,27 @@ def test_constructor_invalid_max_queue_size() -> None:
7777
):
7878
EventQueue(max_queue_size=-10)
7979

80+
@pytest.mark.asyncio
81+
async def test_event_queue_async_context_manager(
82+
event_queue: EventQueue,
83+
) -> None:
84+
"""Test that EventQueue can be used as an async context manager."""
85+
async with event_queue as q:
86+
assert q is event_queue
87+
assert event_queue.is_closed() is False
88+
assert event_queue.is_closed() is True
89+
90+
91+
@pytest.mark.asyncio
92+
async def test_event_queue_async_context_manager_on_exception(
93+
event_queue: EventQueue,
94+
) -> None:
95+
"""Test that close() is called even when an exception occurs inside the context."""
96+
with pytest.raises(RuntimeError, match='boom'):
97+
async with event_queue:
98+
raise RuntimeError('boom')
99+
assert event_queue.is_closed() is True
100+
80101

81102
@pytest.mark.asyncio
82103
async def test_enqueue_and_dequeue_event(event_queue: EventQueue) -> None:

0 commit comments

Comments
 (0)