Skip to content

Commit 98efde0

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Increase default timeout to 300 seconds for ask_contexts and async_retrieve_contexts in VertexRagServiceClient.
PiperOrigin-RevId: 892413253
1 parent c12aedc commit 98efde0

File tree

4 files changed

+154
-77
lines changed

4 files changed

+154
-77
lines changed

tests/unit/vertex_rag/test_rag_retrieval.py

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#
1717
"""Tests for vertex_rag.retrieval."""
1818

19+
# pylint: disable=bad-indentation, unused-variable, unused-argument, redefined-outer-name, C0116
20+
1921
import importlib
2022
from google.cloud import aiplatform
2123
from google.cloud.aiplatform_v1 import VertexRagServiceAsyncClient
@@ -85,7 +87,7 @@ def retrieve_contexts_eq(response, expected_response):
8587

8688

8789
@pytest.mark.usefixtures("google_auth_mock")
88-
class TestRagRetrieval: # pylint: disable=missing-class-docstring
90+
class TestRagRetrieval: # pylint: disable=missing-class-docstring, bad-indentation, unused-variable, unused-argument, redefined-outer-name
8991

9092
def setup_method(self):
9193
importlib.reload(aiplatform.initializer)
@@ -113,6 +115,18 @@ def test_ask_contexts_rag_resources_success(self):
113115
)
114116
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
115117

118+
@pytest.mark.usefixtures("ask_contexts_mock")
119+
def test_ask_contexts_with_timeout(self, ask_contexts_mock):
120+
rag.ask_contexts(
121+
rag_resources=[tc.TEST_RAG_RESOURCE],
122+
text=tc.TEST_QUERY_TEXT,
123+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
124+
timeout=300,
125+
)
126+
ask_contexts_mock.assert_called_once()
127+
_, kwargs = ask_contexts_mock.call_args
128+
assert kwargs["timeout"] == 300
129+
116130
@pytest.mark.usefixtures("ask_contexts_mock")
117131
def test_ask_contexts_multiple_rag_resources_success(self):
118132
response = rag.ask_contexts(
@@ -123,8 +137,9 @@ def test_ask_contexts_multiple_rag_resources_success(self):
123137
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
124138

125139
@pytest.mark.asyncio
126-
@pytest.mark.usefixtures("async_retrieve_contexts_mock")
127-
async def test_async_retrieve_contexts_rag_resources_success(self):
140+
async def test_async_retrieve_contexts_rag_resources_success(
141+
self, async_retrieve_contexts_mock
142+
):
128143
response = await rag.async_retrieve_contexts(
129144
rag_resources=[tc.TEST_RAG_RESOURCE],
130145
text=tc.TEST_QUERY_TEXT,
@@ -133,8 +148,23 @@ async def test_async_retrieve_contexts_rag_resources_success(self):
133148
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
134149

135150
@pytest.mark.asyncio
136-
@pytest.mark.usefixtures("async_retrieve_contexts_mock")
137-
async def test_async_retrieve_contexts_multiple_rag_resources_success(self):
151+
async def test_async_retrieve_contexts_with_timeout(
152+
self, async_retrieve_contexts_mock
153+
):
154+
await rag.async_retrieve_contexts(
155+
rag_resources=[tc.TEST_RAG_RESOURCE],
156+
text=tc.TEST_QUERY_TEXT,
157+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
158+
timeout=300,
159+
)
160+
async_retrieve_contexts_mock.assert_called_once()
161+
_, kwargs = async_retrieve_contexts_mock.call_args
162+
assert kwargs["timeout"] == 300
163+
164+
@pytest.mark.asyncio
165+
async def test_async_retrieve_contexts_multiple_rag_resources_success(
166+
self, async_retrieve_contexts_mock
167+
):
138168
response = await rag.async_retrieve_contexts(
139169
rag_resources=[tc.TEST_RAG_RESOURCE, tc.TEST_RAG_RESOURCE],
140170
text=tc.TEST_QUERY_TEXT,
@@ -177,7 +207,7 @@ def test_retrieval_query_failure(self):
177207
text=tc.TEST_QUERY_TEXT,
178208
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
179209
)
180-
e.match("Failed in retrieving contexts due to")
210+
e.match("Failed in retrieving contexts due to")
181211

182212
def test_retrieval_query_invalid_name(self):
183213
with pytest.raises(ValueError) as e:
@@ -186,7 +216,7 @@ def test_retrieval_query_invalid_name(self):
186216
text=tc.TEST_QUERY_TEXT,
187217
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
188218
)
189-
e.match("Invalid RagCorpus name")
219+
e.match("Invalid RagCorpus name")
190220

191221
def test_retrieval_query_multiple_rag_resources(self):
192222
with pytest.raises(ValueError) as e:
@@ -195,7 +225,7 @@ def test_retrieval_query_multiple_rag_resources(self):
195225
text=tc.TEST_QUERY_TEXT,
196226
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
197227
)
198-
e.match("Currently only support 1 RagResource")
228+
e.match("Currently only support 1 RagResource")
199229

