Skip to content

Commit c149ebb

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Increase default timeout to 300 seconds for ask_contexts and async_retrieve_contexts in VertexRagServiceClient.
PiperOrigin-RevId: 892413253
1 parent c12aedc commit c149ebb

4 files changed

Lines changed: 201 additions & 91 deletions

File tree

tests/unit/vertex_rag/test_rag_retrieval.py

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#
1717
"""Tests for vertex_rag.retrieval."""
1818

19+
# pylint: disable=bad-indentation, unused-variable, unused-argument, redefined-outer-name, C0116
20+
1921
import importlib
2022
from google.cloud import aiplatform
2123
from google.cloud.aiplatform_v1 import VertexRagServiceAsyncClient
@@ -73,7 +75,9 @@ def rag_client_mock_exception():
7375

7476

7577
def retrieve_contexts_eq(response, expected_response):
76-
assert len(response.contexts.contexts) == len(expected_response.contexts.contexts)
78+
assert len(response.contexts.contexts) == len(
79+
expected_response.contexts.contexts
80+
)
7781
assert (
7882
response.contexts.contexts[0].text
7983
== expected_response.contexts.contexts[0].text
@@ -85,7 +89,7 @@ def retrieve_contexts_eq(response, expected_response):
8589

8690

8791
@pytest.mark.usefixtures("google_auth_mock")
88-
class TestRagRetrieval: # pylint: disable=missing-class-docstring
92+
class TestRagRetrieval: # pylint: disable=missing-class-docstring, bad-indentation, unused-variable, unused-argument, redefined-outer-name
8993

9094
def setup_method(self):
9195
importlib.reload(aiplatform.initializer)
@@ -113,6 +117,18 @@ def test_ask_contexts_rag_resources_success(self):
113117
)
114118
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
115119

120+
@pytest.mark.usefixtures("ask_contexts_mock")
121+
def test_ask_contexts_with_timeout(self, ask_contexts_mock):
122+
rag.ask_contexts(
123+
rag_resources=[tc.TEST_RAG_RESOURCE],
124+
text=tc.TEST_QUERY_TEXT,
125+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
126+
timeout=300,
127+
)
128+
ask_contexts_mock.assert_called_once()
129+
_, kwargs = ask_contexts_mock.call_args
130+
assert kwargs["timeout"] == 300
131+
116132
@pytest.mark.usefixtures("ask_contexts_mock")
117133
def test_ask_contexts_multiple_rag_resources_success(self):
118134
response = rag.ask_contexts(
@@ -123,8 +139,9 @@ def test_ask_contexts_multiple_rag_resources_success(self):
123139
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
124140

125141
@pytest.mark.asyncio
126-
@pytest.mark.usefixtures("async_retrieve_contexts_mock")
127-
async def test_async_retrieve_contexts_rag_resources_success(self):
142+
async def test_async_retrieve_contexts_rag_resources_success(
143+
self, async_retrieve_contexts_mock
144+
):
128145
response = await rag.async_retrieve_contexts(
129146
rag_resources=[tc.TEST_RAG_RESOURCE],
130147
text=tc.TEST_QUERY_TEXT,
@@ -133,8 +150,23 @@ async def test_async_retrieve_contexts_rag_resources_success(self):
133150
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
134151

135152
@pytest.mark.asyncio
136-
@pytest.mark.usefixtures("async_retrieve_contexts_mock")
137-
async def test_async_retrieve_contexts_multiple_rag_resources_success(self):
153+
async def test_async_retrieve_contexts_with_timeout(
154+
self, async_retrieve_contexts_mock
155+
):
156+
await rag.async_retrieve_contexts(
157+
rag_resources=[tc.TEST_RAG_RESOURCE],
158+
text=tc.TEST_QUERY_TEXT,
159+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
160+
timeout=300,
161+
)
162+
async_retrieve_contexts_mock.assert_called_once()
163+
_, kwargs = async_retrieve_contexts_mock.call_args
164+
assert kwargs["timeout"] == 300
165+
166+
@pytest.mark.asyncio
167+
async def test_async_retrieve_contexts_multiple_rag_resources_success(
168+
self, async_retrieve_contexts_mock
169+
):
138170
response = await rag.async_retrieve_contexts(
139171
rag_resources=[tc.TEST_RAG_RESOURCE, tc.TEST_RAG_RESOURCE],
140172
text=tc.TEST_QUERY_TEXT,
@@ -177,7 +209,7 @@ def test_retrieval_query_failure(self):
177209
text=tc.TEST_QUERY_TEXT,
178210
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
179211
)
180-
e.match("Failed in retrieving contexts due to")
212+
e.match("Failed in retrieving contexts due to")
181213

182214
def test_retrieval_query_invalid_name(self):
183215
with pytest.raises(ValueError) as e:
@@ -186,7 +218,7 @@ def test_retrieval_query_invalid_name(self):
186218
text=tc.TEST_QUERY_TEXT,
187219
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
188220
)
189-
e.match("Invalid RagCorpus name")
221+
e.match("Invalid RagCorpus name")
190222

191223
def test_retrieval_query_multiple_rag_resources(self):
192224
with pytest.raises(ValueError) as e:
@@ -195,7 +227,7 @@ def test_retrieval_query_multiple_rag_resources(self):
195227
text=tc.TEST_QUERY_TEXT,
196228
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
197229
)
198-
e.match("Currently only support 1 RagResource")
230+
e.match("Currently only support 1 RagResource")
199231

