various improvements for PR based on recommendations

This commit is contained in:
Gilbert Bagaoisan
2024-12-17 20:53:17 -08:00
parent 3795d7dd8e
commit 73b803ddc3

View File

@@ -1,6 +1,6 @@
import logging
from typing import Any, Dict, Literal, Optional, Type
from urllib.parse import urlparse
from urllib.parse import unquote, urlparse
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
@@ -20,12 +20,28 @@ class SpiderToolSchema(BaseModel):
)
class SpiderTool(BaseTool):
"""Tool for scraping and crawling websites."""
class SpiderToolConfig(BaseModel):
"""Configuration settings for SpiderTool.
Contains all default values and constants used by SpiderTool.
Centralizes configuration management for easier maintenance.
"""
# Crawling settings
DEFAULT_CRAWL_LIMIT: int = 5
DEFAULT_RETURN_FORMAT: str = "markdown"
# Request parameters
DEFAULT_REQUEST_MODE: str = "smart"
FILTER_SVG: bool = True
class SpiderTool(BaseTool):
"""Tool for scraping and crawling websites.
This tool provides functionality to either scrape a single webpage or crawl multiple
pages, returning content in a format suitable for LLM processing.
"""
name: str = "SpiderTool"
description: str = (
"A tool to scrape or crawl a website and return LLM-ready content."
@@ -36,6 +52,7 @@ class SpiderTool(BaseTool):
api_key: Optional[str] = None
spider: Any = None
log_failures: bool = True
config: SpiderToolConfig = SpiderToolConfig()
def __init__(
self,
@@ -79,16 +96,26 @@ class SpiderTool(BaseTool):
raise RuntimeError(f"Failed to initialize Spider client: {str(e)}")
def _validate_url(self, url: str) -> bool:
"""Validate URL format.
"""Validate URL format and security constraints.
Args:
url (str): URL to validate.
url (str): URL to validate. Must be a properly formatted HTTP(S) URL
Returns:
bool: True if valid URL.
bool: True if URL is valid and meets security requirements, False otherwise.
"""
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
url = url.strip()
decoded_url = unquote(url)
result = urlparse(decoded_url)
if not all([result.scheme, result.netloc]):
return False
if result.scheme not in ["http", "https"]:
return False
return True
except Exception:
return False
@@ -96,42 +123,80 @@ class SpiderTool(BaseTool):
self,
website_url: str,
mode: Literal["scrape", "crawl"] = "scrape",
) -> str:
params = {}
url = website_url or self.website_url
) -> Optional[str]:
"""Execute the spider tool to scrape or crawl the specified website.
if not url:
raise ValueError(
"Website URL must be provided either during initialization or execution"
)
Args:
website_url (str): The URL to process. Must be a valid HTTP(S) URL.
mode (Literal["scrape", "crawl"]): Operation mode.
- "scrape": Extract content from single page
- "crawl": Follow links and extract content from multiple pages
if not self._validate_url(url):
raise ValueError("Invalid URL format")
Returns:
Optional[str]: Extracted content in markdown format, or None if extraction fails
and log_failures is True.
if mode not in ["scrape", "crawl"]:
raise ValueError("Mode must be either 'scrape' or 'crawl'")
params["request"] = "smart"
params["filter_output_svg"] = True
params["return_format"] = self.DEFAULT_RETURN_FORMAT
if mode == "crawl":
params["limit"] = self.DEFAULT_CRAWL_LIMIT
# Update params with custom params if provided.
# This will override any params passed by LLM.
if self.custom_params:
params.update(self.custom_params)
Raises:
ValueError: If URL is invalid or missing, or if mode is invalid.
ImportError: If spider-client package is not properly installed.
ConnectionError: If network connection fails while accessing the URL.
Exception: For other runtime errors.
"""
try:
params = {}
url = website_url or self.website_url
if not url:
raise ValueError(
"Website URL must be provided either during initialization or execution"
)
if not self._validate_url(url):
raise ValueError(f"Invalid URL format: {url}")
if mode not in ["scrape", "crawl"]:
raise ValueError(
f"Invalid mode: {mode}. Must be either 'scrape' or 'crawl'"
)
params = {
"request": self.config.DEFAULT_REQUEST_MODE,
"filter_output_svg": self.config.FILTER_SVG,
"return_format": self.config.DEFAULT_RETURN_FORMAT,
}
if mode == "crawl":
params["limit"] = self.config.DEFAULT_CRAWL_LIMIT
if self.custom_params:
params.update(self.custom_params)
action = (
self.spider.scrape_url if mode == "scrape" else self.spider.crawl_url
)
return action(url=url, params=params)
except ValueError as ve:
if self.log_failures:
logger.error(f"Validation error for URL {url}: {str(ve)}")
return None
raise ve
except ImportError as ie:
logger.error(f"Spider client import error: {str(ie)}")
raise ie
except ConnectionError as ce:
if self.log_failures:
logger.error(f"Connection error while accessing {url}: {str(ce)}")
return None
raise ce
except Exception as e:
if self.log_failures:
logger.error(f"Error fetching data from {url}, exception: {e}")
logger.error(
f"Unexpected error during {mode} operation on {url}: {str(e)}"
)
return None
else:
raise e
raise e