200230
def test_retrieval_query_similarity_multiple_rag_resources(self):
201231
with pytest.raises(ValueError) as e:
@@ -204,7 +234,7 @@ def test_retrieval_query_similarity_multiple_rag_resources(self):
204234
text=tc.TEST_QUERY_TEXT,
205235
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
206236
)
207-
e.match("Currently only support 1 RagResource")
237+
e.match("Currently only support 1 RagResource")
208238

209239
def test_retrieval_query_invalid_config_filter(self):
210240
with pytest.raises(ValueError) as e:
@@ -213,8 +243,8 @@ def test_retrieval_query_invalid_config_filter(self):
213243
text=tc.TEST_QUERY_TEXT,
214244
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_ERROR_CONFIG,
215245
)
216-
e.match(
217-
"Only one of vector_distance_threshold or"
218-
" vector_similarity_threshold can be specified at a time"
219-
" in rag_retrieval_config."
220-
)
246+
e.match(
247+
"Only one of vector_distance_threshold or"
248+
" vector_similarity_threshold can be specified at a time"
249+
" in rag_retrieval_config."
250+
)

tests/unit/vertex_rag/test_rag_retrieval_preview.py

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#
1717
"""Tests for vertex_rag.retrieval_preview."""
1818

19+
# pylint: disable=bad-indentation, unused-variable, unused-argument, redefined-outer-name, C0116
20+
1921
import importlib
2022
from google.cloud import aiplatform
2123
from google.cloud.aiplatform_v1beta1 import VertexRagServiceAsyncClient
@@ -87,7 +89,7 @@ def retrieve_contexts_eq(response, expected_response):
8789

8890

8991
@pytest.mark.usefixtures("google_auth_mock")
90-
class TestRagRetrieval: # pylint: disable=missing-class-docstring
92+
class TestRagRetrieval: # pylint: disable=missing-class-docstring, bad-indentation, unused-variable, unused-argument, redefined-outer-name
9193

9294
def setup_method(self):
9395
importlib.reload(aiplatform.initializer)
@@ -118,6 +120,18 @@ def test_ask_contexts_rag_resources_success(self):
118120
)
119121
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
120122

123+
@pytest.mark.usefixtures("ask_contexts_mock")
124+
def test_ask_contexts_with_timeout(self, ask_contexts_mock):
125+
response = rag.ask_contexts(
126+
rag_resources=[tc.TEST_RAG_RESOURCE],
127+
text=tc.TEST_QUERY_TEXT,
128+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG_ALPHA,
129+
timeout=300,
130+
)
131+
ask_contexts_mock.assert_called_once()
132+
args, kwargs = ask_contexts_mock.call_args
133+
assert kwargs["timeout"] == 300
134+
121135
@pytest.mark.usefixtures("ask_contexts_mock")
122136
def test_ask_contexts_multiple_rag_resources_success(self):
123137
response = rag.ask_contexts(
@@ -138,8 +152,9 @@ def test_ask_contexts_multiple_rag_corpora_success(self):
138152
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
139153

140154
@pytest.mark.asyncio
141-
@pytest.mark.usefixtures("async_retrieve_contexts_mock")
142-
async def test_async_retrieve_contexts_rag_resources_success(self):
155+
async def test_async_retrieve_contexts_rag_resources_success(
156+
self, async_retrieve_contexts_mock
157+
):
143158
response = await rag.async_retrieve_contexts(
144159
rag_resources=[tc.TEST_RAG_RESOURCE],
145160
text=tc.TEST_QUERY_TEXT,
@@ -148,8 +163,23 @@ async def test_async_retrieve_contexts_rag_resources_success(self):
148163
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
149164

150165
@pytest.mark.asyncio
151-
@pytest.mark.usefixtures("async_retrieve_contexts_mock")
152-
async def test_async_retrieve_contexts_multiple_rag_resources_success(self):
166+
async def test_async_retrieve_contexts_with_timeout(
167+
self, async_retrieve_contexts_mock
168+
):
169+
response = await rag.async_retrieve_contexts(
170+
rag_resources=[tc.TEST_RAG_RESOURCE],
171+
text=tc.TEST_QUERY_TEXT,
172+
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG_ALPHA,
173+
timeout=300,
174+
)
175+
async_retrieve_contexts_mock.assert_called_once()
176+
args, kwargs = async_retrieve_contexts_mock.call_args
177+
assert kwargs["timeout"] == 300
178+
179+
@pytest.mark.asyncio
180+
async def test_async_retrieve_contexts_multiple_rag_resources_success(
181+
self, async_retrieve_contexts_mock
182+
):
153183
response = await rag.async_retrieve_contexts(
154184
rag_resources=[tc.TEST_RAG_RESOURCE, tc.TEST_RAG_RESOURCE],
155185
text=tc.TEST_QUERY_TEXT,
@@ -158,8 +188,9 @@ async def test_async_retrieve_contexts_multiple_rag_resources_success(self):
158188
retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE)
159189

160190
@pytest.mark.asyncio
161-
@pytest.mark.usefixtures("async_retrieve_contexts_mock")
162-
async def test_async_retrieve_contexts_multiple_rag_corpora_success(self):
191+
async def test_async_retrieve_contexts_multiple_rag_corpora_success(
192+
self, async_retrieve_contexts_mock
193+
):
163194
with pytest.warns(DeprecationWarning):
164195
response = await rag.async_retrieve_contexts(
165196
rag_corpora=[tc.TEST_RAG_CORPUS_ID, tc.TEST_RAG_CORPUS_ID],
@@ -262,7 +293,7 @@ def test_retrieval_query_failure(self):
262293
similarity_top_k=2,
263294
vector_distance_threshold=0.5,
264295
)
265-
e.match("Failed in retrieving contexts due to")
296+
e.match("Failed in retrieving contexts due to")
266297

267298
@pytest.mark.usefixtures("rag_client_mock_exception")
268299
def test_retrieval_query_config_failure(self):
@@ -272,7 +303,7 @@ def test_retrieval_query_config_failure(self):
272303
text=tc.TEST_QUERY_TEXT,
273304
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
274305
)
275-
e.match("Failed in retrieving contexts due to")
306+
e.match("Failed in retrieving contexts due to")
276307

277308
def test_retrieval_query_invalid_name(self):
278309
with pytest.raises(ValueError) as e:
@@ -282,7 +313,7 @@ def test_retrieval_query_invalid_name(self):
282313
similarity_top_k=2,
283314
vector_distance_threshold=0.5,
284315
)
285-
e.match("Invalid RagCorpus name")
316+
e.match("Invalid RagCorpus name")
286317

