Update scrapegraph_scrape_tool.py

This commit is contained in:
Marco Vinciguerra
2025-01-07 15:51:52 +01:00
parent ad4c711223
commit c27727b16e

View File

@@ -60,16 +60,19 @@ class ScrapegraphScrapeTool(BaseTool):
website_url: Optional[str] = None website_url: Optional[str] = None
user_prompt: Optional[str] = None user_prompt: Optional[str] = None
api_key: Optional[str] = None api_key: Optional[str] = None
enable_logging: bool = False
def __init__( def __init__(
self, self,
website_url: Optional[str] = None, website_url: Optional[str] = None,
user_prompt: Optional[str] = None, user_prompt: Optional[str] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
enable_logging: bool = False,
**kwargs, **kwargs,
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
self.api_key = api_key or os.getenv("SCRAPEGRAPH_API_KEY") self.api_key = api_key or os.getenv("SCRAPEGRAPH_API_KEY")
self.enable_logging = enable_logging
if not self.api_key: if not self.api_key:
raise ValueError("Scrapegraph API key is required") raise ValueError("Scrapegraph API key is required")
@@ -83,8 +86,9 @@ class ScrapegraphScrapeTool(BaseTool):
if user_prompt is not None: if user_prompt is not None:
self.user_prompt = user_prompt self.user_prompt = user_prompt
# Configure logging # Configure logging only if enabled
sgai_logger.set_logging(level="INFO") if self.enable_logging:
sgai_logger.set_logging(level="INFO")
@staticmethod @staticmethod
def _validate_url(url: str) -> None: def _validate_url(url: str) -> None:
@@ -96,22 +100,6 @@ class ScrapegraphScrapeTool(BaseTool):
except Exception: except Exception:
raise ValueError("Invalid URL format. URL must include scheme (http/https) and domain") raise ValueError("Invalid URL format. URL must include scheme (http/https) and domain")
def _handle_api_response(self, response: dict) -> str:
"""Handle and validate API response"""
if not response:
raise RuntimeError("Empty response from Scrapegraph API")
if "error" in response:
error_msg = response.get("error", {}).get("message", "Unknown error")
if "rate limit" in error_msg.lower():
raise RateLimitError(f"Rate limit exceeded: {error_msg}")
raise RuntimeError(f"API error: {error_msg}")
if "result" not in response:
raise RuntimeError("Invalid response format from Scrapegraph API")
return response["result"]
def _run( def _run(
self, self,
**kwargs: Any, **kwargs: Any,
@@ -135,8 +123,7 @@ class ScrapegraphScrapeTool(BaseTool):
user_prompt=user_prompt, user_prompt=user_prompt,
) )
# Handle and validate the response return response
return self._handle_api_response(response)
except RateLimitError: except RateLimitError:
raise # Re-raise rate limit errors raise # Re-raise rate limit errors