From b66b9041581cb9b471ec6ec6a71ffa8dbc015911 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Wed, 8 Apr 2026 21:40:25 -0700 Subject: [PATCH] fix: Add SDK workaround for double-wrapped Any response in async_retrieve_contexts. PiperOrigin-RevId: 896867061 --- vertexai/preview/rag/rag_retrieval.py | 1489 +++++++++++++------------ 1 file changed, 782 insertions(+), 707 deletions(-) diff --git a/vertexai/preview/rag/rag_retrieval.py b/vertexai/preview/rag/rag_retrieval.py index d3d879509a..fa0558e64d 100644 --- a/vertexai/preview/rag/rag_retrieval.py +++ b/vertexai/preview/rag/rag_retrieval.py @@ -25,8 +25,6 @@ from vertexai.preview.rag.utils import _gapic_utils from vertexai.preview.rag.utils import resources -from google.protobuf import any_pb2 - def retrieval_query( text: str, @@ -37,251 +35,263 @@ def retrieval_query( vector_search_alpha: Optional[float] = None, rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None, ) -> aiplatform_v1beta1.RetrieveContextsResponse: - """Retrieve top k relevant docs/chunks. - - Example usage: - ``` - import vertexai - - vertexai.init(project="my-project") - - # Using deprecated parameters - results = vertexai.preview.rag.retrieval_query( - text="Why is the sky blue?", - rag_resources=[vertexai.preview.rag.RagResource( - rag_corpus="projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1", - rag_file_ids=["rag-file-1", "rag-file-2", ...], - )], - similarity_top_k=2, - vector_distance_threshold=0.5, - vector_search_alpha=0.5, + """Retrieve top k relevant docs/chunks. + + Example usage: + ``` + import vertexai + + vertexai.init(project="my-project") + + # Using deprecated parameters + results = vertexai.preview.rag.retrieval_query( + text="Why is the sky blue?", + rag_resources=[vertexai.preview.rag.RagResource( + rag_corpus="projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1", + rag_file_ids=["rag-file-1", "rag-file-2", ...], + )], + similarity_top_k=2, + vector_distance_threshold=0.5, + vector_search_alpha=0.5, + ) + + # Using RagRetrievalConfig. Equivalent to the above example. + config = vertexai.preview.rag.RagRetrievalConfig( + top_k=2, + filter=vertexai.preview.rag.Filter( + vector_distance_threshold=0.5 + ), + hybrid_search=vertexai.preview.rag.rag_retrieval_config.hybrid_search( + alpha=0.5 + ), + ranking=vertex.preview.rag.Ranking( + llm_ranker=vertexai.preview.rag.LlmRanker( + model_name="gemini-1.5-flash-002" + ) + ) + ) + + results = vertexai.preview.rag.retrieval_query( + text="Why is the sky blue?", + rag_resources=[vertexai.preview.rag.RagResource( + rag_corpus="projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1", + rag_file_ids=["rag-file-1", "rag-file-2", ...], + )], + rag_retrieval_config=config, + ) + ``` + + Args: + text: The query in text format to get relevant contexts. + rag_resources: A list of RagResource. It can be used to specify corpus + only or ragfiles. Currently only support one corpus or multiple files + from one corpus. In the future we may open up multiple corpora support. + rag_corpora: If rag_resources is not specified, use rag_corpora as a list + of rag corpora names. Deprecated. Use rag_resources instead. + similarity_top_k: The number of contexts to retrieve. Deprecated. Use + rag_retrieval_config.top_k instead. + vector_distance_threshold: Optional. Only return contexts with vector + distance smaller than the threshold. Deprecated. Use + rag_retrieval_config.filter.vector_distance_threshold instead. + vector_search_alpha: Optional. Controls the weight between dense and + sparse vector search results. The range is [0, 1], where 0 means sparse + vector search only and 1 means dense vector search only. The default + value is 0.5. Deprecated. Use rag_retrieval_config.hybrid_search.alpha + instead. + rag_retrieval_config: Optional. The config containing the retrieval + parameters, including top_k, vector_distance_threshold, and alpha. + + Returns: + RetrieveContextsResonse. + """ + parent = initializer.global_config.common_location_path() + + client = _gapic_utils.create_rag_service_client() + + if rag_resources: + if len(rag_resources) > 1: + raise ValueError("Currently only support 1 RagResource.") + name = rag_resources[0].rag_corpus + elif rag_corpora: + if len(rag_corpora) > 1: + raise ValueError("Currently only support 1 RagCorpus.") + name = rag_corpora[0] + warnings.warn( + "rag_corpora is deprecated. Please use rag_resources instead." + f" After {resources.DEPRECATION_DATE} using" + " rag_corpora will raise error", + DeprecationWarning, ) - - # Using RagRetrievalConfig. Equivalent to the above example. - config = vertexai.preview.rag.RagRetrievalConfig( - top_k=2, - filter=vertexai.preview.rag.Filter( - vector_distance_threshold=0.5 - ), - hybrid_search=vertexai.preview.rag.rag_retrieval_config.hybrid_search( - alpha=0.5 - ), - ranking=vertex.preview.rag.Ranking( - llm_ranker=vertexai.preview.rag.LlmRanker( - model_name="gemini-1.5-flash-002" - ) - ) + else: + raise ValueError("rag_resources or rag_corpora must be specified.") + + data_client = _gapic_utils.create_rag_data_service_client() + if data_client.parse_rag_corpus_path(name): + rag_corpus_name = name + elif re.match( + "^{}$".format( + _gapic_utils._VALID_RESOURCE_NAME_REGEX # pylint: disable=protected-access + ), + name, + ): + rag_corpus_name = parent + "/ragCorpora/" + name + else: + raise ValueError( + f"Invalid RagCorpus name: {rag_corpora}. Proper format should be:" + " projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}" ) - results = vertexai.preview.rag.retrieval_query( - text="Why is the sky blue?", - rag_resources=[vertexai.preview.rag.RagResource( - rag_corpus="projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1", - rag_file_ids=["rag-file-1", "rag-file-2", ...], - )], - rag_retrieval_config=config, - ) - ``` - - Args: - text: The query in text format to get relevant contexts. - rag_resources: A list of RagResource. It can be used to specify corpus - only or ragfiles. Currently only support one corpus or multiple files - from one corpus. In the future we may open up multiple corpora support. - rag_corpora: If rag_resources is not specified, use rag_corpora as a list - of rag corpora names. Deprecated. Use rag_resources instead. - similarity_top_k: The number of contexts to retrieve. Deprecated. Use - rag_retrieval_config.top_k instead. - vector_distance_threshold: Optional. Only return contexts with vector - distance smaller than the threshold. Deprecated. Use - rag_retrieval_config.filter.vector_distance_threshold instead. - vector_search_alpha: Optional. Controls the weight between dense and - sparse vector search results. The range is [0, 1], where 0 means sparse - vector search only and 1 means dense vector search only. The default - value is 0.5. Deprecated. Use rag_retrieval_config.hybrid_search.alpha - instead. - rag_retrieval_config: Optional. The config containing the retrieval - parameters, including top_k, vector_distance_threshold, and alpha. - - Returns: - RetrieveContextsResonse. - """ - parent = initializer.global_config.common_location_path() - - client = _gapic_utils.create_rag_service_client() - - if rag_resources: - if len(rag_resources) > 1: - raise ValueError("Currently only support 1 RagResource.") - name = rag_resources[0].rag_corpus - elif rag_corpora: - if len(rag_corpora) > 1: - raise ValueError("Currently only support 1 RagCorpus.") - name = rag_corpora[0] - warnings.warn( - "rag_corpora is deprecated. Please use rag_resources instead." - f" After {resources.DEPRECATION_DATE} using" - " rag_corpora will raise error", - DeprecationWarning, + if rag_resources: + gapic_rag_resource = ( + aiplatform_v1beta1.RetrieveContextsRequest.VertexRagStore.RagResource( + rag_corpus=rag_corpus_name, + rag_file_ids=rag_resources[0].rag_file_ids, ) - else: - raise ValueError("rag_resources or rag_corpora must be specified.") - - data_client = _gapic_utils.create_rag_data_service_client() - if data_client.parse_rag_corpus_path(name): - rag_corpus_name = name - elif re.match( - "^{}$".format( - _gapic_utils._VALID_RESOURCE_NAME_REGEX # pylint: disable=protected-access - ), - name, - ): - rag_corpus_name = parent + "/ragCorpora/" + name - else: - raise ValueError( - f"Invalid RagCorpus name: {rag_corpora}. Proper format should be:" - " projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}" - ) - - if rag_resources: - gapic_rag_resource = ( - aiplatform_v1beta1.RetrieveContextsRequest.VertexRagStore.RagResource( - rag_corpus=rag_corpus_name, - rag_file_ids=rag_resources[0].rag_file_ids, - ) - ) - vertex_rag_store = aiplatform_v1beta1.RetrieveContextsRequest.VertexRagStore( + ) + vertex_rag_store = ( + aiplatform_v1beta1.RetrieveContextsRequest.VertexRagStore( rag_resources=[gapic_rag_resource], ) - else: - vertex_rag_store = aiplatform_v1beta1.RetrieveContextsRequest.VertexRagStore( + ) + else: + vertex_rag_store = ( + aiplatform_v1beta1.RetrieveContextsRequest.VertexRagStore( rag_corpora=[rag_corpus_name], ) + ) - # Check for deprecated parameters and raise warnings. - if similarity_top_k: - # If similarity_top_k is specified, throw deprecation warning. - warnings.warn( - "similarity_top_k is deprecated. Please use" - " rag_retrieval_config.top_k instead." - f" After {resources.DEPRECATION_DATE} using" - " similarity_top_k will raise error", - DeprecationWarning, - ) - if vector_search_alpha: - # If vector_search_alpha is specified, throw deprecation warning. - warnings.warn( - "vector_search_alpha is deprecated. Please use" - " rag_retrieval_config.alpha instead." - f" After {resources.DEPRECATION_DATE} using" - " vector_search_alpha will raise error", - DeprecationWarning, - ) - if vector_distance_threshold: - # If vector_distance_threshold is specified, throw deprecation warning. - warnings.warn( - "vector_distance_threshold is deprecated. Please use" - " rag_retrieval_config.filter.vector_distance_threshold instead." - f" After {resources.DEPRECATION_DATE} using" - " vector_distance_threshold will raise error", - DeprecationWarning, - ) - - # If rag_retrieval_config is not specified, set it to default values. - if not rag_retrieval_config: - api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig( - top_k=similarity_top_k, - hybrid_search=aiplatform_v1beta1.RagRetrievalConfig.HybridSearch( - alpha=vector_search_alpha, - ), - filter=aiplatform_v1beta1.RagRetrievalConfig.Filter( - vector_distance_threshold=vector_distance_threshold - ), - ) - else: - # If rag_retrieval_config is specified, check for missing parameters. - api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig() - # Set top_k to config value if specified - if rag_retrieval_config.top_k: - api_retrival_config.top_k = rag_retrieval_config.top_k - else: - api_retrival_config.top_k = similarity_top_k - # Set alpha to config value if specified - if ( - rag_retrieval_config.hybrid_search - and rag_retrieval_config.hybrid_search.alpha - ): - api_retrival_config.hybrid_search.alpha = ( - rag_retrieval_config.hybrid_search.alpha - ) - else: - api_retrival_config.hybrid_search.alpha = vector_search_alpha - # Check if both vector_distance_threshold and vector_similarity_threshold - # are specified. - if ( - rag_retrieval_config.filter - and rag_retrieval_config.filter.vector_distance_threshold - and rag_retrieval_config.filter.vector_similarity_threshold - ): - raise ValueError( - "Only one of vector_distance_threshold or" - " vector_similarity_threshold can be specified at a time" - " in rag_retrieval_config." - ) - # Set vector_distance_threshold to config value if specified - if ( - rag_retrieval_config.filter - and rag_retrieval_config.filter.vector_distance_threshold - ): - api_retrival_config.filter.vector_distance_threshold = ( - rag_retrieval_config.filter.vector_distance_threshold - ) - else: - api_retrival_config.filter.vector_distance_threshold = ( - vector_distance_threshold - ) - # Set vector_similarity_threshold to config value if specified - if ( - rag_retrieval_config.filter - and rag_retrieval_config.filter.vector_similarity_threshold - ): - api_retrival_config.filter.vector_similarity_threshold = ( - rag_retrieval_config.filter.vector_similarity_threshold - ) - if rag_retrieval_config.filter and rag_retrieval_config.filter.metadata_filter: - api_retrival_config.filter.metadata_filter = ( - rag_retrieval_config.filter.metadata_filter - ) - - if ( - rag_retrieval_config.ranking - and rag_retrieval_config.ranking.rank_service - and rag_retrieval_config.ranking.llm_ranker - ): - raise ValueError("Only one of rank_service and llm_ranker can be set.") - if rag_retrieval_config.ranking and rag_retrieval_config.ranking.rank_service: - api_retrival_config.ranking.rank_service.model_name = ( - rag_retrieval_config.ranking.rank_service.model_name - ) - elif rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker: - api_retrival_config.ranking.llm_ranker.model_name = ( - rag_retrieval_config.ranking.llm_ranker.model_name - ) - query = aiplatform_v1beta1.RagQuery( - text=text, - rag_retrieval_config=api_retrival_config, + # Check for deprecated parameters and raise warnings. + if similarity_top_k: + # If similarity_top_k is specified, throw deprecation warning. + warnings.warn( + "similarity_top_k is deprecated. Please use" + " rag_retrieval_config.top_k instead." + f" After {resources.DEPRECATION_DATE} using" + " similarity_top_k will raise error", + DeprecationWarning, ) - request = aiplatform_v1beta1.RetrieveContextsRequest( - vertex_rag_store=vertex_rag_store, - parent=parent, - query=query, + if vector_search_alpha: + # If vector_search_alpha is specified, throw deprecation warning. + warnings.warn( + "vector_search_alpha is deprecated. Please use" + " rag_retrieval_config.alpha instead." + f" After {resources.DEPRECATION_DATE} using" + " vector_search_alpha will raise error", + DeprecationWarning, + ) + if vector_distance_threshold: + # If vector_distance_threshold is specified, throw deprecation warning. + warnings.warn( + "vector_distance_threshold is deprecated. Please use" + " rag_retrieval_config.filter.vector_distance_threshold instead." + f" After {resources.DEPRECATION_DATE} using" + " vector_distance_threshold will raise error", + DeprecationWarning, ) - try: - response = client.retrieve_contexts(request=request) - except Exception as e: - raise RuntimeError("Failed in retrieving contexts due to: ", e) from e - return response + # If rag_retrieval_config is not specified, set it to default values. + if not rag_retrieval_config: + api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig( + top_k=similarity_top_k, + hybrid_search=aiplatform_v1beta1.RagRetrievalConfig.HybridSearch( + alpha=vector_search_alpha, + ), + filter=aiplatform_v1beta1.RagRetrievalConfig.Filter( + vector_distance_threshold=vector_distance_threshold + ), + ) + else: + # If rag_retrieval_config is specified, check for missing parameters. + api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig() + # Set top_k to config value if specified + if rag_retrieval_config.top_k: + api_retrival_config.top_k = rag_retrieval_config.top_k + else: + api_retrival_config.top_k = similarity_top_k + # Set alpha to config value if specified + if ( + rag_retrieval_config.hybrid_search + and rag_retrieval_config.hybrid_search.alpha + ): + api_retrival_config.hybrid_search.alpha = ( + rag_retrieval_config.hybrid_search.alpha + ) + else: + api_retrival_config.hybrid_search.alpha = vector_search_alpha + # Check if both vector_distance_threshold and vector_similarity_threshold + # are specified. + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_distance_threshold + and rag_retrieval_config.filter.vector_similarity_threshold + ): + raise ValueError( + "Only one of vector_distance_threshold or" + " vector_similarity_threshold can be specified at a time" + " in rag_retrieval_config." + ) + # Set vector_distance_threshold to config value if specified + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_distance_threshold + ): + api_retrival_config.filter.vector_distance_threshold = ( + rag_retrieval_config.filter.vector_distance_threshold + ) + else: + api_retrival_config.filter.vector_distance_threshold = ( + vector_distance_threshold + ) + # Set vector_similarity_threshold to config value if specified + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_similarity_threshold + ): + api_retrival_config.filter.vector_similarity_threshold = ( + rag_retrieval_config.filter.vector_similarity_threshold + ) + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.metadata_filter + ): + api_retrival_config.filter.metadata_filter = ( + rag_retrieval_config.filter.metadata_filter + ) + + if ( + rag_retrieval_config.ranking + and rag_retrieval_config.ranking.rank_service + and rag_retrieval_config.ranking.llm_ranker + ): + raise ValueError("Only one of rank_service and llm_ranker can be set.") + if ( + rag_retrieval_config.ranking + and rag_retrieval_config.ranking.rank_service + ): + api_retrival_config.ranking.rank_service.model_name = ( + rag_retrieval_config.ranking.rank_service.model_name + ) + elif ( + rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker + ): + api_retrival_config.ranking.llm_ranker.model_name = ( + rag_retrieval_config.ranking.llm_ranker.model_name + ) + query = aiplatform_v1beta1.RagQuery( + text=text, + rag_retrieval_config=api_retrival_config, + ) + request = aiplatform_v1beta1.RetrieveContextsRequest( + vertex_rag_store=vertex_rag_store, + parent=parent, + query=query, + ) + try: + response = client.retrieve_contexts(request=request) + except Exception as e: + raise RuntimeError("Failed in retrieving contexts due to: ", e) from e + + return response async def async_retrieve_contexts( @@ -294,264 +304,314 @@ async def async_retrieve_contexts( rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None, timeout: int = 600, ) -> aiplatform_v1beta1.RetrieveContextsResponse: - """Retrieve top k relevant docs/chunks asynchronously. - - Example usage: - ``` - import vertexai - - vertexai.init(project="my-project") - - config = vertexai.preview.rag.RagRetrievalConfig( - top_k=2, + """Retrieve top k relevant docs/chunks asynchronously. + + Example usage: + ``` + import vertexai + + vertexai.init(project="my-project") + + config = vertexai.preview.rag.RagRetrievalConfig( + top_k=2, + ) + + results = await vertexai.preview.rag.async_retrieve_contexts( + text="Why is the sky blue?", + rag_resources=[vertexai.preview.rag.RagResource( + rag_corpus="projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1", + rag_file_ids=["rag-file-1", "rag-file-2", ...], + )], + rag_retrieval_config=config, + ) + ``` + + Args: + text: Required. The query in text format to get relevant contexts. + rag_resources: Optional. A list of RagResource. It can be used to specify + corpus only or ragfiles. Currently only support one corpus or multiple + files from one corpus. In the future we may open up multiple corpora + support. + rag_corpora: Optional. Deprecated. Please use rag_resources instead. A + list of RagCorpora resource names. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + Currently only support one corpus. In the future we may open up multiple + corpora support. + similarity_top_k: Optional. Deprecated. Please use + rag_retrieval_config.top_k instead. + vector_distance_threshold: Optional. Deprecated. Please use + rag_retrieval_config.filter.vector_distance_threshold instead. + vector_search_alpha: Optional. Deprecated. Please use + rag_retrieval_config.hybrid_search.alpha instead. + rag_retrieval_config: Optional. The config containing the retrieval + parameters, including top_k, vector_distance_threshold, and alpha. + timeout: Optional. The timeout for the request in seconds. Default is 600. + + Returns: + RetrieveContextsResponse. + """ + parent = initializer.global_config.common_location_path() + + client = _gapic_utils.create_rag_service_async_client() + + if not rag_resources and not rag_corpora: + raise ValueError("rag_resources or rag_corpora must be specified.") + + data_client = _gapic_utils.create_rag_data_service_client() + + gapic_rag_resources = [] + gapic_rag_corpora = [] + if rag_resources: + for rag_resource in rag_resources: + name = rag_resource.rag_corpus + if data_client.parse_rag_corpus_path(name): + rag_corpus_name = name + elif re.match( + "^{}$".format( + _gapic_utils._VALID_RESOURCE_NAME_REGEX # pylint: disable=protected-access + ), + name, + ): + rag_corpus_name = parent + "/ragCorpora/" + name + else: + raise ValueError( + f"Invalid RagCorpus name: {name}. Proper format should be:" + " projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}" + ) + gapic_rag_resources.append( + aiplatform_v1beta1.VertexRagStore.RagResource( + rag_corpus=rag_corpus_name, + rag_file_ids=rag_resource.rag_file_ids, + ) + ) + vertex_rag_store = aiplatform_v1beta1.VertexRagStore( + rag_resources=gapic_rag_resources, ) - results = await vertexai.preview.rag.async_retrieve_contexts( - text="Why is the sky blue?", - rag_resources=[vertexai.preview.rag.RagResource( - rag_corpus="projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1", - rag_file_ids=["rag-file-1", "rag-file-2", ...], - )], - rag_retrieval_config=config, + elif rag_corpora: + warnings.warn( + "rag_corpora is deprecated. Please use rag_resources instead." + f" After {resources.DEPRECATION_DATE} using" + " rag_corpora will raise error", + DeprecationWarning, ) - ``` - - Args: - text: Required. The query in text format to get relevant contexts. - rag_resources: Optional. A list of RagResource. It can be used to specify - corpus only or ragfiles. Currently only support one corpus or multiple - files from one corpus. In the future we may open up multiple corpora - support. - rag_corpora: Optional. Deprecated. Please use rag_resources instead. A - list of RagCorpora resource names. Format: - ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` - Currently only support one corpus. In the future we may open up multiple - corpora support. - similarity_top_k: Optional. Deprecated. Please use - rag_retrieval_config.top_k instead. - vector_distance_threshold: Optional. Deprecated. Please use - rag_retrieval_config.filter.vector_distance_threshold instead. - vector_search_alpha: Optional. Deprecated. Please use - rag_retrieval_config.hybrid_search.alpha instead. - rag_retrieval_config: Optional. The config containing the retrieval - parameters, including top_k, vector_distance_threshold, and alpha. - timeout: Optional. The timeout for the request in seconds. Default is 600. - - Returns: - RetrieveContextsResponse. - """ - parent = initializer.global_config.common_location_path() - - client = _gapic_utils.create_rag_service_async_client() - - if not rag_resources and not rag_corpora: - raise ValueError("rag_resources or rag_corpora must be specified.") - - data_client = _gapic_utils.create_rag_data_service_client() - - gapic_rag_resources = [] - gapic_rag_corpora = [] - if rag_resources: - for rag_resource in rag_resources: - name = rag_resource.rag_corpus - if data_client.parse_rag_corpus_path(name): - rag_corpus_name = name - elif re.match( - "^{}$".format( - _gapic_utils._VALID_RESOURCE_NAME_REGEX # pylint: disable=protected-access - ), - name, - ): - rag_corpus_name = parent + "/ragCorpora/" + name - else: - raise ValueError( - f"Invalid RagCorpus name: {name}. Proper format should be:" - " projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}" - ) - gapic_rag_resources.append( - aiplatform_v1beta1.VertexRagStore.RagResource( - rag_corpus=rag_corpus_name, - rag_file_ids=rag_resource.rag_file_ids, - ) - ) - vertex_rag_store = aiplatform_v1beta1.VertexRagStore( - rag_resources=gapic_rag_resources, - ) - elif rag_corpora: - warnings.warn( - "rag_corpora is deprecated. Please use rag_resources instead." - f" After {resources.DEPRECATION_DATE} using" - " rag_corpora will raise error", - DeprecationWarning, - ) - for name in rag_corpora: - if data_client.parse_rag_corpus_path(name): - rag_corpus_name = name - elif re.match( - "^{}$".format( - _gapic_utils._VALID_RESOURCE_NAME_REGEX # pylint: disable=protected-access - ), - name, - ): - rag_corpus_name = parent + "/ragCorpora/" + name - else: - raise ValueError( - f"Invalid RagCorpus name: {name}. Proper format should be:" - " projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}" - ) - gapic_rag_corpora.append(rag_corpus_name) - vertex_rag_store = aiplatform_v1beta1.VertexRagStore( - rag_corpora=gapic_rag_corpora, + for name in rag_corpora: + if data_client.parse_rag_corpus_path(name): + rag_corpus_name = name + elif re.match( + "^{}$".format( + _gapic_utils._VALID_RESOURCE_NAME_REGEX # pylint: disable=protected-access + ), + name, + ): + rag_corpus_name = parent + "/ragCorpora/" + name + else: + raise ValueError( + f"Invalid RagCorpus name: {name}. Proper format should be:" + " projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}" ) + gapic_rag_corpora.append(rag_corpus_name) + vertex_rag_store = aiplatform_v1beta1.VertexRagStore( + rag_corpora=gapic_rag_corpora, + ) - # Check for deprecated parameters and raise warnings. - if similarity_top_k: - warnings.warn( - "similarity_top_k is deprecated. Please use" - " rag_retrieval_config.top_k instead." - f" After {resources.DEPRECATION_DATE} using" - " similarity_top_k will raise error", - DeprecationWarning, - ) - if vector_search_alpha: - warnings.warn( - "vector_search_alpha is deprecated. Please use" - " rag_retrieval_config.alpha instead." - f" After {resources.DEPRECATION_DATE} using" - " vector_search_alpha will raise error", - DeprecationWarning, - ) - if vector_distance_threshold: - warnings.warn( - "vector_distance_threshold is deprecated. Please use" - " rag_retrieval_config.filter.vector_distance_threshold instead." - f" After {resources.DEPRECATION_DATE} using" - " vector_distance_threshold will raise error", - DeprecationWarning, - ) + # Check for deprecated parameters and raise warnings. + if similarity_top_k: + warnings.warn( + "similarity_top_k is deprecated. Please use" + " rag_retrieval_config.top_k instead." + f" After {resources.DEPRECATION_DATE} using" + " similarity_top_k will raise error", + DeprecationWarning, + ) + if vector_search_alpha: + warnings.warn( + "vector_search_alpha is deprecated. Please use" + " rag_retrieval_config.alpha instead." + f" After {resources.DEPRECATION_DATE} using" + " vector_search_alpha will raise error", + DeprecationWarning, + ) + if vector_distance_threshold: + warnings.warn( + "vector_distance_threshold is deprecated. Please use" + " rag_retrieval_config.filter.vector_distance_threshold instead." + f" After {resources.DEPRECATION_DATE} using" + " vector_distance_threshold will raise error", + DeprecationWarning, + ) - # If rag_retrieval_config is not specified, set it to default values. - if not rag_retrieval_config: - api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig( - top_k=similarity_top_k, - hybrid_search=aiplatform_v1beta1.RagRetrievalConfig.HybridSearch( - alpha=vector_search_alpha, - ), - filter=aiplatform_v1beta1.RagRetrievalConfig.Filter( - vector_distance_threshold=vector_distance_threshold - ), - ) + # If rag_retrieval_config is not specified, set it to default values. + if not rag_retrieval_config: + api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig( + top_k=similarity_top_k, + hybrid_search=aiplatform_v1beta1.RagRetrievalConfig.HybridSearch( + alpha=vector_search_alpha, + ), + filter=aiplatform_v1beta1.RagRetrievalConfig.Filter( + vector_distance_threshold=vector_distance_threshold + ), + ) + else: + api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig() + if rag_retrieval_config.top_k: + api_retrival_config.top_k = rag_retrieval_config.top_k else: - api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig() - if rag_retrieval_config.top_k: - api_retrival_config.top_k = rag_retrieval_config.top_k - else: - api_retrival_config.top_k = similarity_top_k - - if ( - rag_retrieval_config.hybrid_search - and rag_retrieval_config.hybrid_search.alpha - ): - api_retrival_config.hybrid_search.alpha = ( - rag_retrieval_config.hybrid_search.alpha - ) - else: - api_retrival_config.hybrid_search.alpha = vector_search_alpha - - if ( - rag_retrieval_config.filter - and rag_retrieval_config.filter.vector_distance_threshold - and rag_retrieval_config.filter.vector_similarity_threshold - ): - raise ValueError( - "Only one of vector_distance_threshold or" - " vector_similarity_threshold can be specified at a time" - " in rag_retrieval_config." - ) - - if ( - rag_retrieval_config.filter - and rag_retrieval_config.filter.vector_distance_threshold - ): - api_retrival_config.filter.vector_distance_threshold = ( - rag_retrieval_config.filter.vector_distance_threshold - ) - else: - api_retrival_config.filter.vector_distance_threshold = ( - vector_distance_threshold - ) - - if ( - rag_retrieval_config.filter - and rag_retrieval_config.filter.vector_similarity_threshold - ): - api_retrival_config.filter.vector_similarity_threshold = ( - rag_retrieval_config.filter.vector_similarity_threshold - ) + api_retrival_config.top_k = similarity_top_k - if ( - rag_retrieval_config.ranking - and rag_retrieval_config.ranking.rank_service - and rag_retrieval_config.ranking.llm_ranker - ): - raise ValueError("Only one of rank_service and llm_ranker can be set.") - if rag_retrieval_config.ranking and rag_retrieval_config.ranking.rank_service: - api_retrival_config.ranking.rank_service.model_name = ( - rag_retrieval_config.ranking.rank_service.model_name - ) - elif rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker: - api_retrival_config.ranking.llm_ranker.model_name = ( - rag_retrieval_config.ranking.llm_ranker.model_name - ) - if rag_retrieval_config.filter and rag_retrieval_config.filter.metadata_filter: - api_retrival_config.filter.metadata_filter = ( - rag_retrieval_config.filter.metadata_filter - ) - - query = aiplatform_v1beta1.RagQuery( - text=text, - rag_retrieval_config=api_retrival_config, - ) + if ( + rag_retrieval_config.hybrid_search + and rag_retrieval_config.hybrid_search.alpha + ): + api_retrival_config.hybrid_search.alpha = ( + rag_retrieval_config.hybrid_search.alpha + ) + else: + api_retrival_config.hybrid_search.alpha = vector_search_alpha - vertex_rag_store.rag_retrieval_config = api_retrival_config + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_distance_threshold + and rag_retrieval_config.filter.vector_similarity_threshold + ): + raise ValueError( + "Only one of vector_distance_threshold or" + " vector_similarity_threshold can be specified at a time" + " in rag_retrieval_config." + ) + + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_distance_threshold + ): + api_retrival_config.filter.vector_distance_threshold = ( + rag_retrieval_config.filter.vector_distance_threshold + ) + else: + api_retrival_config.filter.vector_distance_threshold = ( + vector_distance_threshold + ) - tool = aiplatform_v1beta1.Tool( - retrieval=aiplatform_v1beta1.Retrieval( - vertex_rag_store=vertex_rag_store, - ) + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_similarity_threshold + ): + api_retrival_config.filter.vector_similarity_threshold = ( + rag_retrieval_config.filter.vector_similarity_threshold + ) + + if ( + rag_retrieval_config.ranking + and rag_retrieval_config.ranking.rank_service + and rag_retrieval_config.ranking.llm_ranker + ): + raise ValueError("Only one of rank_service and llm_ranker can be set.") + if ( + rag_retrieval_config.ranking + and rag_retrieval_config.ranking.rank_service + ): + api_retrival_config.ranking.rank_service.model_name = ( + rag_retrieval_config.ranking.rank_service.model_name + ) + elif ( + rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker + ): + api_retrival_config.ranking.llm_ranker.model_name = ( + rag_retrieval_config.ranking.llm_ranker.model_name + ) + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.metadata_filter + ): + api_retrival_config.filter.metadata_filter = ( + rag_retrieval_config.filter.metadata_filter + ) + + query = aiplatform_v1beta1.RagQuery( + text=text, + rag_retrieval_config=api_retrival_config, + ) + + vertex_rag_store.rag_retrieval_config = api_retrival_config + + tool = aiplatform_v1beta1.Tool( + retrieval=aiplatform_v1beta1.Retrieval( + vertex_rag_store=vertex_rag_store, + ) + ) + + request = aiplatform_v1beta1.AsyncRetrieveContextsRequest( + parent=parent, + query=query, + tools=[tool], + ) + try: + + response_lro = await client.async_retrieve_contexts( + request=request, timeout=timeout ) - - request = aiplatform_v1beta1.AsyncRetrieveContextsRequest( - parent=parent, - query=query, - tools=[tool], + print( + f"[DEBUG] async_retrieve_contexts called. LRO: {response_lro}, Type:" + f" {type(response_lro)}" ) try: - response_lro = await client.async_retrieve_contexts( - request=request, timeout=timeout - ) - try: - response = await response_lro.result(timeout=timeout) - except Exception as e: - if response_lro.done(): - raw_op = response_lro.operation - if raw_op.WhichOneof("result") == "response": - any_response = raw_op.response - inner_any = any_pb2.Any() - if any_response.Unpack(inner_any): - inner_any.type_url = "type.googleapis.com/google.cloud.aiplatform.v1beta1.RagContexts" - rag_contexts = aiplatform_v1beta1.RagContexts() - if inner_any.Unpack(rag_contexts._pb): - return aiplatform_v1beta1.AsyncRetrieveContextsResponse( - contexts=rag_contexts - ) - raise e + print(f"[DEBUG] Waiting for LRO result with timeout={timeout}") + response = await response_lro.result(timeout=timeout) + print(f"[DEBUG] LRO result received successfully. Type: {type(response)}") except Exception as e: - raise RuntimeError( - "Failed in retrieving contexts asynchronously due to: ", e - ) from e + print( + f"[DEBUG] Exception while waiting for LRO result: {e!r}, Type:" + f" {type(e)}" + ) + lro_done = response_lro.done() if response_lro else "None" + print(f"[DEBUG] response_lro.done(): {lro_done}") + if response_lro and response_lro.done(): + raw_op = response_lro.operation + print(f"[DEBUG] raw_op: {raw_op}, Type: {type(raw_op)}") + print(f"[DEBUG] raw_op.done: {raw_op.done}") + print(f"[DEBUG] raw_op result oneof: {raw_op.WhichOneof('result')}") + if raw_op.WhichOneof("result") == "error": + print(f"[DEBUG] raw_op error: {raw_op.error}") + if raw_op.WhichOneof("result") == "response": + any_response = raw_op.response + print( + f"[DEBUG] raw_op result oneof: {raw_op.WhichOneof('result')}," + f" Type: {type(raw_op)}" + ) + print( + "[DEBUG] any_response.type_url before hack:" + f" {any_response.type_url}" + ) + # HACK: Override type_url to match SDK's expected type + any_response.type_url = ( + "type.googleapis.com/google.cloud.aiplatform.v1beta1.RagContexts" + ) + print( + "[DEBUG] any_response.type_url after hack:" + f" {any_response.type_url}" + ) + rag_contexts = aiplatform_v1beta1.RagContexts() + if any_response.Unpack(rag_contexts._pb): + print( + f"[DEBUG] Unpack successful. rag_contexts: {rag_contexts}," + f" Type: {type(rag_contexts)}" + ) + if not rag_contexts.contexts: + raise ValueError("No rag contexts were returned.") from e + return aiplatform_v1beta1.AsyncRetrieveContextsResponse( + contexts=rag_contexts + ) + raise e + except Exception as e: + print( + f"[DEBUG] Top level exception in async_retrieve_contexts: {e!r}, Type:" + f" {type(e)}" + ) + raise RuntimeError( + "Failed in retrieving contexts asynchronously due to: ", e + ) from e - return response + return response def ask_contexts( @@ -564,241 +624,256 @@ def ask_contexts( rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None, timeout: int = 600, ) -> aiplatform_v1beta1.AskContextsResponse: - """Ask questions on top k relevant docs/chunks. - - Example usage: - ``` - import vertexai - - vertexai.init(project="my-project") - - config = vertexai.preview.rag.RagRetrievalConfig( - top_k=2, + """Ask questions on top k relevant docs/chunks. + + Example usage: + ``` + import vertexai + + vertexai.init(project="my-project") + + config = vertexai.preview.rag.RagRetrievalConfig( + top_k=2, + ) + + results = vertexai.preview.rag.ask_contexts( + text="Why is the sky blue?", + rag_resources=[vertexai.preview.rag.RagResource( + rag_corpus="projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1", + rag_file_ids=["rag-file-1", "rag-file-2", ...], + )], + rag_retrieval_config=config, + ) + ``` + + Args: + text: Required. The query in text format to get relevant contexts. + rag_resources: Optional. A list of RagResource. It can be used to specify + corpus only or ragfiles. Currently only support one corpus or multiple + files from one corpus. In the future we may open up multiple corpora + support. + rag_corpora: Optional. Deprecated. Please use rag_resources instead. A + list of RagCorpora resource names. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + Currently only support one corpus. In the future we may open up multiple + corpora support. + similarity_top_k: Optional. Deprecated. Please use + rag_retrieval_config.top_k instead. + vector_distance_threshold: Optional. Deprecated. Please use + rag_retrieval_config.filter.vector_distance_threshold instead. + vector_search_alpha: Optional. Deprecated. Please use + rag_retrieval_config.hybrid_search.alpha instead. + rag_retrieval_config: Optional. The config containing the retrieval + parameters, including top_k, vector_distance_threshold, and alpha. + timeout: Optional. The timeout for the request in seconds. Default is 600. + + Returns: + AskContextsResponse. + """ + parent = initializer.global_config.common_location_path() + + client = _gapic_utils.create_rag_service_client() + + if not rag_resources and not rag_corpora: + raise ValueError("rag_resources or rag_corpora must be specified.") + + data_client = _gapic_utils.create_rag_data_service_client() + + gapic_rag_resources = [] + gapic_rag_corpora = [] + if rag_resources: + for rag_resource in rag_resources: + name = rag_resource.rag_corpus + if data_client.parse_rag_corpus_path(name): + rag_corpus_name = name + elif re.match( + "^{}$".format( + _gapic_utils._VALID_RESOURCE_NAME_REGEX # pylint: disable=protected-access + ), + name, + ): + rag_corpus_name = parent + "/ragCorpora/" + name + else: + raise ValueError( + f"Invalid RagCorpus name: {name}. Proper format should be:" + " projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}" + ) + gapic_rag_resources.append( + aiplatform_v1beta1.VertexRagStore.RagResource( + rag_corpus=rag_corpus_name, + rag_file_ids=rag_resource.rag_file_ids, + ) + ) + vertex_rag_store = aiplatform_v1beta1.VertexRagStore( + rag_resources=gapic_rag_resources, ) - - results = vertexai.preview.rag.ask_contexts( - text="Why is the sky blue?", - rag_resources=[vertexai.preview.rag.RagResource( - rag_corpus="projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1", - rag_file_ids=["rag-file-1", "rag-file-2", ...], - )], - rag_retrieval_config=config, + elif rag_corpora: + warnings.warn( + "rag_corpora is deprecated. Please use rag_resources instead." + f" After {resources.DEPRECATION_DATE} using" + " rag_corpora will raise error", + DeprecationWarning, ) - ``` - - Args: - text: Required. The query in text format to get relevant contexts. - rag_resources: Optional. A list of RagResource. It can be used to specify - corpus only or ragfiles. Currently only support one corpus or multiple - files from one corpus. In the future we may open up multiple corpora - support. - rag_corpora: Optional. Deprecated. Please use rag_resources instead. A - list of RagCorpora resource names. Format: - ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` - Currently only support one corpus. In the future we may open up multiple - corpora support. - similarity_top_k: Optional. Deprecated. Please use - rag_retrieval_config.top_k instead. - vector_distance_threshold: Optional. Deprecated. Please use - rag_retrieval_config.filter.vector_distance_threshold instead. - vector_search_alpha: Optional. Deprecated. Please use - rag_retrieval_config.hybrid_search.alpha instead. - rag_retrieval_config: Optional. The config containing the retrieval - parameters, including top_k, vector_distance_threshold, and alpha. - timeout: Optional. The timeout for the request in seconds. Default is 600. - - Returns: - AskContextsResponse. - """ - parent = initializer.global_config.common_location_path() - - client = _gapic_utils.create_rag_service_client() - - if not rag_resources and not rag_corpora: - raise ValueError("rag_resources or rag_corpora must be specified.") - - data_client = _gapic_utils.create_rag_data_service_client() - - gapic_rag_resources = [] - gapic_rag_corpora = [] - if rag_resources: - for rag_resource in rag_resources: - name = rag_resource.rag_corpus - if data_client.parse_rag_corpus_path(name): - rag_corpus_name = name - elif re.match( - "^{}$".format( - _gapic_utils._VALID_RESOURCE_NAME_REGEX # pylint: disable=protected-access - ), - name, - ): - rag_corpus_name = parent + "/ragCorpora/" + name - else: - raise ValueError( - f"Invalid RagCorpus name: {name}. Proper format should be:" - " projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}" - ) - gapic_rag_resources.append( - aiplatform_v1beta1.VertexRagStore.RagResource( - rag_corpus=rag_corpus_name, - rag_file_ids=rag_resource.rag_file_ids, - ) - ) - vertex_rag_store = aiplatform_v1beta1.VertexRagStore( - rag_resources=gapic_rag_resources, - ) - elif rag_corpora: - warnings.warn( - "rag_corpora is deprecated. Please use rag_resources instead." - f" After {resources.DEPRECATION_DATE} using" - " rag_corpora will raise error", - DeprecationWarning, - ) - for name in rag_corpora: - if data_client.parse_rag_corpus_path(name): - rag_corpus_name = name - elif re.match( - "^{}$".format( - _gapic_utils._VALID_RESOURCE_NAME_REGEX # pylint: disable=protected-access - ), - name, - ): - rag_corpus_name = parent + "/ragCorpora/" + name - else: - raise ValueError( - f"Invalid RagCorpus name: {name}. Proper format should be:" - " projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}" - ) - gapic_rag_corpora.append(rag_corpus_name) - vertex_rag_store = aiplatform_v1beta1.VertexRagStore( - rag_corpora=gapic_rag_corpora, - ) - - # Check for deprecated parameters and raise warnings. - if similarity_top_k: - warnings.warn( - "similarity_top_k is deprecated. Please use" - " rag_retrieval_config.top_k instead." - f" After {resources.DEPRECATION_DATE} using" - " similarity_top_k will raise error", - DeprecationWarning, - ) - if vector_search_alpha: - warnings.warn( - "vector_search_alpha is deprecated. Please use" - " rag_retrieval_config.alpha instead." - f" After {resources.DEPRECATION_DATE} using" - " vector_search_alpha will raise error", - DeprecationWarning, - ) - if vector_distance_threshold: - warnings.warn( - "vector_distance_threshold is deprecated. Please use" - " rag_retrieval_config.filter.vector_distance_threshold instead." - f" After {resources.DEPRECATION_DATE} using" - " vector_distance_threshold will raise error", - DeprecationWarning, - ) - - # If rag_retrieval_config is not specified, set it to default values. - if not rag_retrieval_config: - api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig( - top_k=similarity_top_k, - hybrid_search=aiplatform_v1beta1.RagRetrievalConfig.HybridSearch( - alpha=vector_search_alpha, - ), - filter=aiplatform_v1beta1.RagRetrievalConfig.Filter( - vector_distance_threshold=vector_distance_threshold - ), + for name in rag_corpora: + if data_client.parse_rag_corpus_path(name): + rag_corpus_name = name + elif re.match( + "^{}$".format( + _gapic_utils._VALID_RESOURCE_NAME_REGEX # pylint: disable=protected-access + ), + name, + ): + rag_corpus_name = parent + "/ragCorpora/" + name + else: + raise ValueError( + f"Invalid RagCorpus name: {name}. Proper format should be:" + " projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}" ) - else: - api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig() - if rag_retrieval_config.top_k: - api_retrival_config.top_k = rag_retrieval_config.top_k - else: - api_retrival_config.top_k = similarity_top_k - - if ( - rag_retrieval_config.hybrid_search - and rag_retrieval_config.hybrid_search.alpha - ): - api_retrival_config.hybrid_search.alpha = ( - rag_retrieval_config.hybrid_search.alpha - ) - else: - api_retrival_config.hybrid_search.alpha = vector_search_alpha - - if ( - rag_retrieval_config.filter - and rag_retrieval_config.filter.vector_distance_threshold - and rag_retrieval_config.filter.vector_similarity_threshold - ): - raise ValueError( - "Only one of vector_distance_threshold or" - " vector_similarity_threshold can be specified at a time" - " in rag_retrieval_config." - ) - - if ( - rag_retrieval_config.filter - and rag_retrieval_config.filter.vector_distance_threshold - ): - api_retrival_config.filter.vector_distance_threshold = ( - rag_retrieval_config.filter.vector_distance_threshold - ) - else: - api_retrival_config.filter.vector_distance_threshold = ( - vector_distance_threshold - ) - - if ( - rag_retrieval_config.filter - and rag_retrieval_config.filter.vector_similarity_threshold - ): - api_retrival_config.filter.vector_similarity_threshold = ( - rag_retrieval_config.filter.vector_similarity_threshold - ) - - if ( - rag_retrieval_config.ranking - and rag_retrieval_config.ranking.rank_service - and rag_retrieval_config.ranking.llm_ranker - ): - raise ValueError("Only one of rank_service and llm_ranker can be set.") - if rag_retrieval_config.ranking and rag_retrieval_config.ranking.rank_service: - api_retrival_config.ranking.rank_service.model_name = ( - rag_retrieval_config.ranking.rank_service.model_name - ) - elif rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker: - api_retrival_config.ranking.llm_ranker.model_name = ( - rag_retrieval_config.ranking.llm_ranker.model_name - ) - if rag_retrieval_config.filter and rag_retrieval_config.filter.metadata_filter: - api_retrival_config.filter.metadata_filter = ( - rag_retrieval_config.filter.metadata_filter - ) - - query = aiplatform_v1beta1.RagQuery( - text=text, - rag_retrieval_config=api_retrival_config, + gapic_rag_corpora.append(rag_corpus_name) + vertex_rag_store = aiplatform_v1beta1.VertexRagStore( + rag_corpora=gapic_rag_corpora, ) - vertex_rag_store.rag_retrieval_config = api_retrival_config - - tool = aiplatform_v1beta1.Tool( - retrieval=aiplatform_v1beta1.Retrieval( - vertex_rag_store=vertex_rag_store, - ) + # Check for deprecated parameters and raise warnings. + if similarity_top_k: + warnings.warn( + "similarity_top_k is deprecated. Please use" + " rag_retrieval_config.top_k instead." + f" After {resources.DEPRECATION_DATE} using" + " similarity_top_k will raise error", + DeprecationWarning, + ) + if vector_search_alpha: + warnings.warn( + "vector_search_alpha is deprecated. Please use" + " rag_retrieval_config.alpha instead." + f" After {resources.DEPRECATION_DATE} using" + " vector_search_alpha will raise error", + DeprecationWarning, + ) + if vector_distance_threshold: + warnings.warn( + "vector_distance_threshold is deprecated. Please use" + " rag_retrieval_config.filter.vector_distance_threshold instead." + f" After {resources.DEPRECATION_DATE} using" + " vector_distance_threshold will raise error", + DeprecationWarning, ) - request = aiplatform_v1beta1.AskContextsRequest( - parent=parent, - query=query, - tools=[tool], + # If rag_retrieval_config is not specified, set it to default values. + if not rag_retrieval_config: + api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig( + top_k=similarity_top_k, + hybrid_search=aiplatform_v1beta1.RagRetrievalConfig.HybridSearch( + alpha=vector_search_alpha, + ), + filter=aiplatform_v1beta1.RagRetrievalConfig.Filter( + vector_distance_threshold=vector_distance_threshold + ), ) - try: - response = client.ask_contexts(request=request, timeout=timeout) - except Exception as e: - raise RuntimeError("Failed in asking contexts due to: ", e) from e + else: + api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig() + if rag_retrieval_config.top_k: + api_retrival_config.top_k = rag_retrieval_config.top_k + else: + api_retrival_config.top_k = similarity_top_k + + if ( + rag_retrieval_config.hybrid_search + and rag_retrieval_config.hybrid_search.alpha + ): + api_retrival_config.hybrid_search.alpha = ( + rag_retrieval_config.hybrid_search.alpha + ) + else: + api_retrival_config.hybrid_search.alpha = vector_search_alpha + + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_distance_threshold + and rag_retrieval_config.filter.vector_similarity_threshold + ): + raise ValueError( + "Only one of vector_distance_threshold or" + " vector_similarity_threshold can be specified at a time" + " in rag_retrieval_config." + ) + + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_distance_threshold + ): + api_retrival_config.filter.vector_distance_threshold = ( + rag_retrieval_config.filter.vector_distance_threshold + ) + else: + api_retrival_config.filter.vector_distance_threshold = ( + vector_distance_threshold + ) - return response + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_similarity_threshold + ): + api_retrival_config.filter.vector_similarity_threshold = ( + rag_retrieval_config.filter.vector_similarity_threshold + ) + + if ( + rag_retrieval_config.ranking + and rag_retrieval_config.ranking.rank_service + and rag_retrieval_config.ranking.llm_ranker + ): + raise ValueError("Only one of rank_service and llm_ranker can be set.") + if ( + rag_retrieval_config.ranking + and rag_retrieval_config.ranking.rank_service + ): + api_retrival_config.ranking.rank_service.model_name = ( + rag_retrieval_config.ranking.rank_service.model_name + ) + elif ( + rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker + ): + api_retrival_config.ranking.llm_ranker.model_name = ( + rag_retrieval_config.ranking.llm_ranker.model_name + ) + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.metadata_filter + ): + api_retrival_config.filter.metadata_filter = ( + rag_retrieval_config.filter.metadata_filter + ) + + query = aiplatform_v1beta1.RagQuery( + text=text, + rag_retrieval_config=api_retrival_config, + ) + + vertex_rag_store.rag_retrieval_config = api_retrival_config + + tool = aiplatform_v1beta1.Tool( + retrieval=aiplatform_v1beta1.Retrieval( + vertex_rag_store=vertex_rag_store, + ) + ) + + request = aiplatform_v1beta1.AskContextsRequest( + parent=parent, + query=query, + tools=[tool], + ) + try: + response = client.ask_contexts(request=request, timeout=timeout) + except Exception as e: + print(f"DEBUG: ask_contexts failed with error: {e}") + if hasattr(e, "trailing_metadata"): + metadata = e.trailing_metadata() + print(f"DEBUG: trailing metadata: {metadata}") + for key, value in metadata: + if key == "x-google-trace-id" or "trace" in key: + print(f"DEBUG: Found trace header: {key}={value}") + raise RuntimeError("Failed in asking contexts due to: ", e) from e + + return response