287318
def test_retrieval_query_invalid_name_config(self):
288319
with pytest.raises(ValueError) as e:
@@ -291,7 +322,7 @@ def test_retrieval_query_invalid_name_config(self):
291322
text=tc.TEST_QUERY_TEXT,
292323
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
293324
)
294-
e.match("Invalid RagCorpus name")
325+
e.match("Invalid RagCorpus name")
295326

296327
def test_retrieval_query_multiple_rag_corpora(self):
297328
with pytest.raises(ValueError) as e:
@@ -304,7 +335,7 @@ def test_retrieval_query_multiple_rag_corpora(self):
304335
similarity_top_k=2,
305336
vector_distance_threshold=0.5,
306337
)
307-
e.match("Currently only support 1 RagCorpus")
338+
e.match("Currently only support 1 RagCorpus")
308339

309340
def test_retrieval_query_multiple_rag_corpora_config(self):
310341
with pytest.raises(ValueError) as e:
@@ -316,7 +347,7 @@ def test_retrieval_query_multiple_rag_corpora_config(self):
316347
text=tc.TEST_QUERY_TEXT,
317348
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
318349
)
319-
e.match("Currently only support 1 RagCorpus")
350+
e.match("Currently only support 1 RagCorpus")
320351

321352
def test_retrieval_query_multiple_rag_resources(self):
322353
with pytest.raises(ValueError) as e:
@@ -329,7 +360,7 @@ def test_retrieval_query_multiple_rag_resources(self):
329360
similarity_top_k=2,
330361
vector_distance_threshold=0.5,
331362
)
332-
e.match("Currently only support 1 RagResource")
363+
e.match("Currently only support 1 RagResource")
333364

334365
def test_retrieval_query_multiple_rag_resources_config(self):
335366
with pytest.raises(ValueError) as e:
@@ -341,7 +372,7 @@ def test_retrieval_query_multiple_rag_resources_config(self):
341372
text=tc.TEST_QUERY_TEXT,
342373
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG,
343374
)
344-
e.match("Currently only support 1 RagResource")
375+
e.match("Currently only support 1 RagResource")
345376

346377
def test_retrieval_query_multiple_rag_resources_similarity_config(self):
347378
with pytest.raises(ValueError) as e:
@@ -353,7 +384,7 @@ def test_retrieval_query_multiple_rag_resources_similarity_config(self):
353384
text=tc.TEST_QUERY_TEXT,
354385
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG,
355386
)
356-
e.match("Currently only support 1 RagResource")
387+
e.match("Currently only support 1 RagResource")
357388

358389
def test_retrieval_query_invalid_config_filter(self):
359390
with pytest.raises(ValueError) as e:
@@ -362,8 +393,8 @@ def test_retrieval_query_invalid_config_filter(self):
362393
text=tc.TEST_QUERY_TEXT,
363394
rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_ERROR_CONFIG,
364395
)
365-
e.match(
366-
"Only one of vector_distance_threshold or"
367-
" vector_similarity_threshold can be specified at a time"
368-
" in rag_retrieval_config."
369-
)
396+
e.match(
397+
"Only one of vector_distance_threshold or"
398+
" vector_similarity_threshold can be specified at a time"
399+
" in rag_retrieval_config."
400+
)

0 commit comments

Comments
 (0)