From bd4983b488a3181578c552482af6ec78e8aa408e Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Wed, 8 Apr 2026 15:43:57 -0700 Subject: [PATCH] fix: Add SDK workaround for double-wrapped Any response in async_retrieve_contexts. PiperOrigin-RevId: 896740495 --- vertexai/preview/rag/rag_retrieval.py | 19 ++++++++++++++++++- vertexai/rag/rag_retrieval.py | 21 ++++++++++++++++++++- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/vertexai/preview/rag/rag_retrieval.py b/vertexai/preview/rag/rag_retrieval.py index 7dd3ca30df..d3d879509a 100644 --- a/vertexai/preview/rag/rag_retrieval.py +++ b/vertexai/preview/rag/rag_retrieval.py @@ -25,6 +25,8 @@ from vertexai.preview.rag.utils import _gapic_utils from vertexai.preview.rag.utils import resources +from google.protobuf import any_pb2 + def retrieval_query( text: str, @@ -528,7 +530,22 @@ async def async_retrieve_contexts( response_lro = await client.async_retrieve_contexts( request=request, timeout=timeout ) - response = await response_lro.result() + try: + response = await response_lro.result(timeout=timeout) + except Exception as e: + if response_lro.done(): + raw_op = response_lro.operation + if raw_op.WhichOneof("result") == "response": + any_response = raw_op.response + inner_any = any_pb2.Any() + if any_response.Unpack(inner_any): + inner_any.type_url = "type.googleapis.com/google.cloud.aiplatform.v1beta1.RagContexts" + rag_contexts = aiplatform_v1beta1.RagContexts() + if inner_any.Unpack(rag_contexts._pb): + return aiplatform_v1beta1.AsyncRetrieveContextsResponse( + contexts=rag_contexts + ) + raise e except Exception as e: raise RuntimeError( "Failed in retrieving contexts asynchronously due to: ", e diff --git a/vertexai/rag/rag_retrieval.py b/vertexai/rag/rag_retrieval.py index 9b6d684610..0640d413ba 100644 --- a/vertexai/rag/rag_retrieval.py +++ b/vertexai/rag/rag_retrieval.py @@ -24,6 +24,8 @@ from vertexai.rag.utils import _gapic_utils from vertexai.rag.utils import resources +from google.protobuf import any_pb2 + def retrieval_query( text: str, @@ -325,7 +327,24 @@ async def async_retrieve_contexts( response_lro = await client.async_retrieve_contexts( request=request, timeout=timeout ) - response = await response_lro.result() + try: + response = await response_lro.result(timeout=timeout) + except Exception as e: + if response_lro.done(): + raw_op = response_lro.operation + if raw_op.WhichOneof("result") == "response": + any_response = raw_op.response + inner_any = any_pb2.Any() + if any_response.Unpack(inner_any): + inner_any.type_url = ( + "type.googleapis.com/google.cloud.aiplatform.v1.RagContexts" + ) + rag_contexts = aiplatform_v1.RagContexts() + if inner_any.Unpack(rag_contexts._pb): + return aiplatform_v1.AsyncRetrieveContextsResponse( + contexts=rag_contexts + ) + raise e except Exception as e: raise RuntimeError( "Failed in retrieving contexts asynchronously due to: ", e