1616#
1717"""Tests for vertex_rag.retrieval_preview."""
1818
19+ # pylint: disable=bad-indentation, unused-variable, unused-argument, redefined-outer-name, C0116
20+
1921import importlib
2022from google .cloud import aiplatform
2123from google .cloud .aiplatform_v1beta1 import VertexRagServiceAsyncClient
@@ -87,7 +89,7 @@ def retrieve_contexts_eq(response, expected_response):
8789
8890
8991@pytest .mark .usefixtures ("google_auth_mock" )
90- class TestRagRetrieval : # pylint: disable=missing-class-docstring
92+ class TestRagRetrieval : # pylint: disable=missing-class-docstring, bad-indentation, unused-variable, unused-argument, redefined-outer-name
9193
9294 def setup_method (self ):
9395 importlib .reload (aiplatform .initializer )
@@ -118,6 +120,18 @@ def test_ask_contexts_rag_resources_success(self):
118120 )
119121 retrieve_contexts_eq (response , tc .TEST_RETRIEVAL_RESPONSE )
120122
123+ @pytest .mark .usefixtures ("ask_contexts_mock" )
124+ def test_ask_contexts_with_timeout (self , ask_contexts_mock ):
125+ response = rag .ask_contexts (
126+ rag_resources = [tc .TEST_RAG_RESOURCE ],
127+ text = tc .TEST_QUERY_TEXT ,
128+ rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_CONFIG_ALPHA ,
129+ timeout = 300 ,
130+ )
131+ ask_contexts_mock .assert_called_once ()
132+ args , kwargs = ask_contexts_mock .call_args
133+ assert kwargs ["timeout" ] == 300
134+
121135 @pytest .mark .usefixtures ("ask_contexts_mock" )
122136 def test_ask_contexts_multiple_rag_resources_success (self ):
123137 response = rag .ask_contexts (
@@ -138,8 +152,9 @@ def test_ask_contexts_multiple_rag_corpora_success(self):
138152 retrieve_contexts_eq (response , tc .TEST_RETRIEVAL_RESPONSE )
139153
140154 @pytest .mark .asyncio
141- @pytest .mark .usefixtures ("async_retrieve_contexts_mock" )
142- async def test_async_retrieve_contexts_rag_resources_success (self ):
155+ async def test_async_retrieve_contexts_rag_resources_success (
156+ self , async_retrieve_contexts_mock
157+ ):
143158 response = await rag .async_retrieve_contexts (
144159 rag_resources = [tc .TEST_RAG_RESOURCE ],
145160 text = tc .TEST_QUERY_TEXT ,
@@ -148,8 +163,23 @@ async def test_async_retrieve_contexts_rag_resources_success(self):
148163 retrieve_contexts_eq (response , tc .TEST_RETRIEVAL_RESPONSE )
149164
150165 @pytest .mark .asyncio
151- @pytest .mark .usefixtures ("async_retrieve_contexts_mock" )
152- async def test_async_retrieve_contexts_multiple_rag_resources_success (self ):
166+ async def test_async_retrieve_contexts_with_timeout (
167+ self , async_retrieve_contexts_mock
168+ ):
169+ response = await rag .async_retrieve_contexts (
170+ rag_resources = [tc .TEST_RAG_RESOURCE ],
171+ text = tc .TEST_QUERY_TEXT ,
172+ rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_CONFIG_ALPHA ,
173+ timeout = 300 ,
174+ )
175+ async_retrieve_contexts_mock .assert_called_once ()
176+ args , kwargs = async_retrieve_contexts_mock .call_args
177+ assert kwargs ["timeout" ] == 300
178+
179+ @pytest .mark .asyncio
180+ async def test_async_retrieve_contexts_multiple_rag_resources_success (
181+ self , async_retrieve_contexts_mock
182+ ):
153183 response = await rag .async_retrieve_contexts (
154184 rag_resources = [tc .TEST_RAG_RESOURCE , tc .TEST_RAG_RESOURCE ],
155185 text = tc .TEST_QUERY_TEXT ,
@@ -158,8 +188,9 @@ async def test_async_retrieve_contexts_multiple_rag_resources_success(self):
158188 retrieve_contexts_eq (response , tc .TEST_RETRIEVAL_RESPONSE )
159189
160190 @pytest .mark .asyncio
161- @pytest .mark .usefixtures ("async_retrieve_contexts_mock" )
162- async def test_async_retrieve_contexts_multiple_rag_corpora_success (self ):
191+ async def test_async_retrieve_contexts_multiple_rag_corpora_success (
192+ self , async_retrieve_contexts_mock
193+ ):
163194 with pytest .warns (DeprecationWarning ):
164195 response = await rag .async_retrieve_contexts (
165196 rag_corpora = [tc .TEST_RAG_CORPUS_ID , tc .TEST_RAG_CORPUS_ID ],
@@ -262,7 +293,7 @@ def test_retrieval_query_failure(self):
262293 similarity_top_k = 2 ,
263294 vector_distance_threshold = 0.5 ,
264295 )
265- e .match ("Failed in retrieving contexts due to" )
296+ e .match ("Failed in retrieving contexts due to" )
266297
267298 @pytest .mark .usefixtures ("rag_client_mock_exception" )
268299 def test_retrieval_query_config_failure (self ):
@@ -272,7 +303,7 @@ def test_retrieval_query_config_failure(self):
272303 text = tc .TEST_QUERY_TEXT ,
273304 rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_CONFIG ,
274305 )
275- e .match ("Failed in retrieving contexts due to" )
306+ e .match ("Failed in retrieving contexts due to" )
276307
277308 def test_retrieval_query_invalid_name (self ):
278309 with pytest .raises (ValueError ) as e :
@@ -282,7 +313,7 @@ def test_retrieval_query_invalid_name(self):
282313 similarity_top_k = 2 ,
283314 vector_distance_threshold = 0.5 ,
284315 )
285- e .match ("Invalid RagCorpus name" )
316+ e .match ("Invalid RagCorpus name" )
286317
287318 def test_retrieval_query_invalid_name_config (self ):
288319 with pytest .raises (ValueError ) as e :
@@ -291,7 +322,7 @@ def test_retrieval_query_invalid_name_config(self):
291322 text = tc .TEST_QUERY_TEXT ,
292323 rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_CONFIG ,
293324 )
294- e .match ("Invalid RagCorpus name" )
325+ e .match ("Invalid RagCorpus name" )
295326
296327 def test_retrieval_query_multiple_rag_corpora (self ):
297328 with pytest .raises (ValueError ) as e :
@@ -304,7 +335,7 @@ def test_retrieval_query_multiple_rag_corpora(self):
304335 similarity_top_k = 2 ,
305336 vector_distance_threshold = 0.5 ,
306337 )
307- e .match ("Currently only support 1 RagCorpus" )
338+ e .match ("Currently only support 1 RagCorpus" )
308339
309340 def test_retrieval_query_multiple_rag_corpora_config (self ):
310341 with pytest .raises (ValueError ) as e :
@@ -316,7 +347,7 @@ def test_retrieval_query_multiple_rag_corpora_config(self):
316347 text = tc .TEST_QUERY_TEXT ,
317348 rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_CONFIG ,
318349 )
319- e .match ("Currently only support 1 RagCorpus" )
350+ e .match ("Currently only support 1 RagCorpus" )
320351
321352 def test_retrieval_query_multiple_rag_resources (self ):
322353 with pytest .raises (ValueError ) as e :
@@ -329,7 +360,7 @@ def test_retrieval_query_multiple_rag_resources(self):
329360 similarity_top_k = 2 ,
330361 vector_distance_threshold = 0.5 ,
331362 )
332- e .match ("Currently only support 1 RagResource" )
363+ e .match ("Currently only support 1 RagResource" )
333364
334365 def test_retrieval_query_multiple_rag_resources_config (self ):
335366 with pytest .raises (ValueError ) as e :
@@ -341,7 +372,7 @@ def test_retrieval_query_multiple_rag_resources_config(self):
341372 text = tc .TEST_QUERY_TEXT ,
342373 rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_CONFIG ,
343374 )
344- e .match ("Currently only support 1 RagResource" )
375+ e .match ("Currently only support 1 RagResource" )
345376
346377 def test_retrieval_query_multiple_rag_resources_similarity_config (self ):
347378 with pytest .raises (ValueError ) as e :
@@ -353,7 +384,7 @@ def test_retrieval_query_multiple_rag_resources_similarity_config(self):
353384 text = tc .TEST_QUERY_TEXT ,
354385 rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG ,
355386 )
356- e .match ("Currently only support 1 RagResource" )
387+ e .match ("Currently only support 1 RagResource" )
357388
358389 def test_retrieval_query_invalid_config_filter (self ):
359390 with pytest .raises (ValueError ) as e :
@@ -362,8 +393,8 @@ def test_retrieval_query_invalid_config_filter(self):
362393 text = tc .TEST_QUERY_TEXT ,
363394 rag_retrieval_config = tc .TEST_RAG_RETRIEVAL_ERROR_CONFIG ,
364395 )
365- e .match (
366- "Only one of vector_distance_threshold or"
367- " vector_similarity_threshold can be specified at a time"
368- " in rag_retrieval_config."
369- )
396+ e .match (
397+ "Only one of vector_distance_threshold or"
398+ " vector_similarity_threshold can be specified at a time"
399+ " in rag_retrieval_config."
400+ )
0 commit comments