diff --git a/src/crewai_tools/__init__.py b/src/crewai_tools/__init__.py index 890dc36f8..6274840fc 100644 --- a/src/crewai_tools/__init__.py +++ b/src/crewai_tools/__init__.py @@ -16,6 +16,7 @@ from .tools import ( FirecrawlScrapeWebsiteTool, FirecrawlSearchTool, GithubSearchTool, + HyperbrowserLoadTool, JSONSearchTool, LinkupSearchTool, LlamaIndexTool, @@ -23,6 +24,9 @@ from .tools import ( MultiOnTool, MySQLSearchTool, NL2SQLTool, + PatronusEvalTool, + PatronusLocalEvaluatorTool, + PatronusPredefinedCriteriaEvalTool, PDFSearchTool, PGSearchTool, RagTool, @@ -32,20 +36,22 @@ from .tools import ( ScrapeWebsiteTool, ScrapflyScrapeWebsiteTool, SeleniumScrapingTool, + SerpApiGoogleSearchTool, + SerpApiGoogleShoppingTool, SerperDevTool, SerplyJobSearchTool, SerplyNewsSearchTool, SerplyScholarSearchTool, SerplyWebpageToMarkdownTool, SerplyWebSearchTool, + SnowflakeConfig, + SnowflakeSearchTool, SpiderTool, TXTSearchTool, VisionTool, + WeaviateVectorSearchTool, WebsiteSearchTool, XMLSearchTool, YoutubeChannelSearchTool, YoutubeVideoSearchTool, - WeaviateVectorSearchTool, - SerpApiGoogleSearchTool, - SerpApiGoogleShoppingTool, ) diff --git a/src/crewai_tools/tools/__init__.py b/src/crewai_tools/tools/__init__.py index c8ee55084..b4f46e073 100644 --- a/src/crewai_tools/tools/__init__.py +++ b/src/crewai_tools/tools/__init__.py @@ -19,6 +19,7 @@ from .firecrawl_scrape_website_tool.firecrawl_scrape_website_tool import ( ) from .firecrawl_search_tool.firecrawl_search_tool import FirecrawlSearchTool from .github_search_tool.github_search_tool import GithubSearchTool +from .hyperbrowser_load_tool.hyperbrowser_load_tool import HyperbrowserLoadTool from .json_search_tool.json_search_tool import JSONSearchTool from .linkup.linkup_search_tool import LinkupSearchTool from .llamaindex_tool.llamaindex_tool import LlamaIndexTool @@ -26,33 +27,46 @@ from .mdx_seach_tool.mdx_search_tool import MDXSearchTool from .multion_tool.multion_tool import MultiOnTool from .mysql_search_tool.mysql_search_tool import MySQLSearchTool from .nl2sql.nl2sql_tool import NL2SQLTool +from .patronus_eval_tool import ( + PatronusEvalTool, + PatronusLocalEvaluatorTool, + PatronusPredefinedCriteriaEvalTool, +) from .pdf_search_tool.pdf_search_tool import PDFSearchTool from .pg_seach_tool.pg_search_tool import PGSearchTool from .rag.rag_tool import RagTool from .scrape_element_from_website.scrape_element_from_website import ( ScrapeElementFromWebsiteTool, ) -from .scrapegraph_scrape_tool.scrapegraph_scrape_tool import ScrapegraphScrapeTool, ScrapegraphScrapeToolSchema from .scrape_website_tool.scrape_website_tool import ScrapeWebsiteTool +from .scrapegraph_scrape_tool.scrapegraph_scrape_tool import ( + ScrapegraphScrapeTool, + ScrapegraphScrapeToolSchema, +) from .scrapfly_scrape_website_tool.scrapfly_scrape_website_tool import ( ScrapflyScrapeWebsiteTool, ) from .selenium_scraping_tool.selenium_scraping_tool import SeleniumScrapingTool +from .serpapi_tool.serpapi_google_search_tool import SerpApiGoogleSearchTool +from .serpapi_tool.serpapi_google_shopping_tool import SerpApiGoogleShoppingTool from .serper_dev_tool.serper_dev_tool import SerperDevTool from .serply_api_tool.serply_job_search_tool import SerplyJobSearchTool from .serply_api_tool.serply_news_search_tool import SerplyNewsSearchTool from .serply_api_tool.serply_scholar_search_tool import SerplyScholarSearchTool from .serply_api_tool.serply_web_search_tool import SerplyWebSearchTool from .serply_api_tool.serply_webpage_to_markdown_tool import SerplyWebpageToMarkdownTool +from .snowflake_search_tool import ( + SnowflakeConfig, + SnowflakeSearchTool, + SnowflakeSearchToolInput, +) from .spider_tool.spider_tool import SpiderTool from .txt_search_tool.txt_search_tool import TXTSearchTool from .vision_tool.vision_tool import VisionTool +from .weaviate_tool.vector_search import WeaviateVectorSearchTool from .website_search.website_search_tool import WebsiteSearchTool from .xml_search_tool.xml_search_tool import XMLSearchTool from .youtube_channel_search_tool.youtube_channel_search_tool import ( YoutubeChannelSearchTool, ) from .youtube_video_search_tool.youtube_video_search_tool import YoutubeVideoSearchTool -from .weaviate_tool.vector_search import WeaviateVectorSearchTool -from .serpapi_tool.serpapi_google_search_tool import SerpApiGoogleSearchTool -from .serpapi_tool.serpapi_google_shopping_tool import SerpApiGoogleShoppingTool diff --git a/src/crewai_tools/tools/browserbase_load_tool/browserbase_load_tool.py b/src/crewai_tools/tools/browserbase_load_tool/browserbase_load_tool.py index 2ca1b95fc..d3f76e0a6 100644 --- a/src/crewai_tools/tools/browserbase_load_tool/browserbase_load_tool.py +++ b/src/crewai_tools/tools/browserbase_load_tool/browserbase_load_tool.py @@ -1,8 +1,8 @@ import os from typing import Any, Optional, Type -from pydantic import BaseModel, Field from crewai.tools import BaseTool +from pydantic import BaseModel, Field class BrowserbaseLoadToolSchema(BaseModel): @@ -11,12 +11,10 @@ class BrowserbaseLoadToolSchema(BaseModel): class BrowserbaseLoadTool(BaseTool): name: str = "Browserbase web load tool" - description: str = ( - "Load webpages url in a headless browser using Browserbase and return the contents" - ) + description: str = "Load webpages url in a headless browser using Browserbase and return the contents" args_schema: Type[BaseModel] = BrowserbaseLoadToolSchema - api_key: Optional[str] = os.getenv('BROWSERBASE_API_KEY') - project_id: Optional[str] = os.getenv('BROWSERBASE_PROJECT_ID') + api_key: Optional[str] = os.getenv("BROWSERBASE_API_KEY") + project_id: Optional[str] = os.getenv("BROWSERBASE_PROJECT_ID") text_content: Optional[bool] = False session_id: Optional[str] = None proxy: Optional[bool] = None @@ -33,7 +31,9 @@ class BrowserbaseLoadTool(BaseTool): ): super().__init__(**kwargs) if not self.api_key: - raise EnvironmentError("BROWSERBASE_API_KEY environment variable is required for initialization") + raise EnvironmentError( + "BROWSERBASE_API_KEY environment variable is required for initialization" + ) try: from browserbase import Browserbase # type: ignore except ImportError: diff --git a/src/crewai_tools/tools/code_interpreter_tool/code_interpreter_tool.py b/src/crewai_tools/tools/code_interpreter_tool/code_interpreter_tool.py index fd0d39932..2a0f9ffe6 100644 --- a/src/crewai_tools/tools/code_interpreter_tool/code_interpreter_tool.py +++ b/src/crewai_tools/tools/code_interpreter_tool/code_interpreter_tool.py @@ -2,10 +2,12 @@ import importlib.util import os from typing import List, Optional, Type +from crewai.tools import BaseTool from docker import from_env as docker_from_env +from docker import DockerClient from docker.models.containers import Container from docker.errors import ImageNotFound, NotFound -from crewai.tools import BaseTool +from docker.models.containers import Container from pydantic import BaseModel, Field @@ -30,7 +32,7 @@ class CodeInterpreterTool(BaseTool): default_image_tag: str = "code-interpreter:latest" code: Optional[str] = None user_dockerfile_path: Optional[str] = None - user_docker_base_url: Optional[str] = None + user_docker_base_url: Optional[str] = None unsafe_mode: bool = False @staticmethod @@ -43,7 +45,11 @@ class CodeInterpreterTool(BaseTool): Verify if the Docker image is available. Optionally use a user-provided Dockerfile. """ - client = docker_from_env() if self.user_docker_base_url == None else docker.DockerClient(base_url=self.user_docker_base_url) + client = ( + docker_from_env() + if self.user_docker_base_url == None + else DockerClient(base_url=self.user_docker_base_url) + ) try: client.images.get(self.default_image_tag) @@ -76,9 +82,7 @@ class CodeInterpreterTool(BaseTool): else: return self.run_code_in_docker(code, libraries_used) - def _install_libraries( - self, container: Container, libraries: List[str] - ) -> None: + def _install_libraries(self, container: Container, libraries: List[str]) -> None: """ Install missing libraries in the Docker container """ @@ -135,4 +139,4 @@ class CodeInterpreterTool(BaseTool): exec(code, {}, exec_locals) return exec_locals.get("result", "No result variable found.") except Exception as e: - return f"An error occurred: {str(e)}" + return f"An error occurred: {str(e)}" \ No newline at end of file diff --git a/src/crewai_tools/tools/directory_read_tool/directory_read_tool.py b/src/crewai_tools/tools/directory_read_tool/directory_read_tool.py index 6033202be..8488f391e 100644 --- a/src/crewai_tools/tools/directory_read_tool/directory_read_tool.py +++ b/src/crewai_tools/tools/directory_read_tool/directory_read_tool.py @@ -8,8 +8,6 @@ from pydantic import BaseModel, Field class FixedDirectoryReadToolSchema(BaseModel): """Input for DirectoryReadTool.""" - pass - class DirectoryReadToolSchema(FixedDirectoryReadToolSchema): """Input for DirectoryReadTool.""" diff --git a/src/crewai_tools/tools/file_read_tool/file_read_tool.py b/src/crewai_tools/tools/file_read_tool/file_read_tool.py index 323a26d51..55fb5d490 100644 --- a/src/crewai_tools/tools/file_read_tool/file_read_tool.py +++ b/src/crewai_tools/tools/file_read_tool/file_read_tool.py @@ -32,6 +32,7 @@ class FileReadTool(BaseTool): >>> content = tool.run() # Reads /path/to/file.txt >>> content = tool.run(file_path="/path/to/other.txt") # Reads other.txt """ + name: str = "Read a file's content" description: str = "A tool that reads the content of a file. To use this tool, provide a 'file_path' parameter with the path to the file you want to read." args_schema: Type[BaseModel] = FileReadToolSchema @@ -45,10 +46,11 @@ class FileReadTool(BaseTool): this becomes the default file path for the tool. **kwargs: Additional keyword arguments passed to BaseTool. """ - super().__init__(**kwargs) if file_path is not None: - self.file_path = file_path - self.description = f"A tool that reads file content. The default file is {file_path}, but you can provide a different 'file_path' parameter to read another file." + kwargs['description'] = f"A tool that reads file content. The default file is {file_path}, but you can provide a different 'file_path' parameter to read another file." + + super().__init__(**kwargs) + self.file_path = file_path def _run( self, @@ -57,7 +59,7 @@ class FileReadTool(BaseTool): file_path = kwargs.get("file_path", self.file_path) if file_path is None: return "Error: No file path provided. Please provide a file path either in the constructor or as an argument." - + try: with open(file_path, "r") as file: return file.read() @@ -66,16 +68,4 @@ class FileReadTool(BaseTool): except PermissionError: return f"Error: Permission denied when trying to read file: {file_path}" except Exception as e: - return f"Error: Failed to read file {file_path}. {str(e)}" - - def _generate_description(self) -> None: - """Generate the tool description based on file path. - - This method updates the tool's description to include information about - the default file path while maintaining the ability to specify a different - file at runtime. - - Returns: - None - """ - self.description = f"A tool that can be used to read {self.file_path}'s content." + return f"Error: Failed to read file {file_path}. {str(e)}" \ No newline at end of file diff --git a/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py b/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py index ed454a1bd..f975d3301 100644 --- a/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py +++ b/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py @@ -15,9 +15,7 @@ class FileWriterToolInput(BaseModel): class FileWriterTool(BaseTool): name: str = "File Writer Tool" - description: str = ( - "A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input." - ) + description: str = "A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input." args_schema: Type[BaseModel] = FileWriterToolInput def _run(self, **kwargs: Any) -> str: diff --git a/src/crewai_tools/tools/firecrawl_crawl_website_tool/firecrawl_crawl_website_tool.py b/src/crewai_tools/tools/firecrawl_crawl_website_tool/firecrawl_crawl_website_tool.py index edada38dd..dcb70e291 100644 --- a/src/crewai_tools/tools/firecrawl_crawl_website_tool/firecrawl_crawl_website_tool.py +++ b/src/crewai_tools/tools/firecrawl_crawl_website_tool/firecrawl_crawl_website_tool.py @@ -1,9 +1,8 @@ import os from typing import TYPE_CHECKING, Any, Dict, Optional, Type -from pydantic import BaseModel, ConfigDict, Field - from crewai.tools import BaseTool +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr # Type checking import if TYPE_CHECKING: @@ -12,6 +11,14 @@ if TYPE_CHECKING: class FirecrawlCrawlWebsiteToolSchema(BaseModel): url: str = Field(description="Website URL") + crawler_options: Optional[Dict[str, Any]] = Field( + default=None, description="Options for crawling" + ) + timeout: Optional[int] = Field( + default=30000, + description="Timeout in milliseconds for the crawling operation. The default value is 30000.", + ) + class FirecrawlCrawlWebsiteTool(BaseTool): model_config = ConfigDict( @@ -20,25 +27,10 @@ class FirecrawlCrawlWebsiteTool(BaseTool): name: str = "Firecrawl web crawl tool" description: str = "Crawl webpages using Firecrawl and return the contents" args_schema: Type[BaseModel] = FirecrawlCrawlWebsiteToolSchema - firecrawl_app: Optional["FirecrawlApp"] = None api_key: Optional[str] = None - url: Optional[str] = None - params: Optional[Dict[str, Any]] = None - poll_interval: Optional[int] = 2 - idempotency_key: Optional[str] = None + _firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None) def __init__(self, api_key: Optional[str] = None, **kwargs): - """Initialize FirecrawlCrawlWebsiteTool. - - Args: - api_key (Optional[str]): Firecrawl API key. If not provided, will check FIRECRAWL_API_KEY env var. - url (Optional[str]): Base URL to crawl. Can be overridden by the _run method. - firecrawl_app (Optional[FirecrawlApp]): Previously created FirecrawlApp instance. - params (Optional[Dict[str, Any]]): Additional parameters to pass to the FirecrawlApp. - poll_interval (Optional[int]): Poll interval for the FirecrawlApp. - idempotency_key (Optional[str]): Idempotency key for the FirecrawlApp. - **kwargs: Additional arguments passed to BaseTool. - """ super().__init__(**kwargs) try: from firecrawl import FirecrawlApp # type: ignore @@ -47,28 +39,28 @@ class FirecrawlCrawlWebsiteTool(BaseTool): "`firecrawl` package not found, please run `pip install firecrawl-py`" ) - # Allows passing a previously created FirecrawlApp instance - # or builds a new one with the provided API key - if not self.firecrawl_app: - client_api_key = api_key or os.getenv("FIRECRAWL_API_KEY") - if not client_api_key: - raise ValueError( - "FIRECRAWL_API_KEY is not set. Please provide it either via the constructor " - "with the `api_key` argument or by setting the FIRECRAWL_API_KEY environment variable." - ) - self.firecrawl_app = FirecrawlApp(api_key=client_api_key) + client_api_key = api_key or os.getenv("FIRECRAWL_API_KEY") + if not client_api_key: + raise ValueError( + "FIRECRAWL_API_KEY is not set. Please provide it either via the constructor " + "with the `api_key` argument or by setting the FIRECRAWL_API_KEY environment variable." + ) + self._firecrawl = FirecrawlApp(api_key=client_api_key) - def _run(self, url: str): - # Unless url has been previously set via constructor by the user, - # use the url argument provided by the agent at runtime. - base_url = self.url or url + def _run( + self, + url: str, + crawler_options: Optional[Dict[str, Any]] = None, + timeout: Optional[int] = 30000, + ): + if crawler_options is None: + crawler_options = {} - return self.firecrawl_app.crawl_url( - base_url, - params=self.params, - poll_interval=self.poll_interval, - idempotency_key=self.idempotency_key - ) + options = { + "crawlerOptions": crawler_options, + "timeout": timeout, + } + return self._firecrawl.crawl_url(url, options) try: @@ -80,4 +72,3 @@ except ImportError: """ When this tool is not used, then exception can be ignored. """ - pass diff --git a/src/crewai_tools/tools/firecrawl_scrape_website_tool/firecrawl_scrape_website_tool.py b/src/crewai_tools/tools/firecrawl_scrape_website_tool/firecrawl_scrape_website_tool.py index 9ab7d293e..3f5f8c4c4 100644 --- a/src/crewai_tools/tools/firecrawl_scrape_website_tool/firecrawl_scrape_website_tool.py +++ b/src/crewai_tools/tools/firecrawl_scrape_website_tool/firecrawl_scrape_website_tool.py @@ -1,7 +1,7 @@ -from typing import TYPE_CHECKING, Any, Dict, Optional, Type +from typing import TYPE_CHECKING, Optional, Type from crewai.tools import BaseTool -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr # Type checking import if TYPE_CHECKING: @@ -10,14 +10,8 @@ if TYPE_CHECKING: class FirecrawlScrapeWebsiteToolSchema(BaseModel): url: str = Field(description="Website URL") - page_options: Optional[Dict[str, Any]] = Field( - default=None, description="Options for page scraping" - ) - extractor_options: Optional[Dict[str, Any]] = Field( - default=None, description="Options for data extraction" - ) timeout: Optional[int] = Field( - default=None, + default=30000, description="Timeout in milliseconds for the scraping operation. The default value is 30000.", ) @@ -27,10 +21,10 @@ class FirecrawlScrapeWebsiteTool(BaseTool): arbitrary_types_allowed=True, validate_assignment=True, frozen=False ) name: str = "Firecrawl web scrape tool" - description: str = "Scrape webpages url using Firecrawl and return the contents" + description: str = "Scrape webpages using Firecrawl and return the contents" args_schema: Type[BaseModel] = FirecrawlScrapeWebsiteToolSchema api_key: Optional[str] = None - firecrawl: Optional["FirecrawlApp"] = None # Updated to use TYPE_CHECKING + _firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None) def __init__(self, api_key: Optional[str] = None, **kwargs): super().__init__(**kwargs) @@ -41,28 +35,23 @@ class FirecrawlScrapeWebsiteTool(BaseTool): "`firecrawl` package not found, please run `pip install firecrawl-py`" ) - self.firecrawl = FirecrawlApp(api_key=api_key) + self._firecrawl = FirecrawlApp(api_key=api_key) def _run( self, url: str, - page_options: Optional[Dict[str, Any]] = None, - extractor_options: Optional[Dict[str, Any]] = None, - timeout: Optional[int] = None, + timeout: Optional[int] = 30000, ): - if page_options is None: - page_options = {} - if extractor_options is None: - extractor_options = {} - if timeout is None: - timeout = 30000 - options = { - "pageOptions": page_options, - "extractorOptions": extractor_options, + "formats": ["markdown"], + "onlyMainContent": True, + "includeTags": [], + "excludeTags": [], + "headers": {}, + "waitFor": 0, "timeout": timeout, } - return self.firecrawl.scrape_url(url, options) + return self._firecrawl.scrape_url(url, options) try: @@ -74,4 +63,3 @@ except ImportError: """ When this tool is not used, then exception can be ignored. """ - pass diff --git a/src/crewai_tools/tools/firecrawl_search_tool/firecrawl_search_tool.py b/src/crewai_tools/tools/firecrawl_search_tool/firecrawl_search_tool.py index 5efd274de..da483fb34 100644 --- a/src/crewai_tools/tools/firecrawl_search_tool/firecrawl_search_tool.py +++ b/src/crewai_tools/tools/firecrawl_search_tool/firecrawl_search_tool.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Type from crewai.tools import BaseTool -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr # Type checking import if TYPE_CHECKING: @@ -10,20 +10,34 @@ if TYPE_CHECKING: class FirecrawlSearchToolSchema(BaseModel): query: str = Field(description="Search query") - page_options: Optional[Dict[str, Any]] = Field( - default=None, description="Options for result formatting" + limit: Optional[int] = Field( + default=5, description="Maximum number of results to return" ) - search_options: Optional[Dict[str, Any]] = Field( - default=None, description="Options for searching" + tbs: Optional[str] = Field(default=None, description="Time-based search parameter") + lang: Optional[str] = Field( + default="en", description="Language code for search results" + ) + country: Optional[str] = Field( + default="us", description="Country code for search results" + ) + location: Optional[str] = Field( + default=None, description="Location parameter for search results" + ) + timeout: Optional[int] = Field(default=60000, description="Timeout in milliseconds") + scrape_options: Optional[Dict[str, Any]] = Field( + default=None, description="Options for scraping search results" ) class FirecrawlSearchTool(BaseTool): + model_config = ConfigDict( + arbitrary_types_allowed=True, validate_assignment=True, frozen=False + ) name: str = "Firecrawl web search tool" description: str = "Search webpages using Firecrawl and return the results" args_schema: Type[BaseModel] = FirecrawlSearchToolSchema api_key: Optional[str] = None - firecrawl: Optional["FirecrawlApp"] = None + _firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None) def __init__(self, api_key: Optional[str] = None, **kwargs): super().__init__(**kwargs) @@ -33,19 +47,39 @@ class FirecrawlSearchTool(BaseTool): raise ImportError( "`firecrawl` package not found, please run `pip install firecrawl-py`" ) - - self.firecrawl = FirecrawlApp(api_key=api_key) + self._firecrawl = FirecrawlApp(api_key=api_key) def _run( self, query: str, - page_options: Optional[Dict[str, Any]] = None, - result_options: Optional[Dict[str, Any]] = None, + limit: Optional[int] = 5, + tbs: Optional[str] = None, + lang: Optional[str] = "en", + country: Optional[str] = "us", + location: Optional[str] = None, + timeout: Optional[int] = 60000, + scrape_options: Optional[Dict[str, Any]] = None, ): - if page_options is None: - page_options = {} - if result_options is None: - result_options = {} + if scrape_options is None: + scrape_options = {} - options = {"pageOptions": page_options, "resultOptions": result_options} - return self.firecrawl.search(query, **options) + options = { + "limit": limit, + "tbs": tbs, + "lang": lang, + "country": country, + "location": location, + "timeout": timeout, + "scrapeOptions": scrape_options, + } + return self._firecrawl.search(query, options) + + +try: + from firecrawl import FirecrawlApp + + # Rebuild the model after class is defined + FirecrawlSearchTool.model_rebuild() +except ImportError: + # Exception can be ignored if the tool is not used + pass diff --git a/src/crewai_tools/tools/github_search_tool/github_search_tool.py b/src/crewai_tools/tools/github_search_tool/github_search_tool.py index 4bf8b9e05..6ba7b919c 100644 --- a/src/crewai_tools/tools/github_search_tool/github_search_tool.py +++ b/src/crewai_tools/tools/github_search_tool/github_search_tool.py @@ -27,9 +27,7 @@ class GithubSearchToolSchema(FixedGithubSearchToolSchema): class GithubSearchTool(RagTool): name: str = "Search a github repo's content" - description: str = ( - "A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities." - ) + description: str = "A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities." summarize: bool = False gh_token: str args_schema: Type[BaseModel] = GithubSearchToolSchema diff --git a/src/crewai_tools/tools/hyperbrowser_load_tool/README.md b/src/crewai_tools/tools/hyperbrowser_load_tool/README.md new file mode 100644 index 000000000..e95864f5a --- /dev/null +++ b/src/crewai_tools/tools/hyperbrowser_load_tool/README.md @@ -0,0 +1,42 @@ +# HyperbrowserLoadTool + +## Description + +[Hyperbrowser](https://hyperbrowser.ai) is a platform for running and scaling headless browsers. It lets you launch and manage browser sessions at scale and provides easy to use solutions for any webscraping needs, such as scraping a single page or crawling an entire site. + +Key Features: +- Instant Scalability - Spin up hundreds of browser sessions in seconds without infrastructure headaches +- Simple Integration - Works seamlessly with popular tools like Puppeteer and Playwright +- Powerful APIs - Easy to use APIs for scraping/crawling any site, and much more +- Bypass Anti-Bot Measures - Built-in stealth mode, ad blocking, automatic CAPTCHA solving, and rotating proxies + +For more information about Hyperbrowser, please visit the [Hyperbrowser website](https://hyperbrowser.ai) or if you want to check out the docs, you can visit the [Hyperbrowser docs](https://docs.hyperbrowser.ai). + +## Installation + +- Head to [Hyperbrowser](https://app.hyperbrowser.ai/) to sign up and generate an API key. Once you've done this set the `HYPERBROWSER_API_KEY` environment variable or you can pass it to the `HyperbrowserLoadTool` constructor. +- Install the [Hyperbrowser SDK](https://github.com/hyperbrowserai/python-sdk): + +``` +pip install hyperbrowser 'crewai[tools]' +``` + +## Example + +Utilize the HyperbrowserLoadTool as follows to allow your agent to load websites: + +```python +from crewai_tools import HyperbrowserLoadTool + +tool = HyperbrowserLoadTool() +``` + +## Arguments + +`__init__` arguments: +- `api_key`: Optional. Specifies Hyperbrowser API key. Defaults to the `HYPERBROWSER_API_KEY` environment variable. + +`run` arguments: +- `url`: The base URL to start scraping or crawling from. +- `operation`: Optional. Specifies the operation to perform on the website. Either 'scrape' or 'crawl'. Defaults is 'scrape'. +- `params`: Optional. Specifies the params for the operation. For more information on the supported params, visit https://docs.hyperbrowser.ai/reference/sdks/python/scrape#start-scrape-job-and-wait or https://docs.hyperbrowser.ai/reference/sdks/python/crawl#start-crawl-job-and-wait. diff --git a/src/crewai_tools/tools/hyperbrowser_load_tool/hyperbrowser_load_tool.py b/src/crewai_tools/tools/hyperbrowser_load_tool/hyperbrowser_load_tool.py new file mode 100644 index 000000000..b802d1859 --- /dev/null +++ b/src/crewai_tools/tools/hyperbrowser_load_tool/hyperbrowser_load_tool.py @@ -0,0 +1,103 @@ +import os +from typing import Any, Optional, Type, Dict, Literal, Union + +from crewai.tools import BaseTool +from pydantic import BaseModel, Field + + +class HyperbrowserLoadToolSchema(BaseModel): + url: str = Field(description="Website URL") + operation: Literal['scrape', 'crawl'] = Field(description="Operation to perform on the website. Either 'scrape' or 'crawl'") + params: Optional[Dict] = Field(description="Optional params for scrape or crawl. For more information on the supported params, visit https://docs.hyperbrowser.ai/reference/sdks/python/scrape#start-scrape-job-and-wait or https://docs.hyperbrowser.ai/reference/sdks/python/crawl#start-crawl-job-and-wait") + +class HyperbrowserLoadTool(BaseTool): + """HyperbrowserLoadTool. + + Scrape or crawl web pages and load the contents with optional parameters for configuring content extraction. + Requires the `hyperbrowser` package. + Get your API Key from https://app.hyperbrowser.ai/ + + Args: + api_key: The Hyperbrowser API key, can be set as an environment variable `HYPERBROWSER_API_KEY` or passed directly + """ + name: str = "Hyperbrowser web load tool" + description: str = "Scrape or crawl a website using Hyperbrowser and return the contents in properly formatted markdown or html" + args_schema: Type[BaseModel] = HyperbrowserLoadToolSchema + api_key: Optional[str] = None + hyperbrowser: Optional[Any] = None + + def __init__(self, api_key: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + self.api_key = api_key or os.getenv('HYPERBROWSER_API_KEY') + if not api_key: + raise ValueError( + "`api_key` is required, please set the `HYPERBROWSER_API_KEY` environment variable or pass it directly" + ) + + try: + from hyperbrowser import Hyperbrowser + except ImportError: + raise ImportError("`hyperbrowser` package not found, please run `pip install hyperbrowser`") + + if not self.api_key: + raise ValueError("HYPERBROWSER_API_KEY is not set. Please provide it either via the constructor with the `api_key` argument or by setting the HYPERBROWSER_API_KEY environment variable.") + + self.hyperbrowser = Hyperbrowser(api_key=self.api_key) + + def _prepare_params(self, params: Dict) -> Dict: + """Prepare session and scrape options parameters.""" + try: + from hyperbrowser.models.session import CreateSessionParams + from hyperbrowser.models.scrape import ScrapeOptions + except ImportError: + raise ImportError( + "`hyperbrowser` package not found, please run `pip install hyperbrowser`" + ) + + if "scrape_options" in params: + if "formats" in params["scrape_options"]: + formats = params["scrape_options"]["formats"] + if not all(fmt in ["markdown", "html"] for fmt in formats): + raise ValueError("formats can only contain 'markdown' or 'html'") + + if "session_options" in params: + params["session_options"] = CreateSessionParams(**params["session_options"]) + if "scrape_options" in params: + params["scrape_options"] = ScrapeOptions(**params["scrape_options"]) + return params + + def _extract_content(self, data: Union[Any, None]): + """Extract content from response data.""" + content = "" + if data: + content = data.markdown or data.html or "" + return content + + def _run(self, url: str, operation: Literal['scrape', 'crawl'] = 'scrape', params: Optional[Dict] = {}): + try: + from hyperbrowser.models.scrape import StartScrapeJobParams + from hyperbrowser.models.crawl import StartCrawlJobParams + except ImportError: + raise ImportError( + "`hyperbrowser` package not found, please run `pip install hyperbrowser`" + ) + + params = self._prepare_params(params) + + if operation == 'scrape': + scrape_params = StartScrapeJobParams(url=url, **params) + scrape_resp = self.hyperbrowser.scrape.start_and_wait(scrape_params) + content = self._extract_content(scrape_resp.data) + return content + else: + crawl_params = StartCrawlJobParams(url=url, **params) + crawl_resp = self.hyperbrowser.crawl.start_and_wait(crawl_params) + content = "" + if crawl_resp.data: + for page in crawl_resp.data: + page_content = self._extract_content(page) + if page_content: + content += ( + f"\n{'-'*50}\nUrl: {page.url}\nContent:\n{page_content}\n" + ) + return content diff --git a/src/crewai_tools/tools/jina_scrape_website_tool/jina_scrape_website_tool.py b/src/crewai_tools/tools/jina_scrape_website_tool/jina_scrape_website_tool.py index a10a4ffdb..86f771cd0 100644 --- a/src/crewai_tools/tools/jina_scrape_website_tool/jina_scrape_website_tool.py +++ b/src/crewai_tools/tools/jina_scrape_website_tool/jina_scrape_website_tool.py @@ -13,9 +13,7 @@ class JinaScrapeWebsiteToolInput(BaseModel): class JinaScrapeWebsiteTool(BaseTool): name: str = "JinaScrapeWebsiteTool" - description: str = ( - "A tool that can be used to read a website content using Jina.ai reader and return markdown content." - ) + description: str = "A tool that can be used to read a website content using Jina.ai reader and return markdown content." args_schema: Type[BaseModel] = JinaScrapeWebsiteToolInput website_url: Optional[str] = None api_key: Optional[str] = None diff --git a/src/crewai_tools/tools/linkup/linkup_search_tool.py b/src/crewai_tools/tools/linkup/linkup_search_tool.py index b172ad029..486663d3e 100644 --- a/src/crewai_tools/tools/linkup/linkup_search_tool.py +++ b/src/crewai_tools/tools/linkup/linkup_search_tool.py @@ -2,6 +2,7 @@ from typing import Any try: from linkup import LinkupClient + LINKUP_AVAILABLE = True except ImportError: LINKUP_AVAILABLE = False @@ -9,10 +10,13 @@ except ImportError: from pydantic import PrivateAttr + class LinkupSearchTool: name: str = "Linkup Search Tool" - description: str = "Performs an API call to Linkup to retrieve contextual information." - _client: LinkupClient = PrivateAttr() # type: ignore + description: str = ( + "Performs an API call to Linkup to retrieve contextual information." + ) + _client: LinkupClient = PrivateAttr() # type: ignore def __init__(self, api_key: str): """ @@ -25,7 +29,9 @@ class LinkupSearchTool: ) self._client = LinkupClient(api_key=api_key) - def _run(self, query: str, depth: str = "standard", output_type: str = "searchResults") -> dict: + def _run( + self, query: str, depth: str = "standard", output_type: str = "searchResults" + ) -> dict: """ Executes a search using the Linkup API. @@ -36,9 +42,7 @@ class LinkupSearchTool: """ try: response = self._client.search( - query=query, - depth=depth, - output_type=output_type + query=query, depth=depth, output_type=output_type ) results = [ {"name": result.name, "url": result.url, "content": result.content} diff --git a/src/crewai_tools/tools/mysql_search_tool/mysql_search_tool.py b/src/crewai_tools/tools/mysql_search_tool/mysql_search_tool.py index f931a006b..a472e1761 100644 --- a/src/crewai_tools/tools/mysql_search_tool/mysql_search_tool.py +++ b/src/crewai_tools/tools/mysql_search_tool/mysql_search_tool.py @@ -17,9 +17,7 @@ class MySQLSearchToolSchema(BaseModel): class MySQLSearchTool(RagTool): name: str = "Search a database's table content" - description: str = ( - "A tool that can be used to semantic search a query from a database table's content." - ) + description: str = "A tool that can be used to semantic search a query from a database table's content." args_schema: Type[BaseModel] = MySQLSearchToolSchema db_uri: str = Field(..., description="Mandatory database URI") diff --git a/src/crewai_tools/tools/patronus_eval_tool/__init__.py b/src/crewai_tools/tools/patronus_eval_tool/__init__.py new file mode 100644 index 000000000..351cced92 --- /dev/null +++ b/src/crewai_tools/tools/patronus_eval_tool/__init__.py @@ -0,0 +1,3 @@ +from .patronus_eval_tool import PatronusEvalTool +from .patronus_local_evaluator_tool import PatronusLocalEvaluatorTool +from .patronus_predefined_criteria_eval_tool import PatronusPredefinedCriteriaEvalTool diff --git a/src/crewai_tools/tools/patronus_eval_tool/example.py b/src/crewai_tools/tools/patronus_eval_tool/example.py new file mode 100644 index 000000000..185e9f485 --- /dev/null +++ b/src/crewai_tools/tools/patronus_eval_tool/example.py @@ -0,0 +1,55 @@ +import random + +from crewai import Agent, Crew, Task +from patronus import Client, EvaluationResult +from patronus_local_evaluator_tool import PatronusLocalEvaluatorTool + +# Test the PatronusLocalEvaluatorTool where agent uses the local evaluator +client = Client() + + +# Example of an evaluator that returns a random pass/fail result +@client.register_local_evaluator("random_evaluator") +def random_evaluator(**kwargs): + score = random.random() + return EvaluationResult( + score_raw=score, + pass_=score >= 0.5, + explanation="example explanation", # Optional justification for LLM judges + ) + + +# 1. Uses PatronusEvalTool: agent can pick the best evaluator and criteria +# patronus_eval_tool = PatronusEvalTool() + +# 2. Uses PatronusPredefinedCriteriaEvalTool: agent uses the defined evaluator and criteria +# patronus_eval_tool = PatronusPredefinedCriteriaEvalTool( +# evaluators=[{"evaluator": "judge", "criteria": "contains-code"}] +# ) + +# 3. Uses PatronusLocalEvaluatorTool: agent uses user defined evaluator +patronus_eval_tool = PatronusLocalEvaluatorTool( + patronus_client=client, + evaluator="random_evaluator", + evaluated_model_gold_answer="example label", +) + +# Create a new agent +coding_agent = Agent( + role="Coding Agent", + goal="Generate high quality code and verify that the output is code by using Patronus AI's evaluation tool.", + backstory="You are an experienced coder who can generate high quality python code. You can follow complex instructions accurately and effectively.", + tools=[patronus_eval_tool], + verbose=True, +) + +# Define tasks +generate_code = Task( + description="Create a simple program to generate the first N numbers in the Fibonacci sequence. Select the most appropriate evaluator and criteria for evaluating your output.", + expected_output="Program that generates the first N numbers in the Fibonacci sequence.", + agent=coding_agent, +) + +crew = Crew(agents=[coding_agent], tasks=[generate_code]) + +crew.kickoff() diff --git a/src/crewai_tools/tools/patronus_eval_tool/patronus_eval_tool.py b/src/crewai_tools/tools/patronus_eval_tool/patronus_eval_tool.py new file mode 100644 index 000000000..be1f410e2 --- /dev/null +++ b/src/crewai_tools/tools/patronus_eval_tool/patronus_eval_tool.py @@ -0,0 +1,141 @@ +import json +import os +import warnings +from typing import Any, Dict, List, Optional + +import requests +from crewai.tools import BaseTool + + +class PatronusEvalTool(BaseTool): + name: str = "Patronus Evaluation Tool" + evaluate_url: str = "https://api.patronus.ai/v1/evaluate" + evaluators: List[Dict[str, str]] = [] + criteria: List[Dict[str, str]] = [] + description: str = "" + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + temp_evaluators, temp_criteria = self._init_run() + self.evaluators = temp_evaluators + self.criteria = temp_criteria + self.description = self._generate_description() + warnings.warn( + "You are allowing the agent to select the best evaluator and criteria when you use the `PatronusEvalTool`. If this is not intended then please use `PatronusPredefinedCriteriaEvalTool` instead." + ) + + def _init_run(self): + evaluators_set = json.loads( + requests.get( + "https://api.patronus.ai/v1/evaluators", + headers={ + "accept": "application/json", + "X-API-KEY": os.environ["PATRONUS_API_KEY"], + }, + ).text + )["evaluators"] + ids, evaluators = set(), [] + for ev in evaluators_set: + if not ev["deprecated"] and ev["id"] not in ids: + evaluators.append( + { + "id": ev["id"], + "name": ev["name"], + "description": ev["description"], + "aliases": ev["aliases"], + } + ) + ids.add(ev["id"]) + + criteria_set = json.loads( + requests.get( + "https://api.patronus.ai/v1/evaluator-criteria", + headers={ + "accept": "application/json", + "X-API-KEY": os.environ["PATRONUS_API_KEY"], + }, + ).text + )["evaluator_criteria"] + criteria = [] + for cr in criteria_set: + if cr["config"].get("pass_criteria", None): + if cr["config"].get("rubric", None): + criteria.append( + { + "evaluator": cr["evaluator_family"], + "name": cr["name"], + "pass_criteria": cr["config"]["pass_criteria"], + "rubric": cr["config"]["rubric"], + } + ) + else: + criteria.append( + { + "evaluator": cr["evaluator_family"], + "name": cr["name"], + "pass_criteria": cr["config"]["pass_criteria"], + } + ) + elif cr["description"]: + criteria.append( + { + "evaluator": cr["evaluator_family"], + "name": cr["name"], + "description": cr["description"], + } + ) + + return evaluators, criteria + + def _generate_description(self) -> str: + criteria = "\n".join([json.dumps(i) for i in self.criteria]) + return f"""This tool calls the Patronus Evaluation API that takes the following arguments: + 1. evaluated_model_input: str: The agent's task description in simple text + 2. evaluated_model_output: str: The agent's output of the task + 3. evaluated_model_retrieved_context: str: The agent's context + 4. evaluators: This is a list of dictionaries containing one of the following evaluators and the corresponding criteria. An example input for this field: [{{"evaluator": "Judge", "criteria": "patronus:is-code"}}] + + Evaluators: + {criteria} + + You must ONLY choose the most appropriate evaluator and criteria based on the "pass_criteria" or "description" fields for your evaluation task and nothing from outside of the options present.""" + + def _run( + self, + evaluated_model_input: Optional[str], + evaluated_model_output: Optional[str], + evaluated_model_retrieved_context: Optional[str], + evaluators: List[Dict[str, str]], + ) -> Any: + # Assert correct format of evaluators + evals = [] + for ev in evaluators: + evals.append( + { + "evaluator": ev["evaluator"].lower(), + "criteria": ev["name"] if "name" in ev else ev["criteria"], + } + ) + + data = { + "evaluated_model_input": evaluated_model_input, + "evaluated_model_output": evaluated_model_output, + "evaluated_model_retrieved_context": evaluated_model_retrieved_context, + "evaluators": evals, + } + + headers = { + "X-API-KEY": os.getenv("PATRONUS_API_KEY"), + "accept": "application/json", + "content-type": "application/json", + } + + response = requests.post( + self.evaluate_url, headers=headers, data=json.dumps(data) + ) + if response.status_code != 200: + raise Exception( + f"Failed to evaluate model input and output. Response status code: {response.status_code}. Reason: {response.text}" + ) + + return response.json() diff --git a/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py b/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py new file mode 100644 index 000000000..66781c593 --- /dev/null +++ b/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py @@ -0,0 +1,90 @@ +from typing import Any, Type + +from crewai.tools import BaseTool +from patronus import Client +from pydantic import BaseModel, Field + + +class FixedLocalEvaluatorToolSchema(BaseModel): + evaluated_model_input: str = Field( + ..., description="The agent's task description in simple text" + ) + evaluated_model_output: str = Field( + ..., description="The agent's output of the task" + ) + evaluated_model_retrieved_context: str = Field( + ..., description="The agent's context" + ) + evaluated_model_gold_answer: str = Field( + ..., description="The agent's gold answer only if available" + ) + evaluator: str = Field(..., description="The registered local evaluator") + + +class PatronusLocalEvaluatorTool(BaseTool): + name: str = "Patronus Local Evaluator Tool" + evaluator: str = "The registered local evaluator" + evaluated_model_gold_answer: str = "The agent's gold answer" + description: str = "This tool is used to evaluate the model input and output using custom function evaluators." + client: Any = None + args_schema: Type[BaseModel] = FixedLocalEvaluatorToolSchema + + class Config: + arbitrary_types_allowed = True + + def __init__( + self, + patronus_client: Client, + evaluator: str, + evaluated_model_gold_answer: str, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.client = patronus_client + if evaluator: + self.evaluator = evaluator + self.evaluated_model_gold_answer = evaluated_model_gold_answer + self.description = f"This tool calls the Patronus Evaluation API that takes an additional argument in addition to the following new argument:\n evaluators={evaluator}, evaluated_model_gold_answer={evaluated_model_gold_answer}" + self._generate_description() + print( + f"Updating judge evaluator, gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}" + ) + + def _run( + self, + **kwargs: Any, + ) -> Any: + evaluated_model_input = kwargs.get("evaluated_model_input") + evaluated_model_output = kwargs.get("evaluated_model_output") + evaluated_model_retrieved_context = kwargs.get( + "evaluated_model_retrieved_context" + ) + evaluated_model_gold_answer = self.evaluated_model_gold_answer + evaluator = self.evaluator + + result = self.client.evaluate( + evaluator=evaluator, + evaluated_model_input=( + evaluated_model_input + if isinstance(evaluated_model_input, str) + else evaluated_model_input.get("description") + ), + evaluated_model_output=( + evaluated_model_output + if isinstance(evaluated_model_output, str) + else evaluated_model_output.get("description") + ), + evaluated_model_retrieved_context=( + evaluated_model_retrieved_context + if isinstance(evaluated_model_retrieved_context, str) + else evaluated_model_retrieved_context.get("description") + ), + evaluated_model_gold_answer=( + evaluated_model_gold_answer + if isinstance(evaluated_model_gold_answer, str) + else evaluated_model_gold_answer.get("description") + ), + tags={}, # Optional metadata, supports arbitrary kv pairs + ) + output = f"Evaluation result: {result.pass_}, Explanation: {result.explanation}" + return output diff --git a/src/crewai_tools/tools/patronus_eval_tool/patronus_predefined_criteria_eval_tool.py b/src/crewai_tools/tools/patronus_eval_tool/patronus_predefined_criteria_eval_tool.py new file mode 100644 index 000000000..cf906586d --- /dev/null +++ b/src/crewai_tools/tools/patronus_eval_tool/patronus_predefined_criteria_eval_tool.py @@ -0,0 +1,104 @@ +import json +import os +from typing import Any, Dict, List, Type + +import requests +from crewai.tools import BaseTool +from pydantic import BaseModel, Field + + +class FixedBaseToolSchema(BaseModel): + evaluated_model_input: Dict = Field( + ..., description="The agent's task description in simple text" + ) + evaluated_model_output: Dict = Field( + ..., description="The agent's output of the task" + ) + evaluated_model_retrieved_context: Dict = Field( + ..., description="The agent's context" + ) + evaluated_model_gold_answer: Dict = Field( + ..., description="The agent's gold answer only if available" + ) + evaluators: List[Dict[str, str]] = Field( + ..., + description="List of dictionaries containing the evaluator and criteria to evaluate the model input and output. An example input for this field: [{'evaluator': '[evaluator-from-user]', 'criteria': '[criteria-from-user]'}]", + ) + + +class PatronusPredefinedCriteriaEvalTool(BaseTool): + """ + PatronusEvalTool is a tool to automatically evaluate and score agent interactions. + + Results are logged to the Patronus platform at app.patronus.ai + """ + + name: str = "Call Patronus API tool for evaluation of model inputs and outputs" + description: str = """This tool calls the Patronus Evaluation API that takes the following arguments:""" + evaluate_url: str = "https://api.patronus.ai/v1/evaluate" + args_schema: Type[BaseModel] = FixedBaseToolSchema + evaluators: List[Dict[str, str]] = [] + + def __init__(self, evaluators: List[Dict[str, str]], **kwargs: Any): + super().__init__(**kwargs) + if evaluators: + self.evaluators = evaluators + self.description = f"This tool calls the Patronus Evaluation API that takes an additional argument in addition to the following new argument:\n evaluators={evaluators}" + self._generate_description() + print(f"Updating judge criteria to: {self.evaluators}") + + def _run( + self, + **kwargs: Any, + ) -> Any: + evaluated_model_input = kwargs.get("evaluated_model_input") + evaluated_model_output = kwargs.get("evaluated_model_output") + evaluated_model_retrieved_context = kwargs.get( + "evaluated_model_retrieved_context" + ) + evaluated_model_gold_answer = kwargs.get("evaluated_model_gold_answer") + evaluators = self.evaluators + + headers = { + "X-API-KEY": os.getenv("PATRONUS_API_KEY"), + "accept": "application/json", + "content-type": "application/json", + } + + data = { + "evaluated_model_input": ( + evaluated_model_input + if isinstance(evaluated_model_input, str) + else evaluated_model_input.get("description") + ), + "evaluated_model_output": ( + evaluated_model_output + if isinstance(evaluated_model_output, str) + else evaluated_model_output.get("description") + ), + "evaluated_model_retrieved_context": ( + evaluated_model_retrieved_context + if isinstance(evaluated_model_retrieved_context, str) + else evaluated_model_retrieved_context.get("description") + ), + "evaluated_model_gold_answer": ( + evaluated_model_gold_answer + if isinstance(evaluated_model_gold_answer, str) + else evaluated_model_gold_answer.get("description") + ), + "evaluators": ( + evaluators + if isinstance(evaluators, list) + else evaluators.get("description") + ), + } + + response = requests.post( + self.evaluate_url, headers=headers, data=json.dumps(data) + ) + if response.status_code != 200: + raise Exception( + f"Failed to evaluate model input and output. Status code: {response.status_code}. Reason: {response.text}" + ) + + return response.json() diff --git a/src/crewai_tools/tools/pdf_text_writing_tool/pdf_text_writing_tool.py b/src/crewai_tools/tools/pdf_text_writing_tool/pdf_text_writing_tool.py index 851593167..ad4d847b6 100644 --- a/src/crewai_tools/tools/pdf_text_writing_tool/pdf_text_writing_tool.py +++ b/src/crewai_tools/tools/pdf_text_writing_tool/pdf_text_writing_tool.py @@ -1,7 +1,9 @@ -from typing import Optional, Type -from pydantic import BaseModel, Field -from pypdf import PdfReader, PdfWriter, PageObject, ContentStream, NameObject, Font from pathlib import Path +from typing import Optional, Type + +from pydantic import BaseModel, Field +from pypdf import ContentStream, Font, NameObject, PageObject, PdfReader, PdfWriter + from crewai_tools.tools.rag.rag_tool import RagTool diff --git a/src/crewai_tools/tools/pg_seach_tool/pg_search_tool.py b/src/crewai_tools/tools/pg_seach_tool/pg_search_tool.py index dc75470a2..ec0207aa7 100644 --- a/src/crewai_tools/tools/pg_seach_tool/pg_search_tool.py +++ b/src/crewai_tools/tools/pg_seach_tool/pg_search_tool.py @@ -17,9 +17,7 @@ class PGSearchToolSchema(BaseModel): class PGSearchTool(RagTool): name: str = "Search a database's table content" - description: str = ( - "A tool that can be used to semantic search a query from a database table's content." - ) + description: str = "A tool that can be used to semantic search a query from a database table's content." args_schema: Type[BaseModel] = PGSearchToolSchema db_uri: str = Field(..., description="Mandatory database URI") diff --git a/src/crewai_tools/tools/scrape_element_from_website/scrape_element_from_website.py b/src/crewai_tools/tools/scrape_element_from_website/scrape_element_from_website.py index 14757d247..f1e215bf3 100644 --- a/src/crewai_tools/tools/scrape_element_from_website/scrape_element_from_website.py +++ b/src/crewai_tools/tools/scrape_element_from_website/scrape_element_from_website.py @@ -10,8 +10,6 @@ from pydantic import BaseModel, Field class FixedScrapeElementFromWebsiteToolSchema(BaseModel): """Input for ScrapeElementFromWebsiteTool.""" - pass - class ScrapeElementFromWebsiteToolSchema(FixedScrapeElementFromWebsiteToolSchema): """Input for ScrapeElementFromWebsiteTool.""" diff --git a/src/crewai_tools/tools/scrape_website_tool/scrape_website_tool.py b/src/crewai_tools/tools/scrape_website_tool/scrape_website_tool.py index 8cfc5d136..0e7e25ca6 100644 --- a/src/crewai_tools/tools/scrape_website_tool/scrape_website_tool.py +++ b/src/crewai_tools/tools/scrape_website_tool/scrape_website_tool.py @@ -11,8 +11,6 @@ from pydantic import BaseModel, Field class FixedScrapeWebsiteToolSchema(BaseModel): """Input for ScrapeWebsiteTool.""" - pass - class ScrapeWebsiteToolSchema(FixedScrapeWebsiteToolSchema): """Input for ScrapeWebsiteTool.""" diff --git a/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py b/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py index 9b5806b19..65f630c46 100644 --- a/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py +++ b/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py @@ -10,17 +10,14 @@ from scrapegraph_py.logger import sgai_logger class ScrapegraphError(Exception): """Base exception for Scrapegraph-related errors""" - pass class RateLimitError(ScrapegraphError): """Raised when API rate limits are exceeded""" - pass class FixedScrapegraphScrapeToolSchema(BaseModel): """Input for ScrapegraphScrapeTool when website_url is fixed.""" - pass class ScrapegraphScrapeToolSchema(FixedScrapegraphScrapeToolSchema): @@ -32,7 +29,7 @@ class ScrapegraphScrapeToolSchema(FixedScrapegraphScrapeToolSchema): description="Prompt to guide the extraction of content", ) - @validator('website_url') + @validator("website_url") def validate_url(cls, v): """Validate URL format""" try: @@ -41,13 +38,15 @@ class ScrapegraphScrapeToolSchema(FixedScrapegraphScrapeToolSchema): raise ValueError return v except Exception: - raise ValueError("Invalid URL format. URL must include scheme (http/https) and domain") + raise ValueError( + "Invalid URL format. URL must include scheme (http/https) and domain" + ) class ScrapegraphScrapeTool(BaseTool): """ A tool that uses Scrapegraph AI to intelligently scrape website content. - + Raises: ValueError: If API key is missing or URL format is invalid RateLimitError: If API rate limits are exceeded @@ -55,7 +54,9 @@ class ScrapegraphScrapeTool(BaseTool): """ name: str = "Scrapegraph website scraper" - description: str = "A tool that uses Scrapegraph AI to intelligently scrape website content." + description: str = ( + "A tool that uses Scrapegraph AI to intelligently scrape website content." + ) args_schema: Type[BaseModel] = ScrapegraphScrapeToolSchema website_url: Optional[str] = None user_prompt: Optional[str] = None @@ -72,8 +73,7 @@ class ScrapegraphScrapeTool(BaseTool): ): super().__init__(**kwargs) self.api_key = api_key or os.getenv("SCRAPEGRAPH_API_KEY") - self.enable_logging = enable_logging - + if not self.api_key: raise ValueError("Scrapegraph API key is required") @@ -82,7 +82,7 @@ class ScrapegraphScrapeTool(BaseTool): self.website_url = website_url self.description = f"A tool that uses Scrapegraph AI to intelligently scrape {website_url}'s content." self.args_schema = FixedScrapegraphScrapeToolSchema - + if user_prompt is not None: self.user_prompt = user_prompt @@ -98,14 +98,19 @@ class ScrapegraphScrapeTool(BaseTool): if not all([result.scheme, result.netloc]): raise ValueError except Exception: - raise ValueError("Invalid URL format. URL must include scheme (http/https) and domain") + raise ValueError( + "Invalid URL format. URL must include scheme (http/https) and domain" + ) def _run( self, **kwargs: Any, ) -> Any: website_url = kwargs.get("website_url", self.website_url) - user_prompt = kwargs.get("user_prompt", self.user_prompt) or "Extract the main content of the webpage" + user_prompt = ( + kwargs.get("user_prompt", self.user_prompt) + or "Extract the main content of the webpage" + ) if not website_url: raise ValueError("website_url is required") diff --git a/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py b/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py index d7a55428d..8099a06ab 100644 --- a/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py +++ b/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py @@ -17,33 +17,36 @@ class FixedSeleniumScrapingToolSchema(BaseModel): class SeleniumScrapingToolSchema(FixedSeleniumScrapingToolSchema): """Input for SeleniumScrapingTool.""" - website_url: str = Field(..., description="Mandatory website url to read the file. Must start with http:// or https://") + website_url: str = Field( + ..., + description="Mandatory website url to read the file. Must start with http:// or https://", + ) css_element: str = Field( ..., description="Mandatory css reference for element to scrape from the website", ) - @validator('website_url') + @validator("website_url") def validate_website_url(cls, v): if not v: raise ValueError("Website URL cannot be empty") - + if len(v) > 2048: # Common maximum URL length raise ValueError("URL is too long (max 2048 characters)") - - if not re.match(r'^https?://', v): + + if not re.match(r"^https?://", v): raise ValueError("URL must start with http:// or https://") - + try: result = urlparse(v) if not all([result.scheme, result.netloc]): raise ValueError("Invalid URL format") except Exception as e: raise ValueError(f"Invalid URL: {str(e)}") - - if re.search(r'\s', v): + + if re.search(r"\s", v): raise ValueError("URL cannot contain whitespace") - + return v @@ -130,11 +133,11 @@ class SeleniumScrapingTool(BaseTool): def _create_driver(self, url, cookie, wait_time): if not url: raise ValueError("URL cannot be empty") - + # Validate URL format - if not re.match(r'^https?://', url): + if not re.match(r"^https?://", url): raise ValueError("URL must start with http:// or https://") - + options = Options() options.add_argument("--headless") driver = self.driver(options=options) diff --git a/src/crewai_tools/tools/serpapi_tool/serpapi_base_tool.py b/src/crewai_tools/tools/serpapi_tool/serpapi_base_tool.py index 98491190c..895f3aadc 100644 --- a/src/crewai_tools/tools/serpapi_tool/serpapi_base_tool.py +++ b/src/crewai_tools/tools/serpapi_tool/serpapi_base_tool.py @@ -1,9 +1,10 @@ import os import re -from typing import Optional, Any, Union +from typing import Any, Optional, Union from crewai.tools import BaseTool + class SerpApiBaseTool(BaseTool): """Base class for SerpApi functionality with shared capabilities.""" diff --git a/src/crewai_tools/tools/serpapi_tool/serpapi_google_search_tool.py b/src/crewai_tools/tools/serpapi_tool/serpapi_google_search_tool.py index 199b7f5a2..86f40ef03 100644 --- a/src/crewai_tools/tools/serpapi_tool/serpapi_google_search_tool.py +++ b/src/crewai_tools/tools/serpapi_tool/serpapi_google_search_tool.py @@ -1,14 +1,22 @@ -from typing import Any, Type, Optional +from typing import Any, Optional, Type -import re from pydantic import BaseModel, Field from .serpapi_base_tool import SerpApiBaseTool -from serpapi import HTTPError +from urllib.error import HTTPError + +from .serpapi_base_tool import SerpApiBaseTool + class SerpApiGoogleSearchToolSchema(BaseModel): """Input for Google Search.""" - search_query: str = Field(..., description="Mandatory search query you want to use to Google search.") - location: Optional[str] = Field(None, description="Location you want the search to be performed in.") + + search_query: str = Field( + ..., description="Mandatory search query you want to use to Google search." + ) + location: Optional[str] = Field( + None, description="Location you want the search to be performed in." + ) + class SerpApiGoogleSearchTool(SerpApiBaseTool): name: str = "Google Search" @@ -22,19 +30,25 @@ class SerpApiGoogleSearchTool(SerpApiBaseTool): **kwargs: Any, ) -> Any: try: - results = self.client.search({ - "q": kwargs.get("search_query"), - "location": kwargs.get("location"), - }).as_dict() + results = self.client.search( + { + "q": kwargs.get("search_query"), + "location": kwargs.get("location"), + } + ).as_dict() self._omit_fields( - results, - [r"search_metadata", r"search_parameters", r"serpapi_.+", r".+_token", r"displayed_link", r"pagination"] + results, + [ + r"search_metadata", + r"search_parameters", + r"serpapi_.+", + r".+_token", + r"displayed_link", + r"pagination", + ], ) return results except HTTPError as e: return f"An error occurred: {str(e)}. Some parameters may be invalid." - - - \ No newline at end of file diff --git a/src/crewai_tools/tools/serpapi_tool/serpapi_google_shopping_tool.py b/src/crewai_tools/tools/serpapi_tool/serpapi_google_shopping_tool.py index b44b3a809..2dda9aa4c 100644 --- a/src/crewai_tools/tools/serpapi_tool/serpapi_google_shopping_tool.py +++ b/src/crewai_tools/tools/serpapi_tool/serpapi_google_shopping_tool.py @@ -1,14 +1,21 @@ -from typing import Any, Type, Optional +from typing import Any, Optional, Type -import re from pydantic import BaseModel, Field from .serpapi_base_tool import SerpApiBaseTool -from serpapi import HTTPError +from urllib.error import HTTPError + +from .serpapi_base_tool import SerpApiBaseTool + class SerpApiGoogleShoppingToolSchema(BaseModel): """Input for Google Shopping.""" - search_query: str = Field(..., description="Mandatory search query you want to use to Google shopping.") - location: Optional[str] = Field(None, description="Location you want the search to be performed in.") + + search_query: str = Field( + ..., description="Mandatory search query you want to use to Google shopping." + ) + location: Optional[str] = Field( + None, description="Location you want the search to be performed in." + ) class SerpApiGoogleShoppingTool(SerpApiBaseTool): @@ -23,20 +30,25 @@ class SerpApiGoogleShoppingTool(SerpApiBaseTool): **kwargs: Any, ) -> Any: try: - results = self.client.search({ - "engine": "google_shopping", - "q": kwargs.get("search_query"), - "location": kwargs.get("location") - }).as_dict() + results = self.client.search( + { + "engine": "google_shopping", + "q": kwargs.get("search_query"), + "location": kwargs.get("location"), + } + ).as_dict() self._omit_fields( - results, - [r"search_metadata", r"search_parameters", r"serpapi_.+", r"filters", r"pagination"] + results, + [ + r"search_metadata", + r"search_parameters", + r"serpapi_.+", + r"filters", + r"pagination", + ], ) return results except HTTPError as e: return f"An error occurred: {str(e)}. Some parameters may be invalid." - - - \ No newline at end of file diff --git a/src/crewai_tools/tools/serper_dev_tool/serper_dev_tool.py b/src/crewai_tools/tools/serper_dev_tool/serper_dev_tool.py index 5e8986c7e..2db347190 100644 --- a/src/crewai_tools/tools/serper_dev_tool/serper_dev_tool.py +++ b/src/crewai_tools/tools/serper_dev_tool/serper_dev_tool.py @@ -1,19 +1,19 @@ import datetime import json -import os import logging +import os from typing import Any, Type import requests from crewai.tools import BaseTool from pydantic import BaseModel, Field - logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) + def _save_results_to_file(content: str) -> None: """Saves the search results to a file.""" try: @@ -35,7 +35,7 @@ class SerperDevToolSchema(BaseModel): class SerperDevTool(BaseTool): - name: str = "Search the internet" + name: str = "Search the internet with Serper" description: str = ( "A tool that can be used to search the internet with a search_query. " "Supports different search types: 'search' (default), 'news'" diff --git a/src/crewai_tools/tools/serply_api_tool/serply_webpage_to_markdown_tool.py b/src/crewai_tools/tools/serply_api_tool/serply_webpage_to_markdown_tool.py index e09a36fd9..4010236cc 100644 --- a/src/crewai_tools/tools/serply_api_tool/serply_webpage_to_markdown_tool.py +++ b/src/crewai_tools/tools/serply_api_tool/serply_webpage_to_markdown_tool.py @@ -18,9 +18,7 @@ class SerplyWebpageToMarkdownToolSchema(BaseModel): class SerplyWebpageToMarkdownTool(RagTool): name: str = "Webpage to Markdown" - description: str = ( - "A tool to perform convert a webpage to markdown to make it easier for LLMs to understand" - ) + description: str = "A tool to perform convert a webpage to markdown to make it easier for LLMs to understand" args_schema: Type[BaseModel] = SerplyWebpageToMarkdownToolSchema request_url: str = "https://api.serply.io/v1/request" proxy_location: Optional[str] = "US" diff --git a/src/crewai_tools/tools/snowflake_search_tool/README.md b/src/crewai_tools/tools/snowflake_search_tool/README.md new file mode 100644 index 000000000..fc0b845c3 --- /dev/null +++ b/src/crewai_tools/tools/snowflake_search_tool/README.md @@ -0,0 +1,155 @@ +# Snowflake Search Tool + +A tool for executing queries on Snowflake data warehouse with built-in connection pooling, retry logic, and async execution support. + +## Installation + +```bash +uv sync --extra snowflake + +OR +uv pip install snowflake-connector-python>=3.5.0 snowflake-sqlalchemy>=1.5.0 cryptography>=41.0.0 + +OR +pip install snowflake-connector-python>=3.5.0 snowflake-sqlalchemy>=1.5.0 cryptography>=41.0.0 +``` + +## Quick Start + +```python +import asyncio +from crewai_tools import SnowflakeSearchTool, SnowflakeConfig + +# Create configuration +config = SnowflakeConfig( + account="your_account", + user="your_username", + password="your_password", + warehouse="COMPUTE_WH", + database="your_database", + snowflake_schema="your_schema" # Note: Uses snowflake_schema instead of schema +) + +# Initialize tool +tool = SnowflakeSearchTool( + config=config, + pool_size=5, + max_retries=3, + enable_caching=True +) + +# Execute query +async def main(): + results = await tool._run( + query="SELECT * FROM your_table LIMIT 10", + timeout=300 + ) + print(f"Retrieved {len(results)} rows") + +if __name__ == "__main__": + asyncio.run(main()) +``` + +## Features + +- ✨ Asynchronous query execution +- 🚀 Connection pooling for better performance +- 🔄 Automatic retries for transient failures +- 💾 Query result caching (optional) +- 🔒 Support for both password and key-pair authentication +- 📝 Comprehensive error handling and logging + +## Configuration Options + +### SnowflakeConfig Parameters + +| Parameter | Required | Description | +|-----------|----------|-------------| +| account | Yes | Snowflake account identifier | +| user | Yes | Snowflake username | +| password | Yes* | Snowflake password | +| private_key_path | No* | Path to private key file (alternative to password) | +| warehouse | Yes | Snowflake warehouse name | +| database | Yes | Default database | +| snowflake_schema | Yes | Default schema | +| role | No | Snowflake role | +| session_parameters | No | Custom session parameters dict | + +\* Either password or private_key_path must be provided + +### Tool Parameters + +| Parameter | Default | Description | +|-----------|---------|-------------| +| pool_size | 5 | Number of connections in the pool | +| max_retries | 3 | Maximum retry attempts for failed queries | +| retry_delay | 1.0 | Delay between retries in seconds | +| enable_caching | True | Enable/disable query result caching | + +## Advanced Usage + +### Using Key-Pair Authentication + +```python +config = SnowflakeConfig( + account="your_account", + user="your_username", + private_key_path="/path/to/private_key.p8", + warehouse="your_warehouse", + database="your_database", + snowflake_schema="your_schema" +) +``` + +### Custom Session Parameters + +```python +config = SnowflakeConfig( + # ... other config parameters ... + session_parameters={ + "QUERY_TAG": "my_app", + "TIMEZONE": "America/Los_Angeles" + } +) +``` + +## Best Practices + +1. **Error Handling**: Always wrap query execution in try-except blocks +2. **Logging**: Enable logging to track query execution and errors +3. **Connection Management**: Use appropriate pool sizes for your workload +4. **Timeouts**: Set reasonable query timeouts to prevent hanging +5. **Security**: Use key-pair auth in production and never hardcode credentials + +## Example with Logging + +```python +import logging + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +async def main(): + try: + # ... tool initialization ... + results = await tool._run(query="SELECT * FROM table LIMIT 10") + logger.info(f"Query completed successfully. Retrieved {len(results)} rows") + except Exception as e: + logger.error(f"Query failed: {str(e)}") + raise +``` + +## Error Handling + +The tool automatically handles common Snowflake errors: +- DatabaseError +- OperationalError +- ProgrammingError +- Network timeouts +- Connection issues + +Errors are logged and retried based on your retry configuration. \ No newline at end of file diff --git a/src/crewai_tools/tools/snowflake_search_tool/__init__.py b/src/crewai_tools/tools/snowflake_search_tool/__init__.py new file mode 100644 index 000000000..abc1a45f5 --- /dev/null +++ b/src/crewai_tools/tools/snowflake_search_tool/__init__.py @@ -0,0 +1,11 @@ +from .snowflake_search_tool import ( + SnowflakeConfig, + SnowflakeSearchTool, + SnowflakeSearchToolInput, +) + +__all__ = [ + "SnowflakeSearchTool", + "SnowflakeSearchToolInput", + "SnowflakeConfig", +] diff --git a/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py b/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py new file mode 100644 index 000000000..75c671d21 --- /dev/null +++ b/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py @@ -0,0 +1,201 @@ +import asyncio +import logging +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, List, Optional, Type + +import snowflake.connector +from crewai.tools.base_tool import BaseTool +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from pydantic import BaseModel, ConfigDict, Field, SecretStr +from snowflake.connector.connection import SnowflakeConnection +from snowflake.connector.errors import DatabaseError, OperationalError + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Cache for query results +_query_cache = {} + + +class SnowflakeConfig(BaseModel): + """Configuration for Snowflake connection.""" + + model_config = ConfigDict(protected_namespaces=()) + + account: str = Field( + ..., description="Snowflake account identifier", pattern=r"^[a-zA-Z0-9\-_]+$" + ) + user: str = Field(..., description="Snowflake username") + password: Optional[SecretStr] = Field(None, description="Snowflake password") + private_key_path: Optional[str] = Field( + None, description="Path to private key file" + ) + warehouse: Optional[str] = Field(None, description="Snowflake warehouse") + database: Optional[str] = Field(None, description="Default database") + snowflake_schema: Optional[str] = Field(None, description="Default schema") + role: Optional[str] = Field(None, description="Snowflake role") + session_parameters: Optional[Dict[str, Any]] = Field( + default_factory=dict, description="Session parameters" + ) + + @property + def has_auth(self) -> bool: + return bool(self.password or self.private_key_path) + + def model_post_init(self, *args, **kwargs): + if not self.has_auth: + raise ValueError("Either password or private_key_path must be provided") + + +class SnowflakeSearchToolInput(BaseModel): + """Input schema for SnowflakeSearchTool.""" + + model_config = ConfigDict(protected_namespaces=()) + + query: str = Field(..., description="SQL query or semantic search query to execute") + database: Optional[str] = Field(None, description="Override default database") + snowflake_schema: Optional[str] = Field(None, description="Override default schema") + timeout: Optional[int] = Field(300, description="Query timeout in seconds") + + +class SnowflakeSearchTool(BaseTool): + """Tool for executing queries and semantic search on Snowflake.""" + + name: str = "Snowflake Database Search" + description: str = ( + "Execute SQL queries or semantic search on Snowflake data warehouse. " + "Supports both raw SQL and natural language queries." + ) + args_schema: Type[BaseModel] = SnowflakeSearchToolInput + + # Define Pydantic fields + config: SnowflakeConfig = Field( + ..., description="Snowflake connection configuration" + ) + pool_size: int = Field(default=5, description="Size of connection pool") + max_retries: int = Field(default=3, description="Maximum retry attempts") + retry_delay: float = Field( + default=1.0, description="Delay between retries in seconds" + ) + enable_caching: bool = Field( + default=True, description="Enable query result caching" + ) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def __init__(self, **data): + """Initialize SnowflakeSearchTool.""" + super().__init__(**data) + self._connection_pool: List[SnowflakeConnection] = [] + self._pool_lock = asyncio.Lock() + self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size) + + async def _get_connection(self) -> SnowflakeConnection: + """Get a connection from the pool or create a new one.""" + async with self._pool_lock: + if not self._connection_pool: + conn = self._create_connection() + self._connection_pool.append(conn) + return self._connection_pool.pop() + + def _create_connection(self) -> SnowflakeConnection: + """Create a new Snowflake connection.""" + conn_params = { + "account": self.config.account, + "user": self.config.user, + "warehouse": self.config.warehouse, + "database": self.config.database, + "schema": self.config.snowflake_schema, + "role": self.config.role, + "session_parameters": self.config.session_parameters, + } + + if self.config.password: + conn_params["password"] = self.config.password.get_secret_value() + elif self.config.private_key_path: + with open(self.config.private_key_path, "rb") as key_file: + p_key = serialization.load_pem_private_key( + key_file.read(), password=None, backend=default_backend() + ) + conn_params["private_key"] = p_key + + return snowflake.connector.connect(**conn_params) + + def _get_cache_key(self, query: str, timeout: int) -> str: + """Generate a cache key for the query.""" + return f"{self.config.account}:{self.config.database}:{self.config.snowflake_schema}:{query}:{timeout}" + + async def _execute_query( + self, query: str, timeout: int = 300 + ) -> List[Dict[str, Any]]: + """Execute a query with retries and return results.""" + if self.enable_caching: + cache_key = self._get_cache_key(query, timeout) + if cache_key in _query_cache: + logger.info("Returning cached result") + return _query_cache[cache_key] + + for attempt in range(self.max_retries): + try: + conn = await self._get_connection() + try: + cursor = conn.cursor() + cursor.execute(query, timeout=timeout) + + if not cursor.description: + return [] + + columns = [col[0] for col in cursor.description] + results = [dict(zip(columns, row)) for row in cursor.fetchall()] + + if self.enable_caching: + _query_cache[self._get_cache_key(query, timeout)] = results + + return results + finally: + cursor.close() + async with self._pool_lock: + self._connection_pool.append(conn) + except (DatabaseError, OperationalError) as e: + if attempt == self.max_retries - 1: + raise + await asyncio.sleep(self.retry_delay * (2**attempt)) + logger.warning(f"Query failed, attempt {attempt + 1}: {str(e)}") + continue + + async def _run( + self, + query: str, + database: Optional[str] = None, + snowflake_schema: Optional[str] = None, + timeout: int = 300, + **kwargs: Any, + ) -> Any: + """Execute the search query.""" + try: + # Override database/schema if provided + if database: + await self._execute_query(f"USE DATABASE {database}") + if snowflake_schema: + await self._execute_query(f"USE SCHEMA {snowflake_schema}") + + results = await self._execute_query(query, timeout) + return results + except Exception as e: + logger.error(f"Error executing query: {str(e)}") + raise + + def __del__(self): + """Cleanup connections on deletion.""" + try: + for conn in getattr(self, "_connection_pool", []): + try: + conn.close() + except: + pass + if hasattr(self, "_thread_pool"): + self._thread_pool.shutdown() + except: + pass diff --git a/src/crewai_tools/tools/stagehand_tool/stagehand_tool.py b/src/crewai_tools/tools/stagehand_tool/stagehand_tool.py index 07c76c8c3..37b414509 100644 --- a/src/crewai_tools/tools/stagehand_tool/stagehand_tool.py +++ b/src/crewai_tools/tools/stagehand_tool/stagehand_tool.py @@ -14,9 +14,8 @@ import os from functools import lru_cache from typing import Any, Dict, List, Optional, Type, Union -from pydantic import BaseModel, Field - from crewai.tools.base_tool import BaseTool +from pydantic import BaseModel, Field # Set up logging logger = logging.getLogger(__name__) @@ -25,6 +24,7 @@ logger = logging.getLogger(__name__) STAGEHAND_AVAILABLE = False try: import stagehand + STAGEHAND_AVAILABLE = True except ImportError: pass # Keep STAGEHAND_AVAILABLE as False @@ -32,33 +32,45 @@ except ImportError: class StagehandResult(BaseModel): """Result from a Stagehand operation. - + Attributes: success: Whether the operation completed successfully data: The result data from the operation error: Optional error message if the operation failed """ - success: bool = Field(..., description="Whether the operation completed successfully") - data: Union[str, Dict, List] = Field(..., description="The result data from the operation") - error: Optional[str] = Field(None, description="Optional error message if the operation failed") + + success: bool = Field( + ..., description="Whether the operation completed successfully" + ) + data: Union[str, Dict, List] = Field( + ..., description="The result data from the operation" + ) + error: Optional[str] = Field( + None, description="Optional error message if the operation failed" + ) class StagehandToolConfig(BaseModel): """Configuration for the StagehandTool. - + Attributes: api_key: OpenAI API key for Stagehand authentication timeout: Maximum time in seconds to wait for operations (default: 30) retry_attempts: Number of times to retry failed operations (default: 3) """ + api_key: str = Field(..., description="OpenAI API key for Stagehand authentication") - timeout: int = Field(30, description="Maximum time in seconds to wait for operations") - retry_attempts: int = Field(3, description="Number of times to retry failed operations") + timeout: int = Field( + 30, description="Maximum time in seconds to wait for operations" + ) + retry_attempts: int = Field( + 3, description="Number of times to retry failed operations" + ) class StagehandToolSchema(BaseModel): """Schema for the StagehandTool input parameters. - + Examples: ```python # Using the 'act' API to click a button @@ -66,13 +78,13 @@ class StagehandToolSchema(BaseModel): api_method="act", instruction="Click the 'Sign In' button" ) - + # Using the 'extract' API to get text tool.run( api_method="extract", instruction="Get the text content of the main article" ) - + # Using the 'observe' API to monitor changes tool.run( api_method="observe", @@ -80,48 +92,49 @@ class StagehandToolSchema(BaseModel): ) ``` """ + api_method: str = Field( ..., description="The Stagehand API to use: 'act' for interactions, 'extract' for getting content, or 'observe' for monitoring changes", - pattern="^(act|extract|observe)$" + pattern="^(act|extract|observe)$", ) instruction: str = Field( ..., description="An atomic instruction for Stagehand to execute. Instructions should be simple and specific to increase reliability.", min_length=1, - max_length=500 + max_length=500, ) class StagehandTool(BaseTool): """A tool for using Stagehand's AI-powered web automation capabilities. - + This tool provides access to Stagehand's three core APIs: - act: Perform web interactions (e.g., clicking buttons, filling forms) - extract: Extract information from web pages (e.g., getting text content) - observe: Monitor web page changes (e.g., watching for updates) - + Each function takes atomic instructions to increase reliability. - + Required Environment Variables: OPENAI_API_KEY: API key for OpenAI (required by Stagehand) - + Examples: ```python tool = StagehandTool() - + # Perform a web interaction result = tool.run( api_method="act", instruction="Click the 'Sign In' button" ) - + # Extract content from a page content = tool.run( api_method="extract", instruction="Get the text content of the main article" ) - + # Monitor for changes changes = tool.run( api_method="observe", @@ -129,7 +142,7 @@ class StagehandTool(BaseTool): ) ``` """ - + name: str = "StagehandTool" description: str = ( "A tool that uses Stagehand's AI-powered web automation to interact with websites. " @@ -137,27 +150,29 @@ class StagehandTool(BaseTool): "Each instruction should be atomic (simple and specific) to increase reliability." ) args_schema: Type[BaseModel] = StagehandToolSchema - - def __init__(self, config: StagehandToolConfig | None = None, **kwargs: Any) -> None: + + def __init__( + self, config: StagehandToolConfig | None = None, **kwargs: Any + ) -> None: """Initialize the StagehandTool. - + Args: config: Optional configuration for the tool. If not provided, will attempt to use OPENAI_API_KEY from environment. **kwargs: Additional keyword arguments passed to the base class. - + Raises: ImportError: If the stagehand package is not installed ValueError: If no API key is provided via config or environment """ super().__init__(**kwargs) - + if not STAGEHAND_AVAILABLE: raise ImportError( "The 'stagehand' package is required to use this tool. " "Please install it with: pip install stagehand" ) - + # Use config if provided, otherwise try environment variable if config is not None: self.config = config @@ -168,24 +183,22 @@ class StagehandTool(BaseTool): "Either provide config with api_key or set OPENAI_API_KEY environment variable" ) self.config = StagehandToolConfig( - api_key=api_key, - timeout=30, - retry_attempts=3 + api_key=api_key, timeout=30, retry_attempts=3 ) - + @lru_cache(maxsize=100) def _cached_run(self, api_method: str, instruction: str) -> Any: """Execute a cached Stagehand command. - + This method is cached to improve performance for repeated operations. - + Args: api_method: The Stagehand API to use ('act', 'extract', or 'observe') instruction: An atomic instruction for Stagehand to execute - + Returns: The raw result from the Stagehand API call - + Raises: ValueError: If an invalid api_method is provided Exception: If the Stagehand API call fails @@ -193,23 +206,25 @@ class StagehandTool(BaseTool): logger.debug( "Cache operation - Method: %s, Instruction length: %d", api_method, - len(instruction) + len(instruction), ) - + # Initialize Stagehand with configuration logger.info( "Initializing Stagehand (timeout=%ds, retries=%d)", self.config.timeout, - self.config.retry_attempts + self.config.retry_attempts, ) st = stagehand.Stagehand( api_key=self.config.api_key, timeout=self.config.timeout, - retry_attempts=self.config.retry_attempts + retry_attempts=self.config.retry_attempts, ) - + # Call the appropriate Stagehand API based on the method - logger.info("Executing %s operation with instruction: %s", api_method, instruction[:100]) + logger.info( + "Executing %s operation with instruction: %s", api_method, instruction[:100] + ) try: if api_method == "act": result = st.act(instruction) @@ -219,28 +234,27 @@ class StagehandTool(BaseTool): result = st.observe(instruction) else: raise ValueError(f"Unknown api_method: {api_method}") - - + logger.info("Successfully executed %s operation", api_method) return result - + except Exception as e: logger.warning( "Operation failed (method=%s, error=%s), will be retried on next attempt", api_method, - str(e) + str(e), ) raise def _run(self, api_method: str, instruction: str, **kwargs: Any) -> StagehandResult: """Execute a Stagehand command using the specified API method. - + Args: api_method: The Stagehand API to use ('act', 'extract', or 'observe') instruction: An atomic instruction for Stagehand to execute **kwargs: Additional keyword arguments passed to the Stagehand API - - Returns: + + Returns: StagehandResult containing the operation result and status """ try: @@ -249,56 +263,36 @@ class StagehandTool(BaseTool): "Starting operation - Method: %s, Instruction length: %d, Args: %s", api_method, len(instruction), - kwargs + kwargs, ) - + # Use cached execution result = self._cached_run(api_method, instruction) logger.info("Operation completed successfully") return StagehandResult(success=True, data=result) - + except stagehand.AuthenticationError as e: logger.error( - "Authentication failed - Method: %s, Error: %s", - api_method, - str(e) + "Authentication failed - Method: %s, Error: %s", api_method, str(e) ) return StagehandResult( - success=False, - data={}, - error=f"Authentication failed: {str(e)}" + success=False, data={}, error=f"Authentication failed: {str(e)}" ) except stagehand.APIError as e: - logger.error( - "API error - Method: %s, Error: %s", - api_method, - str(e) - ) - return StagehandResult( - success=False, - data={}, - error=f"API error: {str(e)}" - ) + logger.error("API error - Method: %s, Error: %s", api_method, str(e)) + return StagehandResult(success=False, data={}, error=f"API error: {str(e)}") except stagehand.BrowserError as e: - logger.error( - "Browser error - Method: %s, Error: %s", - api_method, - str(e) - ) + logger.error("Browser error - Method: %s, Error: %s", api_method, str(e)) return StagehandResult( - success=False, - data={}, - error=f"Browser error: {str(e)}" + success=False, data={}, error=f"Browser error: {str(e)}" ) except Exception as e: logger.error( "Unexpected error - Method: %s, Error type: %s, Message: %s", api_method, type(e).__name__, - str(e) + str(e), ) return StagehandResult( - success=False, - data={}, - error=f"Unexpected error: {str(e)}" + success=False, data={}, error=f"Unexpected error: {str(e)}" ) diff --git a/src/crewai_tools/tools/vision_tool/vision_tool.py b/src/crewai_tools/tools/vision_tool/vision_tool.py index 4fbc1df0e..594be0b22 100644 --- a/src/crewai_tools/tools/vision_tool/vision_tool.py +++ b/src/crewai_tools/tools/vision_tool/vision_tool.py @@ -1,30 +1,36 @@ import base64 -from typing import Type, Optional from pathlib import Path +from typing import Optional, Type + from crewai.tools import BaseTool from openai import OpenAI from pydantic import BaseModel, validator + class ImagePromptSchema(BaseModel): """Input for Vision Tool.""" + image_path_url: str = "The image path or URL." @validator("image_path_url") def validate_image_path_url(cls, v: str) -> str: if v.startswith("http"): return v - + path = Path(v) if not path.exists(): raise ValueError(f"Image file does not exist: {v}") - + # Validate supported formats valid_extensions = {".jpg", ".jpeg", ".png", ".gif", ".webp"} if path.suffix.lower() not in valid_extensions: - raise ValueError(f"Unsupported image format. Supported formats: {valid_extensions}") - + raise ValueError( + f"Unsupported image format. Supported formats: {valid_extensions}" + ) + return v + class VisionTool(BaseTool): name: str = "Vision Tool" description: str = ( @@ -45,10 +51,10 @@ class VisionTool(BaseTool): image_path_url = kwargs.get("image_path_url") if not image_path_url: return "Image Path or URL is required." - + # Validate input using Pydantic ImagePromptSchema(image_path_url=image_path_url) - + if image_path_url.startswith("http"): image_data = image_path_url else: @@ -68,12 +74,12 @@ class VisionTool(BaseTool): { "type": "image_url", "image_url": {"url": image_data}, - } + }, ], } ], max_tokens=300, - ) + ) return response.choices[0].message.content diff --git a/src/crewai_tools/tools/weaviate_tool/vector_search.py b/src/crewai_tools/tools/weaviate_tool/vector_search.py index 14e10d7c5..53f641272 100644 --- a/src/crewai_tools/tools/weaviate_tool/vector_search.py +++ b/src/crewai_tools/tools/weaviate_tool/vector_search.py @@ -15,9 +15,8 @@ except ImportError: Vectorizers = Any Auth = Any -from pydantic import BaseModel, Field - from crewai.tools import BaseTool +from pydantic import BaseModel, Field class WeaviateToolSchema(BaseModel): diff --git a/src/crewai_tools/tools/website_search/website_search_tool.py b/src/crewai_tools/tools/website_search/website_search_tool.py index faa1a02e8..842462546 100644 --- a/src/crewai_tools/tools/website_search/website_search_tool.py +++ b/src/crewai_tools/tools/website_search/website_search_tool.py @@ -25,9 +25,7 @@ class WebsiteSearchToolSchema(FixedWebsiteSearchToolSchema): class WebsiteSearchTool(RagTool): name: str = "Search in a specific website" - description: str = ( - "A tool that can be used to semantic search a query from a specific URL content." - ) + description: str = "A tool that can be used to semantic search a query from a specific URL content." args_schema: Type[BaseModel] = WebsiteSearchToolSchema def __init__(self, website: Optional[str] = None, **kwargs): diff --git a/src/crewai_tools/tools/youtube_channel_search_tool/youtube_channel_search_tool.py b/src/crewai_tools/tools/youtube_channel_search_tool/youtube_channel_search_tool.py index b0c6209f1..81ecc30c3 100644 --- a/src/crewai_tools/tools/youtube_channel_search_tool/youtube_channel_search_tool.py +++ b/src/crewai_tools/tools/youtube_channel_search_tool/youtube_channel_search_tool.py @@ -25,9 +25,7 @@ class YoutubeChannelSearchToolSchema(FixedYoutubeChannelSearchToolSchema): class YoutubeChannelSearchTool(RagTool): name: str = "Search a Youtube Channels content" - description: str = ( - "A tool that can be used to semantic search a query from a Youtube Channels content." - ) + description: str = "A tool that can be used to semantic search a query from a Youtube Channels content." args_schema: Type[BaseModel] = YoutubeChannelSearchToolSchema def __init__(self, youtube_channel_handle: Optional[str] = None, **kwargs): diff --git a/src/crewai_tools/tools/youtube_video_search_tool/youtube_video_search_tool.py b/src/crewai_tools/tools/youtube_video_search_tool/youtube_video_search_tool.py index 6852fafb4..1ad8434c8 100644 --- a/src/crewai_tools/tools/youtube_video_search_tool/youtube_video_search_tool.py +++ b/src/crewai_tools/tools/youtube_video_search_tool/youtube_video_search_tool.py @@ -25,9 +25,7 @@ class YoutubeVideoSearchToolSchema(FixedYoutubeVideoSearchToolSchema): class YoutubeVideoSearchTool(RagTool): name: str = "Search a Youtube Video content" - description: str = ( - "A tool that can be used to semantic search a query from a Youtube Video content." - ) + description: str = "A tool that can be used to semantic search a query from a Youtube Video content." args_schema: Type[BaseModel] = YoutubeVideoSearchToolSchema def __init__(self, youtube_video_url: Optional[str] = None, **kwargs): diff --git a/tests/base_tool_test.py b/tests/base_tool_test.py index 4a4e40783..e6f4f127d 100644 --- a/tests/base_tool_test.py +++ b/tests/base_tool_test.py @@ -1,69 +1,104 @@ from typing import Callable + from crewai.tools import BaseTool, tool from crewai.tools.base_tool import to_langchain + def test_creating_a_tool_using_annotation(): - @tool("Name of my tool") - def my_tool(question: str) -> str: - """Clear description for what this tool is useful for, you agent will need this information to use it.""" - return question + @tool("Name of my tool") + def my_tool(question: str) -> str: + """Clear description for what this tool is useful for, you agent will need this information to use it.""" + return question - # Assert all the right attributes were defined - assert my_tool.name == "Name of my tool" - assert my_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." - assert my_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}} - assert my_tool.func("What is the meaning of life?") == "What is the meaning of life?" + # Assert all the right attributes were defined + assert my_tool.name == "Name of my tool" + assert ( + my_tool.description + == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." + ) + assert my_tool.args_schema.schema()["properties"] == { + "question": {"title": "Question", "type": "string"} + } + assert ( + my_tool.func("What is the meaning of life?") == "What is the meaning of life?" + ) + + # Assert the langchain tool conversion worked as expected + converted_tool = to_langchain([my_tool])[0] + assert converted_tool.name == "Name of my tool" + assert ( + converted_tool.description + == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." + ) + assert converted_tool.args_schema.schema()["properties"] == { + "question": {"title": "Question", "type": "string"} + } + assert ( + converted_tool.func("What is the meaning of life?") + == "What is the meaning of life?" + ) - # Assert the langchain tool conversion worked as expected - converted_tool = to_langchain([my_tool])[0] - assert converted_tool.name == "Name of my tool" - assert converted_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." - assert converted_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}} - assert converted_tool.func("What is the meaning of life?") == "What is the meaning of life?" def test_creating_a_tool_using_baseclass(): - class MyCustomTool(BaseTool): - name: str = "Name of my tool" - description: str = "Clear description for what this tool is useful for, you agent will need this information to use it." + class MyCustomTool(BaseTool): + name: str = "Name of my tool" + description: str = "Clear description for what this tool is useful for, you agent will need this information to use it." - def _run(self, question: str) -> str: - return question + def _run(self, question: str) -> str: + return question - my_tool = MyCustomTool() - # Assert all the right attributes were defined - assert my_tool.name == "Name of my tool" - assert my_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." - assert my_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}} - assert my_tool._run("What is the meaning of life?") == "What is the meaning of life?" + my_tool = MyCustomTool() + # Assert all the right attributes were defined + assert my_tool.name == "Name of my tool" + assert ( + my_tool.description + == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." + ) + assert my_tool.args_schema.schema()["properties"] == { + "question": {"title": "Question", "type": "string"} + } + assert ( + my_tool._run("What is the meaning of life?") == "What is the meaning of life?" + ) + + # Assert the langchain tool conversion worked as expected + converted_tool = to_langchain([my_tool])[0] + assert converted_tool.name == "Name of my tool" + assert ( + converted_tool.description + == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." + ) + assert converted_tool.args_schema.schema()["properties"] == { + "question": {"title": "Question", "type": "string"} + } + assert ( + converted_tool.invoke({"question": "What is the meaning of life?"}) + == "What is the meaning of life?" + ) - # Assert the langchain tool conversion worked as expected - converted_tool = to_langchain([my_tool])[0] - assert converted_tool.name == "Name of my tool" - assert converted_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." - assert converted_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}} - assert converted_tool.invoke({"question": "What is the meaning of life?"}) == "What is the meaning of life?" def test_setting_cache_function(): - class MyCustomTool(BaseTool): - name: str = "Name of my tool" - description: str = "Clear description for what this tool is useful for, you agent will need this information to use it." - cache_function: Callable = lambda: False + class MyCustomTool(BaseTool): + name: str = "Name of my tool" + description: str = "Clear description for what this tool is useful for, you agent will need this information to use it." + cache_function: Callable = lambda: False - def _run(self, question: str) -> str: - return question + def _run(self, question: str) -> str: + return question + + my_tool = MyCustomTool() + # Assert all the right attributes were defined + assert my_tool.cache_function() == False - my_tool = MyCustomTool() - # Assert all the right attributes were defined - assert my_tool.cache_function() == False def test_default_cache_function_is_true(): - class MyCustomTool(BaseTool): - name: str = "Name of my tool" - description: str = "Clear description for what this tool is useful for, you agent will need this information to use it." + class MyCustomTool(BaseTool): + name: str = "Name of my tool" + description: str = "Clear description for what this tool is useful for, you agent will need this information to use it." - def _run(self, question: str) -> str: - return question + def _run(self, question: str) -> str: + return question - my_tool = MyCustomTool() - # Assert all the right attributes were defined - assert my_tool.cache_function() == True \ No newline at end of file + my_tool = MyCustomTool() + # Assert all the right attributes were defined + assert my_tool.cache_function() == True diff --git a/tests/file_read_tool_test.py b/tests/file_read_tool_test.py index 4646df24c..5957f863b 100644 --- a/tests/file_read_tool_test.py +++ b/tests/file_read_tool_test.py @@ -1,7 +1,8 @@ import os -import pytest + from crewai_tools import FileReadTool + def test_file_read_tool_constructor(): """Test FileReadTool initialization with file_path.""" # Create a temporary test file @@ -18,6 +19,7 @@ def test_file_read_tool_constructor(): # Clean up os.remove(test_file) + def test_file_read_tool_run(): """Test FileReadTool _run method with file_path at runtime.""" # Create a temporary test file @@ -34,6 +36,7 @@ def test_file_read_tool_run(): # Clean up os.remove(test_file) + def test_file_read_tool_error_handling(): """Test FileReadTool error handling.""" # Test missing file path @@ -58,6 +61,7 @@ def test_file_read_tool_error_handling(): os.chmod(test_file, 0o666) # Restore permissions to delete os.remove(test_file) + def test_file_read_tool_constructor_and_run(): """Test FileReadTool using both constructor and runtime file paths.""" # Create two test files diff --git a/tests/it/tools/__init__.py b/tests/it/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/it/tools/conftest.py b/tests/it/tools/conftest.py new file mode 100644 index 000000000..a633c22c7 --- /dev/null +++ b/tests/it/tools/conftest.py @@ -0,0 +1,21 @@ +import pytest + + +def pytest_configure(config): + """Register custom markers.""" + config.addinivalue_line("markers", "integration: mark test as an integration test") + config.addinivalue_line("markers", "asyncio: mark test as an async test") + + # Set the asyncio loop scope through ini configuration + config.inicfg["asyncio_mode"] = "auto" + + +@pytest.fixture(scope="function") +def event_loop(): + """Create an instance of the default event loop for each test case.""" + import asyncio + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + yield loop + loop.close() diff --git a/tests/it/tools/snowflake_search_tool_test.py b/tests/it/tools/snowflake_search_tool_test.py new file mode 100644 index 000000000..70dc07953 --- /dev/null +++ b/tests/it/tools/snowflake_search_tool_test.py @@ -0,0 +1,219 @@ +import asyncio +import json +from decimal import Decimal + +import pytest +from snowflake.connector.errors import DatabaseError, OperationalError + +from crewai_tools import SnowflakeConfig, SnowflakeSearchTool + +# Test Data +MENU_ITEMS = [ + (10001, "Ice Cream", "Freezing Point", "Lemonade", "Beverage", "Cold Option", 1, 4), + ( + 10002, + "Ice Cream", + "Freezing Point", + "Vanilla Ice Cream", + "Dessert", + "Ice Cream", + 2, + 6, + ), +] + +INVALID_QUERIES = [ + ("SELECT * FROM nonexistent_table", "relation 'nonexistent_table' does not exist"), + ("SELECT invalid_column FROM menu", "invalid identifier 'invalid_column'"), + ("INVALID SQL QUERY", "SQL compilation error"), +] + + +# Integration Test Fixtures +@pytest.fixture +def config(): + """Create a Snowflake configuration with test credentials.""" + return SnowflakeConfig( + account="lwyhjun-wx11931", + user="crewgitci", + password="crewaiT00ls_publicCIpass123", + warehouse="COMPUTE_WH", + database="tasty_bytes_sample_data", + snowflake_schema="raw_pos", + ) + + +@pytest.fixture +def snowflake_tool(config): + """Create a SnowflakeSearchTool instance.""" + return SnowflakeSearchTool(config=config) + + +# Integration Tests with Real Snowflake Connection +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.parametrize( + "menu_id,expected_type,brand,item_name,category,subcategory,cost,price", MENU_ITEMS +) +async def test_menu_items( + snowflake_tool, + menu_id, + expected_type, + brand, + item_name, + category, + subcategory, + cost, + price, +): + """Test menu items with parameterized data for multiple test cases.""" + results = await snowflake_tool._run( + query=f"SELECT * FROM menu WHERE menu_id = {menu_id}" + ) + assert len(results) == 1 + menu_item = results[0] + + # Validate all fields + assert menu_item["MENU_ID"] == menu_id + assert menu_item["MENU_TYPE"] == expected_type + assert menu_item["TRUCK_BRAND_NAME"] == brand + assert menu_item["MENU_ITEM_NAME"] == item_name + assert menu_item["ITEM_CATEGORY"] == category + assert menu_item["ITEM_SUBCATEGORY"] == subcategory + assert menu_item["COST_OF_GOODS_USD"] == cost + assert menu_item["SALE_PRICE_USD"] == price + + # Validate health metrics JSON structure + health_metrics = json.loads(menu_item["MENU_ITEM_HEALTH_METRICS_OBJ"]) + assert "menu_item_health_metrics" in health_metrics + metrics = health_metrics["menu_item_health_metrics"][0] + assert "ingredients" in metrics + assert isinstance(metrics["ingredients"], list) + assert all(isinstance(ingredient, str) for ingredient in metrics["ingredients"]) + assert metrics["is_dairy_free_flag"] in ["Y", "N"] + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_menu_categories_aggregation(snowflake_tool): + """Test complex aggregation query on menu categories with detailed validations.""" + results = await snowflake_tool._run( + query=""" + SELECT + item_category, + COUNT(*) as item_count, + AVG(sale_price_usd) as avg_price, + SUM(sale_price_usd - cost_of_goods_usd) as total_margin, + COUNT(DISTINCT menu_type) as menu_type_count, + MIN(sale_price_usd) as min_price, + MAX(sale_price_usd) as max_price + FROM menu + GROUP BY item_category + HAVING COUNT(*) > 1 + ORDER BY item_count DESC + """ + ) + + assert len(results) > 0 + for category in results: + # Basic presence checks + assert all( + key in category + for key in [ + "ITEM_CATEGORY", + "ITEM_COUNT", + "AVG_PRICE", + "TOTAL_MARGIN", + "MENU_TYPE_COUNT", + "MIN_PRICE", + "MAX_PRICE", + ] + ) + + # Value validations + assert category["ITEM_COUNT"] > 1 # Due to HAVING clause + assert category["MIN_PRICE"] <= category["MAX_PRICE"] + assert category["AVG_PRICE"] >= category["MIN_PRICE"] + assert category["AVG_PRICE"] <= category["MAX_PRICE"] + assert category["MENU_TYPE_COUNT"] >= 1 + assert isinstance(category["TOTAL_MARGIN"], (float, Decimal)) + + +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.parametrize("invalid_query,expected_error", INVALID_QUERIES) +async def test_invalid_queries(snowflake_tool, invalid_query, expected_error): + """Test error handling for invalid queries.""" + with pytest.raises((DatabaseError, OperationalError)) as exc_info: + await snowflake_tool._run(query=invalid_query) + assert expected_error.lower() in str(exc_info.value).lower() + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_concurrent_queries(snowflake_tool): + """Test handling of concurrent queries.""" + queries = [ + "SELECT COUNT(*) FROM menu", + "SELECT COUNT(DISTINCT menu_type) FROM menu", + "SELECT COUNT(DISTINCT item_category) FROM menu", + ] + + tasks = [snowflake_tool._run(query=query) for query in queries] + results = await asyncio.gather(*tasks) + + assert len(results) == 3 + assert all(isinstance(result, list) for result in results) + assert all(len(result) == 1 for result in results) + assert all(isinstance(result[0], dict) for result in results) + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_query_timeout(snowflake_tool): + """Test query timeout handling with a complex query.""" + with pytest.raises((DatabaseError, OperationalError)) as exc_info: + await snowflake_tool._run( + query=""" + WITH RECURSIVE numbers AS ( + SELECT 1 as n + UNION ALL + SELECT n + 1 + FROM numbers + WHERE n < 1000000 + ) + SELECT COUNT(*) FROM numbers + """ + ) + assert ( + "timeout" in str(exc_info.value).lower() + or "execution time" in str(exc_info.value).lower() + ) + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_caching_behavior(snowflake_tool): + """Test query caching behavior and performance.""" + query = "SELECT * FROM menu LIMIT 5" + + # First execution + start_time = asyncio.get_event_loop().time() + results1 = await snowflake_tool._run(query=query) + first_duration = asyncio.get_event_loop().time() - start_time + + # Second execution (should be cached) + start_time = asyncio.get_event_loop().time() + results2 = await snowflake_tool._run(query=query) + second_duration = asyncio.get_event_loop().time() - start_time + + # Verify results + assert results1 == results2 + assert len(results1) == 5 + assert second_duration < first_duration + + # Verify cache invalidation with different query + different_query = "SELECT * FROM menu LIMIT 10" + different_results = await snowflake_tool._run(query=different_query) + assert len(different_results) == 10 + assert different_results != results1 diff --git a/tests/spider_tool_test.py b/tests/spider_tool_test.py index 264394777..7f5613fe6 100644 --- a/tests/spider_tool_test.py +++ b/tests/spider_tool_test.py @@ -1,5 +1,7 @@ +from crewai import Agent, Crew, Task + from crewai_tools.tools.spider_tool.spider_tool import SpiderTool -from crewai import Agent, Task, Crew + def test_spider_tool(): spider_tool = SpiderTool() @@ -10,38 +12,35 @@ def test_spider_tool(): backstory="An expert web researcher that uses the web extremely well", tools=[spider_tool], verbose=True, - cache=False + cache=False, ) choose_between_scrape_crawl = Task( description="Scrape the page of spider.cloud and return a summary of how fast it is", expected_output="spider.cloud is a fast scraping and crawling tool", - agent=searcher + agent=searcher, ) return_metadata = Task( description="Scrape https://spider.cloud with a limit of 1 and enable metadata", expected_output="Metadata and 10 word summary of spider.cloud", - agent=searcher + agent=searcher, ) css_selector = Task( description="Scrape one page of spider.cloud with the `body > div > main > section.grid.md\:grid-cols-2.gap-10.place-items-center.md\:max-w-screen-xl.mx-auto.pb-8.pt-20 > div:nth-child(1) > h1` CSS selector", expected_output="The content of the element with the css selector body > div > main > section.grid.md\:grid-cols-2.gap-10.place-items-center.md\:max-w-screen-xl.mx-auto.pb-8.pt-20 > div:nth-child(1) > h1", - agent=searcher + agent=searcher, ) crew = Crew( agents=[searcher], - tasks=[ - choose_between_scrape_crawl, - return_metadata, - css_selector - ], - verbose=True + tasks=[choose_between_scrape_crawl, return_metadata, css_selector], + verbose=True, ) crew.kickoff() + if __name__ == "__main__": test_spider_tool() diff --git a/tests/tools/snowflake_search_tool_test.py b/tests/tools/snowflake_search_tool_test.py new file mode 100644 index 000000000..d4851b8ab --- /dev/null +++ b/tests/tools/snowflake_search_tool_test.py @@ -0,0 +1,103 @@ +import asyncio +from unittest.mock import MagicMock, patch + +import pytest + +from crewai_tools import SnowflakeConfig, SnowflakeSearchTool + + +# Unit Test Fixtures +@pytest.fixture +def mock_snowflake_connection(): + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.description = [("col1",), ("col2",)] + mock_cursor.fetchall.return_value = [(1, "value1"), (2, "value2")] + mock_cursor.execute.return_value = None + mock_conn.cursor.return_value = mock_cursor + return mock_conn + + +@pytest.fixture +def mock_config(): + return SnowflakeConfig( + account="test_account", + user="test_user", + password="test_password", + warehouse="test_warehouse", + database="test_db", + snowflake_schema="test_schema", + ) + + +@pytest.fixture +def snowflake_tool(mock_config): + with patch("snowflake.connector.connect") as mock_connect: + tool = SnowflakeSearchTool(config=mock_config) + yield tool + + +# Unit Tests +@pytest.mark.asyncio +async def test_successful_query_execution(snowflake_tool, mock_snowflake_connection): + with patch.object(snowflake_tool, "_create_connection") as mock_create_conn: + mock_create_conn.return_value = mock_snowflake_connection + + results = await snowflake_tool._run( + query="SELECT * FROM test_table", timeout=300 + ) + + assert len(results) == 2 + assert results[0]["col1"] == 1 + assert results[0]["col2"] == "value1" + mock_snowflake_connection.cursor.assert_called_once() + + +@pytest.mark.asyncio +async def test_connection_pooling(snowflake_tool, mock_snowflake_connection): + with patch.object(snowflake_tool, "_create_connection") as mock_create_conn: + mock_create_conn.return_value = mock_snowflake_connection + + # Execute multiple queries + await asyncio.gather( + snowflake_tool._run("SELECT 1"), + snowflake_tool._run("SELECT 2"), + snowflake_tool._run("SELECT 3"), + ) + + # Should reuse connections from pool + assert mock_create_conn.call_count <= snowflake_tool.pool_size + + +@pytest.mark.asyncio +async def test_cleanup_on_deletion(snowflake_tool, mock_snowflake_connection): + with patch.object(snowflake_tool, "_create_connection") as mock_create_conn: + mock_create_conn.return_value = mock_snowflake_connection + + # Add connection to pool + await snowflake_tool._get_connection() + + # Return connection to pool + async with snowflake_tool._pool_lock: + snowflake_tool._connection_pool.append(mock_snowflake_connection) + + # Trigger cleanup + snowflake_tool.__del__() + + mock_snowflake_connection.close.assert_called_once() + + +def test_config_validation(): + # Test missing required fields + with pytest.raises(ValueError): + SnowflakeConfig() + + # Test invalid account format + with pytest.raises(ValueError): + SnowflakeConfig( + account="invalid//account", user="test_user", password="test_pass" + ) + + # Test missing authentication + with pytest.raises(ValueError): + SnowflakeConfig(account="test_account", user="test_user") diff --git a/tests/tools/test_code_interpreter_tool.py b/tests/tools/test_code_interpreter_tool.py index 6470c9dc1..e281fffaf 100644 --- a/tests/tools/test_code_interpreter_tool.py +++ b/tests/tools/test_code_interpreter_tool.py @@ -7,7 +7,9 @@ from crewai_tools.tools.code_interpreter_tool.code_interpreter_tool import ( class TestCodeInterpreterTool(unittest.TestCase): - @patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env") + @patch( + "crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env" + ) def test_run_code_in_docker(self, docker_mock): tool = CodeInterpreterTool() code = "print('Hello, World!')" @@ -15,14 +17,14 @@ class TestCodeInterpreterTool(unittest.TestCase): expected_output = "Hello, World!\n" docker_mock().containers.run().exec_run().exit_code = 0 - docker_mock().containers.run().exec_run().output = ( - expected_output.encode() - ) + docker_mock().containers.run().exec_run().output = expected_output.encode() result = tool.run_code_in_docker(code, libraries_used) self.assertEqual(result, expected_output) - @patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env") + @patch( + "crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env" + ) def test_run_code_in_docker_with_error(self, docker_mock): tool = CodeInterpreterTool() code = "print(1/0)" @@ -37,7 +39,9 @@ class TestCodeInterpreterTool(unittest.TestCase): self.assertEqual(result, expected_output) - @patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env") + @patch( + "crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env" + ) def test_run_code_in_docker_with_script(self, docker_mock): tool = CodeInterpreterTool() code = """print("This is line 1")