Skip to content

Commit e44cfbe

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add Tool compatibility with RagRetrievalConfig in both Vertex AI SDK and GenAI SDK for use with generate_content.
PiperOrigin-RevId: 890232571
1 parent e164b19 commit e44cfbe

3 files changed

Lines changed: 284 additions & 19 deletions

File tree

tests/unit/vertexai/test_generative_models.py

Lines changed: 133 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333
from vertexai import generative_models
3434
from vertexai.preview import (
3535
generative_models as preview_generative_models,
36-
rag,
3736
)
37+
from vertexai.preview.rag.utils import resources
3838
from vertexai.generative_models._generative_models import (
3939
prediction_service,
4040
gapic_prediction_service_types,
@@ -298,7 +298,7 @@ def mock_generate_content(
298298
tool.retrieval or tool.google_search_retrieval for tool in request.tools
299299
)
300300
has_rag_retrieval = any(
301-
isinstance(tool.retrieval, rag.Retrieval) for tool in request.tools
301+
isinstance(tool.retrieval, generative_models.grounding.Retrieval) and tool.retrieval._raw_retrieval._pb.WhichOneof("source") == "vertex_rag_store" for tool in request.tools
302302
)
303303
has_function_declarations = any(
304304
tool.function_declarations for tool in request.tools
@@ -367,6 +367,23 @@ def mock_generate_content(
367367
web_search_queries=[request.contents[0].parts[0].text],
368368
)
369369
elif has_rag_retrieval and request.contents[0].parts[0].text:
370+
vertex_rag_store = request.tools[0].retrieval.vertex_rag_store
371+
assert vertex_rag_store is not None
372+
# Validate rag_resources or rag_corpora
373+
if vertex_rag_store.rag_resources:
374+
assert len(vertex_rag_store.rag_resources) == 1
375+
assert vertex_rag_store.rag_resources[0].rag_corpus == f"projects/{_TEST_PROJECT}/locations/us-central1/ragCorpora/1234556"
376+
assert vertex_rag_store.rag_resources[0].rag_file_ids == ["123", "456"]
377+
elif vertex_rag_store.rag_corpora:
378+
assert vertex_rag_store.rag_corpora == [f"projects/{_TEST_PROJECT}/locations/us-central1/ragCorpora/654321"]
379+
380+
# Validate rag_retrieval_config and metadata_filter
381+
assert vertex_rag_store.rag_retrieval_config is not None
382+
assert vertex_rag_store.rag_retrieval_config.top_k == 1
383+
assert (
384+
vertex_rag_store.rag_retrieval_config.filter.metadata_filter
385+
== "test_metadata_filter"
386+
)
370387
grounding_metadata = gapic_content_types.GroundingMetadata(
371388
retrieval_queries=[request.contents[0].parts[0].text],
372389
)
@@ -1400,20 +1417,29 @@ def test_generate_content_grounding_vertex_ai_search_retriever_with_project_and_
14001417
assert response.text
14011418

14021419
@patch_genai_services
1403-
def test_generate_content_vertex_rag_retriever(self):
1404-
model = preview_generative_models.GenerativeModel("gemini-pro")
1420+
@pytest.mark.parametrize(
1421+
"generative_models",
1422+
[generative_models, preview_generative_models],
1423+
)
1424+
def test_generate_content_with_vertex_rag_store_tool(self, generative_models: generative_models):
1425+
model = generative_models.GenerativeModel("gemini-pro")
14051426
rag_resources = [
1406-
rag.RagResource(
1427+
resources.RagResource(
14071428
rag_corpus=f"projects/{_TEST_PROJECT}/locations/us-central1/ragCorpora/1234556",
14081429
rag_file_ids=["123", "456"],
14091430
),
14101431
]
1411-
rag_retriever_tool = preview_generative_models.Tool.from_retrieval(
1412-
retrieval=rag.Retrieval(
1413-
source=rag.VertexRagStore(
1432+
rag_retrieval_config = resources.RagRetrievalConfig(
1433+
top_k=1,
1434+
filter=resources.Filter(
1435+
metadata_filter="test_metadata_filter"
1436+
),
1437+
)
1438+
rag_retriever_tool = generative_models.Tool.from_retrieval(
1439+
retrieval=generative_models.grounding.Retrieval(
1440+
source=generative_models.grounding.VertexRagStore(
14141441
rag_resources=rag_resources,
1415-
similarity_top_k=1,
1416-
vector_distance_threshold=0.4,
1442+
rag_retrieval_config=rag_retrieval_config,
14171443
),
14181444
),
14191445
)
@@ -1422,6 +1448,103 @@ def test_generate_content_vertex_rag_retriever(self):
14221448
)
14231449
assert response.text
14241450

1451+
@patch_genai_services
1452+
@pytest.mark.parametrize(
1453+
"generative_models",
1454+
[generative_models, preview_generative_models],
1455+
)
1456+
def test_generate_content_with_vertex_rag_store_tool_with_rag_corpora(self, generative_models: generative_models):
1457+
model = generative_models.GenerativeModel("gemini-pro")
1458+
rag_corpora = [
1459+
f"projects/{_TEST_PROJECT}/locations/us-central1/ragCorpora/654321",
1460+
]
1461+
rag_retrieval_config = resources.RagRetrievalConfig(
1462+
top_k=1,
1463+
filter=resources.Filter(
1464+
metadata_filter="test_metadata_filter"
1465+
),
1466+
)
1467+
rag_retriever_tool = generative_models.Tool.from_retrieval(
1468+
retrieval=generative_models.grounding.Retrieval(
1469+
source=generative_models.grounding.VertexRagStore(
1470+
rag_corpora=rag_corpora,
1471+
rag_retrieval_config=rag_retrieval_config,
1472+
),
1473+
),
1474+
)
1475+
response = model.generate_content(
1476+
"Why is sky blue?", tools=[rag_retriever_tool]
1477+
)
1478+
assert response.text
1479+
1480+
@patch_genai_services
1481+
@pytest.mark.parametrize(
1482+
"generative_models",
1483+
[generative_models, preview_generative_models],
1484+
)
1485+
def test_generate_content_with_vertex_rag_store_tool_with_deprecated_fields(self, generative_models: generative_models):
1486+
model = generative_models.GenerativeModel("gemini-pro")
1487+
rag_resources = [
1488+
resources.RagResource(
1489+
rag_corpus=f"projects/{_TEST_PROJECT}/locations/us-central1/ragCorpora/1234556",
1490+
rag_file_ids=["123", "456"],
1491+
),
1492+
]
1493+
# Deprecated fields are now inside RagRetrievalConfig
1494+
rag_retrieval_config = resources.RagRetrievalConfig(
1495+
top_k=1,
1496+
filter=resources.Filter(
1497+
vector_distance_threshold=0.4,
1498+
),
1499+
)
1500+
# Expect warnings for deprecated fields
1501+
with pytest.warns(DeprecationWarning):
1502+
rag_retriever_tool = generative_models.Tool.from_retrieval(
1503+
retrieval=generative_models.grounding.Retrieval(
1504+
source=generative_models.grounding.VertexRagStore(
1505+
rag_resources=rag_resources,
1506+
rag_retrieval_config=rag_retrieval_config,
1507+
),
1508+
),
1509+
)
1510+
response = model.generate_content("Why is sky blue?", tools=[rag_retriever_tool])
1511+
assert response.text
1512+
1513+
@patch_genai_services
1514+
@pytest.mark.parametrize(
1515+
"generative_models",
1516+
[generative_models, preview_generative_models],
1517+
)
1518+
def test_generate_content_with_vertex_rag_store_tool_invalid_config(self, generative_models: generative_models):
1519+
model = generative_models.GenerativeModel("gemini-pro")
1520+
rag_resources = [
1521+
resources.RagResource(
1522+
rag_corpus=f"projects/{_TEST_PROJECT}/locations/us-central1/ragCorpora/1234556",
1523+
rag_file_ids=["123", "456"],
1524+
),
1525+
]
1526+
rag_retrieval_config = resources.RagRetrievalConfig(
1527+
top_k=1,
1528+
filter=resources.Filter(
1529+
vector_distance_threshold=0.4,
1530+
vector_similarity_threshold=0.6,
1531+
),
1532+
)
1533+
with pytest.raises(ValueError) as e:
1534+
rag_retriever_tool = generative_models.Tool.from_retrieval(
1535+
retrieval=generative_models.grounding.Retrieval(
1536+
source=generative_models.grounding.VertexRagStore(
1537+
rag_resources=rag_resources,
1538+
rag_retrieval_config=rag_retrieval_config,
1539+
),
1540+
),
1541+
)
1542+
e.match(
1543+
"Only one of vector_distance_threshold or"
1544+
" vector_similarity_threshold can be specified at a time"
1545+
" in rag_retrieval_config."
1546+
)
1547+
14251548
@patch_genai_services
14261549
def test_chat_automatic_function_calling_with_function_returning_dict(self):
14271550
generative_models = preview_generative_models

vertexai/generative_models/_generative_models.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666

6767
if TYPE_CHECKING:
6868
from vertexai.caching import CachedContent
69+
from vertexai.preview import generative_models as preview_generative_models
6970

7071
try:
7172
from PIL import Image as PIL_Image # pylint: disable=g-import-not-at-top
@@ -2078,8 +2079,13 @@ def from_retrieval(
20782079
cls,
20792080
retrieval: Union["grounding.Retrieval"],
20802081
) -> "Tool":
2081-
raw_tool = gapic_tool_types.Tool(retrieval=retrieval._raw_retrieval)
2082-
return cls._from_gapic(raw_tool=raw_tool)
2082+
# Late import to avoid circular dependency
2083+
from vertexai.preview import generative_models as preview_generative_models
2084+
raw_retrieval = retrieval._raw_retrieval
2085+
raw_tool = gapic_tool_types.Tool(retrieval=raw_retrieval)
2086+
response = cls._from_gapic(raw_tool=raw_tool)
2087+
response.retrieval = retrieval
2088+
return response
20832089

20842090
@classmethod
20852091
def from_google_search_retrieval(
@@ -2978,24 +2984,35 @@ class Retrieval:
29782984

29792985
def __init__(
29802986
self,
2981-
source: Union["grounding.VertexAISearch"],
2987+
source: Union["grounding.VertexAISearch", "grounding.VertexRagStore"],
29822988
disable_attribution: Optional[bool] = None,
29832989
):
29842990
"""Initializes a Retrieval tool.
29852991
29862992
Args:
2987-
source (VertexAISearch):
2988-
Set to use data source powered by Vertex AI Search.
2993+
source (Union[VertexAISearch, VertexRagStore]):
2994+
Set to use data source powered by Vertex AI Search or Vertex Rag Store.
29892995
disable_attribution (bool):
29902996
Deprecated. Disable using the result from this
29912997
tool in detecting grounding attribution. This
29922998
does not affect how the result is given to the
29932999
model for generation.
29943000
"""
2995-
self._raw_retrieval = gapic_tool_types.Retrieval(
2996-
vertex_ai_search=source._raw_vertex_ai_search,
2997-
disable_attribution=disable_attribution,
2998-
)
3001+
if isinstance(source, grounding.VertexAISearch):
3002+
self._raw_retrieval = gapic_tool_types.Retrieval(
3003+
vertex_ai_search=source._raw_vertex_ai_search,
3004+
disable_attribution=disable_attribution,
3005+
)
3006+
elif isinstance(source, grounding.VertexRagStore):
3007+
# Late import to avoid circular dependency
3008+
from vertexai.preview import generative_models as preview_generative_models
3009+
gapic_vertex_rag_store = preview_generative_models._preview_parse_vertex_rag_store_to_api(source)
3010+
self._raw_retrieval = gapic_tool_types.Retrieval(
3011+
vertex_rag_store=gapic_vertex_rag_store,
3012+
disable_attribution=disable_attribution,
3013+
)
3014+
else:
3015+
raise TypeError(f"Unexpected source type: {type(source)}")
29993016

30003017
class VertexAISearch:
30013018
r"""Retrieve from Vertex AI Search data store for grounding.
@@ -3034,6 +3051,29 @@ def __init__(
30343051
datastore=datastore,
30353052
)
30363053

3054+
class VertexRagStore:
3055+
"""Retrieve from Vertex Rag Store for grounding.
3056+
3057+
Attributes:
3058+
rag_resources: A list of RagResource. It can be used to specify corpus
3059+
only or ragfiles. Currently only support one corpus or multiple files
3060+
from one corpus. In the future we may open up multiple corpora support.
3061+
rag_corpora: If rag_resources is not specified, use rag_corpora as a list
3062+
of rag corpora names. Deprecated. Use rag_resources instead.
3063+
rag_retrieval_config: Optional. The config containing the retrieval
3064+
parameters, including top_k, vector_distance_threshold, and alpha.
3065+
"""
3066+
3067+
def __init__(
3068+
self,
3069+
rag_resources: Optional[List["preview_generative_models.rag.RagResource"]] = None,
3070+
rag_corpora: Optional[List[str]] = None,
3071+
rag_retrieval_config: Optional["preview_generative_models.rag.RagRetrievalConfig"] = None,
3072+
):
3073+
self.rag_resources = rag_resources
3074+
self.rag_corpora = rag_corpora
3075+
self.rag_retrieval_config = rag_retrieval_config
3076+
30373077

30383078
class preview_grounding(grounding): # pylint: disable=invalid-name
30393079
"""Grounding namespace (preview)."""

0 commit comments

Comments
 (0)