From 8e01d870a4188b476280340fc66880d017816597 Mon Sep 17 00:00:00 2001 From: SentienceDEV Date: Sat, 28 Mar 2026 18:35:26 -0700 Subject: [PATCH 1/4] stabilize snapshot calls; boiler plate for dual agent --- predicate/__init__.py | 2 + predicate/agents/__init__.py | 9 + predicate/agents/agent_factory.py | 371 +++++++++++++++++++++ predicate/agents/planner_executor_agent.py | 144 +++++--- predicate/llm_provider.py | 70 ++++ tests/test_agent_factory.py | 306 +++++++++++++++++ tests/test_ollama_provider.py | 94 ++++++ traces/test-run.jsonl | 5 + 8 files changed, 951 insertions(+), 50 deletions(-) create mode 100644 predicate/agents/agent_factory.py create mode 100644 tests/test_agent_factory.py create mode 100644 tests/test_ollama_provider.py diff --git a/predicate/__init__.py b/predicate/__init__.py index 4dca10f..7898315 100644 --- a/predicate/__init__.py +++ b/predicate/__init__.py @@ -84,6 +84,7 @@ LocalLLMProvider, LocalVisionLLMProvider, MLXVLMProvider, + OllamaProvider, OpenAIProvider, ) from .models import ( # Agent Layer Models @@ -266,6 +267,7 @@ "LocalLLMProvider", "LocalVisionLLMProvider", "MLXVLMProvider", + "OllamaProvider", "PredicateAgent", "PredicateAgentAsync", "SentienceAgent", diff --git a/predicate/agents/__init__.py b/predicate/agents/__init__.py index 6029f11..fa895ad 100644 --- a/predicate/agents/__init__.py +++ b/predicate/agents/__init__.py @@ -61,6 +61,11 @@ validate_plan_smoothness, ) from .recovery import RecoveryCheckpoint, RecoveryState +from .agent_factory import ( + ConfigPreset, + create_planner_executor_agent, + get_config_preset, +) __all__ = [ # Automation Task @@ -104,4 +109,8 @@ # Recovery "RecoveryCheckpoint", "RecoveryState", + # Agent Factory + "ConfigPreset", + "create_planner_executor_agent", + "get_config_preset", ] diff --git a/predicate/agents/agent_factory.py b/predicate/agents/agent_factory.py new file mode 100644 index 0000000..d637b81 --- /dev/null +++ b/predicate/agents/agent_factory.py @@ -0,0 +1,371 @@ +""" +Agent factory for simplified agent creation. + +Provides convenient factory functions to create PlannerExecutorAgent instances +with sensible defaults, auto-provider detection, and auto-tracer creation. + +This module reduces boilerplate for common use cases: +- Local LLM via Ollama +- Cloud LLM via OpenAI/Anthropic +- Mixed configurations (cloud planner, local executor) +""" + +from __future__ import annotations + +import os +import uuid +from enum import Enum +from pathlib import Path +from typing import TYPE_CHECKING, Literal + +from ..llm_provider import AnthropicProvider, LLMProvider, OllamaProvider, OpenAIProvider +from ..tracer_factory import create_tracer +from ..tracing import JsonlTraceSink, Tracer +from .planner_executor_agent import ( + IntentHeuristics, + PlannerExecutorAgent, + PlannerExecutorConfig, + RetryConfig, +) + +if TYPE_CHECKING: + from collections.abc import Callable + + from ..models import Snapshot + + +# --------------------------------------------------------------------------- +# Config Presets +# --------------------------------------------------------------------------- + + +class ConfigPreset(str, Enum): + """Pre-configured settings for common use cases.""" + + DEFAULT = "default" + LOCAL_SMALL_MODEL = "local_small" # Optimized for 4B-8B local models + CLOUD_HIGH_QUALITY = "cloud_high" # Optimized for GPT-4/Claude + FAST_ITERATION = "fast" # Minimal retries for rapid development + PRODUCTION = "production" # Conservative settings for reliability + + +def get_config_preset(preset: ConfigPreset | str) -> PlannerExecutorConfig: + """ + Get a pre-configured PlannerExecutorConfig for common use cases. + + Args: + preset: Either a ConfigPreset enum value or string name + + Returns: + PlannerExecutorConfig with preset values + + Example: + >>> config = get_config_preset(ConfigPreset.LOCAL_SMALL_MODEL) + >>> agent = create_planner_executor_agent( + ... planner_model="qwen3:8b", + ... executor_model="qwen3:4b", + ... config=config, + ... ) + """ + if isinstance(preset, str): + preset = ConfigPreset(preset) + + if preset == ConfigPreset.LOCAL_SMALL_MODEL: + # Optimized for local 4B-8B models (Ollama) + # - Tighter token limits work better with small models + # - More lenient timeouts for slower local inference + # - Verbose mode helpful for debugging local model behavior + return PlannerExecutorConfig( + planner_max_tokens=1024, + executor_max_tokens=64, + retry=RetryConfig( + verify_timeout_s=15.0, + verify_max_attempts=6, + ), + verbose=True, + ) + + elif preset == ConfigPreset.CLOUD_HIGH_QUALITY: + # Optimized for high-capability cloud models (GPT-4, Claude) + # - Higher token limits for more detailed plans + # - Faster timeouts (cloud inference is quick) + # - Verbose off for cleaner output + return PlannerExecutorConfig( + planner_max_tokens=2048, + executor_max_tokens=128, + retry=RetryConfig( + verify_timeout_s=10.0, + verify_max_attempts=4, + ), + verbose=False, + ) + + elif preset == ConfigPreset.FAST_ITERATION: + # For rapid development and testing + # - Minimal retries to fail fast + # - Verbose for debugging + return PlannerExecutorConfig( + planner_max_tokens=1024, + executor_max_tokens=64, + retry=RetryConfig( + verify_timeout_s=5.0, + verify_max_attempts=2, + ), + verbose=True, + ) + + elif preset == ConfigPreset.PRODUCTION: + # Conservative settings for production reliability + # - More retries for robustness + # - Longer timeouts for edge cases + # - No verbose output + return PlannerExecutorConfig( + planner_max_tokens=2048, + executor_max_tokens=128, + retry=RetryConfig( + verify_timeout_s=20.0, + verify_max_attempts=8, + ), + verbose=False, + ) + + # Default + return PlannerExecutorConfig() + + +# --------------------------------------------------------------------------- +# Provider Detection and Creation +# --------------------------------------------------------------------------- + + +def _detect_provider(model: str) -> str: + """ + Auto-detect provider from model name. + + Args: + model: Model name/identifier + + Returns: + Provider name: "openai", "anthropic", or "ollama" + """ + model_lower = model.lower() + + # OpenAI models + if model_lower.startswith(("gpt-", "o1-", "o3-", "o4-")): + return "openai" + + # Anthropic models + if model_lower.startswith("claude-"): + return "anthropic" + + # Common Ollama model patterns + if any( + model_lower.startswith(p) + for p in ("qwen", "llama", "phi", "mistral", "gemma", "deepseek", "codellama") + ): + return "ollama" + + # Ollama models typically have "model:tag" format + if ":" in model: + return "ollama" + + # Default to ollama for unknown models (assume local) + return "ollama" + + +def _create_provider( + model: str, + provider: str, + ollama_base_url: str, + openai_api_key: str | None, + anthropic_api_key: str | None, +) -> LLMProvider: + """ + Create provider instance based on provider name. + + Args: + model: Model name + provider: Provider name ("auto", "ollama", "openai", "anthropic") + ollama_base_url: Ollama server URL + openai_api_key: OpenAI API key (can be None if using env var) + anthropic_api_key: Anthropic API key (can be None if using env var) + + Returns: + LLMProvider instance + """ + if provider == "auto": + provider = _detect_provider(model) + + if provider == "ollama": + return OllamaProvider(model=model, base_url=ollama_base_url) + + elif provider == "openai": + return OpenAIProvider(model=model, api_key=openai_api_key) + + elif provider == "anthropic": + return AnthropicProvider(model=model, api_key=anthropic_api_key) + + else: + raise ValueError( + f"Unknown provider: {provider}. " + f"Supported: 'auto', 'ollama', 'openai', 'anthropic'" + ) + + +def _create_auto_tracer( + planner_model: str, + executor_model: str, + run_id: str | None = None, +) -> Tracer: + """ + Create tracer based on environment configuration. + + If PREDICATE_API_KEY env var is set, creates a cloud tracer. + Otherwise, creates a local JsonlTraceSink tracer. + + Args: + planner_model: Planner model name (for metadata) + executor_model: Executor model name (for metadata) + run_id: Optional run ID (generates UUID if not provided) + + Returns: + Configured Tracer instance + """ + api_key = os.environ.get("PREDICATE_API_KEY") + + if run_id is None: + run_id = f"run-{uuid.uuid4().hex[:8]}" + + if api_key: + # Cloud tracing (auto-detected from env var) + return create_tracer( + api_key=api_key, + run_id=run_id, + llm_model=f"{planner_model}/{executor_model}", + agent_type="planner-executor", + ) + else: + # Local file tracing + trace_dir = Path("./traces") + trace_dir.mkdir(exist_ok=True) + trace_file = trace_dir / f"{run_id}.jsonl" + sink = JsonlTraceSink(str(trace_file)) + return Tracer(run_id=run_id, sink=sink) + + +# --------------------------------------------------------------------------- +# Factory Function +# --------------------------------------------------------------------------- + + +def create_planner_executor_agent( + *, + planner_model: str, + executor_model: str, + planner_provider: str = "auto", + executor_provider: str = "auto", + ollama_base_url: str = "http://localhost:11434", + openai_api_key: str | None = None, + anthropic_api_key: str | None = None, + tracer: Tracer | Literal["auto"] | None = "auto", + config: PlannerExecutorConfig | None = None, + intent_heuristics: IntentHeuristics | None = None, + context_formatter: Callable[[Snapshot, str], str] | None = None, + run_id: str | None = None, +) -> PlannerExecutorAgent: + """ + Create a PlannerExecutorAgent with sensible defaults and auto-detection. + + This factory function reduces boilerplate by: + - Auto-detecting provider from model name (gpt-* -> OpenAI, claude-* -> Anthropic, etc.) + - Auto-creating tracer (cloud if PREDICATE_API_KEY set, else local JSONL) + - Providing sensible default configuration + + Args: + planner_model: Model name for planning (e.g., "gpt-4o", "qwen3:8b") + executor_model: Model name for execution (e.g., "gpt-4o-mini", "qwen3:4b") + planner_provider: Provider for planner ("auto", "ollama", "openai", "anthropic") + executor_provider: Provider for executor ("auto", "ollama", "openai", "anthropic") + ollama_base_url: Ollama server URL (default: http://localhost:11434) + openai_api_key: OpenAI API key (defaults to OPENAI_API_KEY env var) + anthropic_api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var) + tracer: Tracer instance, "auto" to auto-create, or None to disable + config: PlannerExecutorConfig (uses default if not provided) + intent_heuristics: Optional domain-specific heuristics for element selection + context_formatter: Optional custom context formatter for snapshots + run_id: Optional run ID for tracing (generates UUID if not provided) + + Returns: + Configured PlannerExecutorAgent instance + + Example (minimal - local Ollama): + >>> agent = create_planner_executor_agent( + ... planner_model="qwen3:8b", + ... executor_model="qwen3:4b", + ... ) + + Example (cloud OpenAI): + >>> agent = create_planner_executor_agent( + ... planner_model="gpt-4o", + ... executor_model="gpt-4o-mini", + ... openai_api_key="sk-...", # or set OPENAI_API_KEY env var + ... ) + + Example (mixed - cloud planner, local executor): + >>> agent = create_planner_executor_agent( + ... planner_model="gpt-4o", + ... planner_provider="openai", + ... executor_model="qwen3:4b", + ... executor_provider="ollama", + ... openai_api_key="sk-...", + ... ) + + Example (with config preset): + >>> from predicate.agents import get_config_preset, ConfigPreset + >>> agent = create_planner_executor_agent( + ... planner_model="qwen3:8b", + ... executor_model="qwen3:4b", + ... config=get_config_preset(ConfigPreset.LOCAL_SMALL_MODEL), + ... ) + """ + # Create providers + planner = _create_provider( + model=planner_model, + provider=planner_provider, + ollama_base_url=ollama_base_url, + openai_api_key=openai_api_key, + anthropic_api_key=anthropic_api_key, + ) + + executor = _create_provider( + model=executor_model, + provider=executor_provider, + ollama_base_url=ollama_base_url, + openai_api_key=openai_api_key, + anthropic_api_key=anthropic_api_key, + ) + + # Create tracer + tracer_instance: Tracer | None = None + if tracer == "auto": + tracer_instance = _create_auto_tracer( + planner_model=planner_model, + executor_model=executor_model, + run_id=run_id, + ) + elif isinstance(tracer, Tracer): + tracer_instance = tracer + # else: tracer is None, leave tracer_instance as None + + # Use default config if not provided + if config is None: + config = PlannerExecutorConfig() + + return PlannerExecutorAgent( + planner=planner, + executor=executor, + config=config, + tracer=tracer_instance, + intent_heuristics=intent_heuristics, + context_formatter=context_formatter, + ) diff --git a/predicate/agents/planner_executor_agent.py b/predicate/agents/planner_executor_agent.py index d420c75..52f14fe 100644 --- a/predicate/agents/planner_executor_agent.py +++ b/predicate/agents/planner_executor_agent.py @@ -802,7 +802,19 @@ def detect_snapshot_failure(snap: Snapshot) -> tuple[bool, str | None]: Returns: (should_use_vision, reason) + + Note: If we have sufficient elements (10+), we should NOT trigger vision + fallback even if diagnostics suggest it. This handles cases where the + API incorrectly flags normal HTML pages as requiring vision. """ + elements = getattr(snap, "elements", []) or [] + element_count = len(elements) + + # If we have sufficient elements, the snapshot is usable + # regardless of what diagnostics say + if element_count >= 10: + return False, None + # Check explicit status field (tri-state: success, error, require_vision) status = getattr(snap, "status", "success") if status == "require_vision": @@ -820,13 +832,16 @@ def detect_snapshot_failure(snap: Snapshot) -> tuple[bool, str | None]: return True, "low_confidence" has_canvas = getattr(diag, "has_canvas", False) - elements = getattr(snap, "elements", []) or [] - if has_canvas and len(elements) < 5: + if has_canvas and element_count < 5: return True, "canvas_page" + # Check diagnostics.requires_vision only if few elements + requires_vision = getattr(diag, "requires_vision", False) + if requires_vision and element_count < 5: + return True, "diagnostics_requires_vision" + # Very few elements usually indicates a problem - elements = getattr(snap, "elements", []) or [] - if len(elements) < 3: + if element_count < 3: return True, "too_few_elements" return False, None @@ -1335,34 +1350,25 @@ def build_stepwise_planner_prompt( history_text += "\n" history_text += "\n" - system = """You are a web automation agent using ReAct-style planning. -Given the goal, current page state, and action history, decide the NEXT SINGLE ACTION. - -Available actions: -- CLICK: Click an element. Provide "intent" to describe which element. -- TYPE_AND_SUBMIT: Type text into an input and submit. Provide "input" (the text to type) and "intent" (which input field). -- SCROLL: Scroll the page. Provide "direction" (up or down). -- DONE: Task is complete. Use when the goal has been achieved. -- STUCK: Cannot proceed (e.g., login required without credentials, CAPTCHA, unexpected error). - -Output ONLY a JSON object: -{ - "action": "CLICK" | "TYPE_AND_SUBMIT" | "SCROLL" | "DONE" | "STUCK", - "intent": "description of target element (required for CLICK, TYPE_AND_SUBMIT)", - "input": "text to type (required for TYPE_AND_SUBMIT)", - "direction": "up" | "down" (required for SCROLL)", - "reasoning": "brief explanation of why this action" -} - -IMPORTANT RULES: -1. Look at the ACTUAL page elements provided - don't assume what should be there -2. If you don't see the expected element, try SCROLL or report STUCK -3. For TYPE_AND_SUBMIT, find a text input (textbox, searchbox, combobox) and specify what text to type -4. Use action history to avoid repeating failed actions -5. When the goal is achieved (e.g., item in cart, on checkout page), return DONE -6. If you need to login but have no credentials, return STUCK with explanation - -Return ONLY valid JSON. No prose, no code fences.""" + # Tight prompt optimized for small local models (7B) + system = """You are a browser automation planner. Decide the NEXT action. + +Actions: +- CLICK: Click an element. Set "intent" to element text/description. +- TYPE_AND_SUBMIT: Type and submit. Set "intent" and "input". +- SCROLL: Scroll page. Set "direction" to "up" or "down". +- DONE: Goal achieved. + +Output ONLY JSON: +{"action":"CLICK","intent":"button text","reasoning":"why"} +{"action":"TYPE_AND_SUBMIT","intent":"search box","input":"query","reasoning":"why"} +{"action":"DONE","intent":"completed","reasoning":"why"} + +RULES: +1. Look at ACTUAL elements shown - pick one that matches your intent +2. Do NOT repeat actions that already succeeded +3. If goal is done, return DONE immediately +4. No tags, no markdown, no prose - ONLY JSON""" user = f"""Goal: {goal} @@ -1397,24 +1403,36 @@ def build_executor_prompt( intent_line = f"Intent: {intent}\n" if intent else "" input_line = f"Text to type: \"{input_text}\"\n" if input_text else "" - system = """You are a careful web automation executor. -You must respond with exactly ONE action in this format: -- CLICK() -- TYPE(, "text") -- PRESS('key') -- SCROLL(direction) -- FINISH() - -Output only the action. No explanations.""" - - user = f"""You are controlling a browser via element IDs. + # Tight prompt optimized for small local models (4B-7B) + # Key: explicit format, no reasoning, clear failure consequence + if input_text: + # TYPE action needed + system = ( + "You are an executor for browser automation.\n" + "Return ONLY: TYPE(, \"text\") or CLICK()\n" + "If you output anything else, the action fails.\n" + "Do NOT output or any reasoning.\n" + "No prose, no markdown, no extra whitespace.\n" + "Example: TYPE(42, \"hello world\")" + ) + else: + # CLICK action (most common) + system = ( + "You are an executor for browser automation.\n" + "Return ONLY a single-line CLICK(id) action.\n" + "If you output anything else, the action fails.\n" + "Do NOT output or any reasoning.\n" + "No prose, no markdown, no extra whitespace.\n" + "Output MUST match exactly: CLICK() with no spaces.\n" + "Example: CLICK(12)" + ) -Goal: {goal} + user = f"""Goal: {goal} {intent_line}{input_line} -Elements (ID|role|text|importance|clickable|...): +Elements: {compact_context} -Return ONLY the action to take.""" +Return CLICK(id):""" return system, user @@ -2116,6 +2134,11 @@ async def _snapshot_with_escalation( if raw_screenshot: screenshot_b64 = raw_screenshot + # Format context FIRST - we always want the compact representation + # even if vision fallback is required, so the planner can see available elements + compact = self._format_context(snap, goal) + last_compact = compact + # Check for vision fallback needs_vision, reason = detect_snapshot_failure(snap) if needs_vision: @@ -2123,10 +2146,6 @@ async def _snapshot_with_escalation( vision_reason = reason break - # Format context - compact = self._format_context(snap, goal) - last_compact = compact - # If escalation disabled, we're done after first successful snapshot if not cfg.enabled: break @@ -3952,11 +3971,24 @@ async def run_stepwise( if self.config.verbose: print(f"\n[STEP {step_num}] Planning next action...", flush=True) + # 0. Stabilize before taking snapshot (wait for DOM to settle) + # This is critical for pages with delayed hydration/rendering + if self.config.stabilize_enabled: + await runtime.stabilize() + # 1. Take snapshot with escalation self._snapshot_context = await self._snapshot_with_escalation( runtime, goal=task_description, ) + + # Debug: log snapshot context details + if self.config.verbose: + snap = self._snapshot_context.snapshot + elem_count = len(snap.elements) if snap and snap.elements else 0 + compact_len = len(self._snapshot_context.compact_representation) if self._snapshot_context.compact_representation else 0 + requires_vision = self._snapshot_context.requires_vision + print(f" [STEPWISE-SNAPSHOT] Elements: {elem_count}, Compact len: {compact_len}, Requires vision: {requires_vision}", flush=True) current_url = await runtime.get_url() if hasattr(runtime, "get_url") else "" # 2. Build page context @@ -3965,6 +3997,18 @@ async def run_stepwise( else: page_context = "(page context disabled)" + # Debug: print page context for stepwise planning + if self.config.verbose and stepwise_cfg.include_page_context: + print("\n--- Stepwise Page Context ---", flush=True) + # Truncate to first 20 lines for readability + context_lines = page_context.split("\n") + if len(context_lines) > 20: + print("\n".join(context_lines[:20]), flush=True) + print(f"... ({len(context_lines) - 20} more lines)", flush=True) + else: + print(page_context, flush=True) + print("--- End Page Context ---\n", flush=True) + # 3. Get recent action history recent_history = action_history[-stepwise_cfg.action_history_limit:] diff --git a/predicate/llm_provider.py b/predicate/llm_provider.py index cfe02f6..d027268 100644 --- a/predicate/llm_provider.py +++ b/predicate/llm_provider.py @@ -376,6 +376,76 @@ def supports_vision(self) -> bool: return super().supports_vision() +class OllamaProvider(OpenAIProvider): + """ + Ollama local LLM provider via OpenAI-compatible API. + + Ollama serves models locally and provides an OpenAI-compatible endpoint at /v1. + This provider wraps OpenAIProvider with sensible defaults for local inference. + + Example: + >>> from predicate.llm_provider import OllamaProvider + >>> llm = OllamaProvider(model="qwen3:8b") + >>> response = llm.generate("You are helpful", "Hello!") + >>> print(response.content) + + Example with custom base URL: + >>> llm = OllamaProvider(model="llama3:8b", base_url="http://192.168.1.100:11434") + """ + + def __init__( + self, + model: str, + base_url: str = "http://localhost:11434", + **kwargs, + ): + """ + Initialize Ollama provider. + + Args: + model: Ollama model name (e.g., "qwen3:8b", "llama3:8b", "mistral:7b") + base_url: Ollama server URL (default: http://localhost:11434) + **kwargs: Additional parameters passed to OpenAIProvider + """ + # Ollama serves OpenAI-compatible API at /v1 + super().__init__( + model=model, + base_url=f"{base_url.rstrip('/')}/v1", + api_key="ollama", # Ollama doesn't require a real API key + **kwargs, + ) + self._ollama_base_url = base_url + + @property + def is_local(self) -> bool: + """Ollama runs locally.""" + return True + + @property + def provider_name(self) -> str: + """Provider identifier.""" + return "ollama" + + def supports_json_mode(self) -> bool: + """ + JSON mode support varies by Ollama model. + + Most instruction-tuned models (qwen, llama, mistral) can output JSON + with proper prompting, but native JSON mode is model-dependent. + """ + # Conservative default: rely on prompt engineering for JSON + return False + + def supports_vision(self) -> bool: + """ + Vision support varies by Ollama model. + + Models like llava, bakllava support vision. Check model capabilities. + """ + model_lower = self._model_name.lower() + return any(x in model_lower for x in ["llava", "bakllava", "moondream"]) + + class AnthropicProvider(LLMProvider): """ Anthropic provider implementation (Claude 3 Opus, Sonnet, Haiku, etc.) diff --git a/tests/test_agent_factory.py b/tests/test_agent_factory.py new file mode 100644 index 0000000..9fbe6bd --- /dev/null +++ b/tests/test_agent_factory.py @@ -0,0 +1,306 @@ +"""Tests for agent_factory module.""" + +import os +from unittest.mock import MagicMock, patch + +import pytest + +from predicate.agents.agent_factory import ( + ConfigPreset, + _create_auto_tracer, + _create_provider, + _detect_provider, + create_planner_executor_agent, + get_config_preset, +) +from predicate.agents.planner_executor_agent import PlannerExecutorAgent, PlannerExecutorConfig +from predicate.llm_provider import AnthropicProvider, OllamaProvider, OpenAIProvider +from predicate.tracing import Tracer + + +class TestDetectProvider: + """Test provider auto-detection from model names.""" + + def test_detect_openai_gpt4(self): + """Should detect OpenAI for GPT-4 models.""" + assert _detect_provider("gpt-4o") == "openai" + assert _detect_provider("gpt-4-turbo") == "openai" + assert _detect_provider("gpt-4o-mini") == "openai" + assert _detect_provider("GPT-4o") == "openai" # Case insensitive + + def test_detect_openai_o1(self): + """Should detect OpenAI for o1 reasoning models.""" + assert _detect_provider("o1-preview") == "openai" + assert _detect_provider("o1-mini") == "openai" + + def test_detect_openai_o3(self): + """Should detect OpenAI for o3 models.""" + assert _detect_provider("o3-mini") == "openai" + + def test_detect_anthropic_claude(self): + """Should detect Anthropic for Claude models.""" + assert _detect_provider("claude-3-opus-20240229") == "anthropic" + assert _detect_provider("claude-3-5-sonnet-20241022") == "anthropic" + assert _detect_provider("claude-3-haiku-20240307") == "anthropic" + assert _detect_provider("Claude-3-Opus") == "anthropic" # Case insensitive + + def test_detect_ollama_qwen(self): + """Should detect Ollama for Qwen models.""" + assert _detect_provider("qwen3:8b") == "ollama" + assert _detect_provider("qwen2.5:7b-instruct") == "ollama" + assert _detect_provider("Qwen3:4b") == "ollama" + + def test_detect_ollama_llama(self): + """Should detect Ollama for Llama models.""" + assert _detect_provider("llama3:8b") == "ollama" + assert _detect_provider("llama3.2:3b") == "ollama" + assert _detect_provider("codellama:7b") == "ollama" + + def test_detect_ollama_other_local(self): + """Should detect Ollama for other common local models.""" + assert _detect_provider("phi3:mini") == "ollama" + assert _detect_provider("mistral:7b") == "ollama" + assert _detect_provider("gemma:2b") == "ollama" + assert _detect_provider("deepseek:6.7b") == "ollama" + + def test_detect_ollama_by_tag_format(self): + """Should detect Ollama for model:tag format.""" + assert _detect_provider("custom-model:latest") == "ollama" + assert _detect_provider("my-finetuned:v2") == "ollama" + + def test_detect_unknown_defaults_ollama(self): + """Unknown models should default to Ollama.""" + assert _detect_provider("some-unknown-model") == "ollama" + + +class TestCreateProvider: + """Test provider creation.""" + + def test_create_ollama_provider(self): + """Should create OllamaProvider for ollama.""" + provider = _create_provider( + model="qwen3:8b", + provider="ollama", + ollama_base_url="http://localhost:11434", + openai_api_key=None, + anthropic_api_key=None, + ) + assert isinstance(provider, OllamaProvider) + assert provider.model_name == "qwen3:8b" + + def test_create_openai_provider(self): + """Should create OpenAIProvider for openai.""" + provider = _create_provider( + model="gpt-4o", + provider="openai", + ollama_base_url="http://localhost:11434", + openai_api_key="test-key", + anthropic_api_key=None, + ) + assert isinstance(provider, OpenAIProvider) + assert provider.model_name == "gpt-4o" + + def test_create_anthropic_provider(self): + """Should create AnthropicProvider for anthropic.""" + provider = _create_provider( + model="claude-3-opus-20240229", + provider="anthropic", + ollama_base_url="http://localhost:11434", + openai_api_key=None, + anthropic_api_key="test-key", + ) + assert isinstance(provider, AnthropicProvider) + assert provider.model_name == "claude-3-opus-20240229" + + def test_create_provider_auto_detection(self): + """Should auto-detect provider when 'auto' specified.""" + provider = _create_provider( + model="qwen3:8b", + provider="auto", + ollama_base_url="http://localhost:11434", + openai_api_key=None, + anthropic_api_key=None, + ) + assert isinstance(provider, OllamaProvider) + + def test_create_provider_invalid_raises(self): + """Should raise ValueError for unknown provider.""" + with pytest.raises(ValueError, match="Unknown provider"): + _create_provider( + model="test", + provider="invalid-provider", + ollama_base_url="http://localhost:11434", + openai_api_key=None, + anthropic_api_key=None, + ) + + +class TestConfigPresets: + """Test configuration presets.""" + + def test_get_config_preset_default(self): + """Should return default config for DEFAULT preset.""" + config = get_config_preset(ConfigPreset.DEFAULT) + assert isinstance(config, PlannerExecutorConfig) + + def test_get_config_preset_local_small_model(self): + """Should return optimized config for local small models.""" + config = get_config_preset(ConfigPreset.LOCAL_SMALL_MODEL) + assert isinstance(config, PlannerExecutorConfig) + # Check optimized settings + assert config.planner_max_tokens == 1024 + assert config.executor_max_tokens == 64 + assert config.retry.verify_timeout_s == 15.0 + assert config.retry.verify_max_attempts == 6 + assert config.verbose is True + + def test_get_config_preset_cloud_high_quality(self): + """Should return optimized config for cloud models.""" + config = get_config_preset(ConfigPreset.CLOUD_HIGH_QUALITY) + assert isinstance(config, PlannerExecutorConfig) + assert config.planner_max_tokens == 2048 + assert config.executor_max_tokens == 128 + assert config.retry.verify_timeout_s == 10.0 + assert config.verbose is False + + def test_get_config_preset_fast_iteration(self): + """Should return fast iteration config.""" + config = get_config_preset(ConfigPreset.FAST_ITERATION) + assert isinstance(config, PlannerExecutorConfig) + assert config.retry.verify_max_attempts == 2 + assert config.verbose is True + + def test_get_config_preset_production(self): + """Should return production config with conservative settings.""" + config = get_config_preset(ConfigPreset.PRODUCTION) + assert isinstance(config, PlannerExecutorConfig) + assert config.retry.verify_max_attempts == 8 + assert config.retry.verify_timeout_s == 20.0 + assert config.verbose is False + + def test_get_config_preset_by_string(self): + """Should accept string preset names.""" + config = get_config_preset("local_small") + assert isinstance(config, PlannerExecutorConfig) + assert config.planner_max_tokens == 1024 + + +class TestCreateAutoTracer: + """Test automatic tracer creation.""" + + def test_create_local_tracer_no_api_key(self): + """Should create local tracer when no API key set.""" + with patch.dict(os.environ, {}, clear=True): + # Remove PREDICATE_API_KEY if it exists + os.environ.pop("PREDICATE_API_KEY", None) + tracer = _create_auto_tracer( + planner_model="qwen3:8b", + executor_model="qwen3:4b", + run_id="test-run", + ) + assert isinstance(tracer, Tracer) + assert tracer.run_id == "test-run" + + def test_create_tracer_generates_run_id(self): + """Should generate run_id if not provided.""" + with patch.dict(os.environ, {}, clear=True): + os.environ.pop("PREDICATE_API_KEY", None) + tracer = _create_auto_tracer( + planner_model="qwen3:8b", + executor_model="qwen3:4b", + ) + assert isinstance(tracer, Tracer) + assert tracer.run_id.startswith("run-") + + +class TestCreatePlannerExecutorAgent: + """Test the main factory function.""" + + def test_create_agent_minimal_local(self): + """Should create agent with minimal local config.""" + agent = create_planner_executor_agent( + planner_model="qwen3:8b", + executor_model="qwen3:4b", + tracer=None, # Disable tracer for test + ) + assert isinstance(agent, PlannerExecutorAgent) + + def test_create_agent_with_explicit_providers(self): + """Should respect explicit provider settings.""" + agent = create_planner_executor_agent( + planner_model="qwen3:8b", + executor_model="qwen3:4b", + planner_provider="ollama", + executor_provider="ollama", + tracer=None, + ) + assert isinstance(agent, PlannerExecutorAgent) + + def test_create_agent_with_custom_config(self): + """Should use provided config.""" + custom_config = PlannerExecutorConfig(verbose=True, planner_max_tokens=512) + agent = create_planner_executor_agent( + planner_model="qwen3:8b", + executor_model="qwen3:4b", + config=custom_config, + tracer=None, + ) + assert isinstance(agent, PlannerExecutorAgent) + + def test_create_agent_with_preset(self): + """Should work with config presets.""" + agent = create_planner_executor_agent( + planner_model="qwen3:8b", + executor_model="qwen3:4b", + config=get_config_preset(ConfigPreset.LOCAL_SMALL_MODEL), + tracer=None, + ) + assert isinstance(agent, PlannerExecutorAgent) + + def test_create_agent_with_custom_tracer(self): + """Should use provided tracer.""" + mock_tracer = MagicMock(spec=Tracer) + agent = create_planner_executor_agent( + planner_model="qwen3:8b", + executor_model="qwen3:4b", + tracer=mock_tracer, + ) + assert isinstance(agent, PlannerExecutorAgent) + + def test_create_agent_mixed_providers(self): + """Should support mixed cloud/local configuration.""" + agent = create_planner_executor_agent( + planner_model="gpt-4o", + planner_provider="openai", + executor_model="qwen3:4b", + executor_provider="ollama", + openai_api_key="test-key", + tracer=None, + ) + assert isinstance(agent, PlannerExecutorAgent) + + def test_create_agent_custom_ollama_base_url(self): + """Should respect custom Ollama base URL.""" + agent = create_planner_executor_agent( + planner_model="qwen3:8b", + executor_model="qwen3:4b", + ollama_base_url="http://192.168.1.100:11434", + tracer=None, + ) + assert isinstance(agent, PlannerExecutorAgent) + + +class TestAgentFactoryImports: + """Test that factory is properly exported.""" + + def test_import_from_agents_module(self): + """Factory should be importable from predicate.agents.""" + from predicate.agents import ( + ConfigPreset, + create_planner_executor_agent, + get_config_preset, + ) + + assert create_planner_executor_agent is not None + assert ConfigPreset is not None + assert get_config_preset is not None diff --git a/tests/test_ollama_provider.py b/tests/test_ollama_provider.py new file mode 100644 index 0000000..2a140fe --- /dev/null +++ b/tests/test_ollama_provider.py @@ -0,0 +1,94 @@ +"""Tests for OllamaProvider.""" + +import pytest + +from predicate.llm_provider import OllamaProvider, OpenAIProvider + + +class TestOllamaProvider: + """Test suite for OllamaProvider.""" + + def test_ollama_provider_is_subclass_of_openai(self): + """OllamaProvider should inherit from OpenAIProvider.""" + assert issubclass(OllamaProvider, OpenAIProvider) + + def test_ollama_provider_default_base_url(self): + """OllamaProvider should use default localhost:11434 base URL.""" + provider = OllamaProvider(model="qwen3:8b") + # The internal client should have base_url set to /v1 endpoint + assert provider._ollama_base_url == "http://localhost:11434" + + def test_ollama_provider_custom_base_url(self): + """OllamaProvider should accept custom base URL.""" + provider = OllamaProvider(model="llama3:8b", base_url="http://192.168.1.100:11434") + assert provider._ollama_base_url == "http://192.168.1.100:11434" + + def test_ollama_provider_strips_trailing_slash(self): + """OllamaProvider should strip trailing slash from base URL.""" + provider = OllamaProvider(model="mistral:7b", base_url="http://localhost:11434/") + # The /v1 should be appended correctly without double slash + assert provider._ollama_base_url == "http://localhost:11434/" + # The actual OpenAI client base_url should be properly formed + # (trailing slash stripped before /v1 is appended) + + def test_ollama_provider_is_local_property(self): + """OllamaProvider.is_local should return True.""" + provider = OllamaProvider(model="qwen3:4b") + assert provider.is_local is True + + def test_ollama_provider_name_property(self): + """OllamaProvider.provider_name should return 'ollama'.""" + provider = OllamaProvider(model="phi3:mini") + assert provider.provider_name == "ollama" + + def test_ollama_provider_model_name(self): + """OllamaProvider should correctly report model name.""" + provider = OllamaProvider(model="qwen3:8b") + assert provider.model_name == "qwen3:8b" + + def test_ollama_provider_supports_json_mode_false(self): + """OllamaProvider should return False for supports_json_mode (conservative default).""" + provider = OllamaProvider(model="qwen3:8b") + assert provider.supports_json_mode() is False + + def test_ollama_provider_supports_vision_llava(self): + """OllamaProvider should detect vision support for llava models.""" + provider = OllamaProvider(model="llava:7b") + assert provider.supports_vision() is True + + def test_ollama_provider_supports_vision_bakllava(self): + """OllamaProvider should detect vision support for bakllava models.""" + provider = OllamaProvider(model="bakllava:latest") + assert provider.supports_vision() is True + + def test_ollama_provider_supports_vision_moondream(self): + """OllamaProvider should detect vision support for moondream models.""" + provider = OllamaProvider(model="moondream:1.8b") + assert provider.supports_vision() is True + + def test_ollama_provider_no_vision_for_text_models(self): + """OllamaProvider should return False for non-vision models.""" + provider = OllamaProvider(model="qwen3:8b") + assert provider.supports_vision() is False + + provider = OllamaProvider(model="llama3:8b") + assert provider.supports_vision() is False + + provider = OllamaProvider(model="mistral:7b") + assert provider.supports_vision() is False + + +class TestOllamaProviderImport: + """Test that OllamaProvider is properly exported.""" + + def test_import_from_llm_provider(self): + """OllamaProvider should be importable from predicate.llm_provider.""" + from predicate.llm_provider import OllamaProvider + + assert OllamaProvider is not None + + def test_import_from_predicate(self): + """OllamaProvider should be importable from predicate package root.""" + from predicate import OllamaProvider + + assert OllamaProvider is not None diff --git a/traces/test-run.jsonl b/traces/test-run.jsonl index 85084c5..6ae86f1 100644 --- a/traces/test-run.jsonl +++ b/traces/test-run.jsonl @@ -8,3 +8,8 @@ {"v": 1, "type": "run_start", "ts": "2026-02-11T03:10:06.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1770779406555} {"v": 1, "type": "run_start", "ts": "2026-02-11T03:10:06.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1770779406557} {"v": 1, "type": "run_start", "ts": "2026-02-11T03:10:06.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1770779406568} +{"v": 1, "type": "run_start", "ts": "2026-03-29T01:08:28.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774746508041} +{"v": 1, "type": "run_start", "ts": "2026-03-29T01:08:28.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774746508043} +{"v": 1, "type": "run_start", "ts": "2026-03-29T01:08:28.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774746508045} +{"v": 1, "type": "run_start", "ts": "2026-03-29T01:08:28.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774746508046} +{"v": 1, "type": "run_start", "ts": "2026-03-29T01:08:28.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774746508134} From 5abc293386b40a74c7e4d4052418e523372a97a4 Mon Sep 17 00:00:00 2001 From: SentienceDEV Date: Sat, 28 Mar 2026 18:53:56 -0700 Subject: [PATCH 2/4] fix tests --- predicate/llm_provider.py | 225 ++++++++++++++++++++++++++++++++-- tests/test_ollama_provider.py | 20 +-- traces/test-run.jsonl | 5 + 3 files changed, 231 insertions(+), 19 deletions(-) diff --git a/predicate/llm_provider.py b/predicate/llm_provider.py index d027268..9aa53f2 100644 --- a/predicate/llm_provider.py +++ b/predicate/llm_provider.py @@ -376,12 +376,12 @@ def supports_vision(self) -> bool: return super().supports_vision() -class OllamaProvider(OpenAIProvider): +class OllamaProvider(LLMProvider): """ Ollama local LLM provider via OpenAI-compatible API. Ollama serves models locally and provides an OpenAI-compatible endpoint at /v1. - This provider wraps OpenAIProvider with sensible defaults for local inference. + This provider uses HTTP requests directly without requiring the openai package. Example: >>> from predicate.llm_provider import OllamaProvider @@ -397,6 +397,7 @@ def __init__( self, model: str, base_url: str = "http://localhost:11434", + timeout_seconds: float = 120.0, **kwargs, ): """ @@ -405,16 +406,117 @@ def __init__( Args: model: Ollama model name (e.g., "qwen3:8b", "llama3:8b", "mistral:7b") base_url: Ollama server URL (default: http://localhost:11434) - **kwargs: Additional parameters passed to OpenAIProvider + timeout_seconds: Request timeout in seconds (default: 120) + **kwargs: Additional parameters (reserved for future use) """ - # Ollama serves OpenAI-compatible API at /v1 - super().__init__( - model=model, - base_url=f"{base_url.rstrip('/')}/v1", - api_key="ollama", # Ollama doesn't require a real API key - **kwargs, + super().__init__(model) + self._ollama_base_url = base_url.rstrip("/") + self._api_base_url = f"{self._ollama_base_url}/v1" + self._timeout_seconds = timeout_seconds + + def generate( + self, + system_prompt: str, + user_prompt: str, + temperature: float = 0.0, + max_tokens: int | None = None, + json_mode: bool = False, + **kwargs, + ) -> LLMResponse: + """ + Generate response using Ollama's OpenAI-compatible API. + + Args: + system_prompt: System instruction + user_prompt: User query + temperature: Sampling temperature (0.0 = deterministic, 1.0 = creative) + max_tokens: Maximum tokens to generate + json_mode: Enable JSON response format (model-dependent support) + **kwargs: Additional API parameters (max_new_tokens is mapped to max_tokens) + + Returns: + LLMResponse object + """ + import json + import urllib.request + import urllib.error + + # Handle max_new_tokens -> max_tokens mapping for cross-provider compatibility + if "max_new_tokens" in kwargs: + if max_tokens is None: + max_tokens = kwargs.pop("max_new_tokens") + else: + kwargs.pop("max_new_tokens") # max_tokens takes precedence + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": user_prompt}) + + # Build API parameters + api_params: dict[str, Any] = { + "model": self._model_name, + "messages": messages, + "temperature": temperature, + } + + if max_tokens: + api_params["max_tokens"] = max_tokens + + if json_mode and self.supports_json_mode(): + api_params["response_format"] = {"type": "json_object"} + + # Merge additional parameters (excluding internal ones) + for key, value in kwargs.items(): + if key not in api_params: + api_params[key] = value + + # Make HTTP request to Ollama's OpenAI-compatible endpoint + url = f"{self._api_base_url}/chat/completions" + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer ollama", # Ollama doesn't require a real API key + } + + try: + request_data = json.dumps(api_params).encode("utf-8") + req = urllib.request.Request(url, data=request_data, headers=headers, method="POST") + with urllib.request.urlopen(req, timeout=self._timeout_seconds) as response: + response_data = json.loads(response.read().decode("utf-8")) + except urllib.error.URLError as e: + raise RuntimeError( + f"Failed to connect to Ollama at {self._ollama_base_url}. " + f"Ensure Ollama is running: {e}" + ) from e + except urllib.error.HTTPError as e: + raise RuntimeError( + f"Ollama API error: {e.code} - {e.reason}" + ) from e + except json.JSONDecodeError as e: + raise RuntimeError(f"Failed to parse Ollama response: {e}") from e + + # Parse response + choice = response_data.get("choices", [{}])[0] + usage = response_data.get("usage", {}) + message = choice.get("message", {}) + + return LLMResponseBuilder.from_openai_format( + content=message.get("content", ""), + prompt_tokens=usage.get("prompt_tokens"), + completion_tokens=usage.get("completion_tokens"), + total_tokens=usage.get("total_tokens"), + model_name=response_data.get("model", self._model_name), + finish_reason=choice.get("finish_reason"), ) - self._ollama_base_url = base_url + + @property + def model_name(self) -> str: + return self._model_name + + @property + def ollama_base_url(self) -> str: + """Return the Ollama server base URL.""" + return self._ollama_base_url @property def is_local(self) -> bool: @@ -445,6 +547,109 @@ def supports_vision(self) -> bool: model_lower = self._model_name.lower() return any(x in model_lower for x in ["llava", "bakllava", "moondream"]) + def generate_with_image( + self, + system_prompt: str, + user_prompt: str, + image_base64: str, + temperature: float = 0.0, + max_tokens: int | None = None, + **kwargs, + ) -> LLMResponse: + """ + Generate response with image input using Ollama's vision models. + + Args: + system_prompt: System instruction + user_prompt: User query + image_base64: Base64-encoded image (PNG or JPEG) + temperature: Sampling temperature (0.0 = deterministic) + max_tokens: Maximum tokens to generate + **kwargs: Additional API parameters + + Returns: + LLMResponse object + + Raises: + NotImplementedError: If model doesn't support vision + """ + import json + import urllib.request + import urllib.error + + if not self.supports_vision(): + raise NotImplementedError( + f"Model {self._model_name} does not support vision. " + "Use llava, bakllava, or moondream models." + ) + + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + + # Vision message format with image_url + messages.append( + { + "role": "user", + "content": [ + {"type": "text", "text": user_prompt}, + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image_base64}"}, + }, + ], + } + ) + + # Build API parameters + api_params: dict[str, Any] = { + "model": self._model_name, + "messages": messages, + "temperature": temperature, + } + + if max_tokens: + api_params["max_tokens"] = max_tokens + + # Merge additional parameters + for key, value in kwargs.items(): + if key not in api_params: + api_params[key] = value + + # Make HTTP request + url = f"{self._api_base_url}/chat/completions" + headers = { + "Content-Type": "application/json", + "Authorization": "Bearer ollama", + } + + try: + request_data = json.dumps(api_params).encode("utf-8") + req = urllib.request.Request(url, data=request_data, headers=headers, method="POST") + with urllib.request.urlopen(req, timeout=self._timeout_seconds) as response: + response_data = json.loads(response.read().decode("utf-8")) + except urllib.error.URLError as e: + raise RuntimeError( + f"Failed to connect to Ollama at {self._ollama_base_url}. " + f"Ensure Ollama is running: {e}" + ) from e + except urllib.error.HTTPError as e: + raise RuntimeError(f"Ollama API error: {e.code} - {e.reason}") from e + + # Parse response + choice = response_data.get("choices", [{}])[0] + usage = response_data.get("usage", {}) + message = choice.get("message", {}) + + return LLMResponseBuilder.from_openai_format( + content=message.get("content", ""), + prompt_tokens=usage.get("prompt_tokens"), + completion_tokens=usage.get("completion_tokens"), + total_tokens=usage.get("total_tokens"), + model_name=response_data.get("model", self._model_name), + finish_reason=choice.get("finish_reason"), + ) + class AnthropicProvider(LLMProvider): """ diff --git a/tests/test_ollama_provider.py b/tests/test_ollama_provider.py index 2a140fe..f3dadc8 100644 --- a/tests/test_ollama_provider.py +++ b/tests/test_ollama_provider.py @@ -2,34 +2,36 @@ import pytest -from predicate.llm_provider import OllamaProvider, OpenAIProvider +from predicate.llm_provider import OllamaProvider, LLMProvider class TestOllamaProvider: """Test suite for OllamaProvider.""" - def test_ollama_provider_is_subclass_of_openai(self): - """OllamaProvider should inherit from OpenAIProvider.""" - assert issubclass(OllamaProvider, OpenAIProvider) + def test_ollama_provider_is_subclass_of_llm_provider(self): + """OllamaProvider should inherit from LLMProvider (not OpenAIProvider).""" + assert issubclass(OllamaProvider, LLMProvider) def test_ollama_provider_default_base_url(self): """OllamaProvider should use default localhost:11434 base URL.""" provider = OllamaProvider(model="qwen3:8b") - # The internal client should have base_url set to /v1 endpoint + # The internal base URL should be set correctly assert provider._ollama_base_url == "http://localhost:11434" + assert provider.ollama_base_url == "http://localhost:11434" def test_ollama_provider_custom_base_url(self): """OllamaProvider should accept custom base URL.""" provider = OllamaProvider(model="llama3:8b", base_url="http://192.168.1.100:11434") assert provider._ollama_base_url == "http://192.168.1.100:11434" + assert provider.ollama_base_url == "http://192.168.1.100:11434" def test_ollama_provider_strips_trailing_slash(self): """OllamaProvider should strip trailing slash from base URL.""" provider = OllamaProvider(model="mistral:7b", base_url="http://localhost:11434/") - # The /v1 should be appended correctly without double slash - assert provider._ollama_base_url == "http://localhost:11434/" - # The actual OpenAI client base_url should be properly formed - # (trailing slash stripped before /v1 is appended) + # The trailing slash should be stripped + assert provider._ollama_base_url == "http://localhost:11434" + # The API base URL should be properly formed + assert provider._api_base_url == "http://localhost:11434/v1" def test_ollama_provider_is_local_property(self): """OllamaProvider.is_local should return True.""" diff --git a/traces/test-run.jsonl b/traces/test-run.jsonl index 6ae86f1..c75cbac 100644 --- a/traces/test-run.jsonl +++ b/traces/test-run.jsonl @@ -13,3 +13,8 @@ {"v": 1, "type": "run_start", "ts": "2026-03-29T01:08:28.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774746508045} {"v": 1, "type": "run_start", "ts": "2026-03-29T01:08:28.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774746508046} {"v": 1, "type": "run_start", "ts": "2026-03-29T01:08:28.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774746508134} +{"v": 1, "type": "run_start", "ts": "2026-03-29T01:51:59.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774749119216} +{"v": 1, "type": "run_start", "ts": "2026-03-29T01:51:59.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774749119218} +{"v": 1, "type": "run_start", "ts": "2026-03-29T01:51:59.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774749119219} +{"v": 1, "type": "run_start", "ts": "2026-03-29T01:51:59.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774749119221} +{"v": 1, "type": "run_start", "ts": "2026-03-29T01:51:59.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774749119306} From 3f53636a64d224b7f56d1e8f1e209602d29c4147 Mon Sep 17 00:00:00 2001 From: SentienceDEV Date: Sat, 28 Mar 2026 19:55:17 -0700 Subject: [PATCH 3/4] simplified boilerplate and fix tests --- predicate/tracer_factory.py | 105 ++++++++++++++++++---- predicate/tracing.py | 31 ++++++- tests/test_agent_factory.py | 20 ++++- tests/unit/test_planner_executor_agent.py | 4 +- traces/test-run.jsonl | 5 ++ 5 files changed, 145 insertions(+), 20 deletions(-) diff --git a/predicate/tracer_factory.py b/predicate/tracer_factory.py index 45b5ab1..ad72f88 100644 --- a/predicate/tracer_factory.py +++ b/predicate/tracer_factory.py @@ -2,11 +2,19 @@ Tracer factory with automatic tier detection. Provides convenient factory function for creating tracers with cloud upload support. + +Key Features: +- Automatic cloud upload when API key is provided +- Auto-close on process exit (atexit) to prevent data loss +- Context manager support for both sync and async workflows +- Orphaned trace recovery from previous crashes """ +import atexit import gzip import os import uuid +import weakref from collections.abc import Callable from pathlib import Path from typing import Any, Optional @@ -17,6 +25,60 @@ from predicate.constants import PREDICATE_API_URL from predicate.tracing import JsonlTraceSink, Tracer +# Global registry of active tracers for atexit cleanup +# Using a set of tracer IDs mapped to weak references +_active_tracers: dict[int, weakref.ref[Tracer]] = {} +_atexit_registered = False + + +def _cleanup_tracers_on_exit() -> None: + """ + Cleanup handler called on process exit. + + Closes all active tracers to ensure trace data is uploaded to cloud. + This prevents data loss when users forget to call tracer.close(). + """ + for tracer_id, tracer_ref in list(_active_tracers.items()): + tracer = tracer_ref() + if tracer is not None: + try: + tracer.close() + except Exception: + pass # Best effort - don't raise during exit + + +def _register_tracer_for_cleanup(tracer: Tracer) -> None: + """ + Register a tracer for automatic cleanup on process exit. + + Args: + tracer: Tracer instance to register + """ + global _atexit_registered + + # Use id() as key to avoid hashability issues + tracer_id = id(tracer) + _active_tracers[tracer_id] = weakref.ref(tracer) + + # Set callback on tracer so it unregisters itself when closed + tracer._on_close_callback = _unregister_tracer + + # Register atexit handler on first tracer creation + if not _atexit_registered: + atexit.register(_cleanup_tracers_on_exit) + _atexit_registered = True + + +def _unregister_tracer(tracer: Tracer) -> None: + """ + Unregister a tracer from cleanup (called when tracer.close() is invoked). + + Args: + tracer: Tracer instance to unregister + """ + tracer_id = id(tracer) + _active_tracers.pop(tracer_id, None) + def _emit_run_start( tracer: Tracer, @@ -58,12 +120,17 @@ def create_tracer( auto_emit_run_start: bool = True, ) -> Tracer: """ - Create tracer with automatic tier detection. + Create tracer with automatic tier detection and auto-cleanup. Tier Detection: - If api_key is provided: Try to initialize CloudTraceSink (Pro/Enterprise) - If cloud init fails or no api_key: Fall back to JsonlTraceSink (Free tier) + Auto-Cleanup: + - Tracers are automatically registered for cleanup on process exit (atexit) + - This ensures trace data is uploaded even if tracer.close() is not called + - For best practice, still call tracer.close() explicitly or use context manager + Args: api_key: Sentience API key (e.g., "sk_pro_xxxxx") - Free tier: None or empty @@ -92,7 +159,21 @@ def create_tracer( Tracer configured with appropriate sink Example: - >>> # Pro tier user with goal + >>> # RECOMMENDED: Use as context manager (auto-closes on exit) + >>> with create_tracer(api_key="sk_pro_xyz", goal="Add to cart") as tracer: + ... agent = SentienceAgent(browser, llm, tracer=tracer) + ... agent.act("Click search") + >>> # tracer.close() called automatically + >>> + >>> # ALTERNATIVE: Manual close (still safe - atexit cleanup as fallback) + >>> tracer = create_tracer(api_key="sk_pro_xyz", goal="Add to cart") + >>> try: + ... agent = SentienceAgent(browser, llm, tracer=tracer) + ... agent.act("Click search") + ... finally: + ... tracer.close() # Best practice: explicit close + >>> + >>> # Pro tier with all metadata >>> tracer = create_tracer( ... api_key="sk_pro_xyz", ... run_id="demo", @@ -101,8 +182,6 @@ def create_tracer( ... llm_model="gpt-4-turbo", ... start_url="https://amazon.com" ... ) - >>> # Returns: Tracer with CloudTraceSink - >>> # run_start event is automatically emitted >>> >>> # With screenshot processor for PII redaction >>> def redact_pii(screenshot_base64: str) -> str: @@ -113,20 +192,9 @@ def create_tracer( ... api_key="sk_pro_xyz", ... screenshot_processor=redact_pii ... ) - >>> # Screenshots will be processed before upload >>> - >>> # Free tier user + >>> # Free tier user (local-only traces) >>> tracer = create_tracer(run_id="demo") - >>> # Returns: Tracer with JsonlTraceSink (local-only) - >>> - >>> # Disable auto-emit for manual control - >>> tracer = create_tracer(run_id="demo", auto_emit_run_start=False) - >>> tracer.emit_run_start("MyAgent", "gpt-4o") # Manual emit - >>> - >>> # Use with agent - >>> agent = SentienceAgent(browser, llm, tracer=tracer) - >>> agent.act("Click search") - >>> tracer.close() # Uploads to cloud if Pro tier """ if run_id is None: run_id = str(uuid.uuid4()) @@ -187,6 +255,8 @@ def create_tracer( ), screenshot_processor=screenshot_processor, ) + # Register for atexit cleanup (safety net for forgotten close()) + _register_tracer_for_cleanup(tracer) # Auto-emit run_start for complete trace structure if auto_emit_run_start: _emit_run_start(tracer, agent_type, llm_model, goal, start_url) @@ -254,6 +324,9 @@ def create_tracer( screenshot_processor=screenshot_processor, ) + # Register for atexit cleanup (ensures file is properly closed) + _register_tracer_for_cleanup(tracer) + # Auto-emit run_start for complete trace structure if auto_emit_run_start: _emit_run_start(tracer, agent_type, llm_model, goal, start_url) diff --git a/predicate/tracing.py b/predicate/tracing.py index 91174d7..92732dc 100644 --- a/predicate/tracing.py +++ b/predicate/tracing.py @@ -205,6 +205,10 @@ class Tracer: _step_successes: int = field(default=0, init=False) _step_failures: int = field(default=0, init=False) _has_errors: bool = field(default=False, init=False) + # Callback for cleanup notification (set by tracer_factory for atexit cleanup) + _on_close_callback: Callable[["Tracer"], None] | None = field(default=None, init=False) + # Track if already closed to prevent double-close + _closed: bool = field(default=False, init=False) def emit( self, @@ -478,11 +482,27 @@ def _infer_final_status(self) -> None: def close(self, **kwargs) -> None: """ - Close the underlying sink. + Close the underlying sink and upload trace data. + + This method is idempotent - calling it multiple times is safe. + It's automatically called when using the tracer as a context manager, + and as a safety net via atexit when the process exits. Args: **kwargs: Passed through to sink.close() (e.g., blocking=True for CloudTraceSink) """ + # Prevent double-close + if self._closed: + return + self._closed = True + + # Notify cleanup registry (unregister from atexit) + if self._on_close_callback is not None: + try: + self._on_close_callback(self) + except Exception: + pass # Don't let callback errors prevent close + # Auto-infer final_status if not explicitly set and we have step outcomes if self.final_status == "unknown" and ( self._step_successes > 0 or self._step_failures > 0 or self._has_errors @@ -509,3 +529,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): """Context manager cleanup.""" self.close() return False + + async def __aenter__(self): + """Async context manager support for use with 'async with'.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager cleanup.""" + self.close() + return False diff --git a/tests/test_agent_factory.py b/tests/test_agent_factory.py index 9fbe6bd..a73151e 100644 --- a/tests/test_agent_factory.py +++ b/tests/test_agent_factory.py @@ -14,9 +14,24 @@ get_config_preset, ) from predicate.agents.planner_executor_agent import PlannerExecutorAgent, PlannerExecutorConfig -from predicate.llm_provider import AnthropicProvider, OllamaProvider, OpenAIProvider +from predicate.llm_provider import OllamaProvider from predicate.tracing import Tracer +# Optional imports for cloud providers +try: + from predicate.llm_provider import OpenAIProvider + + HAS_OPENAI = True +except ImportError: + HAS_OPENAI = False + +try: + from predicate.llm_provider import AnthropicProvider + + HAS_ANTHROPIC = True +except ImportError: + HAS_ANTHROPIC = False + class TestDetectProvider: """Test provider auto-detection from model names.""" @@ -88,6 +103,7 @@ def test_create_ollama_provider(self): assert isinstance(provider, OllamaProvider) assert provider.model_name == "qwen3:8b" + @pytest.mark.skipif(not HAS_OPENAI, reason="openai package not installed") def test_create_openai_provider(self): """Should create OpenAIProvider for openai.""" provider = _create_provider( @@ -100,6 +116,7 @@ def test_create_openai_provider(self): assert isinstance(provider, OpenAIProvider) assert provider.model_name == "gpt-4o" + @pytest.mark.skipif(not HAS_ANTHROPIC, reason="anthropic package not installed") def test_create_anthropic_provider(self): """Should create AnthropicProvider for anthropic.""" provider = _create_provider( @@ -267,6 +284,7 @@ def test_create_agent_with_custom_tracer(self): ) assert isinstance(agent, PlannerExecutorAgent) + @pytest.mark.skipif(not HAS_OPENAI, reason="openai package not installed") def test_create_agent_mixed_providers(self): """Should support mixed cloud/local configuration.""" agent = create_planner_executor_agent( diff --git a/tests/unit/test_planner_executor_agent.py b/tests/unit/test_planner_executor_agent.py index 0c55471..33d31d1 100644 --- a/tests/unit/test_planner_executor_agent.py +++ b/tests/unit/test_planner_executor_agent.py @@ -47,8 +47,8 @@ def test_basic_prompt_structure(self) -> None: intent=None, compact_context="123|button|Submit|100|1|0|-|0|", ) - assert "CLICK()" in sys_prompt - assert "TYPE(" in sys_prompt + # Prompt should mention CLICK format (either CLICK(id) or CLICK()) + assert "CLICK" in sys_prompt assert "Goal: Click the submit button" in user_prompt assert "123|button|Submit" in user_prompt diff --git a/traces/test-run.jsonl b/traces/test-run.jsonl index c75cbac..221a5d6 100644 --- a/traces/test-run.jsonl +++ b/traces/test-run.jsonl @@ -18,3 +18,8 @@ {"v": 1, "type": "run_start", "ts": "2026-03-29T01:51:59.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774749119219} {"v": 1, "type": "run_start", "ts": "2026-03-29T01:51:59.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774749119221} {"v": 1, "type": "run_start", "ts": "2026-03-29T01:51:59.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774749119306} +{"v": 1, "type": "run_start", "ts": "2026-03-29T02:37:44.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774751864986} +{"v": 1, "type": "run_start", "ts": "2026-03-29T02:37:44.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774751864988} +{"v": 1, "type": "run_start", "ts": "2026-03-29T02:37:44.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774751864990} +{"v": 1, "type": "run_start", "ts": "2026-03-29T02:37:44.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774751864991} +{"v": 1, "type": "run_start", "ts": "2026-03-29T02:37:44.000Z", "run_id": "test-run", "seq": 1, "data": {"agent": "SentienceAgent"}, "ts_ms": 1774751864998} From 6e8958c7a4cccdbc5d511e6aa2d45ef558215009 Mon Sep 17 00:00:00 2001 From: SentienceDEV Date: Sat, 28 Mar 2026 20:07:10 -0700 Subject: [PATCH 4/4] fix tests --- tests/test_agent_factory.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/test_agent_factory.py b/tests/test_agent_factory.py index a73151e..5d3750e 100644 --- a/tests/test_agent_factory.py +++ b/tests/test_agent_factory.py @@ -17,21 +17,25 @@ from predicate.llm_provider import OllamaProvider from predicate.tracing import Tracer -# Optional imports for cloud providers +# Check if optional cloud provider packages are installed +# Note: The provider classes exist but require their respective packages at runtime try: - from predicate.llm_provider import OpenAIProvider + import openai # noqa: F401 HAS_OPENAI = True except ImportError: HAS_OPENAI = False try: - from predicate.llm_provider import AnthropicProvider + import anthropic # noqa: F401 HAS_ANTHROPIC = True except ImportError: HAS_ANTHROPIC = False +# Import provider classes (they exist but need packages at instantiation time) +from predicate.llm_provider import AnthropicProvider, OpenAIProvider + class TestDetectProvider: """Test provider auto-detection from model names."""