diff --git a/src/google/adk/flows/llm_flows/_output_schema_processor.py b/src/google/adk/flows/llm_flows/_output_schema_processor.py index 284cc21383..e58add9a6c 100644 --- a/src/google/adk/flows/llm_flows/_output_schema_processor.py +++ b/src/google/adk/flows/llm_flows/_output_schema_processor.py @@ -17,17 +17,25 @@ from __future__ import annotations import json +import logging from typing import AsyncGenerator +from google.genai import types from typing_extensions import override from ...agents.invocation_context import InvocationContext from ...events.event import Event from ...models.llm_request import LlmRequest from ...tools.set_model_response_tool import SetModelResponseTool +from ...utils._schema_utils import is_basemodel_schema from ...utils.output_schema_utils import can_use_output_schema_with_tools from ._base_llm_processor import BaseLlmRequestProcessor +logger = logging.getLogger('google_adk.' + __name__) + +# Max tool rounds before forcing set_model_response (N-1) or terminating (N). +_MAX_TOOL_ROUNDS = 25 + class _OutputSchemaRequestProcessor(BaseLlmRequestProcessor): """Processor that handles output schema for agents with tools.""" @@ -36,8 +44,6 @@ class _OutputSchemaRequestProcessor(BaseLlmRequestProcessor): async def run_async( self, invocation_context: InvocationContext, llm_request: LlmRequest ) -> AsyncGenerator[Event, None]: - from ...agents.llm_agent import LlmAgent - agent = invocation_context.agent # Check if we need the processor: output_schema + tools + cannot use output @@ -49,20 +55,56 @@ async def run_async( ): return + # Count how many tool rounds have occurred in this invocation. + tool_rounds = sum( + 1 + for e in invocation_context._get_events( + current_invocation=True, current_branch=True + ) + if e.get_function_responses() + ) + + # Terminate the invocation if the model never calls set_model_response. + if tool_rounds >= _MAX_TOOL_ROUNDS: + logger.error( + 'Tool execution reached %d rounds without producing structured' + ' output via set_model_response. Breaking loop to prevent' + ' runaway API costs.', + tool_rounds, + ) + invocation_context.end_invocation = True + return + # Add the set_model_response tool to handle structured output set_response_tool = SetModelResponseTool(agent.output_schema) llm_request.append_tools([set_response_tool]) - # Add instruction about using the set_model_response tool - instruction = ( - 'IMPORTANT: You have access to other tools, but you must provide ' - 'your final response using the set_model_response tool with the ' - 'required structured format. After using any other tools needed ' - 'to complete the task, always call set_model_response with your ' - 'final answer in the specified schema format.' - ) + # Primitive types (str, int, etc.) produce a trivial tool signature + # that flash models tend to ignore use a stronger instruction. + if is_basemodel_schema(agent.output_schema): + instruction = ( + 'After completing any needed tool calls, provide your final' + ' response by calling set_model_response with the required' + ' fields.' + ) + else: + instruction = ( + 'IMPORTANT: After using any needed tools, you MUST call' + ' set_model_response to provide your final answer.' + ' This is required to complete the task.' + ) llm_request.append_instructions([instruction]) + # On round N-1, restrict the model to only call set_model_response. + if tool_rounds >= _MAX_TOOL_ROUNDS - 1: + llm_request.config = llm_request.config or types.GenerateContentConfig() + llm_request.config.tool_config = types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode=types.FunctionCallingConfigMode.ANY, + allowed_function_names=['set_model_response'], + ) + ) + return yield # Generator requires yield statement in function body. diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 3a25799106..1c1ff108a2 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -1045,6 +1045,7 @@ async def _postprocess_live( ) ) yield final_event + return # Skip further processing after set_model_response. async def _postprocess_run_processors_async( self, invocation_context: InvocationContext, llm_response: LlmResponse @@ -1091,6 +1092,7 @@ async def _postprocess_handle_function_calls_async( ) ) yield final_event + return # Skip transfer_to_agent after set_model_response. transfer_to_agent = function_response_event.actions.transfer_to_agent if transfer_to_agent: agent_to_run = self._get_agent_to_run( diff --git a/tests/integration/test_output_schema_with_tools.py b/tests/integration/test_output_schema_with_tools.py new file mode 100644 index 0000000000..58f2e258aa --- /dev/null +++ b/tests/integration/test_output_schema_with_tools.py @@ -0,0 +1,131 @@ +"""Integration test for output_schema + tools behavior. + +Requires GOOGLE_API_KEY or Vertex AI credentials. +Run with: python -m pytest tests/integration/test_output_schema_with_tools.py -v -s +""" + +import os +import time + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.runners import Runner +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.genai import types +from pydantic import BaseModel +from pydantic import Field +import pytest + + +class AnalysisResult(BaseModel): + summary: str = Field(description='Brief summary of the analysis') + confidence: float = Field(description='Confidence score between 0 and 1') + + +def search_data(query: str) -> str: + """Search for data based on the query.""" + return f'Found data for: {query}. Revenue is $1M, growth is 15%.' + + +def calculate_metric(metric_name: str, value: float) -> str: + """Calculate a business metric.""" + return f'{metric_name}: {value * 1.1:.2f} (adjusted)' + + +# Skip if no API key is configured. +skip_no_api_key = pytest.mark.skipif( + not os.environ.get('GOOGLE_API_KEY') + and not os.environ.get('GOOGLE_GENAI_USE_VERTEXAI'), + reason='No Gemini API key or Vertex AI configured', +) + + +@skip_no_api_key +@pytest.mark.asyncio +async def test_basemodel_schema_with_tools(): + """Test that BaseModel output_schema + tools produces structured output.""" + agent = LlmAgent( + name='analyst', + model='gemini-2.5-flash', + instruction=( + 'Analyze the query using the available tools, then return' + ' structured output.' + ), + output_schema=AnalysisResult, + tools=[search_data, calculate_metric], + ) + + session_service = InMemorySessionService() + runner = Runner( + agent=agent, app_name='test_app', session_service=session_service + ) + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + + events = [] + start = time.time() + + async for event in runner.run_async( + user_id='test_user', + session_id=session.id, + new_message=types.Content( + role='user', + parts=[types.Part(text='Analyze Q1 revenue performance')], + ), + ): + events.append(event) + + elapsed = time.time() - start + + # Should complete within a reasonable time (not infinite loop). + assert elapsed < 120, f'Took {elapsed:.1f}s — possible infinite loop' + + # Should have at least one event with structured output. + final_texts = [ + e.content.parts[0].text + for e in events + if e.content and e.content.parts and e.content.parts[0].text + ] + assert len(final_texts) > 0, 'No text output produced' + print(f'\nCompleted in {elapsed:.1f}s with {len(events)} events') + print(f'Final output: {final_texts[-1][:200]}') + + +@skip_no_api_key +@pytest.mark.asyncio +async def test_str_schema_with_tools(): + """Test that str output_schema + tools produces output (not infinite loop).""" + agent = LlmAgent( + name='analyst', + model='gemini-2.5-flash', + instruction='Search for the data, then provide a brief text summary.', + output_schema=str, + tools=[search_data], + ) + + session_service = InMemorySessionService() + runner = Runner( + agent=agent, app_name='test_app', session_service=session_service + ) + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + + events = [] + start = time.time() + + async for event in runner.run_async( + user_id='test_user', + session_id=session.id, + new_message=types.Content( + role='user', + parts=[types.Part(text='What is the Q1 revenue?')], + ), + ): + events.append(event) + + elapsed = time.time() - start + + assert elapsed < 120, f'Took {elapsed:.1f}s — possible infinite loop' + assert len(events) > 0, 'No events produced' + print(f'\nCompleted in {elapsed:.1f}s with {len(events)} events') diff --git a/tests/unittests/flows/llm_flows/test_output_schema_processor.py b/tests/unittests/flows/llm_flows/test_output_schema_processor.py index 23c741bccc..1742002e5e 100644 --- a/tests/unittests/flows/llm_flows/test_output_schema_processor.py +++ b/tests/unittests/flows/llm_flows/test_output_schema_processor.py @@ -19,10 +19,13 @@ from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import LlmAgent from google.adk.agents.run_config import RunConfig +from google.adk.events.event import Event +from google.adk.flows.llm_flows._output_schema_processor import _MAX_TOOL_ROUNDS from google.adk.flows.llm_flows.single_flow import SingleFlow from google.adk.models.llm_request import LlmRequest from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.adk.tools.function_tool import FunctionTool +from google.genai import types from pydantic import BaseModel from pydantic import Field import pytest @@ -520,3 +523,251 @@ async def test_flow_yields_only_function_response_for_normal_tools(): assert first_event.get_function_responses()[0].response == { 'result': 'Searched for: test query' } + + +def _make_function_response_event(tool_name: str = 'dummy_tool') -> Event: + """Helper to create a function response event for round counting.""" + return Event( + author='test_agent', + content=types.Content( + role='user', + parts=[ + types.Part( + function_response=types.FunctionResponse( + name=tool_name, response={'result': 'ok'} + ) + ) + ], + ), + ) + + +@pytest.mark.asyncio +async def test_type_aware_instruction_basemodel(mocker): + """Test that BaseModel schema gets a softer instruction.""" + from google.adk.flows.llm_flows._output_schema_processor import _OutputSchemaRequestProcessor + + agent = LlmAgent( + name='test_agent', + model='gemini-1.5-flash', + output_schema=PersonSchema, + tools=[FunctionTool(func=dummy_tool)], + ) + invocation_context = await _create_invocation_context(agent) + llm_request = LlmRequest() + processor = _OutputSchemaRequestProcessor() + + mocker.patch( + 'google.adk.flows.llm_flows._output_schema_processor.can_use_output_schema_with_tools', + return_value=False, + ) + + async for _ in processor.run_async(invocation_context, llm_request): + pass + + assert ( + 'After completing any needed tool calls' + in llm_request.config.system_instruction + ) + assert 'IMPORTANT' not in llm_request.config.system_instruction + + +@pytest.mark.asyncio +async def test_type_aware_instruction_primitive(mocker): + """Test that primitive schema (str) gets a stronger instruction.""" + from google.adk.flows.llm_flows._output_schema_processor import _OutputSchemaRequestProcessor + + agent = LlmAgent( + name='test_agent', + model='gemini-1.5-flash', + output_schema=str, + tools=[FunctionTool(func=dummy_tool)], + ) + invocation_context = await _create_invocation_context(agent) + llm_request = LlmRequest() + processor = _OutputSchemaRequestProcessor() + + mocker.patch( + 'google.adk.flows.llm_flows._output_schema_processor.can_use_output_schema_with_tools', + return_value=False, + ) + + async for _ in processor.run_async(invocation_context, llm_request): + pass + + assert 'IMPORTANT' in llm_request.config.system_instruction + assert 'MUST call' in llm_request.config.system_instruction + + +@pytest.mark.asyncio +async def test_hard_cutoff_at_max_rounds(mocker): + """Test that invocation is terminated at _MAX_TOOL_ROUNDS.""" + from google.adk.flows.llm_flows._output_schema_processor import _OutputSchemaRequestProcessor + + agent = LlmAgent( + name='test_agent', + model='gemini-1.5-flash', + output_schema=PersonSchema, + tools=[FunctionTool(func=dummy_tool)], + ) + invocation_context = await _create_invocation_context(agent) + llm_request = LlmRequest() + processor = _OutputSchemaRequestProcessor() + + mocker.patch( + 'google.adk.flows.llm_flows._output_schema_processor.can_use_output_schema_with_tools', + return_value=False, + ) + + # Simulate _MAX_TOOL_ROUNDS function response events. + fake_events = [ + _make_function_response_event() for _ in range(_MAX_TOOL_ROUNDS) + ] + mocker.patch.object( + invocation_context, + '_get_events', + return_value=fake_events, + ) + + async for _ in processor.run_async(invocation_context, llm_request): + pass + + assert invocation_context.end_invocation is True + # Should NOT have added set_model_response tool. + assert 'set_model_response' not in llm_request.tools_dict + + +@pytest.mark.asyncio +async def test_force_tool_choice_at_penultimate_round(mocker): + """Test that tool_choice is forced on round N-1.""" + from google.adk.flows.llm_flows._output_schema_processor import _OutputSchemaRequestProcessor + + agent = LlmAgent( + name='test_agent', + model='gemini-1.5-flash', + output_schema=PersonSchema, + tools=[FunctionTool(func=dummy_tool)], + ) + invocation_context = await _create_invocation_context(agent) + llm_request = LlmRequest() + processor = _OutputSchemaRequestProcessor() + + mocker.patch( + 'google.adk.flows.llm_flows._output_schema_processor.can_use_output_schema_with_tools', + return_value=False, + ) + + # Simulate _MAX_TOOL_ROUNDS - 1 function response events. + fake_events = [ + _make_function_response_event() for _ in range(_MAX_TOOL_ROUNDS - 1) + ] + mocker.patch.object( + invocation_context, + '_get_events', + return_value=fake_events, + ) + + async for _ in processor.run_async(invocation_context, llm_request): + pass + + # Should still add the tool. + assert 'set_model_response' in llm_request.tools_dict + + # Should have forced tool_choice. + tool_config = llm_request.config.tool_config + assert tool_config is not None + fc_config = tool_config.function_calling_config + assert fc_config.mode == types.FunctionCallingConfigMode.ANY + assert fc_config.allowed_function_names == ['set_model_response'] + + +@pytest.mark.asyncio +async def test_no_force_tool_choice_on_normal_rounds(mocker): + """Test that tool_choice is NOT forced on normal rounds.""" + from google.adk.flows.llm_flows._output_schema_processor import _OutputSchemaRequestProcessor + + agent = LlmAgent( + name='test_agent', + model='gemini-1.5-flash', + output_schema=PersonSchema, + tools=[FunctionTool(func=dummy_tool)], + ) + invocation_context = await _create_invocation_context(agent) + llm_request = LlmRequest() + processor = _OutputSchemaRequestProcessor() + + mocker.patch( + 'google.adk.flows.llm_flows._output_schema_processor.can_use_output_schema_with_tools', + return_value=False, + ) + + # Simulate a few normal rounds. + fake_events = [_make_function_response_event() for _ in range(3)] + mocker.patch.object( + invocation_context, + '_get_events', + return_value=fake_events, + ) + + async for _ in processor.run_async(invocation_context, llm_request): + pass + + # Should have added the tool but NOT forced tool_choice. + assert 'set_model_response' in llm_request.tools_dict + assert llm_request.config.tool_config is None + + +@pytest.mark.asyncio +async def test_set_model_response_skips_transfer_to_agent(): + """Test that return after set_model_response prevents transfer_to_agent.""" + from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow + from google.adk.tools.set_model_response_tool import SetModelResponseTool + + # Create a sub_agent that would be the transfer target. + sub_agent = LlmAgent(name='sub_agent', model='gemini-1.5-flash') + agent = LlmAgent( + name='test_agent', + model='gemini-1.5-flash', + output_schema=PersonSchema, + sub_agents=[sub_agent], + ) + + invocation_context = await _create_invocation_context(agent) + flow = BaseLlmFlow() + + set_response_tool = SetModelResponseTool(PersonSchema) + llm_request = LlmRequest() + llm_request.tools_dict['set_model_response'] = set_response_tool + + # Create a function call event that includes BOTH set_model_response + # AND a transfer_to_agent action on the response. + function_call_event = Event( + author='test_agent', + content=types.Content( + role='model', + parts=[ + types.Part( + function_call=types.FunctionCall( + name='set_model_response', + args={ + 'name': 'Test User', + 'age': 30, + 'city': 'Test City', + }, + ) + ) + ], + ), + ) + + events = [] + async for event in flow._postprocess_handle_function_calls_async( + invocation_context, function_call_event, llm_request + ): + events.append(event) + + # Should yield exactly 2 events (function response + final model response) + # and NOT attempt to run sub_agent via transfer_to_agent. + assert len(events) == 2 + assert events[1].content.role == 'model' + assert '"Test User"' in events[1].content.parts[0].text