diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index f3751206a8..2710c3894c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index d857da9635..e31db15788 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/src/google/adk/utils/model_name_utils.py b/src/google/adk/utils/model_name_utils.py index 86fd79ab64..9a2ee90e1c 100644 --- a/src/google/adk/utils/model_name_utils.py +++ b/src/google/adk/utils/model_name_utils.py @@ -124,7 +124,7 @@ def is_gemini_2_or_above(model_string: Optional[str]) -> bool: except InvalidVersion: return False - return parsed_version.major >= 2 + return bool(parsed_version.major >= 2) def is_gemini_3_1_flash_live(model_string: Optional[str]) -> bool: @@ -143,3 +143,31 @@ def is_gemini_3_1_flash_live(model_string: Optional[str]) -> bool: return False return model_string == 'gemini-3.1-flash-live-preview' + + +def is_gemini_3_or_above(model_string: Optional[str]) -> bool: + """Check if the model is a Gemini 3.0 or newer model using semantic versions. + + Args: + model_string: Either a simple model name or path-based model name + + Returns: + True if it's a Gemini 3.0+ model, False otherwise + """ + if not model_string: + return False + + model_name = extract_model_name(model_string) + if not model_name.startswith('gemini-'): + return False + + version_string = model_name[len('gemini-') :].split('-', 1)[0] + if not version_string: + return False + + try: + parsed_version = Version(version_string) + except InvalidVersion: + return False + + return bool(parsed_version.major >= 3) diff --git a/src/google/adk/utils/output_schema_utils.py b/src/google/adk/utils/output_schema_utils.py index 228e95b66d..e7f18e483b 100644 --- a/src/google/adk/utils/output_schema_utils.py +++ b/src/google/adk/utils/output_schema_utils.py @@ -23,7 +23,7 @@ from typing import Union from ..models.base_llm import BaseLlm -from .model_name_utils import is_gemini_2_or_above +from .model_name_utils import is_gemini_3_or_above from .variant_utils import get_google_llm_variant from .variant_utils import GoogleLLMVariant @@ -49,5 +49,5 @@ def can_use_output_schema_with_tools(model: Union[str, BaseLlm]) -> bool: return ( get_google_llm_variant() == GoogleLLMVariant.VERTEX_AI - and is_gemini_2_or_above(model_string) + and is_gemini_3_or_above(model_string) ) diff --git a/tests/unittests/utils/test_model_name_utils.py b/tests/unittests/utils/test_model_name_utils.py index 2af1584b05..12ead6037c 100644 --- a/tests/unittests/utils/test_model_name_utils.py +++ b/tests/unittests/utils/test_model_name_utils.py @@ -17,6 +17,7 @@ from google.adk.utils.model_name_utils import extract_model_name from google.adk.utils.model_name_utils import is_gemini_1_model from google.adk.utils.model_name_utils import is_gemini_2_or_above +from google.adk.utils.model_name_utils import is_gemini_3_or_above from google.adk.utils.model_name_utils import is_gemini_model from google.adk.utils.model_name_utils import is_gemini_model_id_check_disabled @@ -236,6 +237,40 @@ def test_is_gemini_2_or_above_edge_cases(self): assert is_gemini_2_or_above('gemini-one') is False +class TestIsGemini3OrAbove: + """Test the is_gemini_3_or_above function.""" + + def test_is_gemini_3_or_above_simple_names(self): + """Test Gemini 3.0+ model detection with simple model names.""" + assert is_gemini_3_or_above('gemini-3.0-pro') is True + assert is_gemini_3_or_above('gemini-3.1-pro-preview') is True + assert is_gemini_3_or_above('gemini-3-flash-preview') is True + assert is_gemini_3_or_above('gemini-4.0-pro') is True + assert is_gemini_3_or_above('gemini-2.5-pro') is False + assert is_gemini_3_or_above('gemini-2.0-flash') is False + assert is_gemini_3_or_above('gemini-1.5-flash') is False + assert is_gemini_3_or_above('claude-3-sonnet') is False + + def test_is_gemini_3_or_above_path_based_names(self): + """Test Gemini 3.0+ model detection with path-based model names.""" + gemini_3_path = 'projects/12345/locations/us-east1/publishers/google/models/gemini-3.0-pro' + assert is_gemini_3_or_above(gemini_3_path) is True + + gemini_3_path_2 = 'projects/12345/locations/us-east1/publishers/google/models/gemini-3.1-pro-preview' + assert is_gemini_3_or_above(gemini_3_path_2) is True + + gemini_2_path = 'projects/265104255505/locations/us-central1/publishers/google/models/gemini-2.5-pro' + assert is_gemini_3_or_above(gemini_2_path) is False + + def test_is_gemini_3_or_above_edge_cases(self): + """Test edge cases for Gemini 3.0+ model detection.""" + assert is_gemini_3_or_above(None) is False + assert is_gemini_3_or_above('') is False + assert is_gemini_3_or_above('gemini-3.') is False + assert is_gemini_3_or_above('gemini-one') is False + assert is_gemini_3_or_above('my-gemini-3.0-model') is False + + class TestModelNameUtilsIntegration: """Integration tests for model name utilities.""" diff --git a/tests/unittests/utils/test_output_schema_utils.py b/tests/unittests/utils/test_output_schema_utils.py index 2f9eb4bb09..9742ad424e 100644 --- a/tests/unittests/utils/test_output_schema_utils.py +++ b/tests/unittests/utils/test_output_schema_utils.py @@ -29,13 +29,19 @@ @pytest.mark.parametrize( "model, env_value, expected", [ - ("gemini-2.5-pro", "1", True), + ("gemini-3.1-pro", "1", True), + ("gemini-3.1-pro", "0", False), + ("gemini-3.1-pro", None, False), + (Gemini(model="gemini-3.0-flash"), "1", True), + (Gemini(model="gemini-3.0-flash"), "0", False), + (Gemini(model="gemini-3.0-flash"), None, False), + ("gemini-2.5-pro", "1", False), ("gemini-2.5-pro", "0", False), ("gemini-2.5-pro", None, False), - (Gemini(model="gemini-2.5-pro"), "1", True), + (Gemini(model="gemini-2.5-pro"), "1", False), (Gemini(model="gemini-2.5-pro"), "0", False), (Gemini(model="gemini-2.5-pro"), None, False), - ("gemini-2.0-flash", "1", True), + ("gemini-2.0-flash", "1", False), ("gemini-2.0-flash", "0", False), ("gemini-2.0-flash", None, False), ("gemini-1.5-pro", "1", False),