200232
def test_retrieval_query_similarity_multiple_rag_resources(self):
201233
with pytest.raises(ValueError) as e:
@@ -204,7 +236,7 @@ def test_retrieval_query_similarity_multiple_rag_resources(self):
204236
text=tc.TEST_QUERY_TEXT,
205237
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
206238
)
207-
e.match("Currently only support 1 RagResource")
239+
e.match("Currently only support 1 RagResource")
208240

209241
def test_retrieval_query_invalid_config_filter(self):
210242
with pytest.raises(ValueError) as e:
@@ -213,8 +245,8 @@ def test_retrieval_query_invalid_config_filter(self):
213245
text=tc.TEST_QUERY_TEXT,
214246
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_ERROR_CONFIG,
215247
)
216-
e.match(
217-
"Only one of vector_distance_threshold or"
218-
" vector_similarity_threshold can be specified at a time"
219-
" in rag_retrieval_config."
220-
)
248+
e.match(
249+
"Only one of vector_distance_threshold or"
250+
" vector_similarity_threshold can be specified at a time"
251+
" in rag_retrieval_config."
252+
)

tests/unit/vertex_rag/test_rag_retrieval_preview.py

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#
1717
"""Tests for vertex_rag.retrieval_preview."""
1818

19+
# pylint: disable=bad-indentation, unused-variable, unused-argument, redefined-outer-name, C0116
20+
1921
import importlib
2022
from google.cloud import aiplatform
2123
from google.cloud.aiplatform_v1beta1 import VertexRagServiceAsyncClient
@@ -75,7 +77,9 @@ def rag_client_mock_exception():
7577

7678

7779
def retrieve_contexts_eq(response, expected_response):
78-
assert len(response.contexts.contexts) == len(expected_response.contexts.contexts)
80+
assert len(response.contexts.contexts) == len(
81+
expected_response.contexts.contexts
82+
)
7983
assert (
8084
response.contexts.contexts[0].text
8185
== expected_response.contexts.contexts[0].text
@@ -87,7 +91,7 @@ def retrieve_contexts_eq(response, expected_response):
8791

8892

8993
@pytest.mark.usefixtures("google_auth_mock")
90-
class TestRagRetrieval: # pylint: disable=missing-class-docstring
94+
class TestRagRetrieval: # pylint: disable=missing-class-docstring, bad-indentation, unused-variable, unused-argument, redefined-outer-name
9195

9296
def setup_method(self):
9397
importlib.reload(aiplatform.initializer)
@@ -118,6 +122,18 @@ def test_ask_contexts_rag_resources_success(self):
118122
)
119123
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
120124

