Skip to content

Commit a2938ae

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Add SDK workaround for double-wrapped Any response in async_retrieve_contexts.
PiperOrigin-RevId: 896867061
1 parent e2e81c9 commit a2938ae

File tree

1 file changed

+91
-85
lines changed

1 file changed

+91
-85
lines changed

vertexai/preview/rag/rag_retrieval.py

Lines changed: 91 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@
2525
from vertexai.preview.rag.utils import _gapic_utils
2626
from vertexai.preview.rag.utils import resources
2727

28-
from google.protobuf import any_pb2
29-
3028

3129
def retrieval_query(
3230
text: str,
@@ -537,14 +535,15 @@ async def async_retrieve_contexts(
537535
raw_op = response_lro.operation
538536
if raw_op.WhichOneof("result") == "response":
539537
any_response = raw_op.response
540-
inner_any = any_pb2.Any()
541-
if any_response.Unpack(inner_any):
542-
inner_any.type_url = "type.googleapis.com/google.cloud.aiplatform.v1beta1.RagContexts"
543-
rag_contexts = aiplatform_v1beta1.RagContexts()
544-
if inner_any.Unpack(rag_contexts._pb):
545-
return aiplatform_v1beta1.AsyncRetrieveContextsResponse(
546-
contexts=rag_contexts
547-
)
538+
# HACK: Override type_url to match SDK's expected type
539+
any_response.type_url = "type.googleapis.com/google.cloud.aiplatform.v1beta1.RagContexts"
540+
rag_contexts = aiplatform_v1beta1.RagContexts()
541+
if any_response.Unpack(rag_contexts._pb):
542+
if not rag_contexts.contexts:
543+
raise ValueError("No rag contexts were returned.")
544+
return aiplatform_v1beta1.AsyncRetrieveContextsResponse(
545+
contexts=rag_contexts
546+
)
548547
raise e
549548
except Exception as e:
550549
raise RuntimeError(
@@ -564,7 +563,7 @@ def ask_contexts(
564563
rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None,
565564
timeout: int = 600,
566565
) -> aiplatform_v1beta1.AskContextsResponse:
567-
"""Ask questions on top k relevant docs/chunks.
566+
"""Ask questions on top k relevant docs/chunks.
568567
569568
Example usage:
570569
```
@@ -610,99 +609,99 @@ def ask_contexts(
610609
Returns:
611610
AskContextsResponse.
612611
"""
613-
parent = initializer.global_config.common_location_path()
612+
parent = initializer.global_config.common_location_path()
614613

615-
client = _gapic_utils.create_rag_service_client()
614+
client = _gapic_utils.create_rag_service_client()
616615

617-
if not rag_resources and not rag_corpora:
618-
raise ValueError("rag_resources or rag_corpora must be specified.")
616+
if not rag_resources and not rag_corpora:
617+
raise ValueError("rag_resources or rag_corpora must be specified.")
619618

620-
data_client = _gapic_utils.create_rag_data_service_client()
619+
data_client = _gapic_utils.create_rag_data_service_client()
621620

622-
gapic_rag_resources = []
623-
gapic_rag_corpora = []
624-
if rag_resources:
625-
for rag_resource in rag_resources:
626-
name = rag_resource.rag_corpus
627-
if data_client.parse_rag_corpus_path(name):
628-
rag_corpus_name = name
629-
elif re.match(
621+
gapic_rag_resources = []
622+
gapic_rag_corpora = []
623+
if rag_resources:
624+
for rag_resource in rag_resources:
625+
name = rag_resource.rag_corpus
626+
if data_client.parse_rag_corpus_path(name):
627+
rag_corpus_name = name
628+
elif re.match(
630629
"^{}$".format(
631630
_gapic_utils._VALID_RESOURCE_NAME_REGEX # pylint: disable=protected-access
632631
),
633632
name,
634633
):
635-
rag_corpus_name = parent + "/ragCorpora/" + name
636-
else:
637-
raise ValueError(
634+
rag_corpus_name = parent + "/ragCorpora/" + name
635+
else:
636+
raise ValueError(
638637
f"Invalid RagCorpus name: {name}. Proper format should be:"
639638
" projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}"
640639
)
641-
gapic_rag_resources.append(
640+
gapic_rag_resources.append(
642641
aiplatform_v1beta1.VertexRagStore.RagResource(
643642
rag_corpus=rag_corpus_name,
644643
rag_file_ids=rag_resource.rag_file_ids,
645644
)
646645
)
647-
vertex_rag_store = aiplatform_v1beta1.VertexRagStore(
646+
vertex_rag_store = aiplatform_v1beta1.VertexRagStore(
648647
rag_resources=gapic_rag_resources,
649648
)
650-
elif rag_corpora:
651-
warnings.warn(
649+
elif rag_corpora:
650+
warnings.warn(
652651
"rag_corpora is deprecated. Please use rag_resources instead."
653652
f" After {resources.DEPRECATION_DATE} using"
654653
" rag_corpora will raise error",
655654
DeprecationWarning,
656655
)
657-
for name in rag_corpora:
658-
if data_client.parse_rag_corpus_path(name):
659-
rag_corpus_name = name
660-
elif re.match(
656+
for name in rag_corpora:
657+
if data_client.parse_rag_corpus_path(name):
658+
rag_corpus_name = name
659+
elif re.match(
661660
"^{}$".format(
662661
_gapic_utils._VALID_RESOURCE_NAME_REGEX # pylint: disable=protected-access
663662
),
664663
name,
665664
):
666-
rag_corpus_name = parent + "/ragCorpora/" + name
667-
else:
668-
raise ValueError(
665+
rag_corpus_name = parent + "/ragCorpora/" + name
666+
else:
667+
raise ValueError(
669668
f"Invalid RagCorpus name: {name}. Proper format should be:"
670669
" projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}"
671670
)
672-
gapic_rag_corpora.append(rag_corpus_name)
673-
vertex_rag_store = aiplatform_v1beta1.VertexRagStore(
671+
gapic_rag_corpora.append(rag_corpus_name)
672+
vertex_rag_store = aiplatform_v1beta1.VertexRagStore(
674673
rag_corpora=gapic_rag_corpora,
675674
)
676675

677-
# Check for deprecated parameters and raise warnings.
678-
if similarity_top_k:
679-
warnings.warn(
676+
# Check for deprecated parameters and raise warnings.
677+
if similarity_top_k:
678+
warnings.warn(
680679
"similarity_top_k is deprecated. Please use"
681680
" rag_retrieval_config.top_k instead."
682681
f" After {resources.DEPRECATION_DATE} using"
683682
" similarity_top_k will raise error",
684683
DeprecationWarning,
685684
)
686-
if vector_search_alpha:
687-
warnings.warn(
685+
if vector_search_alpha:
686+
warnings.warn(
688687
"vector_search_alpha is deprecated. Please use"
689688
" rag_retrieval_config.alpha instead."
690689
f" After {resources.DEPRECATION_DATE} using"
691690
" vector_search_alpha will raise error",
692691
DeprecationWarning,
693692
)
694-
if vector_distance_threshold:
695-
warnings.warn(
693+
if vector_distance_threshold:
694+
warnings.warn(
696695
"vector_distance_threshold is deprecated. Please use"
697696
" rag_retrieval_config.filter.vector_distance_threshold instead."
698697
f" After {resources.DEPRECATION_DATE} using"
699698
" vector_distance_threshold will raise error",
700699
DeprecationWarning,
701700
)
702701

703-
# If rag_retrieval_config is not specified, set it to default values.
704-
if not rag_retrieval_config:
705-
api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig(
702+
# If rag_retrieval_config is not specified, set it to default values.
703+
if not rag_retrieval_config:
704+
api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig(
706705
top_k=similarity_top_k,
707706
hybrid_search=aiplatform_v1beta1.RagRetrievalConfig.HybridSearch(
708707
alpha=vector_search_alpha,
@@ -711,94 +710,101 @@ def ask_contexts(
711710
vector_distance_threshold=vector_distance_threshold
712711
),
713712
)
713+
else:
714+
api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig()
715+
if rag_retrieval_config.top_k:
716+
api_retrival_config.top_k = rag_retrieval_config.top_k
714717
else:
715-
api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig()
716-
if rag_retrieval_config.top_k:
717-
api_retrival_config.top_k = rag_retrieval_config.top_k
718-
else:
719-
api_retrival_config.top_k = similarity_top_k
718+
api_retrival_config.top_k = similarity_top_k
720719

721-
if (
720+
if (
722721
rag_retrieval_config.hybrid_search
723722
and rag_retrieval_config.hybrid_search.alpha
724723
):
725-
api_retrival_config.hybrid_search.alpha = (
724+
api_retrival_config.hybrid_search.alpha = (
726725
rag_retrieval_config.hybrid_search.alpha
727726
)
728-
else:
729-
api_retrival_config.hybrid_search.alpha = vector_search_alpha
727+
else:
728+
api_retrival_config.hybrid_search.alpha = vector_search_alpha
730729

731-
if (
730+
if (
732731
rag_retrieval_config.filter
733732
and rag_retrieval_config.filter.vector_distance_threshold
734733
and rag_retrieval_config.filter.vector_similarity_threshold
735734
):
736-
raise ValueError(
735+
raise ValueError(
737736
"Only one of vector_distance_threshold or"
738737
" vector_similarity_threshold can be specified at a time"
739738
" in rag_retrieval_config."
740739
)
741740

742-
if (
741+
if (
743742
rag_retrieval_config.filter
744743
and rag_retrieval_config.filter.vector_distance_threshold
745744
):
746-
api_retrival_config.filter.vector_distance_threshold = (
745+
api_retrival_config.filter.vector_distance_threshold = (
747746
rag_retrieval_config.filter.vector_distance_threshold
748747
)
749-
else:
750-
api_retrival_config.filter.vector_distance_threshold = (
748+
else:
749+
api_retrival_config.filter.vector_distance_threshold = (
751750
vector_distance_threshold
752751
)
753752

754-
if (
753+
if (
755754
rag_retrieval_config.filter
756755
and rag_retrieval_config.filter.vector_similarity_threshold
757756
):
758-
api_retrival_config.filter.vector_similarity_threshold = (
757+
api_retrival_config.filter.vector_similarity_threshold = (
759758
rag_retrieval_config.filter.vector_similarity_threshold
760759
)
761760

762-
if (
761+
if (
763762
rag_retrieval_config.ranking
764763
and rag_retrieval_config.ranking.rank_service
765764
and rag_retrieval_config.ranking.llm_ranker
766765
):
767-
raise ValueError("Only one of rank_service and llm_ranker can be set.")
768-
if rag_retrieval_config.ranking and rag_retrieval_config.ranking.rank_service:
769-
api_retrival_config.ranking.rank_service.model_name = (
766+
raise ValueError("Only one of rank_service and llm_ranker can be set.")
767+
if rag_retrieval_config.ranking and rag_retrieval_config.ranking.rank_service:
768+
api_retrival_config.ranking.rank_service.model_name = (
770769
rag_retrieval_config.ranking.rank_service.model_name
771770
)
772-
elif rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker:
773-
api_retrival_config.ranking.llm_ranker.model_name = (
771+
elif rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker:
772+
api_retrival_config.ranking.llm_ranker.model_name = (
774773
rag_retrieval_config.ranking.llm_ranker.model_name
775774
)
776-
if rag_retrieval_config.filter and rag_retrieval_config.filter.metadata_filter:
777-
api_retrival_config.filter.metadata_filter = (
775+
if rag_retrieval_config.filter and rag_retrieval_config.filter.metadata_filter:
776+
api_retrival_config.filter.metadata_filter = (
778777
rag_retrieval_config.filter.metadata_filter
779778
)
780779

781-
query = aiplatform_v1beta1.RagQuery(
780+
query = aiplatform_v1beta1.RagQuery(
782781
text=text,
783782
rag_retrieval_config=api_retrival_config,
784783
)
785784

786-
vertex_rag_store.rag_retrieval_config = api_retrival_config
785+
vertex_rag_store.rag_retrieval_config = api_retrival_config
787786

788-
tool = aiplatform_v1beta1.Tool(
787+
tool = aiplatform_v1beta1.Tool(
789788
retrieval=aiplatform_v1beta1.Retrieval(
790789
vertex_rag_store=vertex_rag_store,
791790
)
792791
)
793792

794-
request = aiplatform_v1beta1.AskContextsRequest(
793+
request = aiplatform_v1beta1.AskContextsRequest(
795794
parent=parent,
796795
query=query,
797796
tools=[tool],
798797
)
799-
try:
800-
response = client.ask_contexts(request=request, timeout=timeout)
801-
except Exception as e:
802-
raise RuntimeError("Failed in asking contexts due to: ", e) from e
803-
804-
return response
798+
try:
799+
response = client.ask_contexts(request=request, timeout=timeout)
800+
except Exception as e:
801+
print(f"DEBUG: ask_contexts failed with error: {e}")
802+
if hasattr(e, "trailing_metadata"):
803+
metadata = e.trailing_metadata()
804+
print(f"DEBUG: trailing metadata: {metadata}")
805+
for key, value in metadata:
806+
if key == "x-google-trace-id" or "trace" in key:
807+
print(f"DEBUG: Found trace header: {key}={value}")
808+
raise RuntimeError("Failed in asking contexts due to: ", e) from e
809+
810+
return response

0 commit comments

Comments
 (0)