2525from vertexai .preview .rag .utils import _gapic_utils
2626from vertexai .preview .rag .utils import resources
2727
28- from google .protobuf import any_pb2
29-
3028
3129def retrieval_query (
3230 text : str ,
@@ -537,14 +535,15 @@ async def async_retrieve_contexts(
537535 raw_op = response_lro .operation
538536 if raw_op .WhichOneof ("result" ) == "response" :
539537 any_response = raw_op .response
540- inner_any = any_pb2 .Any ()
541- if any_response .Unpack (inner_any ):
542- inner_any .type_url = "type.googleapis.com/google.cloud.aiplatform.v1beta1.RagContexts"
543- rag_contexts = aiplatform_v1beta1 .RagContexts ()
544- if inner_any .Unpack (rag_contexts ._pb ):
545- return aiplatform_v1beta1 .AsyncRetrieveContextsResponse (
546- contexts = rag_contexts
547- )
538+ # HACK: Override type_url to match SDK's expected type
539+ any_response .type_url = "type.googleapis.com/google.cloud.aiplatform.v1beta1.RagContexts"
540+ rag_contexts = aiplatform_v1beta1 .RagContexts ()
541+ if any_response .Unpack (rag_contexts ._pb ):
542+ if not rag_contexts .contexts :
543+ raise ValueError ("No rag contexts were returned." )
544+ return aiplatform_v1beta1 .AsyncRetrieveContextsResponse (
545+ contexts = rag_contexts
546+ )
548547 raise e
549548 except Exception as e :
550549 raise RuntimeError (
@@ -564,7 +563,7 @@ def ask_contexts(
564563 rag_retrieval_config : Optional [resources .RagRetrievalConfig ] = None ,
565564 timeout : int = 600 ,
566565) -> aiplatform_v1beta1 .AskContextsResponse :
567- """Ask questions on top k relevant docs/chunks.
566+ """Ask questions on top k relevant docs/chunks.
568567
569568 Example usage:
570569 ```
@@ -610,99 +609,99 @@ def ask_contexts(
610609 Returns:
611610 AskContextsResponse.
612611 """
613- parent = initializer .global_config .common_location_path ()
612+ parent = initializer .global_config .common_location_path ()
614613
615- client = _gapic_utils .create_rag_service_client ()
614+ client = _gapic_utils .create_rag_service_client ()
616615
617- if not rag_resources and not rag_corpora :
618- raise ValueError ("rag_resources or rag_corpora must be specified." )
616+ if not rag_resources and not rag_corpora :
617+ raise ValueError ("rag_resources or rag_corpora must be specified." )
619618
620- data_client = _gapic_utils .create_rag_data_service_client ()
619+ data_client = _gapic_utils .create_rag_data_service_client ()
621620
622- gapic_rag_resources = []
623- gapic_rag_corpora = []
624- if rag_resources :
625- for rag_resource in rag_resources :
626- name = rag_resource .rag_corpus
627- if data_client .parse_rag_corpus_path (name ):
628- rag_corpus_name = name
629- elif re .match (
621+ gapic_rag_resources = []
622+ gapic_rag_corpora = []
623+ if rag_resources :
624+ for rag_resource in rag_resources :
625+ name = rag_resource .rag_corpus
626+ if data_client .parse_rag_corpus_path (name ):
627+ rag_corpus_name = name
628+ elif re .match (
630629 "^{}$" .format (
631630 _gapic_utils ._VALID_RESOURCE_NAME_REGEX # pylint: disable=protected-access
632631 ),
633632 name ,
634633 ):
635- rag_corpus_name = parent + "/ragCorpora/" + name
636- else :
637- raise ValueError (
634+ rag_corpus_name = parent + "/ragCorpora/" + name
635+ else :
636+ raise ValueError (
638637 f"Invalid RagCorpus name: { name } . Proper format should be:"
639638 " projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}"
640639 )
641- gapic_rag_resources .append (
640+ gapic_rag_resources .append (
642641 aiplatform_v1beta1 .VertexRagStore .RagResource (
643642 rag_corpus = rag_corpus_name ,
644643 rag_file_ids = rag_resource .rag_file_ids ,
645644 )
646645 )
647- vertex_rag_store = aiplatform_v1beta1 .VertexRagStore (
646+ vertex_rag_store = aiplatform_v1beta1 .VertexRagStore (
648647 rag_resources = gapic_rag_resources ,
649648 )
650- elif rag_corpora :
651- warnings .warn (
649+ elif rag_corpora :
650+ warnings .warn (
652651 "rag_corpora is deprecated. Please use rag_resources instead."
653652 f" After { resources .DEPRECATION_DATE } using"
654653 " rag_corpora will raise error" ,
655654 DeprecationWarning ,
656655 )
657- for name in rag_corpora :
658- if data_client .parse_rag_corpus_path (name ):
659- rag_corpus_name = name
660- elif re .match (
656+ for name in rag_corpora :
657+ if data_client .parse_rag_corpus_path (name ):
658+ rag_corpus_name = name
659+ elif re .match (
661660 "^{}$" .format (
662661 _gapic_utils ._VALID_RESOURCE_NAME_REGEX # pylint: disable=protected-access
663662 ),
664663 name ,
665664 ):
666- rag_corpus_name = parent + "/ragCorpora/" + name
667- else :
668- raise ValueError (
665+ rag_corpus_name = parent + "/ragCorpora/" + name
666+ else :
667+ raise ValueError (
669668 f"Invalid RagCorpus name: { name } . Proper format should be:"
670669 " projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}"
671670 )
672- gapic_rag_corpora .append (rag_corpus_name )
673- vertex_rag_store = aiplatform_v1beta1 .VertexRagStore (
671+ gapic_rag_corpora .append (rag_corpus_name )
672+ vertex_rag_store = aiplatform_v1beta1 .VertexRagStore (
674673 rag_corpora = gapic_rag_corpora ,
675674 )
676675
677- # Check for deprecated parameters and raise warnings.
678- if similarity_top_k :
679- warnings .warn (
676+ # Check for deprecated parameters and raise warnings.
677+ if similarity_top_k :
678+ warnings .warn (
680679 "similarity_top_k is deprecated. Please use"
681680 " rag_retrieval_config.top_k instead."
682681 f" After { resources .DEPRECATION_DATE } using"
683682 " similarity_top_k will raise error" ,
684683 DeprecationWarning ,
685684 )
686- if vector_search_alpha :
687- warnings .warn (
685+ if vector_search_alpha :
686+ warnings .warn (
688687 "vector_search_alpha is deprecated. Please use"
689688 " rag_retrieval_config.alpha instead."
690689 f" After { resources .DEPRECATION_DATE } using"
691690 " vector_search_alpha will raise error" ,
692691 DeprecationWarning ,
693692 )
694- if vector_distance_threshold :
695- warnings .warn (
693+ if vector_distance_threshold :
694+ warnings .warn (
696695 "vector_distance_threshold is deprecated. Please use"
697696 " rag_retrieval_config.filter.vector_distance_threshold instead."
698697 f" After { resources .DEPRECATION_DATE } using"
699698 " vector_distance_threshold will raise error" ,
700699 DeprecationWarning ,
701700 )
702701
703- # If rag_retrieval_config is not specified, set it to default values.
704- if not rag_retrieval_config :
705- api_retrival_config = aiplatform_v1beta1 .RagRetrievalConfig (
702+ # If rag_retrieval_config is not specified, set it to default values.
703+ if not rag_retrieval_config :
704+ api_retrival_config = aiplatform_v1beta1 .RagRetrievalConfig (
706705 top_k = similarity_top_k ,
707706 hybrid_search = aiplatform_v1beta1 .RagRetrievalConfig .HybridSearch (
708707 alpha = vector_search_alpha ,
@@ -711,94 +710,101 @@ def ask_contexts(
711710 vector_distance_threshold = vector_distance_threshold
712711 ),
713712 )
713+ else :
714+ api_retrival_config = aiplatform_v1beta1 .RagRetrievalConfig ()
715+ if rag_retrieval_config .top_k :
716+ api_retrival_config .top_k = rag_retrieval_config .top_k
714717 else :
715- api_retrival_config = aiplatform_v1beta1 .RagRetrievalConfig ()
716- if rag_retrieval_config .top_k :
717- api_retrival_config .top_k = rag_retrieval_config .top_k
718- else :
719- api_retrival_config .top_k = similarity_top_k
718+ api_retrival_config .top_k = similarity_top_k
720719
721- if (
720+ if (
722721 rag_retrieval_config .hybrid_search
723722 and rag_retrieval_config .hybrid_search .alpha
724723 ):
725- api_retrival_config .hybrid_search .alpha = (
724+ api_retrival_config .hybrid_search .alpha = (
726725 rag_retrieval_config .hybrid_search .alpha
727726 )
728- else :
729- api_retrival_config .hybrid_search .alpha = vector_search_alpha
727+ else :
728+ api_retrival_config .hybrid_search .alpha = vector_search_alpha
730729
731- if (
730+ if (
732731 rag_retrieval_config .filter
733732 and rag_retrieval_config .filter .vector_distance_threshold
734733 and rag_retrieval_config .filter .vector_similarity_threshold
735734 ):
736- raise ValueError (
735+ raise ValueError (
737736 "Only one of vector_distance_threshold or"
738737 " vector_similarity_threshold can be specified at a time"
739738 " in rag_retrieval_config."
740739 )
741740
742- if (
741+ if (
743742 rag_retrieval_config .filter
744743 and rag_retrieval_config .filter .vector_distance_threshold
745744 ):
746- api_retrival_config .filter .vector_distance_threshold = (
745+ api_retrival_config .filter .vector_distance_threshold = (
747746 rag_retrieval_config .filter .vector_distance_threshold
748747 )
749- else :
750- api_retrival_config .filter .vector_distance_threshold = (
748+ else :
749+ api_retrival_config .filter .vector_distance_threshold = (
751750 vector_distance_threshold
752751 )
753752
754- if (
753+ if (
755754 rag_retrieval_config .filter
756755 and rag_retrieval_config .filter .vector_similarity_threshold
757756 ):
758- api_retrival_config .filter .vector_similarity_threshold = (
757+ api_retrival_config .filter .vector_similarity_threshold = (
759758 rag_retrieval_config .filter .vector_similarity_threshold
760759 )
761760
762- if (
761+ if (
763762 rag_retrieval_config .ranking
764763 and rag_retrieval_config .ranking .rank_service
765764 and rag_retrieval_config .ranking .llm_ranker
766765 ):
767- raise ValueError ("Only one of rank_service and llm_ranker can be set." )
768- if rag_retrieval_config .ranking and rag_retrieval_config .ranking .rank_service :
769- api_retrival_config .ranking .rank_service .model_name = (
766+ raise ValueError ("Only one of rank_service and llm_ranker can be set." )
767+ if rag_retrieval_config .ranking and rag_retrieval_config .ranking .rank_service :
768+ api_retrival_config .ranking .rank_service .model_name = (
770769 rag_retrieval_config .ranking .rank_service .model_name
771770 )
772- elif rag_retrieval_config .ranking and rag_retrieval_config .ranking .llm_ranker :
773- api_retrival_config .ranking .llm_ranker .model_name = (
771+ elif rag_retrieval_config .ranking and rag_retrieval_config .ranking .llm_ranker :
772+ api_retrival_config .ranking .llm_ranker .model_name = (
774773 rag_retrieval_config .ranking .llm_ranker .model_name
775774 )
776- if rag_retrieval_config .filter and rag_retrieval_config .filter .metadata_filter :
777- api_retrival_config .filter .metadata_filter = (
775+ if rag_retrieval_config .filter and rag_retrieval_config .filter .metadata_filter :
776+ api_retrival_config .filter .metadata_filter = (
778777 rag_retrieval_config .filter .metadata_filter
779778 )
780779
781- query = aiplatform_v1beta1 .RagQuery (
780+ query = aiplatform_v1beta1 .RagQuery (
782781 text = text ,
783782 rag_retrieval_config = api_retrival_config ,
784783 )
785784
786- vertex_rag_store .rag_retrieval_config = api_retrival_config
785+ vertex_rag_store .rag_retrieval_config = api_retrival_config
787786
788- tool = aiplatform_v1beta1 .Tool (
787+ tool = aiplatform_v1beta1 .Tool (
789788 retrieval = aiplatform_v1beta1 .Retrieval (
790789 vertex_rag_store = vertex_rag_store ,
791790 )
792791 )
793792
794- request = aiplatform_v1beta1 .AskContextsRequest (
793+ request = aiplatform_v1beta1 .AskContextsRequest (
795794 parent = parent ,
796795 query = query ,
797796 tools = [tool ],
798797 )
799- try :
800- response = client .ask_contexts (request = request , timeout = timeout )
801- except Exception as e :
802- raise RuntimeError ("Failed in asking contexts due to: " , e ) from e
803-
804- return response
798+ try :
799+ response = client .ask_contexts (request = request , timeout = timeout )
800+ except Exception as e :
801+ print (f"DEBUG: ask_contexts failed with error: { e } " )
802+ if hasattr (e , "trailing_metadata" ):
803+ metadata = e .trailing_metadata ()
804+ print (f"DEBUG: trailing metadata: { metadata } " )
805+ for key , value in metadata :
806+ if key == "x-google-trace-id" or "trace" in key :
807+ print (f"DEBUG: Found trace header: { key } ={ value } " )
808+ raise RuntimeError ("Failed in asking contexts due to: " , e ) from e
809+
810+ return response
0 commit comments