125+
@pytest.mark.usefixtures("ask_contexts_mock")
126+
def test_ask_contexts_with_timeout(self, ask_contexts_mock):
127+
response = rag.ask_contexts(
128+
rag_resources=[tc.TEST_RAG_RESOURCE],
129+
text=tc.TEST_QUERY_TEXT,
130+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG_ALPHA,
131+
timeout=300,
132+
)
133+
ask_contexts_mock.assert_called_once()
134+
args, kwargs = ask_contexts_mock.call_args
135+
assert kwargs["timeout"] == 300
136+
121137
@pytest.mark.usefixtures("ask_contexts_mock")
122138
def test_ask_contexts_multiple_rag_resources_success(self):
123139
response = rag.ask_contexts(
@@ -138,8 +154,9 @@ def test_ask_contexts_multiple_rag_corpora_success(self):
138154
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
139155

140156
@pytest.mark.asyncio
141-
@pytest.mark.usefixtures("async_retrieve_contexts_mock")
142-
async def test_async_retrieve_contexts_rag_resources_success(self):
157+
async def test_async_retrieve_contexts_rag_resources_success(
158+
self, async_retrieve_contexts_mock
159+
):
143160
response = await rag.async_retrieve_contexts(
144161
rag_resources=[tc.TEST_RAG_RESOURCE],
145162
text=tc.TEST_QUERY_TEXT,
@@ -148,8 +165,23 @@ async def test_async_retrieve_contexts_rag_resources_success(self):
148165
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
149166

150167
@pytest.mark.asyncio
151-
@pytest.mark.usefixtures("async_retrieve_contexts_mock")
152-
async def test_async_retrieve_contexts_multiple_rag_resources_success(self):
168+
async def test_async_retrieve_contexts_with_timeout(
169+
self, async_retrieve_contexts_mock
170+
):
171+
response = await rag.async_retrieve_contexts(
172+
rag_resources=[tc.TEST_RAG_RESOURCE],
173+
text=tc.TEST_QUERY_TEXT,
174+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG_ALPHA,
175+
timeout=300,
176+
)
177+
async_retrieve_contexts_mock.assert_called_once()
178+
args, kwargs = async_retrieve_contexts_mock.call_args
179+
assert kwargs["timeout"] == 300
180+
181+
@pytest.mark.asyncio
182+
async def test_async_retrieve_contexts_multiple_rag_resources_success(
183+
self, async_retrieve_contexts_mock
184+
):
153185
response = await rag.async_retrieve_contexts(
154186
rag_resources=[tc.TEST_RAG_RESOURCE, tc.TEST_RAG_RESOURCE],
155187
text=tc.TEST_QUERY_TEXT,
@@ -158,8 +190,9 @@ async def test_async_retrieve_contexts_multiple_rag_resources_success(self):
158190
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
159191

160192
@pytest.mark.asyncio
161-
@pytest.mark.usefixtures("async_retrieve_contexts_mock")
162-
async def test_async_retrieve_contexts_multiple_rag_corpora_success(self):
193+
async def test_async_retrieve_contexts_multiple_rag_corpora_success(
194+
self, async_retrieve_contexts_mock
195+
):
163196
with pytest.warns(DeprecationWarning):
164197
response = await rag.async_retrieve_contexts(
165198
rag_corpora=[tc.TEST_RAG_CORPUS_ID, tc.TEST_RAG_CORPUS_ID],
@@ -241,7 +274,8 @@ def test_retrieval_query_with_metadata_filter(self, retrieve_contexts_mock):
241274
args, kwargs = retrieve_contexts_mock.call_args
242275
request = kwargs["request"]
243276
assert (
244-
request.query.rag_retrieval_config.filter.metadata_filter == metadata_filter
277+
request.query.rag_retrieval_config.filter.metadata_filter
278+
== metadata_filter
245279
)
246280

247281
@pytest.mark.usefixtures("retrieve_contexts_mock")
@@ -262,7 +296,7 @@ def test_retrieval_query_failure(self):
262296
similarity_top_k=2,
263297
vector_distance_threshold=0.5,
264298
)
265-
e.match("Failed in retrieving contexts due to")
299+
e.match("Failed in retrieving contexts due to")
266300

267301
@pytest.mark.usefixtures("rag_client_mock_exception")
268302
def test_retrieval_query_config_failure(self):
@@ -272,7 +306,7 @@ def test_retrieval_query_config_failure(self):
272306
text=tc.TEST_QUERY_TEXT,
273307
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
274308
)
275-
e.match("Failed in retrieving contexts due to")
309+
e.match("Failed in retrieving contexts due to")
276310

277311
def test_retrieval_query_invalid_name(self):
278312
with pytest.raises(ValueError) as e:
@@ -282,7 +316,7 @@ def test_retrieval_query_invalid_name(self):
282316
similarity_top_k=2,
283317
vector_distance_threshold=0.5,
284318
)
285-
e.match("Invalid RagCorpus name")
319+
e.match("Invalid RagCorpus name")
286320

287321
def test_retrieval_query_invalid_name_config(self):
288322
with pytest.raises(ValueError) as e:
@@ -291,7 +325,7 @@ def test_retrieval_query_invalid_name_config(self):
291325
text=tc.TEST_QUERY_TEXT,
292326
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
293327
)
294-
e.match("Invalid RagCorpus name")
328+
e.match("Invalid RagCorpus name")
295329

