mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Update scrapegraph_scrape_tool.py
This commit is contained in:
@@ -60,16 +60,19 @@ class ScrapegraphScrapeTool(BaseTool):
|
|||||||
website_url: Optional[str] = None
|
website_url: Optional[str] = None
|
||||||
user_prompt: Optional[str] = None
|
user_prompt: Optional[str] = None
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
|
enable_logging: bool = False
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
website_url: Optional[str] = None,
|
website_url: Optional[str] = None,
|
||||||
user_prompt: Optional[str] = None,
|
user_prompt: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
|
enable_logging: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.api_key = api_key or os.getenv("SCRAPEGRAPH_API_KEY")
|
self.api_key = api_key or os.getenv("SCRAPEGRAPH_API_KEY")
|
||||||
|
self.enable_logging = enable_logging
|
||||||
|
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("Scrapegraph API key is required")
|
raise ValueError("Scrapegraph API key is required")
|
||||||
@@ -83,8 +86,9 @@ class ScrapegraphScrapeTool(BaseTool):
|
|||||||
if user_prompt is not None:
|
if user_prompt is not None:
|
||||||
self.user_prompt = user_prompt
|
self.user_prompt = user_prompt
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging only if enabled
|
||||||
sgai_logger.set_logging(level="INFO")
|
if self.enable_logging:
|
||||||
|
sgai_logger.set_logging(level="INFO")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _validate_url(url: str) -> None:
|
def _validate_url(url: str) -> None:
|
||||||
@@ -96,22 +100,6 @@ class ScrapegraphScrapeTool(BaseTool):
|
|||||||
except Exception:
|
except Exception:
|
||||||
raise ValueError("Invalid URL format. URL must include scheme (http/https) and domain")
|
raise ValueError("Invalid URL format. URL must include scheme (http/https) and domain")
|
||||||
|
|
||||||
def _handle_api_response(self, response: dict) -> str:
|
|
||||||
"""Handle and validate API response"""
|
|
||||||
if not response:
|
|
||||||
raise RuntimeError("Empty response from Scrapegraph API")
|
|
||||||
|
|
||||||
if "error" in response:
|
|
||||||
error_msg = response.get("error", {}).get("message", "Unknown error")
|
|
||||||
if "rate limit" in error_msg.lower():
|
|
||||||
raise RateLimitError(f"Rate limit exceeded: {error_msg}")
|
|
||||||
raise RuntimeError(f"API error: {error_msg}")
|
|
||||||
|
|
||||||
if "result" not in response:
|
|
||||||
raise RuntimeError("Invalid response format from Scrapegraph API")
|
|
||||||
|
|
||||||
return response["result"]
|
|
||||||
|
|
||||||
def _run(
|
def _run(
|
||||||
self,
|
self,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
@@ -135,8 +123,7 @@ class ScrapegraphScrapeTool(BaseTool):
|
|||||||
user_prompt=user_prompt,
|
user_prompt=user_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Handle and validate the response
|
return response
|
||||||
return self._handle_api_response(response)
|
|
||||||
|
|
||||||
except RateLimitError:
|
except RateLimitError:
|
||||||
raise # Re-raise rate limit errors
|
raise # Re-raise rate limit errors
|
||||||
|
|||||||
Reference in New Issue
Block a user