Skip to content

Commit 6024ab0

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Increase default timeout to 300 seconds for ask_contexts and async_retrieve_contexts in VertexRagServiceClient.
PiperOrigin-RevId: 892413253
1 parent 1a33ad9 commit 6024ab0

4 files changed

Lines changed: 351 additions & 273 deletions

File tree

tests/unit/vertex_rag/test_rag_retrieval.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,19 @@ def test_ask_contexts_rag_resources_success(self):
113113
)
114114
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
115115

116+
@pytest.mark.usefixtures("ask_contexts_mock")
117+
def test_ask_contexts_with_timeout(self, ask_contexts_mock):
118+
response = rag.ask_contexts(
119+
rag_resources=[tc.TEST_RAG_RESOURCE],
120+
text=tc.TEST_QUERY_TEXT,
121+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
122+
timeout=300,
123+
)
124+
ask_contexts_mock.assert_called_once()
125+
args, kwargs = ask_contexts_mock.call_args
126+
assert kwargs["timeout"] == 300
127+
128+
116129
@pytest.mark.usefixtures("ask_contexts_mock")
117130
def test_ask_contexts_multiple_rag_resources_success(self):
118131
response = rag.ask_contexts(
@@ -132,6 +145,22 @@ async def test_async_retrieve_contexts_rag_resources_success(self):
132145
)
133146
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
134147

148+
@pytest.mark.asyncio
149+
@pytest.mark.usefixtures("async_retrieve_contexts_mock")
150+
async def test_async_retrieve_contexts_with_timeout(
151+
self, async_retrieve_contexts_mock
152+
):
153+
response = await rag.async_retrieve_contexts(
154+
rag_resources=[tc.TEST_RAG_RESOURCE],
155+
text=tc.TEST_QUERY_TEXT,
156+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
157+
timeout=300,
158+
)
159+
async_retrieve_contexts_mock.assert_called_once()
160+
args, kwargs = async_retrieve_contexts_mock.call_args
161+
assert kwargs["timeout"] == 300
162+
163+
135164
@pytest.mark.asyncio
136165
@pytest.mark.usefixtures("async_retrieve_contexts_mock")
137166
async def test_async_retrieve_contexts_multiple_rag_resources_success(self):

tests/unit/vertex_rag/test_rag_retrieval_preview.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,19 @@ def test_ask_contexts_rag_resources_success(self):
118118
)
119119
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
120120

121+
@pytest.mark.usefixtures("ask_contexts_mock")
122+
def test_ask_contexts_with_timeout(self, ask_contexts_mock):
123+
response = rag.ask_contexts(
124+
rag_resources=[tc.TEST_RAG_RESOURCE],
125+
text=tc.TEST_QUERY_TEXT,
126+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG_ALPHA,
127+
timeout=300,
128+
)
129+
ask_contexts_mock.assert_called_once()
130+
args, kwargs = ask_contexts_mock.call_args
131+
assert kwargs["timeout"] == 300
132+
133+
121134
@pytest.mark.usefixtures("ask_contexts_mock")
122135
def test_ask_contexts_multiple_rag_resources_success(self):
123136
response = rag.ask_contexts(
@@ -147,6 +160,22 @@ async def test_async_retrieve_contexts_rag_resources_success(self):
147160
)
148161
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
149162

163+
@pytest.mark.asyncio
164+
@pytest.mark.usefixtures("async_retrieve_contexts_mock")
165+
async def test_async_retrieve_contexts_with_timeout(
166+
self, async_retrieve_contexts_mock
167+
):
168+
response = await rag.async_retrieve_contexts(
169+
rag_resources=[tc.TEST_RAG_RESOURCE],
170+
text=tc.TEST_QUERY_TEXT,
171+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG_ALPHA,
172+
timeout=300,
173+
)
174+
async_retrieve_contexts_mock.assert_called_once()
175+
args, kwargs = async_retrieve_contexts_mock.call_args
176+
assert kwargs["timeout"] == 300
177+
178+
150179
@pytest.mark.asyncio
151180
@pytest.mark.usefixtures("async_retrieve_contexts_mock")
152181
async def test_async_retrieve_contexts_multiple_rag_resources_success(self):

vertexai/preview/rag/rag_retrieval.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ async def async_retrieve_contexts(
290290
vector_distance_threshold: Optional[float] = None,
291291
vector_search_alpha: Optional[float] = None,
292292
rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None,
293+
timeout: int = 600,
293294
) -> aiplatform_v1beta1.RetrieveContextsResponse:
294295
"""Retrieve top k relevant docs/chunks asynchronously.
295296
@@ -523,7 +524,9 @@ async def async_retrieve_contexts(
523524
tools=[tool],
524525
)
525526
try:
526-
response_lro = await client.async_retrieve_contexts(request=request)
527+
response_lro = await client.async_retrieve_contexts(
528+
request=request, timeout=timeout
529+
)
527530
response = await response_lro.result()
528531
except Exception as e:
529532
raise RuntimeError(
@@ -541,6 +544,7 @@ def ask_contexts(
541544
vector_distance_threshold: Optional[float] = None,
542545
vector_search_alpha: Optional[float] = None,
543546
rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None,
547+
timeout: int = 600,
544548
) -> aiplatform_v1beta1.AskContextsResponse:
545549
"""Ask questions on top k relevant docs/chunks.
546550
@@ -774,7 +778,7 @@ def ask_contexts(
774778
tools=[tool],
775779
)
776780
try:
777-
response = client.ask_contexts(request=request)
781+
response = client.ask_contexts(request=request, timeout=timeout)
778782
except Exception as e:
779783
raise RuntimeError("Failed in asking contexts due to: ", e) from e
780784

0 commit comments

Comments
 (0)