3333from vertexai import generative_models
3434from vertexai .preview import (
3535 generative_models as preview_generative_models ,
36- rag ,
3736)
37+ from vertexai .preview .rag .utils import resources
3838from vertexai .generative_models ._generative_models import (
3939 prediction_service ,
4040 gapic_prediction_service_types ,
@@ -298,7 +298,7 @@ def mock_generate_content(
298298 tool .retrieval or tool .google_search_retrieval for tool in request .tools
299299 )
300300 has_rag_retrieval = any (
301- isinstance (tool .retrieval , rag . Retrieval ) for tool in request .tools
301+ isinstance (tool .retrieval , generative_models . grounding . Retrieval ) and tool . retrieval . _raw_retrieval . _pb . WhichOneof ( "source" ) == "vertex_rag_store" for tool in request .tools
302302 )
303303 has_function_declarations = any (
304304 tool .function_declarations for tool in request .tools
@@ -367,6 +367,23 @@ def mock_generate_content(
367367 web_search_queries = [request .contents [0 ].parts [0 ].text ],
368368 )
369369 elif has_rag_retrieval and request .contents [0 ].parts [0 ].text :
370+ vertex_rag_store = request .tools [0 ].retrieval .vertex_rag_store
371+ assert vertex_rag_store is not None
372+ # Validate rag_resources or rag_corpora
373+ if vertex_rag_store .rag_resources :
374+ assert len (vertex_rag_store .rag_resources ) == 1
375+ assert vertex_rag_store .rag_resources [0 ].rag_corpus == f"projects/{ _TEST_PROJECT } /locations/us-central1/ragCorpora/1234556"
376+ assert vertex_rag_store .rag_resources [0 ].rag_file_ids == ["123" , "456" ]
377+ elif vertex_rag_store .rag_corpora :
378+ assert vertex_rag_store .rag_corpora == [f"projects/{ _TEST_PROJECT } /locations/us-central1/ragCorpora/654321" ]
379+
380+ # Validate rag_retrieval_config and metadata_filter
381+ assert vertex_rag_store .rag_retrieval_config is not None
382+ assert vertex_rag_store .rag_retrieval_config .top_k == 1
383+ assert (
384+ vertex_rag_store .rag_retrieval_config .filter .metadata_filter
385+ == "test_metadata_filter"
386+ )
370387 grounding_metadata = gapic_content_types .GroundingMetadata (
371388 retrieval_queries = [request .contents [0 ].parts [0 ].text ],
372389 )
@@ -1400,20 +1417,29 @@ def test_generate_content_grounding_vertex_ai_search_retriever_with_project_and_
14001417 assert response .text
14011418
14021419 @patch_genai_services
1403- def test_generate_content_vertex_rag_retriever (self ):
1404- model = preview_generative_models .GenerativeModel ("gemini-pro" )
1420+ @pytest .mark .parametrize (
1421+ "generative_models" ,
1422+ [generative_models , preview_generative_models ],
1423+ )
1424+ def test_generate_content_with_vertex_rag_store_tool (self , generative_models : generative_models ):
1425+ model = generative_models .GenerativeModel ("gemini-pro" )
14051426 rag_resources = [
1406- rag .RagResource (
1427+ resources .RagResource (
14071428 rag_corpus = f"projects/{ _TEST_PROJECT } /locations/us-central1/ragCorpora/1234556" ,
14081429 rag_file_ids = ["123" , "456" ],
14091430 ),
14101431 ]
1411- rag_retriever_tool = preview_generative_models .Tool .from_retrieval (
1412- retrieval = rag .Retrieval (
1413- source = rag .VertexRagStore (
1432+ rag_retrieval_config = resources .RagRetrievalConfig (
1433+ top_k = 1 ,
1434+ filter = resources .Filter (
1435+ metadata_filter = "test_metadata_filter"
1436+ ),
1437+ )
1438+ rag_retriever_tool = generative_models .Tool .from_retrieval (
1439+ retrieval = generative_models .grounding .Retrieval (
1440+ source = generative_models .grounding .VertexRagStore (
14141441 rag_resources = rag_resources ,
1415- similarity_top_k = 1 ,
1416- vector_distance_threshold = 0.4 ,
1442+ rag_retrieval_config = rag_retrieval_config ,
14171443 ),
14181444 ),
14191445 )
@@ -1422,6 +1448,103 @@ def test_generate_content_vertex_rag_retriever(self):
14221448 )
14231449 assert response .text
14241450
1451+ @patch_genai_services
1452+ @pytest .mark .parametrize (
1453+ "generative_models" ,
1454+ [generative_models , preview_generative_models ],
1455+ )
1456+ def test_generate_content_with_vertex_rag_store_tool_with_rag_corpora (self , generative_models : generative_models ):
1457+ model = generative_models .GenerativeModel ("gemini-pro" )
1458+ rag_corpora = [
1459+ f"projects/{ _TEST_PROJECT } /locations/us-central1/ragCorpora/654321" ,
1460+ ]
1461+ rag_retrieval_config = resources .RagRetrievalConfig (
1462+ top_k = 1 ,
1463+ filter = resources .Filter (
1464+ metadata_filter = "test_metadata_filter"
1465+ ),
1466+ )
1467+ rag_retriever_tool = generative_models .Tool .from_retrieval (
1468+ retrieval = generative_models .grounding .Retrieval (
1469+ source = generative_models .grounding .VertexRagStore (
1470+ rag_corpora = rag_corpora ,
1471+ rag_retrieval_config = rag_retrieval_config ,
1472+ ),
1473+ ),
1474+ )
1475+ response = model .generate_content (
1476+ "Why is sky blue?" , tools = [rag_retriever_tool ]
1477+ )
1478+ assert response .text
1479+
1480+ @patch_genai_services
1481+ @pytest .mark .parametrize (
1482+ "generative_models" ,
1483+ [generative_models , preview_generative_models ],
1484+ )
1485+ def test_generate_content_with_vertex_rag_store_tool_with_deprecated_fields (self , generative_models : generative_models ):
1486+ model = generative_models .GenerativeModel ("gemini-pro" )
1487+ rag_resources = [
1488+ resources .RagResource (
1489+ rag_corpus = f"projects/{ _TEST_PROJECT } /locations/us-central1/ragCorpora/1234556" ,
1490+ rag_file_ids = ["123" , "456" ],
1491+ ),
1492+ ]
1493+ # Deprecated fields are now inside RagRetrievalConfig
1494+ rag_retrieval_config = resources .RagRetrievalConfig (
1495+ top_k = 1 ,
1496+ filter = resources .Filter (
1497+ vector_distance_threshold = 0.4 ,
1498+ ),
1499+ )
1500+ # Expect warnings for deprecated fields
1501+ with pytest .warns (DeprecationWarning ):
1502+ rag_retriever_tool = generative_models .Tool .from_retrieval (
1503+ retrieval = generative_models .grounding .Retrieval (
1504+ source = generative_models .grounding .VertexRagStore (
1505+ rag_resources = rag_resources ,
1506+ rag_retrieval_config = rag_retrieval_config ,
1507+ ),
1508+ ),
1509+ )
1510+ response = model .generate_content ("Why is sky blue?" , tools = [rag_retriever_tool ])
1511+ assert response .text
1512+
1513+ @patch_genai_services
1514+ @pytest .mark .parametrize (
1515+ "generative_models" ,
1516+ [generative_models , preview_generative_models ],
1517+ )
1518+ def test_generate_content_with_vertex_rag_store_tool_invalid_config (self , generative_models : generative_models ):
1519+ model = generative_models .GenerativeModel ("gemini-pro" )
1520+ rag_resources = [
1521+ resources .RagResource (
1522+ rag_corpus = f"projects/{ _TEST_PROJECT } /locations/us-central1/ragCorpora/1234556" ,
1523+ rag_file_ids = ["123" , "456" ],
1524+ ),
1525+ ]
1526+ rag_retrieval_config = resources .RagRetrievalConfig (
1527+ top_k = 1 ,
1528+ filter = resources .Filter (
1529+ vector_distance_threshold = 0.4 ,
1530+ vector_similarity_threshold = 0.6 ,
1531+ ),
1532+ )
1533+ with pytest .raises (ValueError ) as e :
1534+ rag_retriever_tool = generative_models .Tool .from_retrieval (
1535+ retrieval = generative_models .grounding .Retrieval (
1536+ source = generative_models .grounding .VertexRagStore (
1537+ rag_resources = rag_resources ,
1538+ rag_retrieval_config = rag_retrieval_config ,
1539+ ),
1540+ ),
1541+ )
1542+ e .match (
1543+ "Only one of vector_distance_threshold or"
1544+ " vector_similarity_threshold can be specified at a time"
1545+ " in rag_retrieval_config."
1546+ )
1547+
14251548 @patch_genai_services
14261549 def test_chat_automatic_function_calling_with_function_returning_dict (self ):
14271550 generative_models = preview_generative_models
0 commit comments