various improvements for PR based on recommendations

This commit is contained in:
Gilbert Bagaoisan
2024-12-17 20:53:17 -08:00
parent 3795d7dd8e
commit 73b803ddc3

View File

@@ -1,6 +1,6 @@
import logging import logging
from typing import Any, Dict, Literal, Optional, Type from typing import Any, Dict, Literal, Optional, Type
from urllib.parse import urlparse from urllib.parse import unquote, urlparse
from crewai.tools import BaseTool from crewai.tools import BaseTool
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -20,12 +20,28 @@ class SpiderToolSchema(BaseModel):
) )
class SpiderTool(BaseTool): class SpiderToolConfig(BaseModel):
"""Tool for scraping and crawling websites.""" """Configuration settings for SpiderTool.
Contains all default values and constants used by SpiderTool.
Centralizes configuration management for easier maintenance.
"""
# Crawling settings
DEFAULT_CRAWL_LIMIT: int = 5 DEFAULT_CRAWL_LIMIT: int = 5
DEFAULT_RETURN_FORMAT: str = "markdown" DEFAULT_RETURN_FORMAT: str = "markdown"
# Request parameters
DEFAULT_REQUEST_MODE: str = "smart"
FILTER_SVG: bool = True
class SpiderTool(BaseTool):
"""Tool for scraping and crawling websites.
This tool provides functionality to either scrape a single webpage or crawl multiple
pages, returning content in a format suitable for LLM processing.
"""
name: str = "SpiderTool" name: str = "SpiderTool"
description: str = ( description: str = (
"A tool to scrape or crawl a website and return LLM-ready content." "A tool to scrape or crawl a website and return LLM-ready content."
@@ -36,6 +52,7 @@ class SpiderTool(BaseTool):
api_key: Optional[str] = None api_key: Optional[str] = None
spider: Any = None spider: Any = None
log_failures: bool = True log_failures: bool = True
config: SpiderToolConfig = SpiderToolConfig()
def __init__( def __init__(
self, self,
@@ -79,16 +96,26 @@ class SpiderTool(BaseTool):
raise RuntimeError(f"Failed to initialize Spider client: {str(e)}") raise RuntimeError(f"Failed to initialize Spider client: {str(e)}")
def _validate_url(self, url: str) -> bool: def _validate_url(self, url: str) -> bool:
"""Validate URL format. """Validate URL format and security constraints.
Args: Args:
url (str): URL to validate. url (str): URL to validate. Must be a properly formatted HTTP(S) URL
Returns: Returns:
bool: True if valid URL. bool: True if URL is valid and meets security requirements, False otherwise.
""" """
try: try:
result = urlparse(url) url = url.strip()
return all([result.scheme, result.netloc]) decoded_url = unquote(url)
result = urlparse(decoded_url)
if not all([result.scheme, result.netloc]):
return False
if result.scheme not in ["http", "https"]:
return False
return True
except Exception: except Exception:
return False return False
@@ -96,42 +123,80 @@ class SpiderTool(BaseTool):
self, self,
website_url: str, website_url: str,
mode: Literal["scrape", "crawl"] = "scrape", mode: Literal["scrape", "crawl"] = "scrape",
) -> str: ) -> Optional[str]:
params = {} """Execute the spider tool to scrape or crawl the specified website.
url = website_url or self.website_url
if not url: Args:
raise ValueError( website_url (str): The URL to process. Must be a valid HTTP(S) URL.
"Website URL must be provided either during initialization or execution" mode (Literal["scrape", "crawl"]): Operation mode.
) - "scrape": Extract content from single page
- "crawl": Follow links and extract content from multiple pages
if not self._validate_url(url): Returns:
raise ValueError("Invalid URL format") Optional[str]: Extracted content in markdown format, or None if extraction fails
and log_failures is True.
if mode not in ["scrape", "crawl"]: Raises:
raise ValueError("Mode must be either 'scrape' or 'crawl'") ValueError: If URL is invalid or missing, or if mode is invalid.
ImportError: If spider-client package is not properly installed.
params["request"] = "smart" ConnectionError: If network connection fails while accessing the URL.
params["filter_output_svg"] = True Exception: For other runtime errors.
params["return_format"] = self.DEFAULT_RETURN_FORMAT """
if mode == "crawl":
params["limit"] = self.DEFAULT_CRAWL_LIMIT
# Update params with custom params if provided.
# This will override any params passed by LLM.
if self.custom_params:
params.update(self.custom_params)
try: try:
params = {}
url = website_url or self.website_url
if not url:
raise ValueError(
"Website URL must be provided either during initialization or execution"
)
if not self._validate_url(url):
raise ValueError(f"Invalid URL format: {url}")
if mode not in ["scrape", "crawl"]:
raise ValueError(
f"Invalid mode: {mode}. Must be either 'scrape' or 'crawl'"
)
params = {
"request": self.config.DEFAULT_REQUEST_MODE,
"filter_output_svg": self.config.FILTER_SVG,
"return_format": self.config.DEFAULT_RETURN_FORMAT,
}
if mode == "crawl":
params["limit"] = self.config.DEFAULT_CRAWL_LIMIT
if self.custom_params:
params.update(self.custom_params)
action = ( action = (
self.spider.scrape_url if mode == "scrape" else self.spider.crawl_url self.spider.scrape_url if mode == "scrape" else self.spider.crawl_url
) )
return action(url=url, params=params) return action(url=url, params=params)
except ValueError as ve:
if self.log_failures:
logger.error(f"Validation error for URL {url}: {str(ve)}")
return None
raise ve
except ImportError as ie:
logger.error(f"Spider client import error: {str(ie)}")
raise ie
except ConnectionError as ce:
if self.log_failures:
logger.error(f"Connection error while accessing {url}: {str(ce)}")
return None
raise ce
except Exception as e: except Exception as e:
if self.log_failures: if self.log_failures:
logger.error(f"Error fetching data from {url}, exception: {e}") logger.error(
f"Unexpected error during {mode} operation on {url}: {str(e)}"
)
return None return None
else: raise e
raise e