1616#
1717"""Tests for vertex_rag.retrieval_preview."""
1818
19+ # pylint: disable=bad-indentation, unused-variable, unused-argument, redefined-outer-name, C0116
20+
1921import importlib
2022from google .cloud import aiplatform
2123from google .cloud .aiplatform_v1beta1 import VertexRagServiceAsyncClient
@@ -75,7 +77,9 @@ def rag_client_mock_exception():
7577
7678
7779def retrieve_contexts_eq (response , expected_response ):
78- assert len (response .contexts .contexts ) == len (expected_response .contexts .contexts )
80+ assert len (response .contexts .contexts ) == len (
81+ expected_response .contexts .contexts
82+ )
7983 assert (
8084 response .contexts .contexts [0 ].text
8185 == expected_response .contexts .contexts [0 ].text
@@ -87,7 +91,7 @@ def retrieve_contexts_eq(response, expected_response):
8791
8892
8993@pytest .mark .usefixtures ("google_auth_mock" )
90- class TestRagRetrieval : # pylint: disable=missing-class-docstring
94+ class TestRagRetrieval : # pylint: disable=missing-class-docstring, bad-indentation, unused-variable, unused-argument, redefined-outer-name
9195
9296 def setup_method (self ):
9397 importlib .reload (aiplatform .initializer )
@@ -118,6 +122,18 @@ def test_ask_contexts_rag_resources_success(self):
118122 )
119123 retrieve_contexts_eq (response , tc .TEST_RETRIEVAL_RESPONSE )
120124
125+ @pytest .mark .usefixtures ("ask_contexts_mock" )
126+ def test_ask_contexts_with_timeout (self , ask_contexts_mock ):
127+ response = rag .ask_contexts (
128+ rag_resources = [tc .TEST_RAG_RESOURCE ],
129+ text = tc .TEST_QUERY_TEXT ,
130+ rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_CONFIG_ALPHA ,
131+ timeout = 300 ,
132+ )
133+ ask_contexts_mock .assert_called_once ()
134+ args , kwargs = ask_contexts_mock .call_args
135+ assert kwargs ["timeout" ] == 300
136+
121137 @pytest .mark .usefixtures ("ask_contexts_mock" )
122138 def test_ask_contexts_multiple_rag_resources_success (self ):
123139 response = rag .ask_contexts (
@@ -138,8 +154,9 @@ def test_ask_contexts_multiple_rag_corpora_success(self):
138154 retrieve_contexts_eq (response , tc .TEST_RETRIEVAL_RESPONSE )
139155
140156 @pytest .mark .asyncio
141- @pytest .mark .usefixtures ("async_retrieve_contexts_mock" )
142- async def test_async_retrieve_contexts_rag_resources_success (self ):
157+ async def test_async_retrieve_contexts_rag_resources_success (
158+ self , async_retrieve_contexts_mock
159+ ):
143160 response = await rag .async_retrieve_contexts (
144161 rag_resources = [tc .TEST_RAG_RESOURCE ],
145162 text = tc .TEST_QUERY_TEXT ,
@@ -148,8 +165,23 @@ async def test_async_retrieve_contexts_rag_resources_success(self):
148165 retrieve_contexts_eq (response , tc .TEST_RETRIEVAL_RESPONSE )
149166
150167 @pytest .mark .asyncio
151- @pytest .mark .usefixtures ("async_retrieve_contexts_mock" )
152- async def test_async_retrieve_contexts_multiple_rag_resources_success (self ):
168+ async def test_async_retrieve_contexts_with_timeout (
169+ self , async_retrieve_contexts_mock
170+ ):
171+ response = await rag .async_retrieve_contexts (
172+ rag_resources = [tc .TEST_RAG_RESOURCE ],
173+ text = tc .TEST_QUERY_TEXT ,
174+ rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_CONFIG_ALPHA ,
175+ timeout = 300 ,
176+ )
177+ async_retrieve_contexts_mock .assert_called_once ()
178+ args , kwargs = async_retrieve_contexts_mock .call_args
179+ assert kwargs ["timeout" ] == 300
180+
181+ @pytest .mark .asyncio
182+ async def test_async_retrieve_contexts_multiple_rag_resources_success (
183+ self , async_retrieve_contexts_mock
184+ ):
153185 response = await rag .async_retrieve_contexts (
154186 rag_resources = [tc .TEST_RAG_RESOURCE , tc .TEST_RAG_RESOURCE ],
155187 text = tc .TEST_QUERY_TEXT ,
@@ -158,8 +190,9 @@ async def test_async_retrieve_contexts_multiple_rag_resources_success(self):
158190 retrieve_contexts_eq (response , tc .TEST_RETRIEVAL_RESPONSE )
159191
160192 @pytest .mark .asyncio
161- @pytest .mark .usefixtures ("async_retrieve_contexts_mock" )
162- async def test_async_retrieve_contexts_multiple_rag_corpora_success (self ):
193+ async def test_async_retrieve_contexts_multiple_rag_corpora_success (
194+ self , async_retrieve_contexts_mock
195+ ):
163196 with pytest .warns (DeprecationWarning ):
164197 response = await rag .async_retrieve_contexts (
165198 rag_corpora = [tc .TEST_RAG_CORPUS_ID , tc .TEST_RAG_CORPUS_ID ],
@@ -241,7 +274,8 @@ def test_retrieval_query_with_metadata_filter(self, retrieve_contexts_mock):
241274 args , kwargs = retrieve_contexts_mock .call_args
242275 request = kwargs ["request" ]
243276 assert (
244- request .query .rag_retrieval_config .filter .metadata_filter == metadata_filter
277+ request .query .rag_retrieval_config .filter .metadata_filter
278+ == metadata_filter
245279 )
246280
247281 @pytest .mark .usefixtures ("retrieve_contexts_mock" )
@@ -262,7 +296,7 @@ def test_retrieval_query_failure(self):
262296 similarity_top_k = 2 ,
263297 vector_distance_threshold = 0.5 ,
264298 )
265- e .match ("Failed in retrieving contexts due to" )
299+ e .match ("Failed in retrieving contexts due to" )
266300
267301 @pytest .mark .usefixtures ("rag_client_mock_exception" )
268302 def test_retrieval_query_config_failure (self ):
@@ -272,7 +306,7 @@ def test_retrieval_query_config_failure(self):
272306 text = tc .TEST_QUERY_TEXT ,
273307 rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_CONFIG ,
274308 )
275- e .match ("Failed in retrieving contexts due to" )
309+ e .match ("Failed in retrieving contexts due to" )
276310
277311 def test_retrieval_query_invalid_name (self ):
278312 with pytest .raises (ValueError ) as e :
@@ -282,7 +316,7 @@ def test_retrieval_query_invalid_name(self):
282316 similarity_top_k = 2 ,
283317 vector_distance_threshold = 0.5 ,
284318 )
285- e .match ("Invalid RagCorpus name" )
319+ e .match ("Invalid RagCorpus name" )
286320
287321 def test_retrieval_query_invalid_name_config (self ):
288322 with pytest .raises (ValueError ) as e :
@@ -291,7 +325,7 @@ def test_retrieval_query_invalid_name_config(self):
291325 text = tc .TEST_QUERY_TEXT ,
292326 rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_CONFIG ,
293327 )
294- e .match ("Invalid RagCorpus name" )
328+ e .match ("Invalid RagCorpus name" )
295329
296330 def test_retrieval_query_multiple_rag_corpora (self ):
297331 with pytest .raises (ValueError ) as e :
@@ -304,7 +338,7 @@ def test_retrieval_query_multiple_rag_corpora(self):
304338 similarity_top_k = 2 ,
305339 vector_distance_threshold = 0.5 ,
306340 )
307- e .match ("Currently only support 1 RagCorpus" )
341+ e .match ("Currently only support 1 RagCorpus" )
308342
309343 def test_retrieval_query_multiple_rag_corpora_config (self ):
310344 with pytest .raises (ValueError ) as e :
@@ -316,7 +350,7 @@ def test_retrieval_query_multiple_rag_corpora_config(self):
316350 text = tc .TEST_QUERY_TEXT ,
317351 rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_CONFIG ,
318352 )
319- e .match ("Currently only support 1 RagCorpus" )
353+ e .match ("Currently only support 1 RagCorpus" )
320354
321355 def test_retrieval_query_multiple_rag_resources (self ):
322356 with pytest .raises (ValueError ) as e :
@@ -329,7 +363,7 @@ def test_retrieval_query_multiple_rag_resources(self):
329363 similarity_top_k = 2 ,
330364 vector_distance_threshold = 0.5 ,
331365 )
332- e .match ("Currently only support 1 RagResource" )
366+ e .match ("Currently only support 1 RagResource" )
333367
334368 def test_retrieval_query_multiple_rag_resources_config (self ):
335369 with pytest .raises (ValueError ) as e :
@@ -341,7 +375,7 @@ def test_retrieval_query_multiple_rag_resources_config(self):
341375 text = tc .TEST_QUERY_TEXT ,
342376 rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_CONFIG ,
343377 )
344- e .match ("Currently only support 1 RagResource" )
378+ e .match ("Currently only support 1 RagResource" )
345379
346380 def test_retrieval_query_multiple_rag_resources_similarity_config (self ):
347381 with pytest .raises (ValueError ) as e :
@@ -353,7 +387,7 @@ def test_retrieval_query_multiple_rag_resources_similarity_config(self):
353387 text = tc .TEST_QUERY_TEXT ,
354388 rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG ,
355389 )
356- e .match ("Currently only support 1 RagResource" )
390+ e .match ("Currently only support 1 RagResource" )
357391
358392 def test_retrieval_query_invalid_config_filter (self ):
359393 with pytest .raises (ValueError ) as e :
@@ -362,8 +396,8 @@ def test_retrieval_query_invalid_config_filter(self):
362396 text = tc .TEST_QUERY_TEXT ,
363397 rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_ERROR_CONFIG ,
364398 )
365- e .match (
366- "Only one of vector_distance_threshold or"
367- " vector_similarity_threshold can be specified at a time"
368- " in rag_retrieval_config."
369- )
399+ e .match (
400+ "Only one of vector_distance_threshold or"
401+ " vector_similarity_threshold can be specified at a time"
402+ " in rag_retrieval_config."
403+ )
0 commit comments