mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
fixing test
This commit is contained in:
207
src/crewai_tools/tools/stagehand_tool/stagehand_extract_tool.py
Normal file
207
src/crewai_tools/tools/stagehand_tool/stagehand_extract_tool.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""Tool for using Stagehand's AI-powered extraction capabilities in CrewAI."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Type
|
||||
import subprocess
|
||||
import json
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class StagehandExtractSchema(BaseModel):
|
||||
"""Schema for data extraction using Stagehand.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Extract a product price
|
||||
tool.run(
|
||||
url="https://example.com/product",
|
||||
instruction="Extract the price of the item",
|
||||
schema={
|
||||
"price": {"type": "number"}
|
||||
}
|
||||
)
|
||||
|
||||
# Extract article content
|
||||
tool.run(
|
||||
url="https://example.com/article",
|
||||
instruction="Extract the article title and content",
|
||||
schema={
|
||||
"title": {"type": "string"},
|
||||
"content": {"type": "string"},
|
||||
"date": {"type": "string", "optional": True}
|
||||
}
|
||||
)
|
||||
```
|
||||
"""
|
||||
url: str = Field(
|
||||
...,
|
||||
description="The URL of the website to extract data from"
|
||||
)
|
||||
instruction: str = Field(
|
||||
...,
|
||||
description="Instructions for what data to extract",
|
||||
min_length=1,
|
||||
max_length=500
|
||||
)
|
||||
schema: Dict[str, Dict[str, Any]] = Field(
|
||||
...,
|
||||
description="Zod-like schema defining the structure of data to extract"
|
||||
)
|
||||
|
||||
|
||||
class StagehandExtractTool(BaseTool):
|
||||
name: str = "StagehandExtractTool"
|
||||
description: str = (
|
||||
"A tool that uses Stagehand's AI-powered extraction to get structured data from websites. "
|
||||
"Requires a schema defining the structure of data to extract."
|
||||
)
|
||||
args_schema: Type[BaseModel] = StagehandExtractSchema
|
||||
config: Optional[Dict[str, Any]] = None
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""Initialize the StagehandExtractTool.
|
||||
|
||||
Args:
|
||||
**kwargs: Additional keyword arguments passed to the base class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Use provided API key or try environment variable
|
||||
if not os.getenv("OPENAI_API_KEY"):
|
||||
raise ValueError(
|
||||
"Set OPENAI_API_KEY environment variable, mandatory for Stagehand"
|
||||
)
|
||||
|
||||
def _convert_to_zod_schema(self, schema: Dict[str, Dict[str, Any]]) -> str:
|
||||
"""Convert Python schema definition to Zod schema string."""
|
||||
zod_parts = []
|
||||
for field_name, field_def in schema.items():
|
||||
field_type = field_def["type"]
|
||||
is_optional = field_def.get("optional", False)
|
||||
|
||||
if field_type == "string":
|
||||
zod_type = "z.string()"
|
||||
elif field_type == "number":
|
||||
zod_type = "z.number()"
|
||||
elif field_type == "boolean":
|
||||
zod_type = "z.boolean()"
|
||||
elif field_type == "array":
|
||||
item_type = field_def.get("items", {"type": "string"})
|
||||
zod_type = f"z.array({self._convert_to_zod_schema({'item': item_type})})"
|
||||
else:
|
||||
zod_type = "z.string()" # Default to string for unknown types
|
||||
|
||||
if is_optional:
|
||||
zod_type += ".optional()"
|
||||
|
||||
zod_parts.append(f"{field_name}: {zod_type}")
|
||||
|
||||
return f"z.object({{ {', '.join(zod_parts)} }})"
|
||||
|
||||
def _run(self, url: str, instruction: str, schema: Dict[str, Dict[str, Any]]) -> Any:
|
||||
"""Execute a Stagehand extract command.
|
||||
|
||||
Args:
|
||||
url: The URL to extract data from
|
||||
instruction: What data to extract
|
||||
schema: Schema defining the structure of data to extract
|
||||
|
||||
Returns:
|
||||
The extracted data matching the provided schema
|
||||
"""
|
||||
logger.debug(
|
||||
"Starting extraction - URL: %s, Instruction: %s, Schema: %s",
|
||||
url,
|
||||
instruction,
|
||||
schema
|
||||
)
|
||||
|
||||
# Convert Python schema to Zod schema
|
||||
zod_schema = self._convert_to_zod_schema(schema)
|
||||
|
||||
# Prepare the Node.js command
|
||||
command = [
|
||||
"node",
|
||||
"-e",
|
||||
f"""
|
||||
const {{ Stagehand }} = require('@browserbasehq/stagehand');
|
||||
const z = require('zod');
|
||||
|
||||
async function run() {{
|
||||
console.log('Initializing Stagehand...');
|
||||
const stagehand = new Stagehand({{
|
||||
apiKey: '{os.getenv("OPENAI_API_KEY")}',
|
||||
env: 'LOCAL'
|
||||
}});
|
||||
|
||||
try {{
|
||||
console.log('Initializing browser...');
|
||||
await stagehand.init();
|
||||
|
||||
console.log('Navigating to:', '{url}');
|
||||
await stagehand.page.goto('{url}');
|
||||
|
||||
console.log('Extracting data...');
|
||||
const result = await stagehand.page.extract({{
|
||||
instruction: '{instruction}',
|
||||
schema: {zod_schema}
|
||||
}});
|
||||
|
||||
process.stdout.write('RESULT_START');
|
||||
process.stdout.write(JSON.stringify({{ data: result, success: true }}));
|
||||
process.stdout.write('RESULT_END');
|
||||
|
||||
await stagehand.close();
|
||||
}} catch (error) {{
|
||||
console.error('Extraction failed:', error);
|
||||
process.stdout.write('RESULT_START');
|
||||
process.stdout.write(JSON.stringify({{
|
||||
error: error.message,
|
||||
name: error.name,
|
||||
success: false
|
||||
}}));
|
||||
process.stdout.write('RESULT_END');
|
||||
process.exit(1);
|
||||
}}
|
||||
}}
|
||||
|
||||
run();
|
||||
"""
|
||||
]
|
||||
|
||||
try:
|
||||
# Execute Node.js script
|
||||
result = subprocess.run(
|
||||
command,
|
||||
check=True,
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
# Extract the JSON result using markers
|
||||
if 'RESULT_START' in result.stdout and 'RESULT_END' in result.stdout:
|
||||
json_str = result.stdout.split('RESULT_START')[1].split('RESULT_END')[0]
|
||||
try:
|
||||
parsed_result = json.loads(json_str)
|
||||
logger.info("Successfully parsed result: %s", parsed_result)
|
||||
if parsed_result.get('success', False):
|
||||
return parsed_result.get('data')
|
||||
else:
|
||||
raise Exception(f"Extraction failed: {parsed_result.get('error', 'Unknown error')}")
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error("Failed to parse JSON output: %s", json_str)
|
||||
raise Exception(f"Invalid JSON response: {e}")
|
||||
else:
|
||||
logger.error("No valid result markers found in output")
|
||||
raise ValueError("No valid output from Stagehand command")
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.error("Node.js script failed with exit code %d", e.returncode)
|
||||
if e.stderr:
|
||||
logger.error("Error output: %s", e.stderr)
|
||||
raise Exception(f"Stagehand command failed: {e}")
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
Reference in New Issue
Block a user