Stagehand tool improvements (#415)

* Stagehand tool improvements

This commit significantly improves the StagehandTool reliability and usability when working with CrewAI agents by addressing several critical
  issues:

  ## Key Improvements

  ### 1. Atomic Action Support
  - Added _extract_steps() method to break complex instructions into individual steps
  - Added _simplify_instruction() method for intelligent error recovery
  - Sequential execution of micro-actions with proper DOM settling between steps
  - Prevents token limit issues on complex pages by encouraging scoped actions

  ### 2. Enhanced Schema Design
  - Made instruction field optional to handle navigation-only commands
  - Added smart defaults for missing instructions based on command_type
  - Improved field descriptions to guide agents toward atomic actions with location context
  - Prevents "instruction Field required" validation errors

  ### 3. Intelligent API Key Management
  - Added _get_model_api_key() method with automatic detection based on model type
  - Support for OpenAI (GPT), Anthropic (Claude), and Google (Gemini) API keys
  - Removes need for manual model API key configuration

  ### 4. Robust Error Recovery
  - Step-by-step execution with individual error handling per atomic action
  - Automatic retry with simplified instructions when complex actions fail
  - Comprehensive error logging and reporting for debugging
  - Graceful degradation instead of complete failure

  ### 5. Token Management & Performance
  - Tool descriptions encourage atomic, scoped actions (e.g., "click search box in header")
  - Prevents "prompt too long" errors on complex pages like Wikipedia
  - Location-aware instruction patterns for better DOM targeting
  - Reduced observe-act cycles through better instruction decomposition

  ### 6. Enhanced Testing Support
  - Comprehensive async mock objects for testing mode
  - Proper async/sync compatibility for different execution contexts
  - Enhanced resource cleanup and session management

* Update stagehand_tool.py

removeing FixedStagehandTool in favour of StagehandTool

* removed comment

* Cleanup

Revoved unused class
Improved tool description
This commit is contained in:
nicoferdi96
2025-08-13 14:57:11 +02:00
committed by GitHub
parent 41ce4981ac
commit 99e174e575

View File

@@ -1,7 +1,8 @@
import asyncio import asyncio
import json import json
import logging import os
from typing import Dict, List, Optional, Type, Union, Any import re
from typing import Any, Dict, List, Optional, Type, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -9,14 +10,14 @@ from pydantic import BaseModel, Field
_HAS_STAGEHAND = False _HAS_STAGEHAND = False
try: try:
from stagehand import Stagehand, StagehandConfig, StagehandPage from stagehand import Stagehand, StagehandConfig, StagehandPage, configure_logging
from stagehand.schemas import ( from stagehand.schemas import (
ActOptions, ActOptions,
AvailableModel, AvailableModel,
ExtractOptions, ExtractOptions,
ObserveOptions, ObserveOptions,
) )
from stagehand import configure_logging
_HAS_STAGEHAND = True _HAS_STAGEHAND = True
except ImportError: except ImportError:
# Define type stubs for when stagehand is not installed # Define type stubs for when stagehand is not installed
@@ -26,25 +27,19 @@ except ImportError:
ActOptions = Any ActOptions = Any
ExtractOptions = Any ExtractOptions = Any
ObserveOptions = Any ObserveOptions = Any
# Mock configure_logging function # Mock configure_logging function
def configure_logging(level=None, remove_logger_name=None, quiet_dependencies=None): def configure_logging(level=None, remove_logger_name=None, quiet_dependencies=None):
pass pass
# Define only what's needed for class defaults # Define only what's needed for class defaults
class AvailableModel: class AvailableModel:
CLAUDE_3_7_SONNET_LATEST = "anthropic.claude-3-7-sonnet-20240607" CLAUDE_3_7_SONNET_LATEST = "anthropic.claude-3-7-sonnet-20240607"
from crewai.tools import BaseTool from crewai.tools import BaseTool
class StagehandCommandType(str):
ACT = "act"
EXTRACT = "extract"
OBSERVE = "observe"
NAVIGATE = "navigate"
class StagehandResult(BaseModel): class StagehandResult(BaseModel):
"""Result from a Stagehand operation. """Result from a Stagehand operation.
@@ -68,9 +63,9 @@ class StagehandResult(BaseModel):
class StagehandToolSchema(BaseModel): class StagehandToolSchema(BaseModel):
"""Input for StagehandTool.""" """Input for StagehandTool."""
instruction: str = Field( instruction: Optional[str] = Field(
..., None,
description="Natural language instruction describing what you want to do on the website. Be specific about the action you want to perform, data to extract, or elements to observe. If your task is complex, break it down into simple, sequential steps. For example: 'Step 1: Navigate to https://example.com; Step 2: Click the login button; Step 3: Enter your credentials; Step 4: Submit the form.' Complex tasks like 'Search for OpenAI' should be broken down as: 'Step 1: Navigate to https://google.com; Step 2: Type OpenAI in the search box; Step 3: Press Enter or click the search button'.", description="Single atomic action with location context. For reliability on complex pages, use ONE specific action with location hints. Good examples: 'Click the search input field in the header', 'Type Italy in the focused field', 'Press Enter', 'Click the first link in the results area'. Avoid combining multiple actions. For 'navigate' command type, this can be omitted if only URL is provided.",
) )
url: Optional[str] = Field( url: Optional[str] = Field(
None, None,
@@ -78,19 +73,18 @@ class StagehandToolSchema(BaseModel):
) )
command_type: Optional[str] = Field( command_type: Optional[str] = Field(
"act", "act",
description="""The type of command to execute (choose one): description="""The type of command to execute (choose one):
- 'act': Perform an action like clicking buttons, filling forms, etc. (default) - 'act': Perform an action like clicking buttons, filling forms, etc. (default)
- 'navigate': Specifically navigate to a URL - 'navigate': Specifically navigate to a URL
- 'extract': Extract structured data from the page - 'extract': Extract structured data from the page
- 'observe': Identify and analyze elements on the page - 'observe': Identify and analyze elements on the page
""", """,
) )
class StagehandTool(BaseTool): class StagehandTool(BaseTool):
package_dependencies: List[str] = ["stagehand"]
""" """
A tool that uses Stagehand to automate web browser interactions using natural language. A tool that uses Stagehand to automate web browser interactions using natural language with atomic action handling.
Stagehand allows AI agents to interact with websites through a browser, Stagehand allows AI agents to interact with websites through a browser,
performing actions like clicking buttons, filling forms, and extracting data. performing actions like clicking buttons, filling forms, and extracting data.
@@ -101,24 +95,6 @@ class StagehandTool(BaseTool):
3. extract - Extract structured data from web pages 3. extract - Extract structured data from web pages
4. observe - Identify and analyze elements on a page 4. observe - Identify and analyze elements on a page
Usage patterns:
1. Using as a context manager (recommended):
```python
with StagehandTool() as tool:
agent = Agent(tools=[tool])
# ... use the agent
```
2. Manual resource management:
```python
tool = StagehandTool()
try:
agent = Agent(tools=[tool])
# ... use the agent
finally:
tool.close()
```
Usage examples: Usage examples:
- Navigate to a website: instruction="Go to the homepage", url="https://example.com" - Navigate to a website: instruction="Go to the homepage", url="https://example.com"
- Click a button: instruction="Click the login button" - Click a button: instruction="Click the login button"
@@ -136,7 +112,7 @@ class StagehandTool(BaseTool):
name: str = "Web Automation Tool" name: str = "Web Automation Tool"
description: str = """Use this tool to control a web browser and interact with websites using natural language. description: str = """Use this tool to control a web browser and interact with websites using natural language.
Capabilities: Capabilities:
- Navigate to websites and follow links - Navigate to websites and follow links
- Click buttons, links, and other elements - Click buttons, links, and other elements
@@ -144,13 +120,18 @@ class StagehandTool(BaseTool):
- Search within websites - Search within websites
- Extract information from web pages - Extract information from web pages
- Identify and analyze elements on a page - Identify and analyze elements on a page
To use this tool, provide a natural language instruction describing what you want to do. To use this tool, provide a natural language instruction describing what you want to do.
For reliability on complex pages, use specific, atomic instructions with location hints:
- Good: "Click the search box in the header"
- Good: "Type 'Italy' in the focused field"
- Bad: "Search for Italy and click the first result"
For different types of tasks, specify the command_type: For different types of tasks, specify the command_type:
- 'act': For performing actions (default) - 'act': For performing one atomic action (default)
- 'navigate': For navigating to a URL (shorthand for act with navigation) - 'navigate': For navigating to a URL
- 'extract': For getting data from the page - 'extract': For getting data from a specific page section
- 'observe': For finding and analyzing elements - 'observe': For finding elements in a specific area
""" """
args_schema: Type[BaseModel] = StagehandToolSchema args_schema: Type[BaseModel] = StagehandToolSchema
@@ -159,18 +140,21 @@ class StagehandTool(BaseTool):
project_id: Optional[str] = None project_id: Optional[str] = None
model_api_key: Optional[str] = None model_api_key: Optional[str] = None
model_name: Optional[AvailableModel] = AvailableModel.CLAUDE_3_7_SONNET_LATEST model_name: Optional[AvailableModel] = AvailableModel.CLAUDE_3_7_SONNET_LATEST
server_url: Optional[str] = "http://api.stagehand.browserbase.com/v1" server_url: Optional[str] = "https://api.stagehand.browserbase.com/v1"
headless: bool = False headless: bool = False
dom_settle_timeout_ms: int = 3000 dom_settle_timeout_ms: int = 3000
self_heal: bool = True self_heal: bool = True
wait_for_captcha_solves: bool = True wait_for_captcha_solves: bool = True
verbose: int = 1 verbose: int = 1
# Token management settings
max_retries_on_token_limit: int = 3
use_simplified_dom: bool = True
# Instance variables # Instance variables
_stagehand: Optional[Stagehand] = None _stagehand: Optional[Stagehand] = None
_page: Optional[StagehandPage] = None _page: Optional[StagehandPage] = None
_session_id: Optional[str] = None _session_id: Optional[str] = None
_logger: Optional[logging.Logger] = None
_testing: bool = False _testing: bool = False
def __init__( def __init__(
@@ -186,7 +170,7 @@ class StagehandTool(BaseTool):
self_heal: Optional[bool] = None, self_heal: Optional[bool] = None,
wait_for_captcha_solves: Optional[bool] = None, wait_for_captcha_solves: Optional[bool] = None,
verbose: Optional[int] = None, verbose: Optional[int] = None,
_testing: bool = False, # Flag to bypass dependency check in tests _testing: bool = False,
**kwargs, **kwargs,
): ):
# Set testing flag early so that other init logic can rely on it # Set testing flag early so that other init logic can rely on it
@@ -194,21 +178,13 @@ class StagehandTool(BaseTool):
super().__init__(**kwargs) super().__init__(**kwargs)
# Set up logger # Set up logger
import logging
self._logger = logging.getLogger(__name__) self._logger = logging.getLogger(__name__)
# For backward compatibility # Set configuration from parameters or environment
browserbase_api_key = kwargs.get("browserbase_api_key") self.api_key = api_key or os.getenv("BROWSERBASE_API_KEY")
browserbase_project_id = kwargs.get("browserbase_project_id") self.project_id = project_id or os.getenv("BROWSERBASE_PROJECT_ID")
if api_key:
self.api_key = api_key
elif browserbase_api_key:
self.api_key = browserbase_api_key
if project_id:
self.project_id = project_id
elif browserbase_project_id:
self.project_id = browserbase_project_id
if model_api_key: if model_api_key:
self.model_api_key = model_api_key self.model_api_key = model_api_key
@@ -230,226 +206,340 @@ class StagehandTool(BaseTool):
self._session_id = session_id self._session_id = session_id
# Configure logging based on verbosity level # Configure logging based on verbosity level
log_level = logging.ERROR if not self._testing:
if self.verbose == 1: log_level = {1: "INFO", 2: "WARNING", 3: "DEBUG"}.get(self.verbose, "ERROR")
log_level = logging.INFO configure_logging(
elif self.verbose == 2: level=log_level, remove_logger_name=True, quiet_dependencies=True
log_level = logging.WARNING )
elif self.verbose >= 3:
log_level = logging.DEBUG
configure_logging(
level=log_level, remove_logger_name=True, quiet_dependencies=True
)
self._check_required_credentials() self._check_required_credentials()
def _check_required_credentials(self): def _check_required_credentials(self):
"""Validate that required credentials are present.""" """Validate that required credentials are present."""
# Check if stagehand is available, but only if we're not in testing mode
if not self._testing and not _HAS_STAGEHAND: if not self._testing and not _HAS_STAGEHAND:
raise ImportError( raise ImportError(
"`stagehand` package not found, please run `uv add stagehand`" "`stagehand` package not found, please run `uv add stagehand`"
) )
if not self.api_key: if not self.api_key:
raise ValueError("api_key is required (or set BROWSERBASE_API_KEY in env).") raise ValueError("api_key is required (or set BROWSERBASE_API_KEY in env).")
if not self.project_id: if not self.project_id:
raise ValueError( raise ValueError(
"project_id is required (or set BROWSERBASE_PROJECT_ID in env)." "project_id is required (or set BROWSERBASE_PROJECT_ID in env)."
) )
if not self.model_api_key:
raise ValueError( def __del__(self):
"model_api_key is required (or set OPENAI_API_KEY or ANTHROPIC_API_KEY in env)." """Ensure cleanup on deletion"""
try:
self.close()
except Exception:
pass
def _get_model_api_key(self):
"""Get the appropriate API key based on the model being used."""
# Check model type and get appropriate key
model_str = str(self.model_name)
if "gpt" in model_str.lower():
return self.model_api_key or os.getenv("OPENAI_API_KEY")
elif "claude" in model_str.lower() or "anthropic" in model_str.lower():
return self.model_api_key or os.getenv("ANTHROPIC_API_KEY")
elif "gemini" in model_str.lower():
return self.model_api_key or os.getenv("GOOGLE_API_KEY")
else:
# Default to trying OpenAI, then Anthropic
return (
self.model_api_key
or os.getenv("OPENAI_API_KEY")
or os.getenv("ANTHROPIC_API_KEY")
) )
async def _setup_stagehand(self, session_id: Optional[str] = None): async def _setup_stagehand(self, session_id: Optional[str] = None):
"""Initialize Stagehand if not already set up.""" """Initialize Stagehand if not already set up."""
# If we're in testing mode, return mock objects # If we're in testing mode, return mock objects
if self._testing: if self._testing:
if not self._stagehand: if not self._stagehand:
# Create a minimal mock for testing with non-async methods # Create mock objects for testing
class MockPage: class MockPage:
def act(self, options): async def act(self, options):
mock_result = type('MockResult', (), {})() mock_result = type("MockResult", (), {})()
mock_result.model_dump = lambda: {"message": "Action completed successfully"} mock_result.model_dump = lambda: {
"message": "Action completed successfully"
}
return mock_result return mock_result
def goto(self, url): async def goto(self, url):
return None return None
def extract(self, options): async def extract(self, options):
mock_result = type('MockResult', (), {})() mock_result = type("MockResult", (), {})()
mock_result.model_dump = lambda: {"data": "Extracted content"} mock_result.model_dump = lambda: {"data": "Extracted content"}
return mock_result return mock_result
def observe(self, options): async def observe(self, options):
mock_result1 = type('MockResult', (), {"description": "Test element", "method": "click"})() mock_result1 = type(
"MockResult",
(),
{"description": "Test element", "method": "click"},
)()
return [mock_result1] return [mock_result1]
async def wait_for_load_state(self, state):
return None
class MockStagehand: class MockStagehand:
def __init__(self): def __init__(self):
self.page = MockPage() self.page = MockPage()
self.session_id = "test-session-id" self.session_id = "test-session-id"
def init(self): async def init(self):
return None return None
def close(self): async def close(self):
return None return None
self._stagehand = MockStagehand() self._stagehand = MockStagehand()
# No need to await the init call in test mode await self._stagehand.init()
self._stagehand.init()
self._page = self._stagehand.page self._page = self._stagehand.page
self._session_id = self._stagehand.session_id self._session_id = self._stagehand.session_id
return self._stagehand, self._page return self._stagehand, self._page
# Normal initialization for non-testing mode # Normal initialization for non-testing mode
if not self._stagehand: if not self._stagehand:
self._logger.debug("Initializing Stagehand") # Get the appropriate API key based on model type
# Create model client options with the API key model_api_key = self._get_model_api_key()
model_client_options = {"apiKey": self.model_api_key}
# Build the StagehandConfig object if not model_api_key:
raise ValueError(
"No appropriate API key found for model. Please set OPENAI_API_KEY, ANTHROPIC_API_KEY, or GOOGLE_API_KEY"
)
# Build the StagehandConfig with proper parameter names
config = StagehandConfig( config = StagehandConfig(
env="BROWSERBASE", env="BROWSERBASE",
api_key=self.api_key, apiKey=self.api_key, # Browserbase API key (camelCase)
project_id=self.project_id, projectId=self.project_id, # Browserbase project ID (camelCase)
headless=self.headless, modelApiKey=model_api_key, # LLM API key - auto-detected based on model
dom_settle_timeout_ms=self.dom_settle_timeout_ms, modelName=self.model_name,
model_name=self.model_name, apiUrl=self.server_url
self_heal=self.self_heal, if self.server_url
wait_for_captcha_solves=self.wait_for_captcha_solves, else "https://api.stagehand.browserbase.com/v1",
model_client_options=model_client_options, domSettleTimeoutMs=self.dom_settle_timeout_ms,
selfHeal=self.self_heal,
waitForCaptchaSolves=self.wait_for_captcha_solves,
verbose=self.verbose, verbose=self.verbose,
session_id=session_id or self._session_id, browserbaseSessionID=session_id or self._session_id,
) )
# Initialize Stagehand with config and server_url # Initialize Stagehand with config
self._stagehand = Stagehand(config=config, server_url=self.server_url) self._stagehand = Stagehand(config=config)
# Initialize the Stagehand instance # Initialize the Stagehand instance
await self._stagehand.init() await self._stagehand.init()
self._page = self._stagehand.page self._page = self._stagehand.page
self._session_id = self._stagehand.session_id self._session_id = self._stagehand.session_id
self._logger.info(f"Session ID: {self._stagehand.session_id}")
self._logger.info(
f"Browser session: https://www.browserbase.com/sessions/{self._stagehand.session_id}"
)
return self._stagehand, self._page return self._stagehand, self._page
def _extract_steps(self, instruction: str) -> List[str]:
"""Extract individual steps from multi-step instructions"""
# Check for numbered steps (Step 1:, Step 2:, etc.)
if re.search(r"Step \d+:", instruction, re.IGNORECASE):
steps = re.findall(
r"Step \d+:\s*([^;]+?)(?=Step \d+:|$)",
instruction,
re.IGNORECASE | re.DOTALL,
)
return [step.strip() for step in steps if step.strip()]
# Check for semicolon-separated instructions
elif ";" in instruction:
return [step.strip() for step in instruction.split(";") if step.strip()]
else:
return [instruction]
def _simplify_instruction(self, instruction: str) -> str:
"""Simplify complex instructions to basic actions"""
# Extract the core action from complex instructions
instruction_lower = instruction.lower()
if "search" in instruction_lower and "click" in instruction_lower:
# For search tasks, focus on the search action first
if "type" in instruction_lower or "enter" in instruction_lower:
return "click on the search input field"
else:
return "search for content on the page"
elif "click" in instruction_lower:
# Extract what to click
if "button" in instruction_lower:
return "click the button"
elif "link" in instruction_lower:
return "click the link"
elif "search" in instruction_lower:
return "click the search field"
else:
return "click on the element"
elif "type" in instruction_lower or "enter" in instruction_lower:
return "type in the input field"
else:
return instruction # Return as-is if can't simplify
async def _async_run( async def _async_run(
self, self,
instruction: str, instruction: Optional[str] = None,
url: Optional[str] = None, url: Optional[str] = None,
command_type: str = "act", command_type: str = "act",
) -> StagehandResult: ):
"""Asynchronous implementation of the tool.""" """Override _async_run with improved atomic action handling"""
# Handle missing instruction based on command type
if not instruction:
if command_type == "navigate" and url:
instruction = f"Navigate to {url}"
elif command_type == "observe":
instruction = "Observe elements on the page"
elif command_type == "extract":
instruction = "Extract information from the page"
else:
instruction = "Perform the requested action"
# For testing mode, use parent implementation
if self._testing:
return await super()._async_run(instruction, url, command_type)
try: try:
# Special handling for test mode to avoid coroutine issues _, page = await self._setup_stagehand(self._session_id)
if self._testing:
# Return predefined mock results based on command type
if command_type.lower() == "act":
return StagehandResult(
success=True,
data={"message": "Action completed successfully"}
)
elif command_type.lower() == "navigate":
return StagehandResult(
success=True,
data={
"url": url or "https://example.com",
"message": f"Successfully navigated to {url or 'https://example.com'}",
},
)
elif command_type.lower() == "extract":
return StagehandResult(
success=True,
data={"data": "Extracted content", "metadata": {"source": "test"}}
)
elif command_type.lower() == "observe":
return StagehandResult(
success=True,
data=[
{"index": 1, "description": "Test element", "method": "click"}
],
)
else:
return StagehandResult(
success=False,
data={},
error=f"Unknown command type: {command_type}"
)
# Normal execution for non-test mode
stagehand, page = await self._setup_stagehand(self._session_id)
self._logger.info( self._logger.info(
f"Executing {command_type} with instruction: {instruction}" f"Executing {command_type} with instruction: {instruction}"
) )
# Get the API key to pass to model operations
model_api_key = self._get_model_api_key()
model_client_options = {"apiKey": model_api_key}
# Always navigate first if URL is provided and we're doing actions
if url and command_type.lower() == "act":
self._logger.info(f"Navigating to {url} before performing actions")
await page.goto(url)
await page.wait_for_load_state("networkidle")
# Small delay to ensure page is fully loaded
await asyncio.sleep(1)
# Process according to command type # Process according to command type
if command_type.lower() == "act": if command_type.lower() == "act":
# Create act options # Extract steps from complex instructions
act_options = ActOptions( steps = self._extract_steps(instruction)
action=instruction, self._logger.info(f"Extracted {len(steps)} steps: {steps}")
model_name=self.model_name,
dom_settle_timeout_ms=self.dom_settle_timeout_ms,
)
# Execute the act command results = []
result = await page.act(act_options) for i, step in enumerate(steps):
self._logger.info(f"Act operation completed: {result}") self._logger.info(f"Executing step {i + 1}/{len(steps)}: {step}")
return StagehandResult(success=True, data=result.model_dump())
try:
# Create act options with API key for each step
from stagehand.schemas import ActOptions
act_options = ActOptions(
action=step,
modelName=self.model_name,
domSettleTimeoutMs=self.dom_settle_timeout_ms,
modelClientOptions=model_client_options,
)
result = await page.act(act_options)
results.append(result.model_dump())
# Small delay between steps to let DOM settle
if i < len(steps) - 1: # Don't delay after last step
await asyncio.sleep(0.5)
except Exception as step_error:
error_msg = f"Step failed: {step_error}"
self._logger.warning(f"Step {i + 1} failed: {error_msg}")
# Try with simplified instruction
try:
simplified = self._simplify_instruction(step)
if simplified != step:
self._logger.info(
f"Retrying with simplified instruction: {simplified}"
)
act_options = ActOptions(
action=simplified,
modelName=self.model_name,
domSettleTimeoutMs=self.dom_settle_timeout_ms,
modelClientOptions=model_client_options,
)
result = await page.act(act_options)
results.append(result.model_dump())
else:
# If we can't simplify or retry fails, record the error
results.append({"error": error_msg, "step": step})
except Exception as retry_error:
self._logger.error(f"Retry also failed: {retry_error}")
results.append({"error": str(retry_error), "step": step})
# Return combined results
if len(results) == 1:
# Single step, return as-is
if "error" in results[0]:
return self._format_result(
False, results[0], results[0]["error"]
)
return self._format_result(True, results[0])
else:
# Multiple steps, return all results
has_errors = any("error" in result for result in results)
return self._format_result(not has_errors, {"steps": results})
elif command_type.lower() == "navigate": elif command_type.lower() == "navigate":
# For navigation, use the goto method directly # For navigation, use the goto method directly
target_url = url if not url:
if not target_url:
error_msg = "No URL provided for navigation. Please provide a URL." error_msg = "No URL provided for navigation. Please provide a URL."
self._logger.error(error_msg) self._logger.error(error_msg)
return StagehandResult(success=False, data={}, error=error_msg) return self._format_result(False, {}, error_msg)
# Navigate using the goto method result = await page.goto(url)
result = await page.goto(target_url) self._logger.info(f"Navigate operation completed to {url}")
self._logger.info(f"Navigate operation completed to {target_url}") return self._format_result(
return StagehandResult( True,
success=True, {
data={ "url": url,
"url": target_url, "message": f"Successfully navigated to {url}",
"message": f"Successfully navigated to {target_url}",
}, },
) )
elif command_type.lower() == "extract": elif command_type.lower() == "extract":
# Create extract options # Create extract options with API key
from stagehand.schemas import ExtractOptions
extract_options = ExtractOptions( extract_options = ExtractOptions(
instruction=instruction, instruction=instruction,
model_name=self.model_name, modelName=self.model_name,
dom_settle_timeout_ms=self.dom_settle_timeout_ms, domSettleTimeoutMs=self.dom_settle_timeout_ms,
use_text_extract=True, useTextExtract=True,
modelClientOptions=model_client_options, # Add API key here
) )
# Execute the extract command
result = await page.extract(extract_options) result = await page.extract(extract_options)
self._logger.info(f"Extract operation completed successfully {result}") self._logger.info(f"Extract operation completed successfully {result}")
return StagehandResult(success=True, data=result.model_dump()) return self._format_result(True, result.model_dump())
elif command_type.lower() == "observe": elif command_type.lower() == "observe":
# Create observe options # Create observe options with API key
from stagehand.schemas import ObserveOptions
observe_options = ObserveOptions( observe_options = ObserveOptions(
instruction=instruction, instruction=instruction,
model_name=self.model_name, modelName=self.model_name,
only_visible=True, onlyVisible=True,
dom_settle_timeout_ms=self.dom_settle_timeout_ms, domSettleTimeoutMs=self.dom_settle_timeout_ms,
modelClientOptions=model_client_options, # Add API key here
) )
# Execute the observe command
results = await page.observe(observe_options) results = await page.observe(observe_options)
# Format the observation results # Format the observation results
@@ -466,21 +556,25 @@ class StagehandTool(BaseTool):
self._logger.info( self._logger.info(
f"Observe operation completed with {len(formatted_results)} elements found" f"Observe operation completed with {len(formatted_results)} elements found"
) )
return StagehandResult(success=True, data=formatted_results) return self._format_result(True, formatted_results)
else: else:
error_msg = f"Unknown command type: {command_type}. Please use 'act', 'navigate', 'extract', or 'observe'." error_msg = f"Unknown command type: {command_type}"
self._logger.error(error_msg) self._logger.error(error_msg)
return StagehandResult(success=False, data={}, error=error_msg) return self._format_result(False, {}, error_msg)
except Exception as e: except Exception as e:
error_msg = f"Error using Stagehand: {str(e)}" error_msg = f"Error using Stagehand: {str(e)}"
self._logger.error(f"Operation failed: {error_msg}") self._logger.error(f"Operation failed: {error_msg}")
return StagehandResult(success=False, data={}, error=error_msg) return self._format_result(False, {}, error_msg)
def _format_result(self, success, data, error=None):
"""Helper to format results consistently"""
return StagehandResult(success=success, data=data, error=error)
def _run( def _run(
self, self,
instruction: str, instruction: Optional[str] = None,
url: Optional[str] = None, url: Optional[str] = None,
command_type: str = "act", command_type: str = "act",
) -> str: ) -> str:
@@ -495,14 +589,28 @@ class StagehandTool(BaseTool):
Returns: Returns:
The result of the browser automation task The result of the browser automation task
""" """
# Handle missing instruction based on command type
if not instruction:
if command_type == "navigate" and url:
instruction = f"Navigate to {url}"
elif command_type == "observe":
instruction = "Observe elements on the page"
elif command_type == "extract":
instruction = "Extract information from the page"
else:
instruction = "Perform the requested action"
# Create an event loop if we're not already in one # Create an event loop if we're not already in one
try: try:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
if loop.is_running(): if loop.is_running():
# We're in an existing event loop, use it # We're in an existing event loop, use it
result = asyncio.run_coroutine_threadsafe( import concurrent.futures
self._async_run(instruction, url, command_type), loop
).result() with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(
asyncio.run, self._async_run(instruction, url, command_type)
)
result = future.result()
else: else:
# We have a loop but it's not running # We have a loop but it's not running
result = loop.run_until_complete( result = loop.run_until_complete(
@@ -512,7 +620,23 @@ class StagehandTool(BaseTool):
# Format the result for output # Format the result for output
if result.success: if result.success:
if command_type.lower() == "act": if command_type.lower() == "act":
return f"Action result: {result.data.get('message', 'Completed')}" if isinstance(result.data, dict) and "steps" in result.data:
# Multiple steps
step_messages = []
for i, step in enumerate(result.data["steps"]):
if "error" in step:
step_messages.append(
f"Step {i + 1}: Failed - {step['error']}"
)
else:
step_messages.append(
f"Step {i + 1}: {step.get('message', 'Completed')}"
)
return "\n".join(step_messages)
else:
return (
f"Action result: {result.data.get('message', 'Completed')}"
)
elif command_type.lower() == "extract": elif command_type.lower() == "extract":
return f"Extracted data: {json.dumps(result.data, indent=2)}" return f"Extracted data: {json.dumps(result.data, indent=2)}"
elif command_type.lower() == "observe": elif command_type.lower() == "observe":
@@ -525,7 +649,6 @@ class StagehandTool(BaseTool):
formatted_results.append( formatted_results.append(
f"Suggested action: {element['method']}" f"Suggested action: {element['method']}"
) )
return "\n".join(formatted_results) return "\n".join(formatted_results)
else: else:
return json.dumps(result.data, indent=2) return json.dumps(result.data, indent=2)
@@ -551,7 +674,7 @@ class StagehandTool(BaseTool):
self._stagehand = None self._stagehand = None
self._page = None self._page = None
return return
if self._stagehand: if self._stagehand:
await self._stagehand.close() await self._stagehand.close()
self._stagehand = None self._stagehand = None
@@ -565,7 +688,7 @@ class StagehandTool(BaseTool):
self._stagehand = None self._stagehand = None
self._page = None self._page = None
return return
if self._stagehand: if self._stagehand:
try: try:
# Handle both synchronous and asynchronous cases # Handle both synchronous and asynchronous cases
@@ -574,7 +697,15 @@ class StagehandTool(BaseTool):
try: try:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
if loop.is_running(): if loop.is_running():
asyncio.run_coroutine_threadsafe(self._async_close(), loop).result() import concurrent.futures
with (
concurrent.futures.ThreadPoolExecutor() as executor
):
future = executor.submit(
asyncio.run, self._async_close()
)
future.result()
else: else:
loop.run_until_complete(self._async_close()) loop.run_until_complete(self._async_close())
except RuntimeError: except RuntimeError:
@@ -584,11 +715,10 @@ class StagehandTool(BaseTool):
self._stagehand.close() self._stagehand.close()
except Exception as e: except Exception as e:
# Log but don't raise - we're cleaning up # Log but don't raise - we're cleaning up
if self._logger: print(f"Error closing Stagehand: {str(e)}")
self._logger.error(f"Error closing Stagehand: {str(e)}")
self._stagehand = None self._stagehand = None
if self._page: if self._page:
self._page = None self._page = None
@@ -599,3 +729,4 @@ class StagehandTool(BaseTool):
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
"""Exit the context manager and clean up resources.""" """Exit the context manager and clean up resources."""
self.close() self.close()