diff --git a/python/pathway/xpacks/llm/question_answering.py b/python/pathway/xpacks/llm/question_answering.py index 6bb607a0f..12fe609b7 100644 --- a/python/pathway/xpacks/llm/question_answering.py +++ b/python/pathway/xpacks/llm/question_answering.py @@ -462,6 +462,11 @@ class BaseRAGQuestionAnswerer(SummaryQuestionAnswerer): search_topk: Top k parameter for the retrieval. Adjusts number of chunks in the context. rerank_topk: Number of top-scoring documents to retain after reranking, when a reranker is provided. If ``None``, reranking is disabled. Defaults to ``None``. + query_transformer: Transformer for query rewriting before retrieval. + Can be a ``pw.UDF``, callable, or ``None`` to skip query transformation. + Available options: ``pathway.xpacks.llm.prompts.prompt_query_rewrite``, + ``pathway.xpacks.llm.prompts.prompt_query_rewrite_hyde``. + Defaults to ``None``. Example: @@ -523,11 +528,13 @@ def __init__( search_topk: int = 6, reranker: pw.UDF | None = None, rerank_topk: int | None = None, + query_transformer: pw.UDF | Callable | None = None, ) -> None: self.llm = llm self.indexer = indexer self.reranker = reranker + self.query_transformer = query_transformer if default_llm_name is None: default_llm_name = llm.model @@ -637,7 +644,10 @@ def add_score_to_doc(doc: pw.Json, score: float) -> dict: @pw.table_transformer def answer_query(self, pw_ai_queries: pw.Table) -> pw.Table: """Answer a question based on the available information.""" - + if self.query_transformer is not None: + pw_ai_queries += pw_ai_queries.select( + prompt=self.query_transformer(pw.this.prompt) + ) pw_ai_results = pw_ai_queries + self.indexer.retrieve_query( pw_ai_queries.select( metadata_filter=pw.this.filters,