diff --git a/backend/api_v2/api_deployment_views.py b/backend/api_v2/api_deployment_views.py index 7462864f04..25d264ba25 100644 --- a/backend/api_v2/api_deployment_views.py +++ b/backend/api_v2/api_deployment_views.py @@ -214,9 +214,14 @@ def get( response_status = status.HTTP_422_UNPROCESSABLE_ENTITY if execution_status_value == CeleryTaskState.COMPLETED.value: response_status = status.HTTP_200_OK - # Check if highlight data should be removed using configuration registry + # Ensure workflow identification keys are always in item metadata api_deployment = deployment_execution_dto.api organization = api_deployment.organization if api_deployment else None + org_id = str(organization.organization_id) if organization else "" + DeploymentHelper._enrich_result_with_workflow_metadata( + response, organization_id=org_id + ) + # Check if highlight data should be removed using configuration registry enable_highlight = False # Safe default if the key is unavailable (e.g., OSS) # Check if the configuration key exists (Cloud deployment) or use settings (OSS) from configuration.config_registry import ConfigurationRegistry @@ -231,8 +236,10 @@ def get( if not enable_highlight: response.remove_result_metadata_keys(["highlight_data"]) response.remove_result_metadata_keys(["extracted_text"]) - if not include_metadata: - response.remove_result_metadata_keys() + if include_metadata or include_metrics: + DeploymentHelper._enrich_result_with_usage_metadata(response) + if not include_metadata and not include_metrics: + response.remove_inner_result_metadata() if not include_metrics: response.remove_result_metrics() return Response( diff --git a/backend/api_v2/deployment_helper.py b/backend/api_v2/deployment_helper.py index bfbff58b7b..8b1c540347 100644 --- a/backend/api_v2/deployment_helper.py +++ b/backend/api_v2/deployment_helper.py @@ -258,8 +258,11 @@ def execute_workflow( result.status_api = DeploymentHelper.construct_status_endpoint( api_endpoint=api.api_endpoint, execution_id=execution_id ) - # Check if highlight data should be removed using configuration registry + # Ensure workflow identification keys are always in item metadata organization = api.organization if api else None + org_id = str(organization.organization_id) if organization else "" + cls._enrich_result_with_workflow_metadata(result, organization_id=org_id) + # Check if highlight data should be removed using configuration registry enable_highlight = False # Safe default if the key is unavailable (e.g., OSS) from configuration.config_registry import ConfigurationRegistry @@ -273,8 +276,10 @@ def execute_workflow( if not enable_highlight: result.remove_result_metadata_keys(["highlight_data"]) result.remove_result_metadata_keys(["extracted_text"]) - if not include_metadata: - result.remove_result_metadata_keys() + if include_metadata or include_metrics: + cls._enrich_result_with_usage_metadata(result) + if not include_metadata and not include_metrics: + result.remove_inner_result_metadata() if not include_metrics: result.remove_result_metrics() except Exception as error: @@ -293,6 +298,120 @@ def execute_workflow( ) return APIExecutionResponseSerializer(result).data + @staticmethod + def _enrich_result_with_usage_metadata(result: ExecutionResponse) -> None: + """Enrich each file result's metadata with usage data. + + For each file_execution_id: + 1. Injects per-model cost arrays (extraction_llm, challenge_llm, + embedding) into item["result"]["metadata"]. + 2. Injects aggregated usage totals into item["metadata"]["usage"], + matching the legacy response format. + """ + if not isinstance(result.result, list): + return + + from usage_v2.helper import UsageHelper + + for item in result.result: + if not isinstance(item, dict): + continue + file_exec_id = item.get("file_execution_id") + if not file_exec_id: + continue + + # Enrich inner result metadata with per-model breakdown + inner_result = item.get("result") + if isinstance(inner_result, dict): + metadata = inner_result.get("metadata") + if isinstance(metadata, dict): + usage_by_model = UsageHelper.get_usage_by_model(file_exec_id) + if usage_by_model: + metadata.update(usage_by_model) + + # Enrich top-level item metadata with aggregated usage + item_metadata = item.get("metadata") + if isinstance(item_metadata, dict): + aggregated = UsageHelper.get_aggregated_token_count(file_exec_id) + if aggregated: + aggregated["file_execution_id"] = file_exec_id + item_metadata["usage"] = aggregated + + @staticmethod + def _enrich_result_with_workflow_metadata( + result: ExecutionResponse, + organization_id: str, + ) -> None: + """Ensure workflow identification keys are always present in item metadata. + + Uses setdefault() — fills in MISSING keys only, never overwrites + values already present from the workers cache. + """ + if not isinstance(result.result, list): + return + + from workflow_manager.file_execution.models import WorkflowFileExecution + + # 1. Collect file_execution_ids + file_exec_ids = [ + item.get("file_execution_id") + for item in result.result + if isinstance(item, dict) and item.get("file_execution_id") + ] + if not file_exec_ids: + return + + # 2. Batch query (single JOIN query for all file executions) + fe_lookup = { + str(fe.id): fe + for fe in WorkflowFileExecution.objects.filter( + id__in=file_exec_ids + ).select_related("workflow_execution") + } + + # 3. Get execution-level data (tags) — one M2M query + workflow_execution = None + tag_names: list[str] = [] + if fe_lookup: + first_fe = next(iter(fe_lookup.values())) + workflow_execution = first_fe.workflow_execution + tag_names = list( + workflow_execution.tags.values_list("name", flat=True) + ) + + # 4. Enrich each item + for item in result.result: + if not isinstance(item, dict): + continue + file_exec_id = item.get("file_execution_id") + if not file_exec_id: + continue + + # Ensure metadata dict exists + if not isinstance(item.get("metadata"), dict): + item["metadata"] = {} + metadata = item["metadata"] + + fe = fe_lookup.get(str(file_exec_id)) + we = fe.workflow_execution if fe else workflow_execution + + # Fill MISSING keys only (setdefault won't overwrite) + if fe: + metadata.setdefault("source_name", fe.file_name) + metadata.setdefault("source_hash", fe.file_hash or "") + metadata.setdefault("file_execution_id", str(fe.id)) + metadata.setdefault("total_elapsed_time", fe.execution_time) + if we: + metadata.setdefault("workflow_id", str(we.workflow_id)) + metadata.setdefault("execution_id", str(we.id)) + metadata.setdefault( + "workflow_start_time", + we.created_at.timestamp() if we.created_at else None, + ) + + metadata.setdefault("organization_id", organization_id) + metadata.setdefault("tags", tag_names) + @staticmethod def get_execution_status(execution_id: str) -> ExecutionResponse: """Current status of api execution. diff --git a/backend/backend/celery_config.py b/backend/backend/celery_config.py index f8833556e7..563fe9126f 100644 --- a/backend/backend/celery_config.py +++ b/backend/backend/celery_config.py @@ -31,3 +31,8 @@ class CeleryConfig: beat_scheduler = "django_celery_beat.schedulers:DatabaseScheduler" task_acks_late = True + + # Prompt Studio IDE callback tasks (ide_index_complete, ide_prompt_complete, etc.) + # run on the "prompt_studio_callback" queue, processed by a dedicated Django + # backend Celery worker (worker-prompt-studio-callback in docker-compose). + # These are sub-second ORM writes + Socket.IO emits after executor completion. diff --git a/backend/backend/worker_celery.py b/backend/backend/worker_celery.py new file mode 100644 index 0000000000..018f3d485b --- /dev/null +++ b/backend/backend/worker_celery.py @@ -0,0 +1,109 @@ +"""Lightweight Celery app for dispatching tasks to worker-v2 workers. + +The Django backend already has a Celery app for internal tasks (beat, +periodic tasks, etc.) whose broker URL is set via CELERY_BROKER_URL. +Workers use the same broker. This module provides a second Celery app +instance that reuses the same broker URL (from Django settings) but +bypasses Celery's env-var-takes-priority behaviour so it can coexist +with the main Django Celery app in the same process. + +Problem: Celery reads the ``CELERY_BROKER_URL`` environment variable +with highest priority — overriding constructor args, ``conf.update()``, +and ``config_from_object()``. + +Solution: Subclass Celery and override ``connection_for_write`` / +``connection_for_read`` so they always use our explicit broker URL, +bypassing the config resolution chain entirely. +""" + +import logging +from urllib.parse import quote_plus + +from celery import Celery +from django.conf import settings +from kombu import Queue + +logger = logging.getLogger(__name__) + +_worker_app: Celery | None = None + + +class _WorkerDispatchCelery(Celery): + """Celery subclass that forces an explicit broker URL. + + Works around Celery's env-var-takes-priority behaviour where + ``CELERY_BROKER_URL`` always overrides per-app configuration. + The connection methods are the actual points where Celery opens + AMQP/Redis connections, so overriding them is both sufficient + and safe. + """ + + _explicit_broker: str | None = None + + def connection_for_write(self, url=None, *args, **kwargs): + return super().connection_for_write( + url=url or self._explicit_broker, *args, **kwargs + ) + + def connection_for_read(self, url=None, *args, **kwargs): + return super().connection_for_read( + url=url or self._explicit_broker, *args, **kwargs + ) + + +def get_worker_celery_app() -> Celery: + """Get or create a Celery app for dispatching to worker-v2 workers. + + The app uses: + - Same broker as the workers (built from CELERY_BROKER_BASE_URL, + CELERY_BROKER_USER, CELERY_BROKER_PASS via Django settings) + - Same PostgreSQL result backend as the Django Celery app + + Returns: + Celery app configured for worker-v2 dispatch. + """ + global _worker_app + if _worker_app is not None: + return _worker_app + + # Reuse the broker URL already built by Django settings (base.py) + # from CELERY_BROKER_BASE_URL + CELERY_BROKER_USER + CELERY_BROKER_PASS + broker_url = settings.CELERY_BROKER_URL + + # Reuse the same PostgreSQL result backend as Django's Celery app + result_backend = ( + f"db+postgresql://{settings.DB_USER}:" + f"{quote_plus(settings.DB_PASSWORD)}" + f"@{settings.DB_HOST}:{settings.DB_PORT}/" + f"{settings.CELERY_BACKEND_DB_NAME}" + ) + + app = _WorkerDispatchCelery( + "worker-dispatch", + set_as_current=False, + fixups=[], + ) + # Store the explicit broker URL for use in connection overrides + app._explicit_broker = broker_url + + app.conf.update( + result_backend=result_backend, + task_queues=[Queue("executor")], + task_serializer="json", + accept_content=["json"], + result_serializer="json", + result_extended=True, + ) + + _worker_app = app + # Log broker host only (mask credentials) + safe_broker = broker_url.split("@")[-1] if "@" in broker_url else broker_url + safe_backend = ( + result_backend.split("@")[-1] if "@" in result_backend else result_backend + ) + logger.info( + "Created worker dispatch Celery app (broker=%s, result_backend=%s)", + safe_broker, + safe_backend, + ) + return _worker_app diff --git a/backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py b/backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py index 991adbcfcc..950ea8afe4 100644 --- a/backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py +++ b/backend/prompt_studio/prompt_studio_core_v2/prompt_studio_helper.py @@ -28,6 +28,7 @@ ExecutionSource, IndexingStatus, LogLevels, + ToolStudioKeys, ToolStudioPromptKeys, ) from prompt_studio.prompt_studio_core_v2.constants import IndexingConstants as IKeys @@ -65,8 +66,11 @@ ) from prompt_studio.prompt_studio_v2.models import ToolStudioPrompt from unstract.core.pubsub_helper import LogPublisher +from feature_flag.helper import FeatureFlagHelper from unstract.sdk1.constants import LogLevel from unstract.sdk1.exceptions import IndexingError, SdkError +from unstract.sdk1.execution.context import ExecutionContext +from unstract.sdk1.execution.dispatcher import ExecutionDispatcher from unstract.sdk1.file_storage.constants import StorageType from unstract.sdk1.file_storage.env_helper import EnvHelper from unstract.sdk1.prompt import PromptTool @@ -77,6 +81,7 @@ CHOICES_JSON = "/static/select_choices.json" ERROR_MSG = "User %s doesn't have access to adapter %s" +ASYNC_EXECUTION_FLAG = "async_prompt_execution" logger = logging.getLogger(__name__) @@ -181,6 +186,9 @@ def validate_profile_manager_owner_access( the action. """ profile_manager_owner = profile_manager.created_by + if profile_manager_owner is None: + # No owner on this profile manager — skip ownership validation + return is_llm_owned = ( profile_manager.llm.shared_to_org @@ -262,10 +270,605 @@ def _publish_log( component: dict[str, str], level: str, state: str, message: str ) -> None: LogPublisher.publish( - StateStore.get(Common.LOG_EVENTS_ID), - LogPublisher.log_prompt(component, level, state, message), + channel_id=StateStore.get(Common.LOG_EVENTS_ID), + payload=LogPublisher.log_progress(component, level, state, message), ) + @staticmethod + def _is_async_execution_enabled() -> bool: + """Check if the async execution feature flag is enabled.""" + try: + return FeatureFlagHelper.check_flag_status(ASYNC_EXECUTION_FLAG) + except Exception: + logger.warning("Feature flag check failed, falling back to sync flow") + return False + + @staticmethod + def _get_dispatcher() -> ExecutionDispatcher: + """Get an ExecutionDispatcher backed by the worker Celery app. + + Uses the RabbitMQ-backed Celery app (not the Django Redis one) + so tasks reach the worker-v2 executor worker. + """ + from backend.worker_celery import get_worker_celery_app + + return ExecutionDispatcher(celery_app=get_worker_celery_app()) + + @staticmethod + def _get_platform_api_key(org_id: str) -> str: + """Get the platform API key for the given organization.""" + from platform_settings_v2.platform_auth_service import ( + PlatformAuthenticationService, + ) + + platform_key = PlatformAuthenticationService.get_active_platform_key(org_id) + return str(platform_key.key) + + # ------------------------------------------------------------------ + # Phase 5B — Payload builders for fire-and-forget dispatch + # ------------------------------------------------------------------ + + @staticmethod + def build_index_payload( + tool_id: str, + file_name: str, + org_id: str, + user_id: str, + document_id: str, + run_id: str, + ) -> tuple[ExecutionContext, dict[str, Any]]: + """Build ide_index ExecutionContext for fire-and-forget dispatch. + + Does ORM validation and summarization synchronously, then returns + the execution context so the caller can dispatch with callbacks. + """ + tool: CustomTool = CustomTool.objects.get(pk=tool_id) + file_path = PromptStudioFileHelper.get_or_create_prompt_studio_subdirectory( + org_id, + is_create=False, + user_id=user_id, + tool_id=tool_id, + ) + file_path = str(Path(file_path) / file_name) + + default_profile = ProfileManager.get_default_llm_profile(tool) + if not tool: + raise ToolNotValid() + + PromptStudioHelper.validate_adapter_status(default_profile) + PromptStudioHelper.validate_profile_manager_owner_access(default_profile) + + # Handle summarization synchronously (uses Django plugin) + if tool.summarize_context: + SummarizeMigrationUtils.migrate_tool_to_adapter_based(tool) + summary_profile = default_profile + if not tool.summarize_llm_adapter: + try: + sp = ProfileManager.objects.get( + prompt_studio_tool=tool, is_summarize_llm=True + ) + sp.chunk_size = 0 + summary_profile = sp + except ProfileManager.DoesNotExist: + pass + + if summary_profile != default_profile: + PromptStudioHelper.validate_adapter_status(summary_profile) + PromptStudioHelper.validate_profile_manager_owner_access(summary_profile) + + summarize_file_path = PromptStudioHelper.summarize( + file_name, org_id, run_id, tool + ) + fs_instance = EnvHelper.get_storage( + storage_type=StorageType.PERMANENT, + env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE, + ) + util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id) + summarize_doc_id = IndexingUtils.generate_index_key( + vector_db=str(summary_profile.vector_store.id), + embedding=str(summary_profile.embedding_model.id), + x2text=str(summary_profile.x2text.id), + chunk_size="0", + chunk_overlap=str(summary_profile.chunk_overlap), + file_path=summarize_file_path, + fs=fs_instance, + tool=util, + ) + PromptStudioIndexHelper.handle_index_manager( + document_id=document_id, + is_summary=True, + profile_manager=summary_profile, + doc_id=summarize_doc_id, + ) + + # Generate doc_id for indexing tracking + fs_instance = EnvHelper.get_storage( + storage_type=StorageType.PERMANENT, + env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE, + ) + util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id) + doc_id_key = IndexingUtils.generate_index_key( + vector_db=str(default_profile.vector_store.id), + embedding=str(default_profile.embedding_model.id), + x2text=str(default_profile.x2text.id), + chunk_size=str(default_profile.chunk_size), + chunk_overlap=str(default_profile.chunk_overlap), + file_path=file_path, + file_hash=None, + fs=fs_instance, + tool=util, + ) + + # Mark as indexing in progress + DocumentIndexingService.set_document_indexing( + org_id=org_id, user_id=user_id, doc_id_key=doc_id_key + ) + + # Build extract params + directory, filename = os.path.split(file_path) + extract_file_path = os.path.join( + directory, "extract", os.path.splitext(filename)[0] + ".txt" + ) + platform_api_key = PromptStudioHelper._get_platform_api_key(org_id) + usage_kwargs = {"run_id": run_id, "file_name": filename} + + from prompt_studio.prompt_studio_core_v2.constants import ( + IndexingConstants as IKeys, + ) + + extract_params = { + IKeys.X2TEXT_INSTANCE_ID: str(default_profile.x2text.id), + IKeys.FILE_PATH: file_path, + IKeys.ENABLE_HIGHLIGHT: tool.enable_highlight, + IKeys.OUTPUT_FILE_PATH: extract_file_path, + "platform_api_key": platform_api_key, + IKeys.USAGE_KWARGS: usage_kwargs, + } + + index_params = { + IKeys.TOOL_ID: tool_id, + IKeys.EMBEDDING_INSTANCE_ID: str(default_profile.embedding_model.id), + IKeys.VECTOR_DB_INSTANCE_ID: str(default_profile.vector_store.id), + IKeys.X2TEXT_INSTANCE_ID: str(default_profile.x2text.id), + IKeys.FILE_PATH: extract_file_path, + IKeys.FILE_HASH: None, + IKeys.CHUNK_OVERLAP: default_profile.chunk_overlap, + IKeys.CHUNK_SIZE: default_profile.chunk_size, + IKeys.REINDEX: True, + IKeys.ENABLE_HIGHLIGHT: tool.enable_highlight, + IKeys.USAGE_KWARGS: usage_kwargs, + IKeys.RUN_ID: run_id, + TSPKeys.EXECUTION_SOURCE: ExecutionSource.IDE.value, + "platform_api_key": platform_api_key, + } + + log_events_id = StateStore.get(Common.LOG_EVENTS_ID) or "" + request_id = StateStore.get(Common.REQUEST_ID) or "" + + context = ExecutionContext( + executor_name="legacy", + operation="ide_index", + run_id=run_id or str(uuid.uuid4()), + execution_source="ide", + organization_id=org_id, + executor_params={ + "extract_params": extract_params, + "index_params": index_params, + }, + request_id=request_id, + log_events_id=log_events_id, + ) + + # x2text config hash for extraction status tracking in callback + x2text_metadata = default_profile.x2text.metadata or {} + x2text_config_hash = ToolUtils.hash_str( + json.dumps(x2text_metadata, sort_keys=True) + ) + + cb_kwargs = { + "log_events_id": log_events_id, + "request_id": request_id, + "org_id": org_id, + "user_id": user_id, + "document_id": document_id, + "doc_id_key": doc_id_key, + "profile_manager_id": str(default_profile.profile_id), + "tool_id": tool_id, + "run_id": run_id, + "file_name": file_name, + "x2text_config_hash": x2text_config_hash, + "enable_highlight": tool.enable_highlight, + } + + return context, cb_kwargs + + @staticmethod + def build_fetch_response_payload( + tool: CustomTool, + doc_path: str, + doc_name: str, + prompt: ToolStudioPrompt, + org_id: str, + user_id: str, + document_id: str, + run_id: str, + profile_manager_id: str | None = None, + ) -> tuple[ExecutionContext | None, dict[str, Any]]: + """Build answer_prompt ExecutionContext for fire-and-forget dispatch. + + Does ORM work, extraction, and indexing synchronously. Only the + LLM answer_prompt call is dispatched asynchronously. + + Returns: + (context, cb_kwargs) or (None, pending_response_dict) + """ + profile_manager = prompt.profile_manager + if profile_manager_id: + profile_manager = ProfileManagerHelper.get_profile_manager( + profile_manager_id=profile_manager_id + ) + + monitor_llm_instance: AdapterInstance | None = tool.monitor_llm + monitor_llm: str | None = None + challenge_llm_instance: AdapterInstance | None = tool.challenge_llm + challenge_llm: str | None = None + if monitor_llm_instance: + monitor_llm = str(monitor_llm_instance.id) + else: + dp = ProfileManager.get_default_llm_profile(tool) + monitor_llm = str(dp.llm.id) + + if challenge_llm_instance: + challenge_llm = str(challenge_llm_instance.id) + else: + dp = ProfileManager.get_default_llm_profile(tool) + challenge_llm = str(dp.llm.id) + + PromptStudioHelper.validate_adapter_status(profile_manager) + PromptStudioHelper.validate_profile_manager_owner_access(profile_manager) + + if not profile_manager: + raise DefaultProfileError() + + vector_db = str(profile_manager.vector_store.id) + embedding_model = str(profile_manager.embedding_model.id) + llm = str(profile_manager.llm.id) + x2text = str(profile_manager.x2text.id) + + fs_instance = EnvHelper.get_storage( + storage_type=StorageType.PERMANENT, + env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE, + ) + util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id) + file_path = doc_path + directory, filename = os.path.split(doc_path) + extract_path = os.path.join( + directory, "extract", os.path.splitext(filename)[0] + ".txt" + ) + + doc_id = IndexingUtils.generate_index_key( + vector_db=vector_db, + embedding=embedding_model, + x2text=x2text, + chunk_size=str(profile_manager.chunk_size), + chunk_overlap=str(profile_manager.chunk_overlap), + file_path=file_path, + file_hash=None, + fs=fs_instance, + tool=util, + ) + + # Extract (blocking, usually cached) + extracted_text = PromptStudioHelper.dynamic_extractor( + profile_manager=profile_manager, + file_path=file_path, + org_id=org_id, + document_id=document_id, + run_id=run_id, + enable_highlight=tool.enable_highlight, + ) + + is_summary = tool.summarize_as_source + if is_summary: + profile_manager.chunk_size = 0 + p = Path(extract_path) + extract_path = str(p.parent.parent / "summarize" / (p.stem + ".txt")) + + # Index (blocking, usually cached) + index_result = PromptStudioHelper.dynamic_indexer( + profile_manager=profile_manager, + tool_id=str(tool.tool_id), + file_path=file_path, + org_id=org_id, + document_id=document_id, + run_id=run_id, + user_id=user_id, + enable_highlight=tool.enable_highlight, + extracted_text=extracted_text, + doc_id_key=doc_id, + ) + + if index_result.get("status") == IndexingStatus.PENDING_STATUS.value: + return None, { + "status": IndexingStatus.PENDING_STATUS.value, + "message": IndexingStatus.DOCUMENT_BEING_INDEXED.value, + } + + # Build outputs + tool_id = str(tool.tool_id) + output: dict[str, Any] = {} + outputs: list[dict[str, Any]] = [] + grammar_list: list[dict[str, Any]] = [] + prompt_grammer = tool.prompt_grammer + if prompt_grammer: + for word, synonyms in prompt_grammer.items(): + grammar_list.append({TSPKeys.WORD: word, TSPKeys.SYNONYMS: synonyms}) + + output[TSPKeys.PROMPT] = prompt.prompt + output[TSPKeys.ACTIVE] = prompt.active + output[TSPKeys.REQUIRED] = prompt.required + output[TSPKeys.CHUNK_SIZE] = profile_manager.chunk_size + output[TSPKeys.VECTOR_DB] = vector_db + output[TSPKeys.EMBEDDING] = embedding_model + output[TSPKeys.CHUNK_OVERLAP] = profile_manager.chunk_overlap + output[TSPKeys.LLM] = llm + output[TSPKeys.TYPE] = prompt.enforce_type + output[TSPKeys.NAME] = prompt.prompt_key + output[TSPKeys.RETRIEVAL_STRATEGY] = profile_manager.retrieval_strategy + output[TSPKeys.SIMILARITY_TOP_K] = profile_manager.similarity_top_k + output[TSPKeys.SECTION] = profile_manager.section + output[TSPKeys.X2TEXT_ADAPTER] = x2text + + webhook_enabled = bool(prompt.enable_postprocessing_webhook) + webhook_url = (prompt.postprocessing_webhook_url or "").strip() + if webhook_enabled and not webhook_url: + webhook_enabled = False + output[TSPKeys.ENABLE_POSTPROCESSING_WEBHOOK] = webhook_enabled + if webhook_enabled: + output[TSPKeys.POSTPROCESSING_WEBHOOK_URL] = webhook_url + + output[TSPKeys.EVAL_SETTINGS] = {} + output[TSPKeys.EVAL_SETTINGS][TSPKeys.EVAL_SETTINGS_EVALUATE] = prompt.evaluate + output[TSPKeys.EVAL_SETTINGS][TSPKeys.EVAL_SETTINGS_MONITOR_LLM] = [monitor_llm] + output[TSPKeys.EVAL_SETTINGS][TSPKeys.EVAL_SETTINGS_EXCLUDE_FAILED] = ( + tool.exclude_failed + ) + for attr in dir(prompt): + if attr.startswith(TSPKeys.EVAL_METRIC_PREFIX): + output[TSPKeys.EVAL_SETTINGS][attr] = getattr(prompt, attr) + + output = PromptStudioHelper.fetch_table_settings_if_enabled( + doc_name, prompt, org_id, user_id, tool_id, output + ) + variable_map = PromptStudioVariableService.frame_variable_replacement_map( + doc_id=document_id, prompt_object=prompt + ) + if variable_map: + output[TSPKeys.VARIABLE_MAP] = variable_map + outputs.append(output) + + tool_settings: dict[str, Any] = {} + tool_settings[TSPKeys.ENABLE_CHALLENGE] = tool.enable_challenge + tool_settings[TSPKeys.CHALLENGE_LLM] = challenge_llm + tool_settings[TSPKeys.SINGLE_PASS_EXTRACTION_MODE] = ( + tool.single_pass_extraction_mode + ) + tool_settings[TSPKeys.SUMMARIZE_AS_SOURCE] = tool.summarize_as_source + tool_settings[TSPKeys.PREAMBLE] = tool.preamble + tool_settings[TSPKeys.POSTAMBLE] = tool.postamble + tool_settings[TSPKeys.GRAMMAR] = grammar_list + tool_settings[TSPKeys.ENABLE_HIGHLIGHT] = tool.enable_highlight + tool_settings[TSPKeys.ENABLE_WORD_CONFIDENCE] = tool.enable_word_confidence + tool_settings[TSPKeys.PLATFORM_POSTAMBLE] = getattr( + settings, TSPKeys.PLATFORM_POSTAMBLE.upper(), "" + ) + tool_settings[TSPKeys.WORD_CONFIDENCE_POSTAMBLE] = getattr( + settings, TSPKeys.WORD_CONFIDENCE_POSTAMBLE.upper(), "" + ) + + file_hash = fs_instance.get_hash_from_file(path=extract_path) + + payload: dict[str, Any] = { + TSPKeys.TOOL_SETTINGS: tool_settings, + TSPKeys.OUTPUTS: outputs, + TSPKeys.TOOL_ID: tool_id, + TSPKeys.RUN_ID: run_id, + TSPKeys.FILE_NAME: doc_name, + TSPKeys.FILE_HASH: file_hash, + TSPKeys.FILE_PATH: extract_path, + Common.LOG_EVENTS_ID: StateStore.get(Common.LOG_EVENTS_ID), + TSPKeys.EXECUTION_SOURCE: ExecutionSource.IDE.value, + TSPKeys.CUSTOM_DATA: tool.custom_data, + } + + platform_api_key = PromptStudioHelper._get_platform_api_key(org_id) + payload[ToolStudioKeys.PLATFORM_SERVICE_API_KEY] = platform_api_key + payload[TSPKeys.INCLUDE_METADATA] = True + + log_events_id = StateStore.get(Common.LOG_EVENTS_ID) or "" + request_id = StateStore.get(Common.REQUEST_ID) or "" + + context = ExecutionContext( + executor_name="legacy", + operation="answer_prompt", + run_id=run_id, + execution_source="ide", + organization_id=org_id, + executor_params=payload, + request_id=request_id, + log_events_id=log_events_id, + ) + + cb_kwargs = { + "log_events_id": log_events_id, + "request_id": request_id, + "org_id": org_id, + "operation": "fetch_response", + "run_id": run_id, + "document_id": document_id, + "tool_id": tool_id, + "prompt_ids": [str(prompt.prompt_id)], + "profile_manager_id": profile_manager_id, + "is_single_pass": False, + } + + return context, cb_kwargs + + @staticmethod + def build_single_pass_payload( + tool: CustomTool, + doc_path: str, + doc_name: str, + prompts: list[ToolStudioPrompt], + org_id: str, + document_id: str, + run_id: str, + ) -> tuple[ExecutionContext, dict[str, Any]]: + """Build single_pass_extraction ExecutionContext. + + Does ORM work and extraction synchronously. Only the LLM + single-pass call is dispatched asynchronously. + """ + tool_id = str(tool.tool_id) + outputs: list[dict[str, Any]] = [] + grammar: list[dict[str, Any]] = [] + prompt_grammar = tool.prompt_grammer + default_profile = ProfileManager.get_default_llm_profile(tool) + + challenge_llm_instance: AdapterInstance | None = tool.challenge_llm + challenge_llm: str | None = None + if challenge_llm_instance: + challenge_llm = str(challenge_llm_instance.id) + else: + challenge_llm = str(default_profile.llm.id) + + PromptStudioHelper.validate_adapter_status(default_profile) + PromptStudioHelper.validate_profile_manager_owner_access(default_profile) + default_profile.chunk_size = 0 + + if not default_profile: + raise DefaultProfileError() + + if prompt_grammar: + for word, synonyms in prompt_grammar.items(): + grammar.append({TSPKeys.WORD: word, TSPKeys.SYNONYMS: synonyms}) + + fs_instance = EnvHelper.get_storage( + storage_type=StorageType.PERMANENT, + env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE, + ) + directory, filename = os.path.split(doc_path) + file_path = os.path.join( + directory, "extract", os.path.splitext(filename)[0] + ".txt" + ) + + # Extract (blocking, usually cached) + PromptStudioHelper.dynamic_extractor( + profile_manager=default_profile, + file_path=doc_path, + org_id=org_id, + document_id=document_id, + run_id=run_id, + enable_highlight=tool.enable_highlight, + ) + + vector_db = str(default_profile.vector_store.id) + embedding_model = str(default_profile.embedding_model.id) + llm = str(default_profile.llm.id) + x2text = str(default_profile.x2text.id) + + tool_settings: dict[str, Any] = { + TSPKeys.PREAMBLE: tool.preamble, + TSPKeys.POSTAMBLE: tool.postamble, + TSPKeys.GRAMMAR: grammar, + TSPKeys.LLM: llm, + TSPKeys.X2TEXT_ADAPTER: x2text, + TSPKeys.VECTOR_DB: vector_db, + TSPKeys.EMBEDDING: embedding_model, + TSPKeys.CHUNK_SIZE: default_profile.chunk_size, + TSPKeys.CHUNK_OVERLAP: default_profile.chunk_overlap, + TSPKeys.ENABLE_CHALLENGE: tool.enable_challenge, + TSPKeys.ENABLE_HIGHLIGHT: tool.enable_highlight, + TSPKeys.ENABLE_WORD_CONFIDENCE: tool.enable_word_confidence, + TSPKeys.CHALLENGE_LLM: challenge_llm, + TSPKeys.PLATFORM_POSTAMBLE: getattr( + settings, TSPKeys.PLATFORM_POSTAMBLE.upper(), "" + ), + TSPKeys.WORD_CONFIDENCE_POSTAMBLE: getattr( + settings, TSPKeys.WORD_CONFIDENCE_POSTAMBLE.upper(), "" + ), + TSPKeys.SUMMARIZE_AS_SOURCE: tool.summarize_as_source, + TSPKeys.RETRIEVAL_STRATEGY: default_profile.retrieval_strategy + or TSPKeys.SIMPLE, + TSPKeys.SIMILARITY_TOP_K: default_profile.similarity_top_k, + } + + for p in prompts: + if not p.prompt: + raise EmptyPromptError() + outputs.append( + { + TSPKeys.PROMPT: p.prompt, + TSPKeys.ACTIVE: p.active, + TSPKeys.TYPE: p.enforce_type, + TSPKeys.NAME: p.prompt_key, + } + ) + + if tool.summarize_as_source: + path_obj = Path(file_path) + file_path = str( + path_obj.parent.parent / TSPKeys.SUMMARIZE / (path_obj.stem + ".txt") + ) + + file_hash = fs_instance.get_hash_from_file(path=file_path) + + payload: dict[str, Any] = { + TSPKeys.TOOL_SETTINGS: tool_settings, + TSPKeys.OUTPUTS: outputs, + TSPKeys.TOOL_ID: tool_id, + TSPKeys.RUN_ID: run_id, + TSPKeys.FILE_HASH: file_hash, + TSPKeys.FILE_NAME: doc_name, + TSPKeys.FILE_PATH: file_path, + Common.LOG_EVENTS_ID: StateStore.get(Common.LOG_EVENTS_ID), + TSPKeys.EXECUTION_SOURCE: ExecutionSource.IDE.value, + TSPKeys.CUSTOM_DATA: tool.custom_data, + } + + platform_api_key = PromptStudioHelper._get_platform_api_key(org_id) + payload[ToolStudioKeys.PLATFORM_SERVICE_API_KEY] = platform_api_key + payload[TSPKeys.INCLUDE_METADATA] = True + + log_events_id = StateStore.get(Common.LOG_EVENTS_ID) or "" + request_id = StateStore.get(Common.REQUEST_ID) or "" + + context = ExecutionContext( + executor_name="legacy", + operation="single_pass_extraction", + run_id=run_id or str(uuid.uuid4()), + execution_source="ide", + organization_id=org_id, + executor_params=payload, + request_id=request_id, + log_events_id=log_events_id, + ) + + cb_kwargs = { + "log_events_id": log_events_id, + "request_id": request_id, + "org_id": org_id, + "operation": "single_pass_extraction", + "run_id": run_id, + "document_id": document_id, + "tool_id": tool_id, + "prompt_ids": [str(p.prompt_id) for p in prompts], + "profile_manager_id": str(default_profile.profile_id), + "is_single_pass": True, + } + + return context, cb_kwargs + @staticmethod def get_select_fields() -> dict[str, Any]: """Method to fetch dropdown field values for frontend. @@ -855,13 +1458,6 @@ def _fetch_response( fs=fs_instance, tool=util, ) - if DocumentIndexingService.is_document_indexing( - org_id=org_id, user_id=user_id, doc_id_key=doc_id - ): - return { - "status": IndexingStatus.PENDING_STATUS.value, - "output": IndexingStatus.DOCUMENT_BEING_INDEXED.value, - } logger.info(f"Extracting text from {file_path} for {doc_id}") extracted_text = PromptStudioHelper.dynamic_extractor( profile_manager=profile_manager, @@ -994,24 +1590,50 @@ def _fetch_response( TSPKeys.CUSTOM_DATA: tool.custom_data, } - try: - responder = PromptTool( - tool=util, - prompt_host=settings.PROMPT_HOST, - prompt_port=settings.PROMPT_PORT, + if PromptStudioHelper._is_async_execution_enabled(): + # === NEW: ExecutionDispatcher → Celery executor worker === + platform_api_key = PromptStudioHelper._get_platform_api_key(org_id) + payload[ToolStudioKeys.PLATFORM_SERVICE_API_KEY] = platform_api_key + payload[TSPKeys.INCLUDE_METADATA] = True + + dispatcher = PromptStudioHelper._get_dispatcher() + context = ExecutionContext( + executor_name="legacy", + operation="answer_prompt", + run_id=run_id, + execution_source="ide", + organization_id=org_id, + executor_params=payload, request_id=StateStore.get(Common.REQUEST_ID), + log_events_id=StateStore.get(Common.LOG_EVENTS_ID), ) - params = {TSPKeys.INCLUDE_METADATA: True} - return responder.answer_prompt(payload=payload, params=params) - except SdkError as e: - msg = str(e) - if e.actual_err and hasattr(e.actual_err, "response"): - msg = e.actual_err.response.json().get("error", str(e)) - raise AnswerFetchError( - "Error while fetching response for " - f"'{prompt.prompt_key}' with '{doc_name}'. {msg}", - status_code=int(e.status_code or 500), - ) + result = dispatcher.dispatch(context) + if not result.success: + raise AnswerFetchError( + "Error while fetching response for " + f"'{prompt.prompt_key}' with '{doc_name}'. {result.error}", + ) + return result.data + else: + # === OLD: PromptTool HTTP → prompt-service === + try: + responder = PromptTool( + tool=util, + prompt_host=settings.PROMPT_HOST, + prompt_port=settings.PROMPT_PORT, + request_id=StateStore.get(Common.REQUEST_ID), + ) + params = {TSPKeys.INCLUDE_METADATA: True} + return responder.answer_prompt(payload=payload, params=params) + except SdkError as e: + msg = str(e) + if e.actual_err and hasattr(e.actual_err, "response"): + msg = e.actual_err.response.json().get("error", str(e)) + raise AnswerFetchError( + "Error while fetching response for " + f"'{prompt.prompt_key}' with '{doc_name}'. {msg}", + status_code=int(e.status_code or 500), + ) @staticmethod def fetch_table_settings_if_enabled( @@ -1140,9 +1762,31 @@ def dynamic_indexer( TSPKeys.EXECUTION_SOURCE: ExecutionSource.IDE.value, } - util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id) - - try: + if PromptStudioHelper._is_async_execution_enabled(): + # === NEW: ExecutionDispatcher → Celery executor worker === + platform_api_key = PromptStudioHelper._get_platform_api_key(org_id) + payload["platform_api_key"] = platform_api_key + + dispatcher = PromptStudioHelper._get_dispatcher() + index_context = ExecutionContext( + executor_name="legacy", + operation="index", + run_id=run_id or str(uuid.uuid4()), + execution_source="ide", + organization_id=org_id, + executor_params=payload, + request_id=StateStore.get(Common.REQUEST_ID), + log_events_id=StateStore.get(Common.LOG_EVENTS_ID), + ) + result = dispatcher.dispatch(index_context) + if not result.success: + raise IndexingAPIError( + f"Failed to index '{filename}'. {result.error}", + ) + doc_id = result.data.get("doc_id") + else: + # === OLD: PromptTool HTTP → prompt-service === + util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id) responder = PromptTool( tool=util, prompt_host=settings.PROMPT_HOST, @@ -1150,14 +1794,6 @@ def dynamic_indexer( request_id=StateStore.get(Common.REQUEST_ID), ) doc_id = responder.index(payload=payload) - except SdkError as e: - msg = str(e) - if e.actual_err and hasattr(e.actual_err, "response"): - msg = e.actual_err.response.json().get("error", str(e)) - raise IndexingAPIError( - f"Failed to index '{filename}'. {msg}", - status_code=int(e.status_code or 500), - ) PromptStudioIndexHelper.handle_index_manager( document_id=document_id, @@ -1169,6 +1805,13 @@ def dynamic_indexer( ) return {"status": IndexingStatus.COMPLETED_STATUS.value, "output": doc_id} except (IndexingError, IndexingAPIError, SdkError) as e: + # Clear the indexing flag so subsequent requests are not blocked + try: + DocumentIndexingService.remove_document_indexing( + org_id=org_id, user_id=user_id, doc_id_key=doc_id_key + ) + except Exception: + logger.exception("Failed to clear indexing flag for %s", doc_id_key) msg = str(e) if isinstance(e, SdkError) and hasattr(e.actual_err, "response"): msg = e.actual_err.response.json().get("error", str(e)) @@ -1221,7 +1864,6 @@ def _fetch_single_pass_response( storage_type=StorageType.PERMANENT, env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE, ) - util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id) directory, filename = os.path.split(input_file_path) file_path = os.path.join( directory, "extract", os.path.splitext(filename)[0] + ".txt" @@ -1260,6 +1902,10 @@ def _fetch_single_pass_response( settings, TSPKeys.WORD_CONFIDENCE_POSTAMBLE.upper(), "" ) tool_settings[TSPKeys.SUMMARIZE_AS_SOURCE] = tool.summarize_as_source + tool_settings[TSPKeys.RETRIEVAL_STRATEGY] = ( + default_profile.retrieval_strategy or TSPKeys.SIMPLE + ) + tool_settings[TSPKeys.SIMILARITY_TOP_K] = default_profile.similarity_top_k for prompt in prompts: if not prompt.prompt: raise EmptyPromptError() @@ -1288,14 +1934,40 @@ def _fetch_single_pass_response( TSPKeys.CUSTOM_DATA: tool.custom_data, } - responder = PromptTool( - tool=util, - prompt_host=settings.PROMPT_HOST, - prompt_port=settings.PROMPT_PORT, - request_id=StateStore.get(Common.REQUEST_ID), - ) - params = {TSPKeys.INCLUDE_METADATA: True} - return responder.single_pass_extraction(payload=payload, params=params) + if PromptStudioHelper._is_async_execution_enabled(): + # === NEW: ExecutionDispatcher → Celery executor worker === + platform_api_key = PromptStudioHelper._get_platform_api_key(org_id) + payload[ToolStudioKeys.PLATFORM_SERVICE_API_KEY] = platform_api_key + payload[TSPKeys.INCLUDE_METADATA] = True + + dispatcher = PromptStudioHelper._get_dispatcher() + context = ExecutionContext( + executor_name="legacy", + operation="single_pass_extraction", + run_id=run_id or str(uuid.uuid4()), + execution_source="ide", + organization_id=org_id, + executor_params=payload, + request_id=StateStore.get(Common.REQUEST_ID), + log_events_id=StateStore.get(Common.LOG_EVENTS_ID), + ) + result = dispatcher.dispatch(context) + if not result.success: + raise AnswerFetchError( + f"Error fetching single pass response. {result.error}", + ) + return result.data + else: + # === OLD: PromptTool HTTP → prompt-service === + util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id) + responder = PromptTool( + tool=util, + prompt_host=settings.PROMPT_HOST, + prompt_port=settings.PROMPT_PORT, + request_id=StateStore.get(Common.REQUEST_ID), + ) + params = {TSPKeys.INCLUDE_METADATA: True} + return responder.single_pass_extraction(payload=payload, params=params) @staticmethod def get_tool_from_tool_id(tool_id: str) -> CustomTool | None: @@ -1361,49 +2033,74 @@ def dynamic_extractor( IKeys.OUTPUT_FILE_PATH: extract_file_path, } - util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id) - - try: - responder = PromptTool( - tool=util, - prompt_host=settings.PROMPT_HOST, - prompt_port=settings.PROMPT_PORT, + if PromptStudioHelper._is_async_execution_enabled(): + # === NEW: ExecutionDispatcher → Celery executor worker === + platform_api_key = PromptStudioHelper._get_platform_api_key(org_id) + payload["platform_api_key"] = platform_api_key + + dispatcher = PromptStudioHelper._get_dispatcher() + extract_context = ExecutionContext( + executor_name="legacy", + operation="extract", + run_id=run_id or str(uuid.uuid4()), + execution_source="ide", + organization_id=org_id, + executor_params=payload, request_id=StateStore.get(Common.REQUEST_ID), + log_events_id=StateStore.get(Common.LOG_EVENTS_ID), ) - extracted_text = responder.extract(payload=payload) - success = PromptStudioIndexHelper.mark_extraction_status( - document_id=document_id, - profile_manager=profile_manager, - x2text_config_hash=x2text_config_hash, - enable_highlight=enable_highlight, - ) - if not success: - logger.warning( - f"Failed to mark extraction success for document {document_id}. " - f"Extraction completed but status not saved." + result = dispatcher.dispatch(extract_context) + if not result.success: + msg = result.error or "Unknown extraction error" + PromptStudioIndexHelper.mark_extraction_status( + document_id=document_id, + profile_manager=profile_manager, + x2text_config_hash=x2text_config_hash, + enable_highlight=enable_highlight, + extracted=False, + error_message=msg, ) - except SdkError as e: - msg = str(e) - if e.actual_err and hasattr(e.actual_err, "response"): - msg = e.actual_err.response.json().get("error", str(e)) - - success = PromptStudioIndexHelper.mark_extraction_status( - document_id=document_id, - profile_manager=profile_manager, - x2text_config_hash=x2text_config_hash, - enable_highlight=enable_highlight, - extracted=False, - error_message=msg, - ) - if not success: - logger.warning( - f"Failed to mark extraction failure for document {document_id}. " - f"Extraction failed but status not saved." + raise ExtractionAPIError( + f"Failed to extract '{filename}'. {msg}", + ) + extracted_text = result.data.get("extracted_text", "") + else: + # === OLD: PromptTool HTTP → prompt-service === + util = PromptIdeBaseTool(log_level=LogLevel.INFO, org_id=org_id) + try: + responder = PromptTool( + tool=util, + prompt_host=settings.PROMPT_HOST, + prompt_port=settings.PROMPT_PORT, + request_id=StateStore.get(Common.REQUEST_ID), + ) + extracted_text = responder.extract(payload=payload) + except SdkError as e: + msg = str(e) + if e.actual_err and hasattr(e.actual_err, "response"): + msg = e.actual_err.response.json().get("error", str(e)) + PromptStudioIndexHelper.mark_extraction_status( + document_id=document_id, + profile_manager=profile_manager, + x2text_config_hash=x2text_config_hash, + enable_highlight=enable_highlight, + extracted=False, + error_message=msg, + ) + raise ExtractionAPIError( + f"Failed to extract '{filename}'. {msg}", + status_code=int(e.status_code or 500), ) - raise ExtractionAPIError( - f"Failed to extract '{filename}'. {msg}", - status_code=int(e.status_code or 500), + success = PromptStudioIndexHelper.mark_extraction_status( + document_id=document_id, + profile_manager=profile_manager, + x2text_config_hash=x2text_config_hash, + enable_highlight=enable_highlight, + ) + if not success: + logger.warning( + f"Failed to mark extraction success for document {document_id}." ) return extracted_text diff --git a/backend/prompt_studio/prompt_studio_core_v2/tasks.py b/backend/prompt_studio/prompt_studio_core_v2/tasks.py new file mode 100644 index 0000000000..1ccaad8a0b --- /dev/null +++ b/backend/prompt_studio/prompt_studio_core_v2/tasks.py @@ -0,0 +1,513 @@ +import json +import logging +import uuid +from datetime import date, datetime +from typing import Any + +from account_v2.constants import Common +from celery import shared_task +from utils.constants import Account +from utils.local_context import StateStore +from utils.log_events import _emit_websocket_event + +logger = logging.getLogger(__name__) + +PROMPT_STUDIO_RESULT_EVENT = "prompt_studio_result" + + +class _SafeEncoder(json.JSONEncoder): + """JSON encoder that converts uuid.UUID and datetime objects to strings.""" + + def default(self, obj: Any) -> Any: + if isinstance(obj, uuid.UUID): + return str(obj) + if isinstance(obj, (datetime, date)): + return obj.isoformat() + return super().default(obj) + + +def _json_safe(data: Any) -> Any: + """Round-trip through JSON to convert non-serializable types. + + Handles uuid.UUID (from DRF serializers) and datetime/date objects + (from plugins or ORM fields) that stdlib json.dumps cannot handle. + """ + return json.loads(json.dumps(data, cls=_SafeEncoder)) + + +def _setup_state_store(log_events_id: str, request_id: str, org_id: str = "") -> None: + """Restore thread-local context that was captured in the Django view.""" + StateStore.set(Common.LOG_EVENTS_ID, log_events_id) + StateStore.set(Common.REQUEST_ID, request_id) + if org_id: + StateStore.set(Account.ORGANIZATION_ID, org_id) + + +def _clear_state_store() -> None: + """Clean up thread-local context to prevent leaking between tasks.""" + StateStore.clear(Common.LOG_EVENTS_ID) + StateStore.clear(Common.REQUEST_ID) + StateStore.clear(Account.ORGANIZATION_ID) + + +def _emit_result( + log_events_id: str, + task_id: str, + operation: str, + result: dict[str, Any], +) -> None: + """Push a success event to the frontend via Socket.IO.""" + _emit_websocket_event( + room=log_events_id, + event=PROMPT_STUDIO_RESULT_EVENT, + data=_json_safe( + { + "task_id": task_id, + "status": "completed", + "operation": operation, + "result": result, + } + ), + ) + + +def _emit_error( + log_events_id: str, + task_id: str, + operation: str, + error: str, + extra: dict[str, Any] | None = None, +) -> None: + """Push a failure event to the frontend via Socket.IO.""" + data: dict[str, Any] = { + "task_id": task_id, + "status": "failed", + "operation": operation, + "error": error, + } + if extra: + data.update(extra) + _emit_websocket_event( + room=log_events_id, + event=PROMPT_STUDIO_RESULT_EVENT, + data=data, + ) + + +# ------------------------------------------------------------------ +# Phase 5B — Fire-and-forget callback tasks +# +# These are lightweight callbacks invoked by Celery `link` / `link_error` +# after the executor worker finishes. They run on the backend +# (prompt_studio_callback queue) and do only post-ORM writes + socket +# emission — no heavy computation. +# ------------------------------------------------------------------ + + +@shared_task(name="ide_index_complete") +def ide_index_complete( + result_dict: dict[str, Any], + callback_kwargs: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Celery ``link`` callback after a successful ``ide_index`` execution. + + Performs post-indexing ORM bookkeeping and pushes a socket event to + the frontend. + """ + from prompt_studio.prompt_profile_manager_v2.models import ProfileManager + from prompt_studio.prompt_studio_core_v2.document_indexing_service import ( + DocumentIndexingService, + ) + from prompt_studio.prompt_studio_index_manager_v2.prompt_studio_index_helper import ( + PromptStudioIndexHelper, + ) + + cb = callback_kwargs or {} + log_events_id = cb.get("log_events_id", "") + request_id = cb.get("request_id", "") + org_id = cb.get("org_id", "") + user_id = cb.get("user_id", "") + document_id = cb.get("document_id", "") + doc_id_key = cb.get("doc_id_key", "") + profile_manager_id = cb.get("profile_manager_id") + executor_task_id = cb.get("executor_task_id", "") + + try: + _setup_state_store(log_events_id, request_id, org_id) + + # Check executor-level failure + if not result_dict.get("success", False): + error_msg = result_dict.get("error", "Unknown executor error") + logger.error("ide_index executor reported failure: %s", error_msg) + DocumentIndexingService.remove_document_indexing( + org_id=org_id, user_id=user_id, doc_id_key=doc_id_key + ) + _emit_error( + log_events_id, + executor_task_id, + "index_document", + error_msg, + extra={"document_id": document_id}, + ) + return {"status": "failed", "error": error_msg} + + doc_id = result_dict.get("data", {}).get("doc_id", doc_id_key) + + # ORM writes + DocumentIndexingService.mark_document_indexed( + org_id=org_id, user_id=user_id, doc_id_key=doc_id_key, doc_id=doc_id + ) + if profile_manager_id: + profile_manager = ProfileManager.objects.get(pk=profile_manager_id) + PromptStudioIndexHelper.handle_index_manager( + document_id=document_id, + profile_manager=profile_manager, + doc_id=doc_id, + ) + + result: dict[str, Any] = { + "message": "Document indexed successfully.", + "document_id": document_id, + } + _emit_result(log_events_id, executor_task_id, "index_document", result) + return result + except Exception as e: + logger.exception("ide_index_complete callback failed") + _emit_error( + log_events_id, + executor_task_id, + "index_document", + str(e), + extra={"document_id": document_id}, + ) + raise + finally: + _clear_state_store() + + +@shared_task(name="ide_index_error") +def ide_index_error( + failed_task_id: str, + callback_kwargs: dict[str, Any] | None = None, +) -> None: + """Celery ``link_error`` callback when an ``ide_index`` task fails. + + Cleans up the indexing-in-progress flag and pushes an error socket + event to the frontend. + """ + from celery.result import AsyncResult + + from prompt_studio.prompt_studio_core_v2.document_indexing_service import ( + DocumentIndexingService, + ) + + cb = callback_kwargs or {} + log_events_id = cb.get("log_events_id", "") + request_id = cb.get("request_id", "") + org_id = cb.get("org_id", "") + user_id = cb.get("user_id", "") + document_id = cb.get("document_id", "") + doc_id_key = cb.get("doc_id_key", "") + executor_task_id = cb.get("executor_task_id", "") + + try: + _setup_state_store(log_events_id, request_id, org_id) + + # Attempt to retrieve the actual exception from the result backend + error_msg = "Indexing failed" + try: + from backend.worker_celery import get_worker_celery_app + + res = AsyncResult(failed_task_id, app=get_worker_celery_app()) + if res.result: + error_msg = str(res.result) + except Exception: + pass + + # Clean up the indexing-in-progress flag + if doc_id_key: + DocumentIndexingService.remove_document_indexing( + org_id=org_id, user_id=user_id, doc_id_key=doc_id_key + ) + + _emit_error( + log_events_id, + executor_task_id, + "index_document", + error_msg, + extra={"document_id": document_id}, + ) + except Exception: + logger.exception("ide_index_error callback failed") + finally: + _clear_state_store() + + +@shared_task(name="ide_prompt_complete") +def ide_prompt_complete( + result_dict: dict[str, Any], + callback_kwargs: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Celery ``link`` callback after a successful answer_prompt / single_pass + execution. + + Persists prompt outputs via OutputManagerHelper and pushes a socket + event. + """ + from prompt_studio.prompt_studio_output_manager_v2.output_manager_helper import ( + OutputManagerHelper, + ) + from prompt_studio.prompt_studio_v2.models import ToolStudioPrompt + + cb = callback_kwargs or {} + log_events_id = cb.get("log_events_id", "") + request_id = cb.get("request_id", "") + org_id = cb.get("org_id", "") + operation = cb.get("operation", "fetch_response") + run_id = cb.get("run_id", "") + document_id = cb.get("document_id", "") + prompt_ids = cb.get("prompt_ids", []) + profile_manager_id = cb.get("profile_manager_id") + is_single_pass = cb.get("is_single_pass", False) + executor_task_id = cb.get("executor_task_id", "") + + try: + _setup_state_store(log_events_id, request_id, org_id) + + # Check executor-level failure + if not result_dict.get("success", False): + error_msg = result_dict.get("error", "Unknown executor error") + logger.error("ide_prompt executor reported failure: %s", error_msg) + _emit_error( + log_events_id, + executor_task_id, + operation, + error_msg, + extra={ + "prompt_ids": prompt_ids, + "document_id": document_id, + "profile_manager_id": profile_manager_id, + }, + ) + return {"status": "failed", "error": error_msg} + + data = result_dict.get("data", {}) + + # Sanitize outputs and metadata so that any non-JSON-safe + # values (e.g. datetime from plugins) are converted before + # they reach Django JSONField saves. + outputs = _json_safe(data.get("output", {})) + metadata = _json_safe(data.get("metadata", {})) + + # Re-fetch prompt ORM objects for OutputManagerHelper + prompts = list( + ToolStudioPrompt.objects.filter(prompt_id__in=prompt_ids).order_by( + "sequence_number" + ) + ) + + response = OutputManagerHelper.handle_prompt_output_update( + run_id=run_id, + prompts=prompts, + outputs=outputs, + document_id=document_id, + is_single_pass_extract=is_single_pass, + profile_manager_id=profile_manager_id, + metadata=metadata, + ) + + _emit_result(log_events_id, executor_task_id, operation, response) + # Return minimal status — full data is sent via websocket above. + # Returning the full response would cause Celery to log sensitive + # extracted data in its "Task succeeded" message. + return {"status": "completed", "operation": operation} + except Exception as e: + logger.exception("ide_prompt_complete callback failed") + _emit_error( + log_events_id, + executor_task_id, + operation, + str(e), + extra={ + "prompt_ids": prompt_ids, + "document_id": document_id, + "profile_manager_id": profile_manager_id, + }, + ) + raise + finally: + _clear_state_store() + + +@shared_task(name="ide_prompt_error") +def ide_prompt_error( + failed_task_id: str, + callback_kwargs: dict[str, Any] | None = None, +) -> None: + """Celery ``link_error`` callback when an answer_prompt / single_pass + task fails. + + Pushes an error socket event to the frontend. + """ + from celery.result import AsyncResult + + cb = callback_kwargs or {} + log_events_id = cb.get("log_events_id", "") + request_id = cb.get("request_id", "") + org_id = cb.get("org_id", "") + operation = cb.get("operation", "fetch_response") + executor_task_id = cb.get("executor_task_id", "") + + try: + _setup_state_store(log_events_id, request_id, org_id) + + error_msg = "Prompt execution failed" + try: + from backend.worker_celery import get_worker_celery_app + + res = AsyncResult(failed_task_id, app=get_worker_celery_app()) + if res.result: + error_msg = str(res.result) + except Exception: + pass + + _emit_error( + log_events_id, + executor_task_id, + operation, + error_msg, + extra={ + "prompt_ids": cb.get("prompt_ids", []), + "document_id": cb.get("document_id", ""), + "profile_manager_id": cb.get("profile_manager_id"), + }, + ) + except Exception: + logger.exception("ide_prompt_error callback failed") + finally: + _clear_state_store() + + +# ------------------------------------------------------------------ +# Legacy tasks (kept for backward compatibility during rollout) +# ------------------------------------------------------------------ + + +@shared_task(name="prompt_studio_index_document", bind=True) +def run_index_document( + self, + tool_id: str, + file_name: str, + org_id: str, + user_id: str, + document_id: str, + run_id: str, + log_events_id: str, + request_id: str, +) -> dict[str, Any]: + from prompt_studio.prompt_studio_core_v2.prompt_studio_helper import ( + PromptStudioHelper, + ) + + try: + _setup_state_store(log_events_id, request_id, org_id) + PromptStudioHelper.index_document( + tool_id=tool_id, + file_name=file_name, + org_id=org_id, + user_id=user_id, + document_id=document_id, + run_id=run_id, + ) + result: dict[str, Any] = { + "message": "Document indexed successfully.", + "document_id": document_id, + } + _emit_result(log_events_id, self.request.id, "index_document", result) + return result + except Exception as e: + logger.exception("run_index_document failed") + _emit_error( + log_events_id, + self.request.id, + "index_document", + str(e), + extra={"document_id": document_id}, + ) + raise + finally: + _clear_state_store() + + +@shared_task(name="prompt_studio_fetch_response", bind=True) +def run_fetch_response( + self, + tool_id: str, + org_id: str, + user_id: str, + document_id: str, + run_id: str, + log_events_id: str, + request_id: str, + id: str | None = None, + profile_manager_id: str | None = None, +) -> dict[str, Any]: + from prompt_studio.prompt_studio_core_v2.prompt_studio_helper import ( + PromptStudioHelper, + ) + + try: + _setup_state_store(log_events_id, request_id, org_id) + response: dict[str, Any] = PromptStudioHelper.prompt_responder( + id=id, + tool_id=tool_id, + org_id=org_id, + user_id=user_id, + document_id=document_id, + run_id=run_id, + profile_manager_id=profile_manager_id, + ) + _emit_result(log_events_id, self.request.id, "fetch_response", response) + # Return minimal status to avoid logging sensitive extracted data + return {"status": "completed", "operation": "fetch_response"} + except Exception as e: + logger.exception("run_fetch_response failed") + _emit_error(log_events_id, self.request.id, "fetch_response", str(e)) + raise + finally: + _clear_state_store() + + +@shared_task(name="prompt_studio_single_pass", bind=True) +def run_single_pass_extraction( + self, + tool_id: str, + org_id: str, + user_id: str, + document_id: str, + run_id: str, + log_events_id: str, + request_id: str, +) -> dict[str, Any]: + from prompt_studio.prompt_studio_core_v2.prompt_studio_helper import ( + PromptStudioHelper, + ) + + try: + _setup_state_store(log_events_id, request_id, org_id) + response: dict[str, Any] = PromptStudioHelper.prompt_responder( + tool_id=tool_id, + org_id=org_id, + user_id=user_id, + document_id=document_id, + run_id=run_id, + ) + _emit_result(log_events_id, self.request.id, "single_pass_extraction", response) + # Return minimal status to avoid logging sensitive extracted data + return {"status": "completed", "operation": "single_pass_extraction"} + except Exception as e: + logger.exception("run_single_pass_extraction failed") + _emit_error(log_events_id, self.request.id, "single_pass_extraction", str(e)) + raise + finally: + _clear_state_store() diff --git a/backend/prompt_studio/prompt_studio_core_v2/test_tasks.py b/backend/prompt_studio/prompt_studio_core_v2/test_tasks.py new file mode 100644 index 0000000000..4efef90987 --- /dev/null +++ b/backend/prompt_studio/prompt_studio_core_v2/test_tasks.py @@ -0,0 +1,417 @@ +"""Phase 7-9 sanity tests for Prompt Studio IDE async backend. + +Tests the Celery task definitions (Phase 7), view dispatch (Phase 8), +and polling endpoint (Phase 9). + +Requires Django to be configured (source .env before running): + set -a && source .env && set +a + uv run pytest prompt_studio/prompt_studio_core_v2/test_tasks.py -v +""" + +import os +from unittest.mock import patch + +import django + +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "backend.settings.dev") +django.setup() + +import pytest # noqa: E402 +from account_v2.constants import Common # noqa: E402 +from celery import Celery # noqa: E402 +from utils.local_context import StateStore # noqa: E402 + +from prompt_studio.prompt_studio_core_v2.tasks import ( # noqa: E402 + PROMPT_STUDIO_RESULT_EVENT, + run_fetch_response, + run_index_document, + run_single_pass_extraction, +) + +# --------------------------------------------------------------------------- +# Celery eager-mode app for testing +# --------------------------------------------------------------------------- +test_app = Celery("test") +test_app.conf.update( + task_always_eager=True, + task_eager_propagates=True, + result_backend="cache+memory://", +) +run_index_document.bind(test_app) +run_fetch_response.bind(test_app) +run_single_pass_extraction.bind(test_app) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- +COMMON_KWARGS = { + "tool_id": "tool-123", + "org_id": "org-456", + "user_id": "user-789", + "document_id": "doc-abc", + "run_id": "run-def", + "log_events_id": "session-room-xyz", + "request_id": "req-001", +} + + +# =================================================================== +# Phase 7: Task definition tests +# =================================================================== +class TestTaskNames: + def test_index_document_task_name(self): + assert run_index_document.name == "prompt_studio_index_document" + + def test_fetch_response_task_name(self): + assert run_fetch_response.name == "prompt_studio_fetch_response" + + def test_single_pass_task_name(self): + assert run_single_pass_extraction.name == "prompt_studio_single_pass" + + +class TestRunIndexDocument: + @patch("prompt_studio.prompt_studio_core_v2.tasks._emit_websocket_event") + @patch("prompt_studio.prompt_studio_core_v2.prompt_studio_helper.PromptStudioHelper") + def test_success_returns_result(self, mock_helper, mock_emit): + mock_helper.index_document.return_value = "unique-id-123" + result = run_index_document.apply( + kwargs={**COMMON_KWARGS, "file_name": "test.pdf"} + ).get() + + assert result == {"message": "Document indexed successfully."} + mock_helper.index_document.assert_called_once_with( + tool_id="tool-123", + file_name="test.pdf", + org_id="org-456", + user_id="user-789", + document_id="doc-abc", + run_id="run-def", + ) + + @patch("prompt_studio.prompt_studio_core_v2.tasks._emit_websocket_event") + @patch("prompt_studio.prompt_studio_core_v2.prompt_studio_helper.PromptStudioHelper") + def test_success_emits_completed_event(self, mock_helper, mock_emit): + mock_helper.index_document.return_value = "unique-id-123" + run_index_document.apply(kwargs={**COMMON_KWARGS, "file_name": "test.pdf"}).get() + + mock_emit.assert_called_once() + kwargs = mock_emit.call_args.kwargs + assert kwargs["room"] == "session-room-xyz" + assert kwargs["event"] == PROMPT_STUDIO_RESULT_EVENT + assert kwargs["data"]["status"] == "completed" + assert kwargs["data"]["operation"] == "index_document" + assert kwargs["data"]["result"] == {"message": "Document indexed successfully."} + assert "task_id" in kwargs["data"] + + @patch("prompt_studio.prompt_studio_core_v2.tasks._emit_websocket_event") + @patch("prompt_studio.prompt_studio_core_v2.prompt_studio_helper.PromptStudioHelper") + def test_failure_emits_error_and_reraises(self, mock_helper, mock_emit): + mock_helper.index_document.side_effect = RuntimeError("index boom") + + with pytest.raises(RuntimeError, match="index boom"): + run_index_document.apply( + kwargs={**COMMON_KWARGS, "file_name": "test.pdf"} + ).get() + + mock_emit.assert_called_once() + assert mock_emit.call_args.kwargs["data"]["status"] == "failed" + assert "index boom" in mock_emit.call_args.kwargs["data"]["error"] + + @patch("prompt_studio.prompt_studio_core_v2.tasks._emit_websocket_event") + @patch("prompt_studio.prompt_studio_core_v2.prompt_studio_helper.PromptStudioHelper") + def test_state_store_cleared_on_success(self, mock_helper, mock_emit): + mock_helper.index_document.return_value = "ok" + run_index_document.apply(kwargs={**COMMON_KWARGS, "file_name": "test.pdf"}).get() + + assert StateStore.get(Common.LOG_EVENTS_ID) is None + assert StateStore.get(Common.REQUEST_ID) is None + + @patch("prompt_studio.prompt_studio_core_v2.tasks._emit_websocket_event") + @patch("prompt_studio.prompt_studio_core_v2.prompt_studio_helper.PromptStudioHelper") + def test_state_store_cleared_on_failure(self, mock_helper, mock_emit): + mock_helper.index_document.side_effect = RuntimeError("fail") + with pytest.raises(RuntimeError): + run_index_document.apply( + kwargs={**COMMON_KWARGS, "file_name": "test.pdf"} + ).get() + + assert StateStore.get(Common.LOG_EVENTS_ID) is None + assert StateStore.get(Common.REQUEST_ID) is None + + @patch("prompt_studio.prompt_studio_core_v2.tasks._emit_websocket_event") + @patch("prompt_studio.prompt_studio_core_v2.prompt_studio_helper.PromptStudioHelper") + def test_state_store_set_during_execution(self, mock_helper, mock_emit): + """Verify StateStore has the right values while the helper runs.""" + captured = {} + + def capture_state(**kwargs): + captured["log_events_id"] = StateStore.get(Common.LOG_EVENTS_ID) + captured["request_id"] = StateStore.get(Common.REQUEST_ID) + return "ok" + + mock_helper.index_document.side_effect = capture_state + run_index_document.apply(kwargs={**COMMON_KWARGS, "file_name": "test.pdf"}).get() + + assert captured["log_events_id"] == "session-room-xyz" + assert captured["request_id"] == "req-001" + # And cleared after + assert StateStore.get(Common.LOG_EVENTS_ID) is None + + +class TestRunFetchResponse: + @patch("prompt_studio.prompt_studio_core_v2.tasks._emit_websocket_event") + @patch("prompt_studio.prompt_studio_core_v2.prompt_studio_helper.PromptStudioHelper") + def test_success_returns_response(self, mock_helper, mock_emit): + expected = {"output": {"field": "value"}, "metadata": {"tokens": 42}} + mock_helper.prompt_responder.return_value = expected + + result = run_fetch_response.apply( + kwargs={ + **COMMON_KWARGS, + "id": "prompt-1", + "profile_manager_id": "pm-1", + } + ).get() + + assert result == expected + mock_helper.prompt_responder.assert_called_once_with( + id="prompt-1", + tool_id="tool-123", + org_id="org-456", + user_id="user-789", + document_id="doc-abc", + run_id="run-def", + profile_manager_id="pm-1", + ) + + @patch("prompt_studio.prompt_studio_core_v2.tasks._emit_websocket_event") + @patch("prompt_studio.prompt_studio_core_v2.prompt_studio_helper.PromptStudioHelper") + def test_success_emits_fetch_response_event(self, mock_helper, mock_emit): + mock_helper.prompt_responder.return_value = {"output": "data"} + run_fetch_response.apply( + kwargs={**COMMON_KWARGS, "id": "p1", "profile_manager_id": None} + ).get() + + data = mock_emit.call_args.kwargs["data"] + assert data["status"] == "completed" + assert data["operation"] == "fetch_response" + + @patch("prompt_studio.prompt_studio_core_v2.tasks._emit_websocket_event") + @patch("prompt_studio.prompt_studio_core_v2.prompt_studio_helper.PromptStudioHelper") + def test_failure_emits_error(self, mock_helper, mock_emit): + mock_helper.prompt_responder.side_effect = ValueError("prompt fail") + + with pytest.raises(ValueError, match="prompt fail"): + run_fetch_response.apply(kwargs=COMMON_KWARGS).get() + + data = mock_emit.call_args.kwargs["data"] + assert data["status"] == "failed" + assert "prompt fail" in data["error"] + + @patch("prompt_studio.prompt_studio_core_v2.tasks._emit_websocket_event") + @patch("prompt_studio.prompt_studio_core_v2.prompt_studio_helper.PromptStudioHelper") + def test_optional_params_default_none(self, mock_helper, mock_emit): + mock_helper.prompt_responder.return_value = {} + run_fetch_response.apply(kwargs=COMMON_KWARGS).get() + + mock_helper.prompt_responder.assert_called_once_with( + id=None, + tool_id="tool-123", + org_id="org-456", + user_id="user-789", + document_id="doc-abc", + run_id="run-def", + profile_manager_id=None, + ) + + @patch("prompt_studio.prompt_studio_core_v2.tasks._emit_websocket_event") + @patch("prompt_studio.prompt_studio_core_v2.prompt_studio_helper.PromptStudioHelper") + def test_state_store_cleared(self, mock_helper, mock_emit): + mock_helper.prompt_responder.return_value = {} + run_fetch_response.apply(kwargs=COMMON_KWARGS).get() + assert StateStore.get(Common.LOG_EVENTS_ID) is None + + +class TestRunSinglePassExtraction: + @patch("prompt_studio.prompt_studio_core_v2.tasks._emit_websocket_event") + @patch("prompt_studio.prompt_studio_core_v2.prompt_studio_helper.PromptStudioHelper") + def test_success_returns_response(self, mock_helper, mock_emit): + expected = {"output": {"key": "val"}} + mock_helper.prompt_responder.return_value = expected + + result = run_single_pass_extraction.apply(kwargs=COMMON_KWARGS).get() + + assert result == expected + mock_helper.prompt_responder.assert_called_once_with( + tool_id="tool-123", + org_id="org-456", + user_id="user-789", + document_id="doc-abc", + run_id="run-def", + ) + + @patch("prompt_studio.prompt_studio_core_v2.tasks._emit_websocket_event") + @patch("prompt_studio.prompt_studio_core_v2.prompt_studio_helper.PromptStudioHelper") + def test_success_emits_single_pass_event(self, mock_helper, mock_emit): + mock_helper.prompt_responder.return_value = {"data": "ok"} + run_single_pass_extraction.apply(kwargs=COMMON_KWARGS).get() + + data = mock_emit.call_args.kwargs["data"] + assert data["status"] == "completed" + assert data["operation"] == "single_pass_extraction" + + @patch("prompt_studio.prompt_studio_core_v2.tasks._emit_websocket_event") + @patch("prompt_studio.prompt_studio_core_v2.prompt_studio_helper.PromptStudioHelper") + def test_failure_emits_error(self, mock_helper, mock_emit): + mock_helper.prompt_responder.side_effect = TypeError("single pass fail") + + with pytest.raises(TypeError, match="single pass fail"): + run_single_pass_extraction.apply(kwargs=COMMON_KWARGS).get() + + data = mock_emit.call_args.kwargs["data"] + assert data["status"] == "failed" + + @patch("prompt_studio.prompt_studio_core_v2.tasks._emit_websocket_event") + @patch("prompt_studio.prompt_studio_core_v2.prompt_studio_helper.PromptStudioHelper") + def test_state_store_cleared(self, mock_helper, mock_emit): + mock_helper.prompt_responder.return_value = {} + run_single_pass_extraction.apply(kwargs=COMMON_KWARGS).get() + assert StateStore.get(Common.LOG_EVENTS_ID) is None + + +# =================================================================== +# Phase 8: View dispatch tests +# =================================================================== +class TestViewsDispatchTasks: + """Verify the three views no longer call helpers directly.""" + + def test_index_document_view_has_no_blocking_call(self): + import inspect + + from prompt_studio.prompt_studio_core_v2.views import PromptStudioCoreView + + source = inspect.getsource(PromptStudioCoreView.index_document) + assert "run_index_document.apply_async" in source + assert "PromptStudioHelper.index_document(" not in source + assert "HTTP_202_ACCEPTED" in source + + def test_fetch_response_view_has_no_blocking_call(self): + import inspect + + from prompt_studio.prompt_studio_core_v2.views import PromptStudioCoreView + + source = inspect.getsource(PromptStudioCoreView.fetch_response) + assert "run_fetch_response.apply_async" in source + assert "PromptStudioHelper.prompt_responder(" not in source + assert "HTTP_202_ACCEPTED" in source + + def test_single_pass_view_has_no_blocking_call(self): + import inspect + + from prompt_studio.prompt_studio_core_v2.views import PromptStudioCoreView + + source = inspect.getsource(PromptStudioCoreView.single_pass_extraction) + assert "run_single_pass_extraction.apply_async" in source + assert "PromptStudioHelper.prompt_responder(" not in source + assert "HTTP_202_ACCEPTED" in source + + def test_views_capture_state_store_context(self): + import inspect + + from prompt_studio.prompt_studio_core_v2.views import PromptStudioCoreView + + for method_name in [ + "index_document", + "fetch_response", + "single_pass_extraction", + ]: + source = inspect.getsource(getattr(PromptStudioCoreView, method_name)) + assert ( + "StateStore.get(Common.LOG_EVENTS_ID)" in source + ), f"{method_name} missing LOG_EVENTS_ID capture" + assert ( + "StateStore.get(Common.REQUEST_ID)" in source + ), f"{method_name} missing REQUEST_ID capture" + + +# =================================================================== +# Phase 9: Polling endpoint tests +# =================================================================== +class TestTaskStatusAction: + def test_task_status_method_exists(self): + from prompt_studio.prompt_studio_core_v2.views import PromptStudioCoreView + + assert hasattr(PromptStudioCoreView, "task_status") + assert callable(PromptStudioCoreView.task_status) + + def test_task_status_url_registered(self): + from prompt_studio.prompt_studio_core_v2.urls import urlpatterns + + task_status_urls = [ + p + for p in urlpatterns + if hasattr(p, "name") and p.name == "prompt-studio-task-status" + ] + assert len(task_status_urls) >= 1 + url = task_status_urls[0] + assert "" in str(url.pattern) + assert "" in str(url.pattern) + + @patch("prompt_studio.prompt_studio_core_v2.views.AsyncResult", create=True) + def test_task_status_processing(self, MockAsyncResult): + """Verify processing response for unfinished task.""" + import inspect + + from prompt_studio.prompt_studio_core_v2.views import PromptStudioCoreView + + source = inspect.getsource(PromptStudioCoreView.task_status) + assert "not result.ready()" in source + assert '"processing"' in source + + @patch("prompt_studio.prompt_studio_core_v2.views.AsyncResult", create=True) + def test_task_status_completed(self, MockAsyncResult): + """Verify completed response structure.""" + import inspect + + from prompt_studio.prompt_studio_core_v2.views import PromptStudioCoreView + + source = inspect.getsource(PromptStudioCoreView.task_status) + assert "result.successful()" in source + assert '"completed"' in source + assert "result.result" in source + + @patch("prompt_studio.prompt_studio_core_v2.views.AsyncResult", create=True) + def test_task_status_failed(self, MockAsyncResult): + """Verify failed response structure.""" + import inspect + + from prompt_studio.prompt_studio_core_v2.views import PromptStudioCoreView + + source = inspect.getsource(PromptStudioCoreView.task_status) + assert '"failed"' in source + assert "HTTP_500_INTERNAL_SERVER_ERROR" in source + + +# =================================================================== +# Phase 6: Config tests +# =================================================================== +class TestCeleryConfig: + def test_task_routes_defined(self): + from backend.celery_config import CeleryConfig + + assert hasattr(CeleryConfig, "task_routes") + + def test_all_three_tasks_routed(self): + from backend.celery_config import CeleryConfig + + routes = CeleryConfig.task_routes + assert routes["prompt_studio_index_document"] == {"queue": "celery_prompt_studio"} + assert routes["prompt_studio_fetch_response"] == {"queue": "celery_prompt_studio"} + assert routes["prompt_studio_single_pass"] == {"queue": "celery_prompt_studio"} + + def test_celery_app_loads_routes(self): + from backend.celery_service import app + + assert app.conf.task_routes is not None + assert "prompt_studio_index_document" in app.conf.task_routes diff --git a/backend/prompt_studio/prompt_studio_core_v2/urls.py b/backend/prompt_studio/prompt_studio_core_v2/urls.py index 228368544a..f0fcb63513 100644 --- a/backend/prompt_studio/prompt_studio_core_v2/urls.py +++ b/backend/prompt_studio/prompt_studio_core_v2/urls.py @@ -59,6 +59,8 @@ {"get": "check_deployment_usage"} ) +prompt_studio_task_status = PromptStudioCoreView.as_view({"get": "task_status"}) + urlpatterns = format_suffix_patterns( [ @@ -143,5 +145,10 @@ prompt_studio_deployment_usage, name="prompt_studio_deployment_usage", ), + path( + "prompt-studio//task-status/", + prompt_studio_task_status, + name="prompt-studio-task-status", + ), ] ) diff --git a/backend/prompt_studio/prompt_studio_core_v2/views.py b/backend/prompt_studio/prompt_studio_core_v2/views.py index 6f447b51e5..b05b220766 100644 --- a/backend/prompt_studio/prompt_studio_core_v2/views.py +++ b/backend/prompt_studio/prompt_studio_core_v2/views.py @@ -2,11 +2,13 @@ import logging import uuid from datetime import datetime +from pathlib import Path from typing import Any import magic from account_v2.custom_exceptions import DuplicateData from api_v2.models import APIDeployment +from celery import signature from django.db import IntegrityError from django.db.models import QuerySet from django.http import HttpRequest, HttpResponse @@ -47,8 +49,12 @@ MaxProfilesReachedError, ToolDeleteError, ) +from feature_flag.helper import FeatureFlagHelper from prompt_studio.prompt_studio_core_v2.migration_utils import SummarizeMigrationUtils -from prompt_studio.prompt_studio_core_v2.prompt_studio_helper import PromptStudioHelper +from prompt_studio.prompt_studio_core_v2.prompt_studio_helper import ( + ASYNC_EXECUTION_FLAG, + PromptStudioHelper, +) from prompt_studio.prompt_studio_core_v2.retrieval_strategies import ( get_retrieval_strategy_metadata, ) @@ -367,87 +373,283 @@ def index_document(self, request: HttpRequest, pk: Any = None) -> Response: document_id: str = serializer.validated_data.get(ToolStudioPromptKeys.DOCUMENT_ID) document: DocumentManager = DocumentManager.objects.get(pk=document_id) file_name: str = document.document_name - # Generate a run_id run_id = CommonUtils.generate_uuid() - unique_id = PromptStudioHelper.index_document( - tool_id=str(tool.tool_id), - file_name=file_name, - org_id=UserSessionUtils.get_organization_id(request), - user_id=tool.created_by.user_id, - document_id=document_id, - run_id=run_id, - ) - if unique_id: + if FeatureFlagHelper.check_flag_status(ASYNC_EXECUTION_FLAG): + # ── NEW ASYNC PATH ── + context, cb_kwargs = PromptStudioHelper.build_index_payload( + tool_id=str(tool.tool_id), + file_name=file_name, + org_id=UserSessionUtils.get_organization_id(request), + user_id=tool.created_by.user_id, + document_id=document_id, + run_id=run_id, + ) + + dispatcher = PromptStudioHelper._get_dispatcher() + + import uuid as _uuid + + executor_task_id = str(_uuid.uuid4()) + cb_kwargs["executor_task_id"] = executor_task_id + + task = dispatcher.dispatch_with_callback( + context, + on_success=signature( + "ide_index_complete", + kwargs={"callback_kwargs": cb_kwargs}, + queue="prompt_studio_callback", + ), + on_error=signature( + "ide_index_error", + kwargs={"callback_kwargs": cb_kwargs}, + queue="prompt_studio_callback", + ), + task_id=executor_task_id, + ) return Response( - {"message": "Document indexed successfully."}, - status=status.HTTP_200_OK, + {"task_id": task.id, "run_id": run_id, "status": "accepted"}, + status=status.HTTP_202_ACCEPTED, ) else: - logger.error("Error occured while indexing. Unique ID is not valid.") - raise IndexingAPIError() + # ── OLD SYNC PATH ── + unique_id = PromptStudioHelper.index_document( + tool_id=str(tool.tool_id), + file_name=file_name, + org_id=UserSessionUtils.get_organization_id(request), + user_id=tool.created_by.user_id, + document_id=document_id, + run_id=run_id, + ) + if unique_id: + return Response( + {"message": "Document indexed successfully."}, + status=status.HTTP_200_OK, + ) + else: + raise IndexingAPIError() @action(detail=True, methods=["post"]) def fetch_response(self, request: HttpRequest, pk: Any = None) -> Response: """API Entry point method to fetch response to prompt. Args: - request (HttpRequest): _description_ - - Raises: - FilenameMissingError: _description_ + request (HttpRequest) Returns: Response """ custom_tool = self.get_object() - tool_id: str = str(custom_tool.tool_id) document_id: str = request.data.get(ToolStudioPromptKeys.DOCUMENT_ID) - id: str = request.data.get(ToolStudioPromptKeys.ID) + prompt_id: str = request.data.get(ToolStudioPromptKeys.ID) run_id: str = request.data.get(ToolStudioPromptKeys.RUN_ID) - profile_manager: str = request.data.get(ToolStudioPromptKeys.PROFILE_MANAGER_ID) + profile_manager_id: str = request.data.get( + ToolStudioPromptKeys.PROFILE_MANAGER_ID + ) if not run_id: - # Generate a run_id run_id = CommonUtils.generate_uuid() - response: dict[str, Any] = PromptStudioHelper.prompt_responder( - id=id, - tool_id=tool_id, - org_id=UserSessionUtils.get_organization_id(request), - user_id=custom_tool.created_by.user_id, - document_id=document_id, - run_id=run_id, - profile_manager_id=profile_manager, - ) - return Response(response, status=status.HTTP_200_OK) + + if FeatureFlagHelper.check_flag_status(ASYNC_EXECUTION_FLAG): + # ── NEW ASYNC PATH ── + org_id = UserSessionUtils.get_organization_id(request) + user_id = custom_tool.created_by.user_id + + prompt = ToolStudioPrompt.objects.get(pk=prompt_id) + + doc_path = PromptStudioFileHelper.get_or_create_prompt_studio_subdirectory( + org_id, + is_create=False, + user_id=user_id, + tool_id=str(custom_tool.tool_id), + ) + document: DocumentManager = DocumentManager.objects.get(pk=document_id) + doc_path = str(Path(doc_path) / document.document_name) + + context, cb_kwargs = PromptStudioHelper.build_fetch_response_payload( + tool=custom_tool, + doc_path=doc_path, + doc_name=document.document_name, + prompt=prompt, + org_id=org_id, + user_id=user_id, + document_id=document_id, + run_id=run_id, + profile_manager_id=profile_manager_id, + ) + + # If document is being indexed, return pending status + if context is None: + return Response(cb_kwargs, status=status.HTTP_200_OK) + + dispatcher = PromptStudioHelper._get_dispatcher() + + import uuid as _uuid + + executor_task_id = str(_uuid.uuid4()) + cb_kwargs["executor_task_id"] = executor_task_id + + task = dispatcher.dispatch_with_callback( + context, + on_success=signature( + "ide_prompt_complete", + kwargs={"callback_kwargs": cb_kwargs}, + queue="prompt_studio_callback", + ), + on_error=signature( + "ide_prompt_error", + kwargs={"callback_kwargs": cb_kwargs}, + queue="prompt_studio_callback", + ), + task_id=executor_task_id, + ) + return Response( + {"task_id": task.id, "run_id": run_id, "status": "accepted"}, + status=status.HTTP_202_ACCEPTED, + ) + else: + # ── OLD SYNC PATH ── + tool_id: str = str(custom_tool.tool_id) + response: dict[str, Any] = PromptStudioHelper.prompt_responder( + id=prompt_id, + tool_id=tool_id, + org_id=UserSessionUtils.get_organization_id(request), + user_id=custom_tool.created_by.user_id, + document_id=document_id, + run_id=run_id, + profile_manager_id=profile_manager_id, + ) + return Response(response, status=status.HTTP_200_OK) @action(detail=True, methods=["post"]) def single_pass_extraction(self, request: HttpRequest, pk: uuid) -> Response: - """API Entry point method to fetch response to prompt. + """API Entry point method for single pass extraction. Args: - request (HttpRequest): _description_ - pk (Any): Primary key of the CustomTool + request (HttpRequest) + pk: Primary key of the CustomTool Returns: Response """ - # TODO: Handle fetch_response and single_pass_ - # extraction using common function custom_tool = self.get_object() - tool_id: str = str(custom_tool.tool_id) document_id: str = request.data.get(ToolStudioPromptKeys.DOCUMENT_ID) run_id: str = request.data.get(ToolStudioPromptKeys.RUN_ID) if not run_id: - # Generate a run_id run_id = CommonUtils.generate_uuid() - response: dict[str, Any] = PromptStudioHelper.prompt_responder( - tool_id=tool_id, - org_id=UserSessionUtils.get_organization_id(request), - user_id=custom_tool.created_by.user_id, - document_id=document_id, - run_id=run_id, + + if FeatureFlagHelper.check_flag_status(ASYNC_EXECUTION_FLAG): + # ── NEW ASYNC PATH ── + org_id = UserSessionUtils.get_organization_id(request) + user_id = custom_tool.created_by.user_id + + doc_path = PromptStudioFileHelper.get_or_create_prompt_studio_subdirectory( + org_id, + is_create=False, + user_id=user_id, + tool_id=str(custom_tool.tool_id), + ) + document: DocumentManager = DocumentManager.objects.get(pk=document_id) + doc_path = str(Path(doc_path) / document.document_name) + + prompts = list( + ToolStudioPrompt.objects.filter(tool_id=custom_tool.tool_id).order_by( + "sequence_number" + ) + ) + prompts = [ + p + for p in prompts + if p.prompt_type != ToolStudioPromptKeys.NOTES + and p.active + and p.enforce_type != ToolStudioPromptKeys.TABLE + and p.enforce_type != ToolStudioPromptKeys.RECORD + ] + if not prompts: + return Response( + {"error": "No active prompts found for single pass extraction."}, + status=status.HTTP_400_BAD_REQUEST, + ) + + context, cb_kwargs = PromptStudioHelper.build_single_pass_payload( + tool=custom_tool, + doc_path=doc_path, + doc_name=document.document_name, + prompts=prompts, + org_id=org_id, + document_id=document_id, + run_id=run_id, + ) + + dispatcher = PromptStudioHelper._get_dispatcher() + + import uuid as _uuid + + executor_task_id = str(_uuid.uuid4()) + cb_kwargs["executor_task_id"] = executor_task_id + + task = dispatcher.dispatch_with_callback( + context, + on_success=signature( + "ide_prompt_complete", + kwargs={"callback_kwargs": cb_kwargs}, + queue="prompt_studio_callback", + ), + on_error=signature( + "ide_prompt_error", + kwargs={"callback_kwargs": cb_kwargs}, + queue="prompt_studio_callback", + ), + task_id=executor_task_id, + ) + return Response( + {"task_id": task.id, "run_id": run_id, "status": "accepted"}, + status=status.HTTP_202_ACCEPTED, + ) + else: + # ── OLD SYNC PATH ── + tool_id: str = str(custom_tool.tool_id) + response: dict[str, Any] = PromptStudioHelper.prompt_responder( + tool_id=tool_id, + org_id=UserSessionUtils.get_organization_id(request), + user_id=custom_tool.created_by.user_id, + document_id=document_id, + run_id=run_id, + ) + return Response(response, status=status.HTTP_200_OK) + + @action(detail=True, methods=["get"]) + def task_status( + self, request: HttpRequest, pk: Any = None, task_id: str = None + ) -> Response: + """Poll the status of an async Prompt Studio task. + + Task IDs now point to executor worker tasks dispatched via the + worker-v2 Celery app. Both apps share the same PostgreSQL + result backend, so we use the worker app to look up results. + + Args: + request (HttpRequest) + pk: Primary key of the CustomTool (for permission check) + task_id: Celery task ID returned by the 202 response + + Returns: + Response with {task_id, status} and optionally result or error + """ + from celery.result import AsyncResult + + from backend.worker_celery import get_worker_celery_app + + result = AsyncResult(task_id, app=get_worker_celery_app()) + if not result.ready(): + return Response({"task_id": task_id, "status": "processing"}) + if result.successful(): + return Response( + {"task_id": task_id, "status": "completed", "result": result.result} + ) + return Response( + {"task_id": task_id, "status": "failed", "error": str(result.result)}, + status=status.HTTP_500_INTERNAL_SERVER_ERROR, ) - return Response(response, status=status.HTTP_200_OK) @action(detail=True, methods=["get"]) def list_of_shared_users(self, request: HttpRequest, pk: Any = None) -> Response: diff --git a/backend/usage_v2/helper.py b/backend/usage_v2/helper.py index 8cefb3b403..04f256ff9b 100644 --- a/backend/usage_v2/helper.py +++ b/backend/usage_v2/helper.py @@ -74,6 +74,64 @@ def get_aggregated_token_count(run_id: str) -> dict: logger.error(f"An unexpected error occurred for run_id {run_id}: {str(e)}") raise APIException("Error while aggregating token counts") + @staticmethod + def get_usage_by_model(run_id: str) -> dict[str, list[dict[str, Any]]]: + """Get per-model usage breakdown matching prompt-service format. + + Groups usage data by (usage_type, llm_usage_reason, model_name) and + returns cost arrays keyed as 'extraction_llm', 'challenge_llm', + 'embedding', etc. — matching the legacy prompt-service response. + + Args: + run_id: The file_execution_id / run_id to query. + + Returns: + Dict with keys like 'extraction_llm', 'embedding' mapping to + lists of per-model cost entries. Empty dict on error. + """ + try: + rows = ( + Usage.objects.filter(run_id=run_id) + .values("usage_type", "llm_usage_reason", "model_name") + .annotate( + sum_input_tokens=Sum("prompt_tokens"), + sum_output_tokens=Sum("completion_tokens"), + sum_total_tokens=Sum("total_tokens"), + sum_embedding_tokens=Sum("embedding_tokens"), + sum_cost=Sum("cost_in_dollars"), + ) + ) + result: dict[str, list[dict[str, Any]]] = {} + for row in rows: + usage_type = row["usage_type"] + llm_reason = row["llm_usage_reason"] + cost_str = UsageHelper._format_float_positional(row["sum_cost"] or 0.0) + + key = usage_type + item: dict[str, Any] = { + "model_name": row["model_name"], + "cost_in_dollars": cost_str, + } + if llm_reason: + key = f"{llm_reason}_{usage_type}" + item["input_tokens"] = row["sum_input_tokens"] or 0 + item["output_tokens"] = row["sum_output_tokens"] or 0 + item["total_tokens"] = row["sum_total_tokens"] or 0 + else: + item["embedding_tokens"] = row["sum_embedding_tokens"] or 0 + + result.setdefault(key, []).append(item) + return result + except Exception as e: + logger.error("Error querying per-model usage for run_id %s: %s", run_id, e) + return {} + + @staticmethod + def _format_float_positional(value: float, precision: int = 10) -> str: + """Format float without scientific notation, stripping trailing zeros.""" + formatted: str = f"{value:.{precision}f}" + return formatted.rstrip("0").rstrip(".") if "." in formatted else formatted + @staticmethod def aggregate_usage_metrics(queryset: QuerySet) -> dict[str, Any]: """Aggregate usage metrics from a queryset of Usage objects. diff --git a/backend/workflow_manager/workflow_v2/dto.py b/backend/workflow_manager/workflow_v2/dto.py index b2398e883e..7c06126db8 100644 --- a/backend/workflow_manager/workflow_v2/dto.py +++ b/backend/workflow_manager/workflow_v2/dto.py @@ -61,13 +61,41 @@ def remove_result_metadata_keys(self, keys_to_remove: list[str] = []) -> None: for item in self.result: if not isinstance(item, dict): - break + continue + # Handle metadata nested inside item["result"]["metadata"] result = item.get("result") - if not isinstance(result, dict): - break + if isinstance(result, dict): + self._remove_specific_keys(result=result, keys_to_remove=keys_to_remove) + + # Handle top-level item["metadata"] (workers cache path) + if "metadata" in item: + if keys_to_remove: + item_metadata = item["metadata"] + if isinstance(item_metadata, dict): + for key in keys_to_remove: + item_metadata.pop(key, None) + else: + item.pop("metadata", None) + + def remove_inner_result_metadata(self) -> None: + """Removes only the inner item["result"]["metadata"] dict (extraction + metadata like highlight_data, per-model costs, etc.) while preserving + the outer item["metadata"] dict which contains workflow identification + keys (source_name, source_hash, workflow_id, etc.). + + Use this instead of remove_result_metadata_keys() when you want to + strip extraction metadata but keep workflow identification metadata. + """ + if not isinstance(self.result, list): + return - self._remove_specific_keys(result=result, keys_to_remove=keys_to_remove) + for item in self.result: + if not isinstance(item, dict): + continue + result = item.get("result") + if isinstance(result, dict): + result.pop("metadata", None) def remove_result_metrics(self) -> None: """Removes the 'metrics' key from the 'result' dictionary within each diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 6f1996818a..1318d515d6 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -57,6 +57,30 @@ services: labels: - traefik.enable=false + # Celery worker for Prompt Studio IDE callbacks. + # Processes post-execution ORM updates and Socket.IO notifications + # after executor workers complete tasks (ide_index_complete, ide_prompt_complete, etc.). + worker-prompt-studio-callback: + image: unstract/backend:${VERSION} + container_name: unstract-worker-prompt-studio-callback + restart: unless-stopped + entrypoint: .venv/bin/celery + command: "-A backend worker --loglevel=info -Q prompt_studio_callback --autoscale=${WORKER_PROMPT_STUDIO_CALLBACK_AUTOSCALE:-4,1}" + env_file: + - ../backend/.env + - ./essentials.env + depends_on: + - db + - redis + - rabbitmq + environment: + - ENVIRONMENT=development + - APPLICATION_NAME=unstract-worker-prompt-studio-callback + labels: + - traefik.enable=false + volumes: + - prompt_studio_data:/app/prompt-studio-data + # Celery Flower celery-flower: image: unstract/backend:${VERSION} @@ -483,6 +507,42 @@ services: - ./workflow_data:/data - ${TOOL_REGISTRY_CONFIG_SRC_PATH}:/data/tool_registry_config + worker-executor-v2: + image: unstract/worker-unified:${VERSION} + container_name: unstract-worker-executor-v2 + restart: unless-stopped + command: ["executor"] + ports: + - "8092:8088" + env_file: + - ../workers/.env + - ./essentials.env + depends_on: + - db + - redis + - rabbitmq + - platform-service + environment: + - ENVIRONMENT=development + - APPLICATION_NAME=unstract-worker-executor-v2 + - WORKER_TYPE=executor + - WORKER_NAME=executor-worker-v2 + - EXECUTOR_METRICS_PORT=8088 + - HEALTH_PORT=8088 + # Configurable Celery options + - CELERY_QUEUES_EXECUTOR=${CELERY_QUEUES_EXECUTOR:-celery_executor_legacy} + - CELERY_POOL=${WORKER_EXECUTOR_POOL:-prefork} + - CELERY_PREFETCH_MULTIPLIER=${WORKER_EXECUTOR_PREFETCH_MULTIPLIER:-1} + - CELERY_CONCURRENCY=${WORKER_EXECUTOR_CONCURRENCY:-2} + - CELERY_EXTRA_ARGS=${WORKER_EXECUTOR_EXTRA_ARGS:-} + labels: + - traefik.enable=false + volumes: + - ./workflow_data:/data + - ${TOOL_REGISTRY_CONFIG_SRC_PATH}:/data/tool_registry_config + profiles: + - workers-v2 + volumes: prompt_studio_data: unstract_data: diff --git a/docker/dockerfiles/worker-unified.Dockerfile b/docker/dockerfiles/worker-unified.Dockerfile index 202f71b699..0ea425b623 100644 --- a/docker/dockerfiles/worker-unified.Dockerfile +++ b/docker/dockerfiles/worker-unified.Dockerfile @@ -83,6 +83,20 @@ RUN uv sync --group deploy --locked && \ touch requirements.txt && \ { chown -R worker:worker ./run-worker.sh ./run-worker-docker.sh 2>/dev/null || true; } +# Install executor plugins from workers/plugins/ (cloud-only, no-op for OSS). +# Plugins register via setuptools entry points in two groups: +# - unstract.executor.executors (executor classes, e.g. table_extractor) +# - unstract.executor.plugins (utility plugins, e.g. highlight-data, challenge) +# Editable installs (-e) ensure Path(__file__) resolves to the source directory, +# giving plugins access to non-Python assets (.md prompts, .txt templates, etc.). +RUN for plugin_dir in /app/plugins/*/; do \ + if [ -f "$plugin_dir/pyproject.toml" ] && \ + grep -qE 'unstract\.executor\.(executors|plugins)' "$plugin_dir/pyproject.toml" 2>/dev/null; then \ + echo "Installing executor plugin: $(basename $plugin_dir)" && \ + uv pip install -e "$plugin_dir" || true; \ + fi; \ + done + # Switch to worker user USER worker diff --git a/docker/dockerfiles/worker-unified.Dockerfile.dockerignore b/docker/dockerfiles/worker-unified.Dockerfile.dockerignore index fca472f1f1..110627ea61 100644 --- a/docker/dockerfiles/worker-unified.Dockerfile.dockerignore +++ b/docker/dockerfiles/worker-unified.Dockerfile.dockerignore @@ -51,7 +51,6 @@ Thumbs.db # Documentation **/docs/ -**/*.md !README.md !unstract !unstract/** diff --git a/docker/sample.compose.override.yaml b/docker/sample.compose.override.yaml index eeb728c822..32f5d3573d 100644 --- a/docker/sample.compose.override.yaml +++ b/docker/sample.compose.override.yaml @@ -319,6 +319,23 @@ services: - action: rebuild path: ../workers/uv.lock + ######################################################################################################### + # Prompt Studio callback worker (Django backend, processes prompt_studio_callback queue) + worker-prompt-studio-callback: + build: + dockerfile: docker/dockerfiles/backend.Dockerfile + context: .. + develop: + watch: + - action: sync+restart + path: ../backend/ + target: /app + ignore: [.venv/, __pycache__/, "*.pyc", .pytest_cache/, .mypy_cache/] + - action: sync+restart + path: ../unstract/ + target: /unstract + ignore: [.venv/, __pycache__/, "*.pyc", .pytest_cache/, .mypy_cache/] + # V1 workers disabled by default (use workers-v2 profile instead) worker: profiles: diff --git a/docs/local-dev-setup-executor-migration.md b/docs/local-dev-setup-executor-migration.md new file mode 100644 index 0000000000..8bb6921fee --- /dev/null +++ b/docs/local-dev-setup-executor-migration.md @@ -0,0 +1,586 @@ +# Local Dev Setup: Executor Migration (Pluggable Executor System v2) + +> **Branch:** `feat/execution-backend` +> **Date:** 2026-02-19 + +This guide covers everything needed to run and test the executor migration locally. + +--- + +## Table of Contents + +1. [Architecture Overview (Post-Migration)](#1-architecture-overview-post-migration) +2. [Prerequisites](#2-prerequisites) +3. [Service Dependency Map](#3-service-dependency-map) +4. [Step-by-Step Setup](#4-step-by-step-setup) +5. [Environment Configuration](#5-environment-configuration) +6. [Running the Executor Worker](#6-running-the-executor-worker) +7. [Port Reference](#7-port-reference) +8. [Health Check Endpoints](#8-health-check-endpoints) +9. [Debugging & Troubleshooting](#9-debugging--troubleshooting) +10. [Test Verification Checklist](#10-test-verification-checklist) + +--- + +## 1. Architecture Overview (Post-Migration) + +``` +┌──────────────────────────────────────────────────────────────┐ +│ CALLERS │ +│ │ +│ Workflow Path: │ +│ process_file_batch → structure_tool_task │ +│ → ExecutionDispatcher.dispatch() [Celery] │ +│ → AsyncResult.get() │ +│ │ +│ Prompt Studio IDE: │ +│ Django View → PromptStudioHelper │ +│ → ExecutionDispatcher.dispatch() [Celery] │ +│ → AsyncResult.get() │ +└───────────────────────┬──────────────────────────────────────┘ + │ Celery task: execute_extraction + ▼ +┌──────────────────────────────────────────────────────────────┐ +│ EXECUTOR WORKER (dedicated, queue: "executor") │ +│ │ +│ execute_extraction task │ +│ → ExecutionOrchestrator → ExecutorRegistry → LegacyExecutor │ +│ → Returns ExecutionResult via Celery result backend │ +└──────────────────────────────────────────────────────────────┘ +``` + +**What changed:** +- `prompt-service` Flask app is **replaced** by the executor worker (Celery) +- Structure tool Docker container is **replaced** by `structure_tool_task` (Celery task in file_processing worker) +- `PromptTool` SDK HTTP client is **replaced** by `ExecutionDispatcher` (Celery dispatch) +- **No DB schema changes** — no Django migrations needed + +**What stays the same:** +- `platform-service` (port 3001) — still serves tool metadata +- `runner` (port 5002) — still needed for Classifier, Text Extractor, Translate tools +- `x2text-service` (port 3004) — still needed for text extraction +- All adapter SDKs (LLM, Embedding, VectorDB, X2Text) — used by LegacyExecutor via ExecutorToolShim +- Frontend — no changes (same REST API responses) + +--- + +## 2. Prerequisites + +### 2.1 System Requirements + +| Requirement | Minimum | Notes | +|---|---|---| +| Docker + Docker Compose | v2.20+ | `docker compose version` | +| Python | 3.11+ | System or pyenv | +| uv | Latest | `pip install uv` or use the repo-local binary at `backend/venv/bin/uv` | +| Git | 2.30+ | On `feat/execution-backend` branch | +| Free RAM | 8 GB+ | Many services run concurrently | +| Free Disk | 10 GB+ | Docker images + volumes | + +### 2.2 Verify Branch + +```bash +cd /home/harini/Documents/Workspace/unstract-poc/clean/unstract +git branch --show-current +# Expected: feat/execution-backend +``` + +### 2.3 Required Docker Images + +The system needs these images built: + +```bash +# Build all images (from docker/ directory) +cd docker +docker compose -f docker-compose.build.yaml build + +# Or build just the critical ones: +docker compose -f docker-compose.build.yaml build backend +docker compose -f docker-compose.build.yaml build platform-service +docker compose -f docker-compose.build.yaml build worker-unified # V2 workers including executor +docker compose -f docker-compose.build.yaml build runner +docker compose -f docker-compose.build.yaml build frontend +``` + +> **Tip:** For faster dev builds, set `MINIMAL_BUILD=1` in docker-compose.build.yaml args. + +--- + +## 3. Service Dependency Map + +### Essential Infrastructure (must be running for ANYTHING to work) + +| Service | Container | Port | Purpose | +|---|---|---|---| +| PostgreSQL (pgvector) | `unstract-db` | 5432 | Primary database | +| Redis | `unstract-redis` | 6379 | Cache + queues | +| RabbitMQ | `unstract-rabbitmq` | 5672 (AMQP), 15672 (UI) | Celery message broker | +| MinIO | `unstract-minio` | 9000 (S3), 9001 (Console) | Object storage | +| Traefik | `unstract-proxy` | 80, 8080 (Dashboard) | Reverse proxy | + +### Application Services + +| Service | Container | Port | Required For | +|---|---|---|---| +| Backend (Django) | `unstract-backend` | 8000 | API, auth, DB migrations | +| Platform Service | `unstract-platform-service` | 3001 | Tool metadata, adapter configs | +| X2Text Service | `unstract-x2text-service` | 3004 | Text extraction (used by executor) | +| Runner | `unstract-runner` | 5002 | Non-structure tools (Classifier, etc.) | +| Frontend | `unstract-frontend` | 3000 | Web UI | +| Flipt | `unstract-flipt` | 8082 (REST), 9005 (gRPC) | Feature flags | + +### Workers (V2 Unified — `--profile workers-v2`) + +| Worker | Container | Health Port | Queue(s) | +|---|---|---|---| +| **Executor** | `unstract-worker-executor-v2` | 8088 | `executor` | +| File Processing | `unstract-worker-file-processing-v2` | 8082 | `file_processing`, `api_file_processing` | +| API Deployment | `unstract-worker-api-deployment-v2` | 8090 | `celery_api_deployments` | +| Callback | `unstract-worker-callback-v2` | 8083 | `file_processing_callback`, `api_file_processing_callback` | +| General | `unstract-worker-general-v2` | 8082 | `celery` | +| Notification | `unstract-worker-notification-v2` | 8085 | `notifications`, `notifications_*` | +| Log Consumer | `unstract-worker-log-consumer-v2` | 8084 | `celery_log_task_queue` | +| Scheduler | `unstract-worker-scheduler-v2` | 8087 | `scheduler` | + +### Post-Migration: REMOVED Services + +| Service | Port | Replaced By | +|---|---|---| +| ~~Prompt Service~~ | ~~3003~~ | Executor Worker (LegacyExecutor inline) | +| ~~Structure Tool (Docker)~~ | N/A | `structure_tool_task` (Celery) | + +--- + +## 4. Step-by-Step Setup + +### 4.1 Start Essential Infrastructure + +```bash +cd /home/harini/Documents/Workspace/unstract-poc/clean/unstract/docker + +# Start infrastructure services only +docker compose -f docker-compose-dev-essentials.yaml up -d +``` + +Wait for all services to be healthy: +```bash +docker compose -f docker-compose-dev-essentials.yaml ps +``` + +### 4.2 Start Application Services + +**Option A: All via Docker Compose (recommended for first-time setup)** + +```bash +cd docker + +# Start everything including V2 workers (with executor) +docker compose --profile workers-v2 up -d +``` + +**Option B: Hybrid mode (services in Docker, workers local)** + +This is useful when you want to iterate on worker code without rebuilding images. + +```bash +# Start only infrastructure + app services (no V2 workers) +docker compose up -d + +# Then run executor worker locally (see Section 6) +``` + +### 4.3 Verify DB Migrations + +The backend container runs migrations on startup (`--migrate` flag). Verify: + +```bash +docker logs unstract-backend 2>&1 | grep -i "migration" +``` + +### 4.4 Create Workers .env for Local Development + +If running workers outside Docker, you need a local `.env`: + +```bash +cd /home/harini/Documents/Workspace/unstract-poc/clean/unstract/workers + +# Copy sample and adjust for local dev +cp sample.env .env +``` + +Then edit `workers/.env` — change all Docker hostnames to `localhost`: + +```ini +# === CRITICAL CHANGES FOR LOCAL DEV === +DJANGO_APP_BACKEND_URL=http://localhost:8000 +INTERNAL_API_BASE_URL=http://localhost:8000/internal +CELERY_BROKER_BASE_URL=amqp://localhost:5672// +DB_HOST=localhost +REDIS_HOST=localhost +CACHE_REDIS_HOST=localhost +PLATFORM_SERVICE_HOST=http://localhost +PLATFORM_SERVICE_PORT=3001 +PROMPT_HOST=http://localhost +PROMPT_PORT=3003 +X2TEXT_HOST=http://localhost +X2TEXT_PORT=3004 +UNSTRACT_RUNNER_HOST=http://localhost +UNSTRACT_RUNNER_PORT=5002 +WORKFLOW_EXECUTION_FILE_STORAGE_CREDENTIALS='{"provider": "minio", "credentials": {"endpoint_url": "http://localhost:9000", "key": "minio", "secret": "minio123"}}' +API_FILE_STORAGE_CREDENTIALS='{"provider": "minio", "credentials": {"endpoint_url": "http://localhost:9000", "key": "minio", "secret": "minio123"}}' +``` + +> **Important:** The `INTERNAL_SERVICE_API_KEY` must match what the backend expects. Default dev value: `dev-internal-key-123`. + +--- + +## 5. Environment Configuration + +### 5.1 Key Environment Variables for Executor Worker + +| Variable | Default (Docker) | Local Override | Purpose | +|---|---|---|---| +| `CELERY_BROKER_BASE_URL` | `amqp://unstract-rabbitmq:5672//` | `amqp://localhost:5672//` | RabbitMQ connection | +| `CELERY_BROKER_USER` | `admin` | same | RabbitMQ user | +| `CELERY_BROKER_PASS` | `password` | same | RabbitMQ password | +| `DB_HOST` | `unstract-db` | `localhost` | PostgreSQL for result backend | +| `DB_USER` | `unstract_dev` | same | DB user | +| `DB_PASSWORD` | `unstract_pass` | same | DB password | +| `DB_NAME` | `unstract_db` | same | DB name | +| `DB_PORT` | `5432` | same | DB port | +| `REDIS_HOST` | `unstract-redis` | `localhost` | Redis for caching | +| `PLATFORM_SERVICE_HOST` | `http://unstract-platform-service` | `http://localhost` | Platform service URL | +| `PLATFORM_SERVICE_PORT` | `3001` | same | Platform service port | +| `X2TEXT_HOST` | `http://unstract-x2text-service` | `http://localhost` | X2Text service URL | +| `X2TEXT_PORT` | `3004` | same | X2Text service port | +| `INTERNAL_SERVICE_API_KEY` | `dev-internal-key-123` | same | Worker→Backend auth | +| `INTERNAL_API_BASE_URL` | `http://unstract-backend:8000/internal` | `http://localhost:8000/internal` | Backend internal API | +| `WORKFLOW_EXECUTION_FILE_STORAGE_CREDENTIALS` | (MinIO JSON, Docker host) | (MinIO JSON, localhost) | Shared file storage | + +### 5.2 Credentials Reference (Default Dev) + +| Service | Username | Password | +|---|---|---| +| PostgreSQL | `unstract_dev` | `unstract_pass` | +| RabbitMQ | `admin` | `password` | +| MinIO | `minio` | `minio123` | +| Redis | (none) | (none) | + +### 5.3 Hierarchical Celery Config + +Worker settings use a 3-tier hierarchy (most specific wins): + +1. **Worker-specific:** `EXECUTOR_TASK_TIME_LIMIT=7200` +2. **Global Celery:** `CELERY_TASK_TIME_LIMIT=3600` +3. **Code default:** (hardcoded fallback) + +--- + +## 6. Running the Executor Worker + +### 6.1 Via Docker Compose (easiest) + +```bash +cd docker + +# Start just the executor worker (assumes infra is up) +docker compose --profile workers-v2 up -d worker-executor-v2 + +# Check logs +docker logs -f unstract-worker-executor-v2 +``` + +### 6.2 Locally with run-worker.sh + +```bash +cd /home/harini/Documents/Workspace/unstract-poc/clean/unstract/workers + +# Ensure .env has local overrides (Section 4.4) +./run-worker.sh executor +``` + +Options: +```bash +./run-worker.sh -l DEBUG executor # Debug logging +./run-worker.sh -c 4 executor # 4 concurrent tasks +./run-worker.sh -P threads executor # Thread pool instead of prefork +./run-worker.sh -d executor # Run in background (detached) +./run-worker.sh -s # Show status of all workers +./run-worker.sh -k # Kill all workers +``` + +### 6.3 Locally with uv (manual) + +```bash +cd /home/harini/Documents/Workspace/unstract-poc/clean/unstract/workers + +# Load env +set -a && source .env && set +a + +# Run executor worker +uv run celery -A worker worker \ + --queues=executor \ + --loglevel=INFO \ + --pool=prefork \ + --concurrency=2 \ + --hostname=executor-worker@%h +``` + +### 6.4 Verify Executor Worker is Running + +```bash +# Check health endpoint +curl -s http://localhost:8088/health | python3 -m json.tool + +# Check Celery registered tasks +uv run celery -A worker inspect registered \ + --destination=executor-worker@$(hostname) + +# Expected task: execute_extraction +``` + +### 6.5 Running All V2 Workers + +```bash +# Via Docker +cd docker && docker compose --profile workers-v2 up -d + +# Via script (local) +cd workers && ./run-worker.sh all +``` + +--- + +## 7. Port Reference + +### Infrastructure + +| Service | Port | URL | +|---|---|---| +| PostgreSQL | 5432 | `psql -h localhost -U unstract_dev -d unstract_db` | +| Redis | 6379 | `redis-cli -h localhost` | +| RabbitMQ AMQP | 5672 | `amqp://admin:password@localhost:5672//` | +| RabbitMQ Management | 15672 | http://localhost:15672 (admin/password) | +| MinIO S3 API | 9000 | http://localhost:9000 | +| MinIO Console | 9001 | http://localhost:9001 (minio/minio123) | +| Qdrant | 6333 | http://localhost:6333 | +| Traefik Dashboard | 8080 | http://localhost:8080 | + +### Application + +| Service | Port | URL | +|---|---|---| +| Backend API | 8000 | http://localhost:8000/api/v1/ | +| Frontend | 3000 | http://frontend.unstract.localhost | +| Platform Service | 3001 | http://localhost:3001 | +| X2Text Service | 3004 | http://localhost:3004 | +| Runner | 5002 | http://localhost:5002 | +| Celery Flower (optional) | 5555 | http://localhost:5555 | + +### V2 Worker Health Ports + +| Worker | Internal Port | External Port (Docker) | +|---|---|---| +| API Deployment | 8090 | 8085 | +| Callback | 8083 | 8086 | +| File Processing | 8082 | 8087 | +| General | 8082 | 8088 | +| Notification | 8085 | 8089 | +| Log Consumer | 8084 | 8090 | +| Scheduler | 8087 | 8091 | +| **Executor** | **8088** | **8092** | + +### Debug Ports (Docker dev mode via compose.override.yaml) + +| Service | Debug Port | +|---|---| +| Backend | 5678 | +| Runner | 5679 | +| Platform Service | 5680 | +| Prompt Service | 5681 | +| File Processing Worker | 5682 | +| Callback Worker | 5683 | +| API Deployment Worker | 5684 | +| General Worker | 5685 | + +--- + +## 8. Health Check Endpoints + +Every V2 worker exposes `GET /health` on its health port: + +```bash +# Executor worker +curl -s http://localhost:8088/health + +# Expected response: +# {"status": "healthy", "worker_type": "executor", ...} +``` + +All endpoints: +``` +http://localhost:8080/health — API Deployment worker +http://localhost:8081/health — General worker +http://localhost:8082/health — File Processing worker +http://localhost:8083/health — Callback worker +http://localhost:8084/health — Log Consumer worker +http://localhost:8085/health — Notification worker +http://localhost:8087/health — Scheduler worker +http://localhost:8088/health — Executor worker +``` + +--- + +## 9. Debugging & Troubleshooting + +### 9.1 Common Issues + +**"Connection refused" to RabbitMQ/Redis/DB** +- Check Docker containers are running: `docker ps` +- Check if using Docker hostnames vs localhost (see Section 5.1) +- Ensure ports are exposed: `docker port unstract-rabbitmq` + +**Executor worker starts but tasks don't execute** +- Check queue binding: Worker must listen on `executor` queue +- Check RabbitMQ UI (http://localhost:15672) → Queues tab → look for `executor` queue +- Check task is registered: `celery -A worker inspect registered` +- Check task routing in `workers/shared/infrastructure/config/registry.py` + +**"Module not found" errors in executor worker** +- Ensure `PYTHONPATH` includes the workers directory +- If running locally, `cd workers` before starting +- If using `run-worker.sh`, it sets PYTHONPATH automatically + +**MinIO file access errors** +- Check `WORKFLOW_EXECUTION_FILE_STORAGE_CREDENTIALS` has correct endpoint (localhost vs Docker hostname) +- Verify MinIO bucket exists: `mc ls minio/unstract/` +- MinIO bootstrap container creates the bucket on first start + +**Platform service connection errors** +- Executor needs `PLATFORM_SERVICE_HOST` and `PLATFORM_SERVICE_PORT` +- Verify platform-service is running: `curl http://localhost:3001/health` + +### 9.2 Useful Debug Commands + +```bash +# Check all Docker containers +docker ps --format "table {{.Names}}\t{{.Status}}\t{{.Ports}}" + +# Check RabbitMQ queues +docker exec unstract-rabbitmq rabbitmqctl list_queues name messages consumers + +# Check Celery worker status (from workers/ dir) +cd workers && uv run celery -A worker inspect active + +# Check registered tasks +cd workers && uv run celery -A worker inspect registered + +# Send a test task to executor +cd workers && uv run python -c " +from worker import app +from shared.enums.task_enums import TaskName +result = app.send_task( + TaskName.EXECUTE_EXTRACTION, + args=[{ + 'executor_name': 'legacy', + 'operation': 'extract', + 'run_id': 'test-123', + 'execution_source': 'tool', + 'executor_params': {} + }], + queue='executor' +) +print(f'Task ID: {result.id}') +print(f'Result: {result.get(timeout=30)}') +" + +# Monitor Celery events in real-time +cd workers && uv run celery -A worker events + +# Check Postgres (Celery result backend) +docker exec -it unstract-db psql -U unstract_dev -d unstract_db -c "SELECT task_id, status FROM public.celery_taskmeta ORDER BY date_done DESC LIMIT 10;" +``` + +### 9.3 Log Locations + +| Context | Location | +|---|---| +| Docker container | `docker logs ` | +| Local worker (foreground) | stdout/stderr | +| Local worker (detached) | `workers//.log` | +| Backend | `docker logs unstract-backend` | + +--- + +## 10. Test Verification Checklist + +### Phase 1 Sanity (Executor Framework) + +- [ ] Executor worker starts and connects to Celery broker +- [ ] Health check responds: `curl http://localhost:8088/health` +- [ ] `execute_extraction` task is registered in Celery +- [ ] No-op task dispatch round-trips successfully +- [ ] Task routing: task goes to `executor` queue, processed by executor worker + +### Phase 2 Sanity (LegacyExecutor) + +- [ ] `extract` operation returns `{"extracted_text": "..."}` +- [ ] `index` operation returns `{"doc_id": "..."}` +- [ ] `answer_prompt` returns `{"output": {...}, "metadata": {...}, "metrics": {...}}` +- [ ] `single_pass_extraction` returns same shape as answer_prompt +- [ ] `summarize` returns `{"data": "..."}` +- [ ] Error cases return `ExecutionResult(success=False, error="...")` not unhandled exceptions + +### Phase 3 Sanity (Structure Tool as Celery Task) + +- [ ] Run workflow with structure tool via new Celery path +- [ ] Compare output with Docker-based structure tool output +- [ ] Non-structure tools still work via Docker/Runner (regression check) + +### Phase 4 Sanity (IDE Path) + +- [ ] Open Prompt Studio IDE, create/load a project +- [ ] Run extraction on a document — result displays correctly +- [ ] Run prompt answering — output persists in DB +- [ ] Error cases display properly in IDE + +### Phase 5 Sanity (Decommission) + +- [ ] `docker compose up` boots cleanly — no errors from missing services +- [ ] No dangling references to prompt-service, PromptTool, PROMPT_HOST, PROMPT_PORT +- [ ] All health checks pass + +### Running Unit Tests + +```bash +# SDK1 tests (execution framework) +cd /home/harini/Documents/Workspace/unstract-poc/clean/unstract/unstract/sdk1 +/home/harini/Documents/Workspace/unstract-poc/clean/unstract/backend/venv/bin/uv run pytest -v + +# Workers tests (executor, LegacyExecutor, retrievers, etc.) +cd /home/harini/Documents/Workspace/unstract-poc/clean/unstract/workers +/home/harini/Documents/Workspace/unstract-poc/clean/unstract/backend/venv/bin/uv run pytest -v +``` + +--- + +## Quick Reference: One-Liner Setup + +```bash +# From repo root: +cd docker + +# 1. Build images +docker compose -f docker-compose.build.yaml build + +# 2. Start everything with V2 workers +docker compose --profile workers-v2 up -d + +# 3. Verify +docker ps --format "table {{.Names}}\t{{.Status}}" + +# 4. Check executor health +curl -s http://localhost:8092/health # 8092 = external Docker port for executor +``` + +For the automated version, use the setup check script: `scripts/check-local-setup.sh` diff --git a/frontend/src/components/custom-tools/manage-docs-modal/ManageDocsModal.jsx b/frontend/src/components/custom-tools/manage-docs-modal/ManageDocsModal.jsx index 29d7d533e2..c1ccfb898c 100644 --- a/frontend/src/components/custom-tools/manage-docs-modal/ManageDocsModal.jsx +++ b/frontend/src/components/custom-tools/manage-docs-modal/ManageDocsModal.jsx @@ -219,9 +219,13 @@ function ManageDocsModal({ newMessages = newMessages.slice(0, lastIndex); } - // Filter only INFO and ERROR logs + // Filter only INFO and ERROR logs that are NOT from answer_prompt. + // Answer prompt messages carry a prompt_key in their component; + // indexing messages do not. newMessages = newMessages.filter( - (item) => item?.level === "INFO" || item?.level === "ERROR", + (item) => + (item?.level === "INFO" || item?.level === "ERROR") && + !item?.component?.prompt_key, ); // If there are no new INFO or ERROR messages, return early diff --git a/frontend/src/components/custom-tools/prompt-card/DisplayPromptResult.jsx b/frontend/src/components/custom-tools/prompt-card/DisplayPromptResult.jsx index dff233f5bc..17006246b3 100644 --- a/frontend/src/components/custom-tools/prompt-card/DisplayPromptResult.jsx +++ b/frontend/src/components/custom-tools/prompt-card/DisplayPromptResult.jsx @@ -25,6 +25,7 @@ function DisplayPromptResult({ wordConfidenceData, isTable = false, setOpenExpandModal = () => {}, + progressMsg, }) { const [isLoading, setIsLoading] = useState(false); const [parsedOutput, setParsedOutput] = useState(null); @@ -66,7 +67,19 @@ function DisplayPromptResult({ ]); if (isLoading) { - return } />; + return ( +
+ } /> + {progressMsg?.message && ( + + {progressMsg.message} + + )} +
+ ); } if (output === undefined) { @@ -427,6 +440,7 @@ DisplayPromptResult.propTypes = { wordConfidenceData: PropTypes.object, isTable: PropTypes.bool, setOpenExpandModal: PropTypes.func, + progressMsg: PropTypes.object, }; export { DisplayPromptResult }; diff --git a/frontend/src/components/custom-tools/prompt-card/PromptCard.css b/frontend/src/components/custom-tools/prompt-card/PromptCard.css index f5b4a66a04..72ba8bcb70 100644 --- a/frontend/src/components/custom-tools/prompt-card/PromptCard.css +++ b/frontend/src/components/custom-tools/prompt-card/PromptCard.css @@ -325,3 +325,17 @@ .prompt-output-result { font-size: 12px; } + +.prompt-loading-container { + display: flex; + align-items: center; + gap: 8px; +} + +.prompt-progress-msg { + font-size: 12px; + max-width: 300px; + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} diff --git a/frontend/src/components/custom-tools/prompt-card/PromptCard.jsx b/frontend/src/components/custom-tools/prompt-card/PromptCard.jsx index a4188ee4c0..8c481d4c9c 100644 --- a/frontend/src/components/custom-tools/prompt-card/PromptCard.jsx +++ b/frontend/src/components/custom-tools/prompt-card/PromptCard.jsx @@ -74,7 +74,8 @@ const PromptCard = memo( .find( (item) => (item?.component?.prompt_id === promptDetailsState?.prompt_id || - item?.component?.prompt_key === promptKey) && + item?.component?.prompt_key === promptKey || + item?.component?.tool_id === details?.tool_id) && (item?.level === "INFO" || item?.level === "ERROR"), ); diff --git a/frontend/src/components/custom-tools/prompt-card/PromptCardItems.jsx b/frontend/src/components/custom-tools/prompt-card/PromptCardItems.jsx index 67a19905bd..18cb947acb 100644 --- a/frontend/src/components/custom-tools/prompt-card/PromptCardItems.jsx +++ b/frontend/src/components/custom-tools/prompt-card/PromptCardItems.jsx @@ -317,6 +317,7 @@ function PromptCardItems({ promptRunStatus={promptRunStatus} isChallenge={isChallenge} handleSelectHighlight={handleSelectHighlight} + progressMsg={progressMsg} /> diff --git a/frontend/src/components/custom-tools/prompt-card/PromptOutput.jsx b/frontend/src/components/custom-tools/prompt-card/PromptOutput.jsx index 0e3e350b46..debc1df762 100644 --- a/frontend/src/components/custom-tools/prompt-card/PromptOutput.jsx +++ b/frontend/src/components/custom-tools/prompt-card/PromptOutput.jsx @@ -66,6 +66,7 @@ function PromptOutput({ promptRunStatus, isChallenge, handleSelectHighlight, + progressMsg, }) { const [openExpandModal, setOpenExpandModal] = useState(false); const { width: windowWidth } = useWindowDimensions(); @@ -111,6 +112,7 @@ function PromptOutput({ promptDetails={promptDetails} isTable={true} setOpenExpandModal={setOpenExpandModal} + progressMsg={progressMsg} />
state.activeApis); - const queue = usePromptRunQueueStore((state) => state.queue); - const setPromptRunQueue = usePromptRunQueueStore( - (state) => state.setPromptRunQueue, - ); - const { runPrompt, syncPromptRunApisAndStatus } = usePromptRun(); - const promptRunStatus = usePromptRunStatusStore( - (state) => state.promptRunStatus, - ); - const updateCustomTool = useCustomToolStore( - (state) => state.updateCustomTool, - ); - - useEffect(() => { - // Retrieve queue from cookies on component load - const queueData = Cookies.get("promptRunQueue"); - if (queueData && JSON.parse(queueData)?.length) { - const promptApis = JSON.parse(queueData); - syncPromptRunApisAndStatus(promptApis); - } - - // Setup the beforeunload event handler to store queue in cookies - const handleBeforeUnload = () => { - if (!PROMPT_RUN_STATE_PERSISTENCE) return; - const { queue } = usePromptRunQueueStore.getState(); // Get the latest state dynamically - if (queue?.length) { - Cookies.set("promptRunQueue", JSON.stringify(queue), { - expires: 5 / 1440, // Expire in 5 minutes - }); - } - }; - - window.addEventListener("beforeunload", handleBeforeUnload); - - return () => { - window.removeEventListener("beforeunload", handleBeforeUnload); // Clean up event listener - }; - }, [syncPromptRunApisAndStatus]); - - useEffect(() => { - if (!queue?.length || activeApis >= MAX_ACTIVE_APIS) return; - - const canRunApis = MAX_ACTIVE_APIS - activeApis; - const apisToRun = queue.slice(0, canRunApis); - - setPromptRunQueue({ - activeApis: activeApis + apisToRun.length, - queue: queue.slice(apisToRun.length), - }); - runPrompt(apisToRun); - }, [activeApis, queue, setPromptRunQueue, runPrompt]); - - useEffect(() => { - const isMultiPassExtractLoading = !!Object.keys(promptRunStatus).length; - updateCustomTool({ isMultiPassExtractLoading }); - }, [promptRunStatus, updateCustomTool]); + const { sessionDetails } = useSessionStore(); + const isAsync = !!sessionDetails?.flags?.async_prompt_execution; - return null; + return isAsync ? : ; } export { PromptRun }; diff --git a/frontend/src/components/custom-tools/prompt-card/PromptRunAsync.jsx b/frontend/src/components/custom-tools/prompt-card/PromptRunAsync.jsx new file mode 100644 index 0000000000..88b894f1cb --- /dev/null +++ b/frontend/src/components/custom-tools/prompt-card/PromptRunAsync.jsx @@ -0,0 +1,79 @@ +import Cookies from "js-cookie"; +import { useEffect } from "react"; +import { usePromptRunQueueStore } from "../../../store/prompt-run-queue-store"; +import usePromptRun from "../../../hooks/usePromptRun"; +import usePromptStudioSocket from "../../../hooks/usePromptStudioSocket"; +import { useCustomToolStore } from "../../../store/custom-tool-store"; +import { usePromptRunStatusStore } from "../../../store/prompt-run-status-store"; + +const MAX_ACTIVE_APIS = 5; +/* + Change this to 'true' to allow persistence of the prompt run state + Right now, this feature cannot be support as the prompt studio details + are not persisted accoss the entire application. +*/ +const PROMPT_RUN_STATE_PERSISTENCE = false; + +function PromptRunAsync() { + const activeApis = usePromptRunQueueStore((state) => state.activeApis); + const queue = usePromptRunQueueStore((state) => state.queue); + const setPromptRunQueue = usePromptRunQueueStore( + (state) => state.setPromptRunQueue, + ); + const { runPrompt, syncPromptRunApisAndStatus } = usePromptRun(); + usePromptStudioSocket(); + const promptRunStatus = usePromptRunStatusStore( + (state) => state.promptRunStatus, + ); + const updateCustomTool = useCustomToolStore( + (state) => state.updateCustomTool, + ); + + useEffect(() => { + // Retrieve queue from cookies on component load + const queueData = Cookies.get("promptRunQueue"); + if (queueData && JSON.parse(queueData)?.length) { + const promptApis = JSON.parse(queueData); + syncPromptRunApisAndStatus(promptApis); + } + + // Setup the beforeunload event handler to store queue in cookies + const handleBeforeUnload = () => { + if (!PROMPT_RUN_STATE_PERSISTENCE) return; + const { queue } = usePromptRunQueueStore.getState(); // Get the latest state dynamically + if (queue?.length) { + Cookies.set("promptRunQueue", JSON.stringify(queue), { + expires: 5 / 1440, // Expire in 5 minutes + }); + } + }; + + window.addEventListener("beforeunload", handleBeforeUnload); + + return () => { + window.removeEventListener("beforeunload", handleBeforeUnload); // Clean up event listener + }; + }, [syncPromptRunApisAndStatus]); + + useEffect(() => { + if (!queue?.length || activeApis >= MAX_ACTIVE_APIS) return; + + const canRunApis = MAX_ACTIVE_APIS - activeApis; + const apisToRun = queue.slice(0, canRunApis); + + setPromptRunQueue({ + activeApis: activeApis + apisToRun.length, + queue: queue.slice(apisToRun.length), + }); + runPrompt(apisToRun); + }, [activeApis, queue, setPromptRunQueue, runPrompt]); + + useEffect(() => { + const isMultiPassExtractLoading = !!Object.keys(promptRunStatus).length; + updateCustomTool({ isMultiPassExtractLoading }); + }, [promptRunStatus, updateCustomTool]); + + return null; +} + +export { PromptRunAsync }; diff --git a/frontend/src/components/custom-tools/prompt-card/PromptRunSync.jsx b/frontend/src/components/custom-tools/prompt-card/PromptRunSync.jsx new file mode 100644 index 0000000000..c4ec189fa5 --- /dev/null +++ b/frontend/src/components/custom-tools/prompt-card/PromptRunSync.jsx @@ -0,0 +1,77 @@ +import Cookies from "js-cookie"; +import { useEffect } from "react"; +import { usePromptRunQueueStore } from "../../../store/prompt-run-queue-store"; +import usePromptRunSync from "../../../hooks/usePromptRunSync"; +import { useCustomToolStore } from "../../../store/custom-tool-store"; +import { usePromptRunStatusStore } from "../../../store/prompt-run-status-store"; + +const MAX_ACTIVE_APIS = 5; +/* + Change this to 'true' to allow persistence of the prompt run state + Right now, this feature cannot be support as the prompt studio details + are not persisted accoss the entire application. +*/ +const PROMPT_RUN_STATE_PERSISTENCE = false; + +function PromptRunSync() { + const activeApis = usePromptRunQueueStore((state) => state.activeApis); + const queue = usePromptRunQueueStore((state) => state.queue); + const setPromptRunQueue = usePromptRunQueueStore( + (state) => state.setPromptRunQueue, + ); + const { runPrompt, syncPromptRunApisAndStatus } = usePromptRunSync(); + const promptRunStatus = usePromptRunStatusStore( + (state) => state.promptRunStatus, + ); + const updateCustomTool = useCustomToolStore( + (state) => state.updateCustomTool, + ); + + useEffect(() => { + // Retrieve queue from cookies on component load + const queueData = Cookies.get("promptRunQueue"); + if (queueData && JSON.parse(queueData)?.length) { + const promptApis = JSON.parse(queueData); + syncPromptRunApisAndStatus(promptApis); + } + + // Setup the beforeunload event handler to store queue in cookies + const handleBeforeUnload = () => { + if (!PROMPT_RUN_STATE_PERSISTENCE) return; + const { queue } = usePromptRunQueueStore.getState(); // Get the latest state dynamically + if (queue?.length) { + Cookies.set("promptRunQueue", JSON.stringify(queue), { + expires: 5 / 1440, // Expire in 5 minutes + }); + } + }; + + window.addEventListener("beforeunload", handleBeforeUnload); + + return () => { + window.removeEventListener("beforeunload", handleBeforeUnload); // Clean up event listener + }; + }, [syncPromptRunApisAndStatus]); + + useEffect(() => { + if (!queue?.length || activeApis >= MAX_ACTIVE_APIS) return; + + const canRunApis = MAX_ACTIVE_APIS - activeApis; + const apisToRun = queue.slice(0, canRunApis); + + setPromptRunQueue({ + activeApis: activeApis + apisToRun.length, + queue: queue.slice(apisToRun.length), + }); + runPrompt(apisToRun); + }, [activeApis, queue, setPromptRunQueue, runPrompt]); + + useEffect(() => { + const isMultiPassExtractLoading = !!Object.keys(promptRunStatus).length; + updateCustomTool({ isMultiPassExtractLoading }); + }, [promptRunStatus, updateCustomTool]); + + return null; +} + +export { PromptRunSync }; diff --git a/frontend/src/components/custom-tools/tool-ide/ToolIde.jsx b/frontend/src/components/custom-tools/tool-ide/ToolIde.jsx index 3b7fa94291..9c54c35d3a 100644 --- a/frontend/src/components/custom-tools/tool-ide/ToolIde.jsx +++ b/frontend/src/components/custom-tools/tool-ide/ToolIde.jsx @@ -266,27 +266,23 @@ function ToolIde() { pushIndexDoc(docId); return axiosPrivate(requestOptions) - .then(() => { + .then((res) => { + if (res?.status === 202) { + // Async path — 202 means accepted, spinner stays until socket event + return; + } + // Sync path — 200 means done + deleteIndexDoc(docId); setAlertDetails({ type: "success", content: `${doc?.document_name} - Indexed successfully`, }); - - try { - setPostHogCustomEvent("intent_success_ps_indexed_file", { - info: "Indexing completed", - }); - } catch (err) { - // If an error occurs while setting custom posthog event, ignore it and continue - } }) .catch((err) => { + deleteIndexDoc(docId); setAlertDetails( handleException(err, `${doc?.document_name} - Failed to index`), ); - }) - .finally(() => { - deleteIndexDoc(docId); }); }; diff --git a/frontend/src/components/helpers/socket-messages/SocketMessages.js b/frontend/src/components/helpers/socket-messages/SocketMessages.js index e2843e60d8..36cb9b9d60 100644 --- a/frontend/src/components/helpers/socket-messages/SocketMessages.js +++ b/frontend/src/components/helpers/socket-messages/SocketMessages.js @@ -12,6 +12,7 @@ import { SocketContext } from "../../../helpers/SocketContext"; import { useExceptionHandler } from "../../../hooks/useExceptionHandler"; import { useAlertStore } from "../../../store/alert-store"; import { useSessionStore } from "../../../store/session-store"; +import { useSocketCustomToolStore } from "../../../store/socket-custom-tool"; import { useSocketLogsStore } from "../../../store/socket-logs-store"; import { useSocketMessagesStore } from "../../../store/socket-messages-store"; import { useUsageStore } from "../../../store/usage-store"; @@ -28,6 +29,7 @@ function SocketMessages() { setPointer, } = useSocketMessagesStore(); const { pushLogMessages } = useSocketLogsStore(); + const { updateCusToolMessages } = useSocketCustomToolStore(); const { sessionDetails } = useSessionStore(); const socket = useContext(SocketContext); const { setAlertDetails } = useAlertStore(); @@ -89,6 +91,8 @@ function SocketMessages() { pushStagedMessage(msg); } else if (msg?.type === "LOG" && msg?.service === "prompt") { handleLogMessages(msg); + } else if (msg?.type === "PROGRESS") { + updateCusToolMessages([msg]); } if (msg?.type === "LOG" && msg?.service === "usage") { @@ -102,7 +106,7 @@ function SocketMessages() { ); } }, - [handleLogMessages, pushStagedMessage], + [handleLogMessages, pushStagedMessage, updateCusToolMessages], ); // Subscribe/unsubscribe to the socket channel diff --git a/frontend/src/helpers/SocketContext.js b/frontend/src/helpers/SocketContext.js index c2c9caa0be..6e6ace9a63 100644 --- a/frontend/src/helpers/SocketContext.js +++ b/frontend/src/helpers/SocketContext.js @@ -10,17 +10,15 @@ const SocketProvider = ({ children }) => { const [socket, setSocket] = useState(null); useEffect(() => { - let baseUrl = ""; - const body = { + // Always connect to the same origin as the page. + // - Dev: CRA proxy (ws: true in setupProxy.js) forwards to the backend. + // - Prod: Traefik routes /api/v1/socket to the backend. + // This ensures session cookies are sent (same-origin) and avoids + // cross-origin WebSocket issues. + const newSocket = io(getBaseUrl(), { transports: ["websocket"], path: "/api/v1/socket", - }; - if (!import.meta.env.MODE || import.meta.env.MODE === "development") { - baseUrl = import.meta.env.VITE_BACKEND_URL; - } else { - baseUrl = getBaseUrl(); - } - const newSocket = io(baseUrl, body); + }); setSocket(newSocket); // Clean up the socket connection on browser unload window.onbeforeunload = () => { diff --git a/frontend/src/hooks/usePromptRun.js b/frontend/src/hooks/usePromptRun.js index 753c128a8e..11e83dc42f 100644 --- a/frontend/src/hooks/usePromptRun.js +++ b/frontend/src/hooks/usePromptRun.js @@ -16,11 +16,9 @@ import usePromptOutput from "./usePromptOutput"; const usePromptRun = () => { const { pushPromptRunApi, freeActiveApi } = usePromptRunQueueStore(); - const { generatePromptOutputKey, updatePromptOutputState } = - usePromptOutput(); + const { generatePromptOutputKey } = usePromptOutput(); const { addPromptStatus, removePromptStatus } = usePromptRunStatusStore(); - const { details, llmProfiles, listOfDocs, selectedDoc } = - useCustomToolStore(); + const { details, llmProfiles, listOfDocs } = useCustomToolStore(); const { sessionDetails } = useSessionStore(); const axiosPrivate = useAxiosPrivate(); const { setAlertDetails } = useAlertStore(); @@ -28,6 +26,8 @@ const usePromptRun = () => { const makeApiRequest = (requestOptions) => axiosPrivate(requestOptions); + const SOCKET_TIMEOUT_MS = 5 * 60 * 1000; // 5 minutes + const runPromptApi = (api) => { const [promptId, docId, profileId] = api.split("__"); const runId = generateUUID(); @@ -49,32 +49,33 @@ const usePromptRun = () => { data: body, }; - const startTime = Date.now(); - const maxWaitTime = 30 * 1000; // 30 seconds - const pollingInterval = 5000; // 5 seconds - - pollForCompletion( - startTime, - requestOptions, - maxWaitTime, - pollingInterval, - makeApiRequest, - ) - .then((res) => { - if (docId !== selectedDoc?.document_id) return; - const data = res?.data || []; - const timeTakenInSeconds = Math.floor((Date.now() - startTime) / 1000); - updatePromptOutputState(data, false, timeTakenInSeconds); + // Fire-and-forget: POST dispatches the Celery task, socket delivers result. + makeApiRequest(requestOptions) + .then(() => { + // Timeout safety net: clear stale status if socket event never arrives. + setTimeout(() => { + const statusKey = generateApiRunStatusId(docId, profileId); + const current = usePromptRunStatusStore.getState().promptRunStatus; + if ( + current?.[promptId]?.[statusKey] === PROMPT_RUN_API_STATUSES.RUNNING + ) { + removePromptStatus(promptId, statusKey); + setAlertDetails({ + type: "warning", + content: "Prompt execution timed out. Please try again.", + }); + } + }, SOCKET_TIMEOUT_MS); }) .catch((err) => { setAlertDetails( handleException(err, "Failed to generate prompt output"), ); + const statusKey = generateApiRunStatusId(docId, profileId); + removePromptStatus(promptId, statusKey); }) .finally(() => { freeActiveApi(); - const statusKey = generateApiRunStatusId(docId, profileId); - removePromptStatus(promptId, statusKey); }); }; diff --git a/frontend/src/hooks/usePromptRunSync.js b/frontend/src/hooks/usePromptRunSync.js new file mode 100644 index 0000000000..98d39f550f --- /dev/null +++ b/frontend/src/hooks/usePromptRunSync.js @@ -0,0 +1,219 @@ +import { + generateApiRunStatusId, + generateUUID, + pollForCompletion, + PROMPT_RUN_API_STATUSES, + PROMPT_RUN_TYPES, +} from "../helpers/GetStaticData"; +import { useAlertStore } from "../store/alert-store"; +import { useCustomToolStore } from "../store/custom-tool-store"; +import { usePromptRunQueueStore } from "../store/prompt-run-queue-store"; +import { usePromptRunStatusStore } from "../store/prompt-run-status-store"; +import { useSessionStore } from "../store/session-store"; +import { useAxiosPrivate } from "./useAxiosPrivate"; +import { useExceptionHandler } from "./useExceptionHandler"; +import usePromptOutput from "./usePromptOutput"; + +const usePromptRunSync = () => { + const { pushPromptRunApi, freeActiveApi } = usePromptRunQueueStore(); + const { generatePromptOutputKey, updatePromptOutputState } = + usePromptOutput(); + const { addPromptStatus, removePromptStatus } = usePromptRunStatusStore(); + const { details, llmProfiles, listOfDocs, selectedDoc } = + useCustomToolStore(); + const { sessionDetails } = useSessionStore(); + const axiosPrivate = useAxiosPrivate(); + const { setAlertDetails } = useAlertStore(); + const handleException = useExceptionHandler(); + + const makeApiRequest = (requestOptions) => axiosPrivate(requestOptions); + + const runPromptApi = (api) => { + const [promptId, docId, profileId] = api.split("__"); + const runId = generateUUID(); + + const body = { + id: promptId, + document_id: docId, + profile_manager: profileId, + run_id: runId, + }; + + const requestOptions = { + method: "POST", + url: `/api/v1/unstract/${sessionDetails?.orgId}/prompt-studio/fetch_response/${details?.tool_id}`, + headers: { + "X-CSRFToken": sessionDetails?.csrfToken, + "Content-Type": "application/json", + }, + data: body, + }; + + const startTime = Date.now(); + const maxWaitTime = 30 * 1000; // 30 seconds + const pollingInterval = 5000; // 5 seconds + + pollForCompletion( + startTime, + requestOptions, + maxWaitTime, + pollingInterval, + makeApiRequest, + ) + .then((res) => { + if (docId !== selectedDoc?.document_id) return; + const data = res?.data || []; + const timeTakenInSeconds = Math.floor((Date.now() - startTime) / 1000); + updatePromptOutputState(data, false, timeTakenInSeconds); + }) + .catch((err) => { + setAlertDetails( + handleException(err, "Failed to generate prompt output"), + ); + }) + .finally(() => { + freeActiveApi(); + const statusKey = generateApiRunStatusId(docId, profileId); + removePromptStatus(promptId, statusKey); + }); + }; + + const runPrompt = (listOfApis) => { + if (!listOfApis?.length) return; + listOfApis.forEach(runPromptApi); + }; + + const prepareApiRequests = (promptIds, profileIds, docIds) => { + const apiRequestsToQueue = []; + const promptRunApiStatus = []; + const combinations = []; + + for (const promptId of promptIds) { + for (const profileId of profileIds) { + for (const docId of docIds) { + combinations.push({ promptId, profileId, docId }); + } + } + } + + combinations.forEach(({ promptId, profileId, docId }) => { + if (!promptRunApiStatus[promptId]) { + promptRunApiStatus[promptId] = {}; + } + const key = generatePromptOutputKey( + promptId, + docId, + profileId, + null, + false, + ); + const statusKey = generateApiRunStatusId(docId, profileId); + promptRunApiStatus[promptId][statusKey] = PROMPT_RUN_API_STATUSES.RUNNING; + apiRequestsToQueue.push(key); + }); + + return { apiRequestsToQueue, promptRunApiStatus }; + }; + + const syncPromptRunApisAndStatus = (promptApis) => { + const promptRunApiStatus = {}; + + promptApis.forEach((apiDetails) => { + const [promptId, docId, profileId] = apiDetails.split("__"); + const statusKey = generateApiRunStatusId(docId, profileId); + + if (!promptRunApiStatus[promptId]) { + promptRunApiStatus[promptId] = {}; + } + + promptRunApiStatus[promptId][statusKey] = PROMPT_RUN_API_STATUSES.RUNNING; + }); + + addPromptStatus(promptRunApiStatus); + pushPromptRunApi(promptApis); + }; + + const handlePromptRunRequest = ( + promptRunType, + promptId = null, + profileId = null, + docId = null, + ) => { + const promptIds = promptId + ? [promptId] + : details?.prompts.map((p) => p.prompt_id) || []; + const profileIds = profileId + ? [profileId] + : llmProfiles.map((p) => p.profile_id) || []; + const docIds = docId ? [docId] : listOfDocs.map((d) => d.document_id) || []; + + let apiRequestsToQueue = []; + let promptRunApiStatus = {}; + + const paramsMap = { + [PROMPT_RUN_TYPES.RUN_ONE_PROMPT_ONE_LLM_ONE_DOC]: { + requiredParams: ["promptId", "profileId", "docId"], + prompts: [promptId], + profiles: [profileId], + docs: [docId], + }, + [PROMPT_RUN_TYPES.RUN_ONE_PROMPT_ONE_LLM_ALL_DOCS]: { + requiredParams: ["promptId", "profileId"], + prompts: [promptId], + profiles: [profileId], + docs: docIds, + }, + [PROMPT_RUN_TYPES.RUN_ONE_PROMPT_ALL_LLMS_ONE_DOC]: { + requiredParams: ["promptId", "docId"], + prompts: [promptId], + profiles: profileIds, + docs: [docId], + }, + [PROMPT_RUN_TYPES.RUN_ONE_PROMPT_ALL_LLMS_ALL_DOCS]: { + requiredParams: ["promptId"], + prompts: [promptId], + profiles: profileIds, + docs: docIds, + }, + [PROMPT_RUN_TYPES.RUN_ALL_PROMPTS_ALL_LLMS_ONE_DOC]: { + requiredParams: ["docId"], + prompts: promptIds, + profiles: profileIds, + docs: [docId], + }, + [PROMPT_RUN_TYPES.RUN_ALL_PROMPTS_ALL_LLMS_ALL_DOCS]: { + requiredParams: [], + prompts: promptIds, + profiles: profileIds, + docs: docIds, + }, + }; + + const params = paramsMap[promptRunType]; + if (!params) return; + + const paramValues = { promptId, profileId, docId }; + const missingParams = params.requiredParams.filter( + (param) => !paramValues[param], + ); + + if (missingParams.length > 0) return; + + ({ apiRequestsToQueue, promptRunApiStatus } = prepareApiRequests( + params.prompts, + params.profiles, + params.docs, + )); + + addPromptStatus(promptRunApiStatus); + pushPromptRunApi(apiRequestsToQueue); + }; + + return { + runPrompt, + handlePromptRunRequest, + syncPromptRunApisAndStatus, + }; +}; + +export default usePromptRunSync; diff --git a/frontend/src/hooks/usePromptStudioSocket.js b/frontend/src/hooks/usePromptStudioSocket.js new file mode 100644 index 0000000000..c5ffa3c765 --- /dev/null +++ b/frontend/src/hooks/usePromptStudioSocket.js @@ -0,0 +1,156 @@ +import { useContext, useEffect, useCallback } from "react"; + +import { SocketContext } from "../helpers/SocketContext"; +import { generateApiRunStatusId } from "../helpers/GetStaticData"; +import { useAlertStore } from "../store/alert-store"; +import { useCustomToolStore } from "../store/custom-tool-store"; +import { usePromptRunStatusStore } from "../store/prompt-run-status-store"; +import { useExceptionHandler } from "./useExceptionHandler"; +import usePromptOutput from "./usePromptOutput"; + +const PROMPT_STUDIO_RESULT_EVENT = "prompt_studio_result"; + +/** + * Hook that listens for `prompt_studio_result` Socket.IO events emitted by + * backend Celery tasks (fetch_response, single_pass_extraction, index_document). + * + * On completion it feeds the result into the prompt-output store and clears + * the corresponding run-status entries so the UI stops showing spinners. + */ +const usePromptStudioSocket = () => { + const socket = useContext(SocketContext); + const { removePromptStatus, clearPromptStatusById } = + usePromptRunStatusStore(); + const { updateCustomTool, deleteIndexDoc } = useCustomToolStore(); + const { setAlertDetails } = useAlertStore(); + const handleException = useExceptionHandler(); + const { updatePromptOutputState } = usePromptOutput(); + + const clearResultStatuses = useCallback( + (data) => { + if (!Array.isArray(data)) return; + data.forEach((item) => { + const promptId = item?.prompt_id; + const docId = item?.document_manager; + const profileId = item?.profile_manager; + if (promptId && docId && profileId) { + const statusKey = generateApiRunStatusId(docId, profileId); + removePromptStatus(promptId, statusKey); + } + }); + }, + [removePromptStatus] + ); + + const handleCompleted = useCallback( + (operation, result) => { + if (operation === "fetch_response") { + const data = Array.isArray(result) ? result : []; + updatePromptOutputState(data, false); + clearResultStatuses(data); + setAlertDetails({ + type: "success", + content: "Prompt execution completed successfully.", + }); + } else if (operation === "single_pass_extraction") { + const data = Array.isArray(result) ? result : []; + updatePromptOutputState(data, false); + updateCustomTool({ isSinglePassExtractLoading: false }); + clearResultStatuses(data); + setAlertDetails({ + type: "success", + content: "Single pass extraction completed successfully.", + }); + } else if (operation === "index_document") { + const docId = result?.document_id; + if (docId) deleteIndexDoc(docId); + setAlertDetails({ + type: "success", + content: result?.message || "Document indexed successfully.", + }); + } + }, + [ + updatePromptOutputState, + clearResultStatuses, + updateCustomTool, + setAlertDetails, + deleteIndexDoc, + ] + ); + + const handleFailed = useCallback( + (operation, error, extra) => { + setAlertDetails({ + type: "error", + content: error || `${operation} failed`, + }); + if (operation === "single_pass_extraction") { + updateCustomTool({ isSinglePassExtractLoading: false }); + } else if (operation === "index_document") { + const docId = extra?.document_id; + if (docId) deleteIndexDoc(docId); + } + + // Clear spinner for prompt operations so buttons re-enable + if ( + operation === "fetch_response" || + operation === "single_pass_extraction" + ) { + const promptIds = extra?.prompt_ids || []; + const docId = extra?.document_id; + const profileId = extra?.profile_manager_id; + if (docId && profileId) { + // Specific clearing (ideal path) + const statusKey = generateApiRunStatusId(docId, profileId); + promptIds.forEach((promptId) => { + removePromptStatus(promptId, statusKey); + }); + } else { + // Fallback: clear ALL statuses for these prompts + promptIds.forEach((promptId) => { + clearPromptStatusById(promptId); + }); + } + } + }, + [ + setAlertDetails, + updateCustomTool, + deleteIndexDoc, + removePromptStatus, + clearPromptStatusById, + ] + ); + + const onResult = useCallback( + (payload) => { + try { + const msg = payload?.data || payload; + const { status, operation, result, error, ...extra } = msg; + + if (status === "completed") { + handleCompleted(operation, result); + } else if (status === "failed") { + handleFailed(operation, error, extra); + } + } catch (err) { + setAlertDetails( + handleException(err, "Failed to process prompt studio result") + ); + } + }, + [handleCompleted, handleFailed, setAlertDetails, handleException] + ); + + useEffect(() => { + if (!socket) return; + + socket.on(PROMPT_STUDIO_RESULT_EVENT, onResult); + return () => { + socket.off(PROMPT_STUDIO_RESULT_EVENT, onResult); + }; + }, [socket, onResult]); +}; + +export default usePromptStudioSocket; diff --git a/frontend/src/store/prompt-run-status-store.js b/frontend/src/store/prompt-run-status-store.js index dcc852a502..8c55e27ac9 100644 --- a/frontend/src/store/prompt-run-status-store.js +++ b/frontend/src/store/prompt-run-status-store.js @@ -26,6 +26,13 @@ const usePromptRunStatusStore = create((setState, getState) => ({ return { promptRunStatus: newStatus }; }); }, + clearPromptStatusById: (promptId) => { + setState((state) => { + const newStatus = { ...state.promptRunStatus }; + delete newStatus[promptId]; + return { promptRunStatus: newStatus }; + }); + }, removePromptStatus: (promptId, key) => { setState((state) => { const currentStatus = state.promptRunStatus || {}; diff --git a/unstract/core/src/unstract/core/pubsub_helper.py b/unstract/core/src/unstract/core/pubsub_helper.py index 6f96d9f7c6..d45b1dfd30 100644 --- a/unstract/core/src/unstract/core/pubsub_helper.py +++ b/unstract/core/src/unstract/core/pubsub_helper.py @@ -16,16 +16,16 @@ class LogPublisher: broker_url = str( httpx.URL(os.getenv("CELERY_BROKER_BASE_URL", "amqp://")).copy_with( - username=os.getenv("CELERY_BROKER_USER"), - password=os.getenv("CELERY_BROKER_PASS"), + username=os.getenv("CELERY_BROKER_USER") or None, + password=os.getenv("CELERY_BROKER_PASS") or None, ) ) kombu_conn = Connection(broker_url) r = redis.Redis( host=os.environ.get("REDIS_HOST"), port=os.environ.get("REDIS_PORT", 6379), - username=os.environ.get("REDIS_USER"), - password=os.environ.get("REDIS_PASSWORD"), + username=os.environ.get("REDIS_USER") or None, + password=os.environ.get("REDIS_PASSWORD") or None, ) @staticmethod @@ -91,6 +91,29 @@ def log_workflow_update( "message": message, } + @staticmethod + def log_progress( + component: dict[str, str], + level: str, + state: str, + message: str, + ) -> dict[str, str]: + """Build a progress log message for streaming to the frontend. + + Same structure as ``log_prompt()`` but uses ``type: "PROGRESS"`` + so the frontend can distinguish executor progress from regular + log messages. + """ + return { + "timestamp": datetime.now(UTC).timestamp(), + "type": "PROGRESS", + "service": "prompt", + "component": component, + "level": level, + "state": state, + "message": message, + } + @staticmethod def log_prompt( component: dict[str, str], diff --git a/unstract/flags/src/unstract/flags/feature_flag.py b/unstract/flags/src/unstract/flags/feature_flag.py index 4776f5bd29..58c463e482 100644 --- a/unstract/flags/src/unstract/flags/feature_flag.py +++ b/unstract/flags/src/unstract/flags/feature_flag.py @@ -2,8 +2,6 @@ import logging -from .client.flipt import FliptClient - logger = logging.getLogger(__name__) @@ -30,7 +28,8 @@ def check_feature_flag_status( True if the feature flag is enabled for the entity, False otherwise. """ try: - # Initialize Flipt client + from .client.flipt import FliptClient + client = FliptClient() logger.info(f"Client has been Initialised {client.list_feature_flags()}") @@ -42,6 +41,6 @@ def check_feature_flag_status( context=context or {}, ) - return bool(result.enabled) + return bool(result) except Exception: return False diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/ocr/register.py b/unstract/sdk1/src/unstract/sdk1/adapters/ocr/register.py index fde5558c16..cbc1a6ea67 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/ocr/register.py +++ b/unstract/sdk1/src/unstract/sdk1/adapters/ocr/register.py @@ -45,5 +45,5 @@ def _build_adapter_list(adapter: str, package: str, adapters: dict[str, Any]) -> Common.MODULE: module, Common.METADATA: metadata, } - except ModuleNotFoundError as exception: - logger.warning(f"Unable to import ocr adapters : {exception}") + except Exception as exception: + logger.warning(f"Unable to import OCR adapter '{adapter}': {exception}") diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/vectordb/exceptions.py b/unstract/sdk1/src/unstract/sdk1/adapters/vectordb/exceptions.py index edef6bd043..e44784671e 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/vectordb/exceptions.py +++ b/unstract/sdk1/src/unstract/sdk1/adapters/vectordb/exceptions.py @@ -1,5 +1,3 @@ -from qdrant_client.http.exceptions import ApiException as QdrantAPIException -from unstract.sdk1.adapters.vectordb.qdrant.src import Qdrant from unstract.sdk1.adapters.vectordb.vectordb_adapter import VectorDBAdapter from unstract.sdk1.exceptions import VectorDBError @@ -20,9 +18,18 @@ def parse_vector_db_err(e: Exception, vector_db: VectorDBAdapter) -> VectorDBErr if isinstance(e, VectorDBError): return e - if isinstance(e, QdrantAPIException): - err = Qdrant.parse_vector_db_err(e) - else: + # Lazy import to avoid hard dependency on qdrant_client at module level. + # qdrant_client's protobuf files can fail to load depending on the + # protobuf runtime version (KeyError: '_POINTID'). + try: + from qdrant_client.http.exceptions import ApiException as QdrantAPIException + from unstract.sdk1.adapters.vectordb.qdrant.src import Qdrant + + if isinstance(e, QdrantAPIException): + err = Qdrant.parse_vector_db_err(e) + else: + err = VectorDBError(str(e), actual_err=e) + except Exception: err = VectorDBError(str(e), actual_err=e) msg = f"Error from vector DB '{vector_db.get_name()}'." diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/vectordb/register.py b/unstract/sdk1/src/unstract/sdk1/adapters/vectordb/register.py index 1c551dafe1..05c01d822e 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/vectordb/register.py +++ b/unstract/sdk1/src/unstract/sdk1/adapters/vectordb/register.py @@ -45,5 +45,5 @@ def _build_adapter_list(adapter: str, package: str, adapters: dict[str, Any]) -> Common.MODULE: module, Common.METADATA: metadata, } - except ModuleNotFoundError as exception: - logger.warning(f"Unable to import vectorDB adapters : {exception}") + except Exception as exception: + logger.warning(f"Unable to import vectorDB adapter '{adapter}': {exception}") diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/x2text/llm_whisperer_v2/src/helper.py b/unstract/sdk1/src/unstract/sdk1/adapters/x2text/llm_whisperer_v2/src/helper.py index 8fda907903..14790065ae 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/x2text/llm_whisperer_v2/src/helper.py +++ b/unstract/sdk1/src/unstract/sdk1/adapters/x2text/llm_whisperer_v2/src/helper.py @@ -203,23 +203,31 @@ def get_whisperer_params( ), WhispererConfig.ADD_LINE_NOS: extra_params.enable_highlight, WhispererConfig.INCLUDE_LINE_CONFIDENCE: extra_params.enable_highlight, - # Not providing default value to maintain legacy compatablity - # these are optional params and identifiers for audit - WhispererConfig.TAG: extra_params.tag - or config.get( - WhispererConfig.TAG, - WhispererDefaults.TAG, - ), - WhispererConfig.USE_WEBHOOK: config.get(WhispererConfig.USE_WEBHOOK, ""), - WhispererConfig.WEBHOOK_METADATA: config.get( - WhispererConfig.WEBHOOK_METADATA - ), - WhispererConfig.WAIT_TIMEOUT: config.get( - WhispererConfig.WAIT_TIMEOUT, - WhispererDefaults.WAIT_TIMEOUT, - ), - WhispererConfig.WAIT_FOR_COMPLETION: WhispererDefaults.WAIT_FOR_COMPLETION, } + logger.info( + "HIGHLIGHT_DEBUG whisper params: ADD_LINE_NOS=%s", + params.get(WhispererConfig.ADD_LINE_NOS), + ) + params.update( + { + # Not providing default value to maintain legacy compatablity + # these are optional params and identifiers for audit + WhispererConfig.TAG: extra_params.tag + or config.get( + WhispererConfig.TAG, + WhispererDefaults.TAG, + ), + WhispererConfig.USE_WEBHOOK: config.get(WhispererConfig.USE_WEBHOOK, ""), + WhispererConfig.WEBHOOK_METADATA: config.get( + WhispererConfig.WEBHOOK_METADATA + ), + WhispererConfig.WAIT_TIMEOUT: config.get( + WhispererConfig.WAIT_TIMEOUT, + WhispererDefaults.WAIT_TIMEOUT, + ), + WhispererConfig.WAIT_FOR_COMPLETION: WhispererDefaults.WAIT_FOR_COMPLETION, + } + ) if params[WhispererConfig.MODE] == Modes.LOW_COST.value: params.update( { diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/x2text/llm_whisperer_v2/src/llm_whisperer_v2.py b/unstract/sdk1/src/unstract/sdk1/adapters/x2text/llm_whisperer_v2/src/llm_whisperer_v2.py index 892339a9be..3a48a57647 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/x2text/llm_whisperer_v2/src/llm_whisperer_v2.py +++ b/unstract/sdk1/src/unstract/sdk1/adapters/x2text/llm_whisperer_v2/src/llm_whisperer_v2.py @@ -82,6 +82,10 @@ def process( if fs is None: fs = FileStorage(provider=FileStorageProvider.LOCAL) enable_highlight = kwargs.get(X2TextConstants.ENABLE_HIGHLIGHT, False) + logger.info( + "HIGHLIGHT_DEBUG LLMWhispererV2.process: enable_highlight=%s", + enable_highlight, + ) extra_params = WhispererRequestParams( tag=kwargs.get(X2TextConstants.TAGS), enable_highlight=enable_highlight, diff --git a/unstract/sdk1/src/unstract/sdk1/adapters/x2text/register.py b/unstract/sdk1/src/unstract/sdk1/adapters/x2text/register.py index 48d6a606af..3318887f95 100644 --- a/unstract/sdk1/src/unstract/sdk1/adapters/x2text/register.py +++ b/unstract/sdk1/src/unstract/sdk1/adapters/x2text/register.py @@ -45,5 +45,5 @@ def _build_adapter_list(adapter: str, package: str, adapters: dict[str, Any]) -> Common.MODULE: module, Common.METADATA: metadata, } - except ModuleNotFoundError as exception: - logger.warning(f"Unable to import X2Text adapters : {exception}") + except Exception as exception: + logger.warning(f"Unable to import X2Text adapter '{adapter}': {exception}") diff --git a/unstract/sdk1/src/unstract/sdk1/execution/__init__.py b/unstract/sdk1/src/unstract/sdk1/execution/__init__.py new file mode 100644 index 0000000000..fa70c88821 --- /dev/null +++ b/unstract/sdk1/src/unstract/sdk1/execution/__init__.py @@ -0,0 +1,15 @@ +from unstract.sdk1.execution.context import ExecutionContext +from unstract.sdk1.execution.dispatcher import ExecutionDispatcher +from unstract.sdk1.execution.executor import BaseExecutor +from unstract.sdk1.execution.orchestrator import ExecutionOrchestrator +from unstract.sdk1.execution.registry import ExecutorRegistry +from unstract.sdk1.execution.result import ExecutionResult + +__all__ = [ + "BaseExecutor", + "ExecutionContext", + "ExecutionDispatcher", + "ExecutionOrchestrator", + "ExecutionResult", + "ExecutorRegistry", +] diff --git a/unstract/sdk1/src/unstract/sdk1/execution/context.py b/unstract/sdk1/src/unstract/sdk1/execution/context.py new file mode 100644 index 0000000000..a1efb4c3f8 --- /dev/null +++ b/unstract/sdk1/src/unstract/sdk1/execution/context.py @@ -0,0 +1,128 @@ +"""Execution context model for the executor framework. + +Defines the serializable context that is dispatched to executor +workers via Celery. Used by both the workflow path (structure tool +task) and the IDE path (PromptStudioHelper). +""" + +import uuid +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + + +class ExecutionSource(str, Enum): + """Origin of the execution request.""" + + IDE = "ide" + TOOL = "tool" + + +class Operation(str, Enum): + """Supported extraction operations. + + Maps 1-to-1 with current PromptTool HTTP endpoints. + """ + + EXTRACT = "extract" + INDEX = "index" + ANSWER_PROMPT = "answer_prompt" + SINGLE_PASS_EXTRACTION = "single_pass_extraction" + SUMMARIZE = "summarize" + IDE_INDEX = "ide_index" + STRUCTURE_PIPELINE = "structure_pipeline" + TABLE_EXTRACT = "table_extract" + SMART_TABLE_EXTRACT = "smart_table_extract" + SPS_ANSWER_PROMPT = "sps_answer_prompt" + SPS_INDEX = "sps_index" + AGENTIC_EXTRACT = "agentic_extract" + AGENTIC_SUMMARIZE = "agentic_summarize" + AGENTIC_UNIFORMIZE = "agentic_uniformize" + AGENTIC_FINALIZE = "agentic_finalize" + AGENTIC_GENERATE_PROMPT = "agentic_generate_prompt" + AGENTIC_GENERATE_PROMPT_PIPELINE = "agentic_generate_prompt_pipeline" + AGENTIC_COMPARE = "agentic_compare" + AGENTIC_TUNE_FIELD = "agentic_tune_field" + + +@dataclass +class ExecutionContext: + """Serializable execution context dispatched to executor worker. + + This is the single payload sent as a Celery task argument to + ``execute_extraction``. It must remain JSON-serializable (no + ORM objects, no file handles, no callables). + + Attributes: + executor_name: Registered executor to handle this request + (e.g. ``"legacy"``, ``"agentic_table"``). + operation: The extraction operation to perform. + run_id: Unique identifier for this execution run. + execution_source: Where the request originated + (``"ide"`` or ``"tool"``). + organization_id: Tenant/org scope. ``None`` for public + calls. + executor_params: Opaque, operation-specific payload passed + through to the executor. Must be JSON-serializable. + request_id: Correlation ID for tracing across services. + log_events_id: Socket.IO channel ID for streaming progress + logs to the frontend. ``None`` when not in an IDE + session (no logs published). + """ + + executor_name: str + operation: str + run_id: str + execution_source: str + organization_id: str | None = None + executor_params: dict[str, Any] = field(default_factory=dict) + request_id: str | None = None + log_events_id: str | None = None + + def __post_init__(self) -> None: + """Validate required fields after initialization.""" + if not self.executor_name: + raise ValueError("executor_name is required") + if not self.operation: + raise ValueError("operation is required") + if not self.run_id: + raise ValueError("run_id is required") + if not self.execution_source: + raise ValueError("execution_source is required") + + # Normalize enum values to plain strings for serialization + if isinstance(self.operation, Operation): + self.operation = self.operation.value + if isinstance(self.execution_source, ExecutionSource): + self.execution_source = self.execution_source.value + + # Auto-generate request_id if not provided + if self.request_id is None: + self.request_id = str(uuid.uuid4()) + + def to_dict(self) -> dict[str, Any]: + """Serialize to a JSON-compatible dict for Celery dispatch.""" + return { + "executor_name": self.executor_name, + "operation": self.operation, + "run_id": self.run_id, + "execution_source": self.execution_source, + "organization_id": self.organization_id, + "executor_params": self.executor_params, + "request_id": self.request_id, + "log_events_id": self.log_events_id, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ExecutionContext": + """Deserialize from a dict (e.g. Celery task argument).""" + return cls( + executor_name=data["executor_name"], + operation=data["operation"], + run_id=data["run_id"], + execution_source=data["execution_source"], + organization_id=data.get("organization_id"), + executor_params=data.get("executor_params", {}), + request_id=data.get("request_id"), + log_events_id=data.get("log_events_id"), + ) diff --git a/unstract/sdk1/src/unstract/sdk1/execution/dispatcher.py b/unstract/sdk1/src/unstract/sdk1/execution/dispatcher.py new file mode 100644 index 0000000000..7fc9c5f720 --- /dev/null +++ b/unstract/sdk1/src/unstract/sdk1/execution/dispatcher.py @@ -0,0 +1,272 @@ +"""Execution dispatcher for sending Celery tasks to executor workers. + +The dispatcher is the caller-side component used by both: +- Structure tool Celery task (workflow path) +- PromptStudioHelper (IDE path) + +It sends ``execute_extraction`` tasks to the ``executor`` queue. +Three dispatch modes are available: + +- ``dispatch()``: Send and block until result (synchronous). +- ``dispatch_async()``: Fire-and-forget, returns task_id for polling. +- ``dispatch_with_callback()``: Fire-and-forget with Celery ``link`` + / ``link_error`` callbacks for post-processing. +""" + +import logging +import os +from typing import Any + +from unstract.sdk1.execution.context import ExecutionContext +from unstract.sdk1.execution.result import ExecutionResult + +logger = logging.getLogger(__name__) + +# Constants matching workers/shared/enums values. +# Defined here to avoid an SDK1 → workers package dependency. +_TASK_NAME = "execute_extraction" + +# Queue-per-executor prefix. Each executor gets its own Celery queue +# named ``celery_executor_{executor_name}``, derived automatically +# from ``ExecutionContext.executor_name``. +_QUEUE_PREFIX = "celery_executor_" + +# Caller-side timeout (seconds) for AsyncResult.get(). +# This controls how long the *caller* waits for the executor to +# finish — distinct from the executor worker's +# ``EXECUTOR_TASK_TIME_LIMIT`` which controls how long the +# *worker* allows a task to run. +# +# Resolution order (matches workers convention): +# 1. Explicit ``timeout`` parameter on dispatch() +# 2. ``EXECUTOR_RESULT_TIMEOUT`` env var +# 3. Hardcoded default (3600s) +# +# The default (3600s) is intentionally <= the executor worker's +# ``task_time_limit`` default (also 3600s) so the caller never +# waits longer than the worker allows the task to run. +_DEFAULT_TIMEOUT_ENV = "EXECUTOR_RESULT_TIMEOUT" +_DEFAULT_TIMEOUT = 3600 # 1 hour — matches executor worker default + + +class ExecutionDispatcher: + """Dispatches execution to executor worker via Celery task. + + Usage:: + + dispatcher = ExecutionDispatcher(celery_app=app) + result = dispatcher.dispatch(context, timeout=120) + + Fire-and-forget:: + + task_id = dispatcher.dispatch_async(context) + + Fire-and-forget with callbacks:: + + from celery import signature + + task = dispatcher.dispatch_with_callback( + context, + on_success=signature("my_success_task", args=[...], queue="q"), + on_error=signature("my_error_task", args=[...], queue="q"), + ) + """ + + def __init__(self, celery_app: Any = None) -> None: + """Initialize the dispatcher. + + Args: + celery_app: A Celery application instance. Required + for dispatching tasks. Can be ``None`` only if + set later via ``celery_app`` attribute. + """ + self._app = celery_app + + @staticmethod + def _get_queue(executor_name: str) -> str: + """Derive the Celery queue name from *executor_name*. + + Convention: ``celery_executor_{executor_name}``. + Adding a new executor automatically gets its own queue — + no registry change needed. + """ + return f"{_QUEUE_PREFIX}{executor_name}" + + def dispatch( + self, + context: ExecutionContext, + timeout: int | None = None, + ) -> ExecutionResult: + """Dispatch context as a Celery task and wait for result. + + Args: + context: ExecutionContext to dispatch. + timeout: Max seconds to wait. ``None`` reads from + the ``EXECUTOR_RESULT_TIMEOUT`` env var, + falling back to 3600s. + + Returns: + ExecutionResult from the executor. + + Raises: + ValueError: If no Celery app is configured. + """ + if self._app is None: + raise ValueError("No Celery app configured on ExecutionDispatcher") + + if timeout is None: + timeout = int(os.environ.get(_DEFAULT_TIMEOUT_ENV, _DEFAULT_TIMEOUT)) + + queue = self._get_queue(context.executor_name) + logger.info( + "Dispatching execution: executor=%s operation=%s " + "run_id=%s request_id=%s timeout=%ss queue=%s", + context.executor_name, + context.operation, + context.run_id, + context.request_id, + timeout, + queue, + ) + + async_result = self._app.send_task( + _TASK_NAME, + args=[context.to_dict()], + queue=queue, + ) + logger.info( + "Task sent: celery_task_id=%s, waiting for result...", + async_result.id, + ) + + try: + # disable_sync_subtasks=False: safe because the executor task + # runs on a separate worker pool (worker-v2) — no deadlock + # risk even when dispatch() is called from inside a Django + # Celery task. + result_dict = async_result.get( + timeout=timeout, + disable_sync_subtasks=False, + ) + except Exception as exc: + logger.error( + "Dispatch failed: executor=%s operation=%s " "run_id=%s error=%s", + context.executor_name, + context.operation, + context.run_id, + exc, + ) + return ExecutionResult.failure( + error=f"{type(exc).__name__}: {exc}", + ) + + return ExecutionResult.from_dict(result_dict) + + def dispatch_async( + self, + context: ExecutionContext, + ) -> str: + """Dispatch without waiting. Returns task_id for polling. + + Args: + context: ExecutionContext to dispatch. + + Returns: + The Celery task ID (use with ``AsyncResult`` to poll). + + Raises: + ValueError: If no Celery app is configured. + """ + if self._app is None: + raise ValueError("No Celery app configured on ExecutionDispatcher") + + queue = self._get_queue(context.executor_name) + logger.info( + "Dispatching async execution: executor=%s " + "operation=%s run_id=%s request_id=%s queue=%s", + context.executor_name, + context.operation, + context.run_id, + context.request_id, + queue, + ) + + async_result = self._app.send_task( + _TASK_NAME, + args=[context.to_dict()], + queue=queue, + ) + return async_result.id + + def dispatch_with_callback( + self, + context: ExecutionContext, + on_success: Any = None, + on_error: Any = None, + task_id: str | None = None, + ) -> Any: + """Fire-and-forget dispatch with Celery link callbacks. + + Sends the task to the executor queue and returns immediately. + When the executor task completes, Celery invokes the + ``on_success`` callback (via ``link``). If the executor task + raises an exception, Celery invokes ``on_error`` (via + ``link_error``). + + Args: + context: ExecutionContext to dispatch. + on_success: A Celery ``Signature`` invoked on success. + Receives ``(result_dict,)`` as first positional arg + followed by the signature's own args. + on_error: A Celery ``Signature`` invoked on failure. + Receives ``(failed_task_uuid,)`` as first positional + arg followed by the signature's own args. + task_id: Optional pre-generated Celery task ID. Useful + when the caller needs to know the task ID before + dispatch (e.g. to include it in callback kwargs). + + Returns: + The ``AsyncResult`` from ``send_task``. Callers can + use ``.id`` for task tracking but should NOT call + ``.get()`` (that would block, defeating the purpose). + + Raises: + ValueError: If no Celery app is configured. + """ + if self._app is None: + raise ValueError("No Celery app configured on ExecutionDispatcher") + + queue = self._get_queue(context.executor_name) + logger.info( + "Dispatching with callback: executor=%s " + "operation=%s run_id=%s request_id=%s " + "on_success=%s on_error=%s queue=%s", + context.executor_name, + context.operation, + context.run_id, + context.request_id, + on_success, + on_error, + queue, + ) + + send_kwargs: dict[str, Any] = { + "args": [context.to_dict()], + "queue": queue, + } + if on_success is not None: + send_kwargs["link"] = on_success + if on_error is not None: + send_kwargs["link_error"] = on_error + if task_id is not None: + send_kwargs["task_id"] = task_id + + async_result = self._app.send_task( + _TASK_NAME, + **send_kwargs, + ) + logger.info( + "Task sent with callbacks: celery_task_id=%s", + async_result.id, + ) + return async_result diff --git a/unstract/sdk1/src/unstract/sdk1/execution/executor.py b/unstract/sdk1/src/unstract/sdk1/execution/executor.py new file mode 100644 index 0000000000..142109945d --- /dev/null +++ b/unstract/sdk1/src/unstract/sdk1/execution/executor.py @@ -0,0 +1,44 @@ +"""Base executor interface for the pluggable executor framework. + +All executors must subclass ``BaseExecutor`` and implement ``name`` +and ``execute``. Registration is handled by +``ExecutorRegistry.register``. +""" + +from abc import ABC, abstractmethod + +from unstract.sdk1.execution.context import ExecutionContext +from unstract.sdk1.execution.result import ExecutionResult + + +class BaseExecutor(ABC): + """Abstract base class for execution strategy implementations. + + Each executor encapsulates a particular extraction strategy + (e.g. the legacy promptservice pipeline, an agentic table + extractor, etc.). Executors are stateless — all request- + specific data arrives via ``ExecutionContext``. + """ + + @property + @abstractmethod + def name(self) -> str: + """Unique identifier used to look up this executor. + + Must match the ``executor_name`` value in + ``ExecutionContext``. Convention: lowercase, snake_case + (e.g. ``"legacy"``, ``"agentic_table"``). + """ + + @abstractmethod + def execute(self, context: ExecutionContext) -> ExecutionResult: + """Run the extraction strategy described by *context*. + + Args: + context: Fully-populated execution context with + operation type and executor params. + + Returns: + An ``ExecutionResult`` whose ``data`` dict conforms to + the response contract for the given operation. + """ diff --git a/unstract/sdk1/src/unstract/sdk1/execution/orchestrator.py b/unstract/sdk1/src/unstract/sdk1/execution/orchestrator.py new file mode 100644 index 0000000000..02693a0509 --- /dev/null +++ b/unstract/sdk1/src/unstract/sdk1/execution/orchestrator.py @@ -0,0 +1,77 @@ +"""Execution orchestrator for the executor worker. + +The orchestrator is the entry point called by the +``execute_extraction`` Celery task. It resolves the correct +executor from the registry and delegates execution, ensuring +that unhandled exceptions are always wrapped in a failed +``ExecutionResult``. +""" + +import logging +import time + +from unstract.sdk1.execution.context import ExecutionContext +from unstract.sdk1.execution.registry import ExecutorRegistry +from unstract.sdk1.execution.result import ExecutionResult + +logger = logging.getLogger(__name__) + + +class ExecutionOrchestrator: + """Looks up and invokes the executor for a given context. + + Usage (inside the Celery task):: + + orchestrator = ExecutionOrchestrator() + result = orchestrator.execute(context) + """ + + def execute(self, context: ExecutionContext) -> ExecutionResult: + """Resolve the executor and run it. + + Args: + context: Fully-populated execution context. + + Returns: + ``ExecutionResult`` — always, even on unhandled + exceptions (wrapped as a failure result). + """ + logger.info( + "Orchestrating execution: executor=%s operation=%s " + "run_id=%s request_id=%s", + context.executor_name, + context.operation, + context.run_id, + context.request_id, + ) + + start = time.monotonic() + try: + executor = ExecutorRegistry.get(context.executor_name) + except KeyError as exc: + logger.error("Executor lookup failed: %s", exc) + return ExecutionResult.failure(error=str(exc)) + + try: + result = executor.execute(context) + except Exception as exc: + elapsed = time.monotonic() - start + logger.exception( + "Executor %r raised an unhandled exception " "after %.2fs", + context.executor_name, + elapsed, + ) + return ExecutionResult.failure( + error=f"{type(exc).__name__}: {exc}", + metadata={"elapsed_seconds": round(elapsed, 3)}, + ) + + elapsed = time.monotonic() - start + logger.info( + "Execution completed: executor=%s operation=%s " "success=%s elapsed=%.2fs", + context.executor_name, + context.operation, + result.success, + elapsed, + ) + return result diff --git a/unstract/sdk1/src/unstract/sdk1/execution/registry.py b/unstract/sdk1/src/unstract/sdk1/execution/registry.py new file mode 100644 index 0000000000..999487a2e5 --- /dev/null +++ b/unstract/sdk1/src/unstract/sdk1/execution/registry.py @@ -0,0 +1,112 @@ +"""Executor registry for the pluggable executor framework. + +Provides a simple in-process registry where executor classes +self-register at import time via the ``@ExecutorRegistry.register`` +decorator. The executor worker imports all executor modules so +that registration happens before any task is processed. +""" + +import logging +from typing import TypeVar + +from unstract.sdk1.execution.executor import BaseExecutor + +logger = logging.getLogger(__name__) + +T = TypeVar("T", bound=type[BaseExecutor]) + + +class ExecutorRegistry: + """In-process registry mapping executor names to classes. + + Usage:: + + @ExecutorRegistry.register + class LegacyExecutor(BaseExecutor): + @property + def name(self) -> str: + return "legacy" + + ... + + + executor = ExecutorRegistry.get("legacy") + """ + + _registry: dict[str, type[BaseExecutor]] = {} + + @classmethod + def register(cls, executor_cls: T) -> T: + """Class decorator that registers an executor. + + Instantiates the class once to read its ``name`` property, + then stores the *class* (not the instance) so a fresh + instance is created per ``get()`` call. + + Args: + executor_cls: A concrete ``BaseExecutor`` subclass. + + Returns: + The same class, unmodified (passthrough decorator). + + Raises: + TypeError: If *executor_cls* is not a BaseExecutor + subclass. + ValueError: If an executor with the same name is + already registered. + """ + if not ( + isinstance(executor_cls, type) and issubclass(executor_cls, BaseExecutor) + ): + raise TypeError(f"{executor_cls!r} is not a BaseExecutor subclass") + + # Instantiate temporarily to read the name property + instance = executor_cls() + name = instance.name + + if name in cls._registry: + existing = cls._registry[name] + raise ValueError( + f"Executor name {name!r} is already registered " + f"by {existing.__name__}; cannot register " + f"{executor_cls.__name__}" + ) + + cls._registry[name] = executor_cls + logger.info( + "Registered executor %r (%s)", + name, + executor_cls.__name__, + ) + return executor_cls + + @classmethod + def get(cls, name: str) -> BaseExecutor: + """Look up and instantiate an executor by name. + + Args: + name: The executor name (e.g. ``"legacy"``). + + Returns: + A fresh ``BaseExecutor`` instance. + + Raises: + KeyError: If no executor is registered under *name*. + """ + executor_cls = cls._registry.get(name) + if executor_cls is None: + available = ", ".join(sorted(cls._registry)) or "(none)" + raise KeyError( + f"No executor registered with name {name!r}. " f"Available: {available}" + ) + return executor_cls() + + @classmethod + def list_executors(cls) -> list[str]: + """Return sorted list of registered executor names.""" + return sorted(cls._registry) + + @classmethod + def clear(cls) -> None: + """Remove all registered executors (for testing).""" + cls._registry.clear() diff --git a/unstract/sdk1/src/unstract/sdk1/execution/result.py b/unstract/sdk1/src/unstract/sdk1/execution/result.py new file mode 100644 index 0000000000..0088d071f5 --- /dev/null +++ b/unstract/sdk1/src/unstract/sdk1/execution/result.py @@ -0,0 +1,72 @@ +"""Execution result model for the executor framework. + +Defines the standardized result returned by executors via the +Celery result backend. All executors must return an +``ExecutionResult`` so that callers (structure tool task, +PromptStudioHelper) have a uniform interface. +""" + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class ExecutionResult: + """Standardized result from an executor. + + Returned via the Celery result backend as a JSON dict. + + Attributes: + success: Whether the execution completed without error. + data: Operation-specific output payload. The shape depends + on the operation (see response contract in the + migration plan). + metadata: Auxiliary information such as token usage, + timings, or adapter metrics. + error: Human-readable error message when ``success`` is + ``False``. ``None`` on success. + """ + + success: bool + data: dict[str, Any] = field(default_factory=dict) + metadata: dict[str, Any] = field(default_factory=dict) + error: str | None = None + + def __post_init__(self) -> None: + """Validate result consistency after initialization.""" + if not self.success and not self.error: + raise ValueError("error message is required when success is False") + + def to_dict(self) -> dict[str, Any]: + """Serialize to a JSON-compatible dict for Celery.""" + result: dict[str, Any] = { + "success": self.success, + "data": self.data, + "metadata": self.metadata, + } + if self.error is not None: + result["error"] = self.error + return result + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ExecutionResult": + """Deserialize from a dict (e.g. Celery result backend).""" + return cls( + success=data["success"], + data=data.get("data", {}), + metadata=data.get("metadata", {}), + error=data.get("error"), + ) + + @classmethod + def failure( + cls, + error: str, + metadata: dict[str, Any] | None = None, + ) -> "ExecutionResult": + """Convenience factory for a failed result.""" + return cls( + success=False, + error=error, + metadata=metadata or {}, + ) diff --git a/unstract/sdk1/src/unstract/sdk1/platform.py b/unstract/sdk1/src/unstract/sdk1/platform.py index a7995164c4..e5ce7fc172 100644 --- a/unstract/sdk1/src/unstract/sdk1/platform.py +++ b/unstract/sdk1/src/unstract/sdk1/platform.py @@ -140,8 +140,8 @@ def _get_adapter_configuration( provider = adapter_data.get("adapter_id", "").split("|")[0] # TODO: Print metadata after redacting sensitive information tool.stream_log( - f"Retrieved config for '{adapter_instance_id}', type: " - f"'{adapter_type}', provider: '{provider}', name: '{adapter_name}'", + f"Retrieved adapter config — name: '{adapter_name}', " + f"type: '{adapter_type}', provider: '{provider}'", level=LogLevel.DEBUG, ) except HTTPError as e: @@ -188,7 +188,7 @@ def get_adapter_config( return adapter_metadata tool.stream_log( - f"Retrieving config from DB for '{adapter_instance_id}'", + "Retrieving adapter configuration from platform service", level=LogLevel.DEBUG, ) diff --git a/unstract/sdk1/src/unstract/sdk1/vector_db.py b/unstract/sdk1/src/unstract/sdk1/vector_db.py index 9638faf358..c46b1c0cb0 100644 --- a/unstract/sdk1/src/unstract/sdk1/vector_db.py +++ b/unstract/sdk1/src/unstract/sdk1/vector_db.py @@ -110,7 +110,7 @@ def _get_vector_db(self) -> BasePydanticVectorStore | VectorStore: return self.vector_db_adapter_class.get_vector_db_instance() except Exception as e: self._tool.stream_log( - log=f"Unable to get vector_db {self._adapter_instance_id}: {e}", + log=f"Unable to get vector database: {e}", level=LogLevel.ERROR, ) raise VectorDBError(f"Error getting vectorDB instance: {e}") from e diff --git a/unstract/sdk1/tests/test_execution.py b/unstract/sdk1/tests/test_execution.py new file mode 100644 index 0000000000..3839a01073 --- /dev/null +++ b/unstract/sdk1/tests/test_execution.py @@ -0,0 +1,1063 @@ +"""Unit tests for execution framework (Phase 1A–1G).""" + +import json +import logging +from typing import Any, Self +from unittest.mock import MagicMock + +import pytest +from unstract.sdk1.constants import LogLevel, ToolEnv +from unstract.sdk1.exceptions import SdkError +from unstract.sdk1.execution.context import ( + ExecutionContext, + ExecutionSource, + Operation, +) +from unstract.sdk1.execution.dispatcher import ExecutionDispatcher +from unstract.sdk1.execution.executor import BaseExecutor +from unstract.sdk1.execution.orchestrator import ExecutionOrchestrator +from unstract.sdk1.execution.registry import ExecutorRegistry +from unstract.sdk1.execution.result import ExecutionResult + + +class TestExecutionContext: + """Tests for ExecutionContext serialization and validation.""" + + def _make_context(self, **overrides: Any) -> ExecutionContext: + """Create a default ExecutionContext with optional overrides.""" + defaults: dict[str, Any] = { + "executor_name": "legacy", + "operation": "extract", + "run_id": "run-001", + "execution_source": "tool", + "organization_id": "org-123", + "executor_params": {"file_path": "/tmp/test.pdf"}, + "request_id": "req-abc", + } + defaults.update(overrides) + return ExecutionContext(**defaults) + + def test_round_trip_serialization(self: Self) -> None: + """to_dict -> from_dict produces identical context.""" + original = self._make_context() + restored = ExecutionContext.from_dict(original.to_dict()) + + assert restored.executor_name == original.executor_name + assert restored.operation == original.operation + assert restored.run_id == original.run_id + assert restored.execution_source == original.execution_source + assert restored.organization_id == original.organization_id + assert restored.executor_params == original.executor_params + assert restored.request_id == original.request_id + + def test_json_serializable(self: Self) -> None: + """to_dict output is JSON-serializable (Celery requirement).""" + ctx = self._make_context() + serialized = json.dumps(ctx.to_dict()) + deserialized = json.loads(serialized) + restored = ExecutionContext.from_dict(deserialized) + assert restored.executor_name == ctx.executor_name + + def test_enum_values_normalized(self: Self) -> None: + """Enum instances are normalized to plain strings.""" + ctx = self._make_context( + operation=Operation.ANSWER_PROMPT, + execution_source=ExecutionSource.IDE, + ) + assert ctx.operation == "answer_prompt" + assert ctx.execution_source == "ide" + # Also check dict output + d = ctx.to_dict() + assert d["operation"] == "answer_prompt" + assert d["execution_source"] == "ide" + + def test_string_values_accepted(self: Self) -> None: + """Plain string values work without enum coercion.""" + ctx = self._make_context( + operation="custom_op", + execution_source="tool", + ) + assert ctx.operation == "custom_op" + assert ctx.execution_source == "tool" + + def test_auto_generates_request_id(self: Self) -> None: + """request_id is generated when not provided.""" + ctx = self._make_context(request_id=None) + assert ctx.request_id is not None + assert len(ctx.request_id) > 0 + + def test_explicit_request_id_preserved(self: Self) -> None: + """Explicit request_id is not overwritten.""" + ctx = self._make_context(request_id="my-req-id") + assert ctx.request_id == "my-req-id" + + def test_optional_organization_id(self: Self) -> None: + """organization_id can be None (public calls).""" + ctx = self._make_context(organization_id=None) + assert ctx.organization_id is None + d = ctx.to_dict() + assert d["organization_id"] is None + restored = ExecutionContext.from_dict(d) + assert restored.organization_id is None + + def test_empty_executor_params_default(self: Self) -> None: + """executor_params defaults to empty dict.""" + ctx = ExecutionContext( + executor_name="legacy", + operation="extract", + run_id="run-001", + execution_source="tool", + ) + assert ctx.executor_params == {} + + def test_complex_executor_params(self: Self) -> None: + """Nested executor_params round-trip correctly.""" + params = { + "file_path": "/data/doc.pdf", + "outputs": [ + {"prompt_key": "p1", "llm": "adapter-1"}, + {"prompt_key": "p2", "llm": "adapter-2"}, + ], + "options": {"reindex": True, "chunk_size": 512}, + } + ctx = self._make_context(executor_params=params) + restored = ExecutionContext.from_dict(ctx.to_dict()) + assert restored.executor_params == params + + @pytest.mark.parametrize( + "field,value", + [ + ("executor_name", ""), + ("operation", ""), + ("run_id", ""), + ("execution_source", ""), + ], + ) + def test_validation_rejects_empty_required_fields( + self: Self, field: str, value: str + ) -> None: + """Empty required fields raise ValueError.""" + with pytest.raises(ValueError, match=f"{field} is required"): + self._make_context(**{field: value}) + + def test_all_operations_accepted(self: Self) -> None: + """All Operation enum values create valid contexts.""" + for op in Operation: + ctx = self._make_context(operation=op) + assert ctx.operation == op.value + + def test_from_dict_missing_optional_fields(self: Self) -> None: + """from_dict handles missing optional fields gracefully.""" + minimal = { + "executor_name": "legacy", + "operation": "extract", + "run_id": "run-001", + "execution_source": "tool", + } + ctx = ExecutionContext.from_dict(minimal) + assert ctx.organization_id is None + assert ctx.executor_params == {} + # request_id is None from dict (no auto-gen in from_dict) + # but __post_init__ auto-generates it + assert ctx.request_id is not None + + +class TestExecutionResult: + """Tests for ExecutionResult serialization and validation.""" + + def test_success_round_trip(self: Self) -> None: + """Successful result round-trips through dict.""" + original = ExecutionResult( + success=True, + data={"output": {"key": "value"}, "metadata": {}}, + metadata={"tokens": 150, "latency_ms": 320}, + ) + restored = ExecutionResult.from_dict(original.to_dict()) + assert restored.success is True + assert restored.data == original.data + assert restored.metadata == original.metadata + assert restored.error is None + + def test_failure_round_trip(self: Self) -> None: + """Failed result round-trips through dict.""" + original = ExecutionResult( + success=False, + error="LLM adapter timeout", + metadata={"retry_count": 2}, + ) + restored = ExecutionResult.from_dict(original.to_dict()) + assert restored.success is False + assert restored.error == "LLM adapter timeout" + assert restored.data == {} + assert restored.metadata == {"retry_count": 2} + + def test_json_serializable(self: Self) -> None: + """to_dict output is JSON-serializable.""" + result = ExecutionResult( + success=True, + data={"extracted_text": "Hello world"}, + ) + serialized = json.dumps(result.to_dict()) + deserialized = json.loads(serialized) + restored = ExecutionResult.from_dict(deserialized) + assert restored.data == result.data + + def test_failure_requires_error_message(self: Self) -> None: + """success=False without error raises ValueError.""" + with pytest.raises( + ValueError, + match="error message is required", + ): + ExecutionResult(success=False) + + def test_success_allows_no_error(self: Self) -> None: + """success=True with no error is valid.""" + result = ExecutionResult(success=True) + assert result.error is None + + def test_failure_factory(self: Self) -> None: + """ExecutionResult.failure() convenience constructor.""" + result = ExecutionResult.failure( + error="Something broke", + metadata={"debug": True}, + ) + assert result.success is False + assert result.error == "Something broke" + assert result.data == {} + assert result.metadata == {"debug": True} + + def test_failure_factory_no_metadata(self: Self) -> None: + """failure() works without metadata.""" + result = ExecutionResult.failure(error="Oops") + assert result.metadata == {} + + def test_error_not_in_success_dict(self: Self) -> None: + """Successful result dict omits error key.""" + result = ExecutionResult(success=True, data={"k": "v"}) + d = result.to_dict() + assert "error" not in d + + def test_error_in_failure_dict(self: Self) -> None: + """Failed result dict includes error key.""" + result = ExecutionResult.failure(error="fail") + d = result.to_dict() + assert d["error"] == "fail" + + def test_default_empty_dicts(self: Self) -> None: + """Data and metadata default to empty dicts.""" + result = ExecutionResult(success=True) + assert result.data == {} + assert result.metadata == {} + + def test_from_dict_missing_optional_fields(self: Self) -> None: + """from_dict handles missing optional fields.""" + minimal = {"success": True} + result = ExecutionResult.from_dict(minimal) + assert result.data == {} + assert result.metadata == {} + assert result.error is None + + def test_response_contract_extract(self: Self) -> None: + """Verify extract operation response shape.""" + result = ExecutionResult( + success=True, + data={"extracted_text": "The quick brown fox"}, + ) + assert "extracted_text" in result.data + + def test_response_contract_index(self: Self) -> None: + """Verify index operation response shape.""" + result = ExecutionResult( + success=True, + data={"doc_id": "doc-abc-123"}, + ) + assert "doc_id" in result.data + + def test_response_contract_answer_prompt(self: Self) -> None: + """Verify answer_prompt operation response shape.""" + result = ExecutionResult( + success=True, + data={ + "output": {"field1": "value1"}, + "metadata": {"confidence": 0.95}, + "metrics": {"tokens": 200}, + }, + ) + assert "output" in result.data + assert "metadata" in result.data + assert "metrics" in result.data + + +# ---- Phase 1B: BaseExecutor & ExecutorRegistry ---- + + +def _make_executor_class( + executor_name: str, +) -> type[BaseExecutor]: + """Helper: build a concrete BaseExecutor subclass dynamically.""" + + class _Executor(BaseExecutor): + @property + def name(self) -> str: + return executor_name + + def execute(self, context: ExecutionContext) -> ExecutionResult: + return ExecutionResult( + success=True, + data={"echo": context.operation}, + ) + + # Give it a readable __name__ for error messages + _Executor.__name__ = f"{executor_name.title()}Executor" + _Executor.__qualname__ = _Executor.__name__ + return _Executor + + +class TestBaseExecutor: + """Tests for BaseExecutor ABC contract.""" + + def test_cannot_instantiate_abstract(self: Self) -> None: + """BaseExecutor itself cannot be instantiated.""" + with pytest.raises(TypeError): + BaseExecutor() # type: ignore[abstract] + + def test_concrete_subclass_works(self: Self) -> None: + """A properly implemented subclass can be instantiated.""" + cls = _make_executor_class("test_abc") + instance = cls() + assert instance.name == "test_abc" + + def test_execute_returns_result(self: Self) -> None: + """execute() returns an ExecutionResult.""" + cls = _make_executor_class("test_exec") + instance = cls() + ctx = ExecutionContext( + executor_name="test_exec", + operation="extract", + run_id="run-1", + execution_source="tool", + ) + result = instance.execute(ctx) + assert isinstance(result, ExecutionResult) + assert result.success is True + assert result.data == {"echo": "extract"} + + +class TestExecutorRegistry: + """Tests for ExecutorRegistry.""" + + @pytest.fixture(autouse=True) + def _clean_registry(self: Self) -> None: + """Ensure a clean registry for every test.""" + ExecutorRegistry.clear() + + def test_register_and_get(self: Self) -> None: + """Register an executor and retrieve by name.""" + cls = _make_executor_class("alpha") + ExecutorRegistry.register(cls) + + executor = ExecutorRegistry.get("alpha") + assert isinstance(executor, BaseExecutor) + assert executor.name == "alpha" + + def test_get_returns_fresh_instance(self: Self) -> None: + """Each get() call returns a new instance.""" + cls = _make_executor_class("fresh") + ExecutorRegistry.register(cls) + + a = ExecutorRegistry.get("fresh") + b = ExecutorRegistry.get("fresh") + assert a is not b + + def test_register_as_decorator(self: Self) -> None: + """@ExecutorRegistry.register works as a class decorator.""" + + @ExecutorRegistry.register + class MyExecutor(BaseExecutor): + @property + def name(self) -> str: + return "decorated" + + def execute(self, context: ExecutionContext) -> ExecutionResult: + return ExecutionResult(success=True) + + executor = ExecutorRegistry.get("decorated") + assert executor.name == "decorated" + # Decorator returns the class unchanged + assert MyExecutor is not None + + def test_list_executors(self: Self) -> None: + """list_executors() returns sorted names.""" + ExecutorRegistry.register(_make_executor_class("charlie")) + ExecutorRegistry.register(_make_executor_class("alpha")) + ExecutorRegistry.register(_make_executor_class("bravo")) + + assert ExecutorRegistry.list_executors() == [ + "alpha", + "bravo", + "charlie", + ] + + def test_list_executors_empty(self: Self) -> None: + """list_executors() returns empty list when nothing registered.""" + assert ExecutorRegistry.list_executors() == [] + + def test_get_unknown_raises_key_error(self: Self) -> None: + """get() with unknown name raises KeyError.""" + with pytest.raises(KeyError, match="no_such_executor"): + ExecutorRegistry.get("no_such_executor") + + def test_get_unknown_lists_available(self: Self) -> None: + """KeyError message includes available executor names.""" + ExecutorRegistry.register(_make_executor_class("one")) + ExecutorRegistry.register(_make_executor_class("two")) + + with pytest.raises(KeyError, match="one") as exc_info: + ExecutorRegistry.get("missing") + assert "two" in str(exc_info.value) + + def test_duplicate_name_raises_value_error(self: Self) -> None: + """Registering two executors with the same name fails.""" + ExecutorRegistry.register(_make_executor_class("dup")) + with pytest.raises(ValueError, match="already registered"): + ExecutorRegistry.register(_make_executor_class("dup")) + + def test_register_non_subclass_raises_type_error(self: Self) -> None: + """Registering a non-BaseExecutor class raises TypeError.""" + with pytest.raises(TypeError, match="not a BaseExecutor"): + ExecutorRegistry.register(dict) # type: ignore[arg-type] + + def test_register_non_class_raises_type_error(self: Self) -> None: + """Registering a non-class object raises TypeError.""" + with pytest.raises(TypeError, match="not a BaseExecutor"): + ExecutorRegistry.register("not_a_class") # type: ignore[arg-type] + + def test_clear(self: Self) -> None: + """clear() removes all registrations.""" + ExecutorRegistry.register(_make_executor_class("temp")) + assert ExecutorRegistry.list_executors() == ["temp"] + ExecutorRegistry.clear() + assert ExecutorRegistry.list_executors() == [] + + def test_execute_through_registry(self: Self) -> None: + """End-to-end: register, get, execute.""" + ExecutorRegistry.register(_make_executor_class("e2e")) + + ctx = ExecutionContext( + executor_name="e2e", + operation="index", + run_id="run-42", + execution_source="ide", + ) + executor = ExecutorRegistry.get("e2e") + result = executor.execute(ctx) + + assert result.success is True + assert result.data == {"echo": "index"} + + +# ---- Phase 1C: ExecutionOrchestrator ---- + + +def _make_failing_executor_class( + executor_name: str, + exc: Exception, +) -> type[BaseExecutor]: + """Build an executor that always raises *exc*.""" + + class _FailExecutor(BaseExecutor): + @property + def name(self) -> str: + return executor_name + + def execute(self, context: ExecutionContext) -> ExecutionResult: + raise exc + + _FailExecutor.__name__ = f"{executor_name.title()}FailExecutor" + _FailExecutor.__qualname__ = _FailExecutor.__name__ + return _FailExecutor + + +class TestExecutionOrchestrator: + """Tests for ExecutionOrchestrator.""" + + @pytest.fixture(autouse=True) + def _clean_registry(self: Self) -> None: + """Ensure a clean registry for every test.""" + ExecutorRegistry.clear() + + def _make_context(self, **overrides: Any) -> ExecutionContext: + defaults: dict[str, Any] = { + "executor_name": "legacy", + "operation": "extract", + "run_id": "run-1", + "execution_source": "tool", + } + defaults.update(overrides) + return ExecutionContext(**defaults) + + def test_dispatches_to_correct_executor(self: Self) -> None: + """Orchestrator routes to the right executor by name.""" + ExecutorRegistry.register(_make_executor_class("alpha")) + ExecutorRegistry.register(_make_executor_class("bravo")) + + orchestrator = ExecutionOrchestrator() + + result_a = orchestrator.execute( + self._make_context(executor_name="alpha", operation="extract") + ) + assert result_a.success is True + assert result_a.data == {"echo": "extract"} + + result_b = orchestrator.execute( + self._make_context(executor_name="bravo", operation="index") + ) + assert result_b.success is True + assert result_b.data == {"echo": "index"} + + def test_unknown_executor_returns_failure(self: Self) -> None: + """Unknown executor_name yields a failure result (not exception).""" + orchestrator = ExecutionOrchestrator() + result = orchestrator.execute(self._make_context(executor_name="nonexistent")) + assert result.success is False + assert "nonexistent" in result.error + + def test_executor_exception_returns_failure(self: Self) -> None: + """Unhandled executor exception is wrapped in failure result.""" + ExecutorRegistry.register( + _make_failing_executor_class("boom", RuntimeError("kaboom")) + ) + orchestrator = ExecutionOrchestrator() + result = orchestrator.execute(self._make_context(executor_name="boom")) + assert result.success is False + assert "RuntimeError" in result.error + assert "kaboom" in result.error + + def test_exception_result_has_elapsed_metadata(self: Self) -> None: + """Failure from exception includes elapsed_seconds metadata.""" + ExecutorRegistry.register( + _make_failing_executor_class("slow_fail", ValueError("bad input")) + ) + orchestrator = ExecutionOrchestrator() + result = orchestrator.execute(self._make_context(executor_name="slow_fail")) + assert result.success is False + assert "elapsed_seconds" in result.metadata + assert isinstance(result.metadata["elapsed_seconds"], float) + + def test_successful_result_passed_through(self: Self) -> None: + """Orchestrator returns the executor's result as-is on success.""" + ExecutorRegistry.register(_make_executor_class("passthru")) + orchestrator = ExecutionOrchestrator() + + ctx = self._make_context(executor_name="passthru", operation="answer_prompt") + result = orchestrator.execute(ctx) + + assert result.success is True + assert result.data == {"echo": "answer_prompt"} + + def test_executor_returning_failure_is_not_wrapped( + self: Self, + ) -> None: + """An executor that returns failure result is passed through.""" + + class FailingExecutor(BaseExecutor): + @property + def name(self) -> str: + return "graceful_fail" + + def execute(self, context: ExecutionContext) -> ExecutionResult: + return ExecutionResult.failure(error="LLM rate limited") + + ExecutorRegistry.register(FailingExecutor) + orchestrator = ExecutionOrchestrator() + + result = orchestrator.execute(self._make_context(executor_name="graceful_fail")) + assert result.success is False + assert result.error == "LLM rate limited" + + +# ---- Phase 1F: ExecutionDispatcher ---- + + +class TestExecutionDispatcher: + """Tests for ExecutionDispatcher (mocked Celery).""" + + def _make_context(self, **overrides: Any) -> ExecutionContext: + defaults: dict[str, Any] = { + "executor_name": "legacy", + "operation": "extract", + "run_id": "run-1", + "execution_source": "tool", + "request_id": "req-1", + } + defaults.update(overrides) + return ExecutionContext(**defaults) + + def _make_mock_app( + self, + result_dict: dict[str, Any] | None = None, + side_effect: Exception | None = None, + task_id: str = "celery-task-123", + ) -> MagicMock: + """Create a mock Celery app with send_task configured.""" + mock_app = MagicMock() + mock_async_result = MagicMock() + mock_async_result.id = task_id + + if side_effect is not None: + mock_async_result.get.side_effect = side_effect + else: + mock_async_result.get.return_value = ( + result_dict + if result_dict is not None + else {"success": True, "data": {}, "metadata": {}} + ) + + mock_app.send_task.return_value = mock_async_result + return mock_app + + def test_dispatch_sends_task_and_returns_result( + self: Self, + ) -> None: + """dispatch() sends task to executor queue and returns result.""" + result_dict = { + "success": True, + "data": {"extracted_text": "hello"}, + "metadata": {}, + } + mock_app = self._make_mock_app(result_dict=result_dict) + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = self._make_context() + + result = dispatcher.dispatch(ctx, timeout=60) + + assert result.success is True + assert result.data == {"extracted_text": "hello"} + + # Verify send_task was called correctly + mock_app.send_task.assert_called_once_with( + "execute_extraction", + args=[ctx.to_dict()], + queue="celery_executor_legacy", + ) + mock_app.send_task.return_value.get.assert_called_once_with( + timeout=60, disable_sync_subtasks=False + ) + + def test_dispatch_uses_default_timeout(self: Self) -> None: + """dispatch() without timeout uses default (3600s).""" + mock_app = self._make_mock_app() + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = self._make_context() + + dispatcher.dispatch(ctx) + + mock_app.send_task.return_value.get.assert_called_once_with( + timeout=3600, disable_sync_subtasks=False + ) + + def test_dispatch_timeout_from_env( + self: Self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """dispatch() reads timeout from EXECUTOR_RESULT_TIMEOUT env.""" + monkeypatch.setenv("EXECUTOR_RESULT_TIMEOUT", "120") + mock_app = self._make_mock_app() + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = self._make_context() + + dispatcher.dispatch(ctx) + + mock_app.send_task.return_value.get.assert_called_once_with( + timeout=120, disable_sync_subtasks=False + ) + + def test_dispatch_explicit_timeout_overrides_env( + self: Self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Explicit timeout parameter overrides env var.""" + monkeypatch.setenv("EXECUTOR_RESULT_TIMEOUT", "120") + mock_app = self._make_mock_app() + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = self._make_context() + + dispatcher.dispatch(ctx, timeout=30) + + mock_app.send_task.return_value.get.assert_called_once_with( + timeout=30, disable_sync_subtasks=False + ) + + def test_dispatch_timeout_returns_failure( + self: Self, + ) -> None: + """TimeoutError from AsyncResult.get() is wrapped in failure.""" + mock_app = self._make_mock_app(side_effect=TimeoutError("Task timed out")) + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = self._make_context() + + result = dispatcher.dispatch(ctx, timeout=1) + + assert result.success is False + assert "TimeoutError" in result.error + + def test_dispatch_generic_exception_returns_failure( + self: Self, + ) -> None: + """Any exception from AsyncResult.get() becomes a failure.""" + mock_app = self._make_mock_app(side_effect=RuntimeError("broker down")) + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = self._make_context() + + result = dispatcher.dispatch(ctx, timeout=10) + + assert result.success is False + assert "RuntimeError" in result.error + assert "broker down" in result.error + + def test_dispatch_async_returns_task_id(self: Self) -> None: + """dispatch_async() returns the Celery task ID.""" + mock_app = self._make_mock_app(task_id="task-xyz-789") + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = self._make_context() + + task_id = dispatcher.dispatch_async(ctx) + + assert task_id == "task-xyz-789" + mock_app.send_task.assert_called_once_with( + "execute_extraction", + args=[ctx.to_dict()], + queue="celery_executor_legacy", + ) + + def test_dispatch_no_app_raises_value_error( + self: Self, + ) -> None: + """dispatch() without celery_app raises ValueError.""" + dispatcher = ExecutionDispatcher(celery_app=None) + ctx = self._make_context() + + with pytest.raises(ValueError, match="No Celery app"): + dispatcher.dispatch(ctx) + + def test_dispatch_async_no_app_raises_value_error( + self: Self, + ) -> None: + """dispatch_async() without celery_app raises ValueError.""" + dispatcher = ExecutionDispatcher(celery_app=None) + ctx = self._make_context() + + with pytest.raises(ValueError, match="No Celery app"): + dispatcher.dispatch_async(ctx) + + def test_dispatch_failure_result_from_executor( + self: Self, + ) -> None: + """Executor failure is deserialized correctly.""" + result_dict = { + "success": False, + "data": {}, + "metadata": {}, + "error": "LLM adapter timeout", + } + mock_app = self._make_mock_app(result_dict=result_dict) + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = self._make_context() + + result = dispatcher.dispatch(ctx, timeout=60) + + assert result.success is False + assert result.error == "LLM adapter timeout" + + def test_dispatch_context_serialized_correctly( + self: Self, + ) -> None: + """The full ExecutionContext is serialized in the task args.""" + mock_app = self._make_mock_app() + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = self._make_context( + executor_name="agentic_table", + operation="agentic_extraction", + organization_id="org-42", + executor_params={"schema": {"name": "str"}}, + ) + + dispatcher.dispatch(ctx, timeout=60) + + sent_args = mock_app.send_task.call_args + context_dict = sent_args[1]["args"][0] + + assert context_dict["executor_name"] == "agentic_table" + assert context_dict["operation"] == "agentic_extraction" + assert context_dict["organization_id"] == "org-42" + assert context_dict["executor_params"] == {"schema": {"name": "str"}} + + # ---- Phase 5A: dispatch_with_callback ---- + + def test_dispatch_with_callback_sends_link_and_link_error( + self: Self, + ) -> None: + """dispatch_with_callback() passes on_success as link, on_error as link_error.""" + mock_app = self._make_mock_app(task_id="cb-task-001") + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = self._make_context() + + on_success = MagicMock(name="on_success_sig") + on_error = MagicMock(name="on_error_sig") + + result = dispatcher.dispatch_with_callback( + ctx, on_success=on_success, on_error=on_error + ) + + assert result.id == "cb-task-001" + mock_app.send_task.assert_called_once_with( + "execute_extraction", + args=[ctx.to_dict()], + queue="celery_executor_legacy", + link=on_success, + link_error=on_error, + ) + + def test_dispatch_with_callback_success_only( + self: Self, + ) -> None: + """dispatch_with_callback() with only on_success omits link_error.""" + mock_app = self._make_mock_app(task_id="cb-task-002") + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = self._make_context() + + on_success = MagicMock(name="on_success_sig") + + dispatcher.dispatch_with_callback(ctx, on_success=on_success) + + call_kwargs = mock_app.send_task.call_args + assert call_kwargs[1]["link"] is on_success + assert "link_error" not in call_kwargs[1] + + def test_dispatch_with_callback_error_only( + self: Self, + ) -> None: + """dispatch_with_callback() with only on_error omits link.""" + mock_app = self._make_mock_app(task_id="cb-task-003") + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = self._make_context() + + on_error = MagicMock(name="on_error_sig") + + dispatcher.dispatch_with_callback(ctx, on_error=on_error) + + call_kwargs = mock_app.send_task.call_args + assert "link" not in call_kwargs[1] + assert call_kwargs[1]["link_error"] is on_error + + def test_dispatch_with_callback_no_callbacks( + self: Self, + ) -> None: + """dispatch_with_callback() with no callbacks sends plain task.""" + mock_app = self._make_mock_app(task_id="cb-task-004") + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = self._make_context() + + result = dispatcher.dispatch_with_callback(ctx) + + assert result.id == "cb-task-004" + call_kwargs = mock_app.send_task.call_args + assert "link" not in call_kwargs[1] + assert "link_error" not in call_kwargs[1] + + def test_dispatch_with_callback_returns_async_result( + self: Self, + ) -> None: + """dispatch_with_callback() returns the AsyncResult object (not just task_id).""" + mock_app = self._make_mock_app(task_id="cb-task-005") + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = self._make_context() + + result = dispatcher.dispatch_with_callback(ctx) + + # Returns the full AsyncResult, not just the id string + assert result is mock_app.send_task.return_value + assert result.id == "cb-task-005" + + def test_dispatch_with_callback_no_app_raises_value_error( + self: Self, + ) -> None: + """dispatch_with_callback() without celery_app raises ValueError.""" + dispatcher = ExecutionDispatcher(celery_app=None) + ctx = self._make_context() + + with pytest.raises(ValueError, match="No Celery app"): + dispatcher.dispatch_with_callback(ctx) + + def test_dispatch_with_callback_context_serialized( + self: Self, + ) -> None: + """dispatch_with_callback() serializes context correctly.""" + mock_app = self._make_mock_app() + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = self._make_context( + operation="answer_prompt", + executor_params={"prompt_key": "p1"}, + ) + + dispatcher.dispatch_with_callback(ctx, on_success=MagicMock()) + + sent_args = mock_app.send_task.call_args + context_dict = sent_args[1]["args"][0] + assert context_dict["operation"] == "answer_prompt" + assert context_dict["executor_params"] == {"prompt_key": "p1"} + + def test_dispatch_with_callback_custom_task_id( + self: Self, + ) -> None: + """dispatch_with_callback() passes custom task_id to send_task.""" + mock_app = self._make_mock_app(task_id="pre-gen-id-123") + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = self._make_context() + + result = dispatcher.dispatch_with_callback(ctx, task_id="pre-gen-id-123") + + call_kwargs = mock_app.send_task.call_args + assert call_kwargs[1]["task_id"] == "pre-gen-id-123" + + def test_dispatch_with_callback_no_task_id_omits_kwarg( + self: Self, + ) -> None: + """dispatch_with_callback() without task_id doesn't pass task_id.""" + mock_app = self._make_mock_app() + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = self._make_context() + + dispatcher.dispatch_with_callback(ctx) + + call_kwargs = mock_app.send_task.call_args + assert "task_id" not in call_kwargs[1] + + +# ---- Phase 1G: ExecutorToolShim ---- +# Note: ExecutorToolShim lives in workers/executor/ but the tests +# import it directly via sys.path manipulation since the workers +# package requires Celery (not installed in SDK1 test venv). +# We test the shim's logic here by importing its direct dependencies +# from SDK1 (StreamMixin, SdkError, LogLevel, ToolEnv). + + +class _MockExecutorToolShim: + """In-test replica of ExecutorToolShim for SDK1 test isolation. + + The real ExecutorToolShim lives in workers/executor/ and cannot + be imported here (Celery not in SDK1 venv). This replica + mirrors the same logic so we can verify the behavior contract + without importing the workers package. + """ + + def __init__(self, platform_api_key: str = "") -> None: + self.platform_api_key = platform_api_key + + def get_env_or_die(self, env_key: str) -> str: + import os + + if env_key == ToolEnv.PLATFORM_API_KEY: + if not self.platform_api_key: + raise SdkError(f"Env variable '{env_key}' is required") + return self.platform_api_key + + env_value = os.environ.get(env_key) + if env_value is None or env_value == "": + raise SdkError(f"Env variable '{env_key}' is required") + return env_value + + def stream_log( + self, + log: str, + level: LogLevel = LogLevel.INFO, + stage: str = "TOOL_RUN", + **kwargs: Any, + ) -> None: + _level_map = { + LogLevel.DEBUG: logging.DEBUG, + LogLevel.INFO: logging.INFO, + LogLevel.WARN: logging.WARNING, + LogLevel.ERROR: logging.ERROR, + LogLevel.FATAL: logging.CRITICAL, + } + py_level = _level_map.get(level, logging.INFO) + logging.getLogger("executor_tool_shim").log(py_level, log) + + def stream_error_and_exit(self, message: str, err: Exception | None = None) -> None: + raise SdkError(message, actual_err=err) + + +class TestExecutorToolShim: + """Tests for ExecutorToolShim behavior contract.""" + + def test_platform_api_key_returned(self: Self) -> None: + """get_env_or_die('PLATFORM_SERVICE_API_KEY') returns configured key.""" + shim = _MockExecutorToolShim(platform_api_key="sk-test-123") + result = shim.get_env_or_die(ToolEnv.PLATFORM_API_KEY) + assert result == "sk-test-123" + + def test_platform_api_key_missing_raises(self: Self) -> None: + """get_env_or_die('PLATFORM_SERVICE_API_KEY') raises when not configured.""" + shim = _MockExecutorToolShim(platform_api_key="") + with pytest.raises(SdkError, match="PLATFORM_SERVICE_API_KEY"): + shim.get_env_or_die(ToolEnv.PLATFORM_API_KEY) + + def test_other_env_var_from_environ( + self: Self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """get_env_or_die() reads non-platform vars from os.environ.""" + monkeypatch.setenv("MY_CUSTOM_VAR", "custom_value") + shim = _MockExecutorToolShim(platform_api_key="sk-test") + result = shim.get_env_or_die("MY_CUSTOM_VAR") + assert result == "custom_value" + + def test_missing_env_var_raises(self: Self) -> None: + """get_env_or_die() raises SdkError for missing env var.""" + shim = _MockExecutorToolShim(platform_api_key="sk-test") + with pytest.raises(SdkError, match="NONEXISTENT_VAR"): + shim.get_env_or_die("NONEXISTENT_VAR") + + def test_empty_env_var_raises(self: Self, monkeypatch: pytest.MonkeyPatch) -> None: + """get_env_or_die() raises SdkError for empty env var.""" + monkeypatch.setenv("EMPTY_VAR", "") + shim = _MockExecutorToolShim(platform_api_key="sk-test") + with pytest.raises(SdkError, match="EMPTY_VAR"): + shim.get_env_or_die("EMPTY_VAR") + + def test_stream_log_routes_to_logging( + self: Self, caplog: pytest.LogCaptureFixture + ) -> None: + """stream_log() routes to Python logging, not stdout.""" + shim = _MockExecutorToolShim() + with caplog.at_level(logging.INFO, logger="executor_tool_shim"): + shim.stream_log("test message", level=LogLevel.INFO) + assert "test message" in caplog.text + + def test_stream_log_respects_level( + self: Self, caplog: pytest.LogCaptureFixture + ) -> None: + """stream_log() maps SDK LogLevel to Python logging level.""" + shim = _MockExecutorToolShim() + with caplog.at_level(logging.WARNING, logger="executor_tool_shim"): + shim.stream_log("debug msg", level=LogLevel.DEBUG) + shim.stream_log("warn msg", level=LogLevel.WARN) + # DEBUG should be filtered out at WARNING level + assert "debug msg" not in caplog.text + assert "warn msg" in caplog.text + + def test_stream_error_and_exit_raises_sdk_error( + self: Self, + ) -> None: + """stream_error_and_exit() raises SdkError (no sys.exit).""" + shim = _MockExecutorToolShim() + with pytest.raises(SdkError, match="something failed"): + shim.stream_error_and_exit("something failed") + + def test_stream_error_and_exit_wraps_original( + self: Self, + ) -> None: + """stream_error_and_exit() passes original exception.""" + shim = _MockExecutorToolShim() + original = ValueError("root cause") + with pytest.raises(SdkError) as exc_info: + shim.stream_error_and_exit("wrapper msg", err=original) + assert exc_info.value.actual_err is original diff --git a/workers/.env.test b/workers/.env.test new file mode 100644 index 0000000000..8cda6b9dc0 --- /dev/null +++ b/workers/.env.test @@ -0,0 +1,4 @@ +# Test environment variables for workers tests. +# Loaded by tests/conftest.py before any shared package imports. +INTERNAL_API_BASE_URL=http://localhost:8000 +INTERNAL_SERVICE_API_KEY=test-key diff --git a/workers/executor/__init__.py b/workers/executor/__init__.py new file mode 100644 index 0000000000..7982e4d411 --- /dev/null +++ b/workers/executor/__init__.py @@ -0,0 +1,12 @@ +"""Executor Worker + +Celery worker for running extraction executors. +Dispatches ExecutionContext to registered executors and returns +ExecutionResult via the Celery result backend. +""" + +from .worker import app as celery_app + +__all__ = [ + "celery_app", +] diff --git a/workers/executor/executor_tool_shim.py b/workers/executor/executor_tool_shim.py new file mode 100644 index 0000000000..63f48dd253 --- /dev/null +++ b/workers/executor/executor_tool_shim.py @@ -0,0 +1,182 @@ +"""ExecutorToolShim — Lightweight BaseTool substitute for executor workers. + +Adapters (PlatformHelper, LLM, Embedding, VectorDB, X2Text) all require +a ``tool: BaseTool`` parameter that provides ``get_env_or_die()`` and +``stream_log()``. The executor worker has no ``BaseTool`` instance, so +this shim provides just those two methods. + +Precedent: ``prompt-service/.../helpers/prompt_ide_base_tool.py`` +(``PromptServiceBaseTool``). +""" + +import logging +import os +from typing import Any + +from unstract.core.pubsub_helper import LogPublisher +from unstract.sdk1.constants import LogLevel, ToolEnv +from unstract.sdk1.exceptions import SdkError +from unstract.sdk1.tool.stream import StreamMixin + +logger = logging.getLogger(__name__) + +# Map SDK log levels to the string levels used by LogPublisher. +_SDK_TO_WF_LEVEL: dict[LogLevel, str] = { + LogLevel.DEBUG: "INFO", # DEBUG not surfaced to frontend + LogLevel.INFO: "INFO", + LogLevel.WARN: "WARN", + LogLevel.ERROR: "ERROR", + LogLevel.FATAL: "ERROR", +} + +# Mapping from SDK LogLevel enum to Python logging levels. +_LEVEL_MAP = { + LogLevel.DEBUG: logging.DEBUG, + LogLevel.INFO: logging.INFO, + LogLevel.WARN: logging.WARNING, + LogLevel.ERROR: logging.ERROR, + LogLevel.FATAL: logging.CRITICAL, +} + + +class ExecutorToolShim(StreamMixin): + """Minimal BaseTool substitute for use inside executor workers. + + Provides the two methods that adapters actually call: + + - ``get_env_or_die(env_key)`` — reads env vars, with special + handling for ``PLATFORM_SERVICE_API_KEY`` (multitenancy) + - ``stream_log(log, level)`` — routes to Python logging instead + of the Unstract stdout JSON protocol used by tools + + Usage:: + + shim = ExecutorToolShim(platform_api_key="sk-...") + adapter = SomeAdapter(tool=shim) # adapter calls shim.get_env_or_die() + """ + + def __init__( + self, + platform_api_key: str = "", + log_events_id: str = "", + component: dict[str, str] | None = None, + ) -> None: + """Initialize the shim. + + Args: + platform_api_key: The platform service API key for this + execution. Returned by ``get_env_or_die()`` when the + caller asks for ``PLATFORM_SERVICE_API_KEY``. + log_events_id: Socket.IO channel ID for streaming progress + logs. Empty string disables publishing. + component: Structured identifier dict for log correlation + (``tool_id``, ``run_id``, ``doc_name``, optionally + ``prompt_key``). + """ + self.platform_api_key = platform_api_key + self.log_events_id = log_events_id + self.component = component or {} + # Initialize StreamMixin. EXECUTION_BY_TOOL is not set in + # the worker environment, so _exec_by_tool will be False. + super().__init__(log_level=LogLevel.INFO) + + def get_env_or_die(self, env_key: str) -> str: + """Return environment variable value. + + Special-cases ``PLATFORM_SERVICE_API_KEY`` to return the key + passed at construction time (supports multitenancy — each + execution may use a different org's API key). + + Args: + env_key: Environment variable name. + + Returns: + The value of the environment variable. + + Raises: + SdkError: If the variable is missing or empty. + """ + if env_key == ToolEnv.PLATFORM_API_KEY: + if not self.platform_api_key: + raise SdkError(f"Env variable '{env_key}' is required") + return self.platform_api_key + + env_value = os.environ.get(env_key) + if env_value is None or env_value == "": + raise SdkError(f"Env variable '{env_key}' is required") + return env_value + + def stream_log( + self, + log: str, + level: LogLevel = LogLevel.INFO, + stage: str = "TOOL_RUN", + **kwargs: dict[str, Any], + ) -> None: + """Route log messages to Python logging and publish progress. + + In the executor worker context, logs go through the standard + Python logging framework (captured by Celery) rather than the + Unstract stdout JSON protocol used by tools. + + Progress messages are published via ``LogPublisher.publish()`` + to the Redis broker (shared with worker-logging). + + Args: + log: The log message. + level: SDK log level. + stage: Ignored (only meaningful for stdout protocol). + **kwargs: Ignored (only meaningful for stdout protocol). + """ + py_level = _LEVEL_MAP.get(level, logging.INFO) + logger.log(py_level, log) + + # Respect log level threshold for frontend publishing (matches + # StreamMixin.stream_log behaviour). Python logging above still + # captures everything for debugging. + _levels = [ + LogLevel.DEBUG, + LogLevel.INFO, + LogLevel.WARN, + LogLevel.ERROR, + LogLevel.FATAL, + ] + if _levels.index(level) < _levels.index(self.log_level): + return + + # Publish progress to frontend via the log consumer queue. + if self.log_events_id: + try: + wf_level = _SDK_TO_WF_LEVEL.get(level, "INFO") + payload = LogPublisher.log_progress( + component=self.component, + level=wf_level, + state=stage, + message=log, + ) + LogPublisher.publish( + channel_id=self.log_events_id, + payload=payload, + ) + except Exception: + logger.debug( + "Failed to publish progress log (non-fatal)", + exc_info=True, + ) + + def stream_error_and_exit(self, message: str, err: Exception | None = None) -> None: + """Log error and raise SdkError. + + Unlike the base StreamMixin which may call ``sys.exit(1)`` + when running as a tool, the executor worker always raises + an exception so the Celery task can handle it gracefully. + + Args: + message: Error description. + err: Original exception, if any. + + Raises: + SdkError: Always. + """ + logger.error(message) + raise SdkError(message, actual_err=err) diff --git a/workers/executor/executors/__init__.py b/workers/executor/executors/__init__.py new file mode 100644 index 0000000000..cb2b54c980 --- /dev/null +++ b/workers/executor/executors/__init__.py @@ -0,0 +1,16 @@ +"""Executor implementations package. + +Importing this module triggers ``@ExecutorRegistry.register`` for all +bundled executors and discovers cloud executors via entry points. +""" + +from executor.executors.legacy_executor import LegacyExecutor +from executor.executors.plugins.loader import ExecutorPluginLoader + +# Discover and register cloud executors installed via entry points. +# Each cloud executor class is decorated with @ExecutorRegistry.register, +# so importing it (via ep.load()) is enough to register it. +# If no cloud plugins are installed this returns an empty list. +_cloud_executors = ExecutorPluginLoader.discover_executors() + +__all__ = ["LegacyExecutor"] diff --git a/workers/executor/executors/answer_prompt.py b/workers/executor/executors/answer_prompt.py new file mode 100644 index 0000000000..859f205bba --- /dev/null +++ b/workers/executor/executors/answer_prompt.py @@ -0,0 +1,343 @@ +"""Answer prompt service — prompt construction and LLM execution. + +Ported from prompt-service/.../services/answer_prompt.py. +Flask dependencies (app.logger, PluginManager, APIError) replaced with +standard logging and executor exceptions. + +Highlight/word-confidence support is available via the ``process_text`` +callback parameter — callers pass the highlight-data plugin's ``run`` +method when the plugin is installed. Challenge and evaluation plugins +are integrated at the caller level (LegacyExecutor). +""" + +import ipaddress +import logging +import os +import socket +from typing import Any +from urllib.parse import urlparse + +from executor.executors.constants import PromptServiceConstants as PSKeys +from executor.executors.exceptions import LegacyExecutorError, RateLimitError + +logger = logging.getLogger(__name__) + + +def _is_safe_public_url(url: str) -> bool: + """Validate webhook URL for SSRF protection. + + Only allows HTTPS and blocks private/loopback/internal addresses. + """ + try: + p = urlparse(url) + if p.scheme not in ("https",): + return False + host = p.hostname or "" + if host in ("localhost",): + return False + + addrs: set[str] = set() + try: + ipaddress.ip_address(host) + addrs.add(host) + except ValueError: + try: + for _family, _type, _proto, _canonname, sockaddr in socket.getaddrinfo( + host, None, type=socket.SOCK_STREAM + ): + addrs.add(sockaddr[0]) + except Exception: + return False + + if not addrs: + return False + + for addr in addrs: + try: + ip = ipaddress.ip_address(addr) + except ValueError: + return False + if ( + ip.is_private + or ip.is_loopback + or ip.is_link_local + or ip.is_reserved + or ip.is_multicast + ): + return False + return True + except Exception: + return False + + +class AnswerPromptService: + @staticmethod + def extract_variable( + structured_output: dict[str, Any], + variable_names: list[Any], + output: dict[str, Any], + promptx: str, + ) -> str: + """Replace %variable_name% references in the prompt text.""" + for variable_name in variable_names: + if promptx.find(f"%{variable_name}%") >= 0: + if variable_name in structured_output: + promptx = promptx.replace( + f"%{variable_name}%", + str(structured_output[variable_name]), + ) + else: + raise ValueError( + f"Variable {variable_name} not found in structured output" + ) + + if promptx != output[PSKeys.PROMPT]: + logger.debug( + "Prompt modified by variable replacement for: %s", + output.get(PSKeys.NAME, ""), + ) + return promptx + + @staticmethod + def construct_and_run_prompt( + tool_settings: dict[str, Any], + output: dict[str, Any], + llm: Any, + context: str, + prompt: str, + metadata: dict[str, Any], + file_path: str = "", + execution_source: str | None = "ide", + process_text: Any = None, + ) -> str: + """Construct the full prompt and run LLM completion. + + Args: + tool_settings: Global tool settings (preamble, postamble, etc.) + output: The prompt definition dict. + llm: LLM adapter instance. + context: Retrieved context string. + prompt: Key into ``output`` for the prompt text (usually "promptx"). + metadata: Metadata dict (updated in place with highlight info). + file_path: Path to the extracted text file. + execution_source: "ide" or "tool". + process_text: Optional callback for text processing during + completion (e.g. highlight-data plugin's ``run`` method). + + Returns: + The LLM answer string. + """ + platform_postamble = tool_settings.get(PSKeys.PLATFORM_POSTAMBLE, "") + word_confidence_postamble = tool_settings.get( + PSKeys.WORD_CONFIDENCE_POSTAMBLE, "" + ) + summarize_as_source = tool_settings.get(PSKeys.SUMMARIZE_AS_SOURCE) + enable_highlight = tool_settings.get(PSKeys.ENABLE_HIGHLIGHT, False) + enable_word_confidence = tool_settings.get(PSKeys.ENABLE_WORD_CONFIDENCE, False) + if not enable_highlight: + enable_word_confidence = False + prompt_type = output.get(PSKeys.TYPE, PSKeys.TEXT) + if not enable_highlight or summarize_as_source: + platform_postamble = "" + if not enable_word_confidence or summarize_as_source: + word_confidence_postamble = "" + + prompt = AnswerPromptService.construct_prompt( + preamble=tool_settings.get(PSKeys.PREAMBLE, ""), + prompt=output[prompt], + postamble=tool_settings.get(PSKeys.POSTAMBLE, ""), + grammar_list=tool_settings.get(PSKeys.GRAMMAR, []), + context=context, + platform_postamble=platform_postamble, + word_confidence_postamble=word_confidence_postamble, + prompt_type=prompt_type, + ) + output[PSKeys.COMBINED_PROMPT] = prompt + return AnswerPromptService.run_completion( + llm=llm, + prompt=prompt, + metadata=metadata, + prompt_key=output[PSKeys.NAME], + prompt_type=prompt_type, + enable_highlight=enable_highlight, + enable_word_confidence=enable_word_confidence, + file_path=file_path, + execution_source=execution_source, + process_text=process_text, + ) + + @staticmethod + def construct_prompt( + preamble: str, + prompt: str, + postamble: str, + grammar_list: list[dict[str, Any]], + context: str, + platform_postamble: str, + word_confidence_postamble: str, + prompt_type: str = "text", + ) -> str: + """Build the full prompt string with preamble, grammar, postamble, context.""" + prompt = f"{preamble}\n\nQuestion or Instruction: {prompt}" + if grammar_list is not None and len(grammar_list) > 0: + prompt += "\n" + for grammar in grammar_list: + word = "" + synonyms = [] + if PSKeys.WORD in grammar: + word = grammar[PSKeys.WORD] + if PSKeys.SYNONYMS in grammar: + synonyms = grammar[PSKeys.SYNONYMS] + if len(synonyms) > 0 and word != "": + prompt += ( + f"\nNote: You can consider that the word '{word}' " + f"is the same as {', '.join(synonyms)} " + f"in both the question and the context." + ) + if prompt_type == PSKeys.JSON: + json_postamble = os.environ.get( + PSKeys.JSON_POSTAMBLE, PSKeys.DEFAULT_JSON_POSTAMBLE + ) + postamble += f"\n{json_postamble}" + if platform_postamble: + platform_postamble += "\n\n" + if word_confidence_postamble: + platform_postamble += f"{word_confidence_postamble}\n\n" + prompt += ( + f"\n\n{postamble}\n\nContext:\n---------------\n{context}\n" + f"-----------------\n\n{platform_postamble}Answer:" + ) + return prompt + + @staticmethod + def run_completion( + llm: Any, + prompt: str, + metadata: dict[str, str] | None = None, + prompt_key: str | None = None, + prompt_type: str | None = "text", + enable_highlight: bool = False, + enable_word_confidence: bool = False, + file_path: str = "", + execution_source: str | None = None, + process_text: Any = None, + ) -> str: + """Run LLM completion and extract the answer. + + Args: + process_text: Optional callback for text processing during + completion (e.g. highlight-data plugin's ``run`` method). + When provided, the SDK passes LLM response text through + this callback, enabling source attribution. + """ + try: + from unstract.sdk1.exceptions import RateLimitError as SdkRateLimitError + from unstract.sdk1.exceptions import SdkError + except ImportError: + SdkRateLimitError = Exception + SdkError = Exception + + try: + completion = llm.complete( + prompt=prompt, + process_text=process_text, + extract_json=prompt_type.lower() != PSKeys.TEXT, + ) + answer: str = completion[PSKeys.RESPONSE].text + highlight_data = completion.get(PSKeys.HIGHLIGHT_DATA, []) + confidence_data = completion.get(PSKeys.CONFIDENCE_DATA) + word_confidence_data = completion.get(PSKeys.WORD_CONFIDENCE_DATA) + line_numbers = completion.get(PSKeys.LINE_NUMBERS, []) + whisper_hash = completion.get(PSKeys.WHISPER_HASH, "") + if metadata is not None and prompt_key: + metadata.setdefault(PSKeys.HIGHLIGHT_DATA, {})[prompt_key] = ( + highlight_data + ) + metadata.setdefault(PSKeys.LINE_NUMBERS, {})[prompt_key] = line_numbers + metadata[PSKeys.WHISPER_HASH] = whisper_hash + if confidence_data: + metadata.setdefault(PSKeys.CONFIDENCE_DATA, {})[prompt_key] = ( + confidence_data + ) + if enable_word_confidence and word_confidence_data: + metadata.setdefault(PSKeys.WORD_CONFIDENCE_DATA, {})[prompt_key] = ( + word_confidence_data + ) + return answer + except SdkRateLimitError as e: + raise RateLimitError(f"Rate limit error. {str(e)}") from e + except SdkError as e: + logger.error("Error fetching response for prompt: %s", e) + status_code = getattr(e, "status_code", None) or 500 + raise LegacyExecutorError(message=str(e), code=status_code) from e + + @staticmethod + def handle_json( + answer: str, + structured_output: dict[str, Any], + output: dict[str, Any], + llm: Any, + enable_highlight: bool = False, + enable_word_confidence: bool = False, + execution_source: str = "ide", + metadata: dict[str, Any] | None = None, + file_path: str = "", + log_events_id: str = "", + tool_id: str = "", + doc_name: str = "", + ) -> None: + """Handle JSON responses from the LLM.""" + from executor.executors.json_repair_helper import repair_json_with_best_structure + from executor.executors.postprocessor import postprocess_data + + prompt_key = output[PSKeys.NAME] + if answer.lower() == "na": + structured_output[prompt_key] = None + else: + parsed_data = repair_json_with_best_structure(answer) + + if isinstance(parsed_data, str): + logger.error("Error parsing response to JSON") + structured_output[prompt_key] = {} + else: + webhook_enabled = output.get(PSKeys.ENABLE_POSTPROCESSING_WEBHOOK, False) + webhook_url = output.get(PSKeys.POSTPROCESSING_WEBHOOK_URL) + + highlight_data = None + if enable_highlight and metadata and PSKeys.HIGHLIGHT_DATA in metadata: + highlight_data = metadata[PSKeys.HIGHLIGHT_DATA].get(prompt_key) + + processed_data = parsed_data + updated_highlight_data = None + + if webhook_enabled: + if not webhook_url: + logger.warning( + "Postprocessing webhook enabled but URL missing; skipping." + ) + elif not _is_safe_public_url(webhook_url): + logger.warning( + "Postprocessing webhook URL is not allowed; skipping." + ) + else: + try: + processed_data, updated_highlight_data = postprocess_data( + parsed_data, + webhook_enabled=True, + webhook_url=webhook_url, + highlight_data=highlight_data, + timeout=60, + ) + except Exception as e: + logger.warning( + "Postprocessing webhook failed: %s. " + "Using unprocessed data.", + e, + ) + + structured_output[prompt_key] = processed_data + + if enable_highlight and metadata and updated_highlight_data is not None: + metadata.setdefault(PSKeys.HIGHLIGHT_DATA, {})[prompt_key] = ( + updated_highlight_data + ) diff --git a/workers/executor/executors/constants.py b/workers/executor/executors/constants.py new file mode 100644 index 0000000000..9eddab8423 --- /dev/null +++ b/workers/executor/executors/constants.py @@ -0,0 +1,203 @@ +from enum import Enum + + +class PromptServiceConstants: + """Constants used in the prompt service.""" + + WORD = "word" + SYNONYMS = "synonyms" + OUTPUTS = "outputs" + TOOL_ID = "tool_id" + RUN_ID = "run_id" + EXECUTION_ID = "execution_id" + FILE_NAME = "file_name" + FILE_HASH = "file_hash" + NAME = "name" + ACTIVE = "active" + PROMPT = "prompt" + CHUNK_SIZE = "chunk-size" + PROMPTX = "promptx" + VECTOR_DB = "vector-db" + EMBEDDING = "embedding" + X2TEXT_ADAPTER = "x2text_adapter" + CHUNK_OVERLAP = "chunk-overlap" + LLM = "llm" + IS_ASSERT = "is_assert" + ASSERTION_FAILURE_PROMPT = "assertion_failure_prompt" + RETRIEVAL_STRATEGY = "retrieval-strategy" + TYPE = "type" + NUMBER = "number" + EMAIL = "email" + DATE = "date" + BOOLEAN = "boolean" + JSON = "json" + PREAMBLE = "preamble" + SIMILARITY_TOP_K = "similarity-top-k" + PROMPT_TOKENS = "prompt_tokens" + COMPLETION_TOKENS = "completion_tokens" + TOTAL_TOKENS = "total_tokens" + RESPONSE = "response" + POSTAMBLE = "postamble" + GRAMMAR = "grammar" + PLATFORM_SERVICE_API_KEY = "PLATFORM_SERVICE_API_KEY" + EMBEDDING_SUFFIX = "embedding_suffix" + EVAL_SETTINGS = "eval_settings" + EVAL_SETTINGS_EVALUATE = "evaluate" + EVAL_SETTINGS_MONITOR_LLM = "monitor_llm" + EVAL_SETTINGS_EXCLUDE_FAILED = "exclude_failed" + TOOL_SETTINGS = "tool_settings" + LOG_EVENTS_ID = "log_events_id" + CHALLENGE_LLM = "challenge_llm" + CHALLENGE = "challenge" + ENABLE_CHALLENGE = "enable_challenge" + EXTRACTION = "extraction" + SUMMARIZE = "summarize" + SINGLE_PASS_EXTRACTION = "single-pass-extraction" + SIMPLE_PROMPT_STUDIO = "simple-prompt-studio" + LLM_USAGE_REASON = "llm_usage_reason" + METADATA = "metadata" + OUTPUT = "output" + CONTEXT = "context" + INCLUDE_METADATA = "include_metadata" + TABLE = "table" + TABLE_SETTINGS = "table_settings" + EPILOGUE = "epilogue" + PLATFORM_POSTAMBLE = "platform_postamble" + WORD_CONFIDENCE_POSTAMBLE = "word_confidence_postamble" + HIGHLIGHT_DATA_PLUGIN = "highlight-data" + SUMMARIZE_AS_SOURCE = "summarize_as_source" + VARIABLE_MAP = "variable_map" + RECORD = "record" + CUSTOM_DATA = "custom_data" + TEXT = "text" + ENABLE_HIGHLIGHT = "enable_highlight" + ENABLE_WORD_CONFIDENCE = "enable_word_confidence" + FILE_PATH = "file_path" + HIGHLIGHT_DATA = "highlight_data" + CONFIDENCE_DATA = "confidence_data" + WORD_CONFIDENCE_DATA = "word_confidence_data" + REQUIRED_FIELDS = "required_fields" + REQUIRED = "required" + EXECUTION_SOURCE = "execution_source" + METRICS = "metrics" + CAPTURE_METRICS = "capture_metrics" + LINE_ITEM = "line-item" + LINE_NUMBERS = "line_numbers" + WHISPER_HASH = "whisper_hash" + PAID_FEATURE_MSG = ( + "It is a cloud / enterprise feature. If you have purchased a plan and still " + "face this issue, please contact support" + ) + NO_CONTEXT_ERROR = ( + "Couldn't fetch context from vector DB. " + "This happens usually due to a delay by the Vector DB " + "provider to confirm writes to DB. " + "Please try again after some time" + ) + COMBINED_PROMPT = "combined_prompt" + TOOL = "tool" + JSON_POSTAMBLE = "JSON_POSTAMBLE" + DEFAULT_JSON_POSTAMBLE = "Wrap the final JSON result inbetween §§§ like below example:\n§§§\n\n§§§" + DOCUMENT_TYPE = "document_type" + # Webhook postprocessing settings + ENABLE_POSTPROCESSING_WEBHOOK = "enable_postprocessing_webhook" + POSTPROCESSING_WEBHOOK_URL = "postprocessing_webhook_url" + + +class RunLevel(Enum): + """Different stages of prompt execution. + + Comprises of prompt run and response evaluation stages. + """ + + RUN = "RUN" + EVAL = "EVAL" + CHALLENGE = "CHALLENGE" + TABLE_EXTRACTION = "TABLE_EXTRACTION" + + +class DBTableV2: + """Database tables.""" + + ORGANIZATION = "organization" + ADAPTER_INSTANCE = "adapter_instance" + PROMPT_STUDIO_REGISTRY = "prompt_studio_registry" + PLATFORM_KEY = "platform_key" + TOKEN_USAGE = "usage" + + +class FileStorageKeys: + """File storage keys.""" + + PERMANENT_REMOTE_STORAGE = "PERMANENT_REMOTE_STORAGE" + TEMPORARY_REMOTE_STORAGE = "TEMPORARY_REMOTE_STORAGE" + + +class FileStorageType(Enum): + """File storage type.""" + + PERMANENT = "permanent" + TEMPORARY = "temporary" + + +class ExecutionSource(Enum): + """Execution source.""" + + IDE = "ide" + TOOL = "tool" + + +class VariableType(str, Enum): + """Type of variable.""" + + STATIC = "STATIC" + DYNAMIC = "DYNAMIC" + CUSTOM_DATA = "CUSTOM_DATA" + + +class RetrievalStrategy(str, Enum): + """Available retrieval strategies for prompt service.""" + + SIMPLE = "simple" + SUBQUESTION = "subquestion" + FUSION = "fusion" + RECURSIVE = "recursive" + ROUTER = "router" + KEYWORD_TABLE = "keyword_table" + AUTOMERGING = "automerging" + + +class VariableConstants: + """Constants for variable extraction.""" + + VARIABLE_REGEX = "{{(.+?)}}" + DYNAMIC_VARIABLE_DATA_REGEX = r"\[(.*?)\]" + DYNAMIC_VARIABLE_URL_REGEX = ( + r"(?i)\b((?:https?://|www\d{0,3}[.]|[a-z0-9.\-]+[.][a-z]{2,4}/)(?:[^\s()<>]+|\(([^\s()<>]+|(\([^\s()<>]+\)))*\))+(?:\(([^\s()<>]+|(\([^\s()<>]+\)))*\)|[^\s`!()\[\]{};:'\".,<>?«»" + "'']))" + ) # noqa: E501 + CUSTOM_DATA_VARIABLE_REGEX = r"custom_data\.([a-zA-Z0-9_\.]+)" + + +class IndexingConstants: + TOOL_ID = "tool_id" + EMBEDDING_INSTANCE_ID = "embedding_instance_id" + VECTOR_DB_INSTANCE_ID = "vector_db_instance_id" + X2TEXT_INSTANCE_ID = "x2text_instance_id" + FILE_PATH = "file_path" + CHUNK_SIZE = "chunk_size" + CHUNK_OVERLAP = "chunk_overlap" + REINDEX = "reindex" + FILE_HASH = "file_hash" + OUTPUT_FILE_PATH = "output_file_path" + ENABLE_HIGHLIGHT = "enable_highlight" + ENABLE_WORD_CONFIDENCE = "enable_word_confidence" + USAGE_KWARGS = "usage_kwargs" + PROCESS_TEXT = "process_text" + EXTRACTED_TEXT = "extracted_text" + TAGS = "tags" + EXECUTION_SOURCE = "execution_source" + DOC_ID = "doc_id" + TOOL_EXECUTION_METATADA = "tool_execution_metadata" + EXECUTION_DATA_DIR = "execution_data_dir" + METADATA_FILE = "METADATA.json" diff --git a/workers/executor/executors/dto.py b/workers/executor/executors/dto.py new file mode 100644 index 0000000000..8c9e4f3d3c --- /dev/null +++ b/workers/executor/executors/dto.py @@ -0,0 +1,39 @@ +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class InstanceIdentifiers: + embedding_instance_id: str + vector_db_instance_id: str + x2text_instance_id: str + llm_instance_id: str + tool_id: str + tags: list[str] | None = None + + +@dataclass +class FileInfo: + file_path: str + file_hash: str + + +@dataclass +class ChunkingConfig: + chunk_size: int + chunk_overlap: int + + def __post_init__(self) -> None: + if self.chunk_size == 0: + raise ValueError( + "Indexing cannot be done for zero chunks." + "Please provide a valid chunk_size." + ) + + +@dataclass +class ProcessingOptions: + reindex: bool = False + enable_highlight: bool = False + enable_word_confidence: bool = False + usage_kwargs: dict[Any, Any] = field(default_factory=dict) diff --git a/workers/executor/executors/exceptions.py b/workers/executor/executors/exceptions.py new file mode 100644 index 0000000000..69cd0a8a16 --- /dev/null +++ b/workers/executor/executors/exceptions.py @@ -0,0 +1,79 @@ +"""Standalone exceptions for the legacy executor. + +Adapted from prompt-service exceptions. The Flask ``APIError`` base +class is replaced with ``LegacyExecutorError`` so these exceptions +work outside of Flask (i.e. inside the Celery executor worker). +""" + + +class LegacyExecutorError(Exception): + """Base exception for legacy executor errors. + + Replaces Flask's ``APIError`` — carries ``message`` and ``code`` + attributes so callers can map to ``ExecutionResult.failure()``. + """ + + code: int = 500 + message: str = "Internal executor error" + + def __init__(self, message: str | None = None, code: int | None = None): + if message is not None: + self.message = message + if code is not None: + self.code = code + super().__init__(self.message) + + +class BadRequest(LegacyExecutorError): + code = 400 + message = "Bad Request / No payload" + + +class RateLimitError(LegacyExecutorError): + code = 429 + message = "Running into rate limit errors, please try again later" + + +class MissingFieldError(LegacyExecutorError): + """Custom error for missing fields.""" + + def __init__(self, missing_fields: list[str]): + message = f"Missing required fields: {', '.join(missing_fields)}" + super().__init__(message=message) + + +class RetrievalError(LegacyExecutorError): + """Custom exception raised for errors during retrieval from VectorDB.""" + + DEFAULT_MESSAGE = ( + "Error while retrieving data from the VectorDB. " + "Please contact the admin for further assistance." + ) + + +class ExtractionError(LegacyExecutorError): + DEFAULT_MESSAGE = "Error while extracting from a document" + + +class UnprocessableEntity(LegacyExecutorError): + code = 422 + message = "Unprocessable Entity" + + +class CustomDataError(LegacyExecutorError): + """Custom exception raised for errors with custom_data variables.""" + + code = 400 + + def __init__(self, variable: str, reason: str, is_ide: bool = True): + if is_ide: + help_text = "Please define this key in Prompt Studio Settings > Custom Data." + else: + help_text = ( + "Please include this key in the 'custom_data' field of your API request." + ) + variable_display = "{{custom_data." + variable + "}}" + message = ( + f"Custom data error for variable '{variable_display}': {reason} {help_text}" + ) + super().__init__(message=message) diff --git a/workers/executor/executors/file_utils.py b/workers/executor/executors/file_utils.py new file mode 100644 index 0000000000..92f80d6d76 --- /dev/null +++ b/workers/executor/executors/file_utils.py @@ -0,0 +1,40 @@ +"""File storage utilities for the legacy executor. + +Adapted from ``prompt-service/.../utils/file_utils.py``. +Returns the appropriate ``FileStorage`` instance based on execution source. +""" + +from executor.executors.constants import ExecutionSource, FileStorageKeys + +from unstract.sdk1.file_storage import FileStorage +from unstract.sdk1.file_storage.constants import StorageType +from unstract.sdk1.file_storage.env_helper import EnvHelper + + +class FileUtils: + @staticmethod + def get_fs_instance(execution_source: str) -> FileStorage: + """Returns a FileStorage instance based on the execution source. + + Args: + execution_source: The source from which the execution is triggered. + + Returns: + FileStorage: The file storage instance — Permanent/Shared temporary. + + Raises: + ValueError: If the execution source is invalid. + """ + if execution_source == ExecutionSource.IDE.value: + return EnvHelper.get_storage( + storage_type=StorageType.PERMANENT, + env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE, + ) + + if execution_source == ExecutionSource.TOOL.value: + return EnvHelper.get_storage( + storage_type=StorageType.SHARED_TEMPORARY, + env_name=FileStorageKeys.TEMPORARY_REMOTE_STORAGE, + ) + + raise ValueError(f"Invalid execution source: {execution_source}") diff --git a/workers/executor/executors/index.py b/workers/executor/executors/index.py new file mode 100644 index 0000000000..da2b68be82 --- /dev/null +++ b/workers/executor/executors/index.py @@ -0,0 +1,222 @@ +"""Indexing logic for the legacy executor. + +Adapted from ``prompt-service/.../core/index_v2.py``. +Performs document chunking and vector DB indexing. + +Heavy dependencies (``llama_index``, ``openai``, vectordb adapters) +are imported lazily inside methods to avoid protobuf descriptor +conflicts at test-collection time. +""" + +from __future__ import annotations + +import json +import logging +from typing import TYPE_CHECKING, Any + +from executor.executors.dto import ( + ChunkingConfig, + FileInfo, + InstanceIdentifiers, + ProcessingOptions, +) + +from unstract.sdk1.constants import LogLevel +from unstract.sdk1.exceptions import SdkError, parse_litellm_err +from unstract.sdk1.file_storage.impl import FileStorage +from unstract.sdk1.file_storage.provider import FileStorageProvider +from unstract.sdk1.platform import PlatformHelper as ToolAdapter +from unstract.sdk1.tool.stream import StreamMixin +from unstract.sdk1.utils.tool import ToolUtils + +if TYPE_CHECKING: + from unstract.sdk1.embedding import Embedding + from unstract.sdk1.vector_db import VectorDB + +logger = logging.getLogger(__name__) + + +class Index: + def __init__( + self, + tool: StreamMixin, + instance_identifiers: InstanceIdentifiers, + chunking_config: ChunkingConfig, + processing_options: ProcessingOptions, + run_id: str | None = None, + capture_metrics: bool = False, + ): + self.tool = tool + self._run_id = run_id + self._capture_metrics = capture_metrics + self.instance_identifiers = instance_identifiers + self.chunking_config = chunking_config + self.processing_options = processing_options + self._metrics = {} + + def generate_index_key( + self, + file_info: FileInfo, + fs: FileStorage = FileStorage(provider=FileStorageProvider.LOCAL), + ) -> str: + """Generate a unique index key for document indexing.""" + if not file_info.file_path and not file_info.file_hash: + raise ValueError("One of `file_path` or `file_hash` need to be provided") + + file_hash = file_info.file_hash + if not file_hash: + file_hash = fs.get_hash_from_file(path=file_info.file_path) + + index_key = { + "file_hash": file_hash, + "vector_db_config": ToolAdapter.get_adapter_config( + self.tool, self.instance_identifiers.vector_db_instance_id + ), + "embedding_config": ToolAdapter.get_adapter_config( + self.tool, self.instance_identifiers.embedding_instance_id + ), + "x2text_config": ToolAdapter.get_adapter_config( + self.tool, self.instance_identifiers.x2text_instance_id + ), + "chunk_size": str(self.chunking_config.chunk_size), + "chunk_overlap": str(self.chunking_config.chunk_overlap), + } + hashed_index_key = ToolUtils.hash_str(json.dumps(index_key, sort_keys=True)) + return hashed_index_key + + def is_document_indexed( + self, + doc_id: str, + embedding: Embedding, + vector_db: VectorDB, + ) -> bool: + """Check if nodes are already present in the vector DB for a doc_id.""" + from llama_index.core.vector_stores import ( + FilterOperator, + MetadataFilter, + MetadataFilters, + VectorStoreQuery, + VectorStoreQueryResult, + ) + + doc_id_eq_filter = MetadataFilter.from_dict( + {"key": "doc_id", "operator": FilterOperator.EQ, "value": doc_id} + ) + filters = MetadataFilters(filters=[doc_id_eq_filter]) + q = VectorStoreQuery( + query_embedding=embedding.get_query_embedding(" "), + doc_ids=[doc_id], + filters=filters, + ) + + doc_id_found = False + try: + n: VectorStoreQueryResult = vector_db.query(query=q) + if len(n.nodes) > 0: + doc_id_found = True + self.tool.stream_log(f"Found {len(n.nodes)} nodes for {doc_id}") + else: + self.tool.stream_log(f"No nodes found for {doc_id}") + except Exception as e: + logger.warning( + f"Error querying {self.instance_identifiers.vector_db_instance_id}:" + f" {str(e)}, proceeding to index", + exc_info=True, + ) + + if doc_id_found and not self.processing_options.reindex: + self.tool.stream_log(f"File was indexed already under {doc_id}") + return doc_id_found + + return doc_id_found + + def perform_indexing( + self, + vector_db: VectorDB, + doc_id: str, + extracted_text: str, + doc_id_found: bool, + ) -> str: + from unstract.sdk1.adapters.vectordb.no_op.src.no_op_custom_vectordb import ( + NoOpCustomVectorDB, + ) + + if isinstance( + vector_db.get_vector_db( + adapter_instance_id=self.instance_identifiers.vector_db_instance_id, + embedding_dimension=1, + ), + (NoOpCustomVectorDB), + ): + return doc_id + + self.tool.stream_log("Indexing file...") + full_text = [ + { + "section": "full", + "text_contents": str(extracted_text), + } + ] + documents = self._prepare_documents(doc_id, full_text) + if self.processing_options.reindex and doc_id_found: + self.delete_nodes(vector_db, doc_id) + self._trigger_indexing(vector_db, documents) + return doc_id + + def _trigger_indexing(self, vector_db: Any, documents: list) -> None: + import openai + + self.tool.stream_log("Adding nodes to vector db...") + try: + vector_db.index_document( + documents, + chunk_size=self.chunking_config.chunk_size, + chunk_overlap=self.chunking_config.chunk_overlap, + show_progress=True, + ) + self.tool.stream_log("File has been indexed successfully") + except openai.OpenAIError as e: + e = parse_litellm_err(e) + raise e + except Exception as e: + self.tool.stream_log( + f"Error adding nodes to vector db: {e}", + level=LogLevel.ERROR, + ) + raise e + + def delete_nodes(self, vector_db: Any, doc_id: str) -> None: + try: + vector_db.delete(ref_doc_id=doc_id) + self.tool.stream_log(f"Deleted nodes for {doc_id}") + except Exception as e: + self.tool.stream_log( + f"Error deleting nodes for {doc_id}: {e}", + level=LogLevel.ERROR, + ) + raise SdkError(f"Error deleting nodes for {doc_id}: {e}") from e + + def _prepare_documents(self, doc_id: str, full_text: Any) -> list: + from llama_index.core import Document + + documents = [] + try: + for item in full_text: + text = item["text_contents"] + document = Document( + text=text, + doc_id=doc_id, + metadata={"section": item["section"]}, + ) + document.id_ = doc_id + documents.append(document) + self.tool.stream_log(f"Number of documents: {len(documents)}") + return documents + except Exception as e: + self.tool.stream_log( + f"Error while processing documents {doc_id}: {e}", + level=LogLevel.ERROR, + ) + raise SdkError( + f"Error while processing documents for indexing {doc_id}: {e}" + ) from e diff --git a/workers/executor/executors/json_repair_helper.py b/workers/executor/executors/json_repair_helper.py new file mode 100644 index 0000000000..f1cf17c0b0 --- /dev/null +++ b/workers/executor/executors/json_repair_helper.py @@ -0,0 +1,63 @@ +"""JSON repair utility functions. + +Copied from prompt-service/.../utils/json_repair_helper.py — already Flask-free. +""" + +import json +from typing import Any + + +def repair_json_with_best_structure(json_str: str) -> Any: + """Intelligently repair JSON string using the best parsing strategy. + + Attempts to parse as valid JSON first, then falls back to basic repair + heuristics. The full ``json_repair`` library is used when available for + more aggressive repair. + + Args: + json_str: The JSON string to repair + + Returns: + The parsed JSON object with the best structure + """ + # Fast path — try strict JSON first + try: + return json.loads(json_str) + except (json.JSONDecodeError, ValueError): + pass + + # Try to import json_repair for advanced repair + try: + from json_repair import repair_json + + parsed_as_is = repair_json( + json_str=json_str, return_objects=True, ensure_ascii=False + ) + parsed_with_wrap = repair_json( + json_str="[" + json_str, return_objects=True, ensure_ascii=False + ) + + if isinstance(parsed_as_is, str) and isinstance(parsed_with_wrap, str): + return parsed_as_is + if isinstance(parsed_as_is, str): + return parsed_with_wrap + if isinstance(parsed_with_wrap, str): + return parsed_as_is + + if ( + isinstance(parsed_with_wrap, list) + and len(parsed_with_wrap) == 1 + and parsed_with_wrap[0] == parsed_as_is + ): + return parsed_as_is + + if isinstance(parsed_as_is, (dict, list)): + if isinstance(parsed_with_wrap, list) and len(parsed_with_wrap) > 1: + return parsed_with_wrap + else: + return parsed_as_is + + return parsed_with_wrap + except ImportError: + # json_repair not installed — return the raw string + return json_str diff --git a/workers/executor/executors/legacy_executor.py b/workers/executor/executors/legacy_executor.py new file mode 100644 index 0000000000..ec44444556 --- /dev/null +++ b/workers/executor/executors/legacy_executor.py @@ -0,0 +1,1676 @@ +"""Legacy executor — migrates the prompt-service pipeline. + +Phase 2A scaffolds the class with operation routing. +Phase 2B implements ``_handle_extract`` (text extraction via x2text). +Phase 2C implements ``_handle_index`` (vector DB indexing). +Remaining handler methods raise ``NotImplementedError`` and are filled +in by phases 2D–2H. +""" + +import logging +import time +from pathlib import Path +from typing import Any + +from executor.executor_tool_shim import ExecutorToolShim +from executor.executors.constants import ExecutionSource +from executor.executors.constants import IndexingConstants as IKeys +from executor.executors.dto import ( + ChunkingConfig, + FileInfo, + InstanceIdentifiers, + ProcessingOptions, +) +from executor.executors.exceptions import ExtractionError, LegacyExecutorError +from executor.executors.file_utils import FileUtils + +from unstract.sdk1.adapters.exceptions import AdapterError +from unstract.sdk1.constants import LogLevel +from unstract.sdk1.adapters.x2text.constants import X2TextConstants +from unstract.sdk1.adapters.x2text.llm_whisperer.src import LLMWhisperer +from unstract.sdk1.adapters.x2text.llm_whisperer_v2.src import LLMWhispererV2 +from unstract.sdk1.execution.context import ExecutionContext, Operation +from unstract.sdk1.execution.executor import BaseExecutor +from unstract.sdk1.execution.registry import ExecutorRegistry +from unstract.sdk1.execution.result import ExecutionResult +from unstract.sdk1.utils.tool import ToolUtils +from unstract.sdk1.x2txt import TextExtractionResult, X2Text + +logger = logging.getLogger(__name__) + + +@ExecutorRegistry.register +class LegacyExecutor(BaseExecutor): + """Executor that wraps the full prompt-service extraction pipeline. + + Routes incoming ``ExecutionContext`` requests to the appropriate + handler method based on the ``Operation`` enum. Each handler + corresponds to one of the original prompt-service HTTP endpoints. + """ + + # Maps Operation enum values to handler method names. + _OPERATION_MAP: dict[str, str] = { + Operation.EXTRACT.value: "_handle_extract", + Operation.INDEX.value: "_handle_index", + Operation.ANSWER_PROMPT.value: "_handle_answer_prompt", + Operation.SINGLE_PASS_EXTRACTION.value: "_handle_single_pass_extraction", + Operation.SUMMARIZE.value: "_handle_summarize", + Operation.IDE_INDEX.value: "_handle_ide_index", + Operation.STRUCTURE_PIPELINE.value: "_handle_structure_pipeline", + } + + # Defaults for log streaming (overridden by execute()). + _log_events_id: str = "" + _log_component: dict[str, str] = {} + + @property + def name(self) -> str: + return "legacy" + + def execute(self, context: ExecutionContext) -> ExecutionResult: + """Route to the handler for ``context.operation``. + + Returns: + ``ExecutionResult`` on success or for unsupported operations. + ``LegacyExecutorError`` subclasses are caught and mapped to + ``ExecutionResult.failure()`` so callers always get a result. + + Raises: + NotImplementedError: From stub handlers (until 2D–2H). + """ + # Extract log streaming info (set by tasks.py for IDE sessions). + self._log_events_id: str = context.log_events_id or "" + self._log_component: dict[str, str] = getattr(context, "_log_component", {}) + + handler_name = self._OPERATION_MAP.get(context.operation) + if handler_name is None: + return ExecutionResult.failure( + error=( + f"LegacyExecutor does not support operation " f"'{context.operation}'" + ) + ) + + handler = getattr(self, handler_name) + logger.info( + "LegacyExecutor routing operation=%s to %s " + "(run_id=%s request_id=%s execution_source=%s)", + context.operation, + handler_name, + context.run_id, + context.request_id, + context.execution_source, + ) + start = time.monotonic() + try: + result = handler(context) + elapsed = time.monotonic() - start + logger.info( + "Handler %s completed in %.2fs " "(run_id=%s success=%s)", + handler_name, + elapsed, + context.run_id, + result.success, + ) + return result + except LegacyExecutorError as exc: + elapsed = time.monotonic() - start + logger.warning( + "Handler %s failed after %.2fs: %s: %s", + handler_name, + elapsed, + type(exc).__name__, + exc.message, + exc_info=True, + ) + # Stream error to FE so the user sees the failure in real-time + if self._log_events_id: + try: + shim = ExecutorToolShim( + log_events_id=self._log_events_id, + component=self._log_component, + ) + shim.stream_log( + f"Error: {exc.message or type(exc).__name__}", + level=LogLevel.ERROR, + ) + except Exception: + pass # Best-effort — don't mask the original error + return ExecutionResult.failure(error=exc.message) + + # ------------------------------------------------------------------ + # Phase 2B — Extract handler + # ------------------------------------------------------------------ + + def _handle_extract(self, context: ExecutionContext) -> ExecutionResult: + """Handle ``Operation.EXTRACT`` — text extraction via x2text. + + Migrated from ``ExtractionService.perform_extraction()`` in + ``prompt-service/.../services/extraction.py``. + + Returns: + ExecutionResult with ``data`` containing ``extracted_text``. + """ + params: dict[str, Any] = context.executor_params + + # Required params + x2text_instance_id: str = params.get(IKeys.X2TEXT_INSTANCE_ID, "") + file_path: str = params.get(IKeys.FILE_PATH, "") + platform_api_key: str = params.get("platform_api_key", "") + + if not x2text_instance_id or not file_path: + missing = [] + if not x2text_instance_id: + missing.append(IKeys.X2TEXT_INSTANCE_ID) + if not file_path: + missing.append(IKeys.FILE_PATH) + return ExecutionResult.failure( + error=f"Missing required params: {', '.join(missing)}" + ) + + # Optional params + output_file_path: str | None = params.get(IKeys.OUTPUT_FILE_PATH) + enable_highlight: bool = params.get(IKeys.ENABLE_HIGHLIGHT, False) + usage_kwargs: dict[Any, Any] = params.get(IKeys.USAGE_KWARGS, {}) + tags: list[str] | None = params.get(IKeys.TAGS) + execution_source: str = context.execution_source + tool_exec_metadata: dict[str, Any] = params.get(IKeys.TOOL_EXECUTION_METATADA, {}) + execution_data_dir: str | None = params.get(IKeys.EXECUTION_DATA_DIR) + + # Build adapter shim and X2Text + shim = ExecutorToolShim( + platform_api_key=platform_api_key, + log_events_id=self._log_events_id, + component=self._log_component, + ) + x2text = X2Text( + tool=shim, + adapter_instance_id=x2text_instance_id, + usage_kwargs=usage_kwargs, + ) + fs = FileUtils.get_fs_instance(execution_source=execution_source) + + logger.info( + "Starting text extraction: x2text_adapter=%s file=%s " "run_id=%s", + x2text_instance_id, + Path(file_path).name, + context.run_id, + ) + logger.info( + "HIGHLIGHT_DEBUG _handle_extract: enable_highlight=%s " + "x2text_type=%s file=%s run_id=%s", + enable_highlight, + type(x2text.x2text_instance).__name__, + Path(file_path).name, + context.run_id, + ) + shim.stream_log("Initializing text extractor...") + shim.stream_log(f"Using text extractor: {type(x2text.x2text_instance).__name__}") + + try: + shim.stream_log("Extracting text from document...") + if enable_highlight and isinstance( + x2text.x2text_instance, (LLMWhisperer, LLMWhispererV2) + ): + shim.stream_log("Extracting text with highlight support enabled...") + process_response: TextExtractionResult = x2text.process( + input_file_path=file_path, + output_file_path=output_file_path, + enable_highlight=enable_highlight, + tags=tags, + fs=fs, + ) + self._update_exec_metadata( + fs=fs, + execution_source=execution_source, + tool_exec_metadata=tool_exec_metadata, + execution_data_dir=execution_data_dir, + process_response=process_response, + ) + else: + process_response = x2text.process( + input_file_path=file_path, + output_file_path=output_file_path, + tags=tags, + fs=fs, + ) + + has_metadata = bool( + process_response.extraction_metadata + and process_response.extraction_metadata.line_metadata + ) + logger.info( + "HIGHLIGHT_DEBUG extraction result: has_line_metadata=%s " + "whisper_hash=%s run_id=%s", + has_metadata, + getattr(process_response.extraction_metadata, "whisper_hash", None) + if process_response.extraction_metadata + else None, + context.run_id, + ) + logger.info( + "Text extraction completed: file=%s run_id=%s", + Path(file_path).name, + context.run_id, + ) + shim.stream_log("Text extraction completed") + result_data: dict[str, Any] = { + IKeys.EXTRACTED_TEXT: process_response.extracted_text, + } + # Include highlight metadata when available + # (used by agentic extraction for PDF source referencing) + if ( + process_response.extraction_metadata + and process_response.extraction_metadata.line_metadata + ): + shim.stream_log("Saving extraction metadata...") + result_data["highlight_metadata"] = ( + process_response.extraction_metadata.line_metadata + ) + return ExecutionResult( + success=True, + data=result_data, + ) + except AdapterError as e: + name = x2text.x2text_instance.get_name() + logger.error( + "Text extraction failed: adapter=%s file=%s error=%s", + name, + Path(file_path).name, + str(e), + ) + msg = f"Error from text extractor '{name}'. {e}" + raise ExtractionError(message=msg) from e + + @staticmethod + def _update_exec_metadata( + fs: Any, + execution_source: str, + tool_exec_metadata: dict[str, Any] | None, + execution_data_dir: str | None, + process_response: TextExtractionResult, + ) -> None: + """Write whisper_hash metadata for tool-sourced executions.""" + if execution_source != ExecutionSource.TOOL.value: + return + whisper_hash = process_response.extraction_metadata.whisper_hash + metadata = {X2TextConstants.WHISPER_HASH: whisper_hash} + if tool_exec_metadata is not None: + for key, value in metadata.items(): + tool_exec_metadata[key] = value + metadata_path = str(Path(execution_data_dir) / IKeys.METADATA_FILE) + ToolUtils.dump_json( + file_to_dump=metadata_path, + json_to_dump=metadata, + fs=fs, + ) + + @staticmethod + def _get_indexing_deps(): + """Lazy-import heavy indexing dependencies. + + These imports trigger llama_index/qdrant/protobuf loading, + so they must not happen at module-collection time (tests). + Wrapped in a method so tests can mock it cleanly. + """ + from executor.executors.index import Index + + from unstract.sdk1.embedding import EmbeddingCompat + from unstract.sdk1.vector_db import VectorDB + + return Index, EmbeddingCompat, VectorDB + + # ------------------------------------------------------------------ + # Phase 5C — Compound IDE index handler (extract + index) + # ------------------------------------------------------------------ + + def _handle_ide_index(self, context: ExecutionContext) -> ExecutionResult: + """Handle ``Operation.IDE_INDEX`` — compound extract then index. + + This compound operation combines ``_handle_extract`` and + ``_handle_index`` in a single executor invocation, eliminating + the need for the backend Celery worker to block between steps. + + The ``executor_params`` must contain: + - ``extract_params``: Parameters for ``_handle_extract``. + - ``index_params``: Parameters for ``_handle_index``. The + executor injects ``extracted_text`` from the extract step + before calling index. + + Returns: + ExecutionResult with ``data`` containing ``doc_id`` from + the index step. + """ + params = context.executor_params + extract_params = params.get("extract_params") + index_params = params.get("index_params") + + if not extract_params or not index_params: + missing = [] + if not extract_params: + missing.append("extract_params") + if not index_params: + missing.append("index_params") + return ExecutionResult.failure( + error=f"ide_index missing required params: " f"{', '.join(missing)}" + ) + + # Step 1: Extract + extract_ctx = ExecutionContext( + executor_name=context.executor_name, + operation=Operation.EXTRACT.value, + run_id=context.run_id, + execution_source=context.execution_source, + organization_id=context.organization_id, + executor_params=extract_params, + request_id=context.request_id, + log_events_id=context.log_events_id, + ) + extract_result = self._handle_extract(extract_ctx) + if not extract_result.success: + return extract_result + + # Step 2: Index — inject extracted text + extracted_text = extract_result.data.get(IKeys.EXTRACTED_TEXT, "") + index_params[IKeys.EXTRACTED_TEXT] = extracted_text + + index_ctx = ExecutionContext( + executor_name=context.executor_name, + operation=Operation.INDEX.value, + run_id=context.run_id, + execution_source=context.execution_source, + organization_id=context.organization_id, + executor_params=index_params, + request_id=context.request_id, + log_events_id=context.log_events_id, + ) + index_result = self._handle_index(index_ctx) + if not index_result.success: + return index_result + + return ExecutionResult( + success=True, + data={ + IKeys.DOC_ID: index_result.data.get(IKeys.DOC_ID, ""), + }, + ) + + # ------------------------------------------------------------------ + # Phase 5D — Compound structure pipeline handler + # ------------------------------------------------------------------ + + def _handle_structure_pipeline(self, context: ExecutionContext) -> ExecutionResult: + """Handle ``Operation.STRUCTURE_PIPELINE``. + + Runs the full structure-tool pipeline in a single executor + invocation: extract → summarize → index → answer_prompt. + + This eliminates three sequential ``dispatcher.dispatch()`` calls + that would otherwise block a file_processing worker slot. + + Expected ``executor_params`` keys: + + ``extract_params`` + Parameters for ``_handle_extract``. + ``index_template`` + Common indexing params (``tool_id``, ``file_hash``, + ``is_highlight_enabled``, ``platform_api_key``, + ``extracted_file_path``). + ``answer_params`` + Full payload for ``_handle_answer_prompt`` / + ``_handle_single_pass_extraction``. + ``pipeline_options`` + Control flags: ``skip_extraction_and_indexing``, + ``is_summarization_enabled``, ``is_single_pass_enabled``, + ``input_file_path``, ``source_file_name``. + ``summarize_params`` + (Optional) Parameters for ``_handle_summarize`` plus + filesystem paths for caching. + + Returns: + ExecutionResult with ``data`` containing the structured + output dict (``output``, ``metadata``, ``metrics``). + """ + params = context.executor_params + extract_params = params.get("extract_params", {}) + index_template = params.get("index_template", {}) + answer_params = params.get("answer_params", {}) + pipeline_options = params.get("pipeline_options", {}) + summarize_params = params.get("summarize_params") + + skip_extraction = pipeline_options.get("skip_extraction_and_indexing", False) + is_summarization = pipeline_options.get("is_summarization_enabled", False) + is_single_pass = pipeline_options.get("is_single_pass_enabled", False) + input_file_path = pipeline_options.get("input_file_path", "") + source_file_name = pipeline_options.get("source_file_name", "") + + extracted_text = "" + index_metrics: dict = {} + + shim = ExecutorToolShim( + platform_api_key=extract_params.get("platform_api_key", ""), + log_events_id=self._log_events_id, + component=self._log_component, + ) + step = 1 + + # ---- Step 1: Extract ---- + if not skip_extraction: + shim.stream_log(f"Pipeline step {step}: Extracting text from document...") + step += 1 + extract_ctx = ExecutionContext( + executor_name=context.executor_name, + operation=Operation.EXTRACT.value, + run_id=context.run_id, + execution_source=context.execution_source, + organization_id=context.organization_id, + executor_params=extract_params, + request_id=context.request_id, + log_events_id=context.log_events_id, + ) + extract_result = self._handle_extract(extract_ctx) + if not extract_result.success: + return extract_result + extracted_text = extract_result.data.get(IKeys.EXTRACTED_TEXT, "") + + # ---- Step 2: Summarize (if enabled) ---- + if is_summarization: + shim.stream_log(f"Pipeline step {step}: Summarizing extracted text...") + step += 1 + summarize_result = self._run_pipeline_summarize( + context=context, + summarize_params=summarize_params or {}, + answer_params=answer_params, + ) + if not summarize_result.success: + return summarize_result + # answer_params file_path/hash updated in-place by helper + elif skip_extraction: + # Smart table: use original source file + answer_params["file_path"] = input_file_path + elif not is_single_pass: + # ---- Step 3: Index per output with dedup ---- + shim.stream_log( + f"Pipeline step {step}: Indexing document into vector store..." + ) + step += 1 + index_metrics = self._run_pipeline_index( + context=context, + index_template=index_template, + answer_params=answer_params, + extracted_text=extracted_text, + ) + + # ---- Step 4: Table settings injection ---- + if not is_single_pass: + outputs = answer_params.get("outputs", []) + extracted_file_path = index_template.get("extracted_file_path", "") + for output in outputs: + if "table_settings" in output: + table_settings = output["table_settings"] + is_dir = table_settings.get("is_directory_mode", False) + if skip_extraction: + table_settings["input_file"] = input_file_path + answer_params["file_path"] = input_file_path + else: + table_settings["input_file"] = extracted_file_path + table_settings["is_directory_mode"] = is_dir + output["table_settings"] = table_settings + + # ---- Step 5: Answer prompt / Single pass ---- + mode_label = "single pass" if is_single_pass else "prompt" + shim.stream_log(f"Pipeline step {step}: Running {mode_label} execution...") + operation = ( + Operation.SINGLE_PASS_EXTRACTION.value + if is_single_pass + else Operation.ANSWER_PROMPT.value + ) + answer_ctx = ExecutionContext( + executor_name=context.executor_name, + operation=operation, + run_id=context.run_id, + execution_source=context.execution_source, + organization_id=context.organization_id, + executor_params=answer_params, + request_id=context.request_id, + log_events_id=context.log_events_id, + ) + answer_result = self._handle_answer_prompt(answer_ctx) + if not answer_result.success: + return answer_result + + # ---- Step 6: Merge results ---- + structured_output = answer_result.data + + # Ensure metadata section + if "metadata" not in structured_output: + structured_output["metadata"] = {} + structured_output["metadata"]["file_name"] = source_file_name + + # Add extracted text for HITL raw view + if extracted_text: + structured_output["metadata"]["extracted_text"] = extracted_text + + # Merge index metrics + if index_metrics: + existing_metrics = structured_output.get("metrics", {}) + merged = self._merge_pipeline_metrics(existing_metrics, index_metrics) + structured_output["metrics"] = merged + + shim.stream_log("Pipeline completed successfully") + return ExecutionResult(success=True, data=structured_output) + + def _run_pipeline_summarize( + self, + context: ExecutionContext, + summarize_params: dict, + answer_params: dict, + ) -> ExecutionResult: + """Run the summarize step of the structure pipeline. + + Handles filesystem caching: if a cached summary exists, uses it. + Otherwise calls ``_handle_summarize`` and writes the result. + Updates ``answer_params`` in-place with new file_path and + file_hash. + """ + extract_file_path = summarize_params.get("extract_file_path", "") + summarize_file_path = summarize_params.get("summarize_file_path", "") + platform_api_key = summarize_params.get("platform_api_key", "") + llm_adapter_id = summarize_params.get("llm_adapter_instance_id", "") + summarize_prompt = summarize_params.get("summarize_prompt", "") + prompt_keys = summarize_params.get("prompt_keys", []) + outputs = answer_params.get("outputs", []) + + fs = FileUtils.get_fs_instance(execution_source=context.execution_source) + + # Set chunk_size=0 for all outputs when summarizing + embedding = answer_params.get("tool_settings", {}).get("embedding", "") + vector_db = answer_params.get("tool_settings", {}).get("vector-db", "") + x2text = answer_params.get("tool_settings", {}).get("x2text_adapter", "") + for output in outputs: + output["embedding"] = embedding + output["vector-db"] = vector_db + output["x2text_adapter"] = x2text + output["chunk-size"] = 0 + output["chunk-overlap"] = 0 + + # Check cache + summarized_context = "" + if fs.exists(summarize_file_path): + summarized_context = fs.read(path=summarize_file_path, mode="r") + + if not summarized_context: + # Read extracted text + doc_context = fs.read(path=extract_file_path, mode="r") + if not doc_context: + return ExecutionResult.failure( + error="No extracted text found for summarization" + ) + + summarize_ctx = ExecutionContext( + executor_name=context.executor_name, + operation=Operation.SUMMARIZE.value, + run_id=context.run_id, + execution_source=context.execution_source, + organization_id=context.organization_id, + request_id=context.request_id, + log_events_id=context.log_events_id, + executor_params={ + "llm_adapter_instance_id": llm_adapter_id, + "summarize_prompt": summarize_prompt, + "context": doc_context, + "prompt_keys": prompt_keys, + "PLATFORM_SERVICE_API_KEY": platform_api_key, + }, + ) + summarize_result = self._handle_summarize(summarize_ctx) + if not summarize_result.success: + return summarize_result + + summarized_context = summarize_result.data.get("data", "") + fs.write( + path=summarize_file_path, + mode="w", + data=summarized_context, + ) + + # Update answer_params + summarize_file_hash = fs.get_hash_from_file(path=summarize_file_path) + answer_params["file_hash"] = summarize_file_hash + answer_params["file_path"] = str(summarize_file_path) + + return ExecutionResult(success=True, data={}) + + def _run_pipeline_index( + self, + context: ExecutionContext, + index_template: dict, + answer_params: dict, + extracted_text: str, + ) -> dict: + """Run per-output indexing with dedup for the structure pipeline. + + Returns: + Dict of index metrics keyed by output name. + """ + import datetime + + tool_settings = answer_params.get("tool_settings", {}) + outputs = answer_params.get("outputs", []) + tool_id = index_template.get("tool_id", "") + file_hash = index_template.get("file_hash", "") + is_highlight = index_template.get("is_highlight_enabled", False) + platform_api_key = index_template.get("platform_api_key", "") + extracted_file_path = index_template.get("extracted_file_path", "") + + index_metrics: dict = {} + seen_params: set = set() + + for output in outputs: + chunk_size = output.get("chunk-size", 0) + chunk_overlap = output.get("chunk-overlap", 0) + vector_db = tool_settings.get("vector-db", "") + embedding = tool_settings.get("embedding", "") + x2text = tool_settings.get("x2text_adapter", "") + + param_key = ( + f"chunk_size={chunk_size}_" + f"chunk_overlap={chunk_overlap}_" + f"vector_db={vector_db}_" + f"embedding={embedding}_" + f"x2text={x2text}" + ) + + if chunk_size != 0 and param_key not in seen_params: + seen_params.add(param_key) + + indexing_start = datetime.datetime.now() + logger.info( + "Pipeline indexing: chunk_size=%s " "chunk_overlap=%s vector_db=%s", + chunk_size, + chunk_overlap, + vector_db, + ) + + index_ctx = ExecutionContext( + executor_name=context.executor_name, + operation=Operation.INDEX.value, + run_id=context.run_id, + execution_source=context.execution_source, + organization_id=context.organization_id, + request_id=context.request_id, + log_events_id=context.log_events_id, + executor_params={ + "embedding_instance_id": embedding, + "vector_db_instance_id": vector_db, + "x2text_instance_id": x2text, + "chunk_size": chunk_size, + "chunk_overlap": chunk_overlap, + "file_path": extracted_file_path, + "reindex": True, + "tool_id": tool_id, + "file_hash": file_hash, + "enable_highlight": is_highlight, + "extracted_text": extracted_text, + "platform_api_key": platform_api_key, + }, + ) + index_result = self._handle_index(index_ctx) + if not index_result.success: + logger.warning( + "Pipeline indexing failed for %s: %s", + param_key, + index_result.error, + ) + + elapsed = (datetime.datetime.now() - indexing_start).total_seconds() + output_name = output.get("name", "") + index_metrics[output_name] = {"indexing": {"time_taken(s)": elapsed}} + + return index_metrics + + @staticmethod + def _merge_pipeline_metrics(metrics1: dict, metrics2: dict) -> dict: + """Merge two metrics dicts, combining sub-dicts for shared keys.""" + merged: dict = {} + all_keys = set(metrics1) | set(metrics2) + for key in all_keys: + if ( + key in metrics1 + and key in metrics2 + and isinstance(metrics1[key], dict) + and isinstance(metrics2[key], dict) + ): + merged[key] = {**metrics1[key], **metrics2[key]} + elif key in metrics1: + merged[key] = metrics1[key] + else: + merged[key] = metrics2[key] + return merged + + # ------------------------------------------------------------------ + # Phase 2C — Index handler + # ------------------------------------------------------------------ + + def _handle_index(self, context: ExecutionContext) -> ExecutionResult: + """Handle ``Operation.INDEX`` — vector DB indexing. + + Migrated from ``IndexingService.index()`` in + ``prompt-service/.../services/indexing.py``. + + Returns: + ExecutionResult with ``data`` containing ``doc_id``. + """ + params: dict[str, Any] = context.executor_params + + # Required params + embedding_instance_id: str = params.get(IKeys.EMBEDDING_INSTANCE_ID, "") + vector_db_instance_id: str = params.get(IKeys.VECTOR_DB_INSTANCE_ID, "") + x2text_instance_id: str = params.get(IKeys.X2TEXT_INSTANCE_ID, "") + file_path: str = params.get(IKeys.FILE_PATH, "") + extracted_text: str = params.get(IKeys.EXTRACTED_TEXT, "") + platform_api_key: str = params.get("platform_api_key", "") + + missing = [] + if not embedding_instance_id: + missing.append(IKeys.EMBEDDING_INSTANCE_ID) + if not vector_db_instance_id: + missing.append(IKeys.VECTOR_DB_INSTANCE_ID) + if not x2text_instance_id: + missing.append(IKeys.X2TEXT_INSTANCE_ID) + if not file_path: + missing.append(IKeys.FILE_PATH) + if missing: + return ExecutionResult.failure( + error=f"Missing required params: {', '.join(missing)}" + ) + + # Optional params + tool_id: str = params.get(IKeys.TOOL_ID, "") + file_hash: str | None = params.get(IKeys.FILE_HASH) + chunk_size: int = params.get(IKeys.CHUNK_SIZE, 512) + chunk_overlap: int = params.get(IKeys.CHUNK_OVERLAP, 128) + reindex: bool = params.get(IKeys.REINDEX, False) + enable_highlight: bool = params.get(IKeys.ENABLE_HIGHLIGHT, False) + enable_word_confidence: bool = params.get(IKeys.ENABLE_WORD_CONFIDENCE, False) + usage_kwargs: dict[Any, Any] = params.get(IKeys.USAGE_KWARGS, {}) + tags: list[str] | None = params.get(IKeys.TAGS) + execution_source: str = context.execution_source + + instance_ids = InstanceIdentifiers( + embedding_instance_id=embedding_instance_id, + vector_db_instance_id=vector_db_instance_id, + x2text_instance_id=x2text_instance_id, + tool_id=tool_id, + tags=tags, + llm_instance_id=None, + ) + file_info = FileInfo(file_path=file_path, file_hash=file_hash) + processing_options = ProcessingOptions( + reindex=reindex, + enable_highlight=enable_highlight, + enable_word_confidence=enable_word_confidence, + usage_kwargs=usage_kwargs, + ) + + shim = ExecutorToolShim( + platform_api_key=platform_api_key, + log_events_id=self._log_events_id, + component=self._log_component, + ) + fs_instance = FileUtils.get_fs_instance(execution_source=execution_source) + + logger.info( + "Starting indexing: chunk_size=%d chunk_overlap=%d " + "reindex=%s file=%s run_id=%s", + chunk_size, + chunk_overlap, + reindex, + Path(file_path).name, + context.run_id, + ) + shim.stream_log("Initializing indexing pipeline...") + + # Skip indexing when chunk_size is 0 — no vector operations needed. + # ChunkingConfig raises ValueError for 0, so handle before DTO. + if chunk_size == 0: + from unstract.sdk1.utils.indexing import IndexingUtils + + doc_id = IndexingUtils.generate_index_key( + vector_db=vector_db_instance_id, + embedding=embedding_instance_id, + x2text=x2text_instance_id, + chunk_size=str(chunk_size), + chunk_overlap=str(chunk_overlap), + tool=shim, + file_path=file_path, + file_hash=file_hash, + fs=fs_instance, + ) + logger.info("Skipping indexing for chunk_size=0. Doc ID: %s", doc_id) + return ExecutionResult(success=True, data={IKeys.DOC_ID: doc_id}) + + chunking_config = ChunkingConfig( + chunk_size=chunk_size, chunk_overlap=chunk_overlap + ) + shim.stream_log( + f"Configured chunking: size={chunk_size}, overlap={chunk_overlap}" + ) + + Index, EmbeddingCompat, VectorDB = self._get_indexing_deps() + + vector_db = None + try: + index = Index( + tool=shim, + run_id=context.run_id, + capture_metrics=True, + instance_identifiers=instance_ids, + chunking_config=chunking_config, + processing_options=processing_options, + ) + doc_id = index.generate_index_key(file_info=file_info, fs=fs_instance) + logger.debug("Generated index key: doc_id=%s", doc_id) + shim.stream_log("Checking document index status...") + + embedding = EmbeddingCompat( + adapter_instance_id=embedding_instance_id, + tool=shim, + kwargs={**usage_kwargs}, + ) + vector_db = VectorDB( + tool=shim, + adapter_instance_id=vector_db_instance_id, + embedding=embedding, + ) + shim.stream_log("Initialized embedding and vector DB adapters") + + doc_id_found = index.is_document_indexed( + doc_id=doc_id, embedding=embedding, vector_db=vector_db + ) + logger.info( + "Index status: doc_id=%s found=%s reindex=%s", + doc_id, + doc_id_found, + reindex, + ) + if doc_id_found and reindex: + shim.stream_log("Document already indexed, re-indexing...") + elif not doc_id_found: + shim.stream_log("Indexing document for the first time...") + shim.stream_log("Indexing document into vector store...") + index.perform_indexing( + vector_db=vector_db, + doc_id=doc_id, + extracted_text=extracted_text, + doc_id_found=doc_id_found, + ) + logger.info( + "Indexing completed: doc_id=%s file=%s", + doc_id, + Path(file_path).name, + ) + shim.stream_log("Document indexing completed") + return ExecutionResult(success=True, data={IKeys.DOC_ID: doc_id}) + except Exception as e: + logger.error( + "Indexing failed: file=%s error=%s", + Path(file_path).name, + str(e), + ) + status_code = getattr(e, "status_code", 500) + raise LegacyExecutorError( + message=f"Error while indexing: {e}", code=status_code + ) from e + finally: + if vector_db is not None: + vector_db.close() + + @staticmethod + def _get_prompt_deps(): + """Lazy-import heavy dependencies for answer_prompt processing. + + These imports trigger llama_index/protobuf loading so they must + not happen at module-collection time (tests). + """ + from executor.executors.answer_prompt import AnswerPromptService + from executor.executors.index import Index + from executor.executors.retrieval import RetrievalService + from executor.executors.variable_replacement import ( + VariableReplacementService, + ) + + from unstract.sdk1.embedding import EmbeddingCompat + from unstract.sdk1.llm import LLM + from unstract.sdk1.vector_db import VectorDB + + return ( + AnswerPromptService, + RetrievalService, + VariableReplacementService, + Index, + LLM, + EmbeddingCompat, + VectorDB, + ) + + @staticmethod + def _sanitize_null_values( + structured_output: dict[str, Any], + ) -> dict[str, Any]: + """Replace 'NA' strings with None in structured output.""" + for k, v in structured_output.items(): + if isinstance(v, str) and v.lower() == "na": + structured_output[k] = None + elif isinstance(v, list): + for i in range(len(v)): + if isinstance(v[i], str) and v[i].lower() == "na": + v[i] = None + elif isinstance(v[i], dict): + for k1, v1 in v[i].items(): + if isinstance(v1, str) and v1.lower() == "na": + v[i][k1] = None + elif isinstance(v, dict): + for k1, v1 in v.items(): + if isinstance(v1, str) and v1.lower() == "na": + v[k1] = None + return structured_output + + def _handle_answer_prompt(self, context: ExecutionContext) -> ExecutionResult: + """Handle ``Operation.ANSWER_PROMPT`` — multi-prompt extraction. + + Migrated from ``prompt_processor()`` in the prompt-service + ``answer_prompt`` controller. Processes all prompts in the + payload: variable replacement, context retrieval, LLM + completion, and type-specific post-processing. + + Returns: + ExecutionResult with ``data`` containing:: + + {"output": dict, "metadata": dict, "metrics": dict} + """ + from executor.executors.constants import ( + PromptServiceConstants as PSKeys, + ) + from executor.executors.constants import ( + RetrievalStrategy, + ) + + params: dict[str, Any] = context.executor_params + + # ---- Unpack payload ------------------------------------------------ + tool_settings = params.get(PSKeys.TOOL_SETTINGS, {}) + prompts = params.get(PSKeys.OUTPUTS, []) + tool_id: str = params.get(PSKeys.TOOL_ID, "") + run_id: str = context.run_id + execution_id: str = params.get(PSKeys.EXECUTION_ID, "") + file_hash = params.get(PSKeys.FILE_HASH) + file_path = params.get(PSKeys.FILE_PATH) + doc_name = str(params.get(PSKeys.FILE_NAME, "")) + log_events_id: str = params.get(PSKeys.LOG_EVENTS_ID, "") + custom_data: dict[str, Any] = params.get(PSKeys.CUSTOM_DATA, {}) + execution_source = params.get(PSKeys.EXECUTION_SOURCE, context.execution_source) + platform_api_key: str = params.get(PSKeys.PLATFORM_SERVICE_API_KEY, "") + + structured_output: dict[str, Any] = {} + metadata: dict[str, Any] = { + PSKeys.RUN_ID: run_id, + PSKeys.FILE_NAME: doc_name, + PSKeys.CONTEXT: {}, + PSKeys.REQUIRED_FIELDS: {}, + } + metrics: dict[str, Any] = {} + variable_names: list[str] = [] + context_retrieval_metrics: dict[str, Any] = {} + + logger.info( + "Starting answer_prompt: tool_id=%s prompt_count=%d " "file=%s run_id=%s", + tool_id, + len(prompts), + doc_name, + run_id, + ) + + # Lazy imports + ( + AnswerPromptService, + RetrievalService, + VariableReplacementService, + _Index, # unused — doc_id via IndexingUtils + LLM, + EmbeddingCompat, + VectorDB, + ) = self._get_prompt_deps() + + # ---- Initialize highlight plugin (if enabled + installed) ---------- + process_text_fn = None + enable_highlight = tool_settings.get(PSKeys.ENABLE_HIGHLIGHT, False) + enable_word_confidence = tool_settings.get(PSKeys.ENABLE_WORD_CONFIDENCE, False) + pipeline_shim = ExecutorToolShim( + platform_api_key=platform_api_key, + log_events_id=self._log_events_id, + component=self._log_component, + ) + if enable_highlight: + from executor.executors.plugins import ExecutorPluginLoader + + highlight_cls = ExecutorPluginLoader.get("highlight-data") + if highlight_cls: + from executor.executors.file_utils import FileUtils + + fs_instance = FileUtils.get_fs_instance(execution_source=execution_source) + highlight_instance = highlight_cls( + file_path=file_path, + fs_instance=fs_instance, + enable_word_confidence=enable_word_confidence, + ) + process_text_fn = highlight_instance.run + logger.info( + "Highlight plugin initialized for file=%s", + doc_name, + ) + pipeline_shim.stream_log("Highlight data plugin ready") + else: + logger.warning( + "Highlight is enabled but highlight-data plugin is not " + "installed. Coordinates will not be produced. Install " + "the plugin via: pip install -e " + ) + pipeline_shim.stream_log("Highlight data plugin not available") + + # ---- Merge tool_settings as defaults into each prompt output -------- + # Single-pass payloads carry adapter IDs and chunk config in + # tool_settings only (not per-prompt), while answer_prompt payloads + # carry them per-prompt. Merging tool_settings as a base ensures + # both paths work. + _ts_defaults = { + k: v + for k, v in tool_settings.items() + if k + in { + PSKeys.CHUNK_SIZE, + PSKeys.CHUNK_OVERLAP, + PSKeys.LLM, + PSKeys.VECTOR_DB, + PSKeys.EMBEDDING, + PSKeys.X2TEXT_ADAPTER, + PSKeys.RETRIEVAL_STRATEGY, + PSKeys.SIMILARITY_TOP_K, + } + } + if _ts_defaults: + prompts = [{**_ts_defaults, **p} for p in prompts] + + # ---- First pass: collect variable names + required fields ---------- + for output in prompts: + variable_names.append(output[PSKeys.NAME]) + metadata[PSKeys.REQUIRED_FIELDS][output[PSKeys.NAME]] = output.get( + PSKeys.REQUIRED, None + ) + + # ---- Process each prompt ------------------------------------------- + for output in prompts: + prompt_name = output[PSKeys.NAME] + prompt_text = output[PSKeys.PROMPT] + chunk_size = output[PSKeys.CHUNK_SIZE] + + logger.debug( + "Prompt config: name=%s chunk_size=%d type=%s", + prompt_name, + chunk_size, + output.get(PSKeys.TYPE, "TEXT"), + ) + + # Enrich component with current prompt_key for log correlation. + prompt_component = { + **self._log_component, + "prompt_key": prompt_name, + } + shim = ExecutorToolShim( + platform_api_key=platform_api_key, + log_events_id=self._log_events_id, + component=prompt_component, + ) + shim.stream_log(f"Processing prompt: {prompt_name}") + + # {{variable}} template replacement + if VariableReplacementService.is_variables_present(prompt_text=prompt_text): + is_ide = execution_source == "ide" + prompt_text = VariableReplacementService.replace_variables_in_prompt( + prompt=output, + structured_output=structured_output, + log_events_id=log_events_id, + tool_id=tool_id, + prompt_name=prompt_name, + doc_name=doc_name, + custom_data=custom_data, + is_ide=is_ide, + ) + shim.stream_log(f"Resolved template variables for: {prompt_name}") + + logger.info( + "Executing prompt: tool_id=%s name=%s run_id=%s", + tool_id, + prompt_name, + run_id, + ) + + # %variable% replacement + output[PSKeys.PROMPTX] = AnswerPromptService.extract_variable( + structured_output, variable_names, output, prompt_text + ) + + # Generate doc_id (standalone util — no Index DTOs needed) + from unstract.sdk1.utils.indexing import IndexingUtils + + doc_id = IndexingUtils.generate_index_key( + vector_db=output[PSKeys.VECTOR_DB], + embedding=output[PSKeys.EMBEDDING], + x2text=output[PSKeys.X2TEXT_ADAPTER], + chunk_size=str(output[PSKeys.CHUNK_SIZE]), + chunk_overlap=str(output[PSKeys.CHUNK_OVERLAP]), + tool=shim, + file_hash=file_hash, + file_path=file_path, + ) + + # TABLE/RECORD: delegate to TableExtractorExecutor in-process. + # The table executor plugin handles PDF table detection, + # header extraction, and CSV-to-JSON post-processing. + if output.get(PSKeys.TYPE) in (PSKeys.TABLE, PSKeys.RECORD): + from unstract.sdk1.execution.registry import ExecutorRegistry + + try: + table_executor = ExecutorRegistry.get("table") + except KeyError: + raise LegacyExecutorError( + message=( + "TABLE extraction requires the table executor " + "plugin. Install the table_extractor plugin." + ) + ) + + table_ctx = ExecutionContext( + executor_name="table", + operation="table_extract", + run_id=run_id, + execution_source=execution_source, + organization_id=context.organization_id, + request_id=context.request_id, + executor_params={ + "llm_adapter_instance_id": output.get(PSKeys.LLM, ""), + "table_settings": output.get(PSKeys.TABLE_SETTINGS, {}), + "prompt": output.get(PSKeys.PROMPT, ""), + "PLATFORM_SERVICE_API_KEY": platform_api_key, + "execution_id": execution_id, + "tool_id": tool_id, + "file_name": doc_name, + }, + ) + table_ctx._log_component = self._log_component + table_ctx.log_events_id = self._log_events_id + + shim.stream_log(f"Running table extraction for: {prompt_name}") + table_result = table_executor.execute(table_ctx) + + if table_result.success: + structured_output[prompt_name] = table_result.data.get("output", "") + table_metrics = table_result.data.get("metadata", {}).get( + "metrics", {} + ) + metrics.setdefault(prompt_name, {}).update( + {"table_extraction": table_metrics} + ) + shim.stream_log(f"Table extraction completed for: {prompt_name}") + logger.info("TABLE extraction completed: prompt=%s", prompt_name) + else: + structured_output[prompt_name] = "" + logger.error( + "TABLE extraction failed for prompt=%s: %s", + prompt_name, + table_result.error, + ) + shim.stream_log(f"Completed prompt: {prompt_name}") + continue + + if output.get(PSKeys.TYPE) == PSKeys.LINE_ITEM: + raise LegacyExecutorError( + message="LINE_ITEM extraction is not supported." + ) + + # Create adapters + try: + usage_kwargs = { + "run_id": run_id, + "execution_id": execution_id, + } + llm = LLM( + adapter_instance_id=output[PSKeys.LLM], + tool=shim, + usage_kwargs={ + **usage_kwargs, + PSKeys.LLM_USAGE_REASON: PSKeys.EXTRACTION, + }, + capture_metrics=True, + ) + embedding = None + vector_db = None + if chunk_size > 0: + embedding = EmbeddingCompat( + adapter_instance_id=output[PSKeys.EMBEDDING], + tool=shim, + kwargs={**usage_kwargs}, + ) + vector_db = VectorDB( + tool=shim, + adapter_instance_id=output[PSKeys.VECTOR_DB], + embedding=embedding, + ) + shim.stream_log( + f"Initialized LLM and retrieval adapters for: {prompt_name}" + ) + except Exception as e: + msg = f"Couldn't fetch adapter. {e}" + logger.error(msg) + status_code = getattr(e, "status_code", None) or 500 + raise LegacyExecutorError(message=msg, code=status_code) from e + + # ---- Retrieval + Answer ---------------------------------------- + context_list: list[str] = [] + try: + answer = "NA" + retrieval_strategy = output.get(PSKeys.RETRIEVAL_STRATEGY) + valid_strategies = {s.value for s in RetrievalStrategy} + + if retrieval_strategy in valid_strategies: + shim.stream_log(f"Retrieving context for: {prompt_name}") + logger.info( + "Performing retrieval: prompt=%s strategy=%s " "chunk_size=%d", + prompt_name, + retrieval_strategy, + chunk_size, + ) + if chunk_size == 0: + context_list = RetrievalService.retrieve_complete_context( + execution_source=execution_source, + file_path=file_path, + context_retrieval_metrics=context_retrieval_metrics, + prompt_key=prompt_name, + ) + else: + context_list = RetrievalService.run_retrieval( + output=output, + doc_id=doc_id, + llm=llm, + vector_db=vector_db, + retrieval_type=retrieval_strategy, + context_retrieval_metrics=context_retrieval_metrics, + ) + metadata[PSKeys.CONTEXT][prompt_name] = context_list + shim.stream_log( + f"Retrieved {len(context_list)} context chunks" + f" for: {prompt_name}" + ) + logger.debug( + "Retrieved %d context chunks for prompt: %s", + len(context_list), + prompt_name, + ) + + # Run prompt with retrieved context + shim.stream_log(f"Running LLM completion for: {prompt_name}") + answer = AnswerPromptService.construct_and_run_prompt( + tool_settings=tool_settings, + output=output, + llm=llm, + context="\n".join(context_list), + prompt=PSKeys.PROMPTX, + metadata=metadata, + execution_source=execution_source, + file_path=file_path, + process_text=process_text_fn, + ) + else: + logger.warning( + "Skipping retrieval: invalid strategy=%s " "for prompt=%s", + retrieval_strategy, + prompt_name, + ) + + # ---- Type-specific post-processing ------------------------- + self._apply_type_conversion( + output=output, + answer=answer, + structured_output=structured_output, + llm=llm, + tool_settings=tool_settings, + metadata=metadata, + execution_source=execution_source, + file_path=file_path, + log_events_id=log_events_id, + tool_id=tool_id, + doc_name=doc_name, + ) + shim.stream_log(f"Applied type conversion for: {prompt_name}") + + # ---- Challenge (quality verification) ---------------------- + if tool_settings.get(PSKeys.ENABLE_CHALLENGE): + from executor.executors.plugins import ( + ExecutorPluginLoader, + ) + + challenge_cls = ExecutorPluginLoader.get("challenge") + if challenge_cls: + challenge_llm_id = tool_settings.get(PSKeys.CHALLENGE_LLM) + if challenge_llm_id: + shim.stream_log(f"Running challenge for: {prompt_name}") + challenge_llm = LLM( + adapter_instance_id=challenge_llm_id, + tool=shim, + usage_kwargs={ + **usage_kwargs, + PSKeys.LLM_USAGE_REASON: PSKeys.CHALLENGE, + }, + capture_metrics=True, + ) + challenger = challenge_cls( + llm=llm, + challenge_llm=challenge_llm, + context="\n".join(context_list), + tool_settings=tool_settings, + output=output, + structured_output=structured_output, + run_id=run_id, + platform_key=platform_api_key, + metadata=metadata, + ) + challenger.run() + shim.stream_log( + f"Challenge verification completed" f" for: {prompt_name}" + ) + logger.info( + "Challenge completed: prompt=%s", + prompt_name, + ) + + # ---- Evaluation (prompt evaluation) ------------------------ + eval_settings = output.get(PSKeys.EVAL_SETTINGS, {}) + if eval_settings.get(PSKeys.EVAL_SETTINGS_EVALUATE): + from executor.executors.plugins import ( + ExecutorPluginLoader, + ) + + evaluator_cls = ExecutorPluginLoader.get("evaluation") + if evaluator_cls: + shim.stream_log(f"Running evaluation for: {prompt_name}") + evaluator = evaluator_cls( + query=output.get(PSKeys.COMBINED_PROMPT, ""), + context="\n".join(context_list), + response=structured_output.get(prompt_name), + reference_answer=output.get("reference_answer", ""), + prompt=output, + structured_output=structured_output, + platform_key=platform_api_key, + ) + evaluator.run() + logger.info( + "Evaluation completed: prompt=%s", + prompt_name, + ) + + shim.stream_log(f"Completed prompt: {prompt_name}") + + # Strip trailing newline + val = structured_output.get(prompt_name) + if isinstance(val, str): + structured_output[prompt_name] = val.rstrip("\n") + + finally: + # Collect metrics + metrics.setdefault(prompt_name, {}).update( + { + "context_retrieval": context_retrieval_metrics.get( + prompt_name, {} + ), + f"{llm.get_usage_reason()}_llm": llm.get_metrics(), + } + ) + if vector_db: + vector_db.close() + + pipeline_shim.stream_log(f"All {len(prompts)} prompts processed successfully") + logger.info( + "All prompts processed: tool_id=%s prompt_count=%d file=%s", + tool_id, + len(prompts), + doc_name, + ) + + # ---- Sanitize null values ------------------------------------------ + structured_output = self._sanitize_null_values(structured_output) + + return ExecutionResult( + success=True, + data={ + PSKeys.OUTPUT: structured_output, + PSKeys.METADATA: metadata, + PSKeys.METRICS: metrics, + }, + ) + + @staticmethod + def _apply_type_conversion( + output: dict[str, Any], + answer: str, + structured_output: dict[str, Any], + llm: Any, + tool_settings: dict[str, Any], + metadata: dict[str, Any], + execution_source: str, + file_path: str, + log_events_id: str = "", + tool_id: str = "", + doc_name: str = "", + ) -> None: + """Apply type-specific conversion to the LLM answer. + + Handles NUMBER, EMAIL, DATE, BOOLEAN, JSON, and TEXT types. + """ + from executor.executors.answer_prompt import AnswerPromptService + from executor.executors.constants import PromptServiceConstants as PSKeys + + prompt_name = output[PSKeys.NAME] + output_type = output[PSKeys.TYPE] + + if output_type == PSKeys.NUMBER: + if answer.lower() == "na": + structured_output[prompt_name] = None + else: + prompt = ( + f"Extract the number from the following " + f"text:\n{answer}\n\nOutput just the number. " + f"If the number is expressed in millions " + f"or thousands, expand the number to its numeric value " + f"The number should be directly assignable " + f"to a numeric variable. " + f"It should not have any commas, " + f"percentages or other grouping " + f"characters. No explanation is required. " + f"If you cannot extract the number, output 0." + ) + answer = AnswerPromptService.run_completion(llm=llm, prompt=prompt) + try: + structured_output[prompt_name] = float(answer) + except Exception: + structured_output[prompt_name] = None + + elif output_type == PSKeys.EMAIL: + if answer.lower() == "na": + structured_output[prompt_name] = None + else: + prompt = ( + f"Extract the email from the following text:\n{answer}" + f"\n\nOutput just the email. " + f"The email should be directly assignable to a string " + f"variable. No explanation is required. If you cannot " + f'extract the email, output "NA".' + ) + answer = AnswerPromptService.run_completion(llm=llm, prompt=prompt) + structured_output[prompt_name] = answer + + elif output_type == PSKeys.DATE: + if answer.lower() == "na": + structured_output[prompt_name] = None + else: + prompt = ( + f"Extract the date from the following text:\n{answer}" + f"\n\nOutput just the date. " + f"The date should be in ISO date time format. " + f"No explanation is required. The date should be " + f"directly assignable to a date variable. " + f"If you cannot convert the string into a date, " + f'output "NA".' + ) + answer = AnswerPromptService.run_completion(llm=llm, prompt=prompt) + structured_output[prompt_name] = answer + + elif output_type == PSKeys.BOOLEAN: + if answer.lower() == "na": + structured_output[prompt_name] = None + else: + prompt = ( + f"Extract yes/no from the following text:\n{answer}\n\n" + f"Output in single word. " + f"If the context is trying to convey that the answer " + f'is true, then return "yes", else return "no".' + ) + answer = AnswerPromptService.run_completion(llm=llm, prompt=prompt) + structured_output[prompt_name] = answer.lower() == "yes" + + elif output_type == PSKeys.JSON: + AnswerPromptService.handle_json( + answer=answer, + structured_output=structured_output, + output=output, + llm=llm, + enable_highlight=tool_settings.get(PSKeys.ENABLE_HIGHLIGHT, False), + enable_word_confidence=tool_settings.get( + PSKeys.ENABLE_WORD_CONFIDENCE, False + ), + execution_source=execution_source, + metadata=metadata, + file_path=file_path, + log_events_id=log_events_id, + tool_id=tool_id, + doc_name=doc_name, + ) + + else: + # TEXT or any other type — store raw answer + structured_output[prompt_name] = answer + + def _handle_single_pass_extraction( + self, context: ExecutionContext + ) -> ExecutionResult: + """Handle ``Operation.SINGLE_PASS_EXTRACTION``. + + Functionally identical to ``_handle_answer_prompt``. The "single + pass" vs "multi pass" distinction is at the *caller* level (the + structure tool batches all prompts into one request vs iterating). + The prompt-service processes both with the same ``prompt_processor`` + handler. + + Returns: + ExecutionResult with ``data`` containing:: + + {"output": dict, "metadata": dict, "metrics": dict} + """ + logger.info( + "single_pass_extraction delegating to answer_prompt " "(run_id=%s)", + context.run_id, + ) + return self._handle_answer_prompt(context) + + def _handle_summarize(self, context: ExecutionContext) -> ExecutionResult: + """Handle ``Operation.SUMMARIZE`` — document summarization. + + Called by the structure tool when ``summarize_as_source`` is + enabled. Takes the full extracted document text and a + user-provided summarize prompt, runs LLM completion, and + returns the summarized text. + + Expected ``executor_params`` keys: + - ``llm_adapter_instance_id`` — LLM adapter to use + - ``summarize_prompt`` — user's summarize instruction + - ``context`` — full document text to summarize + - ``prompt_keys`` — list of field names to focus on + - ``PLATFORM_SERVICE_API_KEY`` — auth key for adapters + + Returns: + ExecutionResult with ``data`` containing:: + + {"data": str} # summarized text + """ + from executor.executors.constants import PromptServiceConstants as PSKeys + + params: dict[str, Any] = context.executor_params + + llm_adapter_id: str = params.get("llm_adapter_instance_id", "") + summarize_prompt: str = params.get("summarize_prompt", "") + doc_context: str = params.get(PSKeys.CONTEXT, "") + prompt_keys: list[str] = params.get("prompt_keys", []) + platform_api_key: str = params.get(PSKeys.PLATFORM_SERVICE_API_KEY, "") + + if not llm_adapter_id: + return ExecutionResult.failure( + error="Missing required param: llm_adapter_instance_id" + ) + if not doc_context: + return ExecutionResult.failure(error="Missing required param: context") + + logger.info( + "Starting summarization: prompt_keys=%s run_id=%s", + prompt_keys, + context.run_id, + ) + + # Build the summarize prompt + prompt = f"{summarize_prompt}\n\n" + if prompt_keys: + prompt += f"Focus on these fields: {', '.join(prompt_keys)}\n\n" + prompt += ( + f"Context:\n---------------\n{doc_context}\n" f"-----------------\n\nSummary:" + ) + + shim = ExecutorToolShim( + platform_api_key=platform_api_key, + log_events_id=self._log_events_id, + component=self._log_component, + ) + usage_kwargs = {"run_id": context.run_id} + + _, _, _, _, LLM, _, _ = self._get_prompt_deps() + + shim.stream_log("Initializing LLM for summarization...") + try: + llm = LLM( + adapter_instance_id=llm_adapter_id, + tool=shim, + usage_kwargs={**usage_kwargs}, + ) + from executor.executors.answer_prompt import AnswerPromptService + + shim.stream_log("Running document summarization...") + summary = AnswerPromptService.run_completion(llm=llm, prompt=prompt) + logger.info("Summarization completed: run_id=%s", context.run_id) + shim.stream_log("Summarization completed") + return ExecutionResult( + success=True, + data={"data": summary}, + ) + except Exception as e: + logger.error("Summarization failed: error=%s", str(e)) + status_code = getattr(e, "status_code", None) or 500 + raise LegacyExecutorError( + message=f"Error during summarization: {e}", + code=status_code, + ) from e diff --git a/workers/executor/executors/plugins/__init__.py b/workers/executor/executors/plugins/__init__.py new file mode 100644 index 0000000000..b730ff12b6 --- /dev/null +++ b/workers/executor/executors/plugins/__init__.py @@ -0,0 +1,3 @@ +from executor.executors.plugins.loader import ExecutorPluginLoader + +__all__ = ["ExecutorPluginLoader"] diff --git a/workers/executor/executors/plugins/loader.py b/workers/executor/executors/plugins/loader.py new file mode 100644 index 0000000000..7a4ed25da5 --- /dev/null +++ b/workers/executor/executors/plugins/loader.py @@ -0,0 +1,79 @@ +"""Entry-point-based discovery for cloud plugins and executors. + +Two entry point groups are used: + +- ``unstract.executor.plugins`` + Utility plugins (highlight-data, challenge, evaluation). + Loaded lazily on first ``get()`` call and cached. + +- ``unstract.executor.executors`` + Executor classes that self-register via ``@ExecutorRegistry.register``. + Loaded eagerly at worker startup from ``executors/__init__.py``. +""" + +import logging + +logger = logging.getLogger(__name__) + + +class ExecutorPluginLoader: + """Discovers cloud plugins and executors via setuptools entry points.""" + + _plugins: dict[str, type] | None = None + + @classmethod + def get(cls, name: str) -> type | None: + """Get a plugin class by name. Returns None if not installed.""" + if cls._plugins is None: + cls._discover_plugins() + return cls._plugins.get(name) + + @classmethod + def discover_executors(cls) -> list[str]: + """Load cloud executor classes via entry points. + + Importing each entry point's class triggers + ``@ExecutorRegistry.register``. Called once at worker startup. + + Returns: + List of discovered executor entry point names. + """ + from importlib.metadata import entry_points + + discovered: list[str] = [] + eps = entry_points(group="unstract.executor.executors") + for ep in eps: + try: + ep.load() # import triggers @ExecutorRegistry.register + discovered.append(ep.name) + logger.info("Loaded cloud executor: %s", ep.name) + except Exception: + logger.warning( + "Failed to load cloud executor: %s", + ep.name, + exc_info=True, + ) + return discovered + + @classmethod + def _discover_plugins(cls) -> None: + """Discover utility plugins from entry points (lazy, first use).""" + from importlib.metadata import entry_points + + cls._plugins = {} + eps = entry_points(group="unstract.executor.plugins") + for ep in eps: + try: + cls._plugins[ep.name] = ep.load() + logger.info("Loaded executor plugin: %s", ep.name) + except Exception: + logger.warning( + "Failed to load executor plugin: %s", + ep.name, + exc_info=True, + ) + + @classmethod + def clear(cls) -> None: + """Reset cached state. Intended for tests only.""" + cls._plugins = None diff --git a/workers/executor/executors/plugins/protocols.py b/workers/executor/executors/plugins/protocols.py new file mode 100644 index 0000000000..fb4d676b37 --- /dev/null +++ b/workers/executor/executors/plugins/protocols.py @@ -0,0 +1,51 @@ +"""Protocol classes defining contracts for cloud executor plugins. + +Cloud plugins must satisfy these protocols. The OSS repo never imports +cloud code — only these protocols and ``ExecutorPluginLoader.get(name)`` +are used to interact with plugins. +""" + +from typing import Any, Protocol, runtime_checkable + + +@runtime_checkable +class HighlightDataProtocol(Protocol): + """Cross-cutting: source attribution from LLMWhisperer metadata. + + Matches the cloud ``HighlightData`` plugin constructor which + accepts ``enable_word_confidence`` (not ``execution_source``). + The filesystem instance is determined by the caller and passed in. + """ + + def __init__( + self, + file_path: str, + fs_instance: Any = None, + enable_word_confidence: bool = False, + **kwargs: Any, + ) -> None: ... + + def run( + self, + response: Any = None, + is_json: bool = False, + original_text: str = "", + **kwargs: Any, + ) -> dict: ... + + @staticmethod + def extract_word_confidence(original_text: str, is_json: bool = False) -> dict: ... + + +@runtime_checkable +class ChallengeProtocol(Protocol): + """Legacy executor: quality verification with a second LLM.""" + + def run(self) -> None: ... + + +@runtime_checkable +class EvaluationProtocol(Protocol): + """Legacy executor: prompt evaluation.""" + + def run(self, **kwargs: Any) -> dict: ... diff --git a/workers/executor/executors/plugins/text_processor.py b/workers/executor/executors/plugins/text_processor.py new file mode 100644 index 0000000000..472d9dc828 --- /dev/null +++ b/workers/executor/executors/plugins/text_processor.py @@ -0,0 +1,19 @@ +"""Pure-function text utilities used by the highlight-data plugin.""" + + +def add_hex_line_numbers(text: str) -> str: + """Add hex line numbers to extracted text for coordinate tracking. + + Each line is prefixed with ``0x: `` where ```` is the + zero-based line index. The hex width auto-adjusts to the total + number of lines. + + Args: + text: Multi-line string to number. + + Returns: + The same text with hex line-number prefixes. + """ + lines = text.split("\n") + hex_width = max(len(hex(len(lines))) - 2, 1) + return "\n".join(f"0x{i:0{hex_width}X}: {line}" for i, line in enumerate(lines)) diff --git a/workers/executor/executors/postprocessor.py b/workers/executor/executors/postprocessor.py new file mode 100644 index 0000000000..bf14a56698 --- /dev/null +++ b/workers/executor/executors/postprocessor.py @@ -0,0 +1,119 @@ +"""Webhook postprocessor for structured output. + +Copied from prompt-service/.../helpers/postprocessor.py — already Flask-free. +""" + +import json +import logging +from typing import Any + +import requests + +logger = logging.getLogger(__name__) + + +def _validate_structured_output(data: Any) -> bool: + """Validate that structured output is a dict or list.""" + return isinstance(data, (dict, list)) + + +def _validate_highlight_data(updated_data: Any, original_data: Any) -> Any: + """Validate highlight data and return appropriate value.""" + if ( + updated_data is not None + and updated_data != original_data + and not isinstance(updated_data, list) + ): + logger.warning( + "Ignoring webhook highlight_data due to invalid type (expected list)" + ) + return original_data + return updated_data + + +def _process_successful_response( + response_data: dict, parsed_data: dict, highlight_data: list | None +) -> tuple[dict[str, Any], list | None]: + """Process successful webhook response.""" + if "structured_output" not in response_data: + logger.warning("Response missing 'structured_output' key") + return parsed_data, highlight_data + + updated_parsed_data = response_data["structured_output"] + + if not _validate_structured_output(updated_parsed_data): + logger.warning("Ignoring postprocessing due to invalid structured_output type") + return parsed_data, highlight_data + + updated_highlight_data = response_data.get("highlight_data", highlight_data) + updated_highlight_data = _validate_highlight_data( + updated_highlight_data, highlight_data + ) + + return updated_parsed_data, updated_highlight_data + + +def _make_webhook_request( + webhook_url: str, payload: dict, timeout: float +) -> tuple[dict[str, Any], list | None] | None: + """Make webhook request and return processed response or None on failure.""" + try: + response = requests.post( + webhook_url, + json=payload, + timeout=timeout, + headers={"Content-Type": "application/json"}, + allow_redirects=False, # Prevent redirect-based SSRF + ) + + if response.status_code != 200: + logger.warning( + f"Postprocessing server returned status code: {response.status_code}" + ) + return None + + return response.json() + + except json.JSONDecodeError as e: + logger.warning(f"Invalid JSON response from postprocessing server: {e}") + except requests.exceptions.Timeout: + logger.warning(f"Postprocessing server request timed out after {timeout}s") + except requests.exceptions.RequestException as e: + logger.warning(f"Postprocessing server request failed: {e}") + except Exception as e: + logger.warning(f"Unexpected error during postprocessing: {e}") + + return None + + +def postprocess_data( + parsed_data: dict[str, Any], + webhook_enabled: bool = False, + webhook_url: str | None = None, + timeout: float = 2.0, + highlight_data: list | None = None, +) -> tuple[dict[str, Any], list | None]: + """Post-process parsed data by sending it to an external server. + + Args: + parsed_data: The parsed data to be post-processed + webhook_enabled: Whether webhook postprocessing is enabled + webhook_url: URL endpoint for the webhook + timeout: Request timeout in seconds (default: 2.0) + highlight_data: Highlight data from metadata to send to webhook + + Returns: + tuple: (postprocessed_data, updated_highlight_data) + """ + if not webhook_enabled or not webhook_url: + return parsed_data, highlight_data + + payload = {"structured_output": parsed_data} + if highlight_data is not None: + payload["highlight_data"] = highlight_data + + response_data = _make_webhook_request(webhook_url, payload, timeout) + if response_data is None: + return parsed_data, highlight_data + + return _process_successful_response(response_data, parsed_data, highlight_data) diff --git a/workers/executor/executors/retrieval.py b/workers/executor/executors/retrieval.py new file mode 100644 index 0000000000..3b4cd1da0a --- /dev/null +++ b/workers/executor/executors/retrieval.py @@ -0,0 +1,113 @@ +"""Retrieval service — factory for retriever strategies. + +Lazy-imports retriever classes to avoid llama_index/protobuf conflicts +at test-collection time. Same pattern as _get_indexing_deps() in Phase 2C. +""" + +import datetime +import logging +from typing import Any + +from executor.executors.constants import RetrievalStrategy + +logger = logging.getLogger(__name__) + + +class RetrievalService: + @staticmethod + def _get_retriever_map() -> dict: + """Lazy-import all retriever classes. + + Returns dict mapping strategy string to class. + Wrapped in a method so tests can mock it. + """ + from executor.executors.retrievers.automerging import AutomergingRetriever + from executor.executors.retrievers.fusion import FusionRetriever + from executor.executors.retrievers.keyword_table import KeywordTableRetriever + from executor.executors.retrievers.recursive import RecursiveRetrieval + from executor.executors.retrievers.router import RouterRetriever + from executor.executors.retrievers.simple import SimpleRetriever + from executor.executors.retrievers.subquestion import SubquestionRetriever + + return { + RetrievalStrategy.SIMPLE.value: SimpleRetriever, + RetrievalStrategy.SUBQUESTION.value: SubquestionRetriever, + RetrievalStrategy.FUSION.value: FusionRetriever, + RetrievalStrategy.RECURSIVE.value: RecursiveRetrieval, + RetrievalStrategy.ROUTER.value: RouterRetriever, + RetrievalStrategy.KEYWORD_TABLE.value: KeywordTableRetriever, + RetrievalStrategy.AUTOMERGING.value: AutomergingRetriever, + } + + @staticmethod + def run_retrieval( + output: dict[str, Any], + doc_id: str, + llm: Any, + vector_db: Any, + retrieval_type: str, + context_retrieval_metrics: dict[str, Any] | None = None, + ) -> list[str]: + """Factory: instantiate and execute the retriever for the given strategy.""" + from executor.executors.constants import PromptServiceConstants as PSKeys + + prompt = output[PSKeys.PROMPTX] + top_k = output[PSKeys.SIMILARITY_TOP_K] + prompt_key = output.get(PSKeys.NAME, "") + start = datetime.datetime.now() + + retriever_map = RetrievalService._get_retriever_map() + retriever_class = retriever_map.get(retrieval_type) + if not retriever_class: + raise ValueError(f"Unknown retrieval type: {retrieval_type}") + + retriever = retriever_class( + vector_db=vector_db, + doc_id=doc_id, + prompt=prompt, + top_k=top_k, + llm=llm, + ) + context = retriever.retrieve() + + elapsed = (datetime.datetime.now() - start).total_seconds() + if context_retrieval_metrics is not None: + context_retrieval_metrics[prompt_key] = {"time_taken(s)": elapsed} + + logger.info( + "[Retrieval] prompt='%s' doc_id=%s strategy='%s' top_k=%d " + "chunks=%d time=%.3fs", + prompt_key, + doc_id, + retrieval_type, + top_k, + len(context), + elapsed, + ) + return list(context) + + @staticmethod + def retrieve_complete_context( + execution_source: str, + file_path: str, + context_retrieval_metrics: dict[str, Any] | None = None, + prompt_key: str = "", + ) -> list[str]: + """Load full file content for chunk_size=0 retrieval.""" + from executor.executors.file_utils import FileUtils + + fs = FileUtils.get_fs_instance(execution_source=execution_source) + start = datetime.datetime.now() + content = fs.read(path=file_path, mode="r") + elapsed = (datetime.datetime.now() - start).total_seconds() + + if context_retrieval_metrics is not None: + context_retrieval_metrics[prompt_key] = {"time_taken(s)": elapsed} + + logger.info( + "[Retrieval] prompt='%s' complete_context chars=%d time=%.3fs", + prompt_key, + len(content), + elapsed, + ) + return [content] diff --git a/workers/executor/executors/retrievers/__init__.py b/workers/executor/executors/retrievers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/workers/executor/executors/retrievers/automerging.py b/workers/executor/executors/retrievers/automerging.py new file mode 100644 index 0000000000..c3472964ed --- /dev/null +++ b/workers/executor/executors/retrievers/automerging.py @@ -0,0 +1,98 @@ +import logging + +from executor.executors.exceptions import RetrievalError +from executor.executors.retrievers.base_retriever import BaseRetriever +from llama_index.core import VectorStoreIndex +from llama_index.core.retrievers import AutoMergingRetriever as LlamaAutoMergingRetriever +from llama_index.core.vector_stores import ExactMatchFilter, MetadataFilters + +logger = logging.getLogger(__name__) + + +class AutomergingRetriever(BaseRetriever): + """Automerging retrieval using LlamaIndex's native AutoMergingRetriever. + + This retriever merges smaller chunks into larger ones when the smaller chunks + don't contain enough information, providing better context for answers. + """ + + def retrieve(self) -> set[str]: + """Retrieve text chunks using LlamaIndex's native AutoMergingRetriever. + + Returns: + set[str]: A set of text chunks retrieved from the database. + """ + try: + logger.info( + f"Retrieving chunks for {self.doc_id} using LlamaIndex AutoMergingRetriever." + ) + + # Get the vector store index + vector_store_index: VectorStoreIndex = self.vector_db.get_vector_store_index() + + # Create base vector retriever with metadata filters + base_retriever = vector_store_index.as_retriever( + similarity_top_k=self.top_k, + filters=MetadataFilters( + filters=[ + ExactMatchFilter(key="doc_id", value=self.doc_id), + ], + ), + ) + + # Try to use native AutoMergingRetriever + try: + # Create AutoMergingRetriever with the base retriever + auto_merging_retriever = LlamaAutoMergingRetriever( + base_retriever, + storage_context=self.vector_db.get_storage_context() + if hasattr(self.vector_db, "get_storage_context") + else None, + verbose=False, + ) + + # Retrieve nodes using auto-merging + nodes = auto_merging_retriever.retrieve(self.prompt) + + except Exception as e: + logger.error( + "AutoMergingRetriever failed: %s: %s", + type(e).__name__, e, + exc_info=True, + ) + raise RetrievalError( + f"AutoMergingRetriever failed: {type(e).__name__}: {e}" + ) from e + + # Extract unique text chunks + chunks: set[str] = set() + for node in nodes: + if node.score > 0: + chunks.add(node.get_content()) + else: + logger.info( + f"Node score is less than 0. " + f"Ignored: {node.node_id} with score {node.score}" + ) + + logger.info( + f"Successfully retrieved {len(chunks)} chunks using AutoMergingRetriever." + ) + return chunks + + except (ValueError, AttributeError, KeyError, ImportError) as e: + logger.error( + "Error during auto-merging retrieval for %s: %s: %s", + self.doc_id, type(e).__name__, e, + exc_info=True, + ) + raise RetrievalError(f"{type(e).__name__}: {e}") from e + except Exception as e: + logger.error( + "Unexpected error during auto-merging retrieval for %s: %s: %s", + self.doc_id, type(e).__name__, e, + exc_info=True, + ) + raise RetrievalError( + f"Unexpected error: {type(e).__name__}: {e}" + ) from e diff --git a/workers/executor/executors/retrievers/base_retriever.py b/workers/executor/executors/retrievers/base_retriever.py new file mode 100644 index 0000000000..48c7485255 --- /dev/null +++ b/workers/executor/executors/retrievers/base_retriever.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from unstract.sdk1.llm import LLM + from unstract.sdk1.vector_db import VectorDB + + +class BaseRetriever: + def __init__( + self, + vector_db: VectorDB, + prompt: str, + doc_id: str, + top_k: int, + llm: LLM | None = None, + ): + """Initialize the Retrieval class. + + Args: + vector_db (VectorDB): The vector database instance. + prompt (str): The query prompt. + doc_id (str): Document identifier for query context. + top_k (int): Number of top results to retrieve. + """ + self.vector_db = vector_db + self.prompt = prompt + self.doc_id = doc_id + self.top_k = top_k + self.llm = llm if llm else None + + @staticmethod + def retrieve() -> set[str]: + return set() diff --git a/workers/executor/executors/retrievers/fusion.py b/workers/executor/executors/retrievers/fusion.py new file mode 100644 index 0000000000..a9b27e2eb0 --- /dev/null +++ b/workers/executor/executors/retrievers/fusion.py @@ -0,0 +1,101 @@ +import logging + +from executor.executors.exceptions import RetrievalError +from executor.executors.retrievers.base_retriever import BaseRetriever +from llama_index.core import VectorStoreIndex +from llama_index.core.retrievers import QueryFusionRetriever +from llama_index.core.vector_stores import ExactMatchFilter, MetadataFilters + +logger = logging.getLogger(__name__) + + +class FusionRetriever(BaseRetriever): + """Fusion retrieval class using LlamaIndex's native QueryFusionRetriever. + + This technique generates multiple query variations and combines results + using reciprocal rank fusion for improved relevance. + """ + + def retrieve(self) -> set[str]: + """Retrieve text chunks using LlamaIndex's QueryFusionRetriever. + + Returns: + set[str]: A set of text chunks retrieved from the database. + """ + try: + logger.info( + f"Retrieving chunks for {self.doc_id} using LlamaIndex QueryFusionRetriever." + ) + + # Get the vector store index + vector_store_index: VectorStoreIndex = self.vector_db.get_vector_store_index() + + # Create multiple retrievers with different parameters for true fusion + filters = MetadataFilters( + filters=[ + ExactMatchFilter(key="doc_id", value=self.doc_id), + ], + ) + + # Retriever 1: Standard similarity search + retriever_1 = vector_store_index.as_retriever( + similarity_top_k=self.top_k, + filters=filters, + ) + + # Retriever 2: Broader search with more candidates + retriever_2 = vector_store_index.as_retriever( + similarity_top_k=self.top_k * 2, + filters=filters, + ) + + # Retriever 3: Focused search with fewer candidates + retriever_3 = vector_store_index.as_retriever( + similarity_top_k=max(1, self.top_k // 2), + filters=filters, + ) + + # Create LlamaIndex QueryFusionRetriever with multiple retrievers + fusion_retriever = QueryFusionRetriever( + [retriever_1, retriever_2, retriever_3], # Multiple retrievers for fusion + similarity_top_k=self.top_k, + num_queries=4, # Generate multiple query variations + mode="simple", # Use simple fusion mode (reciprocal rank fusion) + use_async=False, + verbose=True, + llm=self.llm, # LLM generates query variations + ) + + # Retrieve nodes using fusion technique + nodes = fusion_retriever.retrieve(self.prompt) + + # Extract unique text chunks + chunks: set[str] = set() + for node in nodes: + if node.score > 0: + chunks.add(node.get_content()) + else: + logger.info( + f"Node score is less than 0. " + f"Ignored: {node.node_id} with score {node.score}" + ) + + logger.info(f"Successfully retrieved {len(chunks)} chunks using fusion.") + return chunks + + except (ValueError, AttributeError, KeyError, ImportError) as e: + logger.error( + "Error during fusion retrieval for %s: %s: %s", + self.doc_id, type(e).__name__, e, + exc_info=True, + ) + raise RetrievalError(f"{type(e).__name__}: {e}") from e + except Exception as e: + logger.error( + "Unexpected error during fusion retrieval for %s: %s: %s", + self.doc_id, type(e).__name__, e, + exc_info=True, + ) + raise RetrievalError( + f"Unexpected error: {type(e).__name__}: {e}" + ) from e diff --git a/workers/executor/executors/retrievers/keyword_table.py b/workers/executor/executors/retrievers/keyword_table.py new file mode 100644 index 0000000000..2f1d345c02 --- /dev/null +++ b/workers/executor/executors/retrievers/keyword_table.py @@ -0,0 +1,86 @@ +import logging + +from executor.executors.exceptions import RetrievalError +from executor.executors.retrievers.base_retriever import BaseRetriever +from llama_index.core import VectorStoreIndex +from llama_index.core.indices.keyword_table import KeywordTableIndex +from llama_index.core.vector_stores import ExactMatchFilter, MetadataFilters + +logger = logging.getLogger(__name__) + + +class KeywordTableRetriever(BaseRetriever): + """Keyword table retrieval using LlamaIndex's native KeywordTableIndex.""" + + def retrieve(self) -> set[str]: + """Retrieve text chunks using LlamaIndex's native KeywordTableIndex. + + Returns: + set[str]: A set of text chunks retrieved from the database. + """ + try: + logger.info( + f"Retrieving chunks for {self.doc_id} using LlamaIndex KeywordTableIndex." + ) + + # Get documents from vector index for keyword indexing + vector_store_index: VectorStoreIndex = self.vector_db.get_vector_store_index() + + # Get all nodes for the document + all_retriever = vector_store_index.as_retriever( + similarity_top_k=1000, # Get all nodes + filters=MetadataFilters( + filters=[ + ExactMatchFilter(key="doc_id", value=self.doc_id), + ], + ), + ) + + # Retrieve all nodes to build keyword index + all_nodes = all_retriever.retrieve(" ") + + if not all_nodes: + logger.warning(f"No nodes found for doc_id: {self.doc_id}") + return set() + + # Create KeywordTableIndex from nodes using our provided LLM + keyword_index = KeywordTableIndex( + nodes=[node.node for node in all_nodes], + show_progress=True, + llm=self.llm, # Use the provided LLM instead of defaulting to OpenAI + ) + + # Create retriever from keyword index + keyword_retriever = keyword_index.as_retriever( + similarity_top_k=self.top_k, + ) + + # Retrieve nodes using keyword matching + nodes = keyword_retriever.retrieve(self.prompt) + + # Extract unique text chunks + chunks: set[str] = set() + for node in nodes: + chunks.add(node.get_content()) + + logger.info( + f"Successfully retrieved {len(chunks)} chunks using KeywordTableIndex." + ) + return chunks + + except (ValueError, AttributeError, KeyError, ImportError) as e: + logger.error( + "Error during keyword retrieval for %s: %s: %s", + self.doc_id, type(e).__name__, e, + exc_info=True, + ) + raise RetrievalError(f"{type(e).__name__}: {e}") from e + except Exception as e: + logger.error( + "Unexpected error during keyword retrieval for %s: %s: %s", + self.doc_id, type(e).__name__, e, + exc_info=True, + ) + raise RetrievalError( + f"Unexpected error: {type(e).__name__}: {e}" + ) from e diff --git a/workers/executor/executors/retrievers/recursive.py b/workers/executor/executors/retrievers/recursive.py new file mode 100644 index 0000000000..b49bf298c2 --- /dev/null +++ b/workers/executor/executors/retrievers/recursive.py @@ -0,0 +1,84 @@ +import logging + +from executor.executors.exceptions import RetrievalError +from executor.executors.retrievers.base_retriever import BaseRetriever +from llama_index.core import VectorStoreIndex +from llama_index.core.retrievers import RecursiveRetriever +from llama_index.core.vector_stores import ExactMatchFilter, MetadataFilters + +logger = logging.getLogger(__name__) + + +class RecursiveRetrieval(BaseRetriever): + """Recursive retrieval using LlamaIndex's native RecursiveRetriever. + + This retriever performs recursive retrieval by breaking down queries + and refining results through multiple retrieval steps. + """ + + def retrieve(self) -> set[str]: + """Retrieve text chunks using LlamaIndex's native RecursiveRetriever. + + Returns: + set[str]: A set of text chunks retrieved from the database. + """ + try: + logger.info( + f"Retrieving chunks for {self.doc_id} using LlamaIndex RecursiveRetriever." + ) + + # Get the vector store index + vector_store_index: VectorStoreIndex = self.vector_db.get_vector_store_index() + + # Create base retriever with metadata filters + base_retriever = vector_store_index.as_retriever( + similarity_top_k=self.top_k, + filters=MetadataFilters( + filters=[ + ExactMatchFilter(key="doc_id", value=self.doc_id), + ], + ), + ) + + # Create RecursiveRetriever + recursive_retriever = RecursiveRetriever( + "vector", # root retriever key + retriever_dict={"vector": base_retriever}, + verbose=True, + ) + + # Retrieve nodes using RecursiveRetriever + nodes = recursive_retriever.retrieve(self.prompt) + + # Extract unique text chunks + chunks: set[str] = set() + for node in nodes: + if node.score > 0: + chunks.add(node.get_content()) + else: + logger.info( + f"Node score is less than 0. " + f"Ignored: {node.node_id} with score {node.score}" + ) + + logger.info( + f"Successfully retrieved {len(chunks)} chunks using RecursiveRetriever." + ) + return chunks + + except (ValueError, AttributeError, KeyError, ImportError) as e: + logger.error( + "Error during recursive retrieval for %s: %s: %s", + self.doc_id, type(e).__name__, e, + exc_info=True, + ) + raise RetrievalError(f"{type(e).__name__}: {e}") from e + except Exception as e: + logger.error( + "Unexpected error during recursive retrieval for %s: %s: %s", + self.doc_id, type(e).__name__, e, + exc_info=True, + ) + raise RetrievalError( + f"Unexpected error: {type(e).__name__}: {e}" + ) from e diff --git a/workers/executor/executors/retrievers/router.py b/workers/executor/executors/retrievers/router.py new file mode 100644 index 0000000000..8dae80271c --- /dev/null +++ b/workers/executor/executors/retrievers/router.py @@ -0,0 +1,164 @@ +import logging + +from executor.executors.exceptions import RetrievalError +from executor.executors.retrievers.base_retriever import BaseRetriever +from llama_index.core import VectorStoreIndex +from llama_index.core.query_engine import RouterQueryEngine +from llama_index.core.selectors import LLMSingleSelector +from llama_index.core.tools import QueryEngineTool, ToolMetadata +from llama_index.core.vector_stores import ExactMatchFilter, MetadataFilters + +logger = logging.getLogger(__name__) + + +class RouterRetriever(BaseRetriever): + """Router retrieval class using LlamaIndex's native RouterQueryEngine. + + This technique intelligently routes queries to different retrieval strategies + based on query analysis. + """ + + def _create_metadata_filters(self): + """Create metadata filters for doc_id.""" + return MetadataFilters( + filters=[ + ExactMatchFilter(key="doc_id", value=self.doc_id), + ], + ) + + def _create_base_query_engine(self, vector_store_index, filters): + """Create the base vector query engine.""" + return vector_store_index.as_query_engine( + similarity_top_k=self.top_k, + filters=filters, + llm=self.llm, + ) + + def _add_keyword_search_tool(self, query_engine_tools, vector_store_index, filters): + """Add keyword search tool to query engine tools list.""" + try: + keyword_query_engine = vector_store_index.as_query_engine( + similarity_top_k=self.top_k * 2, + filters=filters, + llm=self.llm, + ) + query_engine_tools.append( + QueryEngineTool( + query_engine=keyword_query_engine, + metadata=ToolMetadata( + name="keyword_search", + description=( + "Best for finding specific terms, names, numbers, dates, " + "or exact phrases. Use when looking for precise matches." + ), + ), + ) + ) + except Exception as e: + logger.debug(f"Could not create keyword search engine: {e}") + + def _add_broad_search_tool(self, query_engine_tools, vector_store_index, filters): + """Add broad search tool to query engine tools list.""" + try: + broad_query_engine = vector_store_index.as_query_engine( + similarity_top_k=self.top_k * 3, + filters=filters, + llm=self.llm, + ) + query_engine_tools.append( + QueryEngineTool( + query_engine=broad_query_engine, + metadata=ToolMetadata( + name="broad_search", + description=( + "Useful for general questions, exploratory queries, " + "or when you need comprehensive information on a topic." + ), + ), + ) + ) + except Exception as e: + logger.debug(f"Could not create broad search engine: {e}") + + def _extract_chunks_from_response(self, response): + """Extract chunks from router query response.""" + chunks: set[str] = set() + if hasattr(response, "source_nodes"): + for node in response.source_nodes: + if node.score > 0: + chunks.add(node.get_content()) + else: + logger.info( + f"Node score is less than 0. " + f"Ignored: {node.node_id} with score {node.score}" + ) + return chunks + + def retrieve(self) -> set[str]: + """Retrieve text chunks using LlamaIndex's RouterQueryEngine. + + Returns: + set[str]: A set of text chunks retrieved from the database. + """ + try: + logger.info( + f"Retrieving chunks for {self.doc_id} using LlamaIndex RouterQueryEngine." + ) + + vector_store_index: VectorStoreIndex = self.vector_db.get_vector_store_index() + filters = self._create_metadata_filters() + vector_query_engine = self._create_base_query_engine( + vector_store_index, filters + ) + + if not self.llm: + return set() + + # Create base query engine tools + query_engine_tools = [ + QueryEngineTool( + query_engine=vector_query_engine, + metadata=ToolMetadata( + name="vector_search", + description=( + "Useful for semantic similarity search, conceptual questions, " + "and finding information based on meaning and context." + ), + ), + ), + ] + + # Add additional search strategies + self._add_keyword_search_tool(query_engine_tools, vector_store_index, filters) + self._add_broad_search_tool(query_engine_tools, vector_store_index, filters) + + # Create and execute router query + router_query_engine = RouterQueryEngine.from_defaults( + selector=LLMSingleSelector.from_defaults(llm=self.llm), + query_engine_tools=query_engine_tools, + verbose=True, + llm=self.llm, + ) + + response = router_query_engine.query(self.prompt) + chunks = self._extract_chunks_from_response(response) + + logger.info(f"Successfully retrieved {len(chunks)} chunks using router.") + return chunks + + except (ValueError, AttributeError, KeyError, ImportError) as e: + logger.error( + "Error during router retrieval for %s: %s: %s", + self.doc_id, type(e).__name__, e, + exc_info=True, + ) + raise RetrievalError(f"{type(e).__name__}: {e}") from e + except Exception as e: + logger.error( + "Unexpected error during router retrieval for %s: %s: %s", + self.doc_id, type(e).__name__, e, + exc_info=True, + ) + raise RetrievalError( + f"Unexpected error: {type(e).__name__}: {e}" + ) from e diff --git a/workers/executor/executors/retrievers/simple.py b/workers/executor/executors/retrievers/simple.py new file mode 100644 index 0000000000..5e533e72b3 --- /dev/null +++ b/workers/executor/executors/retrievers/simple.py @@ -0,0 +1,52 @@ +import logging +import time + +from executor.executors.retrievers.base_retriever import BaseRetriever +from llama_index.core import VectorStoreIndex +from llama_index.core.vector_stores import ExactMatchFilter, MetadataFilters + +logger = logging.getLogger(__name__) + + +class SimpleRetriever(BaseRetriever): + def retrieve(self) -> set[str]: + context = self._simple_retrieval() + if not context: + # UN-1288 For Pinecone, we are seeing an inconsistent case where + # query with doc_id fails even though indexing just happened. + # This causes the following retrieve to return no text. + # To rule out any lag on the Pinecone vector DB write, + # the following sleep is added + # Note: This will not fix the issue. Since this issue is inconsistent + # and not reproducible easily, this is just a safety net. + logger.info( + f"[doc_id: {self.doc_id}] Could not retrieve context, " + "retrying after 2 secs to handle issues due to lag" + ) + time.sleep(2) + context = self._simple_retrieval() + return context + + def _simple_retrieval(self): + vector_query_engine: VectorStoreIndex = self.vector_db.get_vector_store_index() + retriever = vector_query_engine.as_retriever( + similarity_top_k=self.top_k, + filters=MetadataFilters( + filters=[ + ExactMatchFilter(key="doc_id", value=self.doc_id), + ], + ), + ) + nodes = retriever.retrieve(self.prompt) + context: set[str] = set() + for node in nodes: + # May have to fine-tune this value for node score or keep it + # configurable at the adapter level + if node.score > 0: + context.add(node.get_content()) + else: + logger.info( + "Node score is less than 0. " + f"Ignored: {node.node_id} with score {node.score}" + ) + return context diff --git a/workers/executor/executors/retrievers/subquestion.py b/workers/executor/executors/retrievers/subquestion.py new file mode 100644 index 0000000000..de0d5047d3 --- /dev/null +++ b/workers/executor/executors/retrievers/subquestion.py @@ -0,0 +1,68 @@ +import logging + +from executor.executors.exceptions import RetrievalError +from executor.executors.retrievers.base_retriever import BaseRetriever +from llama_index.core.query_engine import SubQuestionQueryEngine +from llama_index.core.schema import QueryBundle +from llama_index.core.tools import QueryEngineTool, ToolMetadata + +logger = logging.getLogger(__name__) + + +class SubquestionRetriever(BaseRetriever): + """SubquestionRetrieval class for querying VectorDB using LlamaIndex's + SubQuestionQueryEngine. + """ + + def retrieve(self) -> set[str]: + """Retrieve text chunks from the VectorDB based on the provided prompt. + + Returns: + set[str]: A set of text chunks retrieved from the database. + """ + try: + logger.info("Initialising vector query engine...") + vector_query_engine = self.vector_db.get_vector_store_index().as_query_engine( + llm=self.llm, similarity_top_k=self.top_k + ) + logger.info( + f"Retrieving chunks for {self.doc_id} using SubQuestionQueryEngine." + ) + query_engine_tools = [ + QueryEngineTool( + query_engine=vector_query_engine, + metadata=ToolMetadata( + name=self.doc_id, description=f"Nodes for {self.doc_id}" + ), + ), + ] + query_bundle = QueryBundle(query_str=self.prompt) + + query_engine = SubQuestionQueryEngine.from_defaults( + query_engine_tools=query_engine_tools, + use_async=True, + llm=self.llm, + ) + + response = query_engine.query(str_or_query_bundle=query_bundle) + + chunks: set[str] = {node.text for node in response.source_nodes} + logger.info(f"Successfully retrieved {len(chunks)} chunks.") + return chunks + + except (ValueError, AttributeError, KeyError, ImportError) as e: + logger.error( + "Error during retrieving chunks %s: %s: %s", + self.doc_id, type(e).__name__, e, + exc_info=True, + ) + raise RetrievalError(f"{type(e).__name__}: {e}") from e + except Exception as e: + logger.error( + "Unexpected error during retrieving chunks %s: %s: %s", + self.doc_id, type(e).__name__, e, + exc_info=True, + ) + raise RetrievalError( + f"Unexpected error: {type(e).__name__}: {e}" + ) from e diff --git a/workers/executor/executors/usage.py b/workers/executor/executors/usage.py new file mode 100644 index 0000000000..ab6296eaeb --- /dev/null +++ b/workers/executor/executors/usage.py @@ -0,0 +1,81 @@ +"""Usage tracking helper for the executor worker. + +Ported from prompt-service/.../helpers/usage.py. +Flask/DB dependencies removed — usage data is pushed via the SDK1 +``Audit`` class (HTTP to platform API) and returned directly in +``ExecutionResult.metadata`` instead of querying the DB. + +Note: The SDK1 adapters (LLM, EmbeddingCompat) already call +``Audit().push_usage_data()`` internally. This helper is for +explicit push calls outside of adapter operations (e.g. rent rolls). +""" + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +class UsageHelper: + @staticmethod + def push_usage_data( + event_type: str, + kwargs: dict[str, Any], + platform_api_key: str, + token_counter: Any = None, + model_name: str = "", + ) -> bool: + """Push usage data to the audit service. + + Wraps ``Audit().push_usage_data()`` with validation and + error handling. + + Args: + event_type: Type of usage event (e.g. "llm", "embedding"). + kwargs: Context dict (run_id, execution_id, etc.). + platform_api_key: API key for platform service auth. + token_counter: Token counter with usage metrics. + model_name: Name of the model used. + + Returns: + True if successful, False otherwise. + """ + if not kwargs or not isinstance(kwargs, dict): + logger.error("Invalid kwargs provided to push_usage_data") + return False + + if not platform_api_key or not isinstance(platform_api_key, str): + logger.error("Invalid platform_api_key provided to push_usage_data") + return False + + try: + from unstract.sdk1.audit import Audit + + logger.debug( + "Pushing usage data for event_type=%s model=%s", + event_type, + model_name, + ) + + Audit().push_usage_data( + platform_api_key=platform_api_key, + token_counter=token_counter, + model_name=model_name, + event_type=event_type, + kwargs=kwargs, + ) + + logger.info("Successfully pushed usage data for %s", model_name) + return True + except Exception: + logger.exception("Error pushing usage data") + return False + + @staticmethod + def format_float_positional(value: float, precision: int = 10) -> str: + """Format a float without scientific notation. + + Removes trailing zeros for clean display of cost values. + """ + formatted: str = f"{value:.{precision}f}" + return formatted.rstrip("0").rstrip(".") if "." in formatted else formatted diff --git a/workers/executor/executors/variable_replacement.py b/workers/executor/executors/variable_replacement.py new file mode 100644 index 0000000000..cca158cba0 --- /dev/null +++ b/workers/executor/executors/variable_replacement.py @@ -0,0 +1,264 @@ +"""Variable replacement for prompt templates. + +Ported from prompt-service variable_replacement service + helper. +Flask dependencies (app.logger, publish_log) replaced with standard logging. +""" + +import json +import logging +import re +from functools import lru_cache +from typing import Any + +import requests as pyrequests +from executor.executors.constants import VariableConstants, VariableType +from executor.executors.exceptions import CustomDataError, LegacyExecutorError +from requests.exceptions import RequestException + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# VariableReplacementHelper — low-level replacement logic +# --------------------------------------------------------------------------- + + +class VariableReplacementHelper: + @staticmethod + def replace_static_variable( + prompt: str, structured_output: dict[str, Any], variable: str + ) -> str: + output_value = VariableReplacementHelper.check_static_variable_run_status( + structure_output=structured_output, variable=variable + ) + if not output_value: + return prompt + static_variable_marker_string = "".join(["{{", variable, "}}"]) + replaced_prompt: str = VariableReplacementHelper.replace_generic_string_value( + prompt=prompt, variable=static_variable_marker_string, value=output_value + ) + return replaced_prompt + + @staticmethod + def check_static_variable_run_status( + structure_output: dict[str, Any], variable: str + ) -> Any: + output = None + try: + output = structure_output[variable] + except KeyError: + logger.warning( + "Prompt with %s is not executed yet. " "Unable to replace the variable", + variable, + ) + return output + + @staticmethod + def replace_generic_string_value(prompt: str, variable: str, value: Any) -> str: + formatted_value: str = value + if not isinstance(value, str): + formatted_value = VariableReplacementHelper.handle_json_and_str_types(value) + replaced_prompt = prompt.replace(variable, formatted_value) + return replaced_prompt + + @staticmethod + def handle_json_and_str_types(value: Any) -> str: + try: + formatted_value = json.dumps(value) + except ValueError: + formatted_value = str(value) + return formatted_value + + @staticmethod + def identify_variable_type(variable: str) -> VariableType: + custom_data_pattern = re.compile(VariableConstants.CUSTOM_DATA_VARIABLE_REGEX) + if re.findall(custom_data_pattern, variable): + return VariableType.CUSTOM_DATA + + dynamic_pattern = re.compile(VariableConstants.DYNAMIC_VARIABLE_URL_REGEX) + if re.findall(dynamic_pattern, variable): + return VariableType.DYNAMIC + + return VariableType.STATIC + + @staticmethod + def replace_dynamic_variable( + prompt: str, variable: str, structured_output: dict[str, Any] + ) -> str: + url = re.search(VariableConstants.DYNAMIC_VARIABLE_URL_REGEX, variable).group(0) + data = re.findall(VariableConstants.DYNAMIC_VARIABLE_DATA_REGEX, variable)[0] + output_value = VariableReplacementHelper.check_static_variable_run_status( + structure_output=structured_output, variable=data + ) + if not output_value: + return prompt + api_response: Any = VariableReplacementHelper.fetch_dynamic_variable_value( + url=url, data=output_value + ) + formatted_api_response: str = VariableReplacementHelper.handle_json_and_str_types( + api_response + ) + static_variable_marker_string = "".join(["{{", variable, "}}"]) + replaced_prompt: str = VariableReplacementHelper.replace_generic_string_value( + prompt=prompt, + variable=static_variable_marker_string, + value=formatted_api_response, + ) + return replaced_prompt + + @staticmethod + def replace_custom_data_variable( + prompt: str, + variable: str, + custom_data: dict[str, Any], + is_ide: bool = True, + ) -> str: + custom_data_match = re.search( + VariableConstants.CUSTOM_DATA_VARIABLE_REGEX, variable + ) + if not custom_data_match: + error_msg = "Invalid variable format." + logger.error("%s: %s", error_msg, variable) + raise CustomDataError(variable=variable, reason=error_msg, is_ide=is_ide) + + path_str = custom_data_match.group(1) + path_parts = path_str.split(".") + + if not custom_data: + error_msg = "Custom data is not configured." + logger.error(error_msg) + raise CustomDataError(variable=path_str, reason=error_msg, is_ide=is_ide) + + try: + value = custom_data + for part in path_parts: + value = value[part] + except (KeyError, TypeError) as e: + error_msg = f"Key '{path_str}' not found in custom data." + logger.error(error_msg) + raise CustomDataError( + variable=path_str, reason=error_msg, is_ide=is_ide + ) from e + + variable_marker_string = "".join(["{{", variable, "}}"]) + replaced_prompt = VariableReplacementHelper.replace_generic_string_value( + prompt=prompt, + variable=variable_marker_string, + value=value, + ) + return replaced_prompt + + @staticmethod + @lru_cache(maxsize=128) + def _extract_variables_cached(prompt_text: str) -> tuple[str, ...]: + return tuple(re.findall(VariableConstants.VARIABLE_REGEX, prompt_text)) + + @staticmethod + def extract_variables_from_prompt(prompt_text: str) -> list[str]: + result = VariableReplacementHelper._extract_variables_cached(prompt_text) + return list(result) + + @staticmethod + def fetch_dynamic_variable_value(url: str, data: str) -> Any: + """Fetch dynamic variable value from an external URL. + + Ported from prompt-service make_http_request — simplified to direct + requests.post since we don't need Flask error classes. + """ + headers = {"Content-Type": "text/plain"} + try: + response = pyrequests.post(url, data=data, headers=headers, timeout=30) + response.raise_for_status() + if response.headers.get("content-type") == "application/json": + return response.json() + return response.text + except RequestException as e: + logger.error("HTTP request error fetching dynamic variable: %s", e) + status_code = None + if getattr(e, "response", None) is not None: + status_code = getattr(e.response, "status_code", None) + raise LegacyExecutorError( + message=f"HTTP POST to {url} failed: {e!s}", + code=status_code or 500, + ) from e + + +# --------------------------------------------------------------------------- +# VariableReplacementService — high-level orchestration +# --------------------------------------------------------------------------- + + +class VariableReplacementService: + @staticmethod + def is_variables_present(prompt_text: str) -> bool: + return bool( + len(VariableReplacementHelper.extract_variables_from_prompt(prompt_text)) + ) + + @staticmethod + def replace_variables_in_prompt( + prompt: dict[str, Any], + structured_output: dict[str, Any], + prompt_name: str, + tool_id: str = "", + log_events_id: str = "", + doc_name: str = "", + custom_data: dict[str, Any] | None = None, + is_ide: bool = True, + ) -> str: + from executor.executors.constants import PromptServiceConstants as PSKeys + + logger.info("[%s] Replacing variables in prompt: %s", tool_id, prompt_name) + + prompt_text = prompt[PSKeys.PROMPT] + try: + variable_map = prompt[PSKeys.VARIABLE_MAP] + prompt_text = VariableReplacementService._execute_variable_replacement( + prompt_text=prompt[PSKeys.PROMPT], + variable_map=variable_map, + custom_data=custom_data, + is_ide=is_ide, + ) + except KeyError: + prompt_text = VariableReplacementService._execute_variable_replacement( + prompt_text=prompt_text, + variable_map=structured_output, + custom_data=custom_data, + is_ide=is_ide, + ) + return prompt_text + + @staticmethod + def _execute_variable_replacement( + prompt_text: str, + variable_map: dict[str, Any], + custom_data: dict[str, Any] | None = None, + is_ide: bool = True, + ) -> str: + variables: list[str] = VariableReplacementHelper.extract_variables_from_prompt( + prompt_text=prompt_text + ) + for variable in variables: + variable_type = VariableReplacementHelper.identify_variable_type( + variable=variable + ) + if variable_type == VariableType.STATIC: + prompt_text = VariableReplacementHelper.replace_static_variable( + prompt=prompt_text, + structured_output=variable_map, + variable=variable, + ) + elif variable_type == VariableType.DYNAMIC: + prompt_text = VariableReplacementHelper.replace_dynamic_variable( + prompt=prompt_text, + variable=variable, + structured_output=variable_map, + ) + elif variable_type == VariableType.CUSTOM_DATA: + prompt_text = VariableReplacementHelper.replace_custom_data_variable( + prompt=prompt_text, + variable=variable, + custom_data=custom_data or {}, + is_ide=is_ide, + ) + return prompt_text diff --git a/workers/executor/tasks.py b/workers/executor/tasks.py new file mode 100644 index 0000000000..a729870e1c --- /dev/null +++ b/workers/executor/tasks.py @@ -0,0 +1,112 @@ +"""Executor Worker Tasks + +Defines the execute_extraction Celery task that receives an +ExecutionContext dict, runs the appropriate executor via +ExecutionOrchestrator, and returns an ExecutionResult dict. +""" + +from celery import shared_task +from shared.enums.task_enums import TaskName +from shared.infrastructure.logging import WorkerLogger + +from unstract.sdk1.execution.context import ExecutionContext +from unstract.sdk1.execution.orchestrator import ExecutionOrchestrator +from unstract.sdk1.execution.result import ExecutionResult + +logger = WorkerLogger.get_logger(__name__) + + +@shared_task( + bind=True, + name=TaskName.EXECUTE_EXTRACTION, + autoretry_for=(ConnectionError, TimeoutError, OSError), + retry_backoff=True, + retry_backoff_max=60, + max_retries=3, + retry_jitter=True, +) +def execute_extraction(self, execution_context_dict: dict) -> dict: + """Execute an extraction operation via the executor framework. + + This is the single Celery task entry point for all extraction + operations. Both the workflow path (structure tool task) and + the IDE path (PromptStudioHelper) dispatch to this task. + + Args: + execution_context_dict: Serialized ExecutionContext. + + Returns: + Serialized ExecutionResult dict. + """ + request_id = execution_context_dict.get("request_id", "") + logger.info( + "Received execute_extraction task: " + "celery_task_id=%s request_id=%s executor=%s " + "operation=%s execution_source=%s run_id=%s", + self.request.id, + request_id, + execution_context_dict.get("executor_name"), + execution_context_dict.get("operation"), + execution_context_dict.get("execution_source"), + execution_context_dict.get("run_id"), + ) + + try: + context = ExecutionContext.from_dict(execution_context_dict) + except (KeyError, ValueError) as exc: + logger.error("Invalid execution context: %s", exc, exc_info=True) + return ExecutionResult.failure( + error=f"Invalid execution context: {exc}" + ).to_dict() + + # Build component dict for log correlation when streaming to + # the frontend. Attached as a transient attribute (not serialized). + if context.log_events_id: + params = context.executor_params + # For compound operations, extract nested params for log + # correlation. + if context.operation == "ide_index": + extract_params = params.get("extract_params", {}) + context._log_component = { + "tool_id": extract_params.get("tool_id", ""), + "run_id": context.run_id, + "doc_name": str(extract_params.get("file_name", "")), + "operation": context.operation, + } + elif context.operation == "structure_pipeline": + answer_params = params.get("answer_params", {}) + pipeline_opts = params.get("pipeline_options", {}) + context._log_component = { + "tool_id": answer_params.get("tool_id", ""), + "run_id": context.run_id, + "doc_name": str(pipeline_opts.get("source_file_name", "")), + "operation": context.operation, + } + elif context.operation in ("table_extract", "smart_table_extract"): + context._log_component = { + "tool_id": params.get("tool_id", ""), + "run_id": context.run_id, + "doc_name": str(params.get("file_name", "")), + "operation": context.operation, + } + else: + context._log_component = { + "tool_id": params.get("tool_id", ""), + "run_id": context.run_id, + "doc_name": str(params.get("file_name", "")), + "operation": context.operation, + } + else: + context._log_component = {} + + orchestrator = ExecutionOrchestrator() + result = orchestrator.execute(context) + + logger.info( + "execute_extraction complete: " "celery_task_id=%s request_id=%s success=%s", + self.request.id, + context.request_id, + result.success, + ) + + return result.to_dict() diff --git a/workers/executor/worker.py b/workers/executor/worker.py new file mode 100644 index 0000000000..ecef4e6873 --- /dev/null +++ b/workers/executor/worker.py @@ -0,0 +1,81 @@ +"""Executor Worker + +Celery worker for the pluggable executor system. +Routes execute_extraction tasks to registered executors. +""" + +import logging + +from shared.enums.worker_enums import WorkerType +from shared.infrastructure.config.builder import WorkerBuilder +from shared.infrastructure.config.registry import WorkerRegistry +from shared.infrastructure.logging import WorkerLogger + +# Setup worker +logger = WorkerLogger.setup(WorkerType.EXECUTOR) +app, config = WorkerBuilder.build_celery_app(WorkerType.EXECUTOR) + +# Suppress Celery trace logging of task return values. +# The trace logger prints the full result dict on task success, which +# can contain sensitive customer data (extracted text, summaries, etc.). +logging.getLogger("celery.app.trace").setLevel(logging.WARNING) + + +def check_executor_health(): + """Custom health check for executor worker.""" + from shared.infrastructure.monitoring.health import ( + HealthCheckResult, + HealthStatus, + ) + + try: + from unstract.sdk1.execution.registry import ( + ExecutorRegistry, + ) + + executors = ExecutorRegistry.list_executors() + + return HealthCheckResult( + name="executor_health", + status=HealthStatus.HEALTHY, + message="Executor worker is healthy", + details={ + "worker_type": "executor", + "registered_executors": executors, + "executor_count": len(executors), + "queues": ["celery_executor_legacy"], + }, + ) + + except Exception as e: + return HealthCheckResult( + name="executor_health", + status=HealthStatus.DEGRADED, + message=f"Health check failed: {e}", + details={"error": str(e)}, + ) + + +# Register health check +WorkerRegistry.register_health_check( + WorkerType.EXECUTOR, + "executor_health", + check_executor_health, +) + + +@app.task(bind=True) +def healthcheck(self): + """Health check task for monitoring systems.""" + return { + "status": "healthy", + "worker_type": "executor", + "task_id": self.request.id, + "worker_name": (config.worker_name if config else "executor-worker"), + } + + +# Import tasks so shared_task definitions bind to this app. +# Import executors to trigger @ExecutorRegistry.register at import time. +import executor.executors # noqa: E402, F401 +import executor.tasks # noqa: E402, F401 diff --git a/workers/file_processing/__init__.py b/workers/file_processing/__init__.py index b3f8b74a97..b2b8ece391 100644 --- a/workers/file_processing/__init__.py +++ b/workers/file_processing/__init__.py @@ -4,6 +4,7 @@ direct Django ORM access, implementing the hybrid approach for tool execution. """ +from .structure_tool_task import execute_structure_tool from .tasks import ( process_file_batch, process_file_batch_api, @@ -13,6 +14,7 @@ __all__ = [ "celery_app", + "execute_structure_tool", "process_file_batch", "process_file_batch_api", "process_file_batch_resilient", diff --git a/workers/file_processing/structure_tool_task.py b/workers/file_processing/structure_tool_task.py new file mode 100644 index 0000000000..94d058952c --- /dev/null +++ b/workers/file_processing/structure_tool_task.py @@ -0,0 +1,682 @@ +"""Structure tool Celery task — Phase 3 of executor migration. + +Replaces the Docker-container-based StructureTool.run() with a Celery +task that runs in the file_processing worker. Instead of PromptTool +HTTP calls to prompt-service, it uses ExecutionDispatcher to send +operations to the executor worker via Celery. + +Before (Docker-based): + File Processing Worker → WorkflowExecutionService → ToolSandbox + → Docker container → StructureTool.run() → PromptTool (HTTP) → prompt-service + +After (Celery-based): + File Processing Worker → WorkerWorkflowExecutionService + → execute_structure_tool task → ExecutionDispatcher + → executor worker → LegacyExecutor +""" + +import json +import logging +import os +import time +from pathlib import Path +from typing import Any + +from file_processing.worker import app +from shared.enums.task_enums import TaskName + +from unstract.sdk1.constants import ToolEnv, UsageKwargs +from unstract.sdk1.execution.context import ExecutionContext +from unstract.sdk1.execution.dispatcher import ExecutionDispatcher +from unstract.sdk1.execution.result import ExecutionResult + +logger = logging.getLogger(__name__) + +# Timeout for executor worker calls (seconds). +# Reads from EXECUTOR_RESULT_TIMEOUT env, defaults to 3600. +EXECUTOR_TIMEOUT = int(os.environ.get("EXECUTOR_RESULT_TIMEOUT", 3600)) + + +# ----------------------------------------------------------------------- +# Constants mirrored from tools/structure/src/constants.py +# These are the keys used in tool_metadata and payload dicts. +# ----------------------------------------------------------------------- + + +class _SK: + """SettingsKeys subset needed by the structure tool task.""" + + PROMPT_REGISTRY_ID = "prompt_registry_id" + TOOL_METADATA = "tool_metadata" + TOOL_ID = "tool_id" + OUTPUTS = "outputs" + TOOL_SETTINGS = "tool_settings" + NAME = "name" + ACTIVE = "active" + PROMPT = "prompt" + CHUNK_SIZE = "chunk-size" + CHUNK_OVERLAP = "chunk-overlap" + VECTOR_DB = "vector-db" + EMBEDDING = "embedding" + X2TEXT_ADAPTER = "x2text_adapter" + LLM = "llm" + CHALLENGE_LLM = "challenge_llm" + ENABLE_CHALLENGE = "enable_challenge" + ENABLE_SINGLE_PASS_EXTRACTION = "enable_single_pass_extraction" + SUMMARIZE_AS_SOURCE = "summarize_as_source" + ENABLE_HIGHLIGHT = "enable_highlight" + ENABLE_WORD_CONFIDENCE = "enable_word_confidence" + SUMMARIZE_PROMPT = "summarize_prompt" + TABLE_SETTINGS = "table_settings" + INPUT_FILE = "input_file" + IS_DIRECTORY_MODE = "is_directory_mode" + RUN_ID = "run_id" + EXECUTION_ID = "execution_id" + FILE_HASH = "file_hash" + FILE_NAME = "file_name" + FILE_PATH = "file_path" + EXECUTION_SOURCE = "execution_source" + TOOL = "tool" + EXTRACT = "EXTRACT" + SUMMARIZE = "SUMMARIZE" + METADATA = "metadata" + METRICS = "metrics" + INDEXING = "indexing" + OUTPUT = "output" + CONTEXT = "context" + DATA = "data" + LLM_ADAPTER_INSTANCE_ID = "llm_adapter_instance_id" + PROMPT_KEYS = "prompt_keys" + LLM_PROFILE_ID = "llm_profile_id" + CUSTOM_DATA = "custom_data" + SINGLE_PASS_EXTRACTION_MODE = "single_pass_extraction_mode" + CHALLENGE_LLM_ADAPTER_ID = "challenge_llm_adapter_id" + + +# ----------------------------------------------------------------------- +# Standalone helper functions (extracted from StructureTool methods) +# ----------------------------------------------------------------------- + + +def _apply_profile_overrides(tool_metadata: dict, profile_data: dict) -> list[str]: + """Apply profile overrides to tool metadata. + + Standalone version of StructureTool._apply_profile_overrides. + """ + changes: list[str] = [] + + profile_to_tool_mapping = { + "chunk_overlap": "chunk-overlap", + "chunk_size": "chunk-size", + "embedding_model_id": "embedding", + "llm_id": "llm", + "similarity_top_k": "similarity-top-k", + "vector_store_id": "vector-db", + "x2text_id": "x2text_adapter", + "retrieval_strategy": "retrieval-strategy", + } + + if "tool_settings" in tool_metadata: + changes.extend( + _override_section( + tool_metadata["tool_settings"], + profile_data, + profile_to_tool_mapping, + "tool_settings", + ) + ) + + if "outputs" in tool_metadata: + for i, output in enumerate(tool_metadata["outputs"]): + output_name = output.get("name", f"output_{i}") + changes.extend( + _override_section( + output, + profile_data, + profile_to_tool_mapping, + f"output[{output_name}]", + ) + ) + + return changes + + +def _override_section( + section: dict, + profile_data: dict, + mapping: dict, + section_name: str = "section", +) -> list[str]: + """Override values in a section using profile data.""" + changes: list[str] = [] + for profile_key, section_key in mapping.items(): + if profile_key in profile_data and section_key in section: + old_value = section[section_key] + new_value = profile_data[profile_key] + if old_value != new_value: + section[section_key] = new_value + change_desc = f"{section_name}.{section_key}: {old_value} -> {new_value}" + changes.append(change_desc) + logger.info("Overrode %s", change_desc) + return changes + + +def _should_skip_extraction_for_smart_table( + input_file: str, outputs: list[dict[str, Any]] +) -> bool: + """Check if extraction and indexing should be skipped for smart table. + + Standalone version of StructureTool._should_skip_extraction_for_smart_table. + """ + for output in outputs: + if _SK.TABLE_SETTINGS in output: + prompt = output.get(_SK.PROMPT, "") + if prompt and isinstance(prompt, str): + try: + schema_data = json.loads(prompt) + if schema_data and isinstance(schema_data, dict): + return True + except (json.JSONDecodeError, ValueError) as e: + logger.warning( + "Failed to parse prompt as JSON for smart table: %s", e + ) + continue + return False + + +# ----------------------------------------------------------------------- +# Main Celery task +# ----------------------------------------------------------------------- + + +@app.task(bind=True, name=str(TaskName.EXECUTE_STRUCTURE_TOOL)) +def execute_structure_tool(self, params: dict) -> dict: + """Execute structure tool as a Celery task. + + Replicates StructureTool.run() from tools/structure/src/main.py + but uses ExecutionDispatcher instead of PromptTool HTTP calls. + + Args: + params: Dict with keys described in the Phase 3 plan. + + Returns: + Dict with {"success": bool, "data": dict, "error": str|None}. + """ + try: + return _execute_structure_tool_impl(params) + except Exception as e: + logger.error("Structure tool task failed: %s", e, exc_info=True) + return ExecutionResult.failure(error=f"Structure tool failed: {e}").to_dict() + + +def _execute_structure_tool_impl(params: dict) -> dict: + """Implementation of the structure tool pipeline. + + Separated from the task function for testability. + + Phase 5E: Uses a single ``structure_pipeline`` dispatch instead of + 3 sequential ``dispatcher.dispatch()`` calls. The executor worker + handles the full extract → summarize → index → answer_prompt + pipeline internally, freeing the file_processing worker slot. + """ + # ---- Unpack params ---- + organization_id = params["organization_id"] + workflow_id = params.get("workflow_id", "") + execution_id = params.get("execution_id", "") + file_execution_id = params["file_execution_id"] + tool_instance_metadata = params["tool_instance_metadata"] + platform_service_api_key = params["platform_service_api_key"] + input_file_path = params["input_file_path"] + output_dir_path = params["output_dir_path"] + source_file_name = params["source_file_name"] + execution_data_dir = params["execution_data_dir"] + file_hash = params.get("file_hash", "") + exec_metadata = params.get("exec_metadata", {}) + + # ---- Step 1: Setup ---- + from executor.executor_tool_shim import ExecutorToolShim + + shim = ExecutorToolShim(platform_api_key=platform_service_api_key) + + platform_helper = _create_platform_helper(shim, file_execution_id) + dispatcher = ExecutionDispatcher(celery_app=app) + fs = _get_file_storage() + + # ---- Step 2: Fetch tool metadata ---- + prompt_registry_id = tool_instance_metadata.get(_SK.PROMPT_REGISTRY_ID, "") + logger.info("Fetching exported tool with UUID '%s'", prompt_registry_id) + + tool_metadata, is_agentic = _fetch_tool_metadata(platform_helper, prompt_registry_id) + + # ---- Route agentic vs regular ---- + if is_agentic: + return _run_agentic_extraction( + tool_metadata=tool_metadata, + input_file_path=input_file_path, + output_dir_path=output_dir_path, + tool_instance_metadata=tool_instance_metadata, + dispatcher=dispatcher, + shim=shim, + platform_helper=platform_helper, + file_execution_id=file_execution_id, + organization_id=organization_id, + source_file_name=source_file_name, + fs=fs, + execution_data_dir=execution_data_dir, + ) + + # ---- Step 3: Profile overrides ---- + _handle_profile_overrides(exec_metadata, platform_helper, tool_metadata) + + # ---- Extract settings from tool_metadata ---- + settings = tool_instance_metadata + is_challenge_enabled = settings.get(_SK.ENABLE_CHALLENGE, False) + is_summarization_enabled = settings.get(_SK.SUMMARIZE_AS_SOURCE, False) + is_single_pass_enabled = settings.get(_SK.SINGLE_PASS_EXTRACTION_MODE, False) + challenge_llm = settings.get(_SK.CHALLENGE_LLM_ADAPTER_ID, "") + is_highlight_enabled = settings.get(_SK.ENABLE_HIGHLIGHT, False) + is_word_confidence_enabled = settings.get(_SK.ENABLE_WORD_CONFIDENCE, False) + logger.info( + "HIGHLIGHT_DEBUG structure_tool: is_highlight_enabled=%s " + "is_word_confidence_enabled=%s from settings keys=%s", + is_highlight_enabled, + is_word_confidence_enabled, + list(settings.keys()), + ) + + tool_id = tool_metadata[_SK.TOOL_ID] + tool_settings = tool_metadata[_SK.TOOL_SETTINGS] + outputs = tool_metadata[_SK.OUTPUTS] + + # Inject workflow-level settings into tool_settings + tool_settings[_SK.CHALLENGE_LLM] = challenge_llm + tool_settings[_SK.ENABLE_CHALLENGE] = is_challenge_enabled + tool_settings[_SK.ENABLE_SINGLE_PASS_EXTRACTION] = is_single_pass_enabled + tool_settings[_SK.SUMMARIZE_AS_SOURCE] = is_summarization_enabled + tool_settings[_SK.ENABLE_HIGHLIGHT] = is_highlight_enabled + tool_settings[_SK.ENABLE_WORD_CONFIDENCE] = is_word_confidence_enabled + + _, file_name = os.path.split(input_file_path) + if is_summarization_enabled: + file_name = _SK.SUMMARIZE + + execution_run_data_folder = Path(execution_data_dir) + extracted_input_file = str(execution_run_data_folder / _SK.EXTRACT) + + # ---- Step 4: Smart table detection ---- + skip_extraction_and_indexing = _should_skip_extraction_for_smart_table( + input_file_path, outputs + ) + if skip_extraction_and_indexing: + logger.info( + "Skipping extraction and indexing for Excel table " "with valid JSON schema" + ) + + # ---- Step 5: Build pipeline params ---- + usage_kwargs: dict[Any, Any] = {} + if not skip_extraction_and_indexing: + usage_kwargs[UsageKwargs.RUN_ID] = file_execution_id + usage_kwargs[UsageKwargs.FILE_NAME] = source_file_name + usage_kwargs[UsageKwargs.EXECUTION_ID] = execution_id + + custom_data = exec_metadata.get(_SK.CUSTOM_DATA, {}) + answer_params = { + _SK.RUN_ID: file_execution_id, + _SK.EXECUTION_ID: execution_id, + _SK.TOOL_SETTINGS: tool_settings, + _SK.OUTPUTS: outputs, + _SK.TOOL_ID: tool_id, + _SK.FILE_HASH: file_hash, + _SK.FILE_NAME: file_name, + _SK.FILE_PATH: extracted_input_file, + _SK.EXECUTION_SOURCE: _SK.TOOL, + _SK.CUSTOM_DATA: custom_data, + "PLATFORM_SERVICE_API_KEY": platform_service_api_key, + } + + extract_params = { + "x2text_instance_id": tool_settings[_SK.X2TEXT_ADAPTER], + "file_path": input_file_path, + "enable_highlight": is_highlight_enabled, + "output_file_path": str(execution_run_data_folder / _SK.EXTRACT), + "platform_api_key": platform_service_api_key, + "usage_kwargs": usage_kwargs, + "tags": exec_metadata.get("tags"), + "tool_execution_metadata": exec_metadata, + "execution_data_dir": str(execution_run_data_folder), + } + + index_template = { + "tool_id": tool_id, + "file_hash": file_hash, + "is_highlight_enabled": is_highlight_enabled, + "platform_api_key": platform_service_api_key, + "extracted_file_path": extracted_input_file, + } + + pipeline_options = { + "skip_extraction_and_indexing": skip_extraction_and_indexing, + "is_summarization_enabled": is_summarization_enabled, + "is_single_pass_enabled": is_single_pass_enabled, + "input_file_path": input_file_path, + "source_file_name": source_file_name, + } + + # Build summarize params if enabled + summarize_params = None + if is_summarization_enabled: + prompt_keys = [o[_SK.NAME] for o in outputs] + summarize_params = { + "llm_adapter_instance_id": tool_settings[_SK.LLM], + "summarize_prompt": tool_settings.get(_SK.SUMMARIZE_PROMPT, ""), + "extract_file_path": str(execution_run_data_folder / _SK.EXTRACT), + "summarize_file_path": str(execution_run_data_folder / _SK.SUMMARIZE), + "platform_api_key": platform_service_api_key, + "prompt_keys": prompt_keys, + } + + # ---- Step 6: Single dispatch to executor ---- + logger.info( + "Dispatching structure_pipeline: tool_id=%s " + "skip_extract=%s summarize=%s single_pass=%s", + tool_id, + skip_extraction_and_indexing, + is_summarization_enabled, + is_single_pass_enabled, + ) + + pipeline_ctx = ExecutionContext( + executor_name="legacy", + operation="structure_pipeline", + run_id=file_execution_id, + execution_source="tool", + organization_id=organization_id, + request_id=file_execution_id, + executor_params={ + "extract_params": extract_params, + "index_template": index_template, + "answer_params": answer_params, + "pipeline_options": pipeline_options, + "summarize_params": summarize_params, + }, + ) + pipeline_start = time.monotonic() + pipeline_result = dispatcher.dispatch(pipeline_ctx, timeout=EXECUTOR_TIMEOUT) + pipeline_elapsed = time.monotonic() - pipeline_start + + if not pipeline_result.success: + return pipeline_result.to_dict() + + structured_output = pipeline_result.data + + # ---- Step 7: Write output files ---- + # (metadata/metrics merging already done by executor pipeline) + try: + output_path = Path(output_dir_path) / f"{Path(source_file_name).stem}.json" + logger.info("Writing output to %s", output_path) + fs.json_dump(path=output_path, data=structured_output) + + # Overwrite INFILE with JSON output (matches Docker-based tool behavior). + # The destination connector reads from INFILE and checks MIME type — + # if we don't overwrite it, INFILE still has the original PDF. + logger.info("Overwriting INFILE with structured output: %s", input_file_path) + fs.json_dump(path=input_file_path, data=structured_output) + + logger.info("Output written successfully to workflow storage") + except Exception as e: + return ExecutionResult.failure(error=f"Error writing output file: {e}").to_dict() + + # Write tool result + tool_metadata to METADATA.json + # (destination connector reads output_type from tool_metadata) + _write_tool_result(fs, execution_data_dir, structured_output, pipeline_elapsed) + + return ExecutionResult(success=True, data=structured_output).to_dict() + + +# ----------------------------------------------------------------------- +# Helper functions for the pipeline steps +# ----------------------------------------------------------------------- + + +def _create_platform_helper(shim, request_id: str): + """Create PlatformHelper using env vars for host/port.""" + from unstract.sdk1.platform import PlatformHelper + + return PlatformHelper( + tool=shim, + platform_host=os.environ.get(ToolEnv.PLATFORM_HOST, ""), + platform_port=os.environ.get(ToolEnv.PLATFORM_PORT, ""), + request_id=request_id, + ) + + +def _get_file_storage(): + """Get workflow execution file storage instance.""" + from unstract.filesystem import FileStorageType, FileSystem + + return FileSystem(FileStorageType.WORKFLOW_EXECUTION).get_file_storage() + + +def _fetch_tool_metadata(platform_helper, prompt_registry_id: str) -> tuple[dict, bool]: + """Fetch tool metadata from platform, trying prompt studio then agentic. + + Returns: + Tuple of (tool_metadata dict, is_agentic bool). + + Raises: + RuntimeError: If neither registry returns valid metadata. + """ + exported_tool = None + try: + exported_tool = platform_helper.get_prompt_studio_tool( + prompt_registry_id=prompt_registry_id + ) + except Exception as e: + logger.info("Not found as prompt studio project, trying agentic: %s", e) + + if exported_tool and _SK.TOOL_METADATA in exported_tool: + tool_metadata = exported_tool[_SK.TOOL_METADATA] + tool_metadata["is_agentic"] = False + return tool_metadata, False + + # Try agentic registry + try: + agentic_tool = platform_helper.get_agentic_studio_tool( + agentic_registry_id=prompt_registry_id + ) + if not agentic_tool or _SK.TOOL_METADATA not in agentic_tool: + raise RuntimeError( + f"Registry returned empty response for {prompt_registry_id}" + ) + tool_metadata = agentic_tool[_SK.TOOL_METADATA] + tool_metadata["is_agentic"] = True + logger.info( + "Retrieved agentic project: %s", + tool_metadata.get("name", prompt_registry_id), + ) + return tool_metadata, True + except Exception as agentic_error: + raise RuntimeError( + f"Error fetching project from both registries " + f"for ID '{prompt_registry_id}': {agentic_error}" + ) from agentic_error + + +def _handle_profile_overrides( + exec_metadata: dict, platform_helper, tool_metadata: dict +) -> None: + """Apply LLM profile overrides if configured.""" + llm_profile_id = exec_metadata.get(_SK.LLM_PROFILE_ID) + if not llm_profile_id: + return + + try: + llm_profile = platform_helper.get_llm_profile(llm_profile_id) + if llm_profile: + profile_name = llm_profile.get("profile_name", llm_profile_id) + logger.info( + "Applying profile overrides from profile: %s", + profile_name, + ) + changes = _apply_profile_overrides(tool_metadata, llm_profile) + if changes: + logger.info( + "Profile overrides applied. Changes: %s", + "; ".join(changes), + ) + else: + logger.info("Profile overrides applied - no changes needed") + except Exception as e: + raise RuntimeError(f"Error applying profile overrides: {e}") from e + + +def _run_agentic_extraction( + tool_metadata: dict, + input_file_path: str, + output_dir_path: str, + tool_instance_metadata: dict, + dispatcher: ExecutionDispatcher, + shim: Any, + platform_helper: Any, + file_execution_id: str, + organization_id: str, + source_file_name: str, + fs: Any, + execution_data_dir: str = "", +) -> dict: + """Execute agentic extraction pipeline via dispatcher. + + Unpacks metadata, extracts document text via X2Text, then dispatches + with flat executor_params matching what AgenticPromptStudioExecutor + expects (adapter_instance_id, document_text, etc.). + """ + from unstract.sdk1.x2txt import X2Text + + # 1. Unpack agentic project metadata (matches registry_helper export format) + adapter_config = tool_metadata.get("adapter_config", {}) + prompt_text = tool_metadata.get("prompt_text", "") + json_schema = tool_metadata.get("json_schema", {}) + enable_highlight = tool_instance_metadata.get( + "enable_highlight", + tool_metadata.get("enable_highlight", False), + ) + + # 2. Get adapter IDs: workflow UI overrides → exported defaults + # (mirrors tools/structure/src/main.py) + extractor_llm = tool_instance_metadata.get( + "extractor_llm_adapter_id", adapter_config.get("extractor_llm", "") + ) + llmwhisperer = tool_instance_metadata.get( + "llmwhisperer_adapter_id", adapter_config.get("llmwhisperer", "") + ) + platform_service_api_key = shim.platform_api_key + + # 3. Extract text from document using X2Text/LLMWhisperer + x2text = X2Text(tool=shim, adapter_instance_id=llmwhisperer) + extraction_result = x2text.process( + input_file_path=input_file_path, + enable_highlight=enable_highlight, + fs=fs, + ) + document_text = extraction_result.extracted_text + + # Parse json_schema if stored as string + if isinstance(json_schema, str): + json_schema = json.loads(json_schema) + + # 4. Dispatch with flat executor_params matching executor expectations + start_time = time.monotonic() + agentic_ctx = ExecutionContext( + executor_name="agentic", + operation="agentic_extract", + run_id=file_execution_id, + execution_source="tool", + organization_id=organization_id, + request_id=file_execution_id, + executor_params={ + "document_id": file_execution_id, + "document_text": document_text, + "prompt_text": prompt_text, + "schema": json_schema, + "adapter_instance_id": extractor_llm, + "PLATFORM_SERVICE_API_KEY": platform_service_api_key, + "include_source_refs": enable_highlight, + }, + ) + agentic_result = dispatcher.dispatch(agentic_ctx, timeout=EXECUTOR_TIMEOUT) + + if not agentic_result.success: + return agentic_result.to_dict() + + structured_output = agentic_result.data + elapsed = time.monotonic() - start_time + + # Write output files (matches regular pipeline path) + try: + output_path = Path(output_dir_path) / f"{Path(source_file_name).stem}.json" + logger.info("Writing agentic output to %s", output_path) + fs.json_dump(path=output_path, data=structured_output) + + # Overwrite INFILE with JSON output so destination connector reads JSON, not PDF + logger.info("Overwriting INFILE with agentic output: %s", input_file_path) + fs.json_dump(path=input_file_path, data=structured_output) + except Exception as e: + return ExecutionResult.failure( + error=f"Error writing agentic output: {e}" + ).to_dict() + + # Write tool result + tool_metadata to METADATA.json + _write_tool_result(fs, execution_data_dir, structured_output, elapsed) + + return ExecutionResult(success=True, data=structured_output).to_dict() + + +def _write_tool_result( + fs: Any, execution_data_dir: str, data: dict, elapsed_time: float = 0.0 +) -> None: + """Write tool result and tool_metadata to METADATA.json. + + Matches BaseTool._update_exec_metadata() + write_tool_result(): + - tool_metadata: list of dicts with tool_name, output_type, elapsed_time + (destination connector reads output_type from here) + - total_elapsed_time: cumulative elapsed time + - tool_result: the structured output data + """ + try: + metadata_path = Path(execution_data_dir) / "METADATA.json" + + # Read existing metadata if present + existing: dict = {} + if fs.exists(metadata_path): + try: + existing_raw = fs.read(path=metadata_path, mode="r") + if existing_raw: + existing = json.loads(existing_raw) + except Exception: + pass + + # Add tool_metadata (matches BaseTool._update_exec_metadata) + # The destination connector reads output_type from tool_metadata[-1] + tool_meta_entry = { + "tool_name": "structure_tool", + "output_type": "JSON", + "elapsed_time": elapsed_time, + } + if "tool_metadata" not in existing: + existing["tool_metadata"] = [tool_meta_entry] + else: + existing["tool_metadata"].append(tool_meta_entry) + + existing["total_elapsed_time"] = ( + existing.get("total_elapsed_time", 0.0) + elapsed_time + ) + + # Add tool result + existing["tool_result"] = data + fs.write( + path=metadata_path, + mode="w", + data=json.dumps(existing, indent=2), + ) + except Exception as e: + logger.warning("Failed to write tool result to METADATA.json: %s", e) diff --git a/workers/run-worker-docker.sh b/workers/run-worker-docker.sh index cdf7e9538d..16668a919e 100755 --- a/workers/run-worker-docker.sh +++ b/workers/run-worker-docker.sh @@ -35,6 +35,7 @@ declare -A WORKERS=( ["log-consumer"]="log_consumer" ["scheduler"]="scheduler" ["schedule"]="scheduler" + ["executor"]="executor" ["all"]="all" ) @@ -51,6 +52,7 @@ declare -A WORKER_QUEUES=( ["notification"]="notifications,notifications_webhook,notifications_email,notifications_sms,notifications_priority" ["log_consumer"]="celery_log_task_queue" ["scheduler"]="scheduler" + ["executor"]="celery_executor_legacy" ) # Worker health ports @@ -62,6 +64,7 @@ declare -A WORKER_HEALTH_PORTS=( ["log_consumer"]="8084" ["notification"]="8085" ["scheduler"]="8087" + ["executor"]="8088" ) # Function to print colored output @@ -196,6 +199,7 @@ detect_worker_type_from_args() { *"notifications"*) echo "notification" ;; *"celery_log_task_queue"*) echo "log_consumer" ;; *"scheduler"*) echo "scheduler" ;; + *"executor"*) echo "executor" ;; *"celery"*) echo "general" ;; *) echo "general" ;; # fallback esac @@ -259,6 +263,9 @@ run_worker() { "scheduler") queues="${CELERY_QUEUES_SCHEDULER:-$queues}" ;; + "executor") + queues="${CELERY_QUEUES_EXECUTOR:-$queues}" + ;; esac # Get health port @@ -294,6 +301,10 @@ run_worker() { export SCHEDULER_HEALTH_PORT="${health_port}" export SCHEDULER_METRICS_PORT="${health_port}" ;; + "executor") + export EXECUTOR_HEALTH_PORT="${health_port}" + export EXECUTOR_METRICS_PORT="${health_port}" + ;; *) # Default for pluggable workers local worker_type_upper=$(echo "$worker_type" | tr '[:lower:]' '[:upper:]' | tr '-' '_') @@ -326,6 +337,9 @@ run_worker() { "scheduler") concurrency="${WORKER_SCHEDULER_CONCURRENCY:-2}" ;; + "executor") + concurrency="${WORKER_EXECUTOR_CONCURRENCY:-2}" + ;; *) # Default for pluggable workers or unknown types local worker_type_upper=$(echo "$worker_type" | tr '[:lower:]' '[:upper:]' | tr '-' '_') @@ -534,6 +548,10 @@ if [[ "$1" == *"celery"* ]] || [[ "$1" == *".venv"* ]]; then export SCHEDULER_HEALTH_PORT="8087" export SCHEDULER_METRICS_PORT="8087" ;; + "executor") + export EXECUTOR_HEALTH_PORT="8088" + export EXECUTOR_METRICS_PORT="8088" + ;; *) # Default for pluggable workers - use dynamic port from WORKER_HEALTH_PORTS health_port="${WORKER_HEALTH_PORTS[$WORKER_TYPE]:-8090}" diff --git a/workers/run-worker.sh b/workers/run-worker.sh index 152a72d859..abd6931534 100755 --- a/workers/run-worker.sh +++ b/workers/run-worker.sh @@ -37,6 +37,7 @@ declare -A WORKERS=( ["notify"]="notification" ["scheduler"]="scheduler" ["schedule"]="scheduler" + ["executor"]="executor" ["all"]="all" ) @@ -52,6 +53,7 @@ declare -A WORKER_QUEUES=( ["log_consumer"]="celery_log_task_queue" ["notification"]="notifications,notifications_webhook,notifications_email,notifications_sms,notifications_priority" ["scheduler"]="scheduler" + ["executor"]="celery_executor_legacy" ) # Worker health ports @@ -63,6 +65,7 @@ declare -A WORKER_HEALTH_PORTS=( ["log_consumer"]="8084" ["notification"]="8085" ["scheduler"]="8087" + ["executor"]="8088" ) # Function to display usage @@ -80,6 +83,7 @@ WORKER_TYPE: log, log-consumer Run log consumer worker notification, notify Run notification worker scheduler, schedule Run scheduler worker (scheduled pipeline tasks) + executor Run executor worker (extraction execution tasks) all Run all workers (in separate processes, includes auto-discovered pluggable workers) Note: Pluggable workers in pluggable_worker/ directory are automatically discovered and can be run by name. @@ -147,6 +151,7 @@ HEALTH CHECKS: - Log Consumer: http://localhost:8084/health - Notification: http://localhost:8085/health - Scheduler: http://localhost:8087/health + - Executor: http://localhost:8088/health - Pluggable workers: http://localhost:8090+/health (auto-assigned ports) EOF @@ -301,7 +306,7 @@ show_status() { print_status $BLUE "Worker Status:" echo "==============" - local workers_to_check="api-deployment general file_processing callback log_consumer notification scheduler" + local workers_to_check="api-deployment general file_processing callback log_consumer notification scheduler executor" # Add discovered pluggable workers if [[ ${#PLUGGABLE_WORKERS[@]} -gt 0 ]]; then @@ -405,6 +410,9 @@ run_worker() { "scheduler") export SCHEDULER_HEALTH_PORT="$health_port" ;; + "executor") + export EXECUTOR_HEALTH_PORT="$health_port" + ;; *) # Handle pluggable workers dynamically if [[ -n "${PLUGGABLE_WORKERS[$worker_type]:-}" ]]; then @@ -478,6 +486,9 @@ run_worker() { "scheduler") cmd_args+=("--concurrency=2") ;; + "executor") + cmd_args+=("--concurrency=2") + ;; *) # Default for pluggable and other workers if [[ -n "${PLUGGABLE_WORKERS[$worker_type]:-}" ]]; then @@ -525,7 +536,7 @@ run_all_workers() { print_status $GREEN "Starting all workers..." # Define core workers - local core_workers="api-deployment general file_processing callback log_consumer notification scheduler" + local core_workers="api-deployment general file_processing callback log_consumer notification scheduler executor" # Add discovered pluggable workers if [[ ${#PLUGGABLE_WORKERS[@]} -gt 0 ]]; then diff --git a/workers/shared/enums/task_enums.py b/workers/shared/enums/task_enums.py index 5f57913cd9..6f3fa1cdd7 100644 --- a/workers/shared/enums/task_enums.py +++ b/workers/shared/enums/task_enums.py @@ -33,6 +33,12 @@ class TaskName(str, Enum): # API deployment worker tasks CHECK_API_DEPLOYMENT_STATUS = "check_api_deployment_status" + # Structure tool task (runs in file_processing worker) + EXECUTE_STRUCTURE_TOOL = "execute_structure_tool" + + # Executor worker tasks + EXECUTE_EXTRACTION = "execute_extraction" + def __str__(self): """Return enum value for Celery task naming.""" return self.value diff --git a/workers/shared/enums/worker_enums_base.py b/workers/shared/enums/worker_enums_base.py index babc19512f..3ed5a6ff35 100644 --- a/workers/shared/enums/worker_enums_base.py +++ b/workers/shared/enums/worker_enums_base.py @@ -23,6 +23,7 @@ class WorkerType(str, Enum): NOTIFICATION = "notification" LOG_CONSUMER = "log_consumer" SCHEDULER = "scheduler" + EXECUTOR = "executor" @classmethod def from_directory_name(cls, name: str) -> "WorkerType": @@ -110,6 +111,7 @@ def to_health_port(self) -> int: WorkerType.NOTIFICATION: 8085, WorkerType.LOG_CONSUMER: 8086, WorkerType.SCHEDULER: 8087, + WorkerType.EXECUTOR: 8088, } return port_mapping.get(self, 8080) @@ -147,6 +149,11 @@ class QueueName(str, Enum): # Scheduler queue SCHEDULER = "scheduler" + # Executor queue — queue-per-executor naming convention. + # The dispatcher derives queue names as ``celery_executor_{executor_name}``. + # The "legacy" executor is the default OSS executor. + EXECUTOR = "celery_executor_legacy" + def to_env_var_name(self) -> str: """Convert queue name to environment variable name. diff --git a/workers/shared/infrastructure/config/registry.py b/workers/shared/infrastructure/config/registry.py index 37ad1c08b9..8d1b208032 100644 --- a/workers/shared/infrastructure/config/registry.py +++ b/workers/shared/infrastructure/config/registry.py @@ -64,6 +64,9 @@ class WorkerRegistry: WorkerType.SCHEDULER: WorkerQueueConfig( primary_queue=QueueName.SCHEDULER, additional_queues=[QueueName.GENERAL] ), + WorkerType.EXECUTOR: WorkerQueueConfig( + primary_queue=QueueName.EXECUTOR, + ), } # Pluggable worker configurations loaded dynamically @@ -134,6 +137,13 @@ class WorkerRegistry: TaskRoute("scheduler.tasks.*", QueueName.SCHEDULER), ], ), + WorkerType.EXECUTOR: WorkerTaskRouting( + worker_type=WorkerType.EXECUTOR, + routes=[ + TaskRoute("execute_extraction", QueueName.EXECUTOR), + TaskRoute("executor.tasks.*", QueueName.EXECUTOR), + ], + ), } # Pluggable worker task routes loaded dynamically @@ -171,6 +181,9 @@ class WorkerRegistry: WorkerType.SCHEDULER: { "log_level": "INFO", }, + WorkerType.EXECUTOR: { + "log_level": "INFO", + }, } # Pluggable worker logging configs loaded dynamically diff --git a/workers/shared/workflow/execution/service.py b/workers/shared/workflow/execution/service.py index e38e372a91..0ad7b5c67f 100644 --- a/workers/shared/workflow/execution/service.py +++ b/workers/shared/workflow/execution/service.py @@ -971,17 +971,115 @@ def _prepare_workflow_input_file( def _build_and_execute_workflow( self, execution_service: WorkflowExecutionService, file_name: str ) -> None: - """Build and execute the workflow.""" - # Build workflow - execution_service.build_workflow() - logger.info(f"Workflow built successfully for file {file_name}") + """Build and execute the workflow. - # Execute workflow - from unstract.workflow_execution.enums import ExecutionType + When the async_prompt_execution flag is ON, detects structure tool + workflows and routes them to the Celery-based execute_structure_tool + task instead of the Docker container flow. When the flag is OFF (or + unreachable), all workflows use the original Docker flow. + """ + from unstract.flags.feature_flag import check_feature_flag_status + + use_async = False + try: + use_async = check_feature_flag_status("async_prompt_execution") + except Exception: + logger.warning("Feature flag check failed, using Docker flow") + + if use_async and self._is_structure_tool_workflow(execution_service): + self._execute_structure_tool_workflow(execution_service, file_name) + else: + # Original Docker-based flow (for non-structure tools or flag OFF) + execution_service.build_workflow() + logger.info(f"Workflow built successfully for file {file_name}") - execution_service.execute_workflow(ExecutionType.COMPLETE) + from unstract.workflow_execution.enums import ExecutionType + + execution_service.execute_workflow(ExecutionType.COMPLETE) logger.info(f"Workflow executed successfully for file {file_name}") + def _is_structure_tool_workflow( + self, execution_service: WorkflowExecutionService + ) -> bool: + """Check if workflow uses the structure tool. + + Compares the base image name (last path component without tag) + to handle registry prefixes like gcr.io/project/tool-structure + vs the default unstract/tool-structure. + """ + structure_image = os.environ.get( + "STRUCTURE_TOOL_IMAGE_NAME", "unstract/tool-structure" + ) + structure_base = structure_image.split(":")[0].rsplit("/", 1)[-1] + for ti in execution_service.tool_instances: + ti_name = str(ti.image_name) if ti.image_name else "" + if not ti_name: + continue + ti_base = ti_name.split(":")[0].rsplit("/", 1)[-1] + if ti_name == structure_image or ti_base == structure_base: + logger.info( + "Detected structure tool workflow " + f"(image={ti_name}, expected={structure_image})" + ) + return True + return False + + def _execute_structure_tool_workflow( + self, execution_service: WorkflowExecutionService, file_name: str + ) -> None: + """Execute structure tool as Celery task instead of Docker container. + + Calls execute_structure_tool directly (same process, in-band). + Only the inner ExecutionDispatcher calls go through Celery to + the executor worker. + """ + from file_processing.structure_tool_task import ( + execute_structure_tool as _execute_structure_tool, + ) + + tool_instance = execution_service.tool_instances[0] + file_handler = execution_service.file_handler + + # Read metadata from METADATA.json for file_hash and exec_metadata + metadata = {} + try: + metadata = file_handler.get_workflow_metadata() + except Exception as e: + logger.warning(f"Could not read workflow metadata: {e}") + + # Get API key from the same source used to create execution_service + platform_api_key = self._get_platform_service_api_key( + execution_service.organization_id + ) + + params = { + "organization_id": execution_service.organization_id, + "workflow_id": execution_service.workflow_id, + "execution_id": execution_service.execution_id, + "file_execution_id": execution_service.file_execution_id, + "tool_instance_metadata": tool_instance.metadata, + "platform_service_api_key": platform_api_key, + "input_file_path": str(file_handler.infile), + "output_dir_path": str(file_handler.execution_dir), + "source_file_name": str( + os.path.basename(file_handler.source_file) + if file_handler.source_file + else file_name + ), + "execution_data_dir": str(file_handler.file_execution_dir), + "messaging_channel": getattr(execution_service, "messaging_channel", ""), + "file_hash": metadata.get("source_hash", ""), + "exec_metadata": metadata, + } + + # Call synchronously (same process, in-band) + result = _execute_structure_tool(params) + + if not result.get("success"): + raise Exception( + f"Structure tool failed: {result.get('error', 'Unknown error')}" + ) + def _extract_source_connector_details( self, source_config: dict[str, Any] | None ) -> tuple[str | None, dict[str, Any]]: diff --git a/workers/tests/__init__.py b/workers/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/workers/tests/conftest.py b/workers/tests/conftest.py new file mode 100644 index 0000000000..084a8ef88c --- /dev/null +++ b/workers/tests/conftest.py @@ -0,0 +1,14 @@ +"""Shared fixtures for workers tests. + +Environment variables are loaded from .env.test at module level +BEFORE any shared package imports. This is required because +shared/constants/api_endpoints.py raises ValueError at import +time if INTERNAL_API_BASE_URL is not set. +""" + +from pathlib import Path + +from dotenv import load_dotenv + +_env_test = Path(__file__).resolve().parent.parent / ".env.test" +load_dotenv(_env_test) diff --git a/workers/tests/test_answer_prompt.py b/workers/tests/test_answer_prompt.py new file mode 100644 index 0000000000..53dfd4d79f --- /dev/null +++ b/workers/tests/test_answer_prompt.py @@ -0,0 +1,861 @@ +"""Tests for the answer_prompt pipeline (Phase 2E). + +Tests the _handle_answer_prompt method, AnswerPromptService, +VariableReplacementService, and type conversion logic. +All heavy dependencies (LLM, VectorDB, etc.) are mocked. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from executor.executors.constants import ( + PromptServiceConstants as PSKeys, +) +from unstract.sdk1.execution.context import ExecutionContext, Operation + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_prompt( + name: str = "field_a", + prompt: str = "What is the revenue?", + output_type: str = "text", + chunk_size: int = 512, + chunk_overlap: int = 128, + retrieval_strategy: str = "simple", + llm_id: str = "llm-1", + embedding_id: str = "emb-1", + vector_db_id: str = "vdb-1", + x2text_id: str = "x2t-1", + similarity_top_k: int = 5, +): + """Build a minimal prompt definition dict.""" + return { + PSKeys.NAME: name, + PSKeys.PROMPT: prompt, + PSKeys.TYPE: output_type, + PSKeys.CHUNK_SIZE: chunk_size, + PSKeys.CHUNK_OVERLAP: chunk_overlap, + PSKeys.RETRIEVAL_STRATEGY: retrieval_strategy, + PSKeys.LLM: llm_id, + PSKeys.EMBEDDING: embedding_id, + PSKeys.VECTOR_DB: vector_db_id, + PSKeys.X2TEXT_ADAPTER: x2text_id, + PSKeys.SIMILARITY_TOP_K: similarity_top_k, + } + + +def _make_context( + prompts=None, + tool_settings=None, + file_hash="abc123", + file_path="/data/doc.txt", + file_name="doc.txt", + execution_source="ide", + platform_api_key="pk-test", + run_id="run-1", +): + """Build an ExecutionContext for answer_prompt.""" + if prompts is None: + prompts = [_make_prompt()] + if tool_settings is None: + tool_settings = {} + + params = { + PSKeys.OUTPUTS: prompts, + PSKeys.TOOL_SETTINGS: tool_settings, + PSKeys.TOOL_ID: "tool-1", + PSKeys.EXECUTION_ID: "exec-1", + PSKeys.FILE_HASH: file_hash, + PSKeys.FILE_PATH: file_path, + PSKeys.FILE_NAME: file_name, + PSKeys.LOG_EVENTS_ID: "", + PSKeys.CUSTOM_DATA: {}, + PSKeys.EXECUTION_SOURCE: execution_source, + PSKeys.PLATFORM_SERVICE_API_KEY: platform_api_key, + } + return ExecutionContext( + executor_name="legacy", + operation=Operation.ANSWER_PROMPT.value, + executor_params=params, + run_id=run_id, + execution_source=execution_source, + ) + + +def _mock_llm(): + """Create a mock LLM that returns a configurable answer.""" + llm = MagicMock(name="llm") + response = MagicMock() + response.text = "test answer" + llm.complete.return_value = { + PSKeys.RESPONSE: response, + PSKeys.HIGHLIGHT_DATA: [], + PSKeys.CONFIDENCE_DATA: None, + PSKeys.WORD_CONFIDENCE_DATA: None, + PSKeys.LINE_NUMBERS: [], + PSKeys.WHISPER_HASH: "", + } + llm.get_usage_reason.return_value = "extraction" + llm.get_metrics.return_value = {"tokens": 100} + return llm + + +def _mock_deps(llm=None): + """Return a tuple of mocked prompt deps matching _get_prompt_deps().""" + if llm is None: + llm = _mock_llm() + + # AnswerPromptService — use the real class + from executor.executors.answer_prompt import AnswerPromptService + + RetrievalService = MagicMock(name="RetrievalService") + RetrievalService.run_retrieval.return_value = ["chunk1", "chunk2"] + RetrievalService.retrieve_complete_context.return_value = ["full content"] + + VariableReplacementService = MagicMock(name="VariableReplacementService") + VariableReplacementService.is_variables_present.return_value = False + + Index = MagicMock(name="Index") + index_instance = MagicMock() + index_instance.generate_index_key.return_value = "doc-id-1" + Index.return_value = index_instance + + LLM_cls = MagicMock(name="LLM") + LLM_cls.return_value = llm + + EmbeddingCompat = MagicMock(name="EmbeddingCompat") + VectorDB = MagicMock(name="VectorDB") + + return ( + AnswerPromptService, + RetrievalService, + VariableReplacementService, + Index, + LLM_cls, + EmbeddingCompat, + VectorDB, + ) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +_PATCH_INDEX_UTILS = ( + "unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key" +) + + +@pytest.fixture(autouse=True) +def _mock_indexing_utils(): + """Mock IndexingUtils.generate_index_key for all answer_prompt tests. + + _handle_answer_prompt calls IndexingUtils.generate_index_key(tool=shim) + which delegates to PlatformHelper.get_adapter_config() — a real HTTP + call. Since tests use a mock shim, the platform URL is invalid. + """ + with patch(_PATCH_INDEX_UTILS, return_value="doc-id-test"): + yield + + +# --------------------------------------------------------------------------- +# Tests — _handle_answer_prompt +# --------------------------------------------------------------------------- + +class TestHandleAnswerPromptText: + """Tests for TEXT type prompts.""" + + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_text_prompt_returns_success(self, mock_shim_cls, mock_deps): + """Simple TEXT prompt returns success with structured output.""" + from executor.executors.legacy_executor import LegacyExecutor + + llm = _mock_llm() + mock_deps.return_value = _mock_deps(llm) + mock_shim_cls.return_value = MagicMock() + + executor = LegacyExecutor() + ctx = _make_context() + result = executor._handle_answer_prompt(ctx) + + assert result.success is True + assert PSKeys.OUTPUT in result.data + assert PSKeys.METADATA in result.data + assert PSKeys.METRICS in result.data + assert "field_a" in result.data[PSKeys.OUTPUT] + + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_text_prompt_answer_stored(self, mock_shim_cls, mock_deps): + """The LLM answer is stored in structured_output.""" + from executor.executors.legacy_executor import LegacyExecutor + + llm = _mock_llm() + mock_deps.return_value = _mock_deps(llm) + mock_shim_cls.return_value = MagicMock() + + executor = LegacyExecutor() + ctx = _make_context() + result = executor._handle_answer_prompt(ctx) + + assert result.data[PSKeys.OUTPUT]["field_a"] == "test answer" + + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_trailing_newline_stripped(self, mock_shim_cls, mock_deps): + """Trailing newlines are stripped from text answers.""" + from executor.executors.legacy_executor import LegacyExecutor + + llm = _mock_llm() + response = MagicMock() + response.text = "answer with trailing\n" + llm.complete.return_value = { + PSKeys.RESPONSE: response, + PSKeys.HIGHLIGHT_DATA: [], + PSKeys.CONFIDENCE_DATA: None, + PSKeys.WORD_CONFIDENCE_DATA: None, + PSKeys.LINE_NUMBERS: [], + PSKeys.WHISPER_HASH: "", + } + mock_deps.return_value = _mock_deps(llm) + mock_shim_cls.return_value = MagicMock() + + executor = LegacyExecutor() + result = executor._handle_answer_prompt(_make_context()) + + assert result.data[PSKeys.OUTPUT]["field_a"] == "answer with trailing" + + +class TestHandleAnswerPromptTypes: + """Tests for type-specific post-processing.""" + + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_number_type_converts_to_float(self, mock_shim_cls, mock_deps): + """NUMBER type converts answer to float.""" + from executor.executors.legacy_executor import LegacyExecutor + + llm = _mock_llm() + # First call: main retrieval answer. Second call: number extraction. + response1 = MagicMock() + response1.text = "revenue is $42.5M" + response2 = MagicMock() + response2.text = "42500000" + llm.complete.side_effect = [ + {PSKeys.RESPONSE: response1, PSKeys.HIGHLIGHT_DATA: [], + PSKeys.CONFIDENCE_DATA: None, PSKeys.WORD_CONFIDENCE_DATA: None, + PSKeys.LINE_NUMBERS: [], PSKeys.WHISPER_HASH: ""}, + {PSKeys.RESPONSE: response2, PSKeys.HIGHLIGHT_DATA: [], + PSKeys.CONFIDENCE_DATA: None, PSKeys.WORD_CONFIDENCE_DATA: None, + PSKeys.LINE_NUMBERS: [], PSKeys.WHISPER_HASH: ""}, + ] + mock_deps.return_value = _mock_deps(llm) + mock_shim_cls.return_value = MagicMock() + + executor = LegacyExecutor() + ctx = _make_context(prompts=[_make_prompt(output_type="number")]) + result = executor._handle_answer_prompt(ctx) + + assert result.data[PSKeys.OUTPUT]["field_a"] == 42500000.0 + + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_number_na_returns_none(self, mock_shim_cls, mock_deps): + """NUMBER type with NA answer returns None.""" + from executor.executors.legacy_executor import LegacyExecutor + + llm = _mock_llm() + response = MagicMock() + response.text = "NA" + llm.complete.return_value = { + PSKeys.RESPONSE: response, PSKeys.HIGHLIGHT_DATA: [], + PSKeys.CONFIDENCE_DATA: None, PSKeys.WORD_CONFIDENCE_DATA: None, + PSKeys.LINE_NUMBERS: [], PSKeys.WHISPER_HASH: "", + } + mock_deps.return_value = _mock_deps(llm) + mock_shim_cls.return_value = MagicMock() + + executor = LegacyExecutor() + ctx = _make_context(prompts=[_make_prompt(output_type="number")]) + result = executor._handle_answer_prompt(ctx) + + # NA → sanitized to None + assert result.data[PSKeys.OUTPUT]["field_a"] is None + + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_boolean_yes(self, mock_shim_cls, mock_deps): + """BOOLEAN type converts 'yes' to True.""" + from executor.executors.legacy_executor import LegacyExecutor + + llm = _mock_llm() + response1 = MagicMock() + response1.text = "The document confirms it" + response2 = MagicMock() + response2.text = "yes" + llm.complete.side_effect = [ + {PSKeys.RESPONSE: response1, PSKeys.HIGHLIGHT_DATA: [], + PSKeys.CONFIDENCE_DATA: None, PSKeys.WORD_CONFIDENCE_DATA: None, + PSKeys.LINE_NUMBERS: [], PSKeys.WHISPER_HASH: ""}, + {PSKeys.RESPONSE: response2, PSKeys.HIGHLIGHT_DATA: [], + PSKeys.CONFIDENCE_DATA: None, PSKeys.WORD_CONFIDENCE_DATA: None, + PSKeys.LINE_NUMBERS: [], PSKeys.WHISPER_HASH: ""}, + ] + mock_deps.return_value = _mock_deps(llm) + mock_shim_cls.return_value = MagicMock() + + executor = LegacyExecutor() + ctx = _make_context(prompts=[_make_prompt(output_type="boolean")]) + result = executor._handle_answer_prompt(ctx) + + assert result.data[PSKeys.OUTPUT]["field_a"] is True + + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_boolean_no(self, mock_shim_cls, mock_deps): + """BOOLEAN type converts 'no' to False.""" + from executor.executors.legacy_executor import LegacyExecutor + + llm = _mock_llm() + response1 = MagicMock() + response1.text = "not confirmed" + response2 = MagicMock() + response2.text = "no" + llm.complete.side_effect = [ + {PSKeys.RESPONSE: response1, PSKeys.HIGHLIGHT_DATA: [], + PSKeys.CONFIDENCE_DATA: None, PSKeys.WORD_CONFIDENCE_DATA: None, + PSKeys.LINE_NUMBERS: [], PSKeys.WHISPER_HASH: ""}, + {PSKeys.RESPONSE: response2, PSKeys.HIGHLIGHT_DATA: [], + PSKeys.CONFIDENCE_DATA: None, PSKeys.WORD_CONFIDENCE_DATA: None, + PSKeys.LINE_NUMBERS: [], PSKeys.WHISPER_HASH: ""}, + ] + mock_deps.return_value = _mock_deps(llm) + mock_shim_cls.return_value = MagicMock() + + executor = LegacyExecutor() + ctx = _make_context(prompts=[_make_prompt(output_type="boolean")]) + result = executor._handle_answer_prompt(ctx) + + assert result.data[PSKeys.OUTPUT]["field_a"] is False + + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_email_type(self, mock_shim_cls, mock_deps): + """EMAIL type extracts email address.""" + from executor.executors.legacy_executor import LegacyExecutor + + llm = _mock_llm() + response1 = MagicMock() + response1.text = "Contact: user@example.com" + response2 = MagicMock() + response2.text = "user@example.com" + llm.complete.side_effect = [ + {PSKeys.RESPONSE: response1, PSKeys.HIGHLIGHT_DATA: [], + PSKeys.CONFIDENCE_DATA: None, PSKeys.WORD_CONFIDENCE_DATA: None, + PSKeys.LINE_NUMBERS: [], PSKeys.WHISPER_HASH: ""}, + {PSKeys.RESPONSE: response2, PSKeys.HIGHLIGHT_DATA: [], + PSKeys.CONFIDENCE_DATA: None, PSKeys.WORD_CONFIDENCE_DATA: None, + PSKeys.LINE_NUMBERS: [], PSKeys.WHISPER_HASH: ""}, + ] + mock_deps.return_value = _mock_deps(llm) + mock_shim_cls.return_value = MagicMock() + + executor = LegacyExecutor() + ctx = _make_context(prompts=[_make_prompt(output_type="email")]) + result = executor._handle_answer_prompt(ctx) + + assert result.data[PSKeys.OUTPUT]["field_a"] == "user@example.com" + + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_date_type(self, mock_shim_cls, mock_deps): + """DATE type extracts date in ISO format.""" + from executor.executors.legacy_executor import LegacyExecutor + + llm = _mock_llm() + response1 = MagicMock() + response1.text = "The date is January 15, 2024" + response2 = MagicMock() + response2.text = "2024-01-15" + llm.complete.side_effect = [ + {PSKeys.RESPONSE: response1, PSKeys.HIGHLIGHT_DATA: [], + PSKeys.CONFIDENCE_DATA: None, PSKeys.WORD_CONFIDENCE_DATA: None, + PSKeys.LINE_NUMBERS: [], PSKeys.WHISPER_HASH: ""}, + {PSKeys.RESPONSE: response2, PSKeys.HIGHLIGHT_DATA: [], + PSKeys.CONFIDENCE_DATA: None, PSKeys.WORD_CONFIDENCE_DATA: None, + PSKeys.LINE_NUMBERS: [], PSKeys.WHISPER_HASH: ""}, + ] + mock_deps.return_value = _mock_deps(llm) + mock_shim_cls.return_value = MagicMock() + + executor = LegacyExecutor() + ctx = _make_context(prompts=[_make_prompt(output_type="date")]) + result = executor._handle_answer_prompt(ctx) + + assert result.data[PSKeys.OUTPUT]["field_a"] == "2024-01-15" + + +class TestHandleAnswerPromptJSON: + """Tests for JSON type handling.""" + + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_json_parsed(self, mock_shim_cls, mock_deps): + """JSON type parses valid JSON from answer.""" + from executor.executors.legacy_executor import LegacyExecutor + + llm = _mock_llm() + response = MagicMock() + response.text = '{"key": "value"}' + llm.complete.return_value = { + PSKeys.RESPONSE: response, PSKeys.HIGHLIGHT_DATA: [], + PSKeys.CONFIDENCE_DATA: None, PSKeys.WORD_CONFIDENCE_DATA: None, + PSKeys.LINE_NUMBERS: [], PSKeys.WHISPER_HASH: "", + } + mock_deps.return_value = _mock_deps(llm) + mock_shim_cls.return_value = MagicMock() + + executor = LegacyExecutor() + ctx = _make_context(prompts=[_make_prompt(output_type="json")]) + result = executor._handle_answer_prompt(ctx) + + assert result.data[PSKeys.OUTPUT]["field_a"] == {"key": "value"} + + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_json_na_returns_none(self, mock_shim_cls, mock_deps): + """JSON type with NA answer returns None.""" + from executor.executors.legacy_executor import LegacyExecutor + + llm = _mock_llm() + response = MagicMock() + response.text = "NA" + llm.complete.return_value = { + PSKeys.RESPONSE: response, PSKeys.HIGHLIGHT_DATA: [], + PSKeys.CONFIDENCE_DATA: None, PSKeys.WORD_CONFIDENCE_DATA: None, + PSKeys.LINE_NUMBERS: [], PSKeys.WHISPER_HASH: "", + } + mock_deps.return_value = _mock_deps(llm) + mock_shim_cls.return_value = MagicMock() + + executor = LegacyExecutor() + ctx = _make_context(prompts=[_make_prompt(output_type="json")]) + result = executor._handle_answer_prompt(ctx) + + assert result.data[PSKeys.OUTPUT]["field_a"] is None + + +class TestHandleAnswerPromptRetrieval: + """Tests for retrieval integration.""" + + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_chunked_retrieval_uses_run_retrieval( + self, mock_shim_cls, mock_deps + ): + """chunk_size > 0 uses RetrievalService.run_retrieval.""" + from executor.executors.legacy_executor import LegacyExecutor + + llm = _mock_llm() + deps = _mock_deps(llm) + _, RetrievalService, *_ = deps + mock_deps.return_value = deps + mock_shim_cls.return_value = MagicMock() + + executor = LegacyExecutor() + ctx = _make_context( + prompts=[_make_prompt(chunk_size=512)] + ) + result = executor._handle_answer_prompt(ctx) + + RetrievalService.run_retrieval.assert_called_once() + assert result.success is True + + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_complete_context_for_chunk_zero( + self, mock_shim_cls, mock_deps + ): + """chunk_size=0 uses RetrievalService.retrieve_complete_context.""" + from executor.executors.legacy_executor import LegacyExecutor + + llm = _mock_llm() + deps = _mock_deps(llm) + _, RetrievalService, *_ = deps + mock_deps.return_value = deps + mock_shim_cls.return_value = MagicMock() + + executor = LegacyExecutor() + ctx = _make_context( + prompts=[_make_prompt(chunk_size=0)] + ) + result = executor._handle_answer_prompt(ctx) + + RetrievalService.retrieve_complete_context.assert_called_once() + assert result.success is True + + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_context_stored_in_metadata(self, mock_shim_cls, mock_deps): + """Retrieved context is stored in metadata.""" + from executor.executors.legacy_executor import LegacyExecutor + + llm = _mock_llm() + mock_deps.return_value = _mock_deps(llm) + mock_shim_cls.return_value = MagicMock() + + executor = LegacyExecutor() + result = executor._handle_answer_prompt(_make_context()) + + metadata = result.data[PSKeys.METADATA] + assert "field_a" in metadata[PSKeys.CONTEXT] + + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_invalid_strategy_skips_retrieval( + self, mock_shim_cls, mock_deps + ): + """Invalid retrieval strategy skips retrieval, answer stays NA.""" + from executor.executors.legacy_executor import LegacyExecutor + + llm = _mock_llm() + mock_deps.return_value = _mock_deps(llm) + mock_shim_cls.return_value = MagicMock() + + executor = LegacyExecutor() + ctx = _make_context( + prompts=[_make_prompt(retrieval_strategy="nonexistent")] + ) + result = executor._handle_answer_prompt(ctx) + + # Answer stays "NA" which gets sanitized to None + assert result.data[PSKeys.OUTPUT]["field_a"] is None + + +class TestHandleAnswerPromptMultiPrompt: + """Tests for multi-prompt processing.""" + + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_multiple_prompts(self, mock_shim_cls, mock_deps): + """Multiple prompts are all processed.""" + from executor.executors.legacy_executor import LegacyExecutor + + llm = _mock_llm() + mock_deps.return_value = _mock_deps(llm) + mock_shim_cls.return_value = MagicMock() + + prompts = [ + _make_prompt(name="revenue"), + _make_prompt(name="date_signed", output_type="text"), + ] + executor = LegacyExecutor() + ctx = _make_context(prompts=prompts) + result = executor._handle_answer_prompt(ctx) + + output = result.data[PSKeys.OUTPUT] + assert "revenue" in output + assert "date_signed" in output + + +class TestHandleAnswerPromptErrors: + """Tests for error handling.""" + + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_table_type_raises_error(self, mock_shim_cls, mock_deps): + """TABLE type raises LegacyExecutorError (plugins not available).""" + from executor.executors.legacy_executor import LegacyExecutor + + llm = _mock_llm() + mock_deps.return_value = _mock_deps(llm) + mock_shim_cls.return_value = MagicMock() + + executor = LegacyExecutor() + ctx = _make_context( + prompts=[_make_prompt(output_type="table")] + ) + # TABLE raises LegacyExecutorError which is caught by execute() + result = executor.execute(ctx) + assert result.success is False + assert "TABLE" in result.error + + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_line_item_type_raises_error(self, mock_shim_cls, mock_deps): + """LINE_ITEM type raises LegacyExecutorError.""" + from executor.executors.legacy_executor import LegacyExecutor + + llm = _mock_llm() + mock_deps.return_value = _mock_deps(llm) + mock_shim_cls.return_value = MagicMock() + + executor = LegacyExecutor() + ctx = _make_context( + prompts=[_make_prompt(output_type="line-item")] + ) + result = executor.execute(ctx) + assert result.success is False + assert "LINE_ITEM" in result.error + + +class TestHandleAnswerPromptMetrics: + """Tests for metrics collection.""" + + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_metrics_collected(self, mock_shim_cls, mock_deps): + """Metrics include context_retrieval and LLM metrics.""" + from executor.executors.legacy_executor import LegacyExecutor + + llm = _mock_llm() + mock_deps.return_value = _mock_deps(llm) + mock_shim_cls.return_value = MagicMock() + + executor = LegacyExecutor() + result = executor._handle_answer_prompt(_make_context()) + + metrics = result.data[PSKeys.METRICS] + assert "field_a" in metrics + assert "context_retrieval" in metrics["field_a"] + assert "extraction_llm" in metrics["field_a"] + + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_vectordb_closed(self, mock_shim_cls, mock_deps): + """VectorDB is closed after processing.""" + from executor.executors.legacy_executor import LegacyExecutor + + llm = _mock_llm() + deps = _mock_deps(llm) + mock_deps.return_value = deps + _, _, _, _, _, _, VectorDB = deps + vdb_instance = MagicMock() + VectorDB.return_value = vdb_instance + mock_shim_cls.return_value = MagicMock() + + executor = LegacyExecutor() + executor._handle_answer_prompt(_make_context()) + + vdb_instance.close.assert_called_once() + + +class TestNullSanitization: + """Tests for _sanitize_null_values.""" + + def test_na_string_becomes_none(self): + """Top-level 'NA' string → None.""" + from executor.executors.legacy_executor import LegacyExecutor + + output = {"field": "NA"} + result = LegacyExecutor._sanitize_null_values(output) + assert result["field"] is None + + def test_na_case_insensitive(self): + """'na' (lowercase) → None.""" + from executor.executors.legacy_executor import LegacyExecutor + + output = {"field": "na"} + result = LegacyExecutor._sanitize_null_values(output) + assert result["field"] is None + + def test_nested_list_na(self): + """NA in nested list items → None.""" + from executor.executors.legacy_executor import LegacyExecutor + + output = {"field": ["value", "NA", "other"]} + result = LegacyExecutor._sanitize_null_values(output) + assert result["field"] == ["value", None, "other"] + + def test_nested_dict_in_list_na(self): + """NA in dicts inside lists → None.""" + from executor.executors.legacy_executor import LegacyExecutor + + output = {"field": [{"a": "NA", "b": "ok"}]} + result = LegacyExecutor._sanitize_null_values(output) + assert result["field"] == [{"a": None, "b": "ok"}] + + def test_nested_dict_na(self): + """NA in nested dict values → None.""" + from executor.executors.legacy_executor import LegacyExecutor + + output = {"field": {"a": "NA", "b": "ok"}} + result = LegacyExecutor._sanitize_null_values(output) + assert result["field"] == {"a": None, "b": "ok"} + + def test_non_na_values_untouched(self): + """Non-NA values are not modified.""" + from executor.executors.legacy_executor import LegacyExecutor + + output = {"field": "hello", "num": 42, "flag": True} + result = LegacyExecutor._sanitize_null_values(output) + assert result == {"field": "hello", "num": 42, "flag": True} + + +class TestAnswerPromptServiceUnit: + """Unit tests for AnswerPromptService methods.""" + + def test_extract_variable_replaces_percent_vars(self): + """Replace %var% references in prompt text.""" + from executor.executors.answer_prompt import AnswerPromptService + + structured = {"field_a": "42"} + output = {"prompt": "Original: %field_a%"} + result = AnswerPromptService.extract_variable( + structured, ["field_a"], output, "Value is %field_a%" + ) + assert result == "Value is 42" + + def test_extract_variable_missing_raises(self): + """Missing variable raises ValueError.""" + from executor.executors.answer_prompt import AnswerPromptService + + output = {"prompt": "test"} + with pytest.raises(ValueError, match="not found"): + AnswerPromptService.extract_variable( + {}, ["missing_var"], output, "Value is %missing_var%" + ) + + def test_construct_prompt_includes_all_parts(self): + """Constructed prompt includes preamble, prompt, postamble, context.""" + from executor.executors.answer_prompt import AnswerPromptService + + result = AnswerPromptService.construct_prompt( + preamble="You are a helpful assistant", + prompt="What is the revenue?", + postamble="Be precise", + grammar_list=[], + context="Revenue was $1M", + platform_postamble="", + word_confidence_postamble="", + ) + assert "You are a helpful assistant" in result + assert "What is the revenue?" in result + assert "Be precise" in result + assert "Revenue was $1M" in result + assert "Answer:" in result + + def test_construct_prompt_with_grammar(self): + """Grammar list adds synonym notes.""" + from executor.executors.answer_prompt import AnswerPromptService + + result = AnswerPromptService.construct_prompt( + preamble="", + prompt="Find the amount", + postamble="", + grammar_list=[{"word": "amount", "synonyms": ["sum", "total"]}], + context="test", + platform_postamble="", + word_confidence_postamble="", + ) + assert "amount" in result + assert "sum, total" in result + + +class TestVariableReplacementService: + """Tests for the VariableReplacementService.""" + + def test_is_variables_present_true(self): + """Detects {{variables}} in text.""" + from executor.executors.variable_replacement import ( + VariableReplacementService, + ) + + assert VariableReplacementService.is_variables_present( + "Hello {{name}}" + ) is True + + def test_is_variables_present_false(self): + """Returns False when no variables present.""" + from executor.executors.variable_replacement import ( + VariableReplacementService, + ) + + assert VariableReplacementService.is_variables_present( + "Hello world" + ) is False + + def test_replace_static_variable(self): + """Static variable {{var}} is replaced with structured output value.""" + from executor.executors.variable_replacement import ( + VariableReplacementHelper, + ) + + result = VariableReplacementHelper.replace_static_variable( + prompt="Total is {{revenue}}", + structured_output={"revenue": "$1M"}, + variable="revenue", + ) + assert result == "Total is $1M" + + def test_custom_data_variable(self): + """Custom data variable {{custom_data.key}} is replaced.""" + from executor.executors.variable_replacement import ( + VariableReplacementHelper, + ) + + result = VariableReplacementHelper.replace_custom_data_variable( + prompt="Company: {{custom_data.company_name}}", + variable="custom_data.company_name", + custom_data={"company_name": "Acme Inc"}, + ) + assert result == "Company: Acme Inc" + + def test_custom_data_missing_raises(self): + """Missing custom data key raises CustomDataError.""" + from executor.executors.exceptions import CustomDataError + from executor.executors.variable_replacement import ( + VariableReplacementHelper, + ) + + with pytest.raises(CustomDataError): + VariableReplacementHelper.replace_custom_data_variable( + prompt="{{custom_data.missing}}", + variable="custom_data.missing", + custom_data={"other": "value"}, + ) diff --git a/workers/tests/test_executor_sanity.py b/workers/tests/test_executor_sanity.py new file mode 100644 index 0000000000..8f0c10927a --- /dev/null +++ b/workers/tests/test_executor_sanity.py @@ -0,0 +1,288 @@ +"""Phase 1 Sanity Check — Executor worker integration tests. + +These tests verify the full executor chain works end-to-end. + +Verifies: +1. Worker enums and registry configuration +2. ExecutorToolShim works from workers venv +3. NoOpExecutor registers and executes via orchestrator +4. Celery task wiring (execute_extraction task logic) +5. Full dispatch -> task -> orchestrator -> executor round-trip +6. Retry configuration on the task +""" + +import pytest +from unstract.sdk1.execution.context import ExecutionContext +from unstract.sdk1.execution.executor import BaseExecutor +from unstract.sdk1.execution.orchestrator import ExecutionOrchestrator +from unstract.sdk1.execution.registry import ExecutorRegistry +from unstract.sdk1.execution.result import ExecutionResult + + +@pytest.fixture(autouse=True) +def _clean_registry(): + """Ensure a clean executor registry for every test.""" + ExecutorRegistry.clear() + yield + ExecutorRegistry.clear() + + +def _make_context(**overrides): + defaults = { + "executor_name": "noop", + "operation": "extract", + "run_id": "run-sanity-001", + "execution_source": "tool", + "organization_id": "org-test", + "request_id": "req-sanity-001", + } + defaults.update(overrides) + return ExecutionContext(**defaults) + + +def _register_noop(): + """Register a NoOpExecutor for testing.""" + + @ExecutorRegistry.register + class NoOpExecutor(BaseExecutor): + @property + def name(self): + return "noop" + + def execute(self, context): + return ExecutionResult( + success=True, + data={"echo": context.operation, "run_id": context.run_id}, + metadata={"executor": self.name}, + ) + + +# --- 1. Worker enums and registry --- + + +class TestWorkerEnumsAndRegistry: + """Verify executor is properly registered in worker infrastructure.""" + + def test_worker_type_executor_exists(self): + from shared.enums.worker_enums import WorkerType + + assert WorkerType.EXECUTOR.value == "executor" + + def test_queue_name_executor_exists(self): + from shared.enums.worker_enums import QueueName + + assert QueueName.EXECUTOR.value == "celery_executor_legacy" + + def test_task_name_execute_extraction_exists(self): + from shared.enums.task_enums import TaskName + + assert TaskName.EXECUTE_EXTRACTION.value == "execute_extraction" + + def test_health_port_is_8088(self): + from shared.enums.worker_enums import WorkerType + + assert WorkerType.EXECUTOR.to_health_port() == 8088 + + def test_worker_registry_has_executor_config(self): + from shared.enums.worker_enums import WorkerType + from shared.infrastructure.config.registry import WorkerRegistry + + config = WorkerRegistry.get_queue_config(WorkerType.EXECUTOR) + assert "celery_executor_legacy" in config.all_queues() + + def test_task_routing_includes_execute_extraction(self): + from shared.enums.worker_enums import WorkerType + from shared.infrastructure.config.registry import WorkerRegistry + + routing = WorkerRegistry.get_task_routing(WorkerType.EXECUTOR) + patterns = [r.pattern for r in routing.routes] + assert "execute_extraction" in patterns + + +# --- 2. ExecutorToolShim --- + + +class TestExecutorToolShim: + """Verify the real ExecutorToolShim works in the workers venv.""" + + def test_import(self): + from executor.executor_tool_shim import ExecutorToolShim + + shim = ExecutorToolShim(platform_api_key="sk-test") + assert shim.platform_api_key == "sk-test" + + def test_platform_key_returned(self): + from executor.executor_tool_shim import ExecutorToolShim + + shim = ExecutorToolShim(platform_api_key="sk-real-key") + assert shim.get_env_or_die("PLATFORM_SERVICE_API_KEY") == "sk-real-key" + + def test_env_var_from_environ(self, monkeypatch): + from executor.executor_tool_shim import ExecutorToolShim + + monkeypatch.setenv("TEST_SHIM_VAR", "hello") + shim = ExecutorToolShim(platform_api_key="sk-test") + assert shim.get_env_or_die("TEST_SHIM_VAR") == "hello" + + def test_missing_var_raises(self): + from executor.executor_tool_shim import ExecutorToolShim + from unstract.sdk1.exceptions import SdkError + + shim = ExecutorToolShim(platform_api_key="sk-test") + with pytest.raises(SdkError, match="NONEXISTENT"): + shim.get_env_or_die("NONEXISTENT") + + def test_stream_log_does_not_print_json(self, capsys): + """stream_log routes to logging, not stdout JSON.""" + from executor.executor_tool_shim import ExecutorToolShim + + shim = ExecutorToolShim(platform_api_key="sk-test") + shim.stream_log("test message") + captured = capsys.readouterr() + # Should NOT produce JSON on stdout (that's the old protocol) + assert '"type": "LOG"' not in captured.out + + def test_stream_error_raises_sdk_error(self): + from executor.executor_tool_shim import ExecutorToolShim + from unstract.sdk1.exceptions import SdkError + + shim = ExecutorToolShim(platform_api_key="sk-test") + with pytest.raises(SdkError, match="boom"): + shim.stream_error_and_exit("boom") + + +# --- 3. NoOpExecutor via Orchestrator --- + + +class TestNoOpExecutorOrchestrator: + """Verify a NoOpExecutor works through the orchestrator.""" + + def test_noop_executor_round_trip(self): + _register_noop() + + ctx = _make_context(operation="extract") + orchestrator = ExecutionOrchestrator() + result = orchestrator.execute(ctx) + + assert result.success is True + assert result.data == {"echo": "extract", "run_id": "run-sanity-001"} + + def test_unknown_executor_fails_gracefully(self): + orchestrator = ExecutionOrchestrator() + ctx = _make_context(executor_name="nonexistent") + result = orchestrator.execute(ctx) + + assert result.success is False + assert "nonexistent" in result.error + + +# --- 4 & 5. Full chain with Celery eager mode --- +# +# executor/worker.py imports executor/tasks.py which defines +# execute_extraction as a shared_task. We import the real app, +# configure it for eager mode, and exercise the actual task. + + +@pytest.fixture +def eager_app(): + """Configure the real executor Celery app for eager-mode testing.""" + from executor.worker import app + + original = { + "task_always_eager": app.conf.task_always_eager, + "task_eager_propagates": app.conf.task_eager_propagates, + "result_backend": app.conf.result_backend, + } + + app.conf.update( + task_always_eager=True, + task_eager_propagates=False, + result_backend="cache+memory://", + ) + + yield app + + app.conf.update(original) + + +class TestCeleryTaskWiring: + """Verify the execute_extraction task configuration.""" + + def test_task_is_registered(self, eager_app): + assert "execute_extraction" in eager_app.tasks + + def test_task_has_retry_config(self, eager_app): + task = eager_app.tasks["execute_extraction"] + assert task.max_retries == 3 + assert ConnectionError in task.autoretry_for + assert TimeoutError in task.autoretry_for + assert OSError in task.autoretry_for + + def test_task_retry_backoff_enabled(self, eager_app): + task = eager_app.tasks["execute_extraction"] + assert task.retry_backoff is True + assert task.retry_jitter is True + + +class TestFullChainEager: + """End-to-end test using Celery's eager mode. + + task_always_eager=True makes tasks execute inline in the + calling process — full chain without a broker. + """ + + def _run_task(self, eager_app, context_dict): + """Run execute_extraction task via task.apply() (eager-safe).""" + task = eager_app.tasks["execute_extraction"] + result = task.apply(args=[context_dict]) + return result.get() + + def test_eager_dispatch_round_trip(self, eager_app): + """Execute task inline, verify result comes back.""" + _register_noop() + + ctx = _make_context(operation="answer_prompt", run_id="run-eager") + result_dict = self._run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is True + assert result.data["echo"] == "answer_prompt" + assert result.data["run_id"] == "run-eager" + assert result.metadata.get("executor") == "noop" + + def test_eager_dispatch_invalid_context(self, eager_app): + """Invalid context dict returns failure result (not exception).""" + result_dict = self._run_task(eager_app, {"bad": "data"}) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is False + assert "Invalid execution context" in result.error + + def test_eager_dispatch_unknown_executor(self, eager_app): + """Unknown executor returns failure (no unhandled exceptions).""" + ctx = _make_context(executor_name="does_not_exist") + result_dict = self._run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is False + assert "does_not_exist" in result.error + + def test_result_serialization_round_trip(self, eager_app): + """Verify ExecutionResult survives Celery serialization.""" + _register_noop() + + ctx = _make_context( + operation="single_pass_extraction", + executor_params={"schema": {"name": "str", "age": "int"}}, + ) + result_dict = self._run_task(eager_app, ctx.to_dict()) + + # Verify the raw dict is JSON-compatible + import json + + serialized = json.dumps(result_dict) + deserialized = json.loads(serialized) + + result = ExecutionResult.from_dict(deserialized) + assert result.success is True + assert result.data["echo"] == "single_pass_extraction" diff --git a/workers/tests/test_legacy_executor_extract.py b/workers/tests/test_legacy_executor_extract.py new file mode 100644 index 0000000000..0711d2255a --- /dev/null +++ b/workers/tests/test_legacy_executor_extract.py @@ -0,0 +1,594 @@ +"""Phase 2B — LegacyExecutor._handle_extract tests. + +Verifies: +1. Happy path: extraction returns success with extracted_text +2. With highlight (LLMWhisperer): enable_highlight passed through +3. Without highlight (non-Whisperer): enable_highlight NOT passed +4. AdapterError → failure result +5. Missing required params → failure result +6. Metadata update for tool source: ToolUtils.dump_json called +7. IDE source skips metadata writing +8. FileUtils routing: correct storage type for ide vs tool +9. Orchestrator integration: extract returns success (mocked) +10. Celery eager-mode: full task chain returns extraction result +11. LegacyExecutorError caught by execute() → failure result +""" + +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +from executor.executors.constants import ( + FileStorageKeys, + IndexingConstants as IKeys, +) +from executor.executors.exceptions import LegacyExecutorError +from unstract.sdk1.adapters.x2text.constants import X2TextConstants +from unstract.sdk1.adapters.x2text.dto import ( + TextExtractionMetadata, + TextExtractionResult, +) +from unstract.sdk1.execution.context import ExecutionContext +from unstract.sdk1.execution.orchestrator import ExecutionOrchestrator +from unstract.sdk1.execution.registry import ExecutorRegistry +from unstract.sdk1.execution.result import ExecutionResult + + +@pytest.fixture(autouse=True) +def _clean_registry(): + """Ensure a clean executor registry for every test.""" + ExecutorRegistry.clear() + yield + ExecutorRegistry.clear() + + +def _register_legacy(): + from executor.executors.legacy_executor import LegacyExecutor # noqa: F401 + + ExecutorRegistry.register(LegacyExecutor) + + +def _make_context(**overrides): + defaults = { + "executor_name": "legacy", + "operation": "extract", + "run_id": "run-2b-001", + "execution_source": "tool", + "organization_id": "org-test", + "request_id": "req-2b-001", + "executor_params": { + "x2text_instance_id": "x2t-001", + "file_path": "/data/test.pdf", + "platform_api_key": "sk-test-key", + }, + } + defaults.update(overrides) + return ExecutionContext(**defaults) + + +def _mock_process_response(extracted_text="hello world", whisper_hash="hash-123"): + """Build a mock TextExtractionResult.""" + metadata = TextExtractionMetadata(whisper_hash=whisper_hash) + return TextExtractionResult( + extracted_text=extracted_text, + extraction_metadata=metadata, + ) + + +# --- 1. Happy path --- + + +class TestHappyPath: + @patch("executor.executors.legacy_executor.FileUtils.get_fs_instance") + @patch("executor.executors.legacy_executor.X2Text") + def test_extract_returns_success(self, mock_x2text_cls, mock_get_fs): + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + mock_x2text = MagicMock() + mock_x2text.process.return_value = _mock_process_response("hello") + mock_x2text.x2text_instance = MagicMock() # not a Whisperer + mock_x2text_cls.return_value = mock_x2text + mock_get_fs.return_value = MagicMock() + + ctx = _make_context() + result = executor.execute(ctx) + + assert result.success is True + assert result.data[IKeys.EXTRACTED_TEXT] == "hello" + + @patch("executor.executors.legacy_executor.FileUtils.get_fs_instance") + @patch("executor.executors.legacy_executor.X2Text") + def test_extract_passes_correct_params_to_x2text( + self, mock_x2text_cls, mock_get_fs + ): + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + mock_x2text = MagicMock() + mock_x2text.process.return_value = _mock_process_response() + mock_x2text.x2text_instance = MagicMock() + mock_x2text_cls.return_value = mock_x2text + mock_get_fs.return_value = MagicMock() + + ctx = _make_context( + executor_params={ + "x2text_instance_id": "x2t-002", + "file_path": "/data/doc.pdf", + "platform_api_key": "sk-key", + "usage_kwargs": {"org": "test-org"}, + } + ) + executor.execute(ctx) + + mock_x2text_cls.assert_called_once() + call_kwargs = mock_x2text_cls.call_args + assert call_kwargs.kwargs.get("adapter_instance_id") == "x2t-002" or ( + call_kwargs.args + and len(call_kwargs.args) > 1 + and call_kwargs.args[1] == "x2t-002" + ) + + +# --- 2. With highlight (LLMWhisperer) --- + + +class TestWithHighlight: + @patch("executor.executors.legacy_executor.ToolUtils.dump_json") + @patch("executor.executors.legacy_executor.FileUtils.get_fs_instance") + @patch("executor.executors.legacy_executor.X2Text") + def test_highlight_with_whisperer_v2( + self, mock_x2text_cls, mock_get_fs, mock_dump + ): + from unstract.sdk1.adapters.x2text.llm_whisperer_v2.src import LLMWhispererV2 + + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + mock_x2text = MagicMock() + mock_x2text.process.return_value = _mock_process_response() + # Make isinstance check pass for LLMWhispererV2 + mock_x2text.x2text_instance = MagicMock(spec=LLMWhispererV2) + mock_x2text_cls.return_value = mock_x2text + mock_get_fs.return_value = MagicMock() + + ctx = _make_context( + executor_params={ + "x2text_instance_id": "x2t-whisperer", + "file_path": "/data/test.pdf", + "platform_api_key": "sk-key", + "enable_highlight": True, + "execution_data_dir": "/data/run", + "tool_execution_metadata": {}, + } + ) + result = executor.execute(ctx) + + assert result.success is True + # Verify enable_highlight was passed to process() + mock_x2text.process.assert_called_once() + call_kwargs = mock_x2text.process.call_args.kwargs + assert call_kwargs.get("enable_highlight") is True + + @patch("executor.executors.legacy_executor.ToolUtils.dump_json") + @patch("executor.executors.legacy_executor.FileUtils.get_fs_instance") + @patch("executor.executors.legacy_executor.X2Text") + def test_highlight_with_whisperer_v1( + self, mock_x2text_cls, mock_get_fs, mock_dump + ): + from unstract.sdk1.adapters.x2text.llm_whisperer.src import LLMWhisperer + + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + mock_x2text = MagicMock() + mock_x2text.process.return_value = _mock_process_response() + mock_x2text.x2text_instance = MagicMock(spec=LLMWhisperer) + mock_x2text_cls.return_value = mock_x2text + mock_get_fs.return_value = MagicMock() + + ctx = _make_context( + executor_params={ + "x2text_instance_id": "x2t-whisperer-v1", + "file_path": "/data/test.pdf", + "platform_api_key": "sk-key", + "enable_highlight": True, + "execution_data_dir": "/data/run", + "tool_execution_metadata": {}, + } + ) + result = executor.execute(ctx) + + assert result.success is True + call_kwargs = mock_x2text.process.call_args.kwargs + assert call_kwargs.get("enable_highlight") is True + + +# --- 3. Without highlight (non-Whisperer) --- + + +class TestWithoutHighlight: + @patch("executor.executors.legacy_executor.FileUtils.get_fs_instance") + @patch("executor.executors.legacy_executor.X2Text") + def test_no_highlight_for_non_whisperer(self, mock_x2text_cls, mock_get_fs): + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + mock_x2text = MagicMock() + mock_x2text.process.return_value = _mock_process_response() + # Generic adapter — not LLMWhisperer + mock_x2text.x2text_instance = MagicMock() + mock_x2text_cls.return_value = mock_x2text + mock_get_fs.return_value = MagicMock() + + ctx = _make_context( + executor_params={ + "x2text_instance_id": "x2t-generic", + "file_path": "/data/test.pdf", + "platform_api_key": "sk-key", + "enable_highlight": True, # requested but adapter doesn't support it + } + ) + result = executor.execute(ctx) + + assert result.success is True + # enable_highlight should NOT be in process() call + call_kwargs = mock_x2text.process.call_args.kwargs + assert "enable_highlight" not in call_kwargs + + @patch("executor.executors.legacy_executor.FileUtils.get_fs_instance") + @patch("executor.executors.legacy_executor.X2Text") + def test_highlight_false_skips_whisperer_branch( + self, mock_x2text_cls, mock_get_fs + ): + from unstract.sdk1.adapters.x2text.llm_whisperer_v2.src import LLMWhispererV2 + + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + mock_x2text = MagicMock() + mock_x2text.process.return_value = _mock_process_response() + mock_x2text.x2text_instance = MagicMock(spec=LLMWhispererV2) + mock_x2text_cls.return_value = mock_x2text + mock_get_fs.return_value = MagicMock() + + ctx = _make_context( + executor_params={ + "x2text_instance_id": "x2t-whisperer", + "file_path": "/data/test.pdf", + "platform_api_key": "sk-key", + "enable_highlight": False, # highlight disabled + } + ) + result = executor.execute(ctx) + + assert result.success is True + call_kwargs = mock_x2text.process.call_args.kwargs + assert "enable_highlight" not in call_kwargs + + +# --- 4. AdapterError → failure result --- + + +class TestAdapterError: + @patch("executor.executors.legacy_executor.FileUtils.get_fs_instance") + @patch("executor.executors.legacy_executor.X2Text") + def test_adapter_error_returns_failure(self, mock_x2text_cls, mock_get_fs): + from unstract.sdk1.adapters.exceptions import AdapterError + + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + mock_x2text = MagicMock() + mock_x2text.x2text_instance = MagicMock() + mock_x2text.x2text_instance.get_name.return_value = "TestExtractor" + mock_x2text.process.side_effect = AdapterError("connection timeout") + mock_x2text_cls.return_value = mock_x2text + mock_get_fs.return_value = MagicMock() + + ctx = _make_context() + result = executor.execute(ctx) + + assert result.success is False + assert "TestExtractor" in result.error + assert "connection timeout" in result.error + + +# --- 5. Missing required params --- + + +class TestMissingParams: + @patch("executor.executors.legacy_executor.FileUtils.get_fs_instance") + @patch("executor.executors.legacy_executor.X2Text") + def test_missing_x2text_instance_id(self, mock_x2text_cls, mock_get_fs): + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + ctx = _make_context( + executor_params={ + "file_path": "/data/test.pdf", + "platform_api_key": "sk-key", + } + ) + result = executor.execute(ctx) + + assert result.success is False + assert "x2text_instance_id" in result.error + mock_x2text_cls.assert_not_called() + + @patch("executor.executors.legacy_executor.FileUtils.get_fs_instance") + @patch("executor.executors.legacy_executor.X2Text") + def test_missing_file_path(self, mock_x2text_cls, mock_get_fs): + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + ctx = _make_context( + executor_params={ + "x2text_instance_id": "x2t-001", + "platform_api_key": "sk-key", + } + ) + result = executor.execute(ctx) + + assert result.success is False + assert "file_path" in result.error + mock_x2text_cls.assert_not_called() + + @patch("executor.executors.legacy_executor.FileUtils.get_fs_instance") + @patch("executor.executors.legacy_executor.X2Text") + def test_missing_both_params(self, mock_x2text_cls, mock_get_fs): + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + ctx = _make_context(executor_params={"platform_api_key": "sk-key"}) + result = executor.execute(ctx) + + assert result.success is False + assert "x2text_instance_id" in result.error + assert "file_path" in result.error + + +# --- 6. Metadata update for tool source --- + + +class TestMetadataToolSource: + @patch("executor.executors.legacy_executor.ToolUtils.dump_json") + @patch("executor.executors.legacy_executor.FileUtils.get_fs_instance") + @patch("executor.executors.legacy_executor.X2Text") + def test_tool_source_writes_metadata( + self, mock_x2text_cls, mock_get_fs, mock_dump + ): + from unstract.sdk1.adapters.x2text.llm_whisperer_v2.src import LLMWhispererV2 + + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + mock_x2text = MagicMock() + mock_x2text.process.return_value = _mock_process_response( + whisper_hash="whash-456" + ) + mock_x2text.x2text_instance = MagicMock(spec=LLMWhispererV2) + mock_x2text_cls.return_value = mock_x2text + mock_fs = MagicMock() + mock_get_fs.return_value = mock_fs + + tool_meta = {} + ctx = _make_context( + execution_source="tool", + executor_params={ + "x2text_instance_id": "x2t-whisperer", + "file_path": "/data/test.pdf", + "platform_api_key": "sk-key", + "enable_highlight": True, + "execution_data_dir": "/run/data", + "tool_execution_metadata": tool_meta, + }, + ) + result = executor.execute(ctx) + + assert result.success is True + # ToolUtils.dump_json should have been called + mock_dump.assert_called_once() + dump_kwargs = mock_dump.call_args.kwargs + assert dump_kwargs["file_to_dump"] == str( + Path("/run/data") / IKeys.METADATA_FILE + ) + assert dump_kwargs["json_to_dump"] == { + X2TextConstants.WHISPER_HASH: "whash-456" + } + assert dump_kwargs["fs"] is mock_fs + # tool_exec_metadata should be updated in-place + assert tool_meta[X2TextConstants.WHISPER_HASH] == "whash-456" + + +# --- 7. IDE source skips metadata --- + + +class TestMetadataIDESource: + @patch("executor.executors.legacy_executor.ToolUtils.dump_json") + @patch("executor.executors.legacy_executor.FileUtils.get_fs_instance") + @patch("executor.executors.legacy_executor.X2Text") + def test_ide_source_skips_metadata( + self, mock_x2text_cls, mock_get_fs, mock_dump + ): + from unstract.sdk1.adapters.x2text.llm_whisperer_v2.src import LLMWhispererV2 + + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + mock_x2text = MagicMock() + mock_x2text.process.return_value = _mock_process_response() + mock_x2text.x2text_instance = MagicMock(spec=LLMWhispererV2) + mock_x2text_cls.return_value = mock_x2text + mock_get_fs.return_value = MagicMock() + + ctx = _make_context( + execution_source="ide", + executor_params={ + "x2text_instance_id": "x2t-whisperer", + "file_path": "/data/test.pdf", + "platform_api_key": "sk-key", + "enable_highlight": True, + }, + ) + result = executor.execute(ctx) + + assert result.success is True + mock_dump.assert_not_called() + + +# --- 8. FileUtils routing --- + + +class TestFileUtilsRouting: + @patch("executor.executors.file_utils.EnvHelper.get_storage") + def test_ide_returns_permanent_storage(self, mock_get_storage): + from executor.executors.file_utils import FileUtils + from unstract.sdk1.file_storage.constants import StorageType + + mock_get_storage.return_value = MagicMock() + FileUtils.get_fs_instance("ide") + + mock_get_storage.assert_called_once_with( + storage_type=StorageType.PERMANENT, + env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE, + ) + + @patch("executor.executors.file_utils.EnvHelper.get_storage") + def test_tool_returns_temporary_storage(self, mock_get_storage): + from executor.executors.file_utils import FileUtils + from unstract.sdk1.file_storage.constants import StorageType + + mock_get_storage.return_value = MagicMock() + FileUtils.get_fs_instance("tool") + + mock_get_storage.assert_called_once_with( + storage_type=StorageType.SHARED_TEMPORARY, + env_name=FileStorageKeys.TEMPORARY_REMOTE_STORAGE, + ) + + def test_invalid_source_raises_value_error(self): + from executor.executors.file_utils import FileUtils + + with pytest.raises(ValueError, match="Invalid execution source"): + FileUtils.get_fs_instance("unknown") + + +# --- 9. Orchestrator integration --- + + +class TestOrchestratorIntegration: + @patch("executor.executors.legacy_executor.FileUtils.get_fs_instance") + @patch("executor.executors.legacy_executor.X2Text") + def test_orchestrator_extract_returns_success( + self, mock_x2text_cls, mock_get_fs + ): + _register_legacy() + orchestrator = ExecutionOrchestrator() + + mock_x2text = MagicMock() + mock_x2text.process.return_value = _mock_process_response("extracted!") + mock_x2text.x2text_instance = MagicMock() + mock_x2text_cls.return_value = mock_x2text + mock_get_fs.return_value = MagicMock() + + ctx = _make_context() + result = orchestrator.execute(ctx) + + assert result.success is True + assert result.data[IKeys.EXTRACTED_TEXT] == "extracted!" + + +# --- 10. Celery eager-mode --- + + +@pytest.fixture +def eager_app(): + """Configure the real executor Celery app for eager-mode testing.""" + from executor.worker import app + + original = { + "task_always_eager": app.conf.task_always_eager, + "task_eager_propagates": app.conf.task_eager_propagates, + "result_backend": app.conf.result_backend, + } + app.conf.update( + task_always_eager=True, + task_eager_propagates=False, + result_backend="cache+memory://", + ) + yield app + app.conf.update(original) + + +class TestCeleryEager: + @patch("executor.executors.legacy_executor.FileUtils.get_fs_instance") + @patch("executor.executors.legacy_executor.X2Text") + def test_eager_extract_returns_success( + self, mock_x2text_cls, mock_get_fs, eager_app + ): + _register_legacy() + + mock_x2text = MagicMock() + mock_x2text.process.return_value = _mock_process_response("celery text") + mock_x2text.x2text_instance = MagicMock() + mock_x2text_cls.return_value = mock_x2text + mock_get_fs.return_value = MagicMock() + + ctx = _make_context() + task = eager_app.tasks["execute_extraction"] + result_dict = task.apply(args=[ctx.to_dict()]).get() + result = ExecutionResult.from_dict(result_dict) + + assert result.success is True + assert result.data[IKeys.EXTRACTED_TEXT] == "celery text" + + +# --- 11. LegacyExecutorError caught by execute() --- + + +class TestExecuteErrorCatching: + @patch("executor.executors.legacy_executor.FileUtils.get_fs_instance") + @patch("executor.executors.legacy_executor.X2Text") + def test_extraction_error_caught_by_execute( + self, mock_x2text_cls, mock_get_fs + ): + """ExtractionError (a LegacyExecutorError) is caught in execute() + and mapped to ExecutionResult.failure().""" + from unstract.sdk1.adapters.exceptions import AdapterError + + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + mock_x2text = MagicMock() + mock_x2text.x2text_instance = MagicMock() + mock_x2text.x2text_instance.get_name.return_value = "BadExtractor" + mock_x2text.process.side_effect = AdapterError("timeout") + mock_x2text_cls.return_value = mock_x2text + mock_get_fs.return_value = MagicMock() + + ctx = _make_context() + result = executor.execute(ctx) + + # Should be a clean failure, NOT an unhandled exception + assert result.success is False + assert "BadExtractor" in result.error + assert "timeout" in result.error + + def test_legacy_executor_error_subclass_caught(self): + """Any LegacyExecutorError subclass raised by a handler is caught.""" + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + # Monkey-patch _handle_extract to raise a LegacyExecutorError + def _raise_err(ctx): + raise LegacyExecutorError(message="custom error", code=422) + + executor._handle_extract = _raise_err + + ctx = _make_context() + result = executor.execute(ctx) + + assert result.success is False + assert result.error == "custom error" diff --git a/workers/tests/test_legacy_executor_index.py b/workers/tests/test_legacy_executor_index.py new file mode 100644 index 0000000000..d87d5b5b97 --- /dev/null +++ b/workers/tests/test_legacy_executor_index.py @@ -0,0 +1,453 @@ +"""Phase 2C — LegacyExecutor._handle_index tests. + +Verifies: +1. Happy path: indexing returns success with doc_id +2. Chunk size 0: skips indexing, still returns doc_id +3. Missing required params → failure result +4. Reindex flag: passes reindex through to Index +5. VectorDB.close() always called (even on error) +6. Indexing error → LegacyExecutorError → failure result +7. Orchestrator integration: index returns success (mocked) +8. Celery eager-mode: full task chain returns indexing result +9. Index class: generate_index_key called with correct DTOs +10. EmbeddingCompat and VectorDB created with correct params + +Heavy SDK1 dependencies (llama_index, qdrant) are lazily imported +via ``LegacyExecutor._get_indexing_deps()``. We mock that method +to avoid protobuf conflicts in the test environment. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from executor.executors.constants import IndexingConstants as IKeys +from unstract.sdk1.execution.context import ExecutionContext +from unstract.sdk1.execution.orchestrator import ExecutionOrchestrator +from unstract.sdk1.execution.registry import ExecutorRegistry +from unstract.sdk1.execution.result import ExecutionResult + + +@pytest.fixture(autouse=True) +def _clean_registry(): + ExecutorRegistry.clear() + yield + ExecutorRegistry.clear() + + +def _register_legacy(): + from executor.executors.legacy_executor import LegacyExecutor # noqa: F401 + + ExecutorRegistry.register(LegacyExecutor) + + +def _make_index_context(**overrides): + defaults = { + "executor_name": "legacy", + "operation": "index", + "run_id": "run-2c-001", + "execution_source": "tool", + "organization_id": "org-test", + "request_id": "req-2c-001", + "executor_params": { + "embedding_instance_id": "emb-001", + "vector_db_instance_id": "vdb-001", + "x2text_instance_id": "x2t-001", + "file_path": "/data/test.pdf", + "file_hash": "abc123", + "extracted_text": "Hello world", + "platform_api_key": "sk-test", + "chunk_size": 512, + "chunk_overlap": 128, + }, + } + defaults.update(overrides) + return ExecutionContext(**defaults) + + +_PATCH_FS = "executor.executors.legacy_executor.FileUtils.get_fs_instance" +_PATCH_DEPS = ( + "executor.executors.legacy_executor.LegacyExecutor._get_indexing_deps" +) + + +@pytest.fixture +def mock_indexing_deps(): + """Mock the heavy indexing dependencies via _get_indexing_deps().""" + mock_index_cls = MagicMock() + mock_emb_cls = MagicMock() + mock_vdb_cls = MagicMock() + + with patch(_PATCH_DEPS, return_value=(mock_index_cls, mock_emb_cls, mock_vdb_cls)): + yield mock_index_cls, mock_emb_cls, mock_vdb_cls + + +def _setup_mock_index(mock_index_cls, doc_id="doc-hash-123"): + """Configure a mock Index instance.""" + mock_index = MagicMock() + mock_index.generate_index_key.return_value = doc_id + mock_index.is_document_indexed.return_value = False + mock_index.perform_indexing.return_value = doc_id + mock_index_cls.return_value = mock_index + return mock_index + + +# --- 1. Happy path --- + + +class TestHappyPath: + @patch(_PATCH_FS) + def test_index_returns_success_with_doc_id( + self, mock_get_fs, mock_indexing_deps + ): + mock_index_cls, mock_emb_cls, mock_vdb_cls = mock_indexing_deps + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + _setup_mock_index(mock_index_cls, "doc-hash-123") + mock_emb_cls.return_value = MagicMock() + mock_vdb = MagicMock() + mock_vdb_cls.return_value = mock_vdb + mock_get_fs.return_value = MagicMock() + + ctx = _make_index_context() + result = executor.execute(ctx) + + assert result.success is True + assert result.data[IKeys.DOC_ID] == "doc-hash-123" + mock_vdb.close.assert_called_once() + + +# --- 2. Chunk size 0: skips indexing --- + + +class TestChunkSizeZero: + @patch( + "unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-zero-chunk", + ) + @patch(_PATCH_FS) + def test_chunk_size_zero_skips_indexing(self, mock_get_fs, mock_gen_key): + _register_legacy() + executor = ExecutorRegistry.get("legacy") + mock_get_fs.return_value = MagicMock() + + ctx = _make_index_context( + executor_params={ + "embedding_instance_id": "emb-001", + "vector_db_instance_id": "vdb-001", + "x2text_instance_id": "x2t-001", + "file_path": "/data/test.pdf", + "file_hash": "abc123", + "extracted_text": "text", + "platform_api_key": "sk-test", + "chunk_size": 0, + "chunk_overlap": 0, + } + ) + result = executor.execute(ctx) + + assert result.success is True + assert result.data[IKeys.DOC_ID] == "doc-zero-chunk" + mock_gen_key.assert_called_once() + + +# --- 3. Missing required params --- + + +class TestMissingParams: + def test_missing_embedding_instance_id(self): + _register_legacy() + executor = ExecutorRegistry.get("legacy") + ctx = _make_index_context( + executor_params={ + "vector_db_instance_id": "vdb-001", + "x2text_instance_id": "x2t-001", + "file_path": "/data/test.pdf", + "platform_api_key": "sk-test", + } + ) + result = executor.execute(ctx) + assert result.success is False + assert "embedding_instance_id" in result.error + + def test_missing_multiple_params(self): + _register_legacy() + executor = ExecutorRegistry.get("legacy") + ctx = _make_index_context( + executor_params={"platform_api_key": "sk-test"} + ) + result = executor.execute(ctx) + assert result.success is False + assert "embedding_instance_id" in result.error + assert "vector_db_instance_id" in result.error + assert "x2text_instance_id" in result.error + assert "file_path" in result.error + + +# --- 4. Reindex flag --- + + +class TestReindex: + @patch(_PATCH_FS) + def test_reindex_passed_through(self, mock_get_fs, mock_indexing_deps): + mock_index_cls, mock_emb_cls, mock_vdb_cls = mock_indexing_deps + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + _setup_mock_index(mock_index_cls, "doc-reindex") + mock_index_cls.return_value.is_document_indexed.return_value = True + mock_emb_cls.return_value = MagicMock() + mock_vdb_cls.return_value = MagicMock() + mock_get_fs.return_value = MagicMock() + + ctx = _make_index_context( + executor_params={ + "embedding_instance_id": "emb-001", + "vector_db_instance_id": "vdb-001", + "x2text_instance_id": "x2t-001", + "file_path": "/data/test.pdf", + "file_hash": "abc123", + "extracted_text": "text", + "platform_api_key": "sk-test", + "chunk_size": 512, + "chunk_overlap": 128, + "reindex": True, + } + ) + result = executor.execute(ctx) + + assert result.success is True + init_call = mock_index_cls.call_args + assert init_call.kwargs["processing_options"].reindex is True + + +# --- 5. VectorDB.close() always called --- + + +class TestVectorDBClose: + @patch(_PATCH_FS) + def test_vectordb_closed_on_success(self, mock_get_fs, mock_indexing_deps): + mock_index_cls, mock_emb_cls, mock_vdb_cls = mock_indexing_deps + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + _setup_mock_index(mock_index_cls) + mock_emb_cls.return_value = MagicMock() + mock_vdb = MagicMock() + mock_vdb_cls.return_value = mock_vdb + mock_get_fs.return_value = MagicMock() + + ctx = _make_index_context() + executor.execute(ctx) + mock_vdb.close.assert_called_once() + + @patch(_PATCH_FS) + def test_vectordb_closed_on_error(self, mock_get_fs, mock_indexing_deps): + mock_index_cls, mock_emb_cls, mock_vdb_cls = mock_indexing_deps + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + mock_index = _setup_mock_index(mock_index_cls) + mock_index.is_document_indexed.side_effect = RuntimeError("boom") + mock_emb_cls.return_value = MagicMock() + mock_vdb = MagicMock() + mock_vdb_cls.return_value = mock_vdb + mock_get_fs.return_value = MagicMock() + + ctx = _make_index_context() + result = executor.execute(ctx) + + assert result.success is False + mock_vdb.close.assert_called_once() + + +# --- 6. Indexing error → failure result --- + + +class TestIndexingError: + @patch(_PATCH_FS) + def test_indexing_error_returns_failure( + self, mock_get_fs, mock_indexing_deps + ): + mock_index_cls, mock_emb_cls, mock_vdb_cls = mock_indexing_deps + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + mock_index = _setup_mock_index(mock_index_cls, "doc-err") + mock_index.perform_indexing.side_effect = RuntimeError( + "vector DB unavailable" + ) + mock_emb_cls.return_value = MagicMock() + mock_vdb_cls.return_value = MagicMock() + mock_get_fs.return_value = MagicMock() + + ctx = _make_index_context() + result = executor.execute(ctx) + + assert result.success is False + assert "indexing" in result.error.lower() + assert "vector DB unavailable" in result.error + + +# --- 7. Orchestrator integration --- + + +class TestOrchestratorIntegration: + @patch(_PATCH_FS) + def test_orchestrator_index_returns_success( + self, mock_get_fs, mock_indexing_deps + ): + mock_index_cls, mock_emb_cls, mock_vdb_cls = mock_indexing_deps + _register_legacy() + orchestrator = ExecutionOrchestrator() + + _setup_mock_index(mock_index_cls, "doc-orch") + mock_emb_cls.return_value = MagicMock() + mock_vdb_cls.return_value = MagicMock() + mock_get_fs.return_value = MagicMock() + + ctx = _make_index_context() + result = orchestrator.execute(ctx) + + assert result.success is True + assert result.data[IKeys.DOC_ID] == "doc-orch" + + +# --- 8. Celery eager-mode --- + + +@pytest.fixture +def eager_app(): + from executor.worker import app + + original = { + "task_always_eager": app.conf.task_always_eager, + "task_eager_propagates": app.conf.task_eager_propagates, + "result_backend": app.conf.result_backend, + } + app.conf.update( + task_always_eager=True, + task_eager_propagates=False, + result_backend="cache+memory://", + ) + yield app + app.conf.update(original) + + +class TestCeleryEager: + @patch(_PATCH_FS) + def test_eager_index_returns_success( + self, mock_get_fs, mock_indexing_deps, eager_app + ): + mock_index_cls, mock_emb_cls, mock_vdb_cls = mock_indexing_deps + _register_legacy() + + _setup_mock_index(mock_index_cls, "doc-celery") + mock_emb_cls.return_value = MagicMock() + mock_vdb_cls.return_value = MagicMock() + mock_get_fs.return_value = MagicMock() + + ctx = _make_index_context() + task = eager_app.tasks["execute_extraction"] + result_dict = task.apply(args=[ctx.to_dict()]).get() + result = ExecutionResult.from_dict(result_dict) + + assert result.success is True + assert result.data[IKeys.DOC_ID] == "doc-celery" + + +# --- 9. Index class receives correct DTOs --- + + +class TestIndexDTOs: + @patch(_PATCH_FS) + def test_index_created_with_correct_dtos( + self, mock_get_fs, mock_indexing_deps + ): + mock_index_cls, mock_emb_cls, mock_vdb_cls = mock_indexing_deps + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + _setup_mock_index(mock_index_cls, "doc-dto") + mock_emb_cls.return_value = MagicMock() + mock_vdb_cls.return_value = MagicMock() + mock_get_fs.return_value = MagicMock() + + ctx = _make_index_context( + executor_params={ + "embedding_instance_id": "emb-dto", + "vector_db_instance_id": "vdb-dto", + "x2text_instance_id": "x2t-dto", + "file_path": "/data/doc.pdf", + "file_hash": "hash-dto", + "extracted_text": "text", + "platform_api_key": "sk-test", + "chunk_size": 256, + "chunk_overlap": 64, + "tool_id": "tool-dto", + "tags": ["tag1"], + } + ) + executor.execute(ctx) + + init_kwargs = mock_index_cls.call_args.kwargs + ids = init_kwargs["instance_identifiers"] + assert ids.embedding_instance_id == "emb-dto" + assert ids.vector_db_instance_id == "vdb-dto" + assert ids.x2text_instance_id == "x2t-dto" + assert ids.tool_id == "tool-dto" + assert ids.tags == ["tag1"] + + chunking = init_kwargs["chunking_config"] + assert chunking.chunk_size == 256 + assert chunking.chunk_overlap == 64 + + gen_call = mock_index_cls.return_value.generate_index_key.call_args + fi = gen_call.kwargs["file_info"] + assert fi.file_path == "/data/doc.pdf" + assert fi.file_hash == "hash-dto" + + +# --- 10. EmbeddingCompat and VectorDB created with correct params --- + + +class TestAdapterCreation: + @patch(_PATCH_FS) + def test_embedding_and_vectordb_params( + self, mock_get_fs, mock_indexing_deps + ): + mock_index_cls, mock_emb_cls, mock_vdb_cls = mock_indexing_deps + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + _setup_mock_index(mock_index_cls, "doc-adapt") + mock_emb = MagicMock() + mock_emb_cls.return_value = mock_emb + mock_vdb = MagicMock() + mock_vdb_cls.return_value = mock_vdb + mock_get_fs.return_value = MagicMock() + + ctx = _make_index_context( + executor_params={ + "embedding_instance_id": "emb-check", + "vector_db_instance_id": "vdb-check", + "x2text_instance_id": "x2t-001", + "file_path": "/data/test.pdf", + "file_hash": "abc", + "extracted_text": "text", + "platform_api_key": "sk-test", + "chunk_size": 512, + "chunk_overlap": 128, + "usage_kwargs": {"org": "test-org"}, + } + ) + executor.execute(ctx) + + emb_call = mock_emb_cls.call_args + assert emb_call.kwargs["adapter_instance_id"] == "emb-check" + assert emb_call.kwargs["kwargs"] == {"org": "test-org"} + + vdb_call = mock_vdb_cls.call_args + assert vdb_call.kwargs["adapter_instance_id"] == "vdb-check" + assert vdb_call.kwargs["embedding"] is mock_emb diff --git a/workers/tests/test_legacy_executor_scaffold.py b/workers/tests/test_legacy_executor_scaffold.py new file mode 100644 index 0000000000..f2d9935f9b --- /dev/null +++ b/workers/tests/test_legacy_executor_scaffold.py @@ -0,0 +1,305 @@ +"""Phase 2A — LegacyExecutor scaffold tests. + +Verifies: +1. Registration in ExecutorRegistry +2. Name property +3. Unsupported operation handling +4. Each operation raises NotImplementedError +5. Orchestrator wraps NotImplementedError as failure +6. Celery eager-mode chain +7. Dispatch table coverage (every Operation has a handler) +8. Constants importable +9. DTOs importable +10. Exceptions standalone (no Flask dependency) +""" + +import pytest + +from unstract.sdk1.execution.context import ExecutionContext, Operation +from unstract.sdk1.execution.registry import ExecutorRegistry +from unstract.sdk1.execution.result import ExecutionResult + + +@pytest.fixture(autouse=True) +def _clean_registry(): + """Ensure a clean executor registry for every test.""" + ExecutorRegistry.clear() + yield + ExecutorRegistry.clear() + + +def _register_legacy(): + """Import executor.executors to trigger LegacyExecutor registration.""" + from executor.executors.legacy_executor import LegacyExecutor # noqa: F401 + + ExecutorRegistry.register(LegacyExecutor) + + +def _make_context(**overrides): + defaults = { + "executor_name": "legacy", + "operation": "extract", + "run_id": "run-2a-001", + "execution_source": "tool", + "organization_id": "org-test", + "request_id": "req-2a-001", + } + defaults.update(overrides) + return ExecutionContext(**defaults) + + +# --- 1. Registration --- + + +class TestRegistration: + def test_legacy_in_registry(self): + _register_legacy() + assert "legacy" in ExecutorRegistry.list_executors() + + +# --- 2. Name --- + + +class TestName: + def test_name_is_legacy(self): + _register_legacy() + executor = ExecutorRegistry.get("legacy") + assert executor.name == "legacy" + + +# --- 3. Unsupported operation --- + + +class TestUnsupportedOperation: + def test_unsupported_operation_returns_failure(self): + _register_legacy() + executor = ExecutorRegistry.get("legacy") + ctx = _make_context(operation="totally_unknown_op") + result = executor.execute(ctx) + + assert result.success is False + assert "does not support operation" in result.error + assert "totally_unknown_op" in result.error + + +# --- 4. All operations are implemented (no stubs remain) --- +# TestHandlerStubs and TestOrchestratorWrapping removed: +# All operations (extract, index, answer_prompt, single_pass_extraction, +# summarize) are now fully implemented. Agentic operations moved to +# AgenticPromptStudioExecutor (cloud plugin). + + +# --- 6. Celery eager-mode chain --- + + +@pytest.fixture +def eager_app(): + """Configure the real executor Celery app for eager-mode testing.""" + from executor.worker import app + + original = { + "task_always_eager": app.conf.task_always_eager, + "task_eager_propagates": app.conf.task_eager_propagates, + "result_backend": app.conf.result_backend, + } + + app.conf.update( + task_always_eager=True, + task_eager_propagates=False, + result_backend="cache+memory://", + ) + + yield app + + app.conf.update(original) + + +class TestCeleryEagerChain: + def test_eager_unsupported_op_returns_failure(self, eager_app): + """execute_extraction with an unsupported operation returns failure.""" + _register_legacy() + + ctx = _make_context(operation="totally_unknown_op") + task = eager_app.tasks["execute_extraction"] + result_dict = task.apply(args=[ctx.to_dict()]).get() + result = ExecutionResult.from_dict(result_dict) + + assert result.success is False + assert "does not support operation" in result.error + + +# --- 7. Dispatch table coverage --- + + +class TestDispatchTableCoverage: + def test_every_operation_has_handler(self): + """Every Operation handled by LegacyExecutor is in _OPERATION_MAP. + + Operations handled by cloud executors (discovered via entry points) + are excluded — they have their own executor classes. + """ + from executor.executors.legacy_executor import LegacyExecutor + + # Operations handled by cloud executors, not LegacyExecutor + cloud_executor_operations = { + "table_extract", # TableExtractorExecutor + "smart_table_extract", # SmartTableExtractorExecutor + "sps_answer_prompt", # SimplePromptStudioExecutor + "sps_index", # SimplePromptStudioExecutor + "agentic_extract", # AgenticPromptStudioExecutor + "agentic_summarize", # AgenticPromptStudioExecutor + "agentic_uniformize", # AgenticPromptStudioExecutor + "agentic_finalize", # AgenticPromptStudioExecutor + "agentic_generate_prompt", # AgenticPromptStudioExecutor + "agentic_generate_prompt_pipeline", # AgenticPromptStudioExecutor + "agentic_compare", # AgenticPromptStudioExecutor + "agentic_tune_field", # AgenticPromptStudioExecutor + } + + for op in Operation: + if op.value in cloud_executor_operations: + continue + assert op.value in LegacyExecutor._OPERATION_MAP, ( + f"Operation {op.value} missing from _OPERATION_MAP" + ) + + +# --- 8. Constants importable --- + + +class TestConstants: + def test_prompt_service_constants(self): + from executor.executors.constants import PromptServiceConstants + + assert hasattr(PromptServiceConstants, "TOOL_ID") + assert PromptServiceConstants.TOOL_ID == "tool_id" + + def test_retrieval_strategy(self): + from executor.executors.constants import RetrievalStrategy + + assert RetrievalStrategy.SIMPLE.value == "simple" + assert RetrievalStrategy.SUBQUESTION.value == "subquestion" + + def test_run_level(self): + from executor.executors.constants import RunLevel + + assert RunLevel.RUN.value == "RUN" + assert RunLevel.EVAL.value == "EVAL" + + +# --- 9. DTOs importable --- + + +class TestDTOs: + def test_chunking_config(self): + from executor.executors.dto import ChunkingConfig + + cfg = ChunkingConfig(chunk_size=512, chunk_overlap=64) + assert cfg.chunk_size == 512 + + def test_chunking_config_zero_raises(self): + from executor.executors.dto import ChunkingConfig + + with pytest.raises(ValueError, match="zero chunks"): + ChunkingConfig(chunk_size=0, chunk_overlap=0) + + def test_file_info(self): + from executor.executors.dto import FileInfo + + fi = FileInfo(file_path="/tmp/test.pdf", file_hash="abc123") + assert fi.file_path == "/tmp/test.pdf" + + def test_instance_identifiers(self): + from executor.executors.dto import InstanceIdentifiers + + ids = InstanceIdentifiers( + embedding_instance_id="emb-1", + vector_db_instance_id="vdb-1", + x2text_instance_id="x2t-1", + llm_instance_id="llm-1", + tool_id="tool-1", + ) + assert ids.tool_id == "tool-1" + + def test_processing_options(self): + from executor.executors.dto import ProcessingOptions + + opts = ProcessingOptions(reindex=True) + assert opts.reindex is True + assert opts.enable_highlight is False + + +# --- 10. Exceptions standalone --- + + +class TestExceptions: + def test_legacy_executor_error_has_code_and_message(self): + from executor.executors.exceptions import LegacyExecutorError + + err = LegacyExecutorError(message="test error", code=418) + assert err.message == "test error" + assert err.code == 418 + assert str(err) == "test error" + + def test_extraction_error_has_code_and_message(self): + from executor.executors.exceptions import ExtractionError + + err = ExtractionError(message="extraction failed", code=500) + assert err.message == "extraction failed" + assert err.code == 500 + + def test_no_flask_import(self): + """Verify exceptions module does NOT import Flask.""" + import importlib + import sys + + # Ensure fresh import + mod_name = "executor.executors.exceptions" + if mod_name in sys.modules: + importlib.reload(sys.modules[mod_name]) + else: + importlib.import_module(mod_name) + + # Check that no flask modules were pulled in + flask_modules = [m for m in sys.modules if m.startswith("flask")] + assert flask_modules == [], ( + f"Flask modules imported: {flask_modules}" + ) + + def test_custom_data_error_signature(self): + from executor.executors.exceptions import CustomDataError + + err = CustomDataError( + variable="invoice_num", reason="not found", is_ide=True + ) + assert "invoice_num" in err.message + assert "not found" in err.message + assert "Prompt Studio" in err.message + + def test_custom_data_error_tool_mode(self): + from executor.executors.exceptions import CustomDataError + + err = CustomDataError( + variable="order_id", reason="missing", is_ide=False + ) + assert "API request" in err.message + + def test_missing_field_error(self): + from executor.executors.exceptions import MissingFieldError + + err = MissingFieldError(missing_fields=["tool_id", "file_path"]) + assert "tool_id" in err.message + assert "file_path" in err.message + + def test_bad_request_defaults(self): + from executor.executors.exceptions import BadRequest + + err = BadRequest() + assert err.code == 400 + assert "Bad Request" in err.message + + def test_rate_limit_error_defaults(self): + from executor.executors.exceptions import RateLimitError + + err = RateLimitError() + assert err.code == 429 diff --git a/workers/tests/test_phase1_log_streaming.py b/workers/tests/test_phase1_log_streaming.py new file mode 100644 index 0000000000..903449d75a --- /dev/null +++ b/workers/tests/test_phase1_log_streaming.py @@ -0,0 +1,489 @@ +"""Phase 1 — Executor log streaming to frontend via Socket.IO. + +Tests cover: +- ExecutionContext round-trips log_events_id through to_dict/from_dict +- LogPublisher.log_progress() returns type: "PROGRESS" (not "LOG") +- LogPublisher.log_prompt() still returns type: "LOG" (unchanged) +- ExecutorToolShim with log_events_id: stream_log() publishes progress +- ExecutorToolShim without log_events_id: no publishing, no exceptions +- ExecutorToolShim with failing LogPublisher: no exception raised +- execute_extraction builds component dict when log_events_id present +- execute_extraction skips component dict when log_events_id absent +""" + +from unittest.mock import MagicMock, patch + + +from unstract.sdk1.constants import LogLevel +from unstract.sdk1.execution.context import ExecutionContext + + +# --------------------------------------------------------------------------- +# 1A — ExecutionContext.log_events_id round-trip +# --------------------------------------------------------------------------- + + +class TestExecutionContextLogEventsId: + """Verify log_events_id serialization in ExecutionContext.""" + + def test_log_events_id_default_is_none(self): + ctx = ExecutionContext( + executor_name="legacy", + operation="extract", + run_id="r1", + execution_source="ide", + ) + assert ctx.log_events_id is None + + def test_log_events_id_round_trips(self): + ctx = ExecutionContext( + executor_name="legacy", + operation="extract", + run_id="r1", + execution_source="ide", + log_events_id="session-abc", + ) + d = ctx.to_dict() + assert d["log_events_id"] == "session-abc" + + restored = ExecutionContext.from_dict(d) + assert restored.log_events_id == "session-abc" + + def test_log_events_id_none_round_trips(self): + ctx = ExecutionContext( + executor_name="legacy", + operation="extract", + run_id="r1", + execution_source="ide", + ) + d = ctx.to_dict() + assert d["log_events_id"] is None + + restored = ExecutionContext.from_dict(d) + assert restored.log_events_id is None + + def test_backward_compat_missing_key(self): + """from_dict with old payload lacking log_events_id.""" + old_payload = { + "executor_name": "legacy", + "operation": "extract", + "run_id": "r1", + "execution_source": "ide", + } + ctx = ExecutionContext.from_dict(old_payload) + assert ctx.log_events_id is None + + +# --------------------------------------------------------------------------- +# 1B-i — LogPublisher.log_progress() vs log_prompt() +# --------------------------------------------------------------------------- + + +class TestLogPublisherLogProgress: + """Verify log_progress returns type PROGRESS, log_prompt returns LOG.""" + + def test_log_progress_type(self): + from unstract.core.pubsub_helper import LogPublisher + + result = LogPublisher.log_progress( + component={"tool_id": "t1"}, + level="INFO", + state="TOOL_RUN", + message="Extracting text...", + ) + assert result["type"] == "PROGRESS" + assert result["service"] == "prompt" + assert result["message"] == "Extracting text..." + assert result["component"] == {"tool_id": "t1"} + assert "timestamp" in result + + def test_log_prompt_type_unchanged(self): + from unstract.core.pubsub_helper import LogPublisher + + result = LogPublisher.log_prompt( + component={"tool_id": "t1"}, + level="INFO", + state="RUNNING", + message="test", + ) + assert result["type"] == "LOG" + assert result["service"] == "prompt" + + def test_log_progress_has_all_fields(self): + from unstract.core.pubsub_helper import LogPublisher + + result = LogPublisher.log_progress( + component={"tool_id": "t1", "prompt_key": "pk"}, + level="ERROR", + state="FAILED", + message="boom", + ) + assert result["level"] == "ERROR" + assert result["state"] == "FAILED" + assert result["component"]["prompt_key"] == "pk" + + +# --------------------------------------------------------------------------- +# 1B-ii — ExecutorToolShim progress publishing +# --------------------------------------------------------------------------- + + +class TestExecutorToolShimProgress: + """Verify ExecutorToolShim publishes progress via LogPublisher.""" + + @patch("executor.executor_tool_shim.LogPublisher") + def test_stream_log_publishes_when_log_events_id_set(self, mock_lp): + from executor.executor_tool_shim import ExecutorToolShim + + component = {"tool_id": "t1", "run_id": "r1"} + shim = ExecutorToolShim( + platform_api_key="sk-test", + log_events_id="session-xyz", + component=component, + ) + shim.stream_log("Extracting...", level=LogLevel.INFO) + + mock_lp.log_progress.assert_called_once_with( + component=component, + level="INFO", + state="TOOL_RUN", + message="Extracting...", + ) + mock_lp.publish.assert_called_once_with( + channel_id="session-xyz", + payload=mock_lp.log_progress.return_value, + ) + + @patch("executor.executor_tool_shim.LogPublisher") + def test_stream_log_no_publish_without_log_events_id(self, mock_lp): + from executor.executor_tool_shim import ExecutorToolShim + + shim = ExecutorToolShim(platform_api_key="sk-test") + shim.stream_log("Hello", level=LogLevel.INFO) + + mock_lp.log_progress.assert_not_called() + mock_lp.publish.assert_not_called() + + @patch("executor.executor_tool_shim.LogPublisher") + def test_stream_log_empty_log_events_id_no_publish(self, mock_lp): + from executor.executor_tool_shim import ExecutorToolShim + + shim = ExecutorToolShim( + platform_api_key="sk-test", log_events_id="" + ) + shim.stream_log("Hello", level=LogLevel.INFO) + + mock_lp.log_progress.assert_not_called() + + @patch("executor.executor_tool_shim.LogPublisher") + def test_stream_log_swallows_publish_error(self, mock_lp): + from executor.executor_tool_shim import ExecutorToolShim + + mock_lp.publish.side_effect = ConnectionError("AMQP down") + shim = ExecutorToolShim( + platform_api_key="sk-test", + log_events_id="session-xyz", + component={"tool_id": "t1"}, + ) + # Should NOT raise + shim.stream_log("test", level=LogLevel.INFO) + + @patch("executor.executor_tool_shim.LogPublisher") + def test_level_mapping(self, mock_lp): + from executor.executor_tool_shim import ExecutorToolShim + + shim = ExecutorToolShim( + platform_api_key="sk-test", + log_events_id="s1", + component={}, + ) + + # DEBUG is below the shim's log_level (INFO) so it should NOT + # be published to the frontend. + shim.stream_log("msg", level=LogLevel.DEBUG) + assert not mock_lp.log_progress.called, ( + "DEBUG should be filtered out (below INFO threshold)" + ) + + # INFO and above should be published with the correct mapped level. + published_cases = [ + (LogLevel.INFO, "INFO"), + (LogLevel.WARN, "WARN"), + (LogLevel.ERROR, "ERROR"), + (LogLevel.FATAL, "ERROR"), + ] + for sdk_level, expected_wf_level in published_cases: + mock_lp.reset_mock() + shim.stream_log("msg", level=sdk_level) + call_kwargs = mock_lp.log_progress.call_args + assert call_kwargs.kwargs["level"] == expected_wf_level, ( + f"SDK {sdk_level} should map to {expected_wf_level}" + ) + + @patch("executor.executor_tool_shim.LogPublisher") + def test_custom_stage_passed_through(self, mock_lp): + from executor.executor_tool_shim import ExecutorToolShim + + shim = ExecutorToolShim( + platform_api_key="sk-test", + log_events_id="s1", + component={}, + ) + shim.stream_log("msg", level=LogLevel.INFO, stage="INDEXING") + call_kwargs = mock_lp.log_progress.call_args + assert call_kwargs.kwargs["state"] == "INDEXING" + + +# --------------------------------------------------------------------------- +# 1C — Component dict building in execute_extraction +# --------------------------------------------------------------------------- + + +class TestExecuteExtractionComponentDict: + """Verify component dict is built from executor_params.""" + + @patch("executor.tasks.ExecutionOrchestrator") + def test_component_dict_built_when_log_events_id_present( + self, mock_orch_cls + ): + mock_orch = MagicMock() + mock_orch.execute.return_value = MagicMock( + success=True, to_dict=lambda: {"success": True} + ) + mock_orch_cls.return_value = mock_orch + + from executor.tasks import execute_extraction + + payload = { + "executor_name": "legacy", + "operation": "extract", + "run_id": "r1", + "execution_source": "ide", + "log_events_id": "session-abc", + "executor_params": { + "tool_id": "tool-123", + "file_name": "invoice.pdf", + }, + } + execute_extraction(payload) + + # Verify the context passed to orchestrator has _log_component + ctx = mock_orch.execute.call_args[0][0] + assert ctx._log_component == { + "tool_id": "tool-123", + "run_id": "r1", + "doc_name": "invoice.pdf", + "operation": "extract", + } + + @patch("executor.tasks.ExecutionOrchestrator") + def test_component_dict_empty_when_no_log_events_id( + self, mock_orch_cls + ): + mock_orch = MagicMock() + mock_orch.execute.return_value = MagicMock( + success=True, to_dict=lambda: {"success": True} + ) + mock_orch_cls.return_value = mock_orch + + from executor.tasks import execute_extraction + + payload = { + "executor_name": "legacy", + "operation": "extract", + "run_id": "r1", + "execution_source": "ide", + "executor_params": {}, + } + execute_extraction(payload) + + ctx = mock_orch.execute.call_args[0][0] + assert ctx._log_component == {} + + +# --------------------------------------------------------------------------- +# 1D — LegacyExecutor passes log info to shim +# --------------------------------------------------------------------------- + + +class TestLegacyExecutorLogPassthrough: + """Verify LegacyExecutor passes log_events_id and component to shim.""" + + @patch("executor.executors.legacy_executor.FileUtils.get_fs_instance") + @patch("executor.executors.legacy_executor.X2Text") + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_extract_passes_log_info_to_shim( + self, mock_shim_cls, mock_x2text, mock_fs + ): + from executor.executors.legacy_executor import LegacyExecutor + from unstract.sdk1.execution.registry import ExecutorRegistry + + if "legacy" not in ExecutorRegistry.list_executors(): + ExecutorRegistry._registry["legacy"] = LegacyExecutor + + mock_shim = MagicMock() + mock_shim_cls.return_value = mock_shim + mock_x2t = MagicMock() + mock_x2t.process.return_value = MagicMock( + extracted_text="hello" + ) + mock_x2text.return_value = mock_x2t + + ctx = ExecutionContext( + executor_name="legacy", + operation="extract", + run_id="r1", + execution_source="ide", + log_events_id="session-abc", + executor_params={ + "x2text_instance_id": "x2t-1", + "file_path": "/tmp/test.pdf", + "platform_api_key": "sk-test", + }, + ) + ctx._log_component = {"tool_id": "t1", "run_id": "r1", "doc_name": "test.pdf"} + + executor = LegacyExecutor() + result = executor.execute(ctx) + + assert result.success + mock_shim_cls.assert_called_once_with( + platform_api_key="sk-test", + log_events_id="session-abc", + component={"tool_id": "t1", "run_id": "r1", "doc_name": "test.pdf"}, + ) + + @patch("executor.executors.legacy_executor.FileUtils.get_fs_instance") + @patch("executor.executors.legacy_executor.X2Text") + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_extract_no_log_info_when_absent( + self, mock_shim_cls, mock_x2text, mock_fs + ): + from executor.executors.legacy_executor import LegacyExecutor + from unstract.sdk1.execution.registry import ExecutorRegistry + + if "legacy" not in ExecutorRegistry.list_executors(): + ExecutorRegistry._registry["legacy"] = LegacyExecutor + + mock_shim = MagicMock() + mock_shim_cls.return_value = mock_shim + mock_x2t = MagicMock() + mock_x2t.process.return_value = MagicMock( + extracted_text="hello" + ) + mock_x2text.return_value = mock_x2t + + ctx = ExecutionContext( + executor_name="legacy", + operation="extract", + run_id="r1", + execution_source="tool", + executor_params={ + "x2text_instance_id": "x2t-1", + "file_path": "/tmp/test.pdf", + "platform_api_key": "sk-test", + }, + ) + + executor = LegacyExecutor() + result = executor.execute(ctx) + + assert result.success + mock_shim_cls.assert_called_once_with( + platform_api_key="sk-test", + log_events_id="", + component={}, + ) + + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_answer_prompt_enriches_component_with_prompt_key( + self, mock_shim_cls, mock_prompt_deps + ): + """Verify per-prompt shim includes prompt_key in component.""" + from executor.executors.legacy_executor import LegacyExecutor + from unstract.sdk1.execution.registry import ExecutorRegistry + + if "legacy" not in ExecutorRegistry.list_executors(): + ExecutorRegistry._registry["legacy"] = LegacyExecutor + + mock_shim = MagicMock() + mock_shim_cls.return_value = mock_shim + + # Mock prompt deps + MockAnswerPromptService = MagicMock() + MockAnswerPromptService.extract_variable.return_value = "prompt text" + MockRetrievalService = MagicMock() + MockVariableReplacementService = MagicMock() + MockVariableReplacementService.is_variables_present.return_value = ( + False + ) + MockIndex = MagicMock() + MockLLM = MagicMock() + MockEmbeddingCompat = MagicMock() + MockVectorDB = MagicMock() + + mock_prompt_deps.return_value = ( + MockAnswerPromptService, + MockRetrievalService, + MockVariableReplacementService, + MockIndex, + MockLLM, + MockEmbeddingCompat, + MockVectorDB, + ) + + ctx = ExecutionContext( + executor_name="legacy", + operation="answer_prompt", + run_id="r1", + execution_source="ide", + log_events_id="session-abc", + executor_params={ + "tool_id": "t1", + "outputs": [ + { + "name": "invoice_number", + "prompt": "What is the invoice number?", + "chunk-size": 0, + "type": "text", + "retrieval-strategy": "simple", + "vector-db": "vdb1", + "embedding": "emb1", + "x2text_adapter": "x2t1", + "chunk-overlap": 0, + "llm": "llm1", + }, + ], + "tool_settings": {}, + "PLATFORM_SERVICE_API_KEY": "sk-test", + }, + ) + ctx._log_component = { + "tool_id": "t1", + "run_id": "r1", + "doc_name": "test.pdf", + } + + # Mock IndexingUtils + with patch( + "unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-1", + ): + executor = LegacyExecutor() + # The handler will try retrieval which we need to mock + MockRetrievalService.retrieve_complete_context.return_value = [ + "context" + ] + MockAnswerPromptService.construct_and_run_prompt.return_value = ( + "INV-001" + ) + + executor.execute(ctx) + + # Check that shim was created with prompt_key in component + shim_call = mock_shim_cls.call_args + assert shim_call.kwargs["component"]["prompt_key"] == "invoice_number" + assert shim_call.kwargs["log_events_id"] == "session-abc" diff --git a/workers/tests/test_phase2f.py b/workers/tests/test_phase2f.py new file mode 100644 index 0000000000..a5913367c1 --- /dev/null +++ b/workers/tests/test_phase2f.py @@ -0,0 +1,330 @@ +"""Phase 2F — single_pass_extraction, summarize, agentic operations tests. + +Verifies: +1. single_pass_extraction delegates to answer_prompt +2. summarize constructs prompt and calls LLM +3. summarize missing params return failure +4. summarize prompt includes prompt_keys +5. agentic operations rejected by LegacyExecutor (cloud executor handles them) +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from unstract.sdk1.execution.context import ExecutionContext +from unstract.sdk1.execution.registry import ExecutorRegistry +from unstract.sdk1.execution.result import ExecutionResult + + +@pytest.fixture(autouse=True) +def _clean_registry(): + """Ensure a clean executor registry for every test.""" + ExecutorRegistry.clear() + yield + ExecutorRegistry.clear() + + +def _register_legacy(): + from executor.executors.legacy_executor import LegacyExecutor # noqa: F401 + + if "legacy" not in ExecutorRegistry.list_executors(): + ExecutorRegistry.register(LegacyExecutor) + + +def _make_context(**overrides): + defaults = { + "executor_name": "legacy", + "operation": "summarize", + "run_id": "run-2f-001", + "execution_source": "tool", + "organization_id": "org-test", + "request_id": "req-2f-001", + } + defaults.update(overrides) + return ExecutionContext(**defaults) + + +# --------------------------------------------------------------------------- +# 1. single_pass_extraction delegates to answer_prompt +# --------------------------------------------------------------------------- + + +class TestSinglePassExtraction: + def test_delegates_to_answer_prompt(self): + """single_pass_extraction calls _handle_answer_prompt internally.""" + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + # Mock _handle_answer_prompt so we can verify delegation + expected_result = ExecutionResult( + success=True, + data={"output": {"field1": "value1"}, "metadata": {}, "metrics": {}}, + ) + executor._handle_answer_prompt = MagicMock(return_value=expected_result) + + ctx = _make_context(operation="single_pass_extraction") + result = executor.execute(ctx) + + assert result.success is True + assert result.data["output"]["field1"] == "value1" + executor._handle_answer_prompt.assert_called_once_with(ctx) + + def test_delegates_failure_too(self): + """Failures from answer_prompt propagate through single_pass.""" + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + fail_result = ExecutionResult.failure(error="some error") + executor._handle_answer_prompt = MagicMock(return_value=fail_result) + + ctx = _make_context(operation="single_pass_extraction") + result = executor.execute(ctx) + + assert result.success is False + assert "some error" in result.error + + +# --------------------------------------------------------------------------- +# 2. summarize +# --------------------------------------------------------------------------- + + +def _make_summarize_params(**overrides): + """Build executor_params for summarize operation.""" + defaults = { + "llm_adapter_instance_id": "llm-001", + "summarize_prompt": "Summarize the following document.", + "context": "This is a long document with lots of content.", + "prompt_keys": ["invoice_number", "total_amount"], + "PLATFORM_SERVICE_API_KEY": "test-key", + } + defaults.update(overrides) + return defaults + + +class TestSummarize: + @patch("executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps") + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_summarize_success(self, mock_shim_cls, mock_get_deps): + """Successful summarize returns data with summary text.""" + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + # Set up mock LLM + mock_llm_cls = MagicMock() + mock_llm = MagicMock() + mock_llm_cls.return_value = mock_llm + + mock_get_deps.return_value = ( + MagicMock(), # AnswerPromptService + MagicMock(), # RetrievalService + MagicMock(), # VariableReplacementService + MagicMock(), # Index + mock_llm_cls, # LLM + MagicMock(), # EmbeddingCompat + MagicMock(), # VectorDB + ) + + # Mock AnswerPromptService.run_completion + with patch( + "executor.executors.answer_prompt.AnswerPromptService.run_completion", + return_value="This is a summary of the document.", + ): + ctx = _make_context( + operation="summarize", + executor_params=_make_summarize_params(), + ) + result = executor.execute(ctx) + + assert result.success is True + assert result.data["data"] == "This is a summary of the document." + + @patch("executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps") + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_summarize_prompt_includes_keys(self, mock_shim_cls, mock_get_deps): + """The summarize prompt includes prompt_keys.""" + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + mock_llm_cls = MagicMock() + mock_llm = MagicMock() + mock_llm_cls.return_value = mock_llm + + mock_get_deps.return_value = ( + MagicMock(), MagicMock(), MagicMock(), MagicMock(), + mock_llm_cls, MagicMock(), MagicMock(), + ) + + captured_prompt = {} + + def capture_run_completion(llm, prompt, **kwargs): + captured_prompt["value"] = prompt + return "summary" + + with patch( + "executor.executors.answer_prompt.AnswerPromptService.run_completion", + side_effect=capture_run_completion, + ): + ctx = _make_context( + operation="summarize", + executor_params=_make_summarize_params( + prompt_keys=["name", "address"], + ), + ) + executor.execute(ctx) + + assert "name" in captured_prompt["value"] + assert "address" in captured_prompt["value"] + + @patch("executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps") + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_summarize_no_prompt_keys(self, mock_shim_cls, mock_get_deps): + """Summarize works without prompt_keys.""" + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + mock_llm_cls = MagicMock() + mock_llm_cls.return_value = MagicMock() + + mock_get_deps.return_value = ( + MagicMock(), MagicMock(), MagicMock(), MagicMock(), + mock_llm_cls, MagicMock(), MagicMock(), + ) + + with patch( + "executor.executors.answer_prompt.AnswerPromptService.run_completion", + return_value="summary without keys", + ): + params = _make_summarize_params() + del params["prompt_keys"] + ctx = _make_context( + operation="summarize", + executor_params=params, + ) + result = executor.execute(ctx) + + assert result.success is True + assert result.data["data"] == "summary without keys" + + def test_summarize_missing_llm_adapter(self): + """Missing llm_adapter_instance_id returns failure.""" + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + params = _make_summarize_params(llm_adapter_instance_id="") + ctx = _make_context( + operation="summarize", + executor_params=params, + ) + result = executor.execute(ctx) + + assert result.success is False + assert "llm_adapter_instance_id" in result.error + + def test_summarize_missing_context(self): + """Missing context returns failure.""" + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + params = _make_summarize_params(context="") + ctx = _make_context( + operation="summarize", + executor_params=params, + ) + result = executor.execute(ctx) + + assert result.success is False + assert "context" in result.error + + @patch("executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps") + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_summarize_llm_error(self, mock_shim_cls, mock_get_deps): + """LLM errors are wrapped in ExecutionResult.failure.""" + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + mock_llm_cls = MagicMock() + mock_llm_cls.return_value = MagicMock() + + mock_get_deps.return_value = ( + MagicMock(), MagicMock(), MagicMock(), MagicMock(), + mock_llm_cls, MagicMock(), MagicMock(), + ) + + with patch( + "executor.executors.answer_prompt.AnswerPromptService.run_completion", + side_effect=Exception("LLM unavailable"), + ): + ctx = _make_context( + operation="summarize", + executor_params=_make_summarize_params(), + ) + result = executor.execute(ctx) + + assert result.success is False + assert "summarization" in result.error.lower() or "LLM" in result.error + + @patch("executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps") + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_summarize_creates_llm_with_correct_adapter( + self, mock_shim_cls, mock_get_deps + ): + """LLM is instantiated with the provided adapter instance ID.""" + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + mock_llm_cls = MagicMock() + mock_llm = MagicMock() + mock_llm_cls.return_value = mock_llm + + mock_get_deps.return_value = ( + MagicMock(), MagicMock(), MagicMock(), MagicMock(), + mock_llm_cls, MagicMock(), MagicMock(), + ) + + with patch( + "executor.executors.answer_prompt.AnswerPromptService.run_completion", + return_value="summary", + ): + ctx = _make_context( + operation="summarize", + executor_params=_make_summarize_params( + llm_adapter_instance_id="custom-llm-42", + ), + ) + executor.execute(ctx) + + mock_llm_cls.assert_called_once() + call_kwargs = mock_llm_cls.call_args + assert call_kwargs.kwargs["adapter_instance_id"] == "custom-llm-42" + + +# --------------------------------------------------------------------------- +# 3. agentic operations — handled by AgenticPromptStudioExecutor (cloud) +# --------------------------------------------------------------------------- + + +class TestAgenticExtraction: + def test_legacy_rejects_agentic_operations(self): + """LegacyExecutor does not handle agentic operations (cloud executor).""" + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + ctx = _make_context(operation="agentic_extract") + result = executor.execute(ctx) + + assert result.success is False + assert "does not support" in result.error + + def test_orchestrator_wraps_unsupported_agentic(self): + """ExecutionOrchestrator returns failure for agentic ops on legacy.""" + from unstract.sdk1.execution.orchestrator import ExecutionOrchestrator + + _register_legacy() + orchestrator = ExecutionOrchestrator() + ctx = _make_context(operation="agentic_extract") + result = orchestrator.execute(ctx) + + assert result.success is False + assert "does not support" in result.error diff --git a/workers/tests/test_phase2h.py b/workers/tests/test_phase2h.py new file mode 100644 index 0000000000..cf02c767b2 --- /dev/null +++ b/workers/tests/test_phase2h.py @@ -0,0 +1,484 @@ +"""Phase 2H: Tests for variable replacement and postprocessor modules. + +Covers VariableReplacementHelper, VariableReplacementService, and +the webhook postprocessor — all pure Python with no llama_index deps. +""" + +import json +from unittest.mock import MagicMock, patch + +import pytest +import requests as real_requests + +from executor.executors.constants import VariableType +from executor.executors.exceptions import CustomDataError, LegacyExecutorError +from executor.executors.postprocessor import ( + _validate_structured_output, + postprocess_data, +) +from executor.executors.variable_replacement import ( + VariableReplacementHelper, + VariableReplacementService, +) + + +# ============================================================================ +# 1. VariableReplacementHelper (15 tests) +# ============================================================================ + + +class TestVariableReplacementHelper: + """Tests for the low-level replacement helper.""" + + # --- extract_variables_from_prompt --- + + def test_extract_variables_single(self): + result = VariableReplacementHelper.extract_variables_from_prompt("{{name}}") + assert result == ["name"] + + def test_extract_variables_multiple(self): + result = VariableReplacementHelper.extract_variables_from_prompt( + "{{a}} and {{b}}" + ) + assert result == ["a", "b"] + + def test_extract_variables_none(self): + result = VariableReplacementHelper.extract_variables_from_prompt("no vars here") + assert result == [] + + # --- identify_variable_type --- + + def test_identify_static_type(self): + assert ( + VariableReplacementHelper.identify_variable_type("name") + == VariableType.STATIC + ) + + def test_identify_dynamic_type(self): + assert ( + VariableReplacementHelper.identify_variable_type( + "https://example.com/api[field1]" + ) + == VariableType.DYNAMIC + ) + + def test_identify_custom_data_type(self): + assert ( + VariableReplacementHelper.identify_variable_type("custom_data.company") + == VariableType.CUSTOM_DATA + ) + + # --- handle_json_and_str_types --- + + def test_handle_json_dict(self): + result = VariableReplacementHelper.handle_json_and_str_types({"k": "v"}) + assert result == '{"k": "v"}' + + def test_handle_json_list(self): + result = VariableReplacementHelper.handle_json_and_str_types([1, 2]) + assert result == "[1, 2]" + + # --- replace_generic_string_value --- + + def test_replace_generic_string_non_str(self): + """Non-string values get JSON-formatted before replacement.""" + result = VariableReplacementHelper.replace_generic_string_value( + prompt="value: {{x}}", variable="{{x}}", value={"nested": True} + ) + assert result == 'value: {"nested": true}' + + # --- check_static_variable_run_status --- + + def test_check_static_missing_key(self): + result = VariableReplacementHelper.check_static_variable_run_status( + structure_output={}, variable="missing" + ) + assert result is None + + # --- replace_static_variable --- + + def test_replace_static_missing_returns_prompt(self): + """Missing key in structured_output leaves prompt unchanged.""" + prompt = "Total is {{revenue}}" + result = VariableReplacementHelper.replace_static_variable( + prompt=prompt, structured_output={}, variable="revenue" + ) + assert result == prompt + + # --- replace_custom_data_variable --- + + def test_custom_data_nested_path(self): + """custom_data.nested.key navigates nested dict.""" + result = VariableReplacementHelper.replace_custom_data_variable( + prompt="val: {{custom_data.nested.key}}", + variable="custom_data.nested.key", + custom_data={"nested": {"key": "deep_value"}}, + ) + assert result == "val: deep_value" + + def test_custom_data_empty_dict_raises(self): + """Empty custom_data={} raises CustomDataError.""" + with pytest.raises(CustomDataError, match="Custom data is not configured"): + VariableReplacementHelper.replace_custom_data_variable( + prompt="{{custom_data.company}}", + variable="custom_data.company", + custom_data={}, + ) + + # --- fetch_dynamic_variable_value / replace_dynamic_variable --- + + @patch("executor.executors.variable_replacement.pyrequests.post") + def test_dynamic_variable_success(self, mock_post): + """Mock HTTP POST, verify URL extraction and replacement.""" + mock_resp = MagicMock() + mock_resp.headers = {"content-type": "application/json"} + mock_resp.json.return_value = {"result": "ok"} + mock_resp.raise_for_status = MagicMock() + mock_post.return_value = mock_resp + + variable = "https://example.com/api[field1]" + result = VariableReplacementHelper.replace_dynamic_variable( + prompt="data: {{" + variable + "}}", + variable=variable, + structured_output={"field1": "input_data"}, + ) + mock_post.assert_called_once() + assert '{"result": "ok"}' in result + + @patch("executor.executors.variable_replacement.pyrequests.post") + def test_dynamic_variable_http_error(self, mock_post): + """HTTP error raises LegacyExecutorError.""" + mock_post.side_effect = real_requests.exceptions.ConnectionError("refused") + + with pytest.raises(LegacyExecutorError, match="failed"): + VariableReplacementHelper.fetch_dynamic_variable_value( + url="https://example.com/api", data="payload" + ) + + +# ============================================================================ +# 2. VariableReplacementService (8 tests) +# ============================================================================ + + +class TestVariableReplacementService: + """Tests for the high-level orchestration service.""" + + def test_replace_with_variable_map(self): + """Uses variable_map key from prompt dict when present.""" + prompt = { + "prompt": "Hello {{name}}", + "variable_map": {"name": "World"}, + } + result = VariableReplacementService.replace_variables_in_prompt( + prompt=prompt, + structured_output={"name": "Fallback"}, + prompt_name="test", + ) + assert result == "Hello World" + + def test_replace_fallback_structured_output(self): + """Falls back to structured_output when no variable_map.""" + prompt = {"prompt": "Hello {{name}}"} + result = VariableReplacementService.replace_variables_in_prompt( + prompt=prompt, + structured_output={"name": "Fallback"}, + prompt_name="test", + ) + assert result == "Hello Fallback" + + def test_mixed_variable_types(self): + """Prompt with static + custom_data variables replaces both.""" + prompt = { + "prompt": "{{name}} works at {{custom_data.company}}", + "variable_map": {"name": "Alice"}, + } + result = VariableReplacementService.replace_variables_in_prompt( + prompt=prompt, + structured_output={}, + prompt_name="test", + custom_data={"company": "Acme"}, + ) + assert result == "Alice works at Acme" + + def test_no_variables_noop(self): + """Prompt without {{}} returns unchanged.""" + prompt = {"prompt": "No variables here"} + result = VariableReplacementService.replace_variables_in_prompt( + prompt=prompt, + structured_output={}, + prompt_name="test", + ) + assert result == "No variables here" + + def test_replace_with_custom_data(self): + """custom_data dict gets passed through to helper.""" + prompt = { + "prompt": "Company: {{custom_data.name}}", + "variable_map": {}, + } + result = VariableReplacementService.replace_variables_in_prompt( + prompt=prompt, + structured_output={}, + prompt_name="test", + custom_data={"name": "TestCorp"}, + ) + assert result == "Company: TestCorp" + + def test_is_ide_flag_propagated(self): + """is_ide=False propagates — error message says 'API request'.""" + prompt = { + "prompt": "{{custom_data.missing}}", + "variable_map": {}, + } + with pytest.raises(CustomDataError, match="API request"): + VariableReplacementService.replace_variables_in_prompt( + prompt=prompt, + structured_output={}, + prompt_name="test", + custom_data={}, + is_ide=False, + ) + + def test_multiple_same_variable(self): + """{{x}} and {{x}} — both occurrences replaced.""" + prompt = { + "prompt": "{{x}} and {{x}}", + "variable_map": {"x": "val"}, + } + result = VariableReplacementService.replace_variables_in_prompt( + prompt=prompt, + structured_output={}, + prompt_name="test", + ) + assert result == "val and val" + + def test_json_value_replacement(self): + """Dict value gets JSON-serialized before replacement.""" + prompt = { + "prompt": "data: {{info}}", + "variable_map": {"info": {"key": "value"}}, + } + result = VariableReplacementService.replace_variables_in_prompt( + prompt=prompt, + structured_output={}, + prompt_name="test", + ) + assert result == 'data: {"key": "value"}' + + +# ============================================================================ +# 3. Postprocessor (15 tests) +# ============================================================================ + + +class TestPostprocessor: + """Tests for the webhook postprocessor.""" + + PARSED = {"field": "original"} + HIGHLIGHT = [{"page": 1, "spans": []}] + + # --- disabled / no-op paths --- + + def test_disabled_returns_original(self): + result = postprocess_data( + parsed_data=self.PARSED, + webhook_enabled=False, + highlight_data=self.HIGHLIGHT, + ) + assert result == (self.PARSED, self.HIGHLIGHT) + + def test_no_url_returns_original(self): + result = postprocess_data( + parsed_data=self.PARSED, + webhook_enabled=True, + webhook_url=None, + highlight_data=self.HIGHLIGHT, + ) + assert result == (self.PARSED, self.HIGHLIGHT) + + # --- successful webhook --- + + @patch("executor.executors.postprocessor.requests.post") + def test_success_returns_updated(self, mock_post): + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"structured_output": {"field": "updated"}} + mock_post.return_value = mock_resp + + result = postprocess_data( + parsed_data=self.PARSED, + webhook_enabled=True, + webhook_url="https://hook.example.com", + highlight_data=self.HIGHLIGHT, + ) + assert result[0] == {"field": "updated"} + + @patch("executor.executors.postprocessor.requests.post") + def test_success_preserves_highlight_data(self, mock_post): + """Response without highlight_data preserves original.""" + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"structured_output": {"f": "v"}} + mock_post.return_value = mock_resp + + _, highlight = postprocess_data( + parsed_data=self.PARSED, + webhook_enabled=True, + webhook_url="https://hook.example.com", + highlight_data=self.HIGHLIGHT, + ) + assert highlight == self.HIGHLIGHT + + @patch("executor.executors.postprocessor.requests.post") + def test_success_updates_highlight_data(self, mock_post): + """Response with valid list highlight_data uses updated.""" + new_highlight = [{"page": 2}] + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = { + "structured_output": {"f": "v"}, + "highlight_data": new_highlight, + } + mock_post.return_value = mock_resp + + _, highlight = postprocess_data( + parsed_data=self.PARSED, + webhook_enabled=True, + webhook_url="https://hook.example.com", + highlight_data=self.HIGHLIGHT, + ) + assert highlight == new_highlight + + @patch("executor.executors.postprocessor.requests.post") + def test_invalid_highlight_data_ignored(self, mock_post): + """Response with non-list highlight_data keeps original.""" + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = { + "structured_output": {"f": "v"}, + "highlight_data": "not-a-list", + } + mock_post.return_value = mock_resp + + _, highlight = postprocess_data( + parsed_data=self.PARSED, + webhook_enabled=True, + webhook_url="https://hook.example.com", + highlight_data=self.HIGHLIGHT, + ) + assert highlight == self.HIGHLIGHT + + # --- response validation failures --- + + @patch("executor.executors.postprocessor.requests.post") + def test_missing_structured_output_key(self, mock_post): + """Response without structured_output returns original.""" + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"other_key": "value"} + mock_post.return_value = mock_resp + + result = postprocess_data( + parsed_data=self.PARSED, + webhook_enabled=True, + webhook_url="https://hook.example.com", + highlight_data=self.HIGHLIGHT, + ) + assert result == (self.PARSED, self.HIGHLIGHT) + + @patch("executor.executors.postprocessor.requests.post") + def test_invalid_structured_output_type(self, mock_post): + """Response with string structured_output returns original.""" + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"structured_output": "just-a-string"} + mock_post.return_value = mock_resp + + result = postprocess_data( + parsed_data=self.PARSED, + webhook_enabled=True, + webhook_url="https://hook.example.com", + highlight_data=self.HIGHLIGHT, + ) + assert result == (self.PARSED, self.HIGHLIGHT) + + # --- HTTP error paths --- + + @patch("executor.executors.postprocessor.requests.post") + def test_http_error_returns_original(self, mock_post): + mock_resp = MagicMock() + mock_resp.status_code = 500 + mock_post.return_value = mock_resp + + result = postprocess_data( + parsed_data=self.PARSED, + webhook_enabled=True, + webhook_url="https://hook.example.com", + highlight_data=self.HIGHLIGHT, + ) + assert result == (self.PARSED, self.HIGHLIGHT) + + @patch("executor.executors.postprocessor.requests.post") + def test_timeout_returns_original(self, mock_post): + mock_post.side_effect = real_requests.exceptions.Timeout("timed out") + + result = postprocess_data( + parsed_data=self.PARSED, + webhook_enabled=True, + webhook_url="https://hook.example.com", + highlight_data=self.HIGHLIGHT, + ) + assert result == (self.PARSED, self.HIGHLIGHT) + + @patch("executor.executors.postprocessor.requests.post") + def test_connection_error_returns_original(self, mock_post): + mock_post.side_effect = real_requests.exceptions.ConnectionError("refused") + + result = postprocess_data( + parsed_data=self.PARSED, + webhook_enabled=True, + webhook_url="https://hook.example.com", + highlight_data=self.HIGHLIGHT, + ) + assert result == (self.PARSED, self.HIGHLIGHT) + + @patch("executor.executors.postprocessor.requests.post") + def test_json_decode_error_returns_original(self, mock_post): + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.side_effect = json.JSONDecodeError("err", "doc", 0) + mock_post.return_value = mock_resp + + result = postprocess_data( + parsed_data=self.PARSED, + webhook_enabled=True, + webhook_url="https://hook.example.com", + highlight_data=self.HIGHLIGHT, + ) + assert result == (self.PARSED, self.HIGHLIGHT) + + @patch("executor.executors.postprocessor.requests.post") + def test_custom_timeout_passed(self, mock_post): + """timeout=5.0 is passed to requests.post().""" + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"structured_output": {"f": "v"}} + mock_post.return_value = mock_resp + + postprocess_data( + parsed_data=self.PARSED, + webhook_enabled=True, + webhook_url="https://hook.example.com", + timeout=5.0, + ) + _, kwargs = mock_post.call_args + assert kwargs["timeout"] == 5.0 + + # --- _validate_structured_output --- + + def test_validate_structured_output_dict(self): + assert _validate_structured_output({"k": "v"}) is True + + def test_validate_structured_output_list(self): + assert _validate_structured_output([1, 2]) is True diff --git a/workers/tests/test_phase5d.py b/workers/tests/test_phase5d.py new file mode 100644 index 0000000000..c5b0a0640a --- /dev/null +++ b/workers/tests/test_phase5d.py @@ -0,0 +1,900 @@ +"""Phase 5D — Tests for structure_pipeline compound operation. + +Tests _handle_structure_pipeline in LegacyExecutor which runs the full +extract → summarize → index → answer_prompt pipeline in a single +executor invocation. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from unstract.sdk1.execution.context import ExecutionContext, Operation +from unstract.sdk1.execution.result import ExecutionResult + +# --------------------------------------------------------------------------- +# Patch targets — all at source in executor.executors.legacy_executor +# --------------------------------------------------------------------------- + +_PATCH_FILE_UTILS = "executor.executors.file_utils.FileUtils.get_fs_instance" +_PATCH_INDEXING_DEPS = ( + "executor.executors.legacy_executor.LegacyExecutor._get_indexing_deps" +) +_PATCH_PROMPT_DEPS = ( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def executor(): + """Create a LegacyExecutor instance.""" + from executor.executors.legacy_executor import LegacyExecutor + + return LegacyExecutor() + + +@pytest.fixture +def mock_fs(): + """Mock filesystem.""" + fs = MagicMock(name="file_storage") + fs.exists.return_value = False + fs.read.return_value = "" + fs.write.return_value = None + fs.get_hash_from_file.return_value = "hash123" + return fs + + +def _make_pipeline_context( + executor_params: dict, + run_id: str = "run-1", + organization_id: str = "org-1", +) -> ExecutionContext: + """Build a structure_pipeline ExecutionContext.""" + return ExecutionContext( + executor_name="legacy", + operation=Operation.STRUCTURE_PIPELINE.value, + run_id=run_id, + execution_source="tool", + organization_id=organization_id, + request_id="req-1", + executor_params=executor_params, + ) + + +def _base_extract_params() -> dict: + """Extract params template.""" + return { + "x2text_instance_id": "x2t-1", + "file_path": "/data/test.pdf", + "enable_highlight": False, + "output_file_path": "/data/exec/EXTRACT", + "platform_api_key": "sk-test", + "usage_kwargs": {"run_id": "run-1", "file_name": "test.pdf"}, + } + + +def _base_index_template() -> dict: + """Index template.""" + return { + "tool_id": "tool-1", + "file_hash": "hash-abc", + "is_highlight_enabled": False, + "platform_api_key": "sk-test", + "extracted_file_path": "/data/exec/EXTRACT", + } + + +def _base_answer_params() -> dict: + """Answer params (payload for answer_prompt).""" + return { + "run_id": "run-1", + "tool_settings": { + "vector-db": "vdb-1", + "embedding": "emb-1", + "x2text_adapter": "x2t-1", + "llm": "llm-1", + "challenge_llm": "", + "enable_challenge": False, + "enable_single_pass_extraction": False, + "summarize_as_source": False, + "enable_highlight": False, + }, + "outputs": [ + { + "name": "field_a", + "prompt": "What is the revenue?", + "type": "text", + "active": True, + "chunk-size": 512, + "chunk-overlap": 128, + "llm": "llm-1", + "embedding": "emb-1", + "vector-db": "vdb-1", + "x2text_adapter": "x2t-1", + "retrieval-strategy": "simple", + "similarity-top-k": 5, + }, + ], + "tool_id": "tool-1", + "file_hash": "hash-abc", + "file_name": "test.pdf", + "file_path": "/data/exec/EXTRACT", + "execution_source": "tool", + "custom_data": {}, + "PLATFORM_SERVICE_API_KEY": "sk-test", + } + + +def _base_pipeline_options() -> dict: + """Default pipeline options.""" + return { + "skip_extraction_and_indexing": False, + "is_summarization_enabled": False, + "is_single_pass_enabled": False, + "input_file_path": "/data/test.pdf", + "source_file_name": "test.pdf", + } + + +# --------------------------------------------------------------------------- +# Tests — Operation enum and routing +# --------------------------------------------------------------------------- + + +class TestStructurePipelineEnum: + """Verify enum and operation map registration.""" + + def test_operation_enum_exists(self): + assert Operation.STRUCTURE_PIPELINE.value == "structure_pipeline" + + def test_operation_map_has_structure_pipeline(self, executor): + assert "structure_pipeline" in executor._OPERATION_MAP + + +# --------------------------------------------------------------------------- +# Tests — Normal pipeline: extract → index → answer_prompt +# --------------------------------------------------------------------------- + + +class TestNormalPipeline: + """Normal pipeline: extract + index + answer_prompt.""" + + def test_extract_index_answer(self, executor): + """Full pipeline calls extract, index, and answer_prompt.""" + extract_result = ExecutionResult( + success=True, data={"extracted_text": "Revenue is $1M"} + ) + index_result = ExecutionResult( + success=True, data={"doc_id": "doc-1"} + ) + answer_result = ExecutionResult( + success=True, + data={ + "output": {"field_a": "$1M"}, + "metadata": {}, + "metrics": {"field_a": {"llm": {"time_taken(s)": 1.0}}}, + }, + ) + + executor._handle_extract = MagicMock(return_value=extract_result) + executor._handle_index = MagicMock(return_value=index_result) + executor._handle_answer_prompt = MagicMock( + return_value=answer_result + ) + + ctx = _make_pipeline_context({ + "extract_params": _base_extract_params(), + "index_template": _base_index_template(), + "answer_params": _base_answer_params(), + "pipeline_options": _base_pipeline_options(), + }) + + result = executor._handle_structure_pipeline(ctx) + + assert result.success + assert executor._handle_extract.call_count == 1 + assert executor._handle_index.call_count == 1 + assert executor._handle_answer_prompt.call_count == 1 + + def test_result_has_metadata_and_file_name(self, executor): + """Result includes source_file_name in metadata.""" + executor._handle_extract = MagicMock( + return_value=ExecutionResult( + success=True, data={"extracted_text": "text"} + ) + ) + executor._handle_index = MagicMock( + return_value=ExecutionResult( + success=True, data={"doc_id": "d1"} + ) + ) + executor._handle_answer_prompt = MagicMock( + return_value=ExecutionResult( + success=True, data={"output": {}, "metadata": {}} + ) + ) + + ctx = _make_pipeline_context({ + "extract_params": _base_extract_params(), + "index_template": _base_index_template(), + "answer_params": _base_answer_params(), + "pipeline_options": _base_pipeline_options(), + }) + result = executor._handle_structure_pipeline(ctx) + + assert result.success + assert result.data["metadata"]["file_name"] == "test.pdf" + + def test_extracted_text_in_metadata(self, executor): + """Extracted text is added to result metadata.""" + executor._handle_extract = MagicMock( + return_value=ExecutionResult( + success=True, data={"extracted_text": "Revenue $1M"} + ) + ) + executor._handle_index = MagicMock( + return_value=ExecutionResult( + success=True, data={"doc_id": "d1"} + ) + ) + executor._handle_answer_prompt = MagicMock( + return_value=ExecutionResult( + success=True, data={"output": {}} + ) + ) + + ctx = _make_pipeline_context({ + "extract_params": _base_extract_params(), + "index_template": _base_index_template(), + "answer_params": _base_answer_params(), + "pipeline_options": _base_pipeline_options(), + }) + result = executor._handle_structure_pipeline(ctx) + + assert result.data["metadata"]["extracted_text"] == "Revenue $1M" + + def test_index_metrics_merged(self, executor): + """Index metrics are merged into answer metrics.""" + executor._handle_extract = MagicMock( + return_value=ExecutionResult( + success=True, data={"extracted_text": "text"} + ) + ) + executor._handle_index = MagicMock( + return_value=ExecutionResult( + success=True, data={"doc_id": "d1"} + ) + ) + executor._handle_answer_prompt = MagicMock( + return_value=ExecutionResult( + success=True, + data={ + "output": {}, + "metrics": { + "field_a": {"llm": {"time_taken(s)": 2.0}}, + }, + }, + ) + ) + # Simulate index metrics by patching _run_pipeline_index + executor._run_pipeline_index = MagicMock( + return_value={ + "field_a": {"indexing": {"time_taken(s)": 0.5}}, + } + ) + + ctx = _make_pipeline_context({ + "extract_params": _base_extract_params(), + "index_template": _base_index_template(), + "answer_params": _base_answer_params(), + "pipeline_options": _base_pipeline_options(), + }) + result = executor._handle_structure_pipeline(ctx) + + assert result.success + metrics = result.data["metrics"] + # Both llm and indexing metrics for field_a should be merged + assert "llm" in metrics["field_a"] + assert "indexing" in metrics["field_a"] + + +# --------------------------------------------------------------------------- +# Tests — Extract failure propagation +# --------------------------------------------------------------------------- + + +class TestExtractFailure: + """Extract failure stops the pipeline.""" + + def test_extract_failure_stops_pipeline(self, executor): + executor._handle_extract = MagicMock( + return_value=ExecutionResult.failure(error="x2text error") + ) + executor._handle_index = MagicMock() + executor._handle_answer_prompt = MagicMock() + + ctx = _make_pipeline_context({ + "extract_params": _base_extract_params(), + "index_template": _base_index_template(), + "answer_params": _base_answer_params(), + "pipeline_options": _base_pipeline_options(), + }) + result = executor._handle_structure_pipeline(ctx) + + assert not result.success + assert "x2text error" in result.error + executor._handle_index.assert_not_called() + executor._handle_answer_prompt.assert_not_called() + + +# --------------------------------------------------------------------------- +# Tests — Skip extraction (smart table) +# --------------------------------------------------------------------------- + + +class TestSkipExtraction: + """Smart table: skip extract+index, use source file.""" + + def test_skip_extraction_uses_input_file(self, executor): + executor._handle_extract = MagicMock() + executor._handle_index = MagicMock() + executor._handle_answer_prompt = MagicMock( + return_value=ExecutionResult( + success=True, data={"output": {}} + ) + ) + + opts = _base_pipeline_options() + opts["skip_extraction_and_indexing"] = True + answer = _base_answer_params() + + ctx = _make_pipeline_context({ + "extract_params": _base_extract_params(), + "index_template": _base_index_template(), + "answer_params": answer, + "pipeline_options": opts, + }) + result = executor._handle_structure_pipeline(ctx) + + assert result.success + executor._handle_extract.assert_not_called() + executor._handle_index.assert_not_called() + # file_path should be set to input_file_path + call_ctx = executor._handle_answer_prompt.call_args[0][0] + assert call_ctx.executor_params["file_path"] == "/data/test.pdf" + + def test_skip_extraction_table_settings_injection(self, executor): + """Table settings get input_file when extraction is skipped.""" + executor._handle_answer_prompt = MagicMock( + return_value=ExecutionResult( + success=True, data={"output": {}} + ) + ) + + opts = _base_pipeline_options() + opts["skip_extraction_and_indexing"] = True + answer = _base_answer_params() + answer["outputs"][0]["table_settings"] = { + "is_directory_mode": False, + } + + ctx = _make_pipeline_context({ + "extract_params": _base_extract_params(), + "index_template": _base_index_template(), + "answer_params": answer, + "pipeline_options": opts, + }) + result = executor._handle_structure_pipeline(ctx) + + assert result.success + ts = answer["outputs"][0]["table_settings"] + assert ts["input_file"] == "/data/test.pdf" + + +# --------------------------------------------------------------------------- +# Tests — Single pass extraction +# --------------------------------------------------------------------------- + + +class TestSinglePass: + """Single pass: extract + answer_prompt (no indexing).""" + + def test_single_pass_skips_index(self, executor): + executor._handle_extract = MagicMock( + return_value=ExecutionResult( + success=True, data={"extracted_text": "text"} + ) + ) + executor._handle_index = MagicMock() + executor._handle_answer_prompt = MagicMock( + return_value=ExecutionResult( + success=True, data={"output": {}} + ) + ) + + opts = _base_pipeline_options() + opts["is_single_pass_enabled"] = True + + ctx = _make_pipeline_context({ + "extract_params": _base_extract_params(), + "index_template": _base_index_template(), + "answer_params": _base_answer_params(), + "pipeline_options": opts, + }) + result = executor._handle_structure_pipeline(ctx) + + assert result.success + executor._handle_extract.assert_called_once() + executor._handle_index.assert_not_called() + executor._handle_answer_prompt.assert_called_once() + + def test_single_pass_operation_is_single_pass(self, executor): + """The answer_prompt call uses single_pass_extraction operation.""" + executor._handle_extract = MagicMock( + return_value=ExecutionResult( + success=True, data={"extracted_text": "text"} + ) + ) + executor._handle_answer_prompt = MagicMock( + return_value=ExecutionResult( + success=True, data={"output": {}} + ) + ) + + opts = _base_pipeline_options() + opts["is_single_pass_enabled"] = True + + ctx = _make_pipeline_context({ + "extract_params": _base_extract_params(), + "index_template": _base_index_template(), + "answer_params": _base_answer_params(), + "pipeline_options": opts, + }) + executor._handle_structure_pipeline(ctx) + + call_ctx = executor._handle_answer_prompt.call_args[0][0] + assert call_ctx.operation == "single_pass_extraction" + + +# --------------------------------------------------------------------------- +# Tests — Summarize pipeline +# --------------------------------------------------------------------------- + + +class TestSummarizePipeline: + """Summarize: extract + summarize + answer_prompt (no indexing).""" + + @patch(_PATCH_FILE_UTILS) + def test_summarize_calls_handle_summarize( + self, mock_get_fs, executor, mock_fs + ): + mock_get_fs.return_value = mock_fs + mock_fs.exists.return_value = False + mock_fs.read.return_value = "extracted text for summarize" + + executor._handle_extract = MagicMock( + return_value=ExecutionResult( + success=True, data={"extracted_text": "text"} + ) + ) + executor._handle_summarize = MagicMock( + return_value=ExecutionResult( + success=True, data={"data": "summarized text"} + ) + ) + executor._handle_index = MagicMock() + executor._handle_answer_prompt = MagicMock( + return_value=ExecutionResult( + success=True, data={"output": {}} + ) + ) + + opts = _base_pipeline_options() + opts["is_summarization_enabled"] = True + + ctx = _make_pipeline_context({ + "extract_params": _base_extract_params(), + "index_template": _base_index_template(), + "answer_params": _base_answer_params(), + "pipeline_options": opts, + "summarize_params": { + "llm_adapter_instance_id": "llm-1", + "summarize_prompt": "Summarize this", + "extract_file_path": "/data/exec/EXTRACT", + "summarize_file_path": "/data/exec/SUMMARIZE", + "platform_api_key": "sk-test", + "prompt_keys": ["field_a"], + }, + }) + result = executor._handle_structure_pipeline(ctx) + + assert result.success + executor._handle_summarize.assert_called_once() + executor._handle_index.assert_not_called() + + @patch(_PATCH_FILE_UTILS) + def test_summarize_uses_cache(self, mock_get_fs, executor, mock_fs): + """If cached summary exists, _handle_summarize is NOT called.""" + mock_get_fs.return_value = mock_fs + mock_fs.exists.return_value = True + mock_fs.read.return_value = "cached summary" + + executor._handle_extract = MagicMock( + return_value=ExecutionResult( + success=True, data={"extracted_text": "text"} + ) + ) + executor._handle_summarize = MagicMock() + executor._handle_answer_prompt = MagicMock( + return_value=ExecutionResult( + success=True, data={"output": {}} + ) + ) + + opts = _base_pipeline_options() + opts["is_summarization_enabled"] = True + + ctx = _make_pipeline_context({ + "extract_params": _base_extract_params(), + "index_template": _base_index_template(), + "answer_params": _base_answer_params(), + "pipeline_options": opts, + "summarize_params": { + "llm_adapter_instance_id": "llm-1", + "summarize_prompt": "Summarize this", + "extract_file_path": "/data/exec/EXTRACT", + "summarize_file_path": "/data/exec/SUMMARIZE", + "platform_api_key": "sk-test", + "prompt_keys": ["field_a"], + }, + }) + result = executor._handle_structure_pipeline(ctx) + + assert result.success + executor._handle_summarize.assert_not_called() + + @patch(_PATCH_FILE_UTILS) + def test_summarize_updates_answer_params( + self, mock_get_fs, executor, mock_fs + ): + """After summarize, answer_params file_path and hash are updated.""" + mock_get_fs.return_value = mock_fs + mock_fs.exists.return_value = False + mock_fs.read.return_value = "doc text" + mock_fs.get_hash_from_file.return_value = "sum-hash-456" + + executor._handle_extract = MagicMock( + return_value=ExecutionResult( + success=True, data={"extracted_text": "text"} + ) + ) + executor._handle_summarize = MagicMock( + return_value=ExecutionResult( + success=True, data={"data": "summarized"} + ) + ) + executor._handle_answer_prompt = MagicMock( + return_value=ExecutionResult( + success=True, data={"output": {}} + ) + ) + + answer = _base_answer_params() + opts = _base_pipeline_options() + opts["is_summarization_enabled"] = True + + ctx = _make_pipeline_context({ + "extract_params": _base_extract_params(), + "index_template": _base_index_template(), + "answer_params": answer, + "pipeline_options": opts, + "summarize_params": { + "llm_adapter_instance_id": "llm-1", + "summarize_prompt": "Summarize", + "extract_file_path": "/data/exec/EXTRACT", + "summarize_file_path": "/data/exec/SUMMARIZE", + "platform_api_key": "sk-test", + "prompt_keys": [], + }, + }) + executor._handle_structure_pipeline(ctx) + + # Check answer_params were updated + assert answer["file_hash"] == "sum-hash-456" + assert answer["file_path"] == "/data/exec/SUMMARIZE" + + @patch(_PATCH_FILE_UTILS) + def test_summarize_sets_chunk_size_zero( + self, mock_get_fs, executor, mock_fs + ): + """Summarize sets chunk-size=0 for all outputs.""" + mock_get_fs.return_value = mock_fs + mock_fs.exists.return_value = True + mock_fs.read.return_value = "cached" + + executor._handle_extract = MagicMock( + return_value=ExecutionResult( + success=True, data={"extracted_text": "t"} + ) + ) + executor._handle_answer_prompt = MagicMock( + return_value=ExecutionResult( + success=True, data={"output": {}} + ) + ) + + answer = _base_answer_params() + opts = _base_pipeline_options() + opts["is_summarization_enabled"] = True + + ctx = _make_pipeline_context({ + "extract_params": _base_extract_params(), + "index_template": _base_index_template(), + "answer_params": answer, + "pipeline_options": opts, + "summarize_params": { + "llm_adapter_instance_id": "llm-1", + "summarize_prompt": "Summarize", + "extract_file_path": "/data/exec/EXTRACT", + "summarize_file_path": "/data/exec/SUMMARIZE", + "platform_api_key": "sk-test", + "prompt_keys": [], + }, + }) + executor._handle_structure_pipeline(ctx) + + # Outputs should have chunk-size=0 + for output in answer["outputs"]: + assert output["chunk-size"] == 0 + assert output["chunk-overlap"] == 0 + + +# --------------------------------------------------------------------------- +# Tests — Index dedup +# --------------------------------------------------------------------------- + + +class TestIndexDedup: + """Index step deduplication.""" + + def test_index_dedup_skips_duplicate_params(self, executor): + """Duplicate param combos are only indexed once.""" + executor._handle_extract = MagicMock( + return_value=ExecutionResult( + success=True, data={"extracted_text": "text"} + ) + ) + index_call_count = 0 + original_index = executor._handle_index + + def counting_index(ctx): + nonlocal index_call_count + index_call_count += 1 + return ExecutionResult(success=True, data={"doc_id": "d1"}) + + executor._handle_index = counting_index + executor._handle_answer_prompt = MagicMock( + return_value=ExecutionResult( + success=True, data={"output": {}} + ) + ) + + answer = _base_answer_params() + # Add a second output with same adapter params + answer["outputs"].append({ + "name": "field_b", + "prompt": "What is the profit?", + "type": "text", + "active": True, + "chunk-size": 512, + "chunk-overlap": 128, + "llm": "llm-1", + "embedding": "emb-1", + "vector-db": "vdb-1", + "x2text_adapter": "x2t-1", + }) + + ctx = _make_pipeline_context({ + "extract_params": _base_extract_params(), + "index_template": _base_index_template(), + "answer_params": answer, + "pipeline_options": _base_pipeline_options(), + }) + result = executor._handle_structure_pipeline(ctx) + + assert result.success + # Only one index call despite two outputs (same params) + assert index_call_count == 1 + + def test_index_different_params_indexes_both(self, executor): + """Different param combos are indexed separately.""" + executor._handle_extract = MagicMock( + return_value=ExecutionResult( + success=True, data={"extracted_text": "text"} + ) + ) + index_call_count = 0 + + def counting_index(ctx): + nonlocal index_call_count + index_call_count += 1 + return ExecutionResult(success=True, data={"doc_id": "d1"}) + + executor._handle_index = counting_index + executor._handle_answer_prompt = MagicMock( + return_value=ExecutionResult( + success=True, data={"output": {}} + ) + ) + + answer = _base_answer_params() + answer["outputs"].append({ + "name": "field_b", + "prompt": "What is the profit?", + "type": "text", + "active": True, + "chunk-size": 256, # Different chunk size + "chunk-overlap": 64, + "llm": "llm-1", + "embedding": "emb-1", + "vector-db": "vdb-1", + "x2text_adapter": "x2t-1", + }) + + ctx = _make_pipeline_context({ + "extract_params": _base_extract_params(), + "index_template": _base_index_template(), + "answer_params": answer, + "pipeline_options": _base_pipeline_options(), + }) + result = executor._handle_structure_pipeline(ctx) + + assert result.success + assert index_call_count == 2 + + def test_chunk_size_zero_skips_index(self, executor): + """chunk-size=0 outputs skip indexing entirely.""" + executor._handle_extract = MagicMock( + return_value=ExecutionResult( + success=True, data={"extracted_text": "text"} + ) + ) + executor._handle_index = MagicMock() + executor._handle_answer_prompt = MagicMock( + return_value=ExecutionResult( + success=True, data={"output": {}} + ) + ) + + answer = _base_answer_params() + answer["outputs"][0]["chunk-size"] = 0 + + ctx = _make_pipeline_context({ + "extract_params": _base_extract_params(), + "index_template": _base_index_template(), + "answer_params": answer, + "pipeline_options": _base_pipeline_options(), + }) + result = executor._handle_structure_pipeline(ctx) + + assert result.success + executor._handle_index.assert_not_called() + + +# --------------------------------------------------------------------------- +# Tests — Answer prompt failure +# --------------------------------------------------------------------------- + + +class TestAnswerPromptFailure: + """Answer prompt failure propagates correctly.""" + + def test_answer_failure_propagates(self, executor): + executor._handle_extract = MagicMock( + return_value=ExecutionResult( + success=True, data={"extracted_text": "text"} + ) + ) + executor._handle_index = MagicMock( + return_value=ExecutionResult( + success=True, data={"doc_id": "d1"} + ) + ) + executor._handle_answer_prompt = MagicMock( + return_value=ExecutionResult.failure(error="LLM timeout") + ) + + ctx = _make_pipeline_context({ + "extract_params": _base_extract_params(), + "index_template": _base_index_template(), + "answer_params": _base_answer_params(), + "pipeline_options": _base_pipeline_options(), + }) + result = executor._handle_structure_pipeline(ctx) + + assert not result.success + assert "LLM timeout" in result.error + + +# --------------------------------------------------------------------------- +# Tests — Merge metrics utility +# --------------------------------------------------------------------------- + + +class TestMergeMetrics: + """Test _merge_pipeline_metrics.""" + + def test_merge_disjoint(self, executor): + m = executor._merge_pipeline_metrics( + {"a": {"x": 1}}, {"b": {"y": 2}} + ) + assert m == {"a": {"x": 1}, "b": {"y": 2}} + + def test_merge_overlapping(self, executor): + m = executor._merge_pipeline_metrics( + {"a": {"x": 1}}, {"a": {"y": 2}} + ) + assert m == {"a": {"x": 1, "y": 2}} + + def test_merge_non_dict_values(self, executor): + m = executor._merge_pipeline_metrics( + {"a": 1}, {"b": 2} + ) + assert m == {"a": 1, "b": 2} + + +# --------------------------------------------------------------------------- +# Tests — Sub-context creation +# --------------------------------------------------------------------------- + + +class TestSubContextCreation: + """Verify sub-contexts inherit parent context fields.""" + + def test_extract_context_inherits_fields(self, executor): + """Extract sub-context gets run_id, org_id, etc. from parent.""" + executor._handle_extract = MagicMock( + return_value=ExecutionResult( + success=True, data={"extracted_text": "text"} + ) + ) + executor._handle_index = MagicMock( + return_value=ExecutionResult( + success=True, data={"doc_id": "d1"} + ) + ) + executor._handle_answer_prompt = MagicMock( + return_value=ExecutionResult( + success=True, data={"output": {}} + ) + ) + + ctx = _make_pipeline_context( + { + "extract_params": _base_extract_params(), + "index_template": _base_index_template(), + "answer_params": _base_answer_params(), + "pipeline_options": _base_pipeline_options(), + }, + run_id="custom-run", + organization_id="custom-org", + ) + executor._handle_structure_pipeline(ctx) + + extract_ctx = executor._handle_extract.call_args[0][0] + assert extract_ctx.run_id == "custom-run" + assert extract_ctx.organization_id == "custom-org" + assert extract_ctx.operation == "extract" + + index_ctx = executor._handle_index.call_args[0][0] + assert index_ctx.run_id == "custom-run" + assert index_ctx.operation == "index" + + answer_ctx = executor._handle_answer_prompt.call_args[0][0] + assert answer_ctx.run_id == "custom-run" + assert answer_ctx.operation == "answer_prompt" diff --git a/workers/tests/test_retrieval.py b/workers/tests/test_retrieval.py new file mode 100644 index 0000000000..a92ce08808 --- /dev/null +++ b/workers/tests/test_retrieval.py @@ -0,0 +1,275 @@ +"""Tests for the RetrievalService factory and complete-context path. + +Retriever internals are NOT tested here — they're llama_index wrappers +that will be validated in Phase 2-SANITY integration tests. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from executor.executors.constants import RetrievalStrategy +from executor.executors.retrieval import RetrievalService + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_output(prompt: str = "What is X?", top_k: int = 5, name: str = "field_a"): + """Build a minimal ``output`` dict matching PromptServiceConstants keys.""" + return { + "promptx": prompt, + "similarity-top-k": top_k, + "name": name, + } + + +def _mock_retriever_class(return_value=None): + """Return a mock class whose instances have a ``.retrieve()`` method.""" + if return_value is None: + return_value = {"chunk1", "chunk2"} + cls = MagicMock() + instance = MagicMock() + instance.retrieve.return_value = return_value + cls.return_value = instance + return cls, instance + + +# --------------------------------------------------------------------------- +# Factory — run_retrieval +# --------------------------------------------------------------------------- + +class TestRunRetrieval: + """Tests for RetrievalService.run_retrieval().""" + + @pytest.mark.parametrize("strategy", list(RetrievalStrategy)) + @patch("executor.executors.retrieval.RetrievalService._get_retriever_map") + def test_correct_class_selected_for_each_strategy(self, mock_map, strategy): + """Factory returns the correct retriever class for each strategy.""" + cls, _inst = _mock_retriever_class() + mock_map.return_value = {strategy.value: cls} + + result = RetrievalService.run_retrieval( + output=_make_output(), + doc_id="doc-1", + llm=MagicMock(), + vector_db=MagicMock(), + retrieval_type=strategy.value, + ) + cls.assert_called_once() + assert isinstance(result, list) + + @patch("executor.executors.retrieval.RetrievalService._get_retriever_map") + def test_unknown_strategy_raises_value_error(self, mock_map): + """Passing an invalid strategy string raises ValueError.""" + mock_map.return_value = {} + + with pytest.raises(ValueError, match="Unknown retrieval type"): + RetrievalService.run_retrieval( + output=_make_output(), + doc_id="doc-1", + llm=MagicMock(), + vector_db=MagicMock(), + retrieval_type="nonexistent", + ) + + @patch("executor.executors.retrieval.RetrievalService._get_retriever_map") + def test_retriever_instantiated_with_correct_params(self, mock_map): + """Verify vector_db, doc_id, prompt, top_k, llm passed through.""" + cls, _inst = _mock_retriever_class() + mock_map.return_value = {RetrievalStrategy.SIMPLE.value: cls} + + llm = MagicMock(name="llm") + vdb = MagicMock(name="vdb") + output = _make_output(prompt="Find revenue", top_k=10, name="revenue") + + RetrievalService.run_retrieval( + output=output, + doc_id="doc-42", + llm=llm, + vector_db=vdb, + retrieval_type=RetrievalStrategy.SIMPLE.value, + ) + + cls.assert_called_once_with( + vector_db=vdb, + doc_id="doc-42", + prompt="Find revenue", + top_k=10, + llm=llm, + ) + + @patch("executor.executors.retrieval.RetrievalService._get_retriever_map") + def test_retrieve_result_converted_to_list(self, mock_map): + """Mock retriever returns a set; run_retrieval returns a list.""" + cls, _inst = _mock_retriever_class(return_value={"a", "b", "c"}) + mock_map.return_value = {RetrievalStrategy.FUSION.value: cls} + + result = RetrievalService.run_retrieval( + output=_make_output(), + doc_id="doc-1", + llm=MagicMock(), + vector_db=MagicMock(), + retrieval_type=RetrievalStrategy.FUSION.value, + ) + assert isinstance(result, list) + assert set(result) == {"a", "b", "c"} + + @patch("executor.executors.retrieval.RetrievalService._get_retriever_map") + def test_metrics_recorded(self, mock_map): + """Verify context_retrieval_metrics dict populated with timing.""" + cls, _inst = _mock_retriever_class() + mock_map.return_value = {RetrievalStrategy.SIMPLE.value: cls} + + metrics: dict = {} + RetrievalService.run_retrieval( + output=_make_output(name="my_field"), + doc_id="doc-1", + llm=MagicMock(), + vector_db=MagicMock(), + retrieval_type=RetrievalStrategy.SIMPLE.value, + context_retrieval_metrics=metrics, + ) + + assert "my_field" in metrics + assert "time_taken(s)" in metrics["my_field"] + assert isinstance(metrics["my_field"]["time_taken(s)"], float) + + @patch("executor.executors.retrieval.RetrievalService._get_retriever_map") + def test_metrics_optional_none_does_not_crash(self, mock_map): + """context_retrieval_metrics=None doesn't crash.""" + cls, _inst = _mock_retriever_class() + mock_map.return_value = {RetrievalStrategy.SIMPLE.value: cls} + + # Should not raise + RetrievalService.run_retrieval( + output=_make_output(), + doc_id="doc-1", + llm=MagicMock(), + vector_db=MagicMock(), + retrieval_type=RetrievalStrategy.SIMPLE.value, + context_retrieval_metrics=None, + ) + + +# --------------------------------------------------------------------------- +# Complete context — retrieve_complete_context +# --------------------------------------------------------------------------- + +class TestRetrieveCompleteContext: + """Tests for RetrievalService.retrieve_complete_context().""" + + @patch("executor.executors.file_utils.FileUtils.get_fs_instance") + def test_reads_file_with_correct_path(self, mock_get_fs): + """Mock FileUtils.get_fs_instance, verify fs.read() called correctly.""" + mock_fs = MagicMock() + mock_fs.read.return_value = "full document text" + mock_get_fs.return_value = mock_fs + + RetrievalService.retrieve_complete_context( + execution_source="ide", + file_path="/data/doc.txt", + ) + + mock_get_fs.assert_called_once_with(execution_source="ide") + mock_fs.read.assert_called_once_with(path="/data/doc.txt", mode="r") + + @patch("executor.executors.file_utils.FileUtils.get_fs_instance") + def test_returns_list_with_single_item(self, mock_get_fs): + """Verify [content] shape.""" + mock_fs = MagicMock() + mock_fs.read.return_value = "hello world" + mock_get_fs.return_value = mock_fs + + result = RetrievalService.retrieve_complete_context( + execution_source="tool", + file_path="/data/doc.txt", + ) + + assert result == ["hello world"] + assert len(result) == 1 + + @patch("executor.executors.file_utils.FileUtils.get_fs_instance") + def test_complete_context_records_metrics(self, mock_get_fs): + """Timing dict populated.""" + mock_fs = MagicMock() + mock_fs.read.return_value = "content" + mock_get_fs.return_value = mock_fs + + metrics: dict = {} + RetrievalService.retrieve_complete_context( + execution_source="ide", + file_path="/data/doc.txt", + context_retrieval_metrics=metrics, + prompt_key="total_revenue", + ) + + assert "total_revenue" in metrics + assert "time_taken(s)" in metrics["total_revenue"] + assert isinstance(metrics["total_revenue"]["time_taken(s)"], float) + + @patch("executor.executors.file_utils.FileUtils.get_fs_instance") + def test_complete_context_metrics_none_does_not_crash(self, mock_get_fs): + """context_retrieval_metrics=None doesn't crash.""" + mock_fs = MagicMock() + mock_fs.read.return_value = "content" + mock_get_fs.return_value = mock_fs + + # Should not raise + RetrievalService.retrieve_complete_context( + execution_source="ide", + file_path="/data/doc.txt", + context_retrieval_metrics=None, + ) + + +# --------------------------------------------------------------------------- +# BaseRetriever interface +# --------------------------------------------------------------------------- + +class TestBaseRetriever: + """Tests for BaseRetriever base class.""" + + def test_default_retrieve_returns_empty_set(self): + """Default retrieve() returns empty set.""" + from executor.executors.retrievers.base_retriever import BaseRetriever + + r = BaseRetriever( + vector_db=MagicMock(), + prompt="test", + doc_id="doc-1", + top_k=5, + ) + assert r.retrieve() == set() + + def test_constructor_stores_all_params(self): + """Constructor stores vector_db, prompt, doc_id, top_k, llm.""" + from executor.executors.retrievers.base_retriever import BaseRetriever + + vdb = MagicMock(name="vdb") + llm = MagicMock(name="llm") + r = BaseRetriever( + vector_db=vdb, + prompt="my prompt", + doc_id="doc-99", + top_k=3, + llm=llm, + ) + assert r.vector_db is vdb + assert r.prompt == "my prompt" + assert r.doc_id == "doc-99" + assert r.top_k == 3 + assert r.llm is llm + + def test_constructor_llm_defaults_to_none(self): + """When llm not provided, it defaults to None.""" + from executor.executors.retrievers.base_retriever import BaseRetriever + + r = BaseRetriever( + vector_db=MagicMock(), + prompt="test", + doc_id="doc-1", + top_k=5, + ) + assert r.llm is None diff --git a/workers/tests/test_sanity_phase2.py b/workers/tests/test_sanity_phase2.py new file mode 100644 index 0000000000..2aaeb81730 --- /dev/null +++ b/workers/tests/test_sanity_phase2.py @@ -0,0 +1,792 @@ +"""Phase 2-SANITY — Full-chain integration tests for LegacyExecutor. + +All Phase 2 code and unit tests are complete (2A–2H, 194 workers tests). +This file bridges unit tests and real integration by testing the full +Celery chain: + + task.apply() → execute_extraction task → ExecutionOrchestrator + → ExecutorRegistry.get("legacy") → LegacyExecutor.execute() + → _handle_X() → ExecutionResult + +All in Celery eager mode (no broker needed). External adapters +(X2Text, LLM, VectorDB) are mocked. +""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from executor.executors.constants import ( + IndexingConstants as IKeys, + PromptServiceConstants as PSKeys, +) +from unstract.sdk1.execution.context import ExecutionContext, Operation +from unstract.sdk1.execution.registry import ExecutorRegistry +from unstract.sdk1.execution.result import ExecutionResult + +# --------------------------------------------------------------------------- +# Patch targets +# --------------------------------------------------------------------------- + +_PATCH_X2TEXT = "executor.executors.legacy_executor.X2Text" +_PATCH_FS = "executor.executors.legacy_executor.FileUtils.get_fs_instance" +_PATCH_INDEX_DEPS = ( + "executor.executors.legacy_executor.LegacyExecutor._get_indexing_deps" +) +_PATCH_PROMPT_DEPS = ( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" +) +_PATCH_SHIM = "executor.executors.legacy_executor.ExecutorToolShim" +_PATCH_RUN_COMPLETION = ( + "executor.executors.answer_prompt.AnswerPromptService.run_completion" +) +_PATCH_INDEX_UTILS = ( + "unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key" +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _ensure_legacy_registered(): + """Ensure LegacyExecutor is registered without clearing other state. + + Unlike unit tests that clear() + re-register, sanity tests need + LegacyExecutor always present. We add it idempotently. + """ + from executor.executors.legacy_executor import LegacyExecutor + + if "legacy" not in ExecutorRegistry.list_executors(): + ExecutorRegistry._registry["legacy"] = LegacyExecutor + yield + + +@pytest.fixture +def eager_app(): + """Configure the real executor Celery app for eager-mode testing.""" + from executor.worker import app + + original = { + "task_always_eager": app.conf.task_always_eager, + "task_eager_propagates": app.conf.task_eager_propagates, + "result_backend": app.conf.result_backend, + } + app.conf.update( + task_always_eager=True, + task_eager_propagates=False, + result_backend="cache+memory://", + ) + yield app + app.conf.update(original) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _run_task(eager_app, ctx_dict): + """Run execute_extraction task via task.apply() (eager-safe).""" + task = eager_app.tasks["execute_extraction"] + result = task.apply(args=[ctx_dict]) + return result.get() + + +def _mock_llm(answer="sanity answer"): + """Create a mock LLM matching the test_answer_prompt.py pattern.""" + llm = MagicMock(name="llm") + response = MagicMock() + response.text = answer + llm.complete.return_value = { + PSKeys.RESPONSE: response, + PSKeys.HIGHLIGHT_DATA: [], + PSKeys.CONFIDENCE_DATA: None, + PSKeys.WORD_CONFIDENCE_DATA: None, + PSKeys.LINE_NUMBERS: [], + PSKeys.WHISPER_HASH: "", + } + llm.get_usage_reason.return_value = "extraction" + llm.get_metrics.return_value = {"tokens": 100} + return llm + + +def _mock_prompt_deps(llm=None): + """Return a 7-tuple matching _get_prompt_deps() return shape. + + Uses the real AnswerPromptService + mocked adapters. + """ + if llm is None: + llm = _mock_llm() + + from executor.executors.answer_prompt import AnswerPromptService + + RetrievalService = MagicMock(name="RetrievalService") + RetrievalService.run_retrieval.return_value = ["chunk1", "chunk2"] + RetrievalService.retrieve_complete_context.return_value = ["full content"] + + VariableReplacementService = MagicMock(name="VariableReplacementService") + VariableReplacementService.is_variables_present.return_value = False + + Index = MagicMock(name="Index") + index_instance = MagicMock() + index_instance.generate_index_key.return_value = "doc-id-sanity" + Index.return_value = index_instance + + LLM_cls = MagicMock(name="LLM") + LLM_cls.return_value = llm + + EmbeddingCompat = MagicMock(name="EmbeddingCompat") + VectorDB = MagicMock(name="VectorDB") + + return ( + AnswerPromptService, + RetrievalService, + VariableReplacementService, + Index, + LLM_cls, + EmbeddingCompat, + VectorDB, + ) + + +def _mock_process_response(text="sanity extracted text"): + """Build a mock TextExtractionResult.""" + from unstract.sdk1.adapters.x2text.dto import ( + TextExtractionMetadata, + TextExtractionResult, + ) + + metadata = TextExtractionMetadata(whisper_hash="sanity-hash") + return TextExtractionResult( + extracted_text=text, + extraction_metadata=metadata, + ) + + +def _make_prompt(name="field_a", prompt="What is the revenue?", + output_type="text", **overrides): + """Build a minimal prompt definition dict.""" + d = { + PSKeys.NAME: name, + PSKeys.PROMPT: prompt, + PSKeys.TYPE: output_type, + PSKeys.CHUNK_SIZE: 512, + PSKeys.CHUNK_OVERLAP: 128, + PSKeys.RETRIEVAL_STRATEGY: "simple", + PSKeys.LLM: "llm-1", + PSKeys.EMBEDDING: "emb-1", + PSKeys.VECTOR_DB: "vdb-1", + PSKeys.X2TEXT_ADAPTER: "x2t-1", + PSKeys.SIMILARITY_TOP_K: 5, + } + d.update(overrides) + return d + + +# --- Context factories per operation --- + + +def _extract_ctx(**overrides): + defaults = { + "executor_name": "legacy", + "operation": "extract", + "run_id": "run-sanity-ext", + "execution_source": "tool", + "organization_id": "org-test", + "executor_params": { + "x2text_instance_id": "x2t-sanity", + "file_path": "/data/sanity.pdf", + "platform_api_key": "sk-sanity", + }, + } + defaults.update(overrides) + return ExecutionContext(**defaults) + + +def _index_ctx(**overrides): + defaults = { + "executor_name": "legacy", + "operation": "index", + "run_id": "run-sanity-idx", + "execution_source": "tool", + "organization_id": "org-test", + "executor_params": { + "embedding_instance_id": "emb-sanity", + "vector_db_instance_id": "vdb-sanity", + "x2text_instance_id": "x2t-sanity", + "file_path": "/data/sanity.pdf", + "file_hash": "sanity-hash", + "extracted_text": "Sanity test document text", + "platform_api_key": "sk-sanity", + "chunk_size": 512, + "chunk_overlap": 128, + }, + } + defaults.update(overrides) + return ExecutionContext(**defaults) + + +def _answer_prompt_ctx(prompts=None, **overrides): + if prompts is None: + prompts = [_make_prompt()] + defaults = { + "executor_name": "legacy", + "operation": Operation.ANSWER_PROMPT.value, + "run_id": "run-sanity-ap", + "execution_source": "ide", + "executor_params": { + PSKeys.OUTPUTS: prompts, + PSKeys.TOOL_SETTINGS: {}, + PSKeys.TOOL_ID: "tool-sanity", + PSKeys.EXECUTION_ID: "exec-sanity", + PSKeys.FILE_HASH: "hash-sanity", + PSKeys.FILE_PATH: "/data/sanity.txt", + PSKeys.FILE_NAME: "sanity.txt", + PSKeys.LOG_EVENTS_ID: "", + PSKeys.CUSTOM_DATA: {}, + PSKeys.EXECUTION_SOURCE: "ide", + PSKeys.PLATFORM_SERVICE_API_KEY: "pk-sanity", + }, + } + defaults.update(overrides) + return ExecutionContext(**defaults) + + +def _summarize_ctx(**overrides): + defaults = { + "executor_name": "legacy", + "operation": "summarize", + "run_id": "run-sanity-sum", + "execution_source": "tool", + "executor_params": { + "llm_adapter_instance_id": "llm-sanity", + "summarize_prompt": "Summarize the document.", + "context": "Long document content here.", + "prompt_keys": ["invoice_number", "total"], + "PLATFORM_SERVICE_API_KEY": "pk-sanity", + }, + } + defaults.update(overrides) + return ExecutionContext(**defaults) + + +# =========================================================================== +# Test classes +# =========================================================================== + + +class TestSanityExtract: + """Full-chain extract tests through Celery eager mode.""" + + @patch(_PATCH_FS) + @patch(_PATCH_X2TEXT) + def test_extract_full_chain(self, mock_x2text_cls, mock_get_fs, eager_app): + """Mocked X2Text + FileUtils → result.data has extracted_text.""" + mock_x2text = MagicMock() + mock_x2text.process.return_value = _mock_process_response( + "sanity extracted" + ) + mock_x2text.x2text_instance = MagicMock() + mock_x2text_cls.return_value = mock_x2text + mock_get_fs.return_value = MagicMock() + + ctx = _extract_ctx() + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is True + assert result.data[IKeys.EXTRACTED_TEXT] == "sanity extracted" + + @patch(_PATCH_FS) + @patch(_PATCH_X2TEXT) + def test_extract_missing_params_full_chain( + self, mock_x2text_cls, mock_get_fs, eager_app + ): + """Empty params → failure with missing fields message.""" + ctx = _extract_ctx(executor_params={"platform_api_key": "sk-test"}) + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is False + assert "x2text_instance_id" in result.error + assert "file_path" in result.error + + @patch(_PATCH_FS) + @patch(_PATCH_X2TEXT) + def test_extract_adapter_error_full_chain( + self, mock_x2text_cls, mock_get_fs, eager_app + ): + """X2Text raises AdapterError → failure result, no unhandled exception.""" + from unstract.sdk1.adapters.exceptions import AdapterError + + mock_x2text = MagicMock() + mock_x2text.x2text_instance = MagicMock() + mock_x2text.x2text_instance.get_name.return_value = "SanityExtractor" + mock_x2text.process.side_effect = AdapterError("sanity adapter err") + mock_x2text_cls.return_value = mock_x2text + mock_get_fs.return_value = MagicMock() + + ctx = _extract_ctx() + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is False + assert "SanityExtractor" in result.error + assert "sanity adapter err" in result.error + + +class TestSanityIndex: + """Full-chain index tests through Celery eager mode.""" + + @patch(_PATCH_FS) + @patch(_PATCH_INDEX_DEPS) + def test_index_full_chain(self, mock_deps, mock_get_fs, eager_app): + """Mocked _get_indexing_deps → result.data has doc_id.""" + mock_index_cls = MagicMock() + mock_index = MagicMock() + mock_index.generate_index_key.return_value = "doc-sanity-idx" + mock_index.is_document_indexed.return_value = False + mock_index.perform_indexing.return_value = "doc-sanity-idx" + mock_index_cls.return_value = mock_index + + mock_emb_cls = MagicMock() + mock_emb_cls.return_value = MagicMock() + mock_vdb_cls = MagicMock() + mock_vdb_cls.return_value = MagicMock() + + mock_deps.return_value = (mock_index_cls, mock_emb_cls, mock_vdb_cls) + mock_get_fs.return_value = MagicMock() + + ctx = _index_ctx() + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is True + assert result.data[IKeys.DOC_ID] == "doc-sanity-idx" + + @patch(_PATCH_INDEX_UTILS, return_value="doc-zero-chunk-sanity") + @patch(_PATCH_FS) + def test_index_chunk_size_zero_full_chain( + self, mock_get_fs, mock_gen_key, eager_app + ): + """chunk_size=0 skips heavy deps → returns doc_id via IndexingUtils.""" + mock_get_fs.return_value = MagicMock() + + params = { + "embedding_instance_id": "emb-sanity", + "vector_db_instance_id": "vdb-sanity", + "x2text_instance_id": "x2t-sanity", + "file_path": "/data/sanity.pdf", + "file_hash": "sanity-hash", + "extracted_text": "text", + "platform_api_key": "sk-sanity", + "chunk_size": 0, + "chunk_overlap": 0, + } + ctx = _index_ctx(executor_params=params) + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is True + assert result.data[IKeys.DOC_ID] == "doc-zero-chunk-sanity" + + @patch(_PATCH_FS) + @patch(_PATCH_INDEX_DEPS) + def test_index_error_full_chain(self, mock_deps, mock_get_fs, eager_app): + """perform_indexing raises → failure result.""" + mock_index_cls = MagicMock() + mock_index = MagicMock() + mock_index.generate_index_key.return_value = "doc-err" + mock_index.is_document_indexed.return_value = False + mock_index.perform_indexing.side_effect = RuntimeError("VDB down") + mock_index_cls.return_value = mock_index + + mock_emb_cls = MagicMock() + mock_emb_cls.return_value = MagicMock() + mock_vdb_cls = MagicMock() + mock_vdb_cls.return_value = MagicMock() + + mock_deps.return_value = (mock_index_cls, mock_emb_cls, mock_vdb_cls) + mock_get_fs.return_value = MagicMock() + + ctx = _index_ctx() + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is False + assert "indexing" in result.error.lower() + + +class TestSanityAnswerPrompt: + """Full-chain answer_prompt tests through Celery eager mode.""" + + @patch(_PATCH_INDEX_UTILS, return_value="doc-id-sanity") + @patch(_PATCH_PROMPT_DEPS) + @patch(_PATCH_SHIM) + def test_answer_prompt_text_full_chain( + self, mock_shim_cls, mock_deps, _mock_idx, eager_app + ): + """TEXT prompt → result.data has output, metadata, metrics.""" + llm = _mock_llm("sanity answer") + mock_deps.return_value = _mock_prompt_deps(llm) + mock_shim_cls.return_value = MagicMock() + + ctx = _answer_prompt_ctx() + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is True + assert PSKeys.OUTPUT in result.data + assert PSKeys.METADATA in result.data + assert PSKeys.METRICS in result.data + assert result.data[PSKeys.OUTPUT]["field_a"] == "sanity answer" + + @patch(_PATCH_INDEX_UTILS, return_value="doc-id-sanity") + @patch(_PATCH_PROMPT_DEPS) + @patch(_PATCH_SHIM) + def test_answer_prompt_multi_prompt_full_chain( + self, mock_shim_cls, mock_deps, _mock_idx, eager_app + ): + """Two prompts → both field names in output and metrics.""" + llm = _mock_llm("multi answer") + mock_deps.return_value = _mock_prompt_deps(llm) + mock_shim_cls.return_value = MagicMock() + + prompts = [ + _make_prompt(name="revenue"), + _make_prompt(name="date_signed"), + ] + ctx = _answer_prompt_ctx(prompts=prompts) + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is True + assert "revenue" in result.data[PSKeys.OUTPUT] + assert "date_signed" in result.data[PSKeys.OUTPUT] + assert "revenue" in result.data[PSKeys.METRICS] + assert "date_signed" in result.data[PSKeys.METRICS] + + @patch(_PATCH_INDEX_UTILS, return_value="doc-id-sanity") + @patch(_PATCH_PROMPT_DEPS) + @patch(_PATCH_SHIM) + def test_answer_prompt_table_fails_full_chain( + self, mock_shim_cls, mock_deps, _mock_idx, eager_app + ): + """TABLE type → failure mentioning TABLE.""" + llm = _mock_llm() + mock_deps.return_value = _mock_prompt_deps(llm) + mock_shim_cls.return_value = MagicMock() + + ctx = _answer_prompt_ctx( + prompts=[_make_prompt(output_type="table")] + ) + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is False + assert "TABLE" in result.error + + +class TestSanitySinglePass: + """Full-chain single_pass_extraction test.""" + + @patch(_PATCH_INDEX_UTILS, return_value="doc-id-sanity") + @patch(_PATCH_PROMPT_DEPS) + @patch(_PATCH_SHIM) + def test_single_pass_delegates_full_chain( + self, mock_shim_cls, mock_deps, _mock_idx, eager_app + ): + """Same mocks as answer_prompt → same response shape.""" + llm = _mock_llm("single pass answer") + mock_deps.return_value = _mock_prompt_deps(llm) + mock_shim_cls.return_value = MagicMock() + + ctx = _answer_prompt_ctx( + operation=Operation.SINGLE_PASS_EXTRACTION.value, + ) + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is True + assert PSKeys.OUTPUT in result.data + assert result.data[PSKeys.OUTPUT]["field_a"] == "single pass answer" + + +class TestSanitySummarize: + """Full-chain summarize tests through Celery eager mode.""" + + @patch(_PATCH_RUN_COMPLETION, return_value="Sanity summary text.") + @patch(_PATCH_PROMPT_DEPS) + @patch(_PATCH_SHIM) + def test_summarize_full_chain( + self, mock_shim_cls, mock_get_deps, mock_run, eager_app + ): + """Mocked _get_prompt_deps + run_completion → result.data has summary.""" + mock_llm_cls = MagicMock() + mock_llm_cls.return_value = MagicMock() + mock_get_deps.return_value = ( + MagicMock(), MagicMock(), MagicMock(), MagicMock(), + mock_llm_cls, MagicMock(), MagicMock(), + ) + + ctx = _summarize_ctx() + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is True + assert result.data["data"] == "Sanity summary text." + + def test_summarize_missing_llm_full_chain(self, eager_app): + """Missing llm_adapter_instance_id → failure.""" + ctx = _summarize_ctx( + executor_params={ + "llm_adapter_instance_id": "", + "summarize_prompt": "Summarize.", + "context": "Document text.", + "PLATFORM_SERVICE_API_KEY": "pk-test", + } + ) + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is False + assert "llm_adapter_instance_id" in result.error + + @patch(_PATCH_RUN_COMPLETION, side_effect=Exception("LLM down")) + @patch(_PATCH_PROMPT_DEPS) + @patch(_PATCH_SHIM) + def test_summarize_error_full_chain( + self, mock_shim_cls, mock_get_deps, mock_run, eager_app + ): + """run_completion raises → failure mentioning summarization.""" + mock_llm_cls = MagicMock() + mock_llm_cls.return_value = MagicMock() + mock_get_deps.return_value = ( + MagicMock(), MagicMock(), MagicMock(), MagicMock(), + mock_llm_cls, MagicMock(), MagicMock(), + ) + + ctx = _summarize_ctx() + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is False + assert "summarization" in result.error.lower() or "LLM" in result.error + + +class TestSanityAgenticExtraction: + """Full-chain agentic operations test — rejected by LegacyExecutor.""" + + def test_agentic_extract_rejected_by_legacy(self, eager_app): + """Agentic operations are handled by cloud executor, not legacy.""" + ctx = ExecutionContext( + executor_name="legacy", + operation="agentic_extract", + run_id="run-sanity-agentic", + execution_source="tool", + ) + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is False + assert "does not support" in result.error + + +class TestSanityResponseContracts: + """Verify response dicts survive JSON round-trip with expected keys.""" + + @patch(_PATCH_FS) + @patch(_PATCH_X2TEXT) + def test_extract_contract(self, mock_x2text_cls, mock_get_fs, eager_app): + mock_x2text = MagicMock() + mock_x2text.process.return_value = _mock_process_response("contract") + mock_x2text.x2text_instance = MagicMock() + mock_x2text_cls.return_value = mock_x2text + mock_get_fs.return_value = MagicMock() + + ctx = _extract_ctx() + result_dict = _run_task(eager_app, ctx.to_dict()) + + # JSON round-trip + serialized = json.dumps(result_dict) + deserialized = json.loads(serialized) + result = ExecutionResult.from_dict(deserialized) + + assert result.success is True + assert isinstance(result.data[IKeys.EXTRACTED_TEXT], str) + + @patch(_PATCH_FS) + @patch(_PATCH_INDEX_DEPS) + def test_index_contract(self, mock_deps, mock_get_fs, eager_app): + mock_index_cls = MagicMock() + mock_index = MagicMock() + mock_index.generate_index_key.return_value = "doc-contract" + mock_index.is_document_indexed.return_value = False + mock_index.perform_indexing.return_value = "doc-contract" + mock_index_cls.return_value = mock_index + + mock_emb_cls = MagicMock() + mock_emb_cls.return_value = MagicMock() + mock_vdb_cls = MagicMock() + mock_vdb_cls.return_value = MagicMock() + + mock_deps.return_value = (mock_index_cls, mock_emb_cls, mock_vdb_cls) + mock_get_fs.return_value = MagicMock() + + ctx = _index_ctx() + result_dict = _run_task(eager_app, ctx.to_dict()) + + serialized = json.dumps(result_dict) + deserialized = json.loads(serialized) + result = ExecutionResult.from_dict(deserialized) + + assert result.success is True + assert isinstance(result.data[IKeys.DOC_ID], str) + + @patch(_PATCH_INDEX_UTILS, return_value="doc-id-sanity") + @patch(_PATCH_PROMPT_DEPS) + @patch(_PATCH_SHIM) + def test_answer_prompt_contract( + self, mock_shim_cls, mock_deps, _mock_idx, eager_app + ): + llm = _mock_llm("contract answer") + mock_deps.return_value = _mock_prompt_deps(llm) + mock_shim_cls.return_value = MagicMock() + + ctx = _answer_prompt_ctx() + result_dict = _run_task(eager_app, ctx.to_dict()) + + serialized = json.dumps(result_dict) + deserialized = json.loads(serialized) + result = ExecutionResult.from_dict(deserialized) + + assert result.success is True + assert isinstance(result.data[PSKeys.OUTPUT], dict) + assert isinstance(result.data[PSKeys.METADATA], dict) + assert isinstance(result.data[PSKeys.METRICS], dict) + + @patch(_PATCH_RUN_COMPLETION, return_value="contract summary") + @patch(_PATCH_PROMPT_DEPS) + @patch(_PATCH_SHIM) + def test_summarize_contract( + self, mock_shim_cls, mock_get_deps, mock_run, eager_app + ): + mock_llm_cls = MagicMock() + mock_llm_cls.return_value = MagicMock() + mock_get_deps.return_value = ( + MagicMock(), MagicMock(), MagicMock(), MagicMock(), + mock_llm_cls, MagicMock(), MagicMock(), + ) + + ctx = _summarize_ctx() + result_dict = _run_task(eager_app, ctx.to_dict()) + + serialized = json.dumps(result_dict) + deserialized = json.loads(serialized) + result = ExecutionResult.from_dict(deserialized) + + assert result.success is True + assert isinstance(result.data["data"], str) + + +class TestSanityDispatcher: + """Full-chain dispatcher tests with Celery eager mode.""" + + @patch(_PATCH_FS) + @patch(_PATCH_X2TEXT) + def test_dispatcher_dispatch_full_chain( + self, mock_x2text_cls, mock_get_fs, eager_app + ): + """ExecutionDispatcher dispatches through Celery and returns result. + + Celery's ``send_task`` doesn't reliably use eager mode, so we + patch it to route through ``task.apply()`` instead — this still + exercises the full Dispatcher → task → orchestrator chain. + """ + from unstract.sdk1.execution.dispatcher import ExecutionDispatcher + + mock_x2text = MagicMock() + mock_x2text.process.return_value = _mock_process_response("dispatched") + mock_x2text.x2text_instance = MagicMock() + mock_x2text_cls.return_value = mock_x2text + mock_get_fs.return_value = MagicMock() + + task = eager_app.tasks["execute_extraction"] + + def eager_send_task(name, args=None, **kwargs): + return task.apply(args=args) + + with patch.object(eager_app, "send_task", side_effect=eager_send_task): + dispatcher = ExecutionDispatcher(celery_app=eager_app) + ctx = _extract_ctx() + result = dispatcher.dispatch(ctx, timeout=10) + + assert isinstance(result, ExecutionResult) + assert result.success is True + assert result.data[IKeys.EXTRACTED_TEXT] == "dispatched" + + def test_dispatcher_no_app_raises(self): + """ExecutionDispatcher(celery_app=None).dispatch() → ValueError.""" + from unstract.sdk1.execution.dispatcher import ExecutionDispatcher + + dispatcher = ExecutionDispatcher(celery_app=None) + ctx = _extract_ctx() + + with pytest.raises(ValueError, match="No Celery app"): + dispatcher.dispatch(ctx) + + +class TestSanityCrossCutting: + """Cross-cutting concerns: unknown ops, invalid contexts, error round-trip.""" + + def test_unknown_operation_full_chain(self, eager_app): + """operation='nonexistent' → failure mentioning unsupported.""" + ctx = ExecutionContext( + executor_name="legacy", + operation="nonexistent", + run_id="run-sanity-unknown", + execution_source="tool", + ) + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is False + assert "nonexistent" in result.error.lower() + + def test_invalid_context_dict_full_chain(self, eager_app): + """Malformed dict → failure mentioning 'Invalid execution context'.""" + result_dict = _run_task(eager_app, {"bad": "data"}) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is False + assert "Invalid execution context" in result.error + + @patch(_PATCH_FS) + @patch(_PATCH_X2TEXT) + def test_failure_result_json_round_trip( + self, mock_x2text_cls, mock_get_fs, eager_app + ): + """Failure result survives JSON serialization with error preserved.""" + from unstract.sdk1.adapters.exceptions import AdapterError + + mock_x2text = MagicMock() + mock_x2text.x2text_instance = MagicMock() + mock_x2text.x2text_instance.get_name.return_value = "FailExtractor" + mock_x2text.process.side_effect = AdapterError("round trip error") + mock_x2text_cls.return_value = mock_x2text + mock_get_fs.return_value = MagicMock() + + ctx = _extract_ctx() + result_dict = _run_task(eager_app, ctx.to_dict()) + + # Verify raw dict survives JSON round-trip + serialized = json.dumps(result_dict) + deserialized = json.loads(serialized) + result = ExecutionResult.from_dict(deserialized) + + assert result.success is False + assert "round trip error" in result.error + assert "FailExtractor" in result.error diff --git a/workers/tests/test_sanity_phase3.py b/workers/tests/test_sanity_phase3.py new file mode 100644 index 0000000000..42dc89462c --- /dev/null +++ b/workers/tests/test_sanity_phase3.py @@ -0,0 +1,981 @@ +"""Phase 3-SANITY — Integration tests for the structure tool Celery task. + +Tests the full structure tool pipeline with mocked platform API and +ExecutionDispatcher. After Phase 5E, the structure tool task dispatches a +single ``structure_pipeline`` operation to the executor worker instead of +3 sequential dispatches. These tests verify the correct pipeline params +are assembled and the result is written to filesystem. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from shared.enums.task_enums import TaskName +from unstract.sdk1.execution.result import ExecutionResult + +# --------------------------------------------------------------------------- +# Patch targets +# --------------------------------------------------------------------------- + +_PATCH_DISPATCHER = ( + "file_processing.structure_tool_task.ExecutionDispatcher" +) +_PATCH_PLATFORM_HELPER = ( + "file_processing.structure_tool_task._create_platform_helper" +) +_PATCH_FILE_STORAGE = ( + "file_processing.structure_tool_task._get_file_storage" +) +_PATCH_SHIM = ( + "executor.executor_tool_shim.ExecutorToolShim" +) +_PATCH_SERVICE_IS_STRUCTURE = ( + "shared.workflow.execution.service." + "WorkerWorkflowExecutionService._is_structure_tool_workflow" +) +_PATCH_SERVICE_EXECUTE_STRUCTURE = ( + "shared.workflow.execution.service." + "WorkerWorkflowExecutionService._execute_structure_tool_workflow" +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_fs(): + """Create a mock file storage.""" + fs = MagicMock(name="file_storage") + fs.exists.return_value = False + fs.read.return_value = "" + fs.json_dump.return_value = None + fs.write.return_value = None + fs.get_hash_from_file.return_value = "abc123hash" + return fs + + +@pytest.fixture +def mock_dispatcher(): + """Create a mock ExecutionDispatcher that returns success results.""" + dispatcher = MagicMock(name="ExecutionDispatcher") + return dispatcher + + +@pytest.fixture +def mock_platform_helper(): + """Create a mock PlatformHelper.""" + helper = MagicMock(name="PlatformHelper") + return helper + + +@pytest.fixture +def tool_metadata_regular(): + """Standard prompt studio tool metadata.""" + return { + "name": "Test Project", + "is_agentic": False, + "tool_id": "tool-123", + "tool_settings": { + "vector-db": "vdb-1", + "embedding": "emb-1", + "x2text_adapter": "x2t-1", + "llm": "llm-1", + }, + "outputs": [ + { + "name": "field_a", + "prompt": "What is the revenue?", + "type": "text", + "active": True, + "chunk-size": 512, + "chunk-overlap": 128, + "retrieval-strategy": "simple", + "llm": "llm-1", + "embedding": "emb-1", + "vector-db": "vdb-1", + "x2text_adapter": "x2t-1", + "similarity-top-k": 5, + }, + ], + } + + +@pytest.fixture +def base_params(): + """Base params dict for execute_structure_tool.""" + return { + "organization_id": "org-test", + "workflow_id": "wf-123", + "execution_id": "exec-456", + "file_execution_id": "fexec-789", + "tool_instance_metadata": { + "prompt_registry_id": "preg-001", + }, + "platform_service_api_key": "sk-test-key", + "input_file_path": "/data/test.pdf", + "output_dir_path": "/output", + "source_file_name": "test.pdf", + "execution_data_dir": "/data/exec", + "messaging_channel": "channel-1", + "file_hash": "filehash123", + "exec_metadata": {"tags": ["tag1"]}, + } + + +def _make_pipeline_result( + output: dict | None = None, + metadata: dict | None = None, + metrics: dict | None = None, +) -> ExecutionResult: + """Create a mock structure_pipeline result.""" + return ExecutionResult( + success=True, + data={ + "output": output or {}, + "metadata": metadata or {}, + "metrics": metrics or {}, + }, + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestTaskEnumRegistered: + """3-SANITY: Verify TaskName enum exists.""" + + def test_task_enum_registered(self): + assert hasattr(TaskName, "EXECUTE_STRUCTURE_TOOL") + assert str(TaskName.EXECUTE_STRUCTURE_TOOL) == "execute_structure_tool" + + +class TestStructureToolPipeline: + """Full pipeline dispatched as single structure_pipeline operation.""" + + @patch(_PATCH_SHIM) + @patch(_PATCH_FILE_STORAGE) + @patch(_PATCH_PLATFORM_HELPER) + @patch(_PATCH_DISPATCHER) + def test_structure_tool_single_dispatch( + self, + MockDispatcher, + mock_create_ph, + mock_get_fs, + MockShim, + base_params, + tool_metadata_regular, + mock_fs, + mock_platform_helper, + ): + """Single structure_pipeline dispatch for extract+index+answer.""" + from file_processing.structure_tool_task import ( + _execute_structure_tool_impl as execute_structure_tool, + ) + + mock_get_fs.return_value = mock_fs + mock_create_ph.return_value = mock_platform_helper + mock_platform_helper.get_prompt_studio_tool.return_value = { + "tool_metadata": tool_metadata_regular, + } + + dispatcher_instance = MagicMock() + MockDispatcher.return_value = dispatcher_instance + + pipeline_result = _make_pipeline_result( + output={"field_a": "$1M"}, + metadata={"run_id": "fexec-789", "file_name": "test.pdf"}, + metrics={"field_a": {"extraction_llm": {"tokens": 50}}}, + ) + dispatcher_instance.dispatch.return_value = pipeline_result + + result = execute_structure_tool(base_params) + + assert result["success"] is True + assert result["data"]["output"]["field_a"] == "$1M" + assert result["data"]["metadata"]["file_name"] == "test.pdf" + # json_dump called twice: output file + INFILE overwrite + assert mock_fs.json_dump.call_count == 2 + + # Single dispatch with structure_pipeline + assert dispatcher_instance.dispatch.call_count == 1 + ctx = dispatcher_instance.dispatch.call_args[0][0] + assert ctx.operation == "structure_pipeline" + assert ctx.execution_source == "tool" + assert ctx.executor_name == "legacy" + + @patch(_PATCH_SHIM) + @patch(_PATCH_FILE_STORAGE) + @patch(_PATCH_PLATFORM_HELPER) + @patch(_PATCH_DISPATCHER) + def test_pipeline_params_structure( + self, + MockDispatcher, + mock_create_ph, + mock_get_fs, + MockShim, + base_params, + tool_metadata_regular, + mock_fs, + mock_platform_helper, + ): + """Verify executor_params contains all pipeline sub-params.""" + from file_processing.structure_tool_task import ( + _execute_structure_tool_impl as execute_structure_tool, + ) + + mock_get_fs.return_value = mock_fs + mock_create_ph.return_value = mock_platform_helper + mock_platform_helper.get_prompt_studio_tool.return_value = { + "tool_metadata": tool_metadata_regular, + } + + dispatcher_instance = MagicMock() + MockDispatcher.return_value = dispatcher_instance + dispatcher_instance.dispatch.return_value = _make_pipeline_result() + + execute_structure_tool(base_params) + + ctx = dispatcher_instance.dispatch.call_args[0][0] + ep = ctx.executor_params + + # All required keys present + assert "extract_params" in ep + assert "index_template" in ep + assert "answer_params" in ep + assert "pipeline_options" in ep + + # Extract params + assert ep["extract_params"]["file_path"] == "/data/test.pdf" + + # Index template + assert ep["index_template"]["tool_id"] == "tool-123" + assert ep["index_template"]["file_hash"] == "filehash123" + + # Answer params + assert ep["answer_params"]["tool_id"] == "tool-123" + assert ep["answer_params"]["run_id"] == "fexec-789" + + # Pipeline options (normal flow) + opts = ep["pipeline_options"] + assert opts["skip_extraction_and_indexing"] is False + assert opts["is_summarization_enabled"] is False + assert opts["is_single_pass_enabled"] is False + assert opts["source_file_name"] == "test.pdf" + + +class TestStructureToolSinglePass: + """Single-pass flag passed to pipeline_options.""" + + @patch(_PATCH_SHIM) + @patch(_PATCH_FILE_STORAGE) + @patch(_PATCH_PLATFORM_HELPER) + @patch(_PATCH_DISPATCHER) + def test_structure_tool_single_pass( + self, + MockDispatcher, + mock_create_ph, + mock_get_fs, + MockShim, + base_params, + tool_metadata_regular, + mock_fs, + mock_platform_helper, + ): + from file_processing.structure_tool_task import ( + _execute_structure_tool_impl as execute_structure_tool, + ) + + mock_get_fs.return_value = mock_fs + mock_create_ph.return_value = mock_platform_helper + mock_platform_helper.get_prompt_studio_tool.return_value = { + "tool_metadata": tool_metadata_regular, + } + + base_params["tool_instance_metadata"]["single_pass_extraction_mode"] = True + + dispatcher_instance = MagicMock() + MockDispatcher.return_value = dispatcher_instance + dispatcher_instance.dispatch.return_value = _make_pipeline_result( + output={"field_a": "answer"}, + ) + + result = execute_structure_tool(base_params) + + assert result["success"] is True + # Single dispatch with is_single_pass_enabled flag + assert dispatcher_instance.dispatch.call_count == 1 + ctx = dispatcher_instance.dispatch.call_args[0][0] + assert ctx.operation == "structure_pipeline" + opts = ctx.executor_params["pipeline_options"] + assert opts["is_single_pass_enabled"] is True + + +class TestStructureToolSummarize: + """Summarization params passed to pipeline.""" + + @patch(_PATCH_SHIM) + @patch(_PATCH_FILE_STORAGE) + @patch(_PATCH_PLATFORM_HELPER) + @patch(_PATCH_DISPATCHER) + def test_structure_tool_summarize_flow( + self, + MockDispatcher, + mock_create_ph, + mock_get_fs, + MockShim, + base_params, + tool_metadata_regular, + mock_fs, + mock_platform_helper, + ): + from file_processing.structure_tool_task import ( + _execute_structure_tool_impl as execute_structure_tool, + ) + + mock_get_fs.return_value = mock_fs + mock_create_ph.return_value = mock_platform_helper + mock_platform_helper.get_prompt_studio_tool.return_value = { + "tool_metadata": tool_metadata_regular, + } + + tool_metadata_regular["tool_settings"]["summarize_prompt"] = ( + "Summarize this doc" + ) + base_params["tool_instance_metadata"]["summarize_as_source"] = True + + dispatcher_instance = MagicMock() + MockDispatcher.return_value = dispatcher_instance + dispatcher_instance.dispatch.return_value = _make_pipeline_result( + output={"field_a": "answer"}, + ) + + result = execute_structure_tool(base_params) + + assert result["success"] is True + assert dispatcher_instance.dispatch.call_count == 1 + ctx = dispatcher_instance.dispatch.call_args[0][0] + assert ctx.operation == "structure_pipeline" + + opts = ctx.executor_params["pipeline_options"] + assert opts["is_summarization_enabled"] is True + + # Summarize params included + sp = ctx.executor_params["summarize_params"] + assert sp is not None + assert sp["summarize_prompt"] == "Summarize this doc" + assert sp["llm_adapter_instance_id"] == "llm-1" + assert "extract_file_path" in sp + assert "summarize_file_path" in sp + + +class TestStructureToolSmartTable: + """Excel with valid JSON schema sets skip_extraction_and_indexing.""" + + @patch(_PATCH_SHIM) + @patch(_PATCH_FILE_STORAGE) + @patch(_PATCH_PLATFORM_HELPER) + @patch(_PATCH_DISPATCHER) + def test_structure_tool_skip_extraction_smart_table( + self, + MockDispatcher, + mock_create_ph, + mock_get_fs, + MockShim, + base_params, + tool_metadata_regular, + mock_fs, + mock_platform_helper, + ): + from file_processing.structure_tool_task import ( + _execute_structure_tool_impl as execute_structure_tool, + ) + + mock_get_fs.return_value = mock_fs + mock_create_ph.return_value = mock_platform_helper + + tool_metadata_regular["outputs"][0]["table_settings"] = { + "is_directory_mode": False, + } + tool_metadata_regular["outputs"][0]["prompt"] = '{"key": "value"}' + + mock_platform_helper.get_prompt_studio_tool.return_value = { + "tool_metadata": tool_metadata_regular, + } + + dispatcher_instance = MagicMock() + MockDispatcher.return_value = dispatcher_instance + dispatcher_instance.dispatch.return_value = _make_pipeline_result( + output={"field_a": "table_answer"}, + ) + + result = execute_structure_tool(base_params) + + assert result["success"] is True + # Single pipeline dispatch with skip flag + assert dispatcher_instance.dispatch.call_count == 1 + ctx = dispatcher_instance.dispatch.call_args[0][0] + assert ctx.operation == "structure_pipeline" + opts = ctx.executor_params["pipeline_options"] + assert opts["skip_extraction_and_indexing"] is True + + +class TestStructureToolAgentic: + """Agentic project routes to AgenticPromptStudioExecutor.""" + + @patch("unstract.sdk1.x2txt.X2Text") + @patch(_PATCH_SHIM) + @patch(_PATCH_FILE_STORAGE) + @patch(_PATCH_PLATFORM_HELPER) + @patch(_PATCH_DISPATCHER) + def test_structure_tool_agentic_routing( + self, + MockDispatcher, + mock_create_ph, + mock_get_fs, + MockShim, + MockX2Text, + base_params, + mock_fs, + mock_platform_helper, + ): + from file_processing.structure_tool_task import ( + _execute_structure_tool_impl as execute_structure_tool, + ) + + mock_get_fs.return_value = mock_fs + mock_create_ph.return_value = mock_platform_helper + + # Mock X2Text to return extracted text + mock_x2text_instance = MagicMock() + mock_extraction_result = MagicMock() + mock_extraction_result.extracted_text = "Extracted document text" + mock_x2text_instance.process.return_value = mock_extraction_result + MockX2Text.return_value = mock_x2text_instance + + # Prompt studio lookup fails, agentic succeeds + mock_platform_helper.get_prompt_studio_tool.return_value = None + + agentic_metadata = { + "name": "Agentic Project", + "project_id": "ap-001", + "json_schema": {"field": "string"}, + "prompt_text": "Extract the field", + "adapter_config": { + "extractor_llm": "llm-adapter-1", + "llmwhisperer": "whisper-adapter-1", + }, + } + mock_platform_helper.get_agentic_studio_tool.return_value = { + "tool_metadata": agentic_metadata, + } + + dispatcher_instance = MagicMock() + MockDispatcher.return_value = dispatcher_instance + + # Simulate successful agentic extraction + agentic_result = ExecutionResult( + success=True, + data={"output": {"field": "value"}}, + ) + dispatcher_instance.dispatch.return_value = agentic_result + + result = execute_structure_tool(base_params) + + # Should dispatch to agentic executor with agentic_extract operation + calls = dispatcher_instance.dispatch.call_args_list + assert len(calls) == 1 + ctx = calls[0][0][0] + assert ctx.executor_name == "agentic" + assert ctx.operation == "agentic_extract" + # Verify flat params are passed (not nested dicts) + params = ctx.executor_params + assert params["adapter_instance_id"] == "llm-adapter-1" + assert params["document_text"] == "Extracted document text" + assert params["prompt_text"] == "Extract the field" + assert params["schema"] == {"field": "string"} + assert "PLATFORM_SERVICE_API_KEY" in params + + +class TestStructureToolProfileOverrides: + """Profile overrides modify tool_metadata before pipeline dispatch.""" + + @patch(_PATCH_SHIM) + @patch(_PATCH_FILE_STORAGE) + @patch(_PATCH_PLATFORM_HELPER) + @patch(_PATCH_DISPATCHER) + def test_structure_tool_profile_overrides( + self, + MockDispatcher, + mock_create_ph, + mock_get_fs, + MockShim, + base_params, + tool_metadata_regular, + mock_fs, + mock_platform_helper, + ): + from file_processing.structure_tool_task import ( + _execute_structure_tool_impl as execute_structure_tool, + ) + + mock_get_fs.return_value = mock_fs + mock_create_ph.return_value = mock_platform_helper + mock_platform_helper.get_prompt_studio_tool.return_value = { + "tool_metadata": tool_metadata_regular, + } + + base_params["exec_metadata"]["llm_profile_id"] = "profile-1" + mock_platform_helper.get_llm_profile.return_value = { + "profile_name": "Test Profile", + "llm_id": "llm-override", + } + + dispatcher_instance = MagicMock() + MockDispatcher.return_value = dispatcher_instance + dispatcher_instance.dispatch.return_value = _make_pipeline_result( + output={"field_a": "answer"}, + ) + + result = execute_structure_tool(base_params) + + assert result["success"] is True + mock_platform_helper.get_llm_profile.assert_called_once_with("profile-1") + assert tool_metadata_regular["tool_settings"]["llm"] == "llm-override" + + +class TestStructureToolPipelineFailure: + """Pipeline failure propagated to caller.""" + + @patch(_PATCH_SHIM) + @patch(_PATCH_FILE_STORAGE) + @patch(_PATCH_PLATFORM_HELPER) + @patch(_PATCH_DISPATCHER) + def test_structure_tool_pipeline_failure( + self, + MockDispatcher, + mock_create_ph, + mock_get_fs, + MockShim, + base_params, + tool_metadata_regular, + mock_fs, + mock_platform_helper, + ): + from file_processing.structure_tool_task import ( + _execute_structure_tool_impl as execute_structure_tool, + ) + + mock_get_fs.return_value = mock_fs + mock_create_ph.return_value = mock_platform_helper + mock_platform_helper.get_prompt_studio_tool.return_value = { + "tool_metadata": tool_metadata_regular, + } + + dispatcher_instance = MagicMock() + MockDispatcher.return_value = dispatcher_instance + + pipeline_failure = ExecutionResult.failure( + error="X2Text adapter error: connection refused" + ) + dispatcher_instance.dispatch.return_value = pipeline_failure + + result = execute_structure_tool(base_params) + + assert result["success"] is False + assert "X2Text" in result["error"] + assert dispatcher_instance.dispatch.call_count == 1 + + +class TestStructureToolMultipleOutputs: + """Multiple outputs are passed to executor in answer_params.""" + + @patch(_PATCH_SHIM) + @patch(_PATCH_FILE_STORAGE) + @patch(_PATCH_PLATFORM_HELPER) + @patch(_PATCH_DISPATCHER) + def test_structure_tool_multiple_outputs( + self, + MockDispatcher, + mock_create_ph, + mock_get_fs, + MockShim, + base_params, + tool_metadata_regular, + mock_fs, + mock_platform_helper, + ): + from file_processing.structure_tool_task import ( + _execute_structure_tool_impl as execute_structure_tool, + ) + + mock_get_fs.return_value = mock_fs + mock_create_ph.return_value = mock_platform_helper + + # Add a second output with same chunking params + second_output = dict(tool_metadata_regular["outputs"][0]) + second_output["name"] = "field_b" + tool_metadata_regular["outputs"].append(second_output) + + mock_platform_helper.get_prompt_studio_tool.return_value = { + "tool_metadata": tool_metadata_regular, + } + + dispatcher_instance = MagicMock() + MockDispatcher.return_value = dispatcher_instance + dispatcher_instance.dispatch.return_value = _make_pipeline_result( + output={"field_a": "a", "field_b": "b"}, + ) + + result = execute_structure_tool(base_params) + + assert result["success"] is True + # Single dispatch — index dedup handled inside executor + assert dispatcher_instance.dispatch.call_count == 1 + ctx = dispatcher_instance.dispatch.call_args[0][0] + outputs = ctx.executor_params["answer_params"]["outputs"] + assert len(outputs) == 2 + assert outputs[0]["name"] == "field_a" + assert outputs[1]["name"] == "field_b" + + +class TestStructureToolOutputWritten: + """Output JSON written to correct path with correct structure.""" + + @patch(_PATCH_SHIM) + @patch(_PATCH_FILE_STORAGE) + @patch(_PATCH_PLATFORM_HELPER) + @patch(_PATCH_DISPATCHER) + def test_structure_tool_output_written( + self, + MockDispatcher, + mock_create_ph, + mock_get_fs, + MockShim, + base_params, + tool_metadata_regular, + mock_fs, + mock_platform_helper, + ): + from file_processing.structure_tool_task import ( + _execute_structure_tool_impl as execute_structure_tool, + ) + + mock_get_fs.return_value = mock_fs + mock_create_ph.return_value = mock_platform_helper + mock_platform_helper.get_prompt_studio_tool.return_value = { + "tool_metadata": tool_metadata_regular, + } + + dispatcher_instance = MagicMock() + MockDispatcher.return_value = dispatcher_instance + dispatcher_instance.dispatch.return_value = _make_pipeline_result( + output={"field_a": "answer"}, + ) + + result = execute_structure_tool(base_params) + + assert result["success"] is True + + # json_dump called twice: once for output file, once for INFILE overwrite + assert mock_fs.json_dump.call_count == 2 + + # First call: output file (execution_dir/{stem}.json) + first_call = mock_fs.json_dump.call_args_list[0] + first_path = first_call.kwargs.get( + "path", first_call[1].get("path") if len(first_call) > 1 else None + ) + if first_path is None: + first_path = first_call[0][0] if first_call[0] else None + assert str(first_path).endswith("test.json") + + # Second call: INFILE overwrite (so destination connector reads JSON, not PDF) + second_call = mock_fs.json_dump.call_args_list[1] + second_path = second_call.kwargs.get( + "path", second_call[1].get("path") if len(second_call) > 1 else None + ) + if second_path is None: + second_path = second_call[0][0] if second_call[0] else None + assert str(second_path) == base_params["input_file_path"] + + +class TestStructureToolMetadataFileName: + """metadata.file_name in pipeline result preserved.""" + + @patch(_PATCH_SHIM) + @patch(_PATCH_FILE_STORAGE) + @patch(_PATCH_PLATFORM_HELPER) + @patch(_PATCH_DISPATCHER) + def test_structure_tool_metadata_file_name( + self, + MockDispatcher, + mock_create_ph, + mock_get_fs, + MockShim, + base_params, + tool_metadata_regular, + mock_fs, + mock_platform_helper, + ): + from file_processing.structure_tool_task import ( + _execute_structure_tool_impl as execute_structure_tool, + ) + + mock_get_fs.return_value = mock_fs + mock_create_ph.return_value = mock_platform_helper + mock_platform_helper.get_prompt_studio_tool.return_value = { + "tool_metadata": tool_metadata_regular, + } + + dispatcher_instance = MagicMock() + MockDispatcher.return_value = dispatcher_instance + dispatcher_instance.dispatch.return_value = _make_pipeline_result( + output={"field_a": "answer"}, + metadata={"run_id": "123", "file_name": "test.pdf"}, + ) + + result = execute_structure_tool(base_params) + + assert result["success"] is True + assert result["data"]["metadata"]["file_name"] == "test.pdf" + + +class TestStructureToolNoSummarize: + """No summarize_params when summarization is not enabled.""" + + @patch(_PATCH_SHIM) + @patch(_PATCH_FILE_STORAGE) + @patch(_PATCH_PLATFORM_HELPER) + @patch(_PATCH_DISPATCHER) + def test_no_summarize_params_when_disabled( + self, + MockDispatcher, + mock_create_ph, + mock_get_fs, + MockShim, + base_params, + tool_metadata_regular, + mock_fs, + mock_platform_helper, + ): + from file_processing.structure_tool_task import ( + _execute_structure_tool_impl as execute_structure_tool, + ) + + mock_get_fs.return_value = mock_fs + mock_create_ph.return_value = mock_platform_helper + mock_platform_helper.get_prompt_studio_tool.return_value = { + "tool_metadata": tool_metadata_regular, + } + + dispatcher_instance = MagicMock() + MockDispatcher.return_value = dispatcher_instance + dispatcher_instance.dispatch.return_value = _make_pipeline_result() + + execute_structure_tool(base_params) + + ctx = dispatcher_instance.dispatch.call_args[0][0] + assert ctx.executor_params["summarize_params"] is None + assert ctx.executor_params["pipeline_options"]["is_summarization_enabled"] is False + + +class TestWorkflowServiceDetection: + """Test _is_structure_tool_workflow detection.""" + + def test_is_structure_tool_detection(self): + from shared.workflow.execution.service import ( + WorkerWorkflowExecutionService, + ) + + service = WorkerWorkflowExecutionService() + + # Mock execution_service with a structure tool instance + mock_exec_service = MagicMock() + ti = MagicMock() + ti.image_name = "unstract/tool-structure" + mock_exec_service.tool_instances = [ti] + + result = service._is_structure_tool_workflow(mock_exec_service) + assert result is True + + def test_non_structure_tool_uses_docker(self): + from shared.workflow.execution.service import ( + WorkerWorkflowExecutionService, + ) + + service = WorkerWorkflowExecutionService() + + # Mock execution_service with a non-structure tool + mock_exec_service = MagicMock() + ti = MagicMock() + ti.image_name = "unstract/tool-classifier" + mock_exec_service.tool_instances = [ti] + + result = service._is_structure_tool_workflow(mock_exec_service) + assert result is False + + @patch.dict("os.environ", {"STRUCTURE_TOOL_IMAGE_NAME": "custom/structure"}) + def test_custom_structure_image_name(self): + from shared.workflow.execution.service import ( + WorkerWorkflowExecutionService, + ) + + service = WorkerWorkflowExecutionService() + + mock_exec_service = MagicMock() + ti = MagicMock() + ti.image_name = "custom/structure" + mock_exec_service.tool_instances = [ti] + + result = service._is_structure_tool_workflow(mock_exec_service) + assert result is True + + def test_registry_prefix_match(self): + """Image from backend with registry prefix matches default base name.""" + from shared.workflow.execution.service import ( + WorkerWorkflowExecutionService, + ) + + service = WorkerWorkflowExecutionService() + + # Worker uses default "unstract/tool-structure", but backend sends + # image with registry prefix (common in K8s deployments) + mock_exec_service = MagicMock() + ti = MagicMock() + ti.image_name = "gcr.io/my-project/tool-structure" + mock_exec_service.tool_instances = [ti] + + result = service._is_structure_tool_workflow(mock_exec_service) + assert result is True + + def test_registry_prefix_with_tag_match(self): + """Image with registry prefix and tag still matches.""" + from shared.workflow.execution.service import ( + WorkerWorkflowExecutionService, + ) + + service = WorkerWorkflowExecutionService() + + mock_exec_service = MagicMock() + ti = MagicMock() + ti.image_name = "us.gcr.io/prod/tool-structure:v1.2.3" + mock_exec_service.tool_instances = [ti] + + result = service._is_structure_tool_workflow(mock_exec_service) + assert result is True + + @patch.dict("os.environ", {"STRUCTURE_TOOL_IMAGE_NAME": "gcr.io/prod/tool-structure"}) + def test_env_has_registry_prefix_instance_has_different_prefix(self): + """Both env and instance have different registry prefixes, same base.""" + from shared.workflow.execution.service import ( + WorkerWorkflowExecutionService, + ) + + service = WorkerWorkflowExecutionService() + + mock_exec_service = MagicMock() + ti = MagicMock() + ti.image_name = "ecr.aws/other/tool-structure" + mock_exec_service.tool_instances = [ti] + + result = service._is_structure_tool_workflow(mock_exec_service) + assert result is True + + +class TestStructureToolParamsPassthrough: + """Task receives correct params from WorkerWorkflowExecutionService.""" + + @patch( + "shared.workflow.execution.service.WorkerWorkflowExecutionService." + "_execute_structure_tool_workflow" + ) + @patch( + "shared.workflow.execution.service.WorkerWorkflowExecutionService." + "_is_structure_tool_workflow", + return_value=True, + ) + def test_structure_tool_params_passthrough( + self, mock_is_struct, mock_exec_struct + ): + from shared.workflow.execution.service import ( + WorkerWorkflowExecutionService, + ) + + service = WorkerWorkflowExecutionService() + + mock_exec_service = MagicMock() + mock_exec_service.tool_instances = [MagicMock()] + + service._build_and_execute_workflow(mock_exec_service, "test.pdf") + + # Verify _execute_structure_tool_workflow was called + mock_exec_struct.assert_called_once_with( + mock_exec_service, "test.pdf" + ) + + +class TestHelperFunctions: + """Test standalone helper functions.""" + + def test_apply_profile_overrides(self): + from file_processing.structure_tool_task import ( + _apply_profile_overrides, + ) + + tool_metadata = { + "tool_settings": { + "llm": "old-llm", + "embedding": "old-emb", + }, + "outputs": [ + { + "name": "field_a", + "llm": "old-llm", + "embedding": "old-emb", + }, + ], + } + profile_data = { + "llm_id": "new-llm", + "embedding_model_id": "new-emb", + } + + changes = _apply_profile_overrides(tool_metadata, profile_data) + + assert len(changes) == 4 # 2 in tool_settings + 2 in output + assert tool_metadata["tool_settings"]["llm"] == "new-llm" + assert tool_metadata["tool_settings"]["embedding"] == "new-emb" + assert tool_metadata["outputs"][0]["llm"] == "new-llm" + assert tool_metadata["outputs"][0]["embedding"] == "new-emb" + + def test_should_skip_extraction_no_table_settings(self): + from file_processing.structure_tool_task import ( + _should_skip_extraction_for_smart_table, + ) + + outputs = [{"name": "field_a", "prompt": "What?"}] + assert ( + _should_skip_extraction_for_smart_table("file.xlsx", outputs) + is False + ) + + def test_should_skip_extraction_with_json_schema(self): + from file_processing.structure_tool_task import ( + _should_skip_extraction_for_smart_table, + ) + + outputs = [ + { + "name": "field_a", + "table_settings": {}, + "prompt": '{"col1": "string", "col2": "number"}', + } + ] + assert ( + _should_skip_extraction_for_smart_table("file.xlsx", outputs) + is True + ) diff --git a/workers/tests/test_sanity_phase4.py b/workers/tests/test_sanity_phase4.py new file mode 100644 index 0000000000..2d5e72715c --- /dev/null +++ b/workers/tests/test_sanity_phase4.py @@ -0,0 +1,899 @@ +"""Phase 4-SANITY — IDE path integration tests through executor chain. + +Phase 4 replaces PromptTool HTTP calls in PromptStudioHelper with +ExecutionDispatcher → executor worker → LegacyExecutor. + +These tests build the EXACT payloads that prompt_studio_helper.py +now sends via ExecutionDispatcher, push them through the full Celery +eager-mode chain, and verify the results match what the IDE expects. + +This validates the full contract: + prompt_studio_helper builds payload + → ExecutionContext(execution_source="ide", ...) + → Celery task → LegacyExecutor._handle_X() + → ExecutionResult → result.data used by IDE + +All tests use execution_source="ide" to match the real IDE path. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from executor.executors.constants import ( + PromptServiceConstants as PSKeys, +) +from unstract.sdk1.execution.context import ExecutionContext +from unstract.sdk1.execution.dispatcher import ExecutionDispatcher +from unstract.sdk1.execution.registry import ExecutorRegistry +from unstract.sdk1.execution.result import ExecutionResult + +# --------------------------------------------------------------------------- +# Patch targets (same as Phase 2 sanity) +# --------------------------------------------------------------------------- + +_PATCH_X2TEXT = "executor.executors.legacy_executor.X2Text" +_PATCH_FS = "executor.executors.legacy_executor.FileUtils.get_fs_instance" +_PATCH_INDEX_DEPS = ( + "executor.executors.legacy_executor.LegacyExecutor._get_indexing_deps" +) +_PATCH_PROMPT_DEPS = ( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" +) +_PATCH_SHIM = "executor.executors.legacy_executor.ExecutorToolShim" +_PATCH_RUN_COMPLETION = ( + "executor.executors.answer_prompt.AnswerPromptService.run_completion" +) +_PATCH_INDEX_UTILS = ( + "unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key" +) +_PATCH_PLUGIN_LOADER = ( + "executor.executors.plugins.loader.ExecutorPluginLoader.get" +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _ensure_legacy_registered(): + """Ensure LegacyExecutor is registered.""" + from executor.executors.legacy_executor import LegacyExecutor + + if "legacy" not in ExecutorRegistry.list_executors(): + ExecutorRegistry._registry["legacy"] = LegacyExecutor + yield + + +@pytest.fixture +def eager_app(): + """Configure executor Celery app for eager-mode testing.""" + from executor.worker import app + + original = { + "task_always_eager": app.conf.task_always_eager, + "task_eager_propagates": app.conf.task_eager_propagates, + "result_backend": app.conf.result_backend, + } + app.conf.update( + task_always_eager=True, + task_eager_propagates=False, + result_backend="cache+memory://", + ) + yield app + app.conf.update(original) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _run_task(eager_app, ctx_dict): + """Run execute_extraction task via task.apply() (eager-safe).""" + task = eager_app.tasks["execute_extraction"] + result = task.apply(args=[ctx_dict]) + return result.get() + + +def _mock_llm(answer="ide answer"): + """Create a mock LLM matching the answer_prompt pattern.""" + llm = MagicMock(name="llm") + response = MagicMock() + response.text = answer + llm.complete.return_value = { + PSKeys.RESPONSE: response, + PSKeys.HIGHLIGHT_DATA: [], + PSKeys.CONFIDENCE_DATA: None, + PSKeys.WORD_CONFIDENCE_DATA: None, + PSKeys.LINE_NUMBERS: [], + PSKeys.WHISPER_HASH: "", + } + llm.get_usage_reason.return_value = "extraction" + llm.get_metrics.return_value = {"tokens": 42} + return llm + + +def _mock_prompt_deps(llm=None): + """Return 7-tuple matching _get_prompt_deps() shape.""" + if llm is None: + llm = _mock_llm() + + from executor.executors.answer_prompt import AnswerPromptService + + RetrievalService = MagicMock(name="RetrievalService") + RetrievalService.run_retrieval.return_value = ["chunk1"] + RetrievalService.retrieve_complete_context.return_value = ["full doc"] + + VariableReplacementService = MagicMock(name="VariableReplacementService") + VariableReplacementService.is_variables_present.return_value = False + + Index = MagicMock(name="Index") + index_instance = MagicMock() + index_instance.generate_index_key.return_value = "doc-ide-key" + Index.return_value = index_instance + + LLM_cls = MagicMock(name="LLM") + LLM_cls.return_value = llm + + EmbeddingCompat = MagicMock(name="EmbeddingCompat") + VectorDB = MagicMock(name="VectorDB") + + return ( + AnswerPromptService, + RetrievalService, + VariableReplacementService, + Index, + LLM_cls, + EmbeddingCompat, + VectorDB, + ) + + +def _mock_process_response(text="ide extracted text"): + """Build a mock TextExtractionResult.""" + from unstract.sdk1.adapters.x2text.dto import ( + TextExtractionMetadata, + TextExtractionResult, + ) + + metadata = TextExtractionMetadata(whisper_hash="ide-hash") + return TextExtractionResult( + extracted_text=text, + extraction_metadata=metadata, + ) + + +def _make_ide_prompt(name="invoice_number", prompt="What is the invoice number?", + output_type="text", **overrides): + """Build a prompt dict matching what prompt_studio_helper builds. + + Uses the exact key strings from ToolStudioPromptKeys / PSKeys. + """ + d = { + PSKeys.NAME: name, + PSKeys.PROMPT: prompt, + PSKeys.TYPE: output_type, + # These match the hyphenated keys from ToolStudioPromptKeys + "chunk-size": 512, + "chunk-overlap": 64, + "retrieval-strategy": "simple", + "llm": "llm-ide-1", + "embedding": "emb-ide-1", + "vector-db": "vdb-ide-1", + "x2text_adapter": "x2t-ide-1", + "similarity-top-k": 3, + "active": True, + "required": True, + } + d.update(overrides) + return d + + +# --- IDE context factories matching prompt_studio_helper payloads --- + + +def _ide_extract_ctx(**overrides): + """Build ExecutionContext matching dynamic_extractor() dispatch. + + Key mapping: dynamic_extractor uses IKeys constants for payload keys, + and adds "platform_api_key" for the executor. + """ + defaults = { + "executor_name": "legacy", + "operation": "extract", + "run_id": "run-ide-ext", + "execution_source": "ide", + "organization_id": "org-ide-test", + "executor_params": { + "x2text_instance_id": "x2t-ide-1", + "file_path": "/prompt-studio/org/user/tool/doc.pdf", + "enable_highlight": True, + "usage_kwargs": {"run_id": "run-ide-ext", "file_name": "doc.pdf"}, + "run_id": "run-ide-ext", + "log_events_id": "log-ide-1", + "execution_source": "ide", + "output_file_path": "/prompt-studio/org/user/tool/extract/doc.txt", + "platform_api_key": "pk-ide-test", + }, + } + defaults.update(overrides) + return ExecutionContext(**defaults) + + +def _ide_index_ctx(**overrides): + """Build ExecutionContext matching dynamic_indexer() dispatch. + + Key mapping: dynamic_indexer uses IKeys constants and adds + "platform_api_key" for the executor. + """ + defaults = { + "executor_name": "legacy", + "operation": "index", + "run_id": "run-ide-idx", + "execution_source": "ide", + "organization_id": "org-ide-test", + "executor_params": { + "tool_id": "tool-ide-1", + "embedding_instance_id": "emb-ide-1", + "vector_db_instance_id": "vdb-ide-1", + "x2text_instance_id": "x2t-ide-1", + "file_path": "/prompt-studio/org/user/tool/extract/doc.txt", + "file_hash": None, + "chunk_overlap": 64, + "chunk_size": 512, + "reindex": False, + "enable_highlight": True, + "usage_kwargs": {"run_id": "run-ide-idx", "file_name": "doc.pdf"}, + "extracted_text": "IDE extracted document text content", + "run_id": "run-ide-idx", + "log_events_id": "log-ide-1", + "execution_source": "ide", + "platform_api_key": "pk-ide-test", + }, + } + defaults.update(overrides) + return ExecutionContext(**defaults) + + +def _ide_answer_prompt_ctx(prompts=None, **overrides): + """Build ExecutionContext matching _fetch_response() dispatch. + + Key mapping: _fetch_response uses TSPKeys (ToolStudioPromptKeys) + constants and adds PLATFORM_SERVICE_API_KEY + include_metadata. + """ + if prompts is None: + prompts = [_make_ide_prompt()] + defaults = { + "executor_name": "legacy", + "operation": "answer_prompt", + "run_id": "run-ide-ap", + "execution_source": "ide", + "organization_id": "org-ide-test", + "executor_params": { + "tool_settings": { + "enable_challenge": False, + "challenge_llm": "llm-challenge-1", + "single_pass_extraction_mode": False, + "summarize_as_source": False, + "preamble": "Extract accurately.", + "postamble": "No explanation.", + "grammar": [], + "enable_highlight": True, + "enable_word_confidence": False, + "platform_postamble": "", + "word_confidence_postamble": "", + }, + "outputs": prompts, + "tool_id": "tool-ide-1", + "run_id": "run-ide-ap", + "file_name": "invoice.pdf", + "file_hash": "abc123hash", + "file_path": "/prompt-studio/org/user/tool/extract/invoice.txt", + "log_events_id": "log-ide-1", + "execution_source": "ide", + "custom_data": {}, + "PLATFORM_SERVICE_API_KEY": "pk-ide-test", + "include_metadata": True, + }, + } + defaults.update(overrides) + return ExecutionContext(**defaults) + + +def _ide_single_pass_ctx(prompts=None, **overrides): + """Build ExecutionContext matching _fetch_single_pass_response() dispatch.""" + if prompts is None: + prompts = [ + _make_ide_prompt(name="revenue", prompt="What is total revenue?"), + _make_ide_prompt(name="date", prompt="What is the date?"), + ] + defaults = { + "executor_name": "legacy", + "operation": "single_pass_extraction", + "run_id": "run-ide-sp", + "execution_source": "ide", + "organization_id": "org-ide-test", + "executor_params": { + "tool_settings": { + "preamble": "Extract accurately.", + "postamble": "No explanation.", + "grammar": [], + "llm": "llm-ide-1", + "x2text_adapter": "x2t-ide-1", + "vector-db": "vdb-ide-1", + "embedding": "emb-ide-1", + "chunk-size": 0, + "chunk-overlap": 0, + "enable_challenge": False, + "enable_highlight": True, + "enable_word_confidence": False, + "challenge_llm": None, + "platform_postamble": "", + "word_confidence_postamble": "", + "summarize_as_source": False, + }, + "outputs": prompts, + "tool_id": "tool-ide-1", + "run_id": "run-ide-sp", + "file_hash": "abc123hash", + "file_name": "invoice.pdf", + "file_path": "/prompt-studio/org/user/tool/extract/invoice.txt", + "log_events_id": "log-ide-1", + "execution_source": "ide", + "custom_data": {}, + "PLATFORM_SERVICE_API_KEY": "pk-ide-test", + "include_metadata": True, + }, + } + defaults.update(overrides) + return ExecutionContext(**defaults) + + +# =========================================================================== +# Test classes +# =========================================================================== + + +class TestIDEExtract: + """IDE extract payload → executor → extracted_text.""" + + @patch(_PATCH_FS) + @patch(_PATCH_X2TEXT) + def test_ide_extract_returns_text(self, mock_x2text_cls, mock_get_fs, eager_app): + """IDE extract payload produces extracted_text in result.data.""" + mock_x2text = MagicMock() + mock_x2text.process.return_value = _mock_process_response( + "Invoice #12345 dated 2024-01-15" + ) + mock_x2text.x2text_instance = MagicMock() + mock_x2text_cls.return_value = mock_x2text + mock_get_fs.return_value = MagicMock() + + ctx = _ide_extract_ctx() + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is True + assert "extracted_text" in result.data + assert result.data["extracted_text"] == "Invoice #12345 dated 2024-01-15" + + @patch(_PATCH_FS) + @patch(_PATCH_X2TEXT) + def test_ide_extract_with_output_file_path( + self, mock_x2text_cls, mock_get_fs, eager_app + ): + """IDE extract passes output_file_path to x2text.process().""" + mock_x2text = MagicMock() + mock_x2text.process.return_value = _mock_process_response("text") + mock_x2text.x2text_instance = MagicMock() + mock_x2text_cls.return_value = mock_x2text + mock_get_fs.return_value = MagicMock() + + ctx = _ide_extract_ctx() + _run_task(eager_app, ctx.to_dict()) + + # Verify output_file_path was passed through + call_kwargs = mock_x2text.process.call_args + assert call_kwargs is not None + assert "output_file_path" in call_kwargs.kwargs + assert call_kwargs.kwargs["output_file_path"] == ( + "/prompt-studio/org/user/tool/extract/doc.txt" + ) + + @patch(_PATCH_FS) + @patch(_PATCH_X2TEXT) + def test_ide_extract_failure(self, mock_x2text_cls, mock_get_fs, eager_app): + """Adapter failure → ExecutionResult(success=False).""" + from unstract.sdk1.adapters.exceptions import AdapterError + + mock_x2text = MagicMock() + mock_x2text.x2text_instance = MagicMock() + mock_x2text.x2text_instance.get_name.return_value = "LLMWhisperer" + mock_x2text.process.side_effect = AdapterError("extraction failed") + mock_x2text_cls.return_value = mock_x2text + mock_get_fs.return_value = MagicMock() + + ctx = _ide_extract_ctx() + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is False + assert "extraction failed" in result.error + + +class TestIDEIndex: + """IDE index payload → executor → doc_id.""" + + @patch(_PATCH_FS) + @patch(_PATCH_INDEX_DEPS) + def test_ide_index_returns_doc_id(self, mock_deps, mock_get_fs, eager_app): + """IDE index payload produces doc_id in result.data.""" + mock_index_cls = MagicMock() + mock_index = MagicMock() + mock_index.generate_index_key.return_value = "doc-ide-indexed" + mock_index.is_document_indexed.return_value = False + mock_index.perform_indexing.return_value = "doc-ide-indexed" + mock_index_cls.return_value = mock_index + + mock_emb_cls = MagicMock() + mock_emb_cls.return_value = MagicMock() + mock_vdb_cls = MagicMock() + mock_vdb_cls.return_value = MagicMock() + + mock_deps.return_value = (mock_index_cls, mock_emb_cls, mock_vdb_cls) + mock_get_fs.return_value = MagicMock() + + ctx = _ide_index_ctx() + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is True + assert result.data["doc_id"] == "doc-ide-indexed" + + @patch(_PATCH_FS) + @patch(_PATCH_INDEX_DEPS) + def test_ide_index_with_null_file_hash(self, mock_deps, mock_get_fs, eager_app): + """IDE indexer sends file_hash=None — executor handles it.""" + mock_index_cls = MagicMock() + mock_index = MagicMock() + mock_index.generate_index_key.return_value = "doc-null-hash" + mock_index.is_document_indexed.return_value = False + mock_index.perform_indexing.return_value = "doc-null-hash" + mock_index_cls.return_value = mock_index + + mock_deps.return_value = (mock_index_cls, MagicMock(), MagicMock()) + mock_get_fs.return_value = MagicMock() + + # file_hash=None is exactly what dynamic_indexer sends + ctx = _ide_index_ctx() + assert ctx.executor_params["file_hash"] is None + + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is True + assert result.data["doc_id"] == "doc-null-hash" + + @patch(_PATCH_FS) + @patch(_PATCH_INDEX_DEPS) + def test_ide_index_failure(self, mock_deps, mock_get_fs, eager_app): + """Index failure → ExecutionResult(success=False).""" + mock_index_cls = MagicMock() + mock_index = MagicMock() + mock_index.generate_index_key.return_value = "doc-fail" + mock_index.is_document_indexed.return_value = False + mock_index.perform_indexing.side_effect = RuntimeError("VDB timeout") + mock_index_cls.return_value = mock_index + + mock_deps.return_value = (mock_index_cls, MagicMock(), MagicMock()) + mock_get_fs.return_value = MagicMock() + + ctx = _ide_index_ctx() + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is False + + +class TestIDEAnswerPrompt: + """IDE answer_prompt payload → executor → {output, metadata, metrics}.""" + + @patch(_PATCH_PLUGIN_LOADER, return_value=None) + @patch(_PATCH_INDEX_UTILS, return_value="doc-id-ide") + @patch(_PATCH_PROMPT_DEPS) + @patch(_PATCH_SHIM) + def test_ide_answer_prompt_text( + self, mock_shim_cls, mock_deps, _mock_idx, _mock_plugin, eager_app + ): + """IDE text prompt → output dict with prompt_key → answer.""" + llm = _mock_llm("INV-2024-001") + mock_deps.return_value = _mock_prompt_deps(llm) + mock_shim_cls.return_value = MagicMock() + + ctx = _ide_answer_prompt_ctx() + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is True + # IDE expects result.data to have "output", "metadata", "metrics" + assert "output" in result.data + assert "metadata" in result.data + assert "metrics" in result.data + assert result.data["output"]["invoice_number"] == "INV-2024-001" + + @patch(_PATCH_PLUGIN_LOADER, return_value=None) + @patch(_PATCH_INDEX_UTILS, return_value="doc-id-ide") + @patch(_PATCH_PROMPT_DEPS) + @patch(_PATCH_SHIM) + def test_ide_answer_prompt_metadata_has_run_id( + self, mock_shim_cls, mock_deps, _mock_idx, _mock_plugin, eager_app + ): + """IDE response metadata contains run_id and file_name.""" + llm = _mock_llm("answer") + mock_deps.return_value = _mock_prompt_deps(llm) + mock_shim_cls.return_value = MagicMock() + + ctx = _ide_answer_prompt_ctx() + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + metadata = result.data["metadata"] + assert metadata["run_id"] == "run-ide-ap" + assert metadata["file_name"] == "invoice.pdf" + + @patch(_PATCH_PLUGIN_LOADER, return_value=None) + @patch(_PATCH_INDEX_UTILS, return_value="doc-id-ide") + @patch(_PATCH_PROMPT_DEPS) + @patch(_PATCH_SHIM) + def test_ide_answer_prompt_with_eval_settings( + self, mock_shim_cls, mock_deps, _mock_idx, _mock_plugin, eager_app + ): + """Prompt with eval_settings passes through to executor cleanly.""" + llm = _mock_llm("answer") + mock_deps.return_value = _mock_prompt_deps(llm) + mock_shim_cls.return_value = MagicMock() + + prompt = _make_ide_prompt( + eval_settings={ + "evaluate": True, + "monitor_llm": ["llm-monitor-1"], + "exclude_failed": True, + } + ) + ctx = _ide_answer_prompt_ctx(prompts=[prompt]) + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is True + + @patch(_PATCH_PLUGIN_LOADER, return_value=None) + @patch(_PATCH_INDEX_UTILS, return_value="doc-id-ide") + @patch(_PATCH_PROMPT_DEPS) + @patch(_PATCH_SHIM) + def test_ide_answer_prompt_platform_key_reaches_shim( + self, mock_shim_cls, mock_deps, _mock_idx, _mock_plugin, eager_app + ): + """PLATFORM_SERVICE_API_KEY in payload reaches ExecutorToolShim.""" + llm = _mock_llm("answer") + mock_deps.return_value = _mock_prompt_deps(llm) + mock_shim_cls.return_value = MagicMock() + + ctx = _ide_answer_prompt_ctx() + _run_task(eager_app, ctx.to_dict()) + + # Verify shim was constructed with the platform key + mock_shim_cls.assert_called() + call_kwargs = mock_shim_cls.call_args + assert call_kwargs.kwargs.get("platform_api_key") == "pk-ide-test" + + @patch(_PATCH_PLUGIN_LOADER, return_value=None) + @patch(_PATCH_INDEX_UTILS, return_value="doc-id-ide") + @patch(_PATCH_PROMPT_DEPS) + @patch(_PATCH_SHIM) + def test_ide_answer_prompt_webhook_settings( + self, mock_shim_cls, mock_deps, _mock_idx, _mock_plugin, eager_app + ): + """Prompt with webhook settings passes through cleanly.""" + llm = _mock_llm("answer") + mock_deps.return_value = _mock_prompt_deps(llm) + mock_shim_cls.return_value = MagicMock() + + prompt = _make_ide_prompt( + enable_postprocessing_webhook=True, + postprocessing_webhook_url="https://example.com/hook", + ) + ctx = _ide_answer_prompt_ctx(prompts=[prompt]) + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is True + + +class TestIDESinglePass: + """IDE single_pass_extraction → executor → same shape as answer_prompt.""" + + @patch(_PATCH_PLUGIN_LOADER, return_value=None) + @patch(_PATCH_INDEX_UTILS, return_value="doc-id-ide") + @patch(_PATCH_PROMPT_DEPS) + @patch(_PATCH_SHIM) + def test_ide_single_pass_multi_prompt( + self, mock_shim_cls, mock_deps, _mock_idx, _mock_plugin, eager_app + ): + """Single pass with multiple prompts → all fields in output.""" + llm = _mock_llm("single pass value") + mock_deps.return_value = _mock_prompt_deps(llm) + mock_shim_cls.return_value = MagicMock() + + ctx = _ide_single_pass_ctx() + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is True + assert "output" in result.data + assert "revenue" in result.data["output"] + assert "date" in result.data["output"] + + @patch(_PATCH_PLUGIN_LOADER, return_value=None) + @patch(_PATCH_INDEX_UTILS, return_value="doc-id-ide") + @patch(_PATCH_PROMPT_DEPS) + @patch(_PATCH_SHIM) + def test_ide_single_pass_has_metadata( + self, mock_shim_cls, mock_deps, _mock_idx, _mock_plugin, eager_app + ): + """Single pass returns metadata with run_id.""" + llm = _mock_llm("value") + mock_deps.return_value = _mock_prompt_deps(llm) + mock_shim_cls.return_value = MagicMock() + + ctx = _ide_single_pass_ctx() + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + + assert result.success is True + assert "metadata" in result.data + assert result.data["metadata"]["run_id"] == "run-ide-sp" + + +class TestIDEDispatcherIntegration: + """Test ExecutionDispatcher dispatch() with IDE payloads in eager mode. + + Celery's send_task() doesn't work with eager mode for AsyncResult.get(), + so we patch send_task to delegate to task.apply() instead. + """ + + @staticmethod + def _patch_send_task(eager_app): + """Patch send_task on eager_app to use task.apply().""" + original_send_task = eager_app.send_task + + def patched_send_task(name, args=None, kwargs=None, **opts): + task = eager_app.tasks[name] + return task.apply(args=args, kwargs=kwargs) + + eager_app.send_task = patched_send_task + return original_send_task + + @patch(_PATCH_FS) + @patch(_PATCH_X2TEXT) + def test_dispatcher_extract_round_trip( + self, mock_x2text_cls, mock_get_fs, eager_app + ): + """ExecutionDispatcher.dispatch() → extract → ExecutionResult.""" + mock_x2text = MagicMock() + mock_x2text.process.return_value = _mock_process_response( + "dispatcher extracted" + ) + mock_x2text.x2text_instance = MagicMock() + mock_x2text_cls.return_value = mock_x2text + mock_get_fs.return_value = MagicMock() + + original = self._patch_send_task(eager_app) + try: + dispatcher = ExecutionDispatcher(celery_app=eager_app) + ctx = _ide_extract_ctx() + result = dispatcher.dispatch(ctx) + finally: + eager_app.send_task = original + + assert result.success is True + assert result.data["extracted_text"] == "dispatcher extracted" + + @patch(_PATCH_PLUGIN_LOADER, return_value=None) + @patch(_PATCH_INDEX_UTILS, return_value="doc-id-ide") + @patch(_PATCH_PROMPT_DEPS) + @patch(_PATCH_SHIM) + def test_dispatcher_answer_prompt_round_trip( + self, mock_shim_cls, mock_deps, _mock_idx, _mock_plugin, eager_app + ): + """ExecutionDispatcher.dispatch() → answer_prompt → ExecutionResult.""" + llm = _mock_llm("dispatcher answer") + mock_deps.return_value = _mock_prompt_deps(llm) + mock_shim_cls.return_value = MagicMock() + + original = self._patch_send_task(eager_app) + try: + dispatcher = ExecutionDispatcher(celery_app=eager_app) + ctx = _ide_answer_prompt_ctx() + result = dispatcher.dispatch(ctx) + finally: + eager_app.send_task = original + + assert result.success is True + assert result.data["output"]["invoice_number"] == "dispatcher answer" + assert "metadata" in result.data + + @patch(_PATCH_PLUGIN_LOADER, return_value=None) + @patch(_PATCH_INDEX_UTILS, return_value="doc-id-ide") + @patch(_PATCH_PROMPT_DEPS) + @patch(_PATCH_SHIM) + def test_dispatcher_single_pass_round_trip( + self, mock_shim_cls, mock_deps, _mock_idx, _mock_plugin, eager_app + ): + """ExecutionDispatcher.dispatch() → single_pass → ExecutionResult.""" + llm = _mock_llm("sp dispatch") + mock_deps.return_value = _mock_prompt_deps(llm) + mock_shim_cls.return_value = MagicMock() + + original = self._patch_send_task(eager_app) + try: + dispatcher = ExecutionDispatcher(celery_app=eager_app) + ctx = _ide_single_pass_ctx() + result = dispatcher.dispatch(ctx) + finally: + eager_app.send_task = original + + assert result.success is True + assert "revenue" in result.data["output"] + + @patch(_PATCH_FS) + @patch(_PATCH_INDEX_DEPS) + def test_dispatcher_index_round_trip( + self, mock_deps, mock_get_fs, eager_app + ): + """ExecutionDispatcher.dispatch() → index → ExecutionResult.""" + mock_index_cls = MagicMock() + mock_index = MagicMock() + mock_index.generate_index_key.return_value = "doc-dispatch-idx" + mock_index.is_document_indexed.return_value = False + mock_index.perform_indexing.return_value = "doc-dispatch-idx" + mock_index_cls.return_value = mock_index + + mock_deps.return_value = (mock_index_cls, MagicMock(), MagicMock()) + mock_get_fs.return_value = MagicMock() + + original = self._patch_send_task(eager_app) + try: + dispatcher = ExecutionDispatcher(celery_app=eager_app) + ctx = _ide_index_ctx() + result = dispatcher.dispatch(ctx) + finally: + eager_app.send_task = original + + assert result.success is True + assert result.data["doc_id"] == "doc-dispatch-idx" + + +class TestIDEExecutionSourceRouting: + """Verify execution_source='ide' propagates correctly.""" + + @patch(_PATCH_FS) + @patch(_PATCH_X2TEXT) + def test_ide_source_reaches_extract_handler( + self, mock_x2text_cls, mock_get_fs, eager_app + ): + """Extract handler receives execution_source='ide' from context.""" + mock_x2text = MagicMock() + mock_x2text.process.return_value = _mock_process_response("text") + mock_x2text.x2text_instance = MagicMock() + mock_x2text_cls.return_value = mock_x2text + mock_fs = MagicMock() + mock_get_fs.return_value = mock_fs + + ctx = _ide_extract_ctx() + assert ctx.execution_source == "ide" + + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + assert result.success is True + + # For IDE source, _update_exec_metadata should NOT write + # (it only writes for execution_source="tool") + # This is verified by the fact that no dump_json was called + # on the fs mock. In IDE mode, whisper_hash metadata is skipped. + + @patch(_PATCH_PLUGIN_LOADER, return_value=None) + @patch(_PATCH_INDEX_UTILS, return_value="doc-id-ide") + @patch(_PATCH_PROMPT_DEPS) + @patch(_PATCH_SHIM) + def test_ide_source_in_answer_prompt_enables_variable_replacement( + self, mock_shim_cls, mock_deps, _mock_idx, _mock_plugin, eager_app + ): + """execution_source='ide' in payload sets is_ide=True for variable replacement.""" + llm = _mock_llm("var answer") + deps = _mock_prompt_deps(llm) + # Enable variable checking to verify is_ide routing + var_service = deps[2] # VariableReplacementService + var_service.is_variables_present.return_value = False + mock_deps.return_value = deps + mock_shim_cls.return_value = MagicMock() + + ctx = _ide_answer_prompt_ctx() + # Verify execution_source is in both context and payload + assert ctx.execution_source == "ide" + assert ctx.executor_params["execution_source"] == "ide" + + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + assert result.success is True + + +class TestIDEPayloadKeyCompatibility: + """Verify the exact key names in IDE payloads match executor expectations.""" + + def test_extract_payload_keys_match_executor(self): + """dynamic_extractor payload keys match _handle_extract reads.""" + ctx = _ide_extract_ctx() + params = ctx.executor_params + + # These are the keys _handle_extract reads from params + assert "x2text_instance_id" in params + assert "file_path" in params + assert "platform_api_key" in params + assert "output_file_path" in params + assert "enable_highlight" in params + assert "usage_kwargs" in params + + def test_index_payload_keys_match_executor(self): + """dynamic_indexer payload keys match _handle_index reads.""" + ctx = _ide_index_ctx() + params = ctx.executor_params + + # These are the keys _handle_index reads from params + assert "embedding_instance_id" in params + assert "vector_db_instance_id" in params + assert "x2text_instance_id" in params + assert "file_path" in params + assert "extracted_text" in params + assert "platform_api_key" in params + assert "chunk_size" in params + assert "chunk_overlap" in params + + def test_answer_prompt_payload_keys_match_executor(self): + """_fetch_response payload keys match _handle_answer_prompt reads.""" + ctx = _ide_answer_prompt_ctx() + params = ctx.executor_params + + # These are the keys _handle_answer_prompt reads + assert "tool_settings" in params + assert "outputs" in params + assert "tool_id" in params + assert "file_hash" in params + assert "file_path" in params + assert "file_name" in params + assert "PLATFORM_SERVICE_API_KEY" in params + assert "log_events_id" in params + assert "execution_source" in params + assert "custom_data" in params + + def test_answer_prompt_platform_key_is_uppercase(self): + """answer_prompt uses PLATFORM_SERVICE_API_KEY (uppercase, not snake_case).""" + ctx = _ide_answer_prompt_ctx() + # _handle_answer_prompt reads PSKeys.PLATFORM_SERVICE_API_KEY + # which is "PLATFORM_SERVICE_API_KEY" + assert "PLATFORM_SERVICE_API_KEY" in ctx.executor_params + # NOT "platform_api_key" (that's for extract/index) + assert ctx.executor_params["PLATFORM_SERVICE_API_KEY"] == "pk-ide-test" + + def test_extract_platform_key_is_lowercase(self): + """extract/index uses platform_api_key (lowercase snake_case).""" + ctx = _ide_extract_ctx() + assert "platform_api_key" in ctx.executor_params + + def test_execution_context_has_ide_source(self): + """All IDE contexts have execution_source='ide'.""" + assert _ide_extract_ctx().execution_source == "ide" + assert _ide_index_ctx().execution_source == "ide" + assert _ide_answer_prompt_ctx().execution_source == "ide" + assert _ide_single_pass_ctx().execution_source == "ide" diff --git a/workers/tests/test_sanity_phase5.py b/workers/tests/test_sanity_phase5.py new file mode 100644 index 0000000000..a7da15d1fb --- /dev/null +++ b/workers/tests/test_sanity_phase5.py @@ -0,0 +1,852 @@ +"""Phase 5-SANITY — Integration tests for the multi-hop elimination. + +Phase 5 eliminates idle backend worker slots by: + - Adding ``dispatch_with_callback`` (fire-and-forget with link/link_error) + - Adding compound operations: ``ide_index``, ``structure_pipeline`` + - Rewiring structure_tool_task to single ``structure_pipeline`` dispatch + +These tests push payloads through the full Celery eager-mode chain and +verify the results match what callers (views / structure_tool_task) expect. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from executor.executors.constants import ( + PromptServiceConstants as PSKeys, +) +from unstract.sdk1.execution.context import ExecutionContext, Operation +from unstract.sdk1.execution.dispatcher import ExecutionDispatcher +from unstract.sdk1.execution.registry import ExecutorRegistry +from unstract.sdk1.execution.result import ExecutionResult + +# --------------------------------------------------------------------------- +# Patch targets +# --------------------------------------------------------------------------- + +_PATCH_X2TEXT = "executor.executors.legacy_executor.X2Text" +_PATCH_FS = "executor.executors.legacy_executor.FileUtils.get_fs_instance" +_PATCH_INDEX_DEPS = ( + "executor.executors.legacy_executor.LegacyExecutor._get_indexing_deps" +) +_PATCH_PROMPT_DEPS = ( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" +) +_PATCH_SHIM = "executor.executors.legacy_executor.ExecutorToolShim" +_PATCH_RUN_COMPLETION = ( + "executor.executors.answer_prompt.AnswerPromptService.run_completion" +) +_PATCH_INDEX_UTILS = ( + "unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key" +) + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _ensure_legacy_registered(): + """Ensure LegacyExecutor is registered.""" + from executor.executors.legacy_executor import LegacyExecutor + + if "legacy" not in ExecutorRegistry.list_executors(): + ExecutorRegistry._registry["legacy"] = LegacyExecutor + yield + + +@pytest.fixture +def eager_app(): + """Configure executor Celery app for eager-mode testing.""" + from executor.worker import app + + original = { + "task_always_eager": app.conf.task_always_eager, + "task_eager_propagates": app.conf.task_eager_propagates, + "result_backend": app.conf.result_backend, + } + app.conf.update( + task_always_eager=True, + task_eager_propagates=False, + result_backend="cache+memory://", + ) + yield app + app.conf.update(original) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _run_task(eager_app, ctx_dict): + """Run execute_extraction task via task.apply() (eager-safe).""" + task = eager_app.tasks["execute_extraction"] + result = task.apply(args=[ctx_dict]) + return result.get() + + +def _mock_llm(answer="pipeline answer"): + """Create a mock LLM matching the answer_prompt pattern.""" + llm = MagicMock(name="llm") + response = MagicMock() + response.text = answer + llm.complete.return_value = { + PSKeys.RESPONSE: response, + PSKeys.HIGHLIGHT_DATA: [], + PSKeys.CONFIDENCE_DATA: None, + PSKeys.WORD_CONFIDENCE_DATA: None, + PSKeys.LINE_NUMBERS: [], + PSKeys.WHISPER_HASH: "", + } + llm.get_usage_reason.return_value = "extraction" + llm.get_metrics.return_value = {"tokens": 42} + return llm + + +def _mock_prompt_deps(llm=None): + """Return 7-tuple matching _get_prompt_deps() shape.""" + if llm is None: + llm = _mock_llm() + + from executor.executors.answer_prompt import AnswerPromptService + + RetrievalService = MagicMock(name="RetrievalService") + RetrievalService.run_retrieval.return_value = ["chunk1"] + RetrievalService.retrieve_complete_context.return_value = ["full doc"] + + VariableReplacementService = MagicMock(name="VariableReplacementService") + VariableReplacementService.is_variables_present.return_value = False + + Index = MagicMock(name="Index") + index_instance = MagicMock() + index_instance.generate_index_key.return_value = "doc-key-1" + Index.return_value = index_instance + + LLM_cls = MagicMock(name="LLM") + LLM_cls.return_value = llm + + EmbeddingCompat = MagicMock(name="EmbeddingCompat") + VectorDB = MagicMock(name="VectorDB") + + return ( + AnswerPromptService, + RetrievalService, + VariableReplacementService, + Index, + LLM_cls, + EmbeddingCompat, + VectorDB, + ) + + +def _mock_process_response(text="extracted text"): + """Build a mock TextExtractionResult.""" + from unstract.sdk1.adapters.x2text.dto import ( + TextExtractionMetadata, + TextExtractionResult, + ) + + metadata = TextExtractionMetadata(whisper_hash="test-hash") + return TextExtractionResult( + extracted_text=text, + extraction_metadata=metadata, + ) + + +def _make_output(name="field_a", prompt="What is the revenue?", **overrides): + """Build an output dict for answer_prompt payloads.""" + d = { + PSKeys.NAME: name, + PSKeys.PROMPT: prompt, + PSKeys.TYPE: "text", + "chunk-size": 512, + "chunk-overlap": 64, + "retrieval-strategy": "simple", + "llm": "llm-1", + "embedding": "emb-1", + "vector-db": "vdb-1", + "x2text_adapter": "x2t-1", + "similarity-top-k": 3, + "active": True, + } + d.update(overrides) + return d + + +# --------------------------------------------------------------------------- +# 5A: dispatch_with_callback +# --------------------------------------------------------------------------- + + +class TestDispatchWithCallback: + """Verify dispatch_with_callback passes link/link_error to send_task.""" + + def test_callback_kwargs_passed(self): + mock_app = MagicMock() + mock_app.send_task.return_value = MagicMock(id="task-123") + dispatcher = ExecutionDispatcher(celery_app=mock_app) + + ctx = ExecutionContext( + executor_name="legacy", + operation="answer_prompt", + run_id="run-cb-1", + execution_source="ide", + ) + on_success = MagicMock(name="success_sig") + on_error = MagicMock(name="error_sig") + + result = dispatcher.dispatch_with_callback( + ctx, + on_success=on_success, + on_error=on_error, + task_id="pre-generated-id", + ) + + call_kwargs = mock_app.send_task.call_args + assert call_kwargs.kwargs["link"] is on_success + assert call_kwargs.kwargs["link_error"] is on_error + assert call_kwargs.kwargs["task_id"] == "pre-generated-id" + assert result.id == "task-123" + + def test_no_callbacks_omits_link_kwargs(self): + mock_app = MagicMock() + mock_app.send_task.return_value = MagicMock(id="task-456") + dispatcher = ExecutionDispatcher(celery_app=mock_app) + + ctx = ExecutionContext( + executor_name="legacy", + operation="extract", + run_id="run-cb-2", + execution_source="tool", + ) + dispatcher.dispatch_with_callback(ctx) + + call_kwargs = mock_app.send_task.call_args + assert "link" not in call_kwargs.kwargs + assert "link_error" not in call_kwargs.kwargs + + def test_no_app_raises(self): + dispatcher = ExecutionDispatcher(celery_app=None) + ctx = ExecutionContext( + executor_name="legacy", + operation="extract", + run_id="run-cb-3", + execution_source="tool", + ) + with pytest.raises(ValueError, match="No Celery app"): + dispatcher.dispatch_with_callback(ctx) + + +# --------------------------------------------------------------------------- +# 5C: ide_index compound operation through eager chain +# --------------------------------------------------------------------------- + + +class TestIdeIndexEagerChain: + """ide_index: extract + index in a single executor invocation.""" + + @patch(_PATCH_INDEX_DEPS) + @patch(_PATCH_FS) + @patch(_PATCH_X2TEXT) + @patch(_PATCH_SHIM) + def test_ide_index_success( + self, + MockShim, + MockX2Text, + mock_fs, + mock_index_deps, + eager_app, + ): + """Full ide_index through eager chain returns doc_id.""" + # Mock extract + x2t_instance = MagicMock() + x2t_instance.process.return_value = _mock_process_response( + "IDE extracted text" + ) + MockX2Text.return_value = x2t_instance + + fs = MagicMock() + fs.exists.return_value = False + mock_fs.return_value = fs + + # Mock index + index_inst = MagicMock() + index_inst.index.return_value = "idx-doc-1" + index_inst.generate_index_key.return_value = "idx-key-1" + mock_index_deps.return_value = ( + MagicMock(return_value=index_inst), # Index + MagicMock(), # EmbeddingCompat + MagicMock(), # VectorDB + ) + + ctx = ExecutionContext( + executor_name="legacy", + operation="ide_index", + run_id="run-ide-idx", + execution_source="ide", + organization_id="org-test", + executor_params={ + "extract_params": { + "x2text_instance_id": "x2t-1", + "file_path": "/data/doc.pdf", + "enable_highlight": False, + "output_file_path": "/data/extract/doc.txt", + "platform_api_key": "pk-test", + "usage_kwargs": {}, + }, + "index_params": { + "tool_id": "tool-1", + "embedding_instance_id": "emb-1", + "vector_db_instance_id": "vdb-1", + "x2text_instance_id": "x2t-1", + "file_path": "/data/extract/doc.txt", + "file_hash": None, + "chunk_overlap": 64, + "chunk_size": 512, + "reindex": True, + "enable_highlight": False, + "usage_kwargs": {}, + "run_id": "run-ide-idx", + "execution_source": "ide", + "platform_api_key": "pk-test", + }, + }, + ) + + result_dict = _run_task(eager_app, ctx.to_dict()) + + result = ExecutionResult.from_dict(result_dict) + assert result.success + assert "doc_id" in result.data + + @patch(_PATCH_FS) + @patch(_PATCH_X2TEXT) + @patch(_PATCH_SHIM) + def test_ide_index_extract_failure( + self, + MockShim, + MockX2Text, + mock_fs, + eager_app, + ): + """ide_index returns failure if extract fails.""" + x2t_instance = MagicMock() + x2t_instance.process.side_effect = Exception("X2Text unavailable") + MockX2Text.return_value = x2t_instance + + fs = MagicMock() + fs.exists.return_value = False + mock_fs.return_value = fs + + ctx = ExecutionContext( + executor_name="legacy", + operation="ide_index", + run_id="run-ide-fail", + execution_source="ide", + executor_params={ + "extract_params": { + "x2text_instance_id": "x2t-1", + "file_path": "/data/doc.pdf", + "enable_highlight": False, + "platform_api_key": "pk-test", + "usage_kwargs": {}, + }, + "index_params": { + "tool_id": "tool-1", + "embedding_instance_id": "emb-1", + "vector_db_instance_id": "vdb-1", + "x2text_instance_id": "x2t-1", + "file_path": "/data/extract/doc.txt", + "file_hash": None, + "chunk_overlap": 64, + "chunk_size": 512, + "reindex": True, + "enable_highlight": False, + "usage_kwargs": {}, + "run_id": "run-ide-fail", + "execution_source": "ide", + "platform_api_key": "pk-test", + }, + }, + ) + + result_dict = _run_task(eager_app, ctx.to_dict()) + result = ExecutionResult.from_dict(result_dict) + assert not result.success + assert "X2Text" in result.error + + +# --------------------------------------------------------------------------- +# 5D: structure_pipeline compound operation through eager chain +# --------------------------------------------------------------------------- + + +class TestStructurePipelineEagerChain: + """structure_pipeline: full extract→index→answer through eager chain.""" + + @patch(_PATCH_INDEX_UTILS, return_value="doc-id-pipeline") + @patch(_PATCH_PROMPT_DEPS) + @patch(_PATCH_INDEX_DEPS) + @patch(_PATCH_FS) + @patch(_PATCH_X2TEXT) + @patch(_PATCH_SHIM) + def test_structure_pipeline_normal( + self, + MockShim, + MockX2Text, + mock_fs, + mock_index_deps, + mock_prompt_deps, + _mock_idx_utils, + eager_app, + ): + """Normal pipeline: extract → index → answer_prompt.""" + # Mock extract + x2t_instance = MagicMock() + x2t_instance.process.return_value = _mock_process_response("Revenue is $1M") + MockX2Text.return_value = x2t_instance + + fs = MagicMock() + fs.exists.return_value = False + mock_fs.return_value = fs + + # Mock index + index_inst = MagicMock() + index_inst.index.return_value = "idx-doc-1" + index_inst.generate_index_key.return_value = "idx-key-1" + mock_index_deps.return_value = ( + MagicMock(return_value=index_inst), + MagicMock(), + MagicMock(), + ) + + # Mock prompt deps + mock_prompt_deps.return_value = _mock_prompt_deps() + + ctx = ExecutionContext( + executor_name="legacy", + operation="structure_pipeline", + run_id="run-sp-1", + execution_source="tool", + organization_id="org-test", + executor_params={ + "extract_params": { + "x2text_instance_id": "x2t-1", + "file_path": "/data/test.pdf", + "enable_highlight": False, + "output_file_path": "/data/exec/EXTRACT", + "platform_api_key": "pk-test", + "usage_kwargs": {}, + }, + "index_template": { + "tool_id": "tool-1", + "file_hash": "hash123", + "is_highlight_enabled": False, + "platform_api_key": "pk-test", + "extracted_file_path": "/data/exec/EXTRACT", + }, + "answer_params": { + "run_id": "run-sp-1", + "execution_id": "exec-1", + "tool_settings": { + "vector-db": "vdb-1", + "embedding": "emb-1", + "x2text_adapter": "x2t-1", + "llm": "llm-1", + "enable_challenge": False, + "challenge_llm": "", + "enable_single_pass_extraction": False, + "summarize_as_source": False, + "enable_highlight": False, + }, + "outputs": [_make_output()], + "tool_id": "tool-1", + "file_hash": "hash123", + "file_name": "test.pdf", + "file_path": "/data/exec/EXTRACT", + "execution_source": "tool", + "PLATFORM_SERVICE_API_KEY": "pk-test", + }, + "pipeline_options": { + "skip_extraction_and_indexing": False, + "is_summarization_enabled": False, + "is_single_pass_enabled": False, + "input_file_path": "/data/test.pdf", + "source_file_name": "test.pdf", + }, + "summarize_params": None, + }, + ) + + result_dict = _run_task(eager_app, ctx.to_dict()) + + result = ExecutionResult.from_dict(result_dict) + assert result.success + assert "output" in result.data + assert "metadata" in result.data + # source_file_name injected into metadata + assert result.data["metadata"]["file_name"] == "test.pdf" + + @patch(_PATCH_INDEX_UTILS, return_value="doc-id-sp") + @patch(_PATCH_PROMPT_DEPS) + @patch(_PATCH_FS) + @patch(_PATCH_X2TEXT) + @patch(_PATCH_SHIM) + def test_structure_pipeline_single_pass( + self, + MockShim, + MockX2Text, + mock_fs, + mock_prompt_deps, + _mock_idx_utils, + eager_app, + ): + """Single pass: extract → single_pass_extraction (no index).""" + x2t_instance = MagicMock() + x2t_instance.process.return_value = _mock_process_response("Revenue data") + MockX2Text.return_value = x2t_instance + + fs = MagicMock() + fs.exists.return_value = False + mock_fs.return_value = fs + + mock_prompt_deps.return_value = _mock_prompt_deps() + + ctx = ExecutionContext( + executor_name="legacy", + operation="structure_pipeline", + run_id="run-sp-sp", + execution_source="tool", + executor_params={ + "extract_params": { + "x2text_instance_id": "x2t-1", + "file_path": "/data/test.pdf", + "enable_highlight": False, + "output_file_path": "/data/exec/EXTRACT", + "platform_api_key": "pk-test", + "usage_kwargs": {}, + }, + "index_template": {}, + "answer_params": { + "run_id": "run-sp-sp", + "tool_settings": { + "vector-db": "vdb-1", + "embedding": "emb-1", + "x2text_adapter": "x2t-1", + "llm": "llm-1", + "enable_challenge": False, + "challenge_llm": "", + "enable_single_pass_extraction": True, + "summarize_as_source": False, + "enable_highlight": False, + }, + "outputs": [_make_output()], + "tool_id": "tool-1", + "file_hash": "hash123", + "file_name": "test.pdf", + "file_path": "/data/exec/EXTRACT", + "execution_source": "tool", + "PLATFORM_SERVICE_API_KEY": "pk-test", + }, + "pipeline_options": { + "skip_extraction_and_indexing": False, + "is_summarization_enabled": False, + "is_single_pass_enabled": True, + "input_file_path": "/data/test.pdf", + "source_file_name": "test.pdf", + }, + "summarize_params": None, + }, + ) + + result_dict = _run_task(eager_app, ctx.to_dict()) + + result = ExecutionResult.from_dict(result_dict) + assert result.success + assert "output" in result.data + + @patch(_PATCH_INDEX_UTILS, return_value="doc-id-skip") + @patch(_PATCH_PROMPT_DEPS) + @patch(_PATCH_FS) + @patch(_PATCH_X2TEXT) + @patch(_PATCH_SHIM) + def test_structure_pipeline_skip_extraction( + self, + MockShim, + MockX2Text, + mock_fs, + mock_prompt_deps, + _mock_idx_utils, + eager_app, + ): + """Smart table: skip extraction, go straight to answer_prompt.""" + fs = MagicMock() + fs.exists.return_value = False + mock_fs.return_value = fs + + mock_prompt_deps.return_value = _mock_prompt_deps() + + ctx = ExecutionContext( + executor_name="legacy", + operation="structure_pipeline", + run_id="run-sp-skip", + execution_source="tool", + executor_params={ + "extract_params": {}, + "index_template": {}, + "answer_params": { + "run_id": "run-sp-skip", + "tool_settings": { + "vector-db": "vdb-1", + "embedding": "emb-1", + "x2text_adapter": "x2t-1", + "llm": "llm-1", + "enable_challenge": False, + "challenge_llm": "", + "enable_single_pass_extraction": False, + "summarize_as_source": False, + "enable_highlight": False, + }, + "outputs": [_make_output(prompt='{"key": "value"}')], + "tool_id": "tool-1", + "file_hash": "hash123", + "file_name": "test.xlsx", + "file_path": "/data/test.xlsx", + "execution_source": "tool", + "PLATFORM_SERVICE_API_KEY": "pk-test", + }, + "pipeline_options": { + "skip_extraction_and_indexing": True, + "is_summarization_enabled": False, + "is_single_pass_enabled": False, + "input_file_path": "/data/test.xlsx", + "source_file_name": "test.xlsx", + }, + "summarize_params": None, + }, + ) + + result_dict = _run_task(eager_app, ctx.to_dict()) + + result = ExecutionResult.from_dict(result_dict) + assert result.success + # No extract was called (X2Text not mocked beyond fixture) + MockX2Text.assert_not_called() + + @patch(_PATCH_FS) + @patch(_PATCH_X2TEXT) + @patch(_PATCH_SHIM) + def test_structure_pipeline_extract_failure( + self, + MockShim, + MockX2Text, + mock_fs, + eager_app, + ): + """Pipeline extract failure propagated as result failure.""" + x2t_instance = MagicMock() + x2t_instance.process.side_effect = Exception("X2Text timeout") + MockX2Text.return_value = x2t_instance + + fs = MagicMock() + fs.exists.return_value = False + mock_fs.return_value = fs + + ctx = ExecutionContext( + executor_name="legacy", + operation="structure_pipeline", + run_id="run-sp-fail", + execution_source="tool", + executor_params={ + "extract_params": { + "x2text_instance_id": "x2t-1", + "file_path": "/data/test.pdf", + "enable_highlight": False, + "platform_api_key": "pk-test", + "usage_kwargs": {}, + }, + "index_template": {}, + "answer_params": {}, + "pipeline_options": { + "skip_extraction_and_indexing": False, + "is_summarization_enabled": False, + "is_single_pass_enabled": False, + "input_file_path": "/data/test.pdf", + "source_file_name": "test.pdf", + }, + "summarize_params": None, + }, + ) + + result_dict = _run_task(eager_app, ctx.to_dict()) + + result = ExecutionResult.from_dict(result_dict) + assert not result.success + assert "X2Text" in result.error + + +# --------------------------------------------------------------------------- +# 5E: structure_tool_task single dispatch verification +# --------------------------------------------------------------------------- + + +class TestStructureToolSingleDispatch: + """Verify structure_tool_task dispatches exactly once.""" + + @patch( + "executor.executor_tool_shim.ExecutorToolShim" + ) + @patch( + "file_processing.structure_tool_task._get_file_storage" + ) + @patch( + "file_processing.structure_tool_task._create_platform_helper" + ) + @patch( + "file_processing.structure_tool_task.ExecutionDispatcher" + ) + def test_single_dispatch_normal( + self, + MockDispatcher, + mock_create_ph, + mock_get_fs, + MockShim, + ): + """Normal path sends single structure_pipeline dispatch.""" + from file_processing.structure_tool_task import ( + _execute_structure_tool_impl, + ) + + fs = MagicMock() + fs.exists.return_value = False + mock_get_fs.return_value = fs + + ph = MagicMock() + ph.get_prompt_studio_tool.return_value = { + "tool_metadata": { + "name": "Test", + "is_agentic": False, + "tool_id": "t1", + "tool_settings": { + "vector-db": "v1", + "embedding": "e1", + "x2text_adapter": "x1", + "llm": "l1", + }, + "outputs": [ + { + "name": "f1", + "prompt": "What?", + "type": "text", + "active": True, + "chunk-size": 512, + "chunk-overlap": 64, + "llm": "l1", + "embedding": "e1", + "vector-db": "v1", + "x2text_adapter": "x1", + }, + ], + }, + } + mock_create_ph.return_value = ph + + dispatcher = MagicMock() + MockDispatcher.return_value = dispatcher + dispatcher.dispatch.return_value = ExecutionResult( + success=True, + data={"output": {"f1": "ans"}, "metadata": {}, "metrics": {}}, + ) + + params = { + "organization_id": "org-1", + "workflow_id": "wf-1", + "execution_id": "ex-1", + "file_execution_id": "fex-1", + "tool_instance_metadata": {"prompt_registry_id": "pr-1"}, + "platform_service_api_key": "pk-1", + "input_file_path": "/data/test.pdf", + "output_dir_path": "/output", + "source_file_name": "test.pdf", + "execution_data_dir": "/data/exec", + "file_hash": "h1", + "exec_metadata": {}, + } + + result = _execute_structure_tool_impl(params) + + assert result["success"] is True + assert dispatcher.dispatch.call_count == 1 + ctx = dispatcher.dispatch.call_args[0][0] + assert ctx.operation == "structure_pipeline" + assert "extract_params" in ctx.executor_params + assert "index_template" in ctx.executor_params + assert "answer_params" in ctx.executor_params + assert "pipeline_options" in ctx.executor_params + + +# --------------------------------------------------------------------------- +# Operation enum completeness +# --------------------------------------------------------------------------- + + +class TestOperationEnum: + """Verify Phase 5 operations registered in enum.""" + + def test_ide_index_operation(self): + assert hasattr(Operation, "IDE_INDEX") + assert Operation.IDE_INDEX.value == "ide_index" + + def test_structure_pipeline_operation(self): + assert hasattr(Operation, "STRUCTURE_PIPELINE") + assert Operation.STRUCTURE_PIPELINE.value == "structure_pipeline" + + +# --------------------------------------------------------------------------- +# Dispatcher modes +# --------------------------------------------------------------------------- + + +class TestDispatcherModes: + """Verify all three dispatch modes work.""" + + def test_dispatch_sync(self): + """dispatch() calls send_task and .get().""" + mock_app = MagicMock() + async_result = MagicMock() + async_result.get.return_value = ExecutionResult( + success=True, data={"test": 1} + ).to_dict() + mock_app.send_task.return_value = async_result + + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = ExecutionContext( + executor_name="legacy", + operation="extract", + run_id="r1", + execution_source="tool", + ) + result = dispatcher.dispatch(ctx, timeout=10) + + assert result.success + mock_app.send_task.assert_called_once() + async_result.get.assert_called_once() + + def test_dispatch_async(self): + """dispatch_async() returns task_id without blocking.""" + mock_app = MagicMock() + mock_app.send_task.return_value = MagicMock(id="async-id") + + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = ExecutionContext( + executor_name="legacy", + operation="extract", + run_id="r2", + execution_source="tool", + ) + task_id = dispatcher.dispatch_async(ctx) + + assert task_id == "async-id" + mock_app.send_task.assert_called_once() diff --git a/workers/tests/test_sanity_phase6a.py b/workers/tests/test_sanity_phase6a.py new file mode 100644 index 0000000000..d35833fc2c --- /dev/null +++ b/workers/tests/test_sanity_phase6a.py @@ -0,0 +1,310 @@ +"""Phase 6A Sanity — Plugin loader infrastructure + queue-per-executor routing. + +Verifies: +1. ExecutorPluginLoader.get() returns None when no plugins installed +2. ExecutorPluginLoader.discover_executors() returns empty when no cloud executors +3. ExecutorPluginLoader.clear() resets cached state +4. ExecutorPluginLoader.get() discovers entry-point-based plugins (mocked) +5. ExecutorPluginLoader.discover_executors() loads cloud executors (mocked) +6. text_processor.add_hex_line_numbers() +7. ExecutionDispatcher._get_queue() naming convention +8. Protocol classes importable and runtime-checkable +9. executors/__init__.py triggers discover_executors() +""" + +from unittest.mock import MagicMock, patch + +import pytest +from executor.executors.plugins.loader import ExecutorPluginLoader +from executor.executors.plugins.text_processor import add_hex_line_numbers +from unstract.sdk1.execution.dispatcher import ExecutionDispatcher + + +@pytest.fixture(autouse=True) +def _reset_plugin_loader(): + """Ensure clean plugin loader state for every test.""" + ExecutorPluginLoader.clear() + yield + ExecutorPluginLoader.clear() + + +# ── 1. Plugin loader: no plugins installed ────────────────────────── + + +class TestPluginLoaderNoPlugins: + """When no cloud plugins are installed, loader returns None / empty. + + Mocks entry_points to simulate a clean OSS environment where + no cloud executor plugins are pip-installed. + """ + + @patch( + "importlib.metadata.entry_points", + return_value=[], + ) + def test_get_returns_none_for_unknown_plugin(self, _mock_eps): + result = ExecutorPluginLoader.get("nonexistent-plugin") + assert result is None + + @patch( + "importlib.metadata.entry_points", + return_value=[], + ) + def test_get_returns_none_for_highlight_data(self, _mock_eps): + """highlight-data is a cloud plugin, not installed in OSS.""" + result = ExecutorPluginLoader.get("highlight-data") + assert result is None + + @patch( + "importlib.metadata.entry_points", + return_value=[], + ) + def test_get_returns_none_for_challenge(self, _mock_eps): + result = ExecutorPluginLoader.get("challenge") + assert result is None + + @patch( + "importlib.metadata.entry_points", + return_value=[], + ) + def test_get_returns_none_for_evaluation(self, _mock_eps): + result = ExecutorPluginLoader.get("evaluation") + assert result is None + + @patch( + "importlib.metadata.entry_points", + return_value=[], + ) + def test_discover_executors_returns_empty(self, _mock_eps): + discovered = ExecutorPluginLoader.discover_executors() + assert discovered == [] + + +# ── 2. Plugin loader: clear resets cached state ───────────────────── + + +class TestPluginLoaderClear: + @patch("importlib.metadata.entry_points", return_value=[]) + def test_clear_resets_plugins(self, _mock_eps): + # Force discovery (caches empty dict) + ExecutorPluginLoader.get("anything") + assert ExecutorPluginLoader._plugins is not None + + ExecutorPluginLoader.clear() + assert ExecutorPluginLoader._plugins is None + + @patch("importlib.metadata.entry_points", return_value=[]) + def test_get_after_clear_re_discovers(self, _mock_eps): + """After clear(), next get() re-runs discovery.""" + ExecutorPluginLoader.get("x") + assert ExecutorPluginLoader._plugins == {} + + ExecutorPluginLoader.clear() + assert ExecutorPluginLoader._plugins is None + + # Next get() triggers fresh discovery + ExecutorPluginLoader.get("y") + assert ExecutorPluginLoader._plugins is not None + + +# ── 3. Plugin loader with mocked entry points ────────────────────── + + +class TestPluginLoaderWithMockedEntryPoints: + """Simulate cloud plugins being installed by mocking entry_points().""" + + def test_get_discovers_plugin_from_entry_point(self): + """Mocked highlight-data entry point is loaded and cached.""" + + class FakeHighlightData: + pass + + fake_ep = MagicMock() + fake_ep.name = "highlight-data" + fake_ep.load.return_value = FakeHighlightData + + with patch( + "importlib.metadata.entry_points", + return_value=[fake_ep], + ): + result = ExecutorPluginLoader.get("highlight-data") + + assert result is FakeHighlightData + fake_ep.load.assert_called_once() + + def test_get_caches_after_first_call(self): + """Entry points are only queried once; subsequent calls use cache.""" + fake_ep = MagicMock() + fake_ep.name = "challenge" + fake_ep.load.return_value = type("FakeChallenge", (), {}) + + with patch( + "importlib.metadata.entry_points", + return_value=[fake_ep], + ) as mock_eps: + ExecutorPluginLoader.get("challenge") + ExecutorPluginLoader.get("challenge") # second call + + # entry_points() called only once (first get triggers discovery) + mock_eps.assert_called_once() + + def test_failed_plugin_load_is_skipped(self): + """If a plugin fails to load, it's skipped without raising.""" + bad_ep = MagicMock() + bad_ep.name = "bad-plugin" + bad_ep.load.side_effect = ImportError("missing dep") + + good_ep = MagicMock() + good_ep.name = "good-plugin" + good_ep.load.return_value = type("Good", (), {}) + + with patch( + "importlib.metadata.entry_points", + return_value=[bad_ep, good_ep], + ): + assert ExecutorPluginLoader.get("good-plugin") is not None + assert ExecutorPluginLoader.get("bad-plugin") is None + + def test_discover_executors_loads_classes(self): + """Mocked cloud executor entry points are imported.""" + + class FakeTableExecutor: + pass + + fake_ep = MagicMock() + fake_ep.name = "table" + fake_ep.load.return_value = FakeTableExecutor + + with patch( + "importlib.metadata.entry_points", + return_value=[fake_ep], + ): + discovered = ExecutorPluginLoader.discover_executors() + + assert discovered == ["table"] + fake_ep.load.assert_called_once() + + def test_discover_executors_skips_failures(self): + """Failed executor loads are skipped, successful ones returned.""" + bad_ep = MagicMock() + bad_ep.name = "broken" + bad_ep.load.side_effect = ImportError("nope") + + good_ep = MagicMock() + good_ep.name = "smart_table" + good_ep.load.return_value = type("FakeSmartTable", (), {}) + + with patch( + "importlib.metadata.entry_points", + return_value=[bad_ep, good_ep], + ): + discovered = ExecutorPluginLoader.discover_executors() + + assert discovered == ["smart_table"] + + +# ── 4. text_processor ─────────────────────────────────────────────── + + +class TestTextProcessor: + def test_single_line(self): + result = add_hex_line_numbers("hello") + assert result == "0x0: hello" + + def test_multiple_lines(self): + result = add_hex_line_numbers("a\nb\nc") + assert result == "0x0: a\n0x1: b\n0x2: c" + + def test_empty_string(self): + result = add_hex_line_numbers("") + assert result == "0x0: " + + def test_hex_width_grows(self): + # 17 lines → hex needs 2 digits (0x10 = 16) + text = "\n".join(f"line{i}" for i in range(17)) + result = add_hex_line_numbers(text) + lines = result.split("\n") + assert lines[0].startswith("0x00: ") + assert lines[16].startswith("0x10: ") + + +# ── 5. Queue-per-executor routing ─────────────────────────────────── + + +class TestQueuePerExecutor: + def test_get_queue_legacy(self): + assert ExecutionDispatcher._get_queue("legacy") == "celery_executor_legacy" + + def test_get_queue_table(self): + assert ExecutionDispatcher._get_queue("table") == "celery_executor_table" + + def test_get_queue_smart_table(self): + assert ( + ExecutionDispatcher._get_queue("smart_table") + == "celery_executor_smart_table" + ) + + def test_get_queue_simple_prompt_studio(self): + assert ( + ExecutionDispatcher._get_queue("simple_prompt_studio") + == "celery_executor_simple_prompt_studio" + ) + + def test_get_queue_agentic(self): + assert ExecutionDispatcher._get_queue("agentic") == "celery_executor_agentic" + + def test_get_queue_arbitrary_name(self): + """Any executor_name works — no whitelist.""" + assert ( + ExecutionDispatcher._get_queue("my_custom") + == "celery_executor_my_custom" + ) + + def test_queue_name_enum_matches_dispatcher(self): + """QueueName.EXECUTOR matches what dispatcher generates for 'legacy'.""" + from shared.enums.worker_enums import QueueName + + assert QueueName.EXECUTOR.value == ExecutionDispatcher._get_queue("legacy") + + +# ── 6. Protocol classes importable ────────────────────────────────── + + +class TestProtocols: + def test_highlight_data_protocol_importable(self): + from executor.executors.plugins.protocols import HighlightDataProtocol + + assert HighlightDataProtocol is not None + + def test_challenge_protocol_importable(self): + from executor.executors.plugins.protocols import ChallengeProtocol + + assert ChallengeProtocol is not None + + def test_evaluation_protocol_importable(self): + from executor.executors.plugins.protocols import EvaluationProtocol + + assert EvaluationProtocol is not None + + def test_runtime_checkable(self): + """Protocols are @runtime_checkable — isinstance checks work.""" + from executor.executors.plugins.protocols import ChallengeProtocol + + class FakeChallenge: + def run(self): + pass + + assert isinstance(FakeChallenge(), ChallengeProtocol) + + +# ── 7. executors/__init__.py triggers discovery ───────────────────── + + +class TestExecutorsInit: + def test_cloud_executors_list_exists(self): + """executors.__init__ populates _cloud_executors (empty in OSS).""" + import executor.executors as mod + + assert hasattr(mod, "_cloud_executors") + # In pure OSS, no cloud executors are installed + assert isinstance(mod._cloud_executors, list) diff --git a/workers/tests/test_sanity_phase6c.py b/workers/tests/test_sanity_phase6c.py new file mode 100644 index 0000000000..54388f6fee --- /dev/null +++ b/workers/tests/test_sanity_phase6c.py @@ -0,0 +1,559 @@ +"""Phase 6C Sanity — Highlight data as cross-cutting plugin. + +Verifies: +1. run_completion() passes process_text to llm.complete() +2. run_completion() with process_text=None (default) works as before +3. construct_and_run_prompt() passes process_text through to run_completion() +4. _handle_answer_prompt() initializes highlight plugin when enabled + available +5. _handle_answer_prompt() skips highlight when plugin not installed +6. _handle_answer_prompt() skips highlight when enable_highlight=False +7. Highlight metadata populated when plugin provides data via process_text +""" + +from unittest.mock import MagicMock, patch + +import pytest +from executor.executors.answer_prompt import AnswerPromptService +from executor.executors.constants import PromptServiceConstants as PSKeys + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture() +def mock_llm(): + """Create a mock LLM that returns a realistic completion dict.""" + llm = MagicMock() + llm.complete.return_value = { + PSKeys.RESPONSE: MagicMock(text="42"), + PSKeys.HIGHLIGHT_DATA: [{"line": 1}], + PSKeys.CONFIDENCE_DATA: {"score": 0.95}, + PSKeys.WORD_CONFIDENCE_DATA: {"words": []}, + PSKeys.LINE_NUMBERS: [1, 2], + PSKeys.WHISPER_HASH: "abc123", + } + return llm + + +@pytest.fixture() +def mock_llm_no_highlight(): + """Create a mock LLM that returns completion without highlight data.""" + llm = MagicMock() + llm.complete.return_value = { + PSKeys.RESPONSE: MagicMock(text="answer"), + PSKeys.HIGHLIGHT_DATA: [], + PSKeys.LINE_NUMBERS: [], + PSKeys.WHISPER_HASH: "", + } + return llm + + +# --------------------------------------------------------------------------- +# 1. run_completion() passes process_text to llm.complete() +# --------------------------------------------------------------------------- + +class TestRunCompletionProcessText: + def test_process_text_passed_to_llm_complete(self, mock_llm): + """process_text callback is forwarded to llm.complete().""" + callback = MagicMock(name="highlight_run") + AnswerPromptService.run_completion( + llm=mock_llm, + prompt="test prompt", + process_text=callback, + ) + mock_llm.complete.assert_called_once() + call_kwargs = mock_llm.complete.call_args + assert call_kwargs.kwargs.get("process_text") is callback or \ + call_kwargs[1].get("process_text") is callback + + def test_process_text_none_by_default(self, mock_llm): + """When process_text not provided, None is passed to llm.complete().""" + AnswerPromptService.run_completion( + llm=mock_llm, + prompt="test prompt", + ) + call_kwargs = mock_llm.complete.call_args + # Check both positional and keyword args + pt = call_kwargs.kwargs.get("process_text", "MISSING") + if pt == "MISSING": + # Might be positional + pt = call_kwargs[1].get("process_text") + assert pt is None + + def test_process_text_none_explicit(self, mock_llm): + """Explicit process_text=None works as before.""" + answer = AnswerPromptService.run_completion( + llm=mock_llm, + prompt="test prompt", + process_text=None, + ) + assert answer == "42" + + +# --------------------------------------------------------------------------- +# 2. run_completion() populates metadata from completion dict +# --------------------------------------------------------------------------- + +class TestRunCompletionMetadata: + def test_highlight_metadata_populated_with_process_text(self, mock_llm): + """When process_text is provided and LLM returns highlight data, + metadata is populated correctly.""" + callback = MagicMock(name="highlight_run") + metadata: dict = {} + AnswerPromptService.run_completion( + llm=mock_llm, + prompt="test", + metadata=metadata, + prompt_key="field1", + enable_highlight=True, + enable_word_confidence=True, + process_text=callback, + ) + assert metadata[PSKeys.HIGHLIGHT_DATA]["field1"] == [{"line": 1}] + assert metadata[PSKeys.CONFIDENCE_DATA]["field1"] == {"score": 0.95} + assert metadata[PSKeys.WORD_CONFIDENCE_DATA]["field1"] == {"words": []} + assert metadata[PSKeys.LINE_NUMBERS]["field1"] == [1, 2] + assert metadata[PSKeys.WHISPER_HASH] == "abc123" + + def test_highlight_metadata_empty_without_process_text( + self, mock_llm_no_highlight + ): + """Without process_text, highlight data is empty but no error.""" + metadata: dict = {} + AnswerPromptService.run_completion( + llm=mock_llm_no_highlight, + prompt="test", + metadata=metadata, + prompt_key="field1", + enable_highlight=True, + process_text=None, + ) + assert metadata[PSKeys.HIGHLIGHT_DATA]["field1"] == [] + assert metadata[PSKeys.LINE_NUMBERS]["field1"] == [] + + +# --------------------------------------------------------------------------- +# 3. construct_and_run_prompt() passes process_text through +# --------------------------------------------------------------------------- + +class TestConstructAndRunPromptProcessText: + def test_process_text_forwarded(self, mock_llm): + """construct_and_run_prompt passes process_text to run_completion.""" + callback = MagicMock(name="highlight_run") + tool_settings = { + PSKeys.PREAMBLE: "", + PSKeys.POSTAMBLE: "", + PSKeys.GRAMMAR: [], + PSKeys.ENABLE_HIGHLIGHT: True, + } + output = { + PSKeys.NAME: "field1", + PSKeys.PROMPT: "What is the value?", + PSKeys.PROMPTX: "What is the value?", + PSKeys.TYPE: PSKeys.TEXT, + } + answer = AnswerPromptService.construct_and_run_prompt( + tool_settings=tool_settings, + output=output, + llm=mock_llm, + context="some context", + prompt=PSKeys.PROMPTX, + metadata={}, + process_text=callback, + ) + # Verify callback was passed to llm.complete + call_kwargs = mock_llm.complete.call_args + pt = call_kwargs.kwargs.get("process_text") + if pt is None: + pt = call_kwargs[1].get("process_text") + assert pt is callback + assert answer == "42" + + def test_process_text_none_default(self, mock_llm): + """construct_and_run_prompt defaults process_text to None.""" + tool_settings = { + PSKeys.PREAMBLE: "", + PSKeys.POSTAMBLE: "", + PSKeys.GRAMMAR: [], + } + output = { + PSKeys.NAME: "field1", + PSKeys.PROMPT: "What?", + PSKeys.PROMPTX: "What?", + PSKeys.TYPE: PSKeys.TEXT, + } + AnswerPromptService.construct_and_run_prompt( + tool_settings=tool_settings, + output=output, + llm=mock_llm, + context="ctx", + prompt=PSKeys.PROMPTX, + metadata={}, + ) + call_kwargs = mock_llm.complete.call_args + pt = call_kwargs.kwargs.get("process_text") + if pt is None and "process_text" not in (call_kwargs.kwargs or {}): + pt = call_kwargs[1].get("process_text") + assert pt is None + + +# --------------------------------------------------------------------------- +# 4. _handle_answer_prompt() initializes highlight plugin +# --------------------------------------------------------------------------- + +class TestHandleAnswerPromptHighlight: + """Test highlight plugin integration in LegacyExecutor._handle_answer_prompt.""" + + def _make_context(self, enable_highlight=False): + """Build a minimal ExecutionContext for answer_prompt.""" + from unstract.sdk1.execution.context import ExecutionContext + + prompt_output = { + PSKeys.NAME: "field1", + PSKeys.PROMPT: "What is X?", + PSKeys.PROMPTX: "What is X?", + PSKeys.TYPE: PSKeys.TEXT, + PSKeys.CHUNK_SIZE: 0, + PSKeys.CHUNK_OVERLAP: 0, + PSKeys.LLM: "llm-123", + PSKeys.EMBEDDING: "emb-123", + PSKeys.VECTOR_DB: "vdb-123", + PSKeys.X2TEXT_ADAPTER: "x2t-123", + PSKeys.RETRIEVAL_STRATEGY: "simple", + } + return ExecutionContext( + executor_name="legacy", + operation="answer_prompt", + run_id="run-001", + execution_source="ide", + organization_id="org-1", + executor_params={ + PSKeys.TOOL_SETTINGS: { + PSKeys.PREAMBLE: "", + PSKeys.POSTAMBLE: "", + PSKeys.GRAMMAR: [], + PSKeys.ENABLE_HIGHLIGHT: enable_highlight, + }, + PSKeys.OUTPUTS: [prompt_output], + PSKeys.TOOL_ID: "tool-1", + PSKeys.FILE_HASH: "hash123", + PSKeys.FILE_PATH: "/data/doc.txt", + PSKeys.FILE_NAME: "doc.txt", + PSKeys.PLATFORM_SERVICE_API_KEY: "key-123", + }, + ) + + def _get_executor(self): + from executor.executors.legacy_executor import LegacyExecutor + from unstract.sdk1.execution.registry import ExecutorRegistry + + ExecutorRegistry.clear() + if "legacy" not in ExecutorRegistry.list_executors(): + ExecutorRegistry.register(LegacyExecutor) + return ExecutorRegistry.get("legacy") + + @patch("executor.executors.legacy_executor.ExecutorToolShim") + @patch("unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-1") + def test_highlight_plugin_initialized_when_enabled( + self, mock_index_key, mock_shim_cls + ): + """When enable_highlight=True and plugin available, highlight is used.""" + mock_shim_cls.return_value = MagicMock() + + # Mock highlight plugin + mock_highlight_cls = MagicMock() + mock_highlight_instance = MagicMock() + mock_highlight_cls.return_value = mock_highlight_instance + + # Mock LLM + mock_llm = MagicMock() + mock_llm.complete.return_value = { + PSKeys.RESPONSE: MagicMock(text="result"), + PSKeys.HIGHLIGHT_DATA: [{"line": 5}], + PSKeys.CONFIDENCE_DATA: {"score": 0.9}, + PSKeys.LINE_NUMBERS: [5], + PSKeys.WHISPER_HASH: "hash1", + } + mock_llm.get_usage_reason.return_value = "extraction" + mock_llm.get_metrics.return_value = {} + + mock_fs = MagicMock() + mock_llm_cls = MagicMock(return_value=mock_llm) + + executor = self._get_executor() + ctx = self._make_context(enable_highlight=True) + + with ( + patch.object( + executor, "_get_prompt_deps", + return_value=( + AnswerPromptService, + MagicMock( + retrieve_complete_context=MagicMock( + return_value=["context chunk"] + ) + ), + MagicMock( + is_variables_present=MagicMock(return_value=False) + ), + None, # Index + mock_llm_cls, + MagicMock(), # EmbeddingCompat + MagicMock(), # VectorDB + ), + ), + patch( + "executor.executors.plugins.loader.ExecutorPluginLoader.get", + return_value=mock_highlight_cls, + ), + patch( + "executor.executors.file_utils.FileUtils.get_fs_instance", + return_value=mock_fs, + ), + ): + result = executor._handle_answer_prompt(ctx) + + assert result.success + # Verify highlight plugin was instantiated with correct args + mock_highlight_cls.assert_called_once_with( + file_path="/data/doc.txt", + fs_instance=mock_fs, + enable_word_confidence=False, + ) + # Verify process_text was the highlight instance's run method + llm_complete_call = mock_llm.complete.call_args + assert llm_complete_call.kwargs.get("process_text") is \ + mock_highlight_instance.run + + @patch("executor.executors.legacy_executor.ExecutorToolShim") + @patch("unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-1") + def test_highlight_skipped_when_plugin_not_installed( + self, mock_index_key, mock_shim_cls + ): + """When enable_highlight=True but plugin not installed, process_text=None.""" + mock_shim = MagicMock() + mock_shim_cls.return_value = mock_shim + + mock_llm = MagicMock() + mock_llm.complete.return_value = { + PSKeys.RESPONSE: MagicMock(text="result"), + PSKeys.HIGHLIGHT_DATA: [], + PSKeys.LINE_NUMBERS: [], + PSKeys.WHISPER_HASH: "", + } + mock_llm.get_usage_reason.return_value = "extraction" + mock_llm.get_metrics.return_value = {} + + executor = self._get_executor() + ctx = self._make_context(enable_highlight=True) + + mock_llm_cls = MagicMock(return_value=mock_llm) + with ( + patch.object( + executor, "_get_prompt_deps", + return_value=( + AnswerPromptService, + MagicMock( + retrieve_complete_context=MagicMock( + return_value=["chunk"] + ) + ), + MagicMock( + is_variables_present=MagicMock(return_value=False) + ), + None, + mock_llm_cls, + MagicMock(), + MagicMock(), + ), + ), + patch( + "executor.executors.plugins.loader.ExecutorPluginLoader.get", + return_value=None, # Plugin not installed + ), + ): + result = executor._handle_answer_prompt(ctx) + + assert result.success + # process_text should be None since plugin not available + llm_complete_call = mock_llm.complete.call_args + assert llm_complete_call.kwargs.get("process_text") is None + + @patch("executor.executors.legacy_executor.ExecutorToolShim") + @patch("unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-1") + def test_highlight_skipped_when_disabled( + self, mock_index_key, mock_shim_cls + ): + """When enable_highlight=False, plugin loader is not even called.""" + mock_shim = MagicMock() + mock_shim_cls.return_value = mock_shim + + mock_llm = MagicMock() + mock_llm.complete.return_value = { + PSKeys.RESPONSE: MagicMock(text="result"), + PSKeys.HIGHLIGHT_DATA: [], + PSKeys.LINE_NUMBERS: [], + PSKeys.WHISPER_HASH: "", + } + mock_llm.get_usage_reason.return_value = "extraction" + mock_llm.get_metrics.return_value = {} + + executor = self._get_executor() + ctx = self._make_context(enable_highlight=False) + + mock_llm_cls = MagicMock(return_value=mock_llm) + with ( + patch.object( + executor, "_get_prompt_deps", + return_value=( + AnswerPromptService, + MagicMock( + retrieve_complete_context=MagicMock( + return_value=["chunk"] + ) + ), + MagicMock( + is_variables_present=MagicMock(return_value=False) + ), + None, + mock_llm_cls, + MagicMock(), + MagicMock(), + ), + ), + patch( + "executor.executors.plugins.loader.ExecutorPluginLoader.get", + ) as mock_plugin_get, + ): + result = executor._handle_answer_prompt(ctx) + + assert result.success + # Plugin loader should NOT have been called + mock_plugin_get.assert_not_called() + # process_text should be None + llm_complete_call = mock_llm.complete.call_args + assert llm_complete_call.kwargs.get("process_text") is None + + +# --------------------------------------------------------------------------- +# 5. Multiple prompts share same highlight instance +# --------------------------------------------------------------------------- + +class TestHighlightMultiplePrompts: + """Verify that one highlight instance is shared across all prompts.""" + + def _make_multi_prompt_context(self): + from unstract.sdk1.execution.context import ExecutionContext + + prompts = [] + for name in ["field1", "field2", "field3"]: + prompts.append({ + PSKeys.NAME: name, + PSKeys.PROMPT: f"What is {name}?", + PSKeys.PROMPTX: f"What is {name}?", + PSKeys.TYPE: PSKeys.TEXT, + PSKeys.CHUNK_SIZE: 0, + PSKeys.CHUNK_OVERLAP: 0, + PSKeys.LLM: "llm-123", + PSKeys.EMBEDDING: "emb-123", + PSKeys.VECTOR_DB: "vdb-123", + PSKeys.X2TEXT_ADAPTER: "x2t-123", + PSKeys.RETRIEVAL_STRATEGY: "simple", + }) + return ExecutionContext( + executor_name="legacy", + operation="answer_prompt", + run_id="run-002", + execution_source="tool", + organization_id="org-1", + executor_params={ + PSKeys.TOOL_SETTINGS: { + PSKeys.PREAMBLE: "", + PSKeys.POSTAMBLE: "", + PSKeys.GRAMMAR: [], + PSKeys.ENABLE_HIGHLIGHT: True, + }, + PSKeys.OUTPUTS: prompts, + PSKeys.TOOL_ID: "tool-1", + PSKeys.FILE_HASH: "hash123", + PSKeys.FILE_PATH: "/data/doc.txt", + PSKeys.FILE_NAME: "doc.txt", + PSKeys.PLATFORM_SERVICE_API_KEY: "key-123", + }, + ) + + @patch("executor.executors.legacy_executor.ExecutorToolShim") + @patch("unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-1") + def test_single_highlight_instance_for_all_prompts( + self, mock_index_key, mock_shim_cls + ): + """One highlight instance is created and reused for all prompts.""" + mock_shim_cls.return_value = MagicMock() + + mock_highlight_cls = MagicMock() + mock_highlight_instance = MagicMock() + mock_highlight_cls.return_value = mock_highlight_instance + + mock_llm = MagicMock() + mock_llm.complete.return_value = { + PSKeys.RESPONSE: MagicMock(text="val"), + PSKeys.HIGHLIGHT_DATA: [], + PSKeys.LINE_NUMBERS: [], + PSKeys.WHISPER_HASH: "", + } + mock_llm.get_usage_reason.return_value = "extraction" + mock_llm.get_metrics.return_value = {} + + from executor.executors.legacy_executor import LegacyExecutor + from unstract.sdk1.execution.registry import ExecutorRegistry + + ExecutorRegistry.clear() + if "legacy" not in ExecutorRegistry.list_executors(): + ExecutorRegistry.register(LegacyExecutor) + executor = ExecutorRegistry.get("legacy") + ctx = self._make_multi_prompt_context() + + mock_llm_cls = MagicMock(return_value=mock_llm) + with ( + patch.object( + executor, "_get_prompt_deps", + return_value=( + AnswerPromptService, + MagicMock( + retrieve_complete_context=MagicMock( + return_value=["chunk"] + ) + ), + MagicMock( + is_variables_present=MagicMock(return_value=False) + ), + None, + mock_llm_cls, + MagicMock(), + MagicMock(), + ), + ), + patch( + "executor.executors.plugins.loader.ExecutorPluginLoader.get", + return_value=mock_highlight_cls, + ), + patch( + "executor.executors.file_utils.FileUtils.get_fs_instance", + return_value=MagicMock(), + ), + ): + result = executor._handle_answer_prompt(ctx) + + assert result.success + # highlight_cls should be instantiated exactly ONCE + assert mock_highlight_cls.call_count == 1 + # llm.complete should be called 3 times (once per prompt) + assert mock_llm.complete.call_count == 3 + # Each call should use the same process_text + for c in mock_llm.complete.call_args_list: + assert c.kwargs.get("process_text") is mock_highlight_instance.run diff --git a/workers/tests/test_sanity_phase6d.py b/workers/tests/test_sanity_phase6d.py new file mode 100644 index 0000000000..91cc8cf72c --- /dev/null +++ b/workers/tests/test_sanity_phase6d.py @@ -0,0 +1,554 @@ +"""Phase 6D Sanity — LegacyExecutor plugin integration. + +Verifies: +1. TABLE type raises LegacyExecutorError with routing guidance +2. LINE_ITEM type raises LegacyExecutorError (not supported) +3. Challenge plugin invoked when enable_challenge=True + plugin installed +4. Challenge skipped when plugin not installed (graceful degradation) +5. Challenge skipped when enable_challenge=False +6. Challenge skipped when challenge_llm not configured +7. Evaluation plugin invoked when eval_settings.evaluate=True + plugin installed +8. Evaluation skipped when plugin not installed +9. Evaluation skipped when eval_settings.evaluate=False +10. Challenge runs before evaluation (order matters) +11. Challenge mutates structured_output (via mock) +""" + +from unittest.mock import MagicMock, patch + +import pytest +from executor.executors.answer_prompt import AnswerPromptService +from executor.executors.constants import PromptServiceConstants as PSKeys +from executor.executors.exceptions import LegacyExecutorError +from unstract.sdk1.execution.result import ExecutionResult + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_context( + output_type="TEXT", + enable_highlight=False, + enable_challenge=False, + challenge_llm="", + eval_settings=None, +): + """Build a minimal ExecutionContext for answer_prompt tests.""" + from unstract.sdk1.execution.context import ExecutionContext + + prompt_output = { + PSKeys.NAME: "field1", + PSKeys.PROMPT: "What is X?", + PSKeys.PROMPTX: "What is X?", + PSKeys.TYPE: output_type, + PSKeys.CHUNK_SIZE: 0, + PSKeys.CHUNK_OVERLAP: 0, + PSKeys.LLM: "llm-123", + PSKeys.EMBEDDING: "emb-123", + PSKeys.VECTOR_DB: "vdb-123", + PSKeys.X2TEXT_ADAPTER: "x2t-123", + PSKeys.RETRIEVAL_STRATEGY: "simple", + } + if eval_settings: + prompt_output[PSKeys.EVAL_SETTINGS] = eval_settings + + tool_settings = { + PSKeys.PREAMBLE: "", + PSKeys.POSTAMBLE: "", + PSKeys.GRAMMAR: [], + PSKeys.ENABLE_HIGHLIGHT: enable_highlight, + PSKeys.ENABLE_CHALLENGE: enable_challenge, + } + if challenge_llm: + tool_settings[PSKeys.CHALLENGE_LLM] = challenge_llm + + return ExecutionContext( + executor_name="legacy", + operation="answer_prompt", + run_id="run-001", + execution_source="ide", + organization_id="org-1", + executor_params={ + PSKeys.TOOL_SETTINGS: tool_settings, + PSKeys.OUTPUTS: [prompt_output], + PSKeys.TOOL_ID: "tool-1", + PSKeys.FILE_HASH: "hash123", + PSKeys.FILE_PATH: "/data/doc.txt", + PSKeys.FILE_NAME: "doc.txt", + PSKeys.PLATFORM_SERVICE_API_KEY: "key-123", + }, + ) + + +def _get_executor(): + from executor.executors.legacy_executor import LegacyExecutor + from unstract.sdk1.execution.registry import ExecutorRegistry + + ExecutorRegistry.clear() + if "legacy" not in ExecutorRegistry.list_executors(): + ExecutorRegistry.register(LegacyExecutor) + return ExecutorRegistry.get("legacy") + + +def _mock_llm(): + """Create a mock LLM that returns a realistic completion dict.""" + llm = MagicMock() + llm.complete.return_value = { + PSKeys.RESPONSE: MagicMock(text="42"), + PSKeys.HIGHLIGHT_DATA: [], + PSKeys.LINE_NUMBERS: [], + PSKeys.WHISPER_HASH: "", + } + llm.get_usage_reason.return_value = "extraction" + llm.get_metrics.return_value = {} + return llm + + +def _standard_patches(executor, mock_llm_instance): + """Return common patches for _handle_answer_prompt tests.""" + mock_llm_cls = MagicMock(return_value=mock_llm_instance) + return { + "_get_prompt_deps": patch.object( + executor, "_get_prompt_deps", + return_value=( + AnswerPromptService, + MagicMock( + retrieve_complete_context=MagicMock( + return_value=["context chunk"] + ) + ), + MagicMock( + is_variables_present=MagicMock(return_value=False) + ), + None, # Index + mock_llm_cls, + MagicMock(), # EmbeddingCompat + MagicMock(), # VectorDB + ), + ), + "shim": patch( + "executor.executors.legacy_executor.ExecutorToolShim", + return_value=MagicMock(), + ), + "index_key": patch( + "unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-1", + ), + } + + +# --------------------------------------------------------------------------- +# 1. TABLE type raises with routing guidance +# --------------------------------------------------------------------------- + +class TestTableLineItemGuard: + @patch("executor.executors.legacy_executor.ExecutorToolShim") + @patch("unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-1") + def test_table_type_delegates_to_table_executor( + self, mock_key, mock_shim_cls + ): + """TABLE prompts are delegated to TableExtractorExecutor in-process.""" + mock_shim_cls.return_value = MagicMock() + executor = _get_executor() + ctx = _make_context(output_type=PSKeys.TABLE) # "table" + llm = _mock_llm() + patches = _standard_patches(executor, llm) + + mock_table_executor = MagicMock() + mock_table_executor.execute.return_value = ExecutionResult( + success=True, + data={"output": {"table_data": "extracted"}, "metadata": {"metrics": {}}}, + ) + + with patches["_get_prompt_deps"], patches["shim"], patches["index_key"]: + with patch( + "unstract.sdk1.execution.registry.ExecutorRegistry.get", + return_value=mock_table_executor, + ): + result = executor._handle_answer_prompt(ctx) + + assert result.success + assert result.data["output"]["field1"] == {"table_data": "extracted"} + mock_table_executor.execute.assert_called_once() + # Verify the sub-context was built with table executor params + sub_ctx = mock_table_executor.execute.call_args[0][0] + assert sub_ctx.executor_name == "table" + assert sub_ctx.operation == "table_extract" + + @patch("executor.executors.legacy_executor.ExecutorToolShim") + @patch("unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-1") + def test_table_type_raises_when_plugin_missing( + self, mock_key, mock_shim_cls + ): + """TABLE prompts raise error when table executor plugin is not installed.""" + mock_shim_cls.return_value = MagicMock() + executor = _get_executor() + ctx = _make_context(output_type=PSKeys.TABLE) # "table" + llm = _mock_llm() + patches = _standard_patches(executor, llm) + + with patches["_get_prompt_deps"], patches["shim"], patches["index_key"]: + with patch( + "unstract.sdk1.execution.registry.ExecutorRegistry.get", + side_effect=KeyError("No executor registered with name 'table'"), + ): + with pytest.raises(LegacyExecutorError, match="table executor plugin"): + executor._handle_answer_prompt(ctx) + + @patch("executor.executors.legacy_executor.ExecutorToolShim") + @patch("unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-1") + def test_line_item_type_raises_not_supported( + self, mock_key, mock_shim_cls + ): + mock_shim_cls.return_value = MagicMock() + executor = _get_executor() + ctx = _make_context(output_type=PSKeys.LINE_ITEM) # "line-item" + llm = _mock_llm() + patches = _standard_patches(executor, llm) + + with patches["_get_prompt_deps"], patches["shim"], patches["index_key"]: + with pytest.raises(LegacyExecutorError, match="not supported"): + executor._handle_answer_prompt(ctx) + + +# --------------------------------------------------------------------------- +# 2. Challenge plugin integration +# --------------------------------------------------------------------------- + +class TestChallengeIntegration: + @patch("executor.executors.legacy_executor.ExecutorToolShim") + @patch("unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-1") + def test_challenge_invoked_when_enabled_and_installed( + self, mock_key, mock_shim_cls + ): + """Challenge plugin is instantiated and run() called.""" + mock_shim_cls.return_value = MagicMock() + executor = _get_executor() + ctx = _make_context(enable_challenge=True, challenge_llm="ch-llm-1") + llm = _mock_llm() + mock_challenge_cls = MagicMock() + mock_challenger = MagicMock() + mock_challenge_cls.return_value = mock_challenger + + patches = _standard_patches(executor, llm) + with ( + patches["_get_prompt_deps"], + patches["shim"], + patches["index_key"], + patch( + "executor.executors.plugins.loader.ExecutorPluginLoader.get", + side_effect=lambda name: ( + mock_challenge_cls if name == "challenge" else None + ), + ), + ): + result = executor._handle_answer_prompt(ctx) + + assert result.success + # Challenge class was instantiated with correct args + mock_challenge_cls.assert_called_once() + init_kwargs = mock_challenge_cls.call_args.kwargs + assert init_kwargs["run_id"] == "run-001" + assert init_kwargs["platform_key"] == "key-123" + assert init_kwargs["llm"] is llm + # run() was called + mock_challenger.run.assert_called_once() + + @patch("executor.executors.legacy_executor.ExecutorToolShim") + @patch("unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-1") + def test_challenge_skipped_when_plugin_not_installed( + self, mock_key, mock_shim_cls + ): + """When challenge enabled but plugin missing, no error.""" + mock_shim_cls.return_value = MagicMock() + executor = _get_executor() + ctx = _make_context(enable_challenge=True, challenge_llm="ch-llm-1") + llm = _mock_llm() + + patches = _standard_patches(executor, llm) + with ( + patches["_get_prompt_deps"], + patches["shim"], + patches["index_key"], + patch( + "executor.executors.plugins.loader.ExecutorPluginLoader.get", + return_value=None, + ), + ): + result = executor._handle_answer_prompt(ctx) + + assert result.success + + @patch("executor.executors.legacy_executor.ExecutorToolShim") + @patch("unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-1") + def test_challenge_skipped_when_disabled( + self, mock_key, mock_shim_cls + ): + """When enable_challenge=False, plugin loader not called for challenge.""" + mock_shim_cls.return_value = MagicMock() + executor = _get_executor() + ctx = _make_context(enable_challenge=False) + llm = _mock_llm() + + patches = _standard_patches(executor, llm) + with ( + patches["_get_prompt_deps"], + patches["shim"], + patches["index_key"], + patch( + "executor.executors.plugins.loader.ExecutorPluginLoader.get", + ) as mock_get, + ): + result = executor._handle_answer_prompt(ctx) + + assert result.success + # Plugin loader should NOT have been called for "challenge" + for c in mock_get.call_args_list: + assert c.args[0] != "challenge", ( + "ExecutorPluginLoader.get('challenge') should not be called" + ) + + @patch("executor.executors.legacy_executor.ExecutorToolShim") + @patch("unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-1") + def test_challenge_skipped_when_no_challenge_llm( + self, mock_key, mock_shim_cls + ): + """When enable_challenge=True but no challenge_llm, skip challenge.""" + mock_shim_cls.return_value = MagicMock() + executor = _get_executor() + # enable_challenge=True but challenge_llm="" (empty) + ctx = _make_context(enable_challenge=True, challenge_llm="") + llm = _mock_llm() + mock_challenge_cls = MagicMock() + + patches = _standard_patches(executor, llm) + with ( + patches["_get_prompt_deps"], + patches["shim"], + patches["index_key"], + patch( + "executor.executors.plugins.loader.ExecutorPluginLoader.get", + return_value=mock_challenge_cls, + ), + ): + result = executor._handle_answer_prompt(ctx) + + assert result.success + # Challenge class should NOT be instantiated (no LLM ID) + mock_challenge_cls.assert_not_called() + + +# --------------------------------------------------------------------------- +# 3. Evaluation plugin integration +# --------------------------------------------------------------------------- + +class TestEvaluationIntegration: + @patch("executor.executors.legacy_executor.ExecutorToolShim") + @patch("unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-1") + def test_evaluation_invoked_when_enabled_and_installed( + self, mock_key, mock_shim_cls + ): + """Evaluation plugin is instantiated and run() called.""" + mock_shim_cls.return_value = MagicMock() + executor = _get_executor() + ctx = _make_context( + eval_settings={PSKeys.EVAL_SETTINGS_EVALUATE: True} + ) + llm = _mock_llm() + mock_eval_cls = MagicMock() + mock_evaluator = MagicMock() + mock_eval_cls.return_value = mock_evaluator + + patches = _standard_patches(executor, llm) + with ( + patches["_get_prompt_deps"], + patches["shim"], + patches["index_key"], + patch( + "executor.executors.plugins.loader.ExecutorPluginLoader.get", + side_effect=lambda name: ( + mock_eval_cls if name == "evaluation" else None + ), + ), + ): + result = executor._handle_answer_prompt(ctx) + + assert result.success + mock_eval_cls.assert_called_once() + init_kwargs = mock_eval_cls.call_args.kwargs + assert init_kwargs["platform_key"] == "key-123" + assert init_kwargs["response"] == "42" # from mock LLM + mock_evaluator.run.assert_called_once() + + @patch("executor.executors.legacy_executor.ExecutorToolShim") + @patch("unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-1") + def test_evaluation_skipped_when_plugin_not_installed( + self, mock_key, mock_shim_cls + ): + """When evaluation enabled but plugin missing, no error.""" + mock_shim_cls.return_value = MagicMock() + executor = _get_executor() + ctx = _make_context( + eval_settings={PSKeys.EVAL_SETTINGS_EVALUATE: True} + ) + llm = _mock_llm() + + patches = _standard_patches(executor, llm) + with ( + patches["_get_prompt_deps"], + patches["shim"], + patches["index_key"], + patch( + "executor.executors.plugins.loader.ExecutorPluginLoader.get", + return_value=None, + ), + ): + result = executor._handle_answer_prompt(ctx) + + assert result.success + + @patch("executor.executors.legacy_executor.ExecutorToolShim") + @patch("unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-1") + def test_evaluation_skipped_when_not_enabled( + self, mock_key, mock_shim_cls + ): + """When no eval_settings or evaluate=False, evaluation skipped.""" + mock_shim_cls.return_value = MagicMock() + executor = _get_executor() + # No eval_settings at all + ctx = _make_context() + llm = _mock_llm() + + patches = _standard_patches(executor, llm) + with ( + patches["_get_prompt_deps"], + patches["shim"], + patches["index_key"], + patch( + "executor.executors.plugins.loader.ExecutorPluginLoader.get", + ) as mock_get, + ): + result = executor._handle_answer_prompt(ctx) + + assert result.success + # Plugin loader should NOT have been called for "evaluation" + for c in mock_get.call_args_list: + assert c.args[0] != "evaluation", ( + "ExecutorPluginLoader.get('evaluation') should not be called" + ) + + +# --------------------------------------------------------------------------- +# 4. Challenge runs before evaluation (ordering) +# --------------------------------------------------------------------------- + +class TestChallengeBeforeEvaluation: + @patch("executor.executors.legacy_executor.ExecutorToolShim") + @patch("unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-1") + def test_challenge_runs_before_evaluation( + self, mock_key, mock_shim_cls + ): + """Challenge mutates structured_output before evaluation reads it.""" + mock_shim_cls.return_value = MagicMock() + executor = _get_executor() + ctx = _make_context( + enable_challenge=True, + challenge_llm="ch-llm-1", + eval_settings={PSKeys.EVAL_SETTINGS_EVALUATE: True}, + ) + llm = _mock_llm() + + # Track call order + call_order = [] + + mock_challenge_cls = MagicMock() + mock_challenger = MagicMock() + mock_challenger.run.side_effect = lambda: call_order.append("challenge") + mock_challenge_cls.return_value = mock_challenger + + mock_eval_cls = MagicMock() + mock_evaluator = MagicMock() + mock_evaluator.run.side_effect = lambda: call_order.append("evaluation") + mock_eval_cls.return_value = mock_evaluator + + def plugin_get(name): + if name == "challenge": + return mock_challenge_cls + if name == "evaluation": + return mock_eval_cls + return None + + patches = _standard_patches(executor, llm) + with ( + patches["_get_prompt_deps"], + patches["shim"], + patches["index_key"], + patch( + "executor.executors.plugins.loader.ExecutorPluginLoader.get", + side_effect=plugin_get, + ), + ): + result = executor._handle_answer_prompt(ctx) + + assert result.success + assert call_order == ["challenge", "evaluation"] + + +# --------------------------------------------------------------------------- +# 5. Challenge mutates structured_output +# --------------------------------------------------------------------------- + +class TestChallengeMutation: + @patch("executor.executors.legacy_executor.ExecutorToolShim") + @patch("unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-1") + def test_challenge_mutates_structured_output( + self, mock_key, mock_shim_cls + ): + """Challenge plugin can mutate structured_output dict.""" + mock_shim_cls.return_value = MagicMock() + executor = _get_executor() + ctx = _make_context(enable_challenge=True, challenge_llm="ch-llm-1") + llm = _mock_llm() + + def challenge_run_side_effect(): + # Simulate challenge replacing the answer with improved version + challenger_instance = mock_challenge_cls.return_value + # Access the structured_output passed to constructor + so = mock_challenge_cls.call_args.kwargs["structured_output"] + so["field1"] = "improved_42" + + mock_challenge_cls = MagicMock() + mock_challenger = MagicMock() + mock_challenger.run.side_effect = challenge_run_side_effect + mock_challenge_cls.return_value = mock_challenger + + patches = _standard_patches(executor, llm) + with ( + patches["_get_prompt_deps"], + patches["shim"], + patches["index_key"], + patch( + "executor.executors.plugins.loader.ExecutorPluginLoader.get", + side_effect=lambda name: ( + mock_challenge_cls if name == "challenge" else None + ), + ), + ): + result = executor._handle_answer_prompt(ctx) + + assert result.success + # The structured_output should contain the mutated value + assert result.data[PSKeys.OUTPUT]["field1"] == "improved_42" diff --git a/workers/tests/test_sanity_phase6e.py b/workers/tests/test_sanity_phase6e.py new file mode 100644 index 0000000000..85d4d60d65 --- /dev/null +++ b/workers/tests/test_sanity_phase6e.py @@ -0,0 +1,215 @@ +"""Phase 6E Sanity — TableExtractorExecutor + TABLE_EXTRACT operation. + +Verifies: +1. Operation.TABLE_EXTRACT enum exists with value "table_extract" +2. tasks.py log_component builder handles table_extract operation +3. TableExtractorExecutor mock — registration via entry point +4. TableExtractorExecutor mock — dispatch to correct queue +5. LegacyExecutor excludes table_extract from its _OPERATION_MAP +6. Cloud executor entry point name matches pyproject.toml +""" + +from unittest.mock import MagicMock + + +from unstract.sdk1.execution.context import ExecutionContext, Operation +from unstract.sdk1.execution.dispatcher import ExecutionDispatcher +from unstract.sdk1.execution.registry import ExecutorRegistry +from unstract.sdk1.execution.result import ExecutionResult + + +# --------------------------------------------------------------------------- +# 1. Operation enum +# --------------------------------------------------------------------------- + +class TestTableExtractOperation: + def test_table_extract_enum_exists(self): + assert hasattr(Operation, "TABLE_EXTRACT") + assert Operation.TABLE_EXTRACT.value == "table_extract" + + def test_table_extract_in_operation_values(self): + values = {op.value for op in Operation} + assert "table_extract" in values + + +# --------------------------------------------------------------------------- +# 2. tasks.py log_component for table_extract +# --------------------------------------------------------------------------- + +class TestTasksLogComponent: + def test_table_extract_log_component(self): + """tasks.py builds correct log_component for table_extract.""" + + # Build a mock context dict + ctx_dict = { + "executor_name": "table", + "operation": "table_extract", + "run_id": "run-001", + "execution_source": "tool", + "organization_id": "org-1", + "executor_params": { + "tool_id": "tool-1", + "file_name": "invoice.pdf", + }, + "request_id": "req-1", + "log_events_id": "evt-1", + } + + # We just need to verify the log_component is built correctly. + # Deserialize the context and check the branch. + context = ExecutionContext.from_dict(ctx_dict) + params = context.executor_params + + # Simulate the tasks.py logic + if context.log_events_id: + if context.operation == "table_extract": + component = { + "tool_id": params.get("tool_id", ""), + "run_id": context.run_id, + "doc_name": str(params.get("file_name", "")), + "operation": context.operation, + } + assert component == { + "tool_id": "tool-1", + "run_id": "run-001", + "doc_name": "invoice.pdf", + "operation": "table_extract", + } + + +# --------------------------------------------------------------------------- +# 3. Mock TableExtractorExecutor — entry point registration +# --------------------------------------------------------------------------- + +class TestTableExtractorExecutorRegistration: + def test_mock_table_executor_discovered_via_entry_point(self): + """Simulate cloud executor discovery via entry point.""" + from unstract.sdk1.execution.executor import BaseExecutor + + # Create a mock TableExtractorExecutor + @ExecutorRegistry.register + class MockTableExtractorExecutor(BaseExecutor): + @property + def name(self) -> str: + return "table" + + def execute(self, context): + if context.operation != "table_extract": + return ExecutionResult.failure( + error=f"Unsupported: {context.operation}" + ) + return ExecutionResult( + success=True, + data={"output": "table_data", "metadata": {}}, + ) + + try: + # Verify it was registered + assert "table" in ExecutorRegistry.list_executors() + executor = ExecutorRegistry.get("table") + assert executor.name == "table" + + # Verify it handles table_extract + ctx = ExecutionContext( + executor_name="table", + operation="table_extract", + run_id="run-1", + execution_source="tool", + executor_params={}, + ) + result = executor.execute(ctx) + assert result.success + assert result.data["output"] == "table_data" + + # Verify it rejects unsupported operations + ctx2 = ExecutionContext( + executor_name="table", + operation="answer_prompt", + run_id="run-2", + execution_source="tool", + executor_params={}, + ) + result2 = executor.execute(ctx2) + assert not result2.success + finally: + # Cleanup + ExecutorRegistry.clear() + + +# --------------------------------------------------------------------------- +# 4. Queue routing for table executor +# --------------------------------------------------------------------------- + +class TestTableQueueRouting: + def test_table_executor_routes_to_correct_queue(self): + """executor_name='table' routes to celery_executor_table queue.""" + queue = ExecutionDispatcher._get_queue("table") + assert queue == "celery_executor_table" + + def test_dispatch_sends_to_table_queue(self): + """ExecutionDispatcher sends table_extract to correct queue.""" + mock_app = MagicMock() + mock_result = MagicMock() + mock_result.get.return_value = ExecutionResult( + success=True, data={"output": "ok"} + ).to_dict() + mock_app.send_task.return_value = mock_result + + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = ExecutionContext( + executor_name="table", + operation="table_extract", + run_id="run-1", + execution_source="tool", + executor_params={"table_settings": {}}, + ) + result = dispatcher.dispatch(ctx) + + mock_app.send_task.assert_called_once() + call_kwargs = mock_app.send_task.call_args + assert call_kwargs.kwargs.get("queue") == "celery_executor_table" + + +# --------------------------------------------------------------------------- +# 5. LegacyExecutor does NOT handle table_extract +# --------------------------------------------------------------------------- + +class TestLegacyExcludesTable: + def test_table_extract_not_in_legacy_operation_map(self): + """LegacyExecutor._OPERATION_MAP should NOT contain table_extract.""" + from executor.executors.legacy_executor import LegacyExecutor + + assert "table_extract" not in LegacyExecutor._OPERATION_MAP + + def test_legacy_returns_failure_for_table_extract(self): + """LegacyExecutor.execute() returns failure for table_extract.""" + from executor.executors.legacy_executor import LegacyExecutor + + ExecutorRegistry.clear() + if "legacy" not in ExecutorRegistry.list_executors(): + ExecutorRegistry.register(LegacyExecutor) + executor = ExecutorRegistry.get("legacy") + + ctx = ExecutionContext( + executor_name="legacy", + operation="table_extract", + run_id="run-1", + execution_source="tool", + executor_params={}, + ) + result = executor.execute(ctx) + assert not result.success + assert "does not support" in result.error + + +# --------------------------------------------------------------------------- +# 6. Entry point name verification +# --------------------------------------------------------------------------- + +class TestEntryPointConfig: + def test_entry_point_name_is_table(self): + """The pyproject.toml entry point name should be 'table'.""" + # This is a documentation/verification test — the entry point + # in pyproject.toml maps 'table' to TableExtractorExecutor. + # Verify the queue name matches. + assert ExecutionDispatcher._get_queue("table") == "celery_executor_table" diff --git a/workers/tests/test_sanity_phase6f.py b/workers/tests/test_sanity_phase6f.py new file mode 100644 index 0000000000..4a8432f6ef --- /dev/null +++ b/workers/tests/test_sanity_phase6f.py @@ -0,0 +1,191 @@ +"""Phase 6F Sanity — SmartTableExtractorExecutor + SMART_TABLE_EXTRACT operation. + +Verifies: +1. Operation.SMART_TABLE_EXTRACT enum exists with value "smart_table_extract" +2. tasks.py log_component builder handles smart_table_extract operation +3. Mock SmartTableExtractorExecutor — registration and execution +4. Queue routing: executor_name="smart_table" → celery_executor_smart_table +5. LegacyExecutor does NOT handle smart_table_extract +6. Dispatch sends to correct queue +""" + +from unittest.mock import MagicMock + + +from unstract.sdk1.execution.context import ExecutionContext, Operation +from unstract.sdk1.execution.dispatcher import ExecutionDispatcher +from unstract.sdk1.execution.executor import BaseExecutor +from unstract.sdk1.execution.registry import ExecutorRegistry +from unstract.sdk1.execution.result import ExecutionResult + + +# --------------------------------------------------------------------------- +# 1. Operation enum +# --------------------------------------------------------------------------- + +class TestSmartTableExtractOperation: + def test_smart_table_extract_enum_exists(self): + assert hasattr(Operation, "SMART_TABLE_EXTRACT") + assert Operation.SMART_TABLE_EXTRACT.value == "smart_table_extract" + + def test_smart_table_extract_in_operation_values(self): + values = {op.value for op in Operation} + assert "smart_table_extract" in values + + +# --------------------------------------------------------------------------- +# 2. tasks.py log_component for smart_table_extract +# --------------------------------------------------------------------------- + +class TestTasksLogComponent: + def test_smart_table_extract_log_component(self): + """tasks.py handles smart_table_extract in the same branch as table_extract.""" + ctx_dict = { + "executor_name": "smart_table", + "operation": "smart_table_extract", + "run_id": "run-001", + "execution_source": "tool", + "organization_id": "org-1", + "executor_params": { + "tool_id": "tool-1", + "file_name": "data.xlsx", + }, + "request_id": "req-1", + "log_events_id": "evt-1", + } + context = ExecutionContext.from_dict(ctx_dict) + params = context.executor_params + + # Simulate the tasks.py logic — smart_table_extract shares the + # branch with table_extract + assert context.operation in ("table_extract", "smart_table_extract") + component = { + "tool_id": params.get("tool_id", ""), + "run_id": context.run_id, + "doc_name": str(params.get("file_name", "")), + "operation": context.operation, + } + assert component == { + "tool_id": "tool-1", + "run_id": "run-001", + "doc_name": "data.xlsx", + "operation": "smart_table_extract", + } + + +# --------------------------------------------------------------------------- +# 3. Mock SmartTableExtractorExecutor — registration and execution +# --------------------------------------------------------------------------- + +class TestSmartTableExtractorRegistration: + def test_mock_smart_table_executor_registers_and_executes(self): + """Simulate cloud executor discovery and execution.""" + @ExecutorRegistry.register + class MockSmartTableExecutor(BaseExecutor): + @property + def name(self) -> str: + return "smart_table" + + def execute(self, context): + if context.operation != "smart_table_extract": + return ExecutionResult.failure( + error=f"Unsupported: {context.operation}" + ) + return ExecutionResult( + success=True, + data={ + "output": [{"col1": "val1"}], + "metadata": {"total_records": 1}, + }, + ) + + try: + assert "smart_table" in ExecutorRegistry.list_executors() + executor = ExecutorRegistry.get("smart_table") + assert executor.name == "smart_table" + + ctx = ExecutionContext( + executor_name="smart_table", + operation="smart_table_extract", + run_id="run-1", + execution_source="tool", + executor_params={}, + ) + result = executor.execute(ctx) + assert result.success + assert result.data["output"] == [{"col1": "val1"}] + assert result.data["metadata"]["total_records"] == 1 + + # Rejects unsupported operations + ctx2 = ExecutionContext( + executor_name="smart_table", + operation="answer_prompt", + run_id="run-2", + execution_source="tool", + executor_params={}, + ) + result2 = executor.execute(ctx2) + assert not result2.success + finally: + ExecutorRegistry.clear() + + +# --------------------------------------------------------------------------- +# 4. Queue routing +# --------------------------------------------------------------------------- + +class TestSmartTableQueueRouting: + def test_smart_table_routes_to_correct_queue(self): + queue = ExecutionDispatcher._get_queue("smart_table") + assert queue == "celery_executor_smart_table" + + def test_dispatch_sends_to_smart_table_queue(self): + mock_app = MagicMock() + mock_result = MagicMock() + mock_result.get.return_value = ExecutionResult( + success=True, data={"output": "ok"} + ).to_dict() + mock_app.send_task.return_value = mock_result + + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = ExecutionContext( + executor_name="smart_table", + operation="smart_table_extract", + run_id="run-1", + execution_source="tool", + executor_params={"table_settings": {}}, + ) + result = dispatcher.dispatch(ctx) + + mock_app.send_task.assert_called_once() + call_kwargs = mock_app.send_task.call_args + assert call_kwargs.kwargs.get("queue") == "celery_executor_smart_table" + + +# --------------------------------------------------------------------------- +# 5. LegacyExecutor does NOT handle smart_table_extract +# --------------------------------------------------------------------------- + +class TestLegacyExcludesSmartTable: + def test_smart_table_extract_not_in_legacy_operation_map(self): + from executor.executors.legacy_executor import LegacyExecutor + assert "smart_table_extract" not in LegacyExecutor._OPERATION_MAP + + def test_legacy_returns_failure_for_smart_table_extract(self): + from executor.executors.legacy_executor import LegacyExecutor + + ExecutorRegistry.clear() + if "legacy" not in ExecutorRegistry.list_executors(): + ExecutorRegistry.register(LegacyExecutor) + executor = ExecutorRegistry.get("legacy") + + ctx = ExecutionContext( + executor_name="legacy", + operation="smart_table_extract", + run_id="run-1", + execution_source="tool", + executor_params={}, + ) + result = executor.execute(ctx) + assert not result.success + assert "does not support" in result.error diff --git a/workers/tests/test_sanity_phase6g.py b/workers/tests/test_sanity_phase6g.py new file mode 100644 index 0000000000..73bb738911 --- /dev/null +++ b/workers/tests/test_sanity_phase6g.py @@ -0,0 +1,296 @@ +"""Phase 6G Sanity — SimplePromptStudioExecutor + SPS operations. + +Verifies: +1. Operation.SPS_ANSWER_PROMPT enum exists with value "sps_answer_prompt" +2. Operation.SPS_INDEX enum exists with value "sps_index" +3. Mock SimplePromptStudioExecutor — registration and execution +4. Queue routing: executor_name="simple_prompt_studio" → celery_executor_simple_prompt_studio +5. LegacyExecutor does NOT handle sps_answer_prompt or sps_index +6. Dispatch sends to correct queue +7. SimplePromptStudioExecutor rejects unsupported operations +""" + +from unittest.mock import MagicMock + + +from unstract.sdk1.execution.context import ExecutionContext, Operation +from unstract.sdk1.execution.dispatcher import ExecutionDispatcher +from unstract.sdk1.execution.executor import BaseExecutor +from unstract.sdk1.execution.registry import ExecutorRegistry +from unstract.sdk1.execution.result import ExecutionResult + + +# --------------------------------------------------------------------------- +# 1. Operation enums +# --------------------------------------------------------------------------- + +class TestSPSOperations: + def test_sps_answer_prompt_enum_exists(self): + assert hasattr(Operation, "SPS_ANSWER_PROMPT") + assert Operation.SPS_ANSWER_PROMPT.value == "sps_answer_prompt" + + def test_sps_index_enum_exists(self): + assert hasattr(Operation, "SPS_INDEX") + assert Operation.SPS_INDEX.value == "sps_index" + + def test_sps_operations_in_operation_values(self): + values = {op.value for op in Operation} + assert "sps_answer_prompt" in values + assert "sps_index" in values + + +# --------------------------------------------------------------------------- +# 2. Mock SimplePromptStudioExecutor — registration and execution +# --------------------------------------------------------------------------- + +class TestSimplePromptStudioRegistration: + def test_mock_sps_executor_registers_and_executes(self): + """Simulate cloud executor discovery and execution.""" + @ExecutorRegistry.register + class MockSPSExecutor(BaseExecutor): + _OPERATION_MAP = { + "sps_answer_prompt": "_handle_answer_prompt", + "sps_index": "_handle_index", + } + + @property + def name(self) -> str: + return "simple_prompt_studio" + + def execute(self, context): + handler_name = self._OPERATION_MAP.get(context.operation) + if not handler_name: + return ExecutionResult.failure( + error=f"Unsupported: {context.operation}" + ) + return getattr(self, handler_name)(context) + + def _handle_answer_prompt(self, context): + return ExecutionResult( + success=True, + data={ + "output": {"invoice_number": "INV-001"}, + "metadata": {}, + }, + ) + + def _handle_index(self, context): + return ExecutionResult( + success=True, + data={"output": "indexed", "metadata": {}}, + ) + + try: + assert "simple_prompt_studio" in ExecutorRegistry.list_executors() + executor = ExecutorRegistry.get("simple_prompt_studio") + assert executor.name == "simple_prompt_studio" + + # sps_answer_prompt + ctx = ExecutionContext( + executor_name="simple_prompt_studio", + operation="sps_answer_prompt", + run_id="run-1", + execution_source="tool", + executor_params={}, + ) + result = executor.execute(ctx) + assert result.success + assert result.data["output"] == {"invoice_number": "INV-001"} + + # sps_index + ctx2 = ExecutionContext( + executor_name="simple_prompt_studio", + operation="sps_index", + run_id="run-2", + execution_source="tool", + executor_params={}, + ) + result2 = executor.execute(ctx2) + assert result2.success + assert result2.data["output"] == "indexed" + + # Rejects unsupported operations + ctx3 = ExecutionContext( + executor_name="simple_prompt_studio", + operation="extract", + run_id="run-3", + execution_source="tool", + executor_params={}, + ) + result3 = executor.execute(ctx3) + assert not result3.success + finally: + ExecutorRegistry.clear() + + +# --------------------------------------------------------------------------- +# 3. Queue routing +# --------------------------------------------------------------------------- + +class TestSPSQueueRouting: + def test_sps_routes_to_correct_queue(self): + queue = ExecutionDispatcher._get_queue("simple_prompt_studio") + assert queue == "celery_executor_simple_prompt_studio" + + def test_dispatch_sends_to_sps_queue(self): + mock_app = MagicMock() + mock_result = MagicMock() + mock_result.get.return_value = ExecutionResult( + success=True, data={"output": {"field": "value"}} + ).to_dict() + mock_app.send_task.return_value = mock_result + + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = ExecutionContext( + executor_name="simple_prompt_studio", + operation="sps_answer_prompt", + run_id="run-1", + execution_source="tool", + executor_params={"tool_settings": {}, "output": {}}, + ) + result = dispatcher.dispatch(ctx) + + mock_app.send_task.assert_called_once() + call_kwargs = mock_app.send_task.call_args + assert call_kwargs.kwargs.get("queue") == "celery_executor_simple_prompt_studio" + + def test_dispatch_sps_index_to_correct_queue(self): + mock_app = MagicMock() + mock_result = MagicMock() + mock_result.get.return_value = ExecutionResult( + success=True, data={"output": "indexed"} + ).to_dict() + mock_app.send_task.return_value = mock_result + + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = ExecutionContext( + executor_name="simple_prompt_studio", + operation="sps_index", + run_id="run-1", + execution_source="tool", + executor_params={"output": {}, "file_path": "/tmp/test.pdf"}, + ) + result = dispatcher.dispatch(ctx) + + mock_app.send_task.assert_called_once() + call_kwargs = mock_app.send_task.call_args + assert call_kwargs.kwargs.get("queue") == "celery_executor_simple_prompt_studio" + + +# --------------------------------------------------------------------------- +# 4. LegacyExecutor does NOT handle SPS operations +# --------------------------------------------------------------------------- + +class TestLegacyExcludesSPS: + def test_sps_answer_prompt_not_in_legacy_operation_map(self): + from executor.executors.legacy_executor import LegacyExecutor + assert "sps_answer_prompt" not in LegacyExecutor._OPERATION_MAP + + def test_sps_index_not_in_legacy_operation_map(self): + from executor.executors.legacy_executor import LegacyExecutor + assert "sps_index" not in LegacyExecutor._OPERATION_MAP + + def test_legacy_returns_failure_for_sps_answer_prompt(self): + from executor.executors.legacy_executor import LegacyExecutor + + ExecutorRegistry.clear() + if "legacy" not in ExecutorRegistry.list_executors(): + ExecutorRegistry.register(LegacyExecutor) + executor = ExecutorRegistry.get("legacy") + + ctx = ExecutionContext( + executor_name="legacy", + operation="sps_answer_prompt", + run_id="run-1", + execution_source="tool", + executor_params={}, + ) + result = executor.execute(ctx) + assert not result.success + assert "does not support" in result.error + + def test_legacy_returns_failure_for_sps_index(self): + from executor.executors.legacy_executor import LegacyExecutor + + ExecutorRegistry.clear() + if "legacy" not in ExecutorRegistry.list_executors(): + ExecutorRegistry.register(LegacyExecutor) + executor = ExecutorRegistry.get("legacy") + + ctx = ExecutionContext( + executor_name="legacy", + operation="sps_index", + run_id="run-1", + execution_source="tool", + executor_params={}, + ) + result = executor.execute(ctx) + assert not result.success + assert "does not support" in result.error + + +# --------------------------------------------------------------------------- +# 5. tasks.py log_component for SPS operations +# --------------------------------------------------------------------------- + +class TestTasksLogComponent: + def test_sps_answer_prompt_uses_default_log_component(self): + """SPS operations use the default log_component branch in tasks.py.""" + ctx_dict = { + "executor_name": "simple_prompt_studio", + "operation": "sps_answer_prompt", + "run_id": "run-001", + "execution_source": "tool", + "organization_id": "org-1", + "executor_params": { + "tool_id": "tool-1", + "file_name": "invoice.pdf", + }, + "request_id": "req-1", + "log_events_id": "evt-1", + } + context = ExecutionContext.from_dict(ctx_dict) + params = context.executor_params + + # SPS operations fall through to the default branch + assert context.operation not in ("ide_index", "structure_pipeline", + "table_extract", "smart_table_extract") + component = { + "tool_id": params.get("tool_id", ""), + "run_id": context.run_id, + "doc_name": str(params.get("file_name", "")), + "operation": context.operation, + } + assert component == { + "tool_id": "tool-1", + "run_id": "run-001", + "doc_name": "invoice.pdf", + "operation": "sps_answer_prompt", + } + + def test_sps_index_uses_default_log_component(self): + """SPS index also uses the default log_component branch.""" + ctx_dict = { + "executor_name": "simple_prompt_studio", + "operation": "sps_index", + "run_id": "run-002", + "execution_source": "tool", + "executor_params": { + "tool_id": "tool-2", + "file_name": "contract.pdf", + }, + "request_id": "req-2", + "log_events_id": "evt-2", + } + context = ExecutionContext.from_dict(ctx_dict) + params = context.executor_params + + assert context.operation not in ("ide_index", "structure_pipeline", + "table_extract", "smart_table_extract") + component = { + "tool_id": params.get("tool_id", ""), + "run_id": context.run_id, + "doc_name": str(params.get("file_name", "")), + "operation": context.operation, + } + assert component["operation"] == "sps_index" diff --git a/workers/tests/test_sanity_phase6h.py b/workers/tests/test_sanity_phase6h.py new file mode 100644 index 0000000000..8fff6ff941 --- /dev/null +++ b/workers/tests/test_sanity_phase6h.py @@ -0,0 +1,268 @@ +"""Phase 6H Sanity — AgenticPromptStudioExecutor + agentic operations. + +Verifies: +1. All 8 agentic Operation enums exist +2. AGENTIC_EXTRACTION removed from Operation enum +3. Mock AgenticPromptStudioExecutor — registration and all 8 operations +4. Queue routing: executor_name="agentic" → celery_executor_agentic +5. LegacyExecutor does NOT handle any agentic operations +6. Dispatch sends to correct queue +7. Structure tool routes to agentic executor (not legacy) +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from unstract.sdk1.execution.context import ExecutionContext, Operation +from unstract.sdk1.execution.dispatcher import ExecutionDispatcher +from unstract.sdk1.execution.executor import BaseExecutor +from unstract.sdk1.execution.registry import ExecutorRegistry +from unstract.sdk1.execution.result import ExecutionResult + + +AGENTIC_OPERATIONS = [ + "agentic_extract", + "agentic_summarize", + "agentic_uniformize", + "agentic_finalize", + "agentic_generate_prompt", + "agentic_generate_prompt_pipeline", + "agentic_compare", + "agentic_tune_field", +] + + +# --------------------------------------------------------------------------- +# 1. Operation enums +# --------------------------------------------------------------------------- + +class TestAgenticOperations: + @pytest.mark.parametrize("op", AGENTIC_OPERATIONS) + def test_agentic_operation_enum_exists(self, op): + values = {o.value for o in Operation} + assert op in values + + def test_agentic_extraction_removed(self): + """Old AGENTIC_EXTRACTION enum no longer exists.""" + assert not hasattr(Operation, "AGENTIC_EXTRACTION") + values = {o.value for o in Operation} + assert "agentic_extraction" not in values + + +# --------------------------------------------------------------------------- +# 2. Mock AgenticPromptStudioExecutor — registration and all operations +# --------------------------------------------------------------------------- + +class TestAgenticExecutorRegistration: + def test_mock_agentic_executor_registers_and_routes_all_ops(self): + """Simulate cloud executor discovery and execution of all 8 ops.""" + @ExecutorRegistry.register + class MockAgenticExecutor(BaseExecutor): + _OPERATION_MAP = {op: f"_handle_{op}" for op in AGENTIC_OPERATIONS} + + @property + def name(self) -> str: + return "agentic" + + def execute(self, context): + handler_name = self._OPERATION_MAP.get(context.operation) + if not handler_name: + return ExecutionResult.failure( + error=f"Unsupported: {context.operation}" + ) + return ExecutionResult( + success=True, + data={ + "output": {"operation": context.operation}, + "metadata": {}, + }, + ) + + try: + assert "agentic" in ExecutorRegistry.list_executors() + executor = ExecutorRegistry.get("agentic") + assert executor.name == "agentic" + + # Test all 8 operations route successfully + for op in AGENTIC_OPERATIONS: + ctx = ExecutionContext( + executor_name="agentic", + operation=op, + run_id=f"run-{op}", + execution_source="tool", + executor_params={}, + ) + result = executor.execute(ctx) + assert result.success, f"Operation {op} failed" + assert result.data["output"]["operation"] == op + + # Rejects unsupported operations + ctx = ExecutionContext( + executor_name="agentic", + operation="answer_prompt", + run_id="run-unsupported", + execution_source="tool", + executor_params={}, + ) + result = executor.execute(ctx) + assert not result.success + finally: + ExecutorRegistry.clear() + + +# --------------------------------------------------------------------------- +# 3. Queue routing +# --------------------------------------------------------------------------- + +class TestAgenticQueueRouting: + def test_agentic_routes_to_correct_queue(self): + queue = ExecutionDispatcher._get_queue("agentic") + assert queue == "celery_executor_agentic" + + @pytest.mark.parametrize("op", AGENTIC_OPERATIONS) + def test_dispatch_sends_to_agentic_queue(self, op): + mock_app = MagicMock() + mock_result = MagicMock() + mock_result.get.return_value = ExecutionResult( + success=True, data={"output": {}} + ).to_dict() + mock_app.send_task.return_value = mock_result + + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = ExecutionContext( + executor_name="agentic", + operation=op, + run_id="run-1", + execution_source="tool", + executor_params={}, + ) + dispatcher.dispatch(ctx) + + mock_app.send_task.assert_called_once() + call_kwargs = mock_app.send_task.call_args + assert call_kwargs.kwargs.get("queue") == "celery_executor_agentic" + + +# --------------------------------------------------------------------------- +# 4. LegacyExecutor does NOT handle agentic operations +# --------------------------------------------------------------------------- + +class TestLegacyExcludesAgentic: + @pytest.mark.parametrize("op", AGENTIC_OPERATIONS) + def test_agentic_op_not_in_legacy_operation_map(self, op): + from executor.executors.legacy_executor import LegacyExecutor + assert op not in LegacyExecutor._OPERATION_MAP + + def test_legacy_returns_failure_for_agentic_extract(self): + from executor.executors.legacy_executor import LegacyExecutor + + ExecutorRegistry.clear() + if "legacy" not in ExecutorRegistry.list_executors(): + ExecutorRegistry.register(LegacyExecutor) + executor = ExecutorRegistry.get("legacy") + + ctx = ExecutionContext( + executor_name="legacy", + operation="agentic_extract", + run_id="run-1", + execution_source="tool", + executor_params={}, + ) + result = executor.execute(ctx) + assert not result.success + assert "does not support" in result.error + + def test_legacy_returns_failure_for_agentic_summarize(self): + from executor.executors.legacy_executor import LegacyExecutor + + ExecutorRegistry.clear() + if "legacy" not in ExecutorRegistry.list_executors(): + ExecutorRegistry.register(LegacyExecutor) + executor = ExecutorRegistry.get("legacy") + + ctx = ExecutionContext( + executor_name="legacy", + operation="agentic_summarize", + run_id="run-1", + execution_source="tool", + executor_params={}, + ) + result = executor.execute(ctx) + assert not result.success + assert "does not support" in result.error + + +# --------------------------------------------------------------------------- +# 5. Structure tool routes to agentic executor +# --------------------------------------------------------------------------- + +class TestStructureToolAgenticRouting: + @patch("unstract.sdk1.x2txt.X2Text") + def test_structure_tool_dispatches_agentic_extract(self, mock_x2text_cls): + """Verify _run_agentic_extraction sends executor_name='agentic'.""" + # Mock X2Text so it doesn't try to call real adapters + mock_x2text = MagicMock() + mock_x2text.process.return_value = MagicMock(extracted_text="hello world") + mock_x2text_cls.return_value = mock_x2text + + from file_processing.structure_tool_task import _run_agentic_extraction + + mock_dispatcher = MagicMock() + mock_dispatcher.dispatch.return_value = ExecutionResult( + success=True, data={"output": {"field": "value"}} + ) + + mock_shim = MagicMock() + mock_shim.platform_api_key = "test-key" + + result = _run_agentic_extraction( + tool_metadata={"name": "test"}, + input_file_path="/tmp/test.pdf", + output_dir_path="/tmp/output", + tool_instance_metadata={}, + dispatcher=mock_dispatcher, + shim=mock_shim, + platform_helper=MagicMock(), + file_execution_id="exec-001", + organization_id="org-001", + source_file_name="test.pdf", + fs=MagicMock(), + ) + + # Verify dispatch was called with correct routing + mock_dispatcher.dispatch.assert_called_once() + dispatched_ctx = mock_dispatcher.dispatch.call_args[0][0] + assert dispatched_ctx.executor_name == "agentic" + assert dispatched_ctx.operation == "agentic_extract" + assert dispatched_ctx.organization_id == "org-001" + + +# --------------------------------------------------------------------------- +# 6. tasks.py log_component for agentic operations +# --------------------------------------------------------------------------- + +class TestTasksLogComponent: + @pytest.mark.parametrize("op", AGENTIC_OPERATIONS) + def test_agentic_ops_use_default_log_component(self, op): + """Agentic operations fall through to default log_component.""" + ctx_dict = { + "executor_name": "agentic", + "operation": op, + "run_id": "run-001", + "execution_source": "tool", + "executor_params": { + "tool_id": "tool-1", + "file_name": "doc.pdf", + }, + "request_id": "req-1", + "log_events_id": "evt-1", + } + context = ExecutionContext.from_dict(ctx_dict) + + # Agentic ops should NOT match ide_index, structure_pipeline, + # or table_extract/smart_table_extract branches + assert context.operation not in ( + "ide_index", "structure_pipeline", + "table_extract", "smart_table_extract", + ) diff --git a/workers/tests/test_sanity_phase6i.py b/workers/tests/test_sanity_phase6i.py new file mode 100644 index 0000000000..4de0e8f662 --- /dev/null +++ b/workers/tests/test_sanity_phase6i.py @@ -0,0 +1,272 @@ +"""Phase 6I Sanity — Backend Summarizer Migration. + +Verifies: +1. Summarize operation exists and routes through LegacyExecutor +2. Summarize executor_params contract matches _handle_summarize expectations +3. Dispatch routes summarize to celery_executor_legacy queue +4. Summarize result has expected shape (data.data = summary text) +5. Full Celery chain for summarize operation +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from unstract.sdk1.execution.context import ExecutionContext, Operation +from unstract.sdk1.execution.dispatcher import ExecutionDispatcher +from unstract.sdk1.execution.registry import ExecutorRegistry +from unstract.sdk1.execution.result import ExecutionResult + + +# Patches +_PATCH_GET_PROMPT_DEPS = ( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" +) + + +def _register_legacy(): + from executor.executors.legacy_executor import LegacyExecutor + ExecutorRegistry.clear() + ExecutorRegistry.register(LegacyExecutor) + + +# --------------------------------------------------------------------------- +# 1. Summarize operation enum +# --------------------------------------------------------------------------- + +class TestSummarizeOperation: + def test_summarize_enum_exists(self): + assert hasattr(Operation, "SUMMARIZE") + assert Operation.SUMMARIZE.value == "summarize" + + def test_summarize_in_legacy_operation_map(self): + from executor.executors.legacy_executor import LegacyExecutor + assert "summarize" in LegacyExecutor._OPERATION_MAP + + +# --------------------------------------------------------------------------- +# 2. Executor params contract +# --------------------------------------------------------------------------- + +class TestSummarizeParamsContract: + def test_summarize_params_match_handler_expectations(self): + """Verify the params the backend summarizer sends match + what _handle_summarize expects.""" + # These are the keys the cloud summarizer.py now sends + backend_params = { + "llm_adapter_instance_id": "llm-uuid", + "summarize_prompt": "Summarize the document...", + "context": "This is the full document text...", + "prompt_keys": ["invoice_number", "total_amount"], + "PLATFORM_SERVICE_API_KEY": "platform-key-123", + } + + # _handle_summarize reads these keys + assert "llm_adapter_instance_id" in backend_params + assert "summarize_prompt" in backend_params + assert "context" in backend_params + assert "prompt_keys" in backend_params + assert "PLATFORM_SERVICE_API_KEY" in backend_params + + +# --------------------------------------------------------------------------- +# 3. Queue routing +# --------------------------------------------------------------------------- + +class TestSummarizeQueueRouting: + def test_summarize_routes_to_legacy_queue(self): + """Summarize dispatches to celery_executor_legacy (LegacyExecutor).""" + queue = ExecutionDispatcher._get_queue("legacy") + assert queue == "celery_executor_legacy" + + def test_dispatch_sends_summarize_to_legacy_queue(self): + mock_app = MagicMock() + mock_result = MagicMock() + mock_result.get.return_value = ExecutionResult( + success=True, data={"data": "Summary text here"} + ).to_dict() + mock_app.send_task.return_value = mock_result + + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = ExecutionContext( + executor_name="legacy", + operation="summarize", + run_id="run-summarize", + execution_source="ide", + organization_id="org-1", + executor_params={ + "llm_adapter_instance_id": "llm-1", + "summarize_prompt": "Summarize...", + "context": "Document text", + "prompt_keys": ["field1"], + "PLATFORM_SERVICE_API_KEY": "key-1", + }, + ) + result = dispatcher.dispatch(ctx) + + mock_app.send_task.assert_called_once() + call_kwargs = mock_app.send_task.call_args + assert call_kwargs.kwargs.get("queue") == "celery_executor_legacy" + assert result.success + assert result.data["data"] == "Summary text here" + + +# --------------------------------------------------------------------------- +# 4. Result shape +# --------------------------------------------------------------------------- + +class TestSummarizeResultShape: + @patch(_PATCH_GET_PROMPT_DEPS) + def test_summarize_returns_data_key(self, mock_deps): + """_handle_summarize returns ExecutionResult with data.data = str.""" + mock_LLM = MagicMock() + mock_llm_instance = MagicMock() + mock_LLM.return_value = mock_llm_instance + + mock_deps.return_value = ( + MagicMock(), # RetrievalService + MagicMock(), # PostProcessor + MagicMock(), # VariableReplacement + MagicMock(), # JsonRepair + mock_LLM, # LLM + MagicMock(), # Embedding + MagicMock(), # VectorDB + ) + + # Mock AnswerPromptService.run_completion + with patch( + "executor.executors.answer_prompt.AnswerPromptService.run_completion", + return_value="This is the summary.", + ): + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + ctx = ExecutionContext( + executor_name="legacy", + operation="summarize", + run_id="run-result-shape", + execution_source="ide", + organization_id="org-1", + executor_params={ + "llm_adapter_instance_id": "llm-1", + "summarize_prompt": "Summarize the document.", + "context": "Full document text here.", + "prompt_keys": ["total"], + "PLATFORM_SERVICE_API_KEY": "key-1", + }, + ) + result = executor.execute(ctx) + + assert result.success + assert result.data["data"] == "This is the summary." + + @patch(_PATCH_GET_PROMPT_DEPS) + def test_summarize_missing_context_returns_failure(self, mock_deps): + """Missing context param returns failure without LLM call.""" + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + ctx = ExecutionContext( + executor_name="legacy", + operation="summarize", + run_id="run-missing-ctx", + execution_source="ide", + executor_params={ + "llm_adapter_instance_id": "llm-1", + "summarize_prompt": "Summarize.", + "context": "", # empty + "PLATFORM_SERVICE_API_KEY": "key-1", + }, + ) + result = executor.execute(ctx) + + assert not result.success + assert "context" in result.error.lower() + + @patch(_PATCH_GET_PROMPT_DEPS) + def test_summarize_missing_llm_returns_failure(self, mock_deps): + """Missing llm_adapter_instance_id returns failure.""" + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + ctx = ExecutionContext( + executor_name="legacy", + operation="summarize", + run_id="run-missing-llm", + execution_source="ide", + executor_params={ + "llm_adapter_instance_id": "", # empty + "summarize_prompt": "Summarize.", + "context": "Some text", + "PLATFORM_SERVICE_API_KEY": "key-1", + }, + ) + result = executor.execute(ctx) + + assert not result.success + assert "llm_adapter_instance_id" in result.error.lower() + + +# --------------------------------------------------------------------------- +# 5. Full Celery chain +# --------------------------------------------------------------------------- + +@pytest.fixture +def eager_app(): + """Configure executor Celery app for eager-mode testing.""" + from executor.worker import app + + original = { + "task_always_eager": app.conf.task_always_eager, + "task_eager_propagates": app.conf.task_eager_propagates, + "result_backend": app.conf.result_backend, + } + app.conf.update( + task_always_eager=True, + task_eager_propagates=False, + result_backend="cache+memory://", + ) + yield app + app.conf.update(original) + + +class TestSummarizeCeleryChain: + @patch(_PATCH_GET_PROMPT_DEPS) + def test_summarize_full_celery_chain(self, mock_deps, eager_app): + """Summarize through full Celery task chain.""" + mock_LLM = MagicMock() + mock_llm_instance = MagicMock() + mock_LLM.return_value = mock_llm_instance + + mock_deps.return_value = ( + MagicMock(), MagicMock(), MagicMock(), MagicMock(), + mock_LLM, MagicMock(), MagicMock(), + ) + + with patch( + "executor.executors.answer_prompt.AnswerPromptService.run_completion", + return_value="Celery chain summary.", + ): + _register_legacy() + + ctx = ExecutionContext( + executor_name="legacy", + operation="summarize", + run_id="run-celery-summarize", + execution_source="ide", + organization_id="org-1", + executor_params={ + "llm_adapter_instance_id": "llm-1", + "summarize_prompt": "Summarize.", + "context": "Document text for celery chain.", + "prompt_keys": ["amount"], + "PLATFORM_SERVICE_API_KEY": "key-1", + }, + ) + + task = eager_app.tasks["execute_extraction"] + result_dict = task.apply(args=[ctx.to_dict()]).get() + result = ExecutionResult.from_dict(result_dict) + + assert result.success + assert result.data["data"] == "Celery chain summary." diff --git a/workers/tests/test_sanity_phase6j.py b/workers/tests/test_sanity_phase6j.py new file mode 100644 index 0000000000..2336b65d05 --- /dev/null +++ b/workers/tests/test_sanity_phase6j.py @@ -0,0 +1,684 @@ +"""Phase 6J — Comprehensive Phase 6 sanity tests. + +Consolidated regression + integration tests for the full Phase 6 +plugin migration. Verifies: + +1. Full Operation enum coverage — every operation has exactly one executor +2. Multi-executor coexistence in ExecutorRegistry +3. End-to-end Celery chain for each cloud executor (mock executors) +4. Cross-cutting highlight plugin works across executors +5. Plugin loader → executor registration → dispatch → result flow +6. Queue routing for all executor names +7. Graceful degradation when cloud plugins missing +8. tasks.py log_component for all operation types +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from unstract.sdk1.execution.context import ExecutionContext, Operation +from unstract.sdk1.execution.dispatcher import ExecutionDispatcher +from unstract.sdk1.execution.executor import BaseExecutor +from unstract.sdk1.execution.orchestrator import ExecutionOrchestrator +from unstract.sdk1.execution.registry import ExecutorRegistry +from unstract.sdk1.execution.result import ExecutionResult + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def _clean_registry(): + ExecutorRegistry.clear() + yield + ExecutorRegistry.clear() + + +@pytest.fixture +def eager_app(): + """Configure executor Celery app for eager-mode testing.""" + from executor.worker import app + + original = { + "task_always_eager": app.conf.task_always_eager, + "task_eager_propagates": app.conf.task_eager_propagates, + "result_backend": app.conf.result_backend, + } + app.conf.update( + task_always_eager=True, + task_eager_propagates=False, + result_backend="cache+memory://", + ) + yield app + app.conf.update(original) + + +def _register_legacy(): + from executor.executors.legacy_executor import LegacyExecutor + ExecutorRegistry.register(LegacyExecutor) + + +# Mock cloud executors for multi-executor tests +def _register_mock_cloud_executors(): + """Register mock cloud executors alongside LegacyExecutor.""" + + @ExecutorRegistry.register + class MockTableExecutor(BaseExecutor): + @property + def name(self) -> str: + return "table" + + def execute(self, context): + if context.operation != "table_extract": + return ExecutionResult.failure( + error=f"Unsupported: {context.operation}" + ) + return ExecutionResult( + success=True, + data={"output": "table_data", "metadata": {}}, + ) + + @ExecutorRegistry.register + class MockSmartTableExecutor(BaseExecutor): + @property + def name(self) -> str: + return "smart_table" + + def execute(self, context): + if context.operation != "smart_table_extract": + return ExecutionResult.failure( + error=f"Unsupported: {context.operation}" + ) + return ExecutionResult( + success=True, + data={"output": "smart_table_data", "metadata": {}}, + ) + + @ExecutorRegistry.register + class MockSPSExecutor(BaseExecutor): + @property + def name(self) -> str: + return "simple_prompt_studio" + + def execute(self, context): + if context.operation not in ("sps_answer_prompt", "sps_index"): + return ExecutionResult.failure( + error=f"Unsupported: {context.operation}" + ) + return ExecutionResult( + success=True, + data={"output": f"sps_{context.operation}", "metadata": {}}, + ) + + @ExecutorRegistry.register + class MockAgenticExecutor(BaseExecutor): + _OPS = { + "agentic_extract", "agentic_summarize", "agentic_uniformize", + "agentic_finalize", "agentic_generate_prompt", + "agentic_generate_prompt_pipeline", "agentic_compare", + "agentic_tune_field", + } + + @property + def name(self) -> str: + return "agentic" + + def execute(self, context): + if context.operation not in self._OPS: + return ExecutionResult.failure( + error=f"Unsupported: {context.operation}" + ) + return ExecutionResult( + success=True, + data={"output": f"agentic_{context.operation}", "metadata": {}}, + ) + + +# --------------------------------------------------------------------------- +# 1. Full Operation enum coverage — every operation has exactly one executor +# --------------------------------------------------------------------------- + +# Map of every Operation value to the executor that handles it +OPERATION_TO_EXECUTOR = { + # LegacyExecutor (OSS) + "extract": "legacy", + "index": "legacy", + "answer_prompt": "legacy", + "single_pass_extraction": "legacy", + "summarize": "legacy", + "ide_index": "legacy", + "structure_pipeline": "legacy", + # Cloud executors + "table_extract": "table", + "smart_table_extract": "smart_table", + "sps_answer_prompt": "simple_prompt_studio", + "sps_index": "simple_prompt_studio", + "agentic_extract": "agentic", + "agentic_summarize": "agentic", + "agentic_uniformize": "agentic", + "agentic_finalize": "agentic", + "agentic_generate_prompt": "agentic", + "agentic_generate_prompt_pipeline": "agentic", + "agentic_compare": "agentic", + "agentic_tune_field": "agentic", +} + + +class TestOperationEnumCoverage: + def test_every_operation_is_mapped(self): + """Every Operation enum value has an assigned executor.""" + for op in Operation: + assert op.value in OPERATION_TO_EXECUTOR, ( + f"Operation {op.value} not mapped to any executor" + ) + + def test_no_extra_mappings(self): + """No stale mappings for removed operations.""" + valid_ops = {op.value for op in Operation} + for mapped_op in OPERATION_TO_EXECUTOR: + assert mapped_op in valid_ops, ( + f"Mapped operation '{mapped_op}' not in Operation enum" + ) + + def test_operation_count(self): + """Verify total operation count matches expectations.""" + assert len(Operation) == 19 # 7 legacy + 2 table + 2 sps + 8 agentic + + def test_legacy_operations_in_operation_map(self): + """All legacy operations are in LegacyExecutor._OPERATION_MAP.""" + from executor.executors.legacy_executor import LegacyExecutor + + for op_value, executor_name in OPERATION_TO_EXECUTOR.items(): + if executor_name == "legacy": + assert op_value in LegacyExecutor._OPERATION_MAP, ( + f"Legacy operation {op_value} missing from _OPERATION_MAP" + ) + + def test_cloud_operations_not_in_legacy_map(self): + """Cloud operations are NOT in LegacyExecutor._OPERATION_MAP.""" + from executor.executors.legacy_executor import LegacyExecutor + + for op_value, executor_name in OPERATION_TO_EXECUTOR.items(): + if executor_name != "legacy": + assert op_value not in LegacyExecutor._OPERATION_MAP, ( + f"Cloud operation {op_value} should NOT be in legacy map" + ) + + +# --------------------------------------------------------------------------- +# 2. Multi-executor coexistence in registry +# --------------------------------------------------------------------------- + +class TestMultiExecutorCoexistence: + def test_all_five_executors_registered(self): + """Legacy + 4 cloud executors all coexist in registry.""" + _register_legacy() + _register_mock_cloud_executors() + + executors = ExecutorRegistry.list_executors() + assert "legacy" in executors + assert "table" in executors + assert "smart_table" in executors + assert "simple_prompt_studio" in executors + assert "agentic" in executors + assert len(executors) == 5 + + def test_each_executor_has_correct_name(self): + _register_legacy() + _register_mock_cloud_executors() + + for name in ["legacy", "table", "smart_table", "simple_prompt_studio", "agentic"]: + executor = ExecutorRegistry.get(name) + assert executor.name == name + + def test_wrong_executor_rejects_operation(self): + """Dispatching a table operation to legacy returns failure.""" + _register_legacy() + _register_mock_cloud_executors() + + legacy = ExecutorRegistry.get("legacy") + ctx = ExecutionContext( + executor_name="legacy", + operation="table_extract", + run_id="run-1", + execution_source="tool", + ) + result = legacy.execute(ctx) + assert not result.success + assert "does not support" in result.error + + def test_correct_executor_handles_operation(self): + """Each operation routes to the right executor.""" + _register_legacy() + _register_mock_cloud_executors() + + test_cases = [ + ("table", "table_extract"), + ("smart_table", "smart_table_extract"), + ("simple_prompt_studio", "sps_answer_prompt"), + ("simple_prompt_studio", "sps_index"), + ("agentic", "agentic_extract"), + ("agentic", "agentic_compare"), + ] + for executor_name, operation in test_cases: + executor = ExecutorRegistry.get(executor_name) + ctx = ExecutionContext( + executor_name=executor_name, + operation=operation, + run_id=f"run-{operation}", + execution_source="tool", + ) + result = executor.execute(ctx) + assert result.success, f"{executor_name}/{operation} failed" + + +# --------------------------------------------------------------------------- +# 3. End-to-end Celery chain for cloud executors +# --------------------------------------------------------------------------- + +class TestCeleryChainCloudExecutors: + def test_table_extract_celery_chain(self, eager_app): + """TABLE extraction through full Celery task chain.""" + _register_legacy() + _register_mock_cloud_executors() + + ctx = ExecutionContext( + executor_name="table", + operation="table_extract", + run_id="run-celery-table", + execution_source="tool", + ) + task = eager_app.tasks["execute_extraction"] + result_dict = task.apply(args=[ctx.to_dict()]).get() + result = ExecutionResult.from_dict(result_dict) + + assert result.success + assert result.data["output"] == "table_data" + + def test_smart_table_extract_celery_chain(self, eager_app): + """SMART TABLE extraction through full Celery task chain.""" + _register_legacy() + _register_mock_cloud_executors() + + ctx = ExecutionContext( + executor_name="smart_table", + operation="smart_table_extract", + run_id="run-celery-smart-table", + execution_source="tool", + ) + task = eager_app.tasks["execute_extraction"] + result_dict = task.apply(args=[ctx.to_dict()]).get() + result = ExecutionResult.from_dict(result_dict) + + assert result.success + assert result.data["output"] == "smart_table_data" + + def test_sps_answer_prompt_celery_chain(self, eager_app): + """SPS answer_prompt through full Celery task chain.""" + _register_legacy() + _register_mock_cloud_executors() + + ctx = ExecutionContext( + executor_name="simple_prompt_studio", + operation="sps_answer_prompt", + run_id="run-celery-sps", + execution_source="tool", + ) + task = eager_app.tasks["execute_extraction"] + result_dict = task.apply(args=[ctx.to_dict()]).get() + result = ExecutionResult.from_dict(result_dict) + + assert result.success + + def test_agentic_extract_celery_chain(self, eager_app): + """Agentic extraction through full Celery task chain.""" + _register_legacy() + _register_mock_cloud_executors() + + ctx = ExecutionContext( + executor_name="agentic", + operation="agentic_extract", + run_id="run-celery-agentic", + execution_source="tool", + ) + task = eager_app.tasks["execute_extraction"] + result_dict = task.apply(args=[ctx.to_dict()]).get() + result = ExecutionResult.from_dict(result_dict) + + assert result.success + + def test_unregistered_executor_returns_failure(self, eager_app): + """Dispatching to unregistered executor returns failure.""" + _register_legacy() + # Don't register cloud executors + + ctx = ExecutionContext( + executor_name="table", + operation="table_extract", + run_id="run-missing", + execution_source="tool", + ) + task = eager_app.tasks["execute_extraction"] + result_dict = task.apply(args=[ctx.to_dict()]).get() + result = ExecutionResult.from_dict(result_dict) + + assert not result.success + assert "table" in result.error.lower() + + +# --------------------------------------------------------------------------- +# 4. Cross-cutting highlight plugin across executors +# --------------------------------------------------------------------------- + +class TestCrossCuttingHighlight: + @patch("importlib.metadata.entry_points", return_value=[]) + def test_highlight_plugin_not_installed_no_error(self, _mock_eps): + """When highlight plugin not installed, extraction still works.""" + from executor.executors.plugins.loader import ExecutorPluginLoader + + ExecutorPluginLoader.clear() + assert ExecutorPluginLoader.get("highlight-data") is None + # No error — graceful degradation + + def test_mock_highlight_plugin_shared_across_executors(self): + """Multiple executors can use the same highlight plugin instance.""" + from executor.executors.plugins.loader import ExecutorPluginLoader + + class FakeHighlight: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def run(self, response, **kwargs): + return {"highlighted": True} + + def get_highlight_data(self): + return {"lines": [1, 2, 3]} + + def get_confidence_data(self): + return {"confidence": 0.95} + + fake_ep = MagicMock() + fake_ep.name = "highlight-data" + fake_ep.load.return_value = FakeHighlight + + with patch( + "importlib.metadata.entry_points", + return_value=[fake_ep], + ): + ExecutorPluginLoader.clear() + cls = ExecutorPluginLoader.get("highlight-data") + assert cls is FakeHighlight + + # Both legacy and agentic contexts can create instances + legacy_hl = cls(file_path="/tmp/doc.txt", execution_source="ide") + agentic_hl = cls(file_path="/tmp/other.txt", execution_source="tool") + + assert legacy_hl.get_highlight_data() == {"lines": [1, 2, 3]} + assert agentic_hl.get_confidence_data() == {"confidence": 0.95} + + +# --------------------------------------------------------------------------- +# 5. Plugin loader → registration → dispatch → result flow +# --------------------------------------------------------------------------- + +class TestPluginDiscoveryToDispatchFlow: + def test_full_discovery_to_dispatch_flow(self): + """Simulate: entry point discovery → register → dispatch → result.""" + # Step 1: "Discover" a cloud executor via entry point + @ExecutorRegistry.register + class DiscoveredExecutor(BaseExecutor): + @property + def name(self): + return "discovered" + + def execute(self, context): + return ExecutionResult( + success=True, + data={"output": "discovered_result"}, + ) + + # Step 2: Verify registration + assert "discovered" in ExecutorRegistry.list_executors() + + # Step 3: Dispatch via mock Celery + mock_app = MagicMock() + mock_result = MagicMock() + mock_result.get.return_value = ExecutionResult( + success=True, data={"output": "discovered_result"} + ).to_dict() + mock_app.send_task.return_value = mock_result + + dispatcher = ExecutionDispatcher(celery_app=mock_app) + ctx = ExecutionContext( + executor_name="discovered", + operation="custom_op", + run_id="run-flow", + execution_source="tool", + ) + result = dispatcher.dispatch(ctx) + + # Step 4: Verify result + assert result.success + assert result.data["output"] == "discovered_result" + + # Step 5: Verify queue routing + call_kwargs = mock_app.send_task.call_args + assert call_kwargs.kwargs["queue"] == "celery_executor_discovered" + + +# --------------------------------------------------------------------------- +# 6. Queue routing for all executor names +# --------------------------------------------------------------------------- + +EXECUTOR_QUEUE_MAP = { + "legacy": "celery_executor_legacy", + "table": "celery_executor_table", + "smart_table": "celery_executor_smart_table", + "simple_prompt_studio": "celery_executor_simple_prompt_studio", + "agentic": "celery_executor_agentic", +} + + +class TestQueueRoutingAllExecutors: + @pytest.mark.parametrize( + "executor_name,expected_queue", + list(EXECUTOR_QUEUE_MAP.items()), + ) + def test_queue_name_for_executor(self, executor_name, expected_queue): + assert ExecutionDispatcher._get_queue(executor_name) == expected_queue + + +# --------------------------------------------------------------------------- +# 7. Graceful degradation when cloud plugins missing +# --------------------------------------------------------------------------- + +class TestGracefulDegradation: + def test_legacy_works_without_cloud_executors(self, eager_app): + """Legacy operations work even when no cloud executors installed.""" + _register_legacy() + + # Only legacy should be in registry + assert ExecutorRegistry.list_executors() == ["legacy"] + + # Legacy operations still work + ctx = ExecutionContext( + executor_name="legacy", + operation="extract", + run_id="run-degrade", + execution_source="tool", + executor_params={ + "tool_id": "t-1", + "file_name": "test.pdf", + "file_hash": "abc", + "PLATFORM_SERVICE_API_KEY": "key", + }, + ) + # This will fail at the handler level (no mocks), but it should + # route correctly and NOT fail at registry/dispatch level + executor = ExecutorRegistry.get("legacy") + assert executor is not None + assert executor.name == "legacy" + + def test_cloud_op_on_legacy_returns_meaningful_error(self): + """Attempting a cloud operation on legacy gives clear error.""" + _register_legacy() + executor = ExecutorRegistry.get("legacy") + + for cloud_op in ["table_extract", "smart_table_extract", + "sps_answer_prompt", "agentic_extract"]: + ctx = ExecutionContext( + executor_name="legacy", + operation=cloud_op, + run_id=f"run-{cloud_op}", + execution_source="tool", + ) + result = executor.execute(ctx) + assert not result.success + assert "does not support" in result.error + + def test_missing_executor_via_orchestrator(self): + """Orchestrator returns failure for unregistered executor.""" + _register_legacy() + orchestrator = ExecutionOrchestrator() + + ctx = ExecutionContext( + executor_name="table", + operation="table_extract", + run_id="run-no-table", + execution_source="tool", + ) + result = orchestrator.execute(ctx) + assert not result.success + assert "table" in result.error.lower() + + +# --------------------------------------------------------------------------- +# 8. tasks.py log_component for all operation types +# --------------------------------------------------------------------------- + +class TestLogComponentAllOperations: + """Verify tasks.py log_component builder handles all operation types.""" + + def _build_log_component(self, operation, executor_params=None): + """Simulate the tasks.py log_component logic.""" + params = executor_params or { + "tool_id": "t-1", + "file_name": "doc.pdf", + } + ctx = ExecutionContext.from_dict({ + "executor_name": "legacy", + "operation": operation, + "run_id": "run-log", + "execution_source": "tool", + "executor_params": params, + "request_id": "req-1", + "log_events_id": "evt-1", + }) + + # Replicate tasks.py logic + if ctx.operation == "ide_index": + extract_params = params.get("extract_params", {}) + return { + "tool_id": extract_params.get("tool_id", ""), + "run_id": ctx.run_id, + "doc_name": str(extract_params.get("file_name", "")), + "operation": ctx.operation, + } + elif ctx.operation == "structure_pipeline": + answer_params = params.get("answer_params", {}) + pipeline_opts = params.get("pipeline_options", {}) + return { + "tool_id": answer_params.get("tool_id", ""), + "run_id": ctx.run_id, + "doc_name": str(pipeline_opts.get("source_file_name", "")), + "operation": ctx.operation, + } + elif ctx.operation in ("table_extract", "smart_table_extract"): + return { + "tool_id": params.get("tool_id", ""), + "run_id": ctx.run_id, + "doc_name": str(params.get("file_name", "")), + "operation": ctx.operation, + } + else: + return { + "tool_id": params.get("tool_id", ""), + "run_id": ctx.run_id, + "doc_name": str(params.get("file_name", "")), + "operation": ctx.operation, + } + + def test_ide_index_extracts_nested_params(self): + comp = self._build_log_component("ide_index", { + "extract_params": {"tool_id": "t-nested", "file_name": "nested.pdf"}, + }) + assert comp["tool_id"] == "t-nested" + assert comp["doc_name"] == "nested.pdf" + + def test_structure_pipeline_extracts_nested_params(self): + comp = self._build_log_component("structure_pipeline", { + "answer_params": {"tool_id": "t-pipe"}, + "pipeline_options": {"source_file_name": "pipe.pdf"}, + }) + assert comp["tool_id"] == "t-pipe" + assert comp["doc_name"] == "pipe.pdf" + + def test_table_extract_uses_direct_params(self): + comp = self._build_log_component("table_extract") + assert comp["tool_id"] == "t-1" + assert comp["operation"] == "table_extract" + + def test_smart_table_extract_uses_direct_params(self): + comp = self._build_log_component("smart_table_extract") + assert comp["operation"] == "smart_table_extract" + + @pytest.mark.parametrize("op", [ + "extract", "index", "answer_prompt", "single_pass_extraction", + "summarize", "sps_answer_prompt", "sps_index", + "agentic_extract", "agentic_summarize", "agentic_compare", + ]) + def test_default_branch_for_standard_ops(self, op): + comp = self._build_log_component(op) + assert comp["tool_id"] == "t-1" + assert comp["doc_name"] == "doc.pdf" + assert comp["operation"] == op + + +# --------------------------------------------------------------------------- +# 9. ExecutionResult serialization round-trip +# --------------------------------------------------------------------------- + +class TestResultRoundTrip: + def test_success_result_round_trip(self): + original = ExecutionResult( + success=True, + data={"output": {"field": "value"}, "metadata": {"tokens": 100}}, + ) + restored = ExecutionResult.from_dict(original.to_dict()) + assert restored.success == original.success + assert restored.data == original.data + + def test_failure_result_round_trip(self): + original = ExecutionResult.failure(error="Something went wrong") + restored = ExecutionResult.from_dict(original.to_dict()) + assert not restored.success + assert restored.error == "Something went wrong" + + def test_context_round_trip(self): + original = ExecutionContext( + executor_name="agentic", + operation="agentic_extract", + run_id="run-rt", + execution_source="tool", + organization_id="org-1", + executor_params={"key": "value"}, + log_events_id="evt-1", + ) + restored = ExecutionContext.from_dict(original.to_dict()) + assert restored.executor_name == "agentic" + assert restored.operation == "agentic_extract" + assert restored.organization_id == "org-1" + assert restored.executor_params == {"key": "value"} + assert restored.log_events_id == "evt-1" diff --git a/workers/tests/test_usage.py b/workers/tests/test_usage.py new file mode 100644 index 0000000000..2fecc76713 --- /dev/null +++ b/workers/tests/test_usage.py @@ -0,0 +1,312 @@ +"""Phase 2G — Usage tracking tests. + +Verifies: +1. UsageHelper.push_usage_data wraps Audit correctly +2. Invalid kwargs returns False +3. Invalid platform_api_key returns False +4. Audit exceptions are caught and return False +5. format_float_positional formats correctly +6. SDK1 adapters already push usage (integration check) +7. answer_prompt handler returns metrics in ExecutionResult +""" + +from unittest.mock import MagicMock, patch + + +from executor.executors.usage import UsageHelper + + +# --------------------------------------------------------------------------- +# 1. push_usage_data success +# --------------------------------------------------------------------------- + + +class TestPushUsageData: + @patch("unstract.sdk1.audit.Audit") + def test_push_success(self, mock_audit_cls): + """Successful push returns True and calls Audit.""" + mock_audit = MagicMock() + mock_audit_cls.return_value = mock_audit + + result = UsageHelper.push_usage_data( + event_type="llm", + kwargs={"run_id": "run-001", "execution_id": "exec-001"}, + platform_api_key="test-key", + token_counter=MagicMock(), + model_name="gpt-4", + ) + + assert result is True + mock_audit.push_usage_data.assert_called_once() + call_kwargs = mock_audit.push_usage_data.call_args + assert call_kwargs.kwargs["platform_api_key"] == "test-key" + assert call_kwargs.kwargs["model_name"] == "gpt-4" + assert call_kwargs.kwargs["event_type"] == "llm" + + @patch("unstract.sdk1.audit.Audit") + def test_push_passes_token_counter(self, mock_audit_cls): + """Token counter is passed through to Audit.""" + mock_audit = MagicMock() + mock_audit_cls.return_value = mock_audit + mock_counter = MagicMock() + + UsageHelper.push_usage_data( + event_type="embedding", + kwargs={"run_id": "run-002"}, + platform_api_key="key-2", + token_counter=mock_counter, + ) + + call_kwargs = mock_audit.push_usage_data.call_args + assert call_kwargs.kwargs["token_counter"] is mock_counter + + +# --------------------------------------------------------------------------- +# 2. Invalid kwargs +# --------------------------------------------------------------------------- + + +class TestPushValidation: + def test_none_kwargs_returns_false(self): + result = UsageHelper.push_usage_data( + event_type="llm", + kwargs=None, + platform_api_key="key", + ) + assert result is False + + def test_empty_kwargs_returns_false(self): + result = UsageHelper.push_usage_data( + event_type="llm", + kwargs={}, + platform_api_key="key", + ) + assert result is False + + def test_non_dict_kwargs_returns_false(self): + result = UsageHelper.push_usage_data( + event_type="llm", + kwargs="not a dict", + platform_api_key="key", + ) + assert result is False + + +# --------------------------------------------------------------------------- +# 3. Invalid platform_api_key +# --------------------------------------------------------------------------- + + +class TestPushApiKeyValidation: + def test_none_key_returns_false(self): + result = UsageHelper.push_usage_data( + event_type="llm", + kwargs={"run_id": "r1"}, + platform_api_key=None, + ) + assert result is False + + def test_empty_key_returns_false(self): + result = UsageHelper.push_usage_data( + event_type="llm", + kwargs={"run_id": "r1"}, + platform_api_key="", + ) + assert result is False + + def test_non_string_key_returns_false(self): + result = UsageHelper.push_usage_data( + event_type="llm", + kwargs={"run_id": "r1"}, + platform_api_key=12345, + ) + assert result is False + + +# --------------------------------------------------------------------------- +# 4. Audit exceptions are caught +# --------------------------------------------------------------------------- + + +class TestPushErrorHandling: + @patch("unstract.sdk1.audit.Audit") + def test_audit_exception_returns_false(self, mock_audit_cls): + """Audit errors are caught and return False.""" + mock_audit = MagicMock() + mock_audit.push_usage_data.side_effect = Exception("Network error") + mock_audit_cls.return_value = mock_audit + + result = UsageHelper.push_usage_data( + event_type="llm", + kwargs={"run_id": "r1"}, + platform_api_key="key", + token_counter=MagicMock(), + ) + + assert result is False + + @patch("unstract.sdk1.audit.Audit") + def test_import_error_returns_false(self, mock_audit_cls): + """Import errors are caught gracefully.""" + mock_audit_cls.side_effect = ImportError("no module") + + result = UsageHelper.push_usage_data( + event_type="llm", + kwargs={"run_id": "r1"}, + platform_api_key="key", + ) + + assert result is False + + +# --------------------------------------------------------------------------- +# 5. format_float_positional +# --------------------------------------------------------------------------- + + +class TestFormatFloat: + def test_normal_float(self): + assert UsageHelper.format_float_positional(0.0001234) == "0.0001234" + + def test_trailing_zeros_removed(self): + assert UsageHelper.format_float_positional(1.50) == "1.5" + + def test_integer_value(self): + assert UsageHelper.format_float_positional(42.0) == "42" + + def test_zero(self): + assert UsageHelper.format_float_positional(0.0) == "0" + + def test_small_value(self): + result = UsageHelper.format_float_positional(0.00000001) + assert "0.00000001" == result + + def test_custom_precision(self): + result = UsageHelper.format_float_positional(1.123456789, precision=3) + assert result == "1.123" + + +# --------------------------------------------------------------------------- +# 6. SDK1 adapters already push usage +# --------------------------------------------------------------------------- + + +class TestAdapterUsageTracking: + def test_llm_calls_audit_push(self): + """Verify the LLM adapter imports and calls Audit.push_usage_data. + + This is a static analysis check — we verify the SDK1 LLM module + references Audit.push_usage_data, confirming adapters handle + usage tracking internally. + """ + import inspect + + from unstract.sdk1.llm import LLM + + source = inspect.getsource(LLM) + assert "push_usage_data" in source + assert "Audit" in source + + +# --------------------------------------------------------------------------- +# 7. answer_prompt handler returns metrics +# --------------------------------------------------------------------------- + + +class TestMetricsInResult: + @patch( + "unstract.sdk1.utils.indexing.IndexingUtils.generate_index_key", + return_value="doc-id-test", + ) + @patch( + "executor.executors.legacy_executor.LegacyExecutor._get_prompt_deps" + ) + @patch("executor.executors.legacy_executor.ExecutorToolShim") + def test_answer_prompt_returns_metrics( + self, mock_shim_cls, mock_get_deps, _mock_idx + ): + """answer_prompt result includes metrics dict.""" + from unstract.sdk1.execution.context import ExecutionContext + from unstract.sdk1.execution.registry import ExecutorRegistry + + ExecutorRegistry.clear() + from executor.executors.legacy_executor import LegacyExecutor + + if "legacy" not in ExecutorRegistry.list_executors(): + ExecutorRegistry.register(LegacyExecutor) + + executor = ExecutorRegistry.get("legacy") + + # Mock all dependencies + mock_llm = MagicMock() + mock_llm.get_metrics.return_value = {"total_tokens": 100} + mock_llm.get_usage_reason.return_value = "extraction" + mock_llm.complete.return_value = { + "response": MagicMock(text="test answer"), + "highlight_data": [], + "confidence_data": None, + "word_confidence_data": None, + "line_numbers": [], + "whisper_hash": "", + } + + mock_llm_cls = MagicMock(return_value=mock_llm) + mock_index = MagicMock() + mock_index.return_value.generate_index_key.return_value = "doc-123" + + mock_get_deps.return_value = ( + MagicMock(), # AnswerPromptService — use real for construct + MagicMock(), # RetrievalService + MagicMock(), # VariableReplacementService + mock_index, # Index + mock_llm_cls, # LLM + MagicMock(), # EmbeddingCompat + MagicMock(), # VectorDB + ) + + # Patch AnswerPromptService methods at their real location + with patch( + "executor.executors.answer_prompt.AnswerPromptService.extract_variable", + return_value="test prompt", + ), patch( + "executor.executors.answer_prompt.AnswerPromptService.construct_and_run_prompt", + return_value="test answer", + ): + ctx = ExecutionContext( + executor_name="legacy", + operation="answer_prompt", + run_id="run-metrics-001", + execution_source="tool", + organization_id="org-test", + request_id="req-metrics-001", + executor_params={ + "tool_settings": {}, + "outputs": [ + { + "name": "field1", + "prompt": "What is X?", + "chunk-size": 512, + "chunk-overlap": 64, + "vector-db": "vdb-1", + "embedding": "emb-1", + "x2text_adapter": "x2t-1", + "llm": "llm-1", + "type": "text", + "retrieval-strategy": "simple", + "similarity-top-k": 5, + }, + ], + "tool_id": "tool-1", + "file_hash": "hash123", + "file_path": "/tmp/test.txt", + "file_name": "test.txt", + "PLATFORM_SERVICE_API_KEY": "test-key", + }, + ) + result = executor.execute(ctx) + + assert result.success is True + assert "metrics" in result.data + assert "field1" in result.data["metrics"] + + ExecutorRegistry.clear()