From 66becd5e50d6209bd46e4c75bef253edaed9872c Mon Sep 17 00:00:00 2001 From: miguel Date: Mon, 2 Jun 2025 14:00:23 -0700 Subject: [PATCH 1/6] abstract functionality and fix default_schema --- stagehand/api.py | 182 +++++++++ stagehand/browser.py | 329 ++++++++++++++++ stagehand/client.py | 515 +++----------------------- stagehand/handlers/extract_handler.py | 7 +- stagehand/llm/client.py | 3 + stagehand/llm/inference.py | 3 + stagehand/page.py | 1 + stagehand/types/__init__.py | 4 +- stagehand/types/page.py | 4 +- stagehand/utils.py | 4 + 10 files changed, 581 insertions(+), 471 deletions(-) diff --git a/stagehand/api.py b/stagehand/api.py index e69de29b..ddda08ce 100644 --- a/stagehand/api.py +++ b/stagehand/api.py @@ -0,0 +1,182 @@ +import json +from typing import Any + +import httpx + +from .utils import convert_dict_keys_to_camel_case + +__all__ = ["_create_session", "_execute"] + + +async def _create_session(self): + """ + Create a new session by calling /sessions/start on the server. + Depends on browserbase_api_key, browserbase_project_id, and model_api_key. + """ + if not self.browserbase_api_key: + raise ValueError("browserbase_api_key is required to create a session.") + if not self.browserbase_project_id: + raise ValueError("browserbase_project_id is required to create a session.") + if not self.model_api_key: + raise ValueError("model_api_key is required to create a session.") + + browserbase_session_create_params = ( + convert_dict_keys_to_camel_case(self.browserbase_session_create_params) + if self.browserbase_session_create_params + else None + ) + + payload = { + "modelName": self.model_name, + "verbose": 2 if self.verbose == 3 else self.verbose, + "domSettleTimeoutMs": self.dom_settle_timeout_ms, + "browserbaseSessionCreateParams": ( + browserbase_session_create_params + if browserbase_session_create_params + else { + "browserSettings": { + "blockAds": True, + "viewport": { + "width": 1024, + "height": 768, + }, + }, + } + ), + "proxies": True, + } + + # Add the new parameters if they have values + if hasattr(self, "self_heal") and self.self_heal is not None: + payload["selfHeal"] = self.self_heal + + if ( + hasattr(self, "wait_for_captcha_solves") + and self.wait_for_captcha_solves is not None + ): + payload["waitForCaptchaSolves"] = self.wait_for_captcha_solves + + if hasattr(self, "act_timeout_ms") and self.act_timeout_ms is not None: + payload["actTimeoutMs"] = self.act_timeout_ms + + if hasattr(self, "system_prompt") and self.system_prompt: + payload["systemPrompt"] = self.system_prompt + + if hasattr(self, "model_client_options") and self.model_client_options: + payload["modelClientOptions"] = self.model_client_options + + headers = { + "x-bb-api-key": self.browserbase_api_key, + "x-bb-project-id": self.browserbase_project_id, + "x-model-api-key": self.model_api_key, + "Content-Type": "application/json", + "x-language": "python", + } + + client = self.httpx_client or httpx.AsyncClient(timeout=self.timeout_settings) + async with client: + resp = await client.post( + f"{self.api_url}/sessions/start", + json=payload, + headers=headers, + ) + if resp.status_code != 200: + raise RuntimeError(f"Failed to create session: {resp.text}") + data = resp.json() + self.logger.debug(f"Session created: {data}") + if not data.get("success") or "sessionId" not in data.get("data", {}): + raise RuntimeError(f"Invalid response format: {resp.text}") + + self.session_id = data["data"]["sessionId"] + + +async def _execute(self, method: str, payload: dict[str, Any]) -> Any: + """ + Internal helper to call /sessions/{session_id}/{method} with the given method and payload. + Streams line-by-line, returning the 'result' from the final message (if any). + """ + headers = { + "x-bb-api-key": self.browserbase_api_key, + "x-bb-project-id": self.browserbase_project_id, + "Content-Type": "application/json", + "Connection": "keep-alive", + # Always enable streaming for better log handling + "x-stream-response": "true", + } + if self.model_api_key: + headers["x-model-api-key"] = self.model_api_key + + # Convert snake_case keys to camelCase for the API + modified_payload = convert_dict_keys_to_camel_case(payload) + + client = self.httpx_client or httpx.AsyncClient(timeout=self.timeout_settings) + self.logger.debug(f"\n==== EXECUTING {method.upper()} ====") + self.logger.debug(f"URL: {self.api_url}/sessions/{self.session_id}/{method}") + self.logger.debug(f"Payload: {modified_payload}") + self.logger.debug(f"Headers: {headers}") + + async with client: + try: + # Always use streaming for consistent log handling + async with client.stream( + "POST", + f"{self.api_url}/sessions/{self.session_id}/{method}", + json=modified_payload, + headers=headers, + ) as response: + if response.status_code != 200: + error_text = await response.aread() + error_message = error_text.decode("utf-8") + self.logger.error( + f"[HTTP ERROR] Status {response.status_code}: {error_message}" + ) + raise RuntimeError( + f"Request failed with status {response.status_code}: {error_message}" + ) + + self.logger.debug("[STREAM] Processing server response") + result = None + + async for line in response.aiter_lines(): + # Skip empty lines + if not line.strip(): + continue + + try: + # Handle SSE-style messages that start with "data: " + if line.startswith("data: "): + line = line[len("data: ") :] + + message = json.loads(line) + # Handle different message types + msg_type = message.get("type") + + if msg_type == "system": + status = message.get("data", {}).get("status") + if status == "error": + error_msg = message.get("data", {}).get( + "error", "Unknown error" + ) + self.logger.error(f"[ERROR] {error_msg}") + raise RuntimeError( + f"Server returned error: {error_msg}" + ) + elif status == "finished": + result = message.get("data", {}).get("result") + self.logger.debug( + "[SYSTEM] Operation completed successfully" + ) + elif msg_type == "log": + # Process log message using _handle_log + await self._handle_log(message) + else: + # Log any other message types + self.logger.debug(f"[UNKNOWN] Message type: {msg_type}") + except json.JSONDecodeError: + self.logger.warning(f"Could not parse line as JSON: {line}") + + # Return the final result + return result + except Exception as e: + self.logger.error(f"[EXCEPTION] {str(e)}") + raise diff --git a/stagehand/browser.py b/stagehand/browser.py index e69de29b..ed920c00 100644 --- a/stagehand/browser.py +++ b/stagehand/browser.py @@ -0,0 +1,329 @@ +import json +import os +import shutil +import tempfile +from pathlib import Path +from typing import Any, Optional + +from browserbase import Browserbase +from playwright.async_api import ( + Browser, + BrowserContext, + Playwright, +) + +from .context import StagehandContext +from .page import StagehandPage +from .utils import StagehandLogger + + +async def connect_browserbase_browser( + playwright: Playwright, + session_id: str, + browserbase_api_key: str, + stagehand_instance: Any, + logger: StagehandLogger, +) -> tuple[Browser, BrowserContext, StagehandContext, StagehandPage]: + """ + Connect to a Browserbase remote browser session. + + Args: + playwright: The Playwright instance + session_id: The Browserbase session ID + browserbase_api_key: The Browserbase API key + stagehand_instance: The Stagehand instance (for context initialization) + logger: The logger instance + + Returns: + tuple of (browser, context, stagehand_context, page) + """ + # Connect to remote browser via Browserbase SDK and CDP + bb = Browserbase(api_key=browserbase_api_key) + try: + logger.debug(f"Retrieving Browserbase session details for {session_id}...") + session = bb.sessions.retrieve(session_id) + if session.status != "RUNNING": + raise RuntimeError( + f"Browserbase session {session_id} is not running (status: {session.status})" + ) + connect_url = session.connectUrl + except Exception as e: + logger.error(f"Error retrieving or validating Browserbase session: {str(e)}") + raise + + logger.debug(f"Connecting to remote browser at: {connect_url}") + try: + browser = await playwright.chromium.connect_over_cdp(connect_url) + logger.debug(f"Connected to remote browser: {browser}") + except Exception as e: + logger.error(f"Failed to connect Playwright via CDP: {str(e)}") + raise + + existing_contexts = browser.contexts + logger.debug(f"Existing contexts in remote browser: {len(existing_contexts)}") + if existing_contexts: + context = existing_contexts[0] + else: + # This case might be less common with Browserbase but handle it + logger.warning( + "No existing context found in remote browser, creating a new one." + ) + context = await browser.new_context() + + stagehand_context = await StagehandContext.init(context, stagehand_instance) + + # Access or create a page via StagehandContext + existing_pages = context.pages + logger.debug(f"Existing pages in context: {len(existing_pages)}") + if existing_pages: + logger.debug("Using existing page via StagehandContext") + page = await stagehand_context.get_stagehand_page(existing_pages[0]) + else: + logger.debug("Creating a new page via StagehandContext") + page = await stagehand_context.new_page() + + return browser, context, stagehand_context, page + + +async def connect_local_browser( + playwright: Playwright, + local_browser_launch_options: dict[str, Any], + stagehand_instance: Any, + logger: StagehandLogger, +) -> tuple[ + Optional[Browser], BrowserContext, StagehandContext, StagehandPage, Optional[Path] +]: + """ + Connect to a local browser via CDP or launch a new browser context. + + Args: + playwright: The Playwright instance + local_browser_launch_options: Options for launching the local browser + stagehand_instance: The Stagehand instance (for context initialization) + logger: The logger instance + + Returns: + tuple of (browser, context, stagehand_context, page, temp_user_data_dir) + """ + cdp_url = local_browser_launch_options.get("cdp_url") + temp_user_data_dir = None + + if cdp_url: + logger.info(f"Connecting to local browser via CDP URL: {cdp_url}") + try: + browser = await playwright.chromium.connect_over_cdp(cdp_url) + + if not browser.contexts: + raise RuntimeError(f"No browser contexts found at CDP URL: {cdp_url}") + context = browser.contexts[0] + stagehand_context = await StagehandContext.init(context, stagehand_instance) + logger.debug(f"Connected via CDP. Using context: {context}") + except Exception as e: + logger.error(f"Failed to connect via CDP URL ({cdp_url}): {str(e)}") + raise + else: + logger.info("Launching new local browser context...") + browser = None + + user_data_dir_option = local_browser_launch_options.get("user_data_dir") + if user_data_dir_option: + user_data_dir = Path(user_data_dir_option).resolve() + else: + # Create temporary directory + temp_dir = tempfile.mkdtemp(prefix="stagehand_ctx_") + temp_user_data_dir = Path(temp_dir) + user_data_dir = temp_user_data_dir + # Create Default profile directory and Preferences file like in TS + default_profile_path = user_data_dir / "Default" + default_profile_path.mkdir(parents=True, exist_ok=True) + prefs_path = default_profile_path / "Preferences" + default_prefs = {"plugins": {"always_open_pdf_externally": True}} + try: + with open(prefs_path, "w") as f: + json.dump(default_prefs, f) + logger.debug( + f"Created temporary user_data_dir with default preferences: {user_data_dir}" + ) + except Exception as e: + logger.error( + f"Failed to write default preferences to {prefs_path}: {e}" + ) + + downloads_path_option = local_browser_launch_options.get("downloads_path") + if downloads_path_option: + downloads_path = str(Path(downloads_path_option).resolve()) + else: + downloads_path = str(Path.cwd() / "downloads") + try: + os.makedirs(downloads_path, exist_ok=True) + logger.debug(f"Using downloads_path: {downloads_path}") + except Exception as e: + logger.error(f"Failed to create downloads_path {downloads_path}: {e}") + + # Prepare Launch Options (translate keys if needed) + launch_options = { + "headless": local_browser_launch_options.get("headless", False), + "accept_downloads": local_browser_launch_options.get( + "acceptDownloads", True + ), + "downloads_path": downloads_path, + "args": local_browser_launch_options.get( + "args", + [ + "--disable-blink-features=AutomationControlled", + ], + ), + "viewport": local_browser_launch_options.get( + "viewport", {"width": 1024, "height": 768} + ), + "locale": local_browser_launch_options.get("locale", "en-US"), + "timezone_id": local_browser_launch_options.get( + "timezoneId", "America/New_York" + ), + "bypass_csp": local_browser_launch_options.get("bypassCSP", True), + "proxy": local_browser_launch_options.get("proxy"), + "ignore_https_errors": local_browser_launch_options.get( + "ignoreHTTPSErrors", True + ), + } + launch_options = {k: v for k, v in launch_options.items() if v is not None} + + # Launch Context + try: + context = await playwright.chromium.launch_persistent_context( + str(user_data_dir), # Needs to be string path + **launch_options, + ) + stagehand_context = await StagehandContext.init(context, stagehand_instance) + logger.info("Local browser context launched successfully.") + browser = context.browser + + except Exception as e: + logger.error(f"Failed to launch local browser context: {str(e)}") + if temp_user_data_dir: + try: + shutil.rmtree(temp_user_data_dir) + except: + pass + raise + + cookies = local_browser_launch_options.get("cookies") + if cookies: + try: + await context.add_cookies(cookies) + logger.debug(f"Added {len(cookies)} cookies to the context.") + except Exception as e: + logger.error(f"Failed to add cookies: {e}") + + # Apply stealth scripts + await apply_stealth_scripts(context, logger) + + # Get the initial page (usually one is created by default) + if context.pages: + playwright_page = context.pages[0] + logger.debug("Using initial page from local context.") + else: + logger.debug("No initial page found, creating a new one.") + playwright_page = await context.new_page() + + page = StagehandPage(playwright_page, stagehand_instance) + + return browser, context, stagehand_context, page, temp_user_data_dir + + +async def apply_stealth_scripts(context: BrowserContext, logger: StagehandLogger): + """Applies JavaScript init scripts to make the browser less detectable.""" + logger.debug("Applying stealth init scripts to the context...") + stealth_script = """ + (() => { + // Override navigator.webdriver + if (navigator.webdriver) { + Object.defineProperty(navigator, 'webdriver', { + get: () => undefined + }); + } + + // Mock languages and plugins + Object.defineProperty(navigator, 'languages', { + get: () => ['en-US', 'en'], + }); + + // Avoid complex plugin mocking, just return a non-empty array like structure + if (navigator.plugins instanceof PluginArray && navigator.plugins.length === 0) { + Object.defineProperty(navigator, 'plugins', { + get: () => Object.values({ + 'plugin1': { name: 'Chrome PDF Plugin', filename: 'internal-pdf-viewer', description: 'Portable Document Format' }, + 'plugin2': { name: 'Chrome PDF Viewer', filename: 'mhjfbmdgcfjbbpaeojofohoefgiehjai', description: '' }, + 'plugin3': { name: 'Native Client', filename: 'internal-nacl-plugin', description: '' } + }), + }); + } + + // Remove Playwright-specific properties from window + try { + delete window.__playwright_run; // Example property, check actual properties if needed + delete window.navigator.__proto__.webdriver; // Another common place + } catch (e) {} + + // Override permissions API (example for notifications) + if (window.navigator && window.navigator.permissions) { + const originalQuery = window.navigator.permissions.query; + window.navigator.permissions.query = (parameters) => { + if (parameters && parameters.name === 'notifications') { + return Promise.resolve({ state: Notification.permission }); + } + // Call original for other permissions + return originalQuery.apply(window.navigator.permissions, [parameters]); + }; + } + })(); + """ + try: + await context.add_init_script(stealth_script) + logger.debug("Stealth init script added successfully.") + except Exception as e: + logger.error(f"Failed to add stealth init script: {str(e)}") + + +async def cleanup_browser_resources( + browser: Optional[Browser], + context: Optional[BrowserContext], + playwright: Optional[Playwright], + temp_user_data_dir: Optional[Path], + logger: StagehandLogger, +): + """ + Clean up browser resources. + + Args: + browser: The browser instance (if any) + context: The browser context + playwright: The Playwright instance + temp_user_data_dir: Temporary user data directory to remove (if any) + logger: The logger instance + """ + if context: + try: + logger.debug("Closing browser context...") + await context.close() + except Exception as e: + logger.error(f"Error closing context: {str(e)}") + + # Clean up temporary user data directory if created + if temp_user_data_dir: + try: + logger.debug( + f"Removing temporary user data directory: {temp_user_data_dir}" + ) + shutil.rmtree(temp_user_data_dir) + except Exception as e: + logger.error( + f"Error removing temporary directory {temp_user_data_dir}: {str(e)}" + ) + + if playwright: + try: + logger.debug("Stopping Playwright...") + await playwright.stop() + except Exception as e: + logger.error(f"Error stopping Playwright: {str(e)}") diff --git a/stagehand/client.py b/stagehand/client.py index 37b35ed7..1a4f7c6e 100644 --- a/stagehand/client.py +++ b/stagehand/client.py @@ -1,14 +1,10 @@ import asyncio -import json import os -import shutil -import tempfile import time from pathlib import Path from typing import Any, Literal, Optional import httpx -from browserbase import Browserbase from dotenv import load_dotenv from playwright.async_api import ( BrowserContext, @@ -18,6 +14,12 @@ from playwright.async_api import Page as PlaywrightPage from .agent import Agent +from .api import _create_session, _execute +from .browser import ( + cleanup_browser_resources, + connect_browserbase_browser, + connect_local_browser, +) from .config import StagehandConfig, default_config from .context import StagehandContext from .llm import LLMClient @@ -26,7 +28,6 @@ from .schemas import AgentConfig from .utils import ( StagehandLogger, - convert_dict_keys_to_camel_case, default_log_handler, make_serializable, ) @@ -374,215 +375,44 @@ async def init(self): f"Using existing Browserbase session: {self.session_id}" ) - # Connect to remote browser via Browserbase SDK and CDP - bb = Browserbase(api_key=self.browserbase_api_key) + # Connect to remote browser try: - self.logger.debug( - f"Retrieving Browserbase session details for {self.session_id}..." - ) - session = bb.sessions.retrieve(self.session_id) - if session.status != "RUNNING": - raise RuntimeError( - f"Browserbase session {self.session_id} is not running (status: {session.status})" - ) - connect_url = session.connectUrl - except Exception as e: - self.logger.error( - f"Error retrieving or validating Browserbase session: {str(e)}" + ( + self._browser, + self._context, + self.context, + self.page, + ) = await connect_browserbase_browser( + self._playwright, + self.session_id, + self.browserbase_api_key, + self, + self.logger, ) - await self.close() # Clean up playwright if started + self._playwright_page = self.page._page + except Exception: + await self.close() raise - self.logger.debug(f"Connecting to remote browser at: {connect_url}") + elif self.env == "LOCAL": + # Connect to local browser try: - self._browser = await self._playwright.chromium.connect_over_cdp( - connect_url + ( + self._browser, + self._context, + self.context, + self.page, + self._local_user_data_dir_temp, + ) = await connect_local_browser( + self._playwright, + self.local_browser_launch_options, + self, + self.logger, ) - self.logger.debug(f"Connected to remote browser: {self._browser}") - except Exception as e: - self.logger.error(f"Failed to connect Playwright via CDP: {str(e)}") + self._playwright_page = self.page._page + except Exception: await self.close() raise - - existing_contexts = self._browser.contexts - self.logger.debug( - f"Existing contexts in remote browser: {len(existing_contexts)}" - ) - if existing_contexts: - self._context = existing_contexts[0] - else: - # This case might be less common with Browserbase but handle it - self.logger.warning( - "No existing context found in remote browser, creating a new one." - ) - self._context = ( - await self._browser.new_context() - ) # Should we pass options? - - self.context = await StagehandContext.init(self._context, self) - - # Access or create a page via StagehandContext - existing_pages = self._context.pages - self.logger.debug(f"Existing pages in context: {len(existing_pages)}") - if existing_pages: - self.logger.debug("Using existing page via StagehandContext") - self.page = await self.context.get_stagehand_page(existing_pages[0]) - self._playwright_page = existing_pages[0] - else: - self.logger.debug("Creating a new page via StagehandContext") - self.page = await self.context.new_page() - self._playwright_page = self.page.page - - elif self.env == "LOCAL": - cdp_url = self.local_browser_launch_options.get("cdp_url") - - if cdp_url: - self.logger.info(f"Connecting to local browser via CDP URL: {cdp_url}") - try: - self._browser = await self._playwright.chromium.connect_over_cdp( - cdp_url - ) - - if not self._browser.contexts: - raise RuntimeError( - f"No browser contexts found at CDP URL: {cdp_url}" - ) - self._context = self._browser.contexts[0] - self.context = await StagehandContext.init(self._context, self) - self.logger.debug( - f"Connected via CDP. Using context: {self._context}" - ) - except Exception as e: - self.logger.error( - f"Failed to connect via CDP URL ({cdp_url}): {str(e)}" - ) - await self.close() - raise - else: - self.logger.info("Launching new local browser context...") - - user_data_dir_option = self.local_browser_launch_options.get( - "user_data_dir" - ) - if user_data_dir_option: - user_data_dir = Path(user_data_dir_option).resolve() - else: - # Create temporary directory - temp_dir = tempfile.mkdtemp(prefix="stagehand_ctx_") - self._local_user_data_dir_temp = Path(temp_dir) - user_data_dir = self._local_user_data_dir_temp - # Create Default profile directory and Preferences file like in TS - default_profile_path = user_data_dir / "Default" - default_profile_path.mkdir(parents=True, exist_ok=True) - prefs_path = default_profile_path / "Preferences" - default_prefs = {"plugins": {"always_open_pdf_externally": True}} - try: - with open(prefs_path, "w") as f: - json.dump(default_prefs, f) - self.logger.debug( - f"Created temporary user_data_dir with default preferences: {user_data_dir}" - ) - except Exception as e: - self.logger.error( - f"Failed to write default preferences to {prefs_path}: {e}" - ) - - downloads_path_option = self.local_browser_launch_options.get( - "downloads_path" - ) - if downloads_path_option: - downloads_path = str(Path(downloads_path_option).resolve()) - else: - downloads_path = str(Path.cwd() / "downloads") - try: - os.makedirs(downloads_path, exist_ok=True) - self.logger.debug(f"Using downloads_path: {downloads_path}") - except Exception as e: - self.logger.error( - f"Failed to create downloads_path {downloads_path}: {e}" - ) - - # 3. Prepare Launch Options (translate keys if needed) - launch_options = { - "headless": self.local_browser_launch_options.get( - "headless", False - ), - "accept_downloads": self.local_browser_launch_options.get( - "acceptDownloads", True - ), - "downloads_path": downloads_path, - "args": self.local_browser_launch_options.get( - "args", - [ - # Common args from TS version - # "--enable-webgl", - # "--use-gl=swiftshader", - # "--enable-accelerated-2d-canvas", - "--disable-blink-features=AutomationControlled", - # "--disable-web-security", # Use with caution - ], - ), - # Add more translations as needed based on local_browser_launch_options structure - "viewport": self.local_browser_launch_options.get( - "viewport", {"width": 1024, "height": 768} - ), - "locale": self.local_browser_launch_options.get("locale", "en-US"), - "timezone_id": self.local_browser_launch_options.get( - "timezoneId", "America/New_York" - ), - "bypass_csp": self.local_browser_launch_options.get( - "bypassCSP", True - ), - "proxy": self.local_browser_launch_options.get("proxy"), - "ignore_https_errors": self.local_browser_launch_options.get( - "ignoreHTTPSErrors", True - ), - } - launch_options = { - k: v for k, v in launch_options.items() if v is not None - } - - # 4. Launch Context - try: - self._context = ( - await self._playwright.chromium.launch_persistent_context( - str(user_data_dir), # Needs to be string path - **launch_options, - ) - ) - self.context = await StagehandContext.init(self._context, self) - self.logger.info("Local browser context launched successfully.") - self._browser = self._context.browser - - except Exception as e: - self.logger.error( - f"Failed to launch local browser context: {str(e)}" - ) - await self.close() # Clean up playwright and temp dir - raise - - cookies = self.local_browser_launch_options.get("cookies") - if cookies: - try: - await self._context.add_cookies(cookies) - self.logger.debug( - f"Added {len(cookies)} cookies to the context." - ) - except Exception as e: - self.logger.error(f"Failed to add cookies: {e}") - - # Apply stealth scripts - await self._apply_stealth_scripts(self._context) - - # Get the initial page (usually one is created by default) - if self._context.pages: - self._playwright_page = self._context.pages[0] - self.logger.debug("Using initial page from local context.") - else: - self.logger.debug("No initial page found, creating a new one.") - self._playwright_page = await self._context.new_page() - - self.page = StagehandPage(self._playwright_page, self) else: # Should not happen due to __init__ validation raise RuntimeError(f"Invalid env value: {self.env}") @@ -653,210 +483,16 @@ async def close(self): await self._client.aclose() self._client = None - elif self.env == "LOCAL": - if self._context: - try: - self.logger.debug("Closing local browser context...") - await self._context.close() - self._context = None - self._browser = None # Clear browser reference too - except Exception as e: - self.logger.error(f"Error closing local context: {str(e)}") - - # Clean up temporary user data directory if created - if self._local_user_data_dir_temp: - try: - self.logger.debug( - f"Removing temporary user data directory: {self._local_user_data_dir_temp}" - ) - shutil.rmtree(self._local_user_data_dir_temp) - self._local_user_data_dir_temp = None - except Exception as e: - self.logger.error( - f"Error removing temporary directory {self._local_user_data_dir_temp}: {str(e)}" - ) - - if self._playwright: - try: - self.logger.debug("Stopping Playwright...") - await self._playwright.stop() - self._playwright = None - except Exception as e: - self.logger.error(f"Error stopping Playwright: {str(e)}") - - self._closed = True - - async def _create_session(self): - """ - Create a new session by calling /sessions/start on the server. - Depends on browserbase_api_key, browserbase_project_id, and model_api_key. - """ - if not self.browserbase_api_key: - raise ValueError("browserbase_api_key is required to create a session.") - if not self.browserbase_project_id: - raise ValueError("browserbase_project_id is required to create a session.") - if not self.model_api_key: - raise ValueError("model_api_key is required to create a session.") - - browserbase_session_create_params = ( - convert_dict_keys_to_camel_case(self.browserbase_session_create_params) - if self.browserbase_session_create_params - else None + # Use the centralized cleanup function for browser resources + await cleanup_browser_resources( + self._browser, + self._context, + self._playwright, + self._local_user_data_dir_temp, + self.logger, ) - payload = { - "modelName": self.model_name, - "verbose": 2 if self.verbose == 3 else self.verbose, - "domSettleTimeoutMs": self.dom_settle_timeout_ms, - "browserbaseSessionCreateParams": ( - browserbase_session_create_params - if browserbase_session_create_params - else { - "browserSettings": { - "blockAds": True, - "viewport": { - "width": 1024, - "height": 768, - }, - }, - } - ), - "proxies": True, - } - - # Add the new parameters if they have values - if hasattr(self, "self_heal") and self.self_heal is not None: - payload["selfHeal"] = self.self_heal - - if ( - hasattr(self, "wait_for_captcha_solves") - and self.wait_for_captcha_solves is not None - ): - payload["waitForCaptchaSolves"] = self.wait_for_captcha_solves - - if hasattr(self, "act_timeout_ms") and self.act_timeout_ms is not None: - payload["actTimeoutMs"] = self.act_timeout_ms - - if hasattr(self, "system_prompt") and self.system_prompt: - payload["systemPrompt"] = self.system_prompt - - if hasattr(self, "model_client_options") and self.model_client_options: - payload["modelClientOptions"] = self.model_client_options - - headers = { - "x-bb-api-key": self.browserbase_api_key, - "x-bb-project-id": self.browserbase_project_id, - "x-model-api-key": self.model_api_key, - "Content-Type": "application/json", - "x-language": "python", - } - - client = self.httpx_client or httpx.AsyncClient(timeout=self.timeout_settings) - async with client: - resp = await client.post( - f"{self.api_url}/sessions/start", - json=payload, - headers=headers, - ) - if resp.status_code != 200: - raise RuntimeError(f"Failed to create session: {resp.text}") - data = resp.json() - self.logger.debug(f"Session created: {data}") - if not data.get("success") or "sessionId" not in data.get("data", {}): - raise RuntimeError(f"Invalid response format: {resp.text}") - - self.session_id = data["data"]["sessionId"] - - async def _execute(self, method: str, payload: dict[str, Any]) -> Any: - """ - Internal helper to call /sessions/{session_id}/{method} with the given method and payload. - Streams line-by-line, returning the 'result' from the final message (if any). - """ - headers = { - "x-bb-api-key": self.browserbase_api_key, - "x-bb-project-id": self.browserbase_project_id, - "Content-Type": "application/json", - "Connection": "keep-alive", - # Always enable streaming for better log handling - "x-stream-response": "true", - } - if self.model_api_key: - headers["x-model-api-key"] = self.model_api_key - - # Convert snake_case keys to camelCase for the API - modified_payload = convert_dict_keys_to_camel_case(payload) - - client = self.httpx_client or httpx.AsyncClient(timeout=self.timeout_settings) - self.logger.debug(f"\n==== EXECUTING {method.upper()} ====") - self.logger.debug(f"URL: {self.api_url}/sessions/{self.session_id}/{method}") - self.logger.debug(f"Payload: {modified_payload}") - self.logger.debug(f"Headers: {headers}") - - async with client: - try: - # Always use streaming for consistent log handling - async with client.stream( - "POST", - f"{self.api_url}/sessions/{self.session_id}/{method}", - json=modified_payload, - headers=headers, - ) as response: - if response.status_code != 200: - error_text = await response.aread() - error_message = error_text.decode("utf-8") - self.logger.error( - f"[HTTP ERROR] Status {response.status_code}: {error_message}" - ) - raise RuntimeError( - f"Request failed with status {response.status_code}: {error_message}" - ) - - self.logger.debug("[STREAM] Processing server response") - result = None - - async for line in response.aiter_lines(): - # Skip empty lines - if not line.strip(): - continue - - try: - # Handle SSE-style messages that start with "data: " - if line.startswith("data: "): - line = line[len("data: ") :] - - message = json.loads(line) - # Handle different message types - msg_type = message.get("type") - - if msg_type == "system": - status = message.get("data", {}).get("status") - if status == "error": - error_msg = message.get("data", {}).get( - "error", "Unknown error" - ) - self.logger.error(f"[ERROR] {error_msg}") - raise RuntimeError( - f"Server returned error: {error_msg}" - ) - elif status == "finished": - result = message.get("data", {}).get("result") - self.logger.debug( - "[SYSTEM] Operation completed successfully" - ) - elif msg_type == "log": - # Process log message using _handle_log - await self._handle_log(message) - else: - # Log any other message types - self.logger.debug(f"[UNKNOWN] Message type: {msg_type}") - except json.JSONDecodeError: - self.logger.warning(f"Could not parse line as JSON: {line}") - - # Return the final result - return result - except Exception as e: - self.logger.error(f"[EXCEPTION] {str(e)}") - raise + self._closed = True async def _handle_log(self, msg: dict[str, Any]): """ @@ -930,64 +566,6 @@ def _log( # Use the structured logger self.logger.log(message, level=level, category=category, auxiliary=auxiliary) - async def _apply_stealth_scripts(self, context: BrowserContext): - """Applies JavaScript init scripts to make the browser less detectable.""" - self.logger.debug("Applying stealth init scripts to the context...") - # Adapted from the TypeScript version - stealth_script = """ - (() => { - // Override navigator.webdriver - if (navigator.webdriver) { - Object.defineProperty(navigator, 'webdriver', { - get: () => undefined - }); - } - - // Mock languages and plugins - Object.defineProperty(navigator, 'languages', { - get: () => ['en-US', 'en'], - }); - - // Avoid complex plugin mocking, just return a non-empty array like structure - if (navigator.plugins instanceof PluginArray && navigator.plugins.length === 0) { - Object.defineProperty(navigator, 'plugins', { - get: () => Object.values({ - 'plugin1': { name: 'Chrome PDF Plugin', filename: 'internal-pdf-viewer', description: 'Portable Document Format' }, - 'plugin2': { name: 'Chrome PDF Viewer', filename: 'mhjfbmdgcfjbbpaeojofohoefgiehjai', description: '' }, - 'plugin3': { name: 'Native Client', filename: 'internal-nacl-plugin', description: '' } - }), - }); - } - - - // Remove Playwright-specific properties from window - try { - delete window.__playwright_run; // Example property, check actual properties if needed - delete window.navigator.__proto__.webdriver; // Another common place - } catch (e) {} - - // Override permissions API (example for notifications) - if (window.navigator && window.navigator.permissions) { - const originalQuery = window.navigator.permissions.query; - window.navigator.permissions.query = (parameters) => { - if (parameters && parameters.name === 'notifications') { - return Promise.resolve({ state: Notification.permission }); - } - // Call original for other permissions - return originalQuery.apply(window.navigator.permissions, [parameters]); - }; - } - - // You might need to add more overrides depending on the detection methods used by websites. - // For example, overriding Chrome runtime properties, canvas fingerprinting, etc. - })(); - """ - try: - await context.add_init_script(stealth_script) - self.logger.debug("Stealth init script added successfully.") - except Exception as e: - self.logger.error(f"Failed to add stealth init script: {str(e)}") - def _handle_llm_metrics( self, response: Any, inference_time_ms: int, function_name=None ): @@ -1014,3 +592,8 @@ def _handle_llm_metrics( function_enum = function_name self.update_metrics_from_response(function_enum, response, inference_time_ms) + + +# Bind the imported API methods to the Stagehand class +Stagehand._create_session = _create_session +Stagehand._execute = _execute diff --git a/stagehand/handlers/extract_handler.py b/stagehand/handlers/extract_handler.py index a73c3687..aedb3029 100644 --- a/stagehand/handlers/extract_handler.py +++ b/stagehand/handlers/extract_handler.py @@ -7,7 +7,7 @@ from stagehand.a11y.utils import get_accessibility_tree from stagehand.llm.inference import extract as extract_inference from stagehand.metrics import StagehandFunctionName # Changed import location -from stagehand.types import ExtractOptions, ExtractResult +from stagehand.types import DefaultExtractSchema, ExtractOptions, ExtractResult from stagehand.utils import inject_urls, transform_url_strings_to_ids T = TypeVar("T", bound=BaseModel) @@ -85,6 +85,7 @@ async def extract( self.logger.info("Getting accessibility tree data") output_string = tree["simplified"] id_to_url_mapping = tree.get("idToUrl", {}) + self.logger.info(f"schema: {schema}") # Transform schema URL fields to numeric IDs if necessary transformed_schema = schema @@ -92,6 +93,10 @@ async def extract( if schema: # TODO: Remove this once we have a better way to handle URLs transformed_schema, url_paths = transform_url_strings_to_ids(schema) + else: + transformed_schema = DefaultExtractSchema + + self.logger.info(f"Transformed schema: {transformed_schema}") # Use inference to call the LLM extraction_response = extract_inference( diff --git a/stagehand/llm/client.py b/stagehand/llm/client.py index d3c31c92..f6cecd1f 100644 --- a/stagehand/llm/client.py +++ b/stagehand/llm/client.py @@ -11,6 +11,9 @@ logger = logging.getLogger(__name__) +litellm.enable_json_schema_validation = True + + class LLMClient: """ Client for making LLM calls using the litellm library. diff --git a/stagehand/llm/inference.py b/stagehand/llm/inference.py index 24f0de91..cb4c8c8e 100644 --- a/stagehand/llm/inference.py +++ b/stagehand/llm/inference.py @@ -195,6 +195,9 @@ def extract( extract_content = extract_response.choices[0].message.content if isinstance(extract_content, str): try: + logger.info( + f"Extraction response: {extract_content} for schema: {schema}" + ) extracted_data = json.loads(extract_content) except json.JSONDecodeError: logger.error( diff --git a/stagehand/page.py b/stagehand/page.py index 6430e0b6..38a40256 100644 --- a/stagehand/page.py +++ b/stagehand/page.py @@ -36,6 +36,7 @@ def __init__(self, page: Page, stagehand_client): self._page = page self._stagehand = stagehand_client + # TODO try catch here async def ensure_injection(self): """Ensure custom injection scripts are present on the page using domScripts.js.""" exists_before = await self._page.evaluate( diff --git a/stagehand/types/__init__.py b/stagehand/types/__init__.py index 78217f45..ac1af176 100644 --- a/stagehand/types/__init__.py +++ b/stagehand/types/__init__.py @@ -20,9 +20,9 @@ ChatMessage, ) from .page import ( - DEFAULT_EXTRACT_SCHEMA, ActOptions, ActResult, + DefaultExtractSchema, ExtractOptions, ExtractResult, MetadataSchema, @@ -50,7 +50,7 @@ "ObserveOptions", "ObserveResult", "MetadataSchema", - "DEFAULT_EXTRACT_SCHEMA", + "DefaultExtractSchema", "ExtractOptions", "ExtractResult", "AgentConfig", diff --git a/stagehand/types/page.py b/stagehand/types/page.py index 690a6acd..ecfee164 100644 --- a/stagehand/types/page.py +++ b/stagehand/types/page.py @@ -5,7 +5,7 @@ # Ignore linting error for this class name since it's used as a constant # ruff: noqa: N801 -class DEFAULT_EXTRACT_SCHEMA(BaseModel): +class DefaultExtractSchema(BaseModel): extraction: str @@ -132,7 +132,7 @@ class ExtractOptions(BaseModel): # IMPORTANT: If using a Pydantic model for schema_definition, please call its .model_json_schema() method # to convert it to a JSON serializable dictionary before sending it with the extract command. schema_definition: Union[dict[str, Any], type[BaseModel]] = Field( - default=DEFAULT_EXTRACT_SCHEMA, + default=DefaultExtractSchema, description="A JSON schema or Pydantic model that defines the structure of the expected data.", ) use_text_extract: Optional[bool] = None diff --git a/stagehand/utils.py b/stagehand/utils.py index 9ef5278d..b5456d5f 100644 --- a/stagehand/utils.py +++ b/stagehand/utils.py @@ -117,6 +117,10 @@ def configure_logging( logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpcore").setLevel(logging.WARNING) logging.getLogger("asyncio").setLevel(logging.WARNING) + logging.getLogger("litellm").setLevel(logging.WARNING) + logging.getLogger("LiteLLM").setLevel( + logging.WARNING + ) # Cover both possible logger names ################################################################################ From 852f42a7de3f49c7e4fceec3b710c16a7b127eda Mon Sep 17 00:00:00 2001 From: miguel Date: Mon, 2 Jun 2025 14:10:31 -0700 Subject: [PATCH 2/6] remove unused imports --- stagehand/client.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/stagehand/client.py b/stagehand/client.py index 45927c32..9ee9191b 100644 --- a/stagehand/client.py +++ b/stagehand/client.py @@ -1,7 +1,5 @@ import asyncio import os -import shutil -import tempfile import signal import sys import time From bc9c3be9c3aca08ceda042cda94fc12e048c7c12 Mon Sep 17 00:00:00 2001 From: miguel Date: Mon, 2 Jun 2025 14:25:38 -0700 Subject: [PATCH 3/6] formatting --- stagehand/browser.py | 2 +- stagehand/client.py | 15 +++++++++------ stagehand/utils.py | 3 ++- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/stagehand/browser.py b/stagehand/browser.py index ed920c00..c0c41d9a 100644 --- a/stagehand/browser.py +++ b/stagehand/browser.py @@ -203,7 +203,7 @@ async def connect_local_browser( if temp_user_data_dir: try: shutil.rmtree(temp_user_data_dir) - except: + except Exception: pass raise diff --git a/stagehand/client.py b/stagehand/client.py index 9ee9191b..54f20671 100644 --- a/stagehand/client.py +++ b/stagehand/client.py @@ -48,7 +48,7 @@ class Stagehand: # Dictionary to store one lock per session_id _session_locks = {} - + # Flag to track if cleanup has been called _cleanup_called = False @@ -194,7 +194,7 @@ def __init__( raise ValueError( "browserbase_project_id is required for BROWSERBASE env with existing session_id (or set BROWSERBASE_PROJECT_ID in env)." ) - + # Register signal handlers for graceful shutdown self._register_signal_handlers() @@ -225,13 +225,16 @@ def __init__( def _register_signal_handlers(self): """Register signal handlers for SIGINT and SIGTERM to ensure proper cleanup.""" + def cleanup_handler(sig, frame): # Prevent multiple cleanup calls if self.__class__._cleanup_called: return self.__class__._cleanup_called = True - print(f"\n[{signal.Signals(sig).name}] received. Ending Browserbase session...") + print( + f"\n[{signal.Signals(sig).name}] received. Ending Browserbase session..." + ) try: # Try to get the current event loop @@ -253,11 +256,11 @@ def cleanup_handler(sig, frame): def schedule_cleanup(): task = asyncio.create_task(self._async_cleanup()) # Shield the task to prevent it from being cancelled - shielded = asyncio.shield(task) + asyncio.shield(task) # We don't need to await here since we're in call_soon_threadsafe - + loop.call_soon_threadsafe(schedule_cleanup) - + except Exception as e: print(f"Error during signal cleanup: {str(e)}") sys.exit(1) diff --git a/stagehand/utils.py b/stagehand/utils.py index b5456d5f..ef45d954 100644 --- a/stagehand/utils.py +++ b/stagehand/utils.py @@ -844,7 +844,8 @@ def transform_url_strings_to_ids(schema): return transform_model(schema) -def transform_model(model_cls, path=[]): +# TODO: remove path? +def transform_model(model_cls, path=[]): # noqa: F841 B006 """ Recursively transforms a Pydantic model by replacing URL fields with numeric fields. From d437133d0c35940d04904261e5c898afc42578c9 Mon Sep 17 00:00:00 2001 From: miguel Date: Mon, 2 Jun 2025 14:27:42 -0700 Subject: [PATCH 4/6] address comments --- stagehand/handlers/extract_handler.py | 2 -- stagehand/llm/client.py | 3 --- stagehand/llm/inference.py | 3 --- 3 files changed, 8 deletions(-) diff --git a/stagehand/handlers/extract_handler.py b/stagehand/handlers/extract_handler.py index aedb3029..190e54cb 100644 --- a/stagehand/handlers/extract_handler.py +++ b/stagehand/handlers/extract_handler.py @@ -96,8 +96,6 @@ async def extract( else: transformed_schema = DefaultExtractSchema - self.logger.info(f"Transformed schema: {transformed_schema}") - # Use inference to call the LLM extraction_response = extract_inference( instruction=instruction, diff --git a/stagehand/llm/client.py b/stagehand/llm/client.py index f6cecd1f..d3c31c92 100644 --- a/stagehand/llm/client.py +++ b/stagehand/llm/client.py @@ -11,9 +11,6 @@ logger = logging.getLogger(__name__) -litellm.enable_json_schema_validation = True - - class LLMClient: """ Client for making LLM calls using the litellm library. diff --git a/stagehand/llm/inference.py b/stagehand/llm/inference.py index cb4c8c8e..24f0de91 100644 --- a/stagehand/llm/inference.py +++ b/stagehand/llm/inference.py @@ -195,9 +195,6 @@ def extract( extract_content = extract_response.choices[0].message.content if isinstance(extract_content, str): try: - logger.info( - f"Extraction response: {extract_content} for schema: {schema}" - ) extracted_data = json.loads(extract_content) except json.JSONDecodeError: logger.error( From d742c528739f65705eb59eb48f77b2fc500c9e49 Mon Sep 17 00:00:00 2001 From: miguel Date: Mon, 2 Jun 2025 16:27:15 -0700 Subject: [PATCH 5/6] remove unnecessary logging --- stagehand/api.py | 9 --------- stagehand/browser.py | 11 +++++++---- stagehand/client.py | 1 - 3 files changed, 7 insertions(+), 14 deletions(-) diff --git a/stagehand/api.py b/stagehand/api.py index ddda08ce..b958a960 100644 --- a/stagehand/api.py +++ b/stagehand/api.py @@ -110,10 +110,6 @@ async def _execute(self, method: str, payload: dict[str, Any]) -> Any: modified_payload = convert_dict_keys_to_camel_case(payload) client = self.httpx_client or httpx.AsyncClient(timeout=self.timeout_settings) - self.logger.debug(f"\n==== EXECUTING {method.upper()} ====") - self.logger.debug(f"URL: {self.api_url}/sessions/{self.session_id}/{method}") - self.logger.debug(f"Payload: {modified_payload}") - self.logger.debug(f"Headers: {headers}") async with client: try: @@ -133,8 +129,6 @@ async def _execute(self, method: str, payload: dict[str, Any]) -> Any: raise RuntimeError( f"Request failed with status {response.status_code}: {error_message}" ) - - self.logger.debug("[STREAM] Processing server response") result = None async for line in response.aiter_lines(): @@ -163,9 +157,6 @@ async def _execute(self, method: str, payload: dict[str, Any]) -> Any: ) elif status == "finished": result = message.get("data", {}).get("result") - self.logger.debug( - "[SYSTEM] Operation completed successfully" - ) elif msg_type == "log": # Process log message using _handle_log await self._handle_log(message) diff --git a/stagehand/browser.py b/stagehand/browser.py index c0c41d9a..431d1ce9 100644 --- a/stagehand/browser.py +++ b/stagehand/browser.py @@ -40,7 +40,6 @@ async def connect_browserbase_browser( # Connect to remote browser via Browserbase SDK and CDP bb = Browserbase(api_key=browserbase_api_key) try: - logger.debug(f"Retrieving Browserbase session details for {session_id}...") session = bb.sessions.retrieve(session_id) if session.status != "RUNNING": raise RuntimeError( @@ -54,7 +53,6 @@ async def connect_browserbase_browser( logger.debug(f"Connecting to remote browser at: {connect_url}") try: browser = await playwright.chromium.connect_over_cdp(connect_url) - logger.debug(f"Connected to remote browser: {browser}") except Exception as e: logger.error(f"Failed to connect Playwright via CDP: {str(e)}") raise @@ -233,7 +231,7 @@ async def connect_local_browser( async def apply_stealth_scripts(context: BrowserContext, logger: StagehandLogger): """Applies JavaScript init scripts to make the browser less detectable.""" - logger.debug("Applying stealth init scripts to the context...") + logger.debug("Applying stealth scripts to the context...") stealth_script = """ (() => { // Override navigator.webdriver @@ -280,7 +278,6 @@ async def apply_stealth_scripts(context: BrowserContext, logger: StagehandLogger """ try: await context.add_init_script(stealth_script) - logger.debug("Stealth init script added successfully.") except Exception as e: logger.error(f"Failed to add stealth init script: {str(e)}") @@ -308,6 +305,12 @@ async def cleanup_browser_resources( await context.close() except Exception as e: logger.error(f"Error closing context: {str(e)}") + if browser: + try: + logger.debug("Closing browser...") + await browser.close() + except Exception as e: + logger.error(f"Error closing browser: {str(e)}") # Clean up temporary user data directory if created if temp_user_data_dir: diff --git a/stagehand/client.py b/stagehand/client.py index 54f20671..53c28611 100644 --- a/stagehand/client.py +++ b/stagehand/client.py @@ -395,7 +395,6 @@ def _get_lock_for_session(self) -> asyncio.Lock: """ if self.session_id not in self._session_locks: self._session_locks[self.session_id] = asyncio.Lock() - self.logger.debug(f"Created lock for session {self.session_id}") return self._session_locks[self.session_id] async def __aenter__(self): From 700d34932f3f8da3d3bf91af651d9e268a70347d Mon Sep 17 00:00:00 2001 From: Miguel <36487034+miguelg719@users.noreply.github.com> Date: Mon, 2 Jun 2025 20:19:41 -0700 Subject: [PATCH 6/6] Update stagehand/handlers/extract_handler.py --- stagehand/handlers/extract_handler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/stagehand/handlers/extract_handler.py b/stagehand/handlers/extract_handler.py index 190e54cb..9025ff82 100644 --- a/stagehand/handlers/extract_handler.py +++ b/stagehand/handlers/extract_handler.py @@ -85,7 +85,6 @@ async def extract( self.logger.info("Getting accessibility tree data") output_string = tree["simplified"] id_to_url_mapping = tree.get("idToUrl", {}) - self.logger.info(f"schema: {schema}") # Transform schema URL fields to numeric IDs if necessary transformed_schema = schema