diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 053f2b3fec..0618d67b6a 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -455,6 +455,7 @@ async def _run_node_async( user_id: str, session_id: str, new_message: Optional[types.Content] = None, + state_delta: Optional[dict[str, Any]] = None, run_config: Optional[RunConfig] = None, yield_user_message: bool = False, node: Optional['BaseNode'] = None, @@ -512,7 +513,9 @@ async def _run_node_async( # Append user message to session for history if new_message: - user_event = await self._append_user_event(ic, new_message) + user_event = await self._append_user_event( + ic, new_message, state_delta=state_delta + ) if yield_user_message and user_event: yield user_event @@ -706,14 +709,26 @@ def _resolve_invocation_id_from_fr( return invocation_ids.pop() async def _append_user_event( - self, ic: InvocationContext, content: types.Content + self, + ic: InvocationContext, + content: types.Content, + *, + state_delta: Optional[dict[str, Any]] = None, ) -> Event: """Append a user message event to the session and return it.""" - event = Event( - invocation_id=ic.invocation_id, - author='user', - content=content, - ) + if state_delta: + event = Event( + invocation_id=ic.invocation_id, + author='user', + actions=EventActions(state_delta=state_delta), + content=content, + ) + else: + event = Event( + invocation_id=ic.invocation_id, + author='user', + content=content, + ) # when a paused task delegation is in flight, stamp # the new user message with that task's isolation_scope so the # task agent's content-build (scoped to ) sees it. @@ -989,6 +1004,7 @@ async def run_async( user_id=user_id, session_id=session_id, new_message=new_message, + state_delta=state_delta, run_config=run_config, yield_user_message=yield_user_message, node=agent_to_run, @@ -1008,6 +1024,7 @@ async def run_async( user_id=user_id, session_id=session_id, new_message=new_message, + state_delta=state_delta, run_config=run_config, yield_user_message=yield_user_message, ) diff --git a/tests/unittests/runners/test_runner_node.py b/tests/unittests/runners/test_runner_node.py index 7c01ce44a3..5abdfee529 100644 --- a/tests/unittests/runners/test_runner_node.py +++ b/tests/unittests/runners/test_runner_node.py @@ -24,7 +24,9 @@ from typing import Any from typing import AsyncGenerator +from google.adk.agents.callback_context import CallbackContext from google.adk.agents.context import Context +from google.adk.agents.llm_agent import LlmAgent from google.adk.events.event import Event from google.adk.runners import Runner from google.adk.sessions.in_memory_session_service import InMemorySessionService @@ -49,6 +51,10 @@ async def _run_impl( yield f'Echo: {text}' +def _user_message(text: str = 'hello') -> types.Content: + return types.Content(parts=[types.Part(text=text)], role='user') + + async def _run_node(node, message='hello'): """Run a BaseNode via Runner(node=...) and return (events, ss, session).""" ss = InMemorySessionService() @@ -288,6 +294,129 @@ async def test_yield_user_message_false_by_default(): assert user_events == [] +@pytest.mark.asyncio +async def test_node_runner_applies_state_delta_before_base_node_runs(): + """A BaseNode sees run_async state_delta as session state.""" + + class _StateReaderNode(BaseNode): + + async def _run_impl( + self, *, ctx: Context, node_input: Any + ) -> AsyncGenerator[Any, None]: + yield f'state:{ctx.state["test_state"]}' + + session_service = InMemorySessionService() + runner = Runner( + app_name='test', + node=_StateReaderNode(name='reader'), + session_service=session_service, + ) + session = await session_service.create_session(app_name='test', user_id='u') + + events: list[Event] = [] + async for event in runner.run_async( + user_id='u', + session_id=session.id, + new_message=_user_message(), + state_delta={'test_state': 'must_change'}, + ): + events.append(event) + + updated = await session_service.get_session( + app_name='test', user_id='u', session_id=session.id + ) + user_events = [event for event in updated.events if event.author == 'user'] + + assert [event.output for event in events if event.output is not None] == [ + 'state:must_change' + ] + assert updated.state['test_state'] == 'must_change' + assert user_events[0].actions.state_delta == {'test_state': 'must_change'} + + +@pytest.mark.asyncio +async def test_node_runner_yields_user_event_with_state_delta(): + """yield_user_message=True yields the user event with state_delta.""" + + class _NoopNode(BaseNode): + + async def _run_impl( + self, *, ctx: Context, node_input: Any + ) -> AsyncGenerator[Any, None]: + yield 'done' + + session_service = InMemorySessionService() + runner = Runner( + app_name='test', + node=_NoopNode(name='noop'), + session_service=session_service, + ) + session = await session_service.create_session(app_name='test', user_id='u') + + events: list[Event] = [] + async for event in runner.run_async( + user_id='u', + session_id=session.id, + new_message=_user_message(), + state_delta={'test_state': 'must_change'}, + yield_user_message=True, + ): + events.append(event) + + assert events[0].author == 'user' + assert events[0].actions.state_delta == {'test_state': 'must_change'} + + +@pytest.mark.asyncio +async def test_node_runner_applies_state_delta_before_llm_agent_runs(): + """An LlmAgent callback sees run_async state_delta before model execution.""" + + captured_state_value = None + + def _before_agent_callback( + callback_context: CallbackContext, + ) -> types.Content: + nonlocal captured_state_value + captured_state_value = callback_context.state['test_state'] + return types.Content( + role='model', + parts=[types.Part(text=f'state:{captured_state_value}')], + ) + + session_service = InMemorySessionService() + agent = LlmAgent( + name='state_agent', + before_agent_callback=_before_agent_callback, + ) + runner = Runner(app_name='test', agent=agent, session_service=session_service) + session = await session_service.create_session(app_name='test', user_id='u') + + events: list[Event] = [] + async for event in runner.run_async( + user_id='u', + session_id=session.id, + new_message=_user_message(), + state_delta={'test_state': 'must_change'}, + ): + events.append(event) + + updated = await session_service.get_session( + app_name='test', user_id='u', session_id=session.id + ) + user_events = [event for event in updated.events if event.author == 'user'] + response_texts = [ + part.text + for event in events + if event.content + for part in event.content.parts + if part.text + ] + + assert captured_state_value == 'must_change' + assert 'state:must_change' in response_texts + assert user_events[0].actions.state_delta == {'test_state': 'must_change'} + + # --------------------------------------------------------------------------- # Resume (HITL) # ---------------------------------------------------------------------------