296330
def test_retrieval_query_multiple_rag_corpora(self):
297331
with pytest.raises(ValueError) as e:
@@ -304,7 +338,7 @@ def test_retrieval_query_multiple_rag_corpora(self):
304338
similarity_top_k=2,
305339
vector_distance_threshold=0.5,
306340
)
307-
e.match("Currently only support 1 RagCorpus")
341+
e.match("Currently only support 1 RagCorpus")
308342

309343
def test_retrieval_query_multiple_rag_corpora_config(self):
310344
with pytest.raises(ValueError) as e:
@@ -316,7 +350,7 @@ def test_retrieval_query_multiple_rag_corpora_config(self):
316350
text=tc.TEST_QUERY_TEXT,
317351
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
318352
)
319-
e.match("Currently only support 1 RagCorpus")
353+
e.match("Currently only support 1 RagCorpus")
320354

321355
def test_retrieval_query_multiple_rag_resources(self):
322356
with pytest.raises(ValueError) as e:
@@ -329,7 +363,7 @@ def test_retrieval_query_multiple_rag_resources(self):
329363
similarity_top_k=2,
330364
vector_distance_threshold=0.5,
331365
)
332-
e.match("Currently only support 1 RagResource")
366+
e.match("Currently only support 1 RagResource")
333367

334368
def test_retrieval_query_multiple_rag_resources_config(self):
335369
with pytest.raises(ValueError) as e:
@@ -341,7 +375,7 @@ def test_retrieval_query_multiple_rag_resources_config(self):
341375
text=tc.TEST_QUERY_TEXT,
342376
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
343377
)
344-
e.match("Currently only support 1 RagResource")
378+
e.match("Currently only support 1 RagResource")
345379

346380
def test_retrieval_query_multiple_rag_resources_similarity_config(self):
347381
with pytest.raises(ValueError) as e:
@@ -353,7 +387,7 @@ def test_retrieval_query_multiple_rag_resources_similarity_config(self):
353387
text=tc.TEST_QUERY_TEXT,
354388
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
355389
)
356-
e.match("Currently only support 1 RagResource")
390+
e.match("Currently only support 1 RagResource")
357391

358392
def test_retrieval_query_invalid_config_filter(self):
359393
with pytest.raises(ValueError) as e:
@@ -362,8 +396,8 @@ def test_retrieval_query_invalid_config_filter(self):
362396
text=tc.TEST_QUERY_TEXT,
363397
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_ERROR_CONFIG,
364398
)
365-
e.match(
366-
"Only one of vector_distance_threshold or"
367-
" vector_similarity_threshold can be specified at a time"
368-
" in rag_retrieval_config."
369-
)
399+
e.match(
400+
"Only one of vector_distance_threshold or"
401+
" vector_similarity_threshold can be specified at a time"
402+
" in rag_retrieval_config."
403+
)

0 commit comments

Comments
 (0)