From 9c4c4219cd18b75f56fce8279a7cca1eb7672829 Mon Sep 17 00:00:00 2001 From: ChethanUK Date: Fri, 17 Jan 2025 02:23:06 +0530 Subject: [PATCH 1/3] Adding Snowflake search tool --- src/crewai_tools/__init__.py | 2 + src/crewai_tools/tools/__init__.py | 5 + .../browserbase_load_tool.py | 14 +- .../code_interpreter_tool.py | 18 +- .../directory_read_tool.py | 2 - .../tools/file_read_tool/file_read_tool.py | 7 +- .../file_writer_tool/file_writer_tool.py | 4 +- .../firecrawl_crawl_website_tool.py | 1 - .../firecrawl_scrape_website_tool.py | 1 - .../github_search_tool/github_search_tool.py | 4 +- .../jina_scrape_website_tool.py | 4 +- .../tools/linkup/linkup_search_tool.py | 16 +- .../mysql_search_tool/mysql_search_tool.py | 4 +- .../tools/patronus_eval_tool/example.py | 26 +-- .../patronus_eval_tool/patronus_eval_tool.py | 14 +- .../patronus_local_evaluator_tool.py | 17 +- .../patronus_predefined_criteria_eval_tool.py | 12 +- .../pdf_text_writing_tool.py | 8 +- .../tools/pg_seach_tool/pg_search_tool.py | 4 +- .../scrape_element_from_website.py | 2 - .../scrape_website_tool.py | 2 - .../scrapegraph_scrape_tool.py | 34 +-- .../selenium_scraping_tool.py | 27 ++- .../tools/serpapi_tool/serpapi_base_tool.py | 3 +- .../serpapi_google_search_tool.py | 41 ++-- .../serpapi_google_shopping_tool.py | 41 ++-- .../tools/serper_dev_tool/serper_dev_tool.py | 4 +- .../serply_webpage_to_markdown_tool.py | 4 +- .../tools/snowflake_search_tool/README.md | 155 +++++++++++++ .../tools/snowflake_search_tool/__init__.py | 11 + .../snowflake_search_tool.py | 201 ++++++++++++++++ .../tools/stagehand_tool/stagehand_tool.py | 154 ++++++------ .../tools/vision_tool/vision_tool.py | 24 +- .../tools/weaviate_tool/vector_search.py | 3 +- .../website_search/website_search_tool.py | 4 +- .../youtube_channel_search_tool.py | 4 +- .../youtube_video_search_tool.py | 4 +- tests/base_tool_test.py | 133 +++++++---- tests/file_read_tool_test.py | 6 +- tests/it/tools/__init__.py | 0 tests/it/tools/conftest.py | 21 ++ tests/it/tools/snowflake_search_tool_test.py | 219 ++++++++++++++++++ tests/spider_tool_test.py | 21 +- tests/tools/snowflake_search_tool_test.py | 103 ++++++++ tests/tools/test_code_interpreter_tool.py | 16 +- 45 files changed, 1089 insertions(+), 311 deletions(-) create mode 100644 src/crewai_tools/tools/snowflake_search_tool/README.md create mode 100644 src/crewai_tools/tools/snowflake_search_tool/__init__.py create mode 100644 src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py create mode 100644 tests/it/tools/__init__.py create mode 100644 tests/it/tools/conftest.py create mode 100644 tests/it/tools/snowflake_search_tool_test.py create mode 100644 tests/tools/snowflake_search_tool_test.py diff --git a/src/crewai_tools/__init__.py b/src/crewai_tools/__init__.py index 2db0fa05f..9c7e9d9a9 100644 --- a/src/crewai_tools/__init__.py +++ b/src/crewai_tools/__init__.py @@ -43,6 +43,8 @@ from .tools import ( SerplyScholarSearchTool, SerplyWebpageToMarkdownTool, SerplyWebSearchTool, + SnowflakeConfig, + SnowflakeSearchTool, SpiderTool, TXTSearchTool, VisionTool, diff --git a/src/crewai_tools/tools/__init__.py b/src/crewai_tools/tools/__init__.py index e4288a310..ea5a87ce1 100644 --- a/src/crewai_tools/tools/__init__.py +++ b/src/crewai_tools/tools/__init__.py @@ -54,6 +54,11 @@ from .serply_api_tool.serply_news_search_tool import SerplyNewsSearchTool from .serply_api_tool.serply_scholar_search_tool import SerplyScholarSearchTool from .serply_api_tool.serply_web_search_tool import SerplyWebSearchTool from .serply_api_tool.serply_webpage_to_markdown_tool import SerplyWebpageToMarkdownTool +from .snowflake_search_tool import ( + SnowflakeConfig, + SnowflakeSearchTool, + SnowflakeSearchToolInput, +) from .spider_tool.spider_tool import SpiderTool from .txt_search_tool.txt_search_tool import TXTSearchTool from .vision_tool.vision_tool import VisionTool diff --git a/src/crewai_tools/tools/browserbase_load_tool/browserbase_load_tool.py b/src/crewai_tools/tools/browserbase_load_tool/browserbase_load_tool.py index 2ca1b95fc..d3f76e0a6 100644 --- a/src/crewai_tools/tools/browserbase_load_tool/browserbase_load_tool.py +++ b/src/crewai_tools/tools/browserbase_load_tool/browserbase_load_tool.py @@ -1,8 +1,8 @@ import os from typing import Any, Optional, Type -from pydantic import BaseModel, Field from crewai.tools import BaseTool +from pydantic import BaseModel, Field class BrowserbaseLoadToolSchema(BaseModel): @@ -11,12 +11,10 @@ class BrowserbaseLoadToolSchema(BaseModel): class BrowserbaseLoadTool(BaseTool): name: str = "Browserbase web load tool" - description: str = ( - "Load webpages url in a headless browser using Browserbase and return the contents" - ) + description: str = "Load webpages url in a headless browser using Browserbase and return the contents" args_schema: Type[BaseModel] = BrowserbaseLoadToolSchema - api_key: Optional[str] = os.getenv('BROWSERBASE_API_KEY') - project_id: Optional[str] = os.getenv('BROWSERBASE_PROJECT_ID') + api_key: Optional[str] = os.getenv("BROWSERBASE_API_KEY") + project_id: Optional[str] = os.getenv("BROWSERBASE_PROJECT_ID") text_content: Optional[bool] = False session_id: Optional[str] = None proxy: Optional[bool] = None @@ -33,7 +31,9 @@ class BrowserbaseLoadTool(BaseTool): ): super().__init__(**kwargs) if not self.api_key: - raise EnvironmentError("BROWSERBASE_API_KEY environment variable is required for initialization") + raise EnvironmentError( + "BROWSERBASE_API_KEY environment variable is required for initialization" + ) try: from browserbase import Browserbase # type: ignore except ImportError: diff --git a/src/crewai_tools/tools/code_interpreter_tool/code_interpreter_tool.py b/src/crewai_tools/tools/code_interpreter_tool/code_interpreter_tool.py index fd0d39932..b508e4b6a 100644 --- a/src/crewai_tools/tools/code_interpreter_tool/code_interpreter_tool.py +++ b/src/crewai_tools/tools/code_interpreter_tool/code_interpreter_tool.py @@ -2,10 +2,10 @@ import importlib.util import os from typing import List, Optional, Type -from docker import from_env as docker_from_env -from docker.models.containers import Container -from docker.errors import ImageNotFound, NotFound from crewai.tools import BaseTool +from docker import from_env as docker_from_env +from docker.errors import ImageNotFound, NotFound +from docker.models.containers import Container from pydantic import BaseModel, Field @@ -30,7 +30,7 @@ class CodeInterpreterTool(BaseTool): default_image_tag: str = "code-interpreter:latest" code: Optional[str] = None user_dockerfile_path: Optional[str] = None - user_docker_base_url: Optional[str] = None + user_docker_base_url: Optional[str] = None unsafe_mode: bool = False @staticmethod @@ -43,7 +43,11 @@ class CodeInterpreterTool(BaseTool): Verify if the Docker image is available. Optionally use a user-provided Dockerfile. """ - client = docker_from_env() if self.user_docker_base_url == None else docker.DockerClient(base_url=self.user_docker_base_url) + client = ( + docker_from_env() + if self.user_docker_base_url == None + else docker.DockerClient(base_url=self.user_docker_base_url) + ) try: client.images.get(self.default_image_tag) @@ -76,9 +80,7 @@ class CodeInterpreterTool(BaseTool): else: return self.run_code_in_docker(code, libraries_used) - def _install_libraries( - self, container: Container, libraries: List[str] - ) -> None: + def _install_libraries(self, container: Container, libraries: List[str]) -> None: """ Install missing libraries in the Docker container """ diff --git a/src/crewai_tools/tools/directory_read_tool/directory_read_tool.py b/src/crewai_tools/tools/directory_read_tool/directory_read_tool.py index 6033202be..8488f391e 100644 --- a/src/crewai_tools/tools/directory_read_tool/directory_read_tool.py +++ b/src/crewai_tools/tools/directory_read_tool/directory_read_tool.py @@ -8,8 +8,6 @@ from pydantic import BaseModel, Field class FixedDirectoryReadToolSchema(BaseModel): """Input for DirectoryReadTool.""" - pass - class DirectoryReadToolSchema(FixedDirectoryReadToolSchema): """Input for DirectoryReadTool.""" diff --git a/src/crewai_tools/tools/file_read_tool/file_read_tool.py b/src/crewai_tools/tools/file_read_tool/file_read_tool.py index 323a26d51..384b97f40 100644 --- a/src/crewai_tools/tools/file_read_tool/file_read_tool.py +++ b/src/crewai_tools/tools/file_read_tool/file_read_tool.py @@ -32,6 +32,7 @@ class FileReadTool(BaseTool): >>> content = tool.run() # Reads /path/to/file.txt >>> content = tool.run(file_path="/path/to/other.txt") # Reads other.txt """ + name: str = "Read a file's content" description: str = "A tool that reads the content of a file. To use this tool, provide a 'file_path' parameter with the path to the file you want to read." args_schema: Type[BaseModel] = FileReadToolSchema @@ -57,7 +58,7 @@ class FileReadTool(BaseTool): file_path = kwargs.get("file_path", self.file_path) if file_path is None: return "Error: No file path provided. Please provide a file path either in the constructor or as an argument." - + try: with open(file_path, "r") as file: return file.read() @@ -78,4 +79,6 @@ class FileReadTool(BaseTool): Returns: None """ - self.description = f"A tool that can be used to read {self.file_path}'s content." + self.description = ( + f"A tool that can be used to read {self.file_path}'s content." + ) diff --git a/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py b/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py index ed454a1bd..f975d3301 100644 --- a/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py +++ b/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py @@ -15,9 +15,7 @@ class FileWriterToolInput(BaseModel): class FileWriterTool(BaseTool): name: str = "File Writer Tool" - description: str = ( - "A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input." - ) + description: str = "A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input." args_schema: Type[BaseModel] = FileWriterToolInput def _run(self, **kwargs: Any) -> str: diff --git a/src/crewai_tools/tools/firecrawl_crawl_website_tool/firecrawl_crawl_website_tool.py b/src/crewai_tools/tools/firecrawl_crawl_website_tool/firecrawl_crawl_website_tool.py index 6c7c4ffd9..dcb70e291 100644 --- a/src/crewai_tools/tools/firecrawl_crawl_website_tool/firecrawl_crawl_website_tool.py +++ b/src/crewai_tools/tools/firecrawl_crawl_website_tool/firecrawl_crawl_website_tool.py @@ -72,4 +72,3 @@ except ImportError: """ When this tool is not used, then exception can be ignored. """ - pass diff --git a/src/crewai_tools/tools/firecrawl_scrape_website_tool/firecrawl_scrape_website_tool.py b/src/crewai_tools/tools/firecrawl_scrape_website_tool/firecrawl_scrape_website_tool.py index 9458e7a4f..3f5f8c4c4 100644 --- a/src/crewai_tools/tools/firecrawl_scrape_website_tool/firecrawl_scrape_website_tool.py +++ b/src/crewai_tools/tools/firecrawl_scrape_website_tool/firecrawl_scrape_website_tool.py @@ -63,4 +63,3 @@ except ImportError: """ When this tool is not used, then exception can be ignored. """ - pass diff --git a/src/crewai_tools/tools/github_search_tool/github_search_tool.py b/src/crewai_tools/tools/github_search_tool/github_search_tool.py index 4bf8b9e05..6ba7b919c 100644 --- a/src/crewai_tools/tools/github_search_tool/github_search_tool.py +++ b/src/crewai_tools/tools/github_search_tool/github_search_tool.py @@ -27,9 +27,7 @@ class GithubSearchToolSchema(FixedGithubSearchToolSchema): class GithubSearchTool(RagTool): name: str = "Search a github repo's content" - description: str = ( - "A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities." - ) + description: str = "A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities." summarize: bool = False gh_token: str args_schema: Type[BaseModel] = GithubSearchToolSchema diff --git a/src/crewai_tools/tools/jina_scrape_website_tool/jina_scrape_website_tool.py b/src/crewai_tools/tools/jina_scrape_website_tool/jina_scrape_website_tool.py index a10a4ffdb..86f771cd0 100644 --- a/src/crewai_tools/tools/jina_scrape_website_tool/jina_scrape_website_tool.py +++ b/src/crewai_tools/tools/jina_scrape_website_tool/jina_scrape_website_tool.py @@ -13,9 +13,7 @@ class JinaScrapeWebsiteToolInput(BaseModel): class JinaScrapeWebsiteTool(BaseTool): name: str = "JinaScrapeWebsiteTool" - description: str = ( - "A tool that can be used to read a website content using Jina.ai reader and return markdown content." - ) + description: str = "A tool that can be used to read a website content using Jina.ai reader and return markdown content." args_schema: Type[BaseModel] = JinaScrapeWebsiteToolInput website_url: Optional[str] = None api_key: Optional[str] = None diff --git a/src/crewai_tools/tools/linkup/linkup_search_tool.py b/src/crewai_tools/tools/linkup/linkup_search_tool.py index b172ad029..486663d3e 100644 --- a/src/crewai_tools/tools/linkup/linkup_search_tool.py +++ b/src/crewai_tools/tools/linkup/linkup_search_tool.py @@ -2,6 +2,7 @@ from typing import Any try: from linkup import LinkupClient + LINKUP_AVAILABLE = True except ImportError: LINKUP_AVAILABLE = False @@ -9,10 +10,13 @@ except ImportError: from pydantic import PrivateAttr + class LinkupSearchTool: name: str = "Linkup Search Tool" - description: str = "Performs an API call to Linkup to retrieve contextual information." - _client: LinkupClient = PrivateAttr() # type: ignore + description: str = ( + "Performs an API call to Linkup to retrieve contextual information." + ) + _client: LinkupClient = PrivateAttr() # type: ignore def __init__(self, api_key: str): """ @@ -25,7 +29,9 @@ class LinkupSearchTool: ) self._client = LinkupClient(api_key=api_key) - def _run(self, query: str, depth: str = "standard", output_type: str = "searchResults") -> dict: + def _run( + self, query: str, depth: str = "standard", output_type: str = "searchResults" + ) -> dict: """ Executes a search using the Linkup API. @@ -36,9 +42,7 @@ class LinkupSearchTool: """ try: response = self._client.search( - query=query, - depth=depth, - output_type=output_type + query=query, depth=depth, output_type=output_type ) results = [ {"name": result.name, "url": result.url, "content": result.content} diff --git a/src/crewai_tools/tools/mysql_search_tool/mysql_search_tool.py b/src/crewai_tools/tools/mysql_search_tool/mysql_search_tool.py index f931a006b..a472e1761 100644 --- a/src/crewai_tools/tools/mysql_search_tool/mysql_search_tool.py +++ b/src/crewai_tools/tools/mysql_search_tool/mysql_search_tool.py @@ -17,9 +17,7 @@ class MySQLSearchToolSchema(BaseModel): class MySQLSearchTool(RagTool): name: str = "Search a database's table content" - description: str = ( - "A tool that can be used to semantic search a query from a database table's content." - ) + description: str = "A tool that can be used to semantic search a query from a database table's content." args_schema: Type[BaseModel] = MySQLSearchToolSchema db_uri: str = Field(..., description="Mandatory database URI") diff --git a/src/crewai_tools/tools/patronus_eval_tool/example.py b/src/crewai_tools/tools/patronus_eval_tool/example.py index b9e1bad5e..185e9f485 100644 --- a/src/crewai_tools/tools/patronus_eval_tool/example.py +++ b/src/crewai_tools/tools/patronus_eval_tool/example.py @@ -1,30 +1,24 @@ -from crewai import Agent, Crew, Task -from patronus_eval_tool import ( - PatronusEvalTool, -) -from patronus_local_evaluator_tool import ( - PatronusLocalEvaluatorTool, -) -from patronus_predefined_criteria_eval_tool import ( - PatronusPredefinedCriteriaEvalTool, -) -from patronus import Client, EvaluationResult import random +from crewai import Agent, Crew, Task +from patronus import Client, EvaluationResult +from patronus_local_evaluator_tool import PatronusLocalEvaluatorTool # Test the PatronusLocalEvaluatorTool where agent uses the local evaluator client = Client() + # Example of an evaluator that returns a random pass/fail result @client.register_local_evaluator("random_evaluator") def random_evaluator(**kwargs): score = random.random() return EvaluationResult( - score_raw=score, - pass_=score >= 0.5, - explanation="example explanation" # Optional justification for LLM judges + score_raw=score, + pass_=score >= 0.5, + explanation="example explanation", # Optional justification for LLM judges ) + # 1. Uses PatronusEvalTool: agent can pick the best evaluator and criteria # patronus_eval_tool = PatronusEvalTool() @@ -35,7 +29,9 @@ def random_evaluator(**kwargs): # 3. Uses PatronusLocalEvaluatorTool: agent uses user defined evaluator patronus_eval_tool = PatronusLocalEvaluatorTool( - patronus_client=client, evaluator="random_evaluator", evaluated_model_gold_answer="example label" + patronus_client=client, + evaluator="random_evaluator", + evaluated_model_gold_answer="example label", ) # Create a new agent diff --git a/src/crewai_tools/tools/patronus_eval_tool/patronus_eval_tool.py b/src/crewai_tools/tools/patronus_eval_tool/patronus_eval_tool.py index 23ffe2fd4..be1f410e2 100644 --- a/src/crewai_tools/tools/patronus_eval_tool/patronus_eval_tool.py +++ b/src/crewai_tools/tools/patronus_eval_tool/patronus_eval_tool.py @@ -1,8 +1,9 @@ -import os import json -import requests +import os import warnings -from typing import Any, List, Dict, Optional +from typing import Any, Dict, List, Optional + +import requests from crewai.tools import BaseTool @@ -19,7 +20,9 @@ class PatronusEvalTool(BaseTool): self.evaluators = temp_evaluators self.criteria = temp_criteria self.description = self._generate_description() - warnings.warn("You are allowing the agent to select the best evaluator and criteria when you use the `PatronusEvalTool`. If this is not intended then please use `PatronusPredefinedCriteriaEvalTool` instead.") + warnings.warn( + "You are allowing the agent to select the best evaluator and criteria when you use the `PatronusEvalTool`. If this is not intended then please use `PatronusPredefinedCriteriaEvalTool` instead." + ) def _init_run(self): evaluators_set = json.loads( @@ -104,7 +107,6 @@ class PatronusEvalTool(BaseTool): evaluated_model_retrieved_context: Optional[str], evaluators: List[Dict[str, str]], ) -> Any: - # Assert correct format of evaluators evals = [] for ev in evaluators: @@ -136,4 +138,4 @@ class PatronusEvalTool(BaseTool): f"Failed to evaluate model input and output. Response status code: {response.status_code}. Reason: {response.text}" ) - return response.json() \ No newline at end of file + return response.json() diff --git a/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py b/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py index e65cb342d..66781c593 100644 --- a/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py +++ b/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py @@ -1,7 +1,8 @@ from typing import Any, Type + from crewai.tools import BaseTool -from pydantic import BaseModel, Field from patronus import Client +from pydantic import BaseModel, Field class FixedLocalEvaluatorToolSchema(BaseModel): @@ -24,16 +25,20 @@ class PatronusLocalEvaluatorTool(BaseTool): name: str = "Patronus Local Evaluator Tool" evaluator: str = "The registered local evaluator" evaluated_model_gold_answer: str = "The agent's gold answer" - description: str = ( - "This tool is used to evaluate the model input and output using custom function evaluators." - ) + description: str = "This tool is used to evaluate the model input and output using custom function evaluators." client: Any = None args_schema: Type[BaseModel] = FixedLocalEvaluatorToolSchema class Config: arbitrary_types_allowed = True - def __init__(self, patronus_client: Client, evaluator: str, evaluated_model_gold_answer: str, **kwargs: Any): + def __init__( + self, + patronus_client: Client, + evaluator: str, + evaluated_model_gold_answer: str, + **kwargs: Any, + ): super().__init__(**kwargs) self.client = patronus_client if evaluator: @@ -79,7 +84,7 @@ class PatronusLocalEvaluatorTool(BaseTool): if isinstance(evaluated_model_gold_answer, str) else evaluated_model_gold_answer.get("description") ), - tags={}, # Optional metadata, supports arbitrary kv pairs + tags={}, # Optional metadata, supports arbitrary kv pairs ) output = f"Evaluation result: {result.pass_}, Explanation: {result.explanation}" return output diff --git a/src/crewai_tools/tools/patronus_eval_tool/patronus_predefined_criteria_eval_tool.py b/src/crewai_tools/tools/patronus_eval_tool/patronus_predefined_criteria_eval_tool.py index 28ffc2912..cf906586d 100644 --- a/src/crewai_tools/tools/patronus_eval_tool/patronus_predefined_criteria_eval_tool.py +++ b/src/crewai_tools/tools/patronus_eval_tool/patronus_predefined_criteria_eval_tool.py @@ -1,7 +1,8 @@ -import os import json +import os +from typing import Any, Dict, List, Type + import requests -from typing import Any, List, Dict, Type from crewai.tools import BaseTool from pydantic import BaseModel, Field @@ -33,9 +34,7 @@ class PatronusPredefinedCriteriaEvalTool(BaseTool): """ name: str = "Call Patronus API tool for evaluation of model inputs and outputs" - description: str = ( - """This tool calls the Patronus Evaluation API that takes the following arguments:""" - ) + description: str = """This tool calls the Patronus Evaluation API that takes the following arguments:""" evaluate_url: str = "https://api.patronus.ai/v1/evaluate" args_schema: Type[BaseModel] = FixedBaseToolSchema evaluators: List[Dict[str, str]] = [] @@ -52,7 +51,6 @@ class PatronusPredefinedCriteriaEvalTool(BaseTool): self, **kwargs: Any, ) -> Any: - evaluated_model_input = kwargs.get("evaluated_model_input") evaluated_model_output = kwargs.get("evaluated_model_output") evaluated_model_retrieved_context = kwargs.get( @@ -103,4 +101,4 @@ class PatronusPredefinedCriteriaEvalTool(BaseTool): f"Failed to evaluate model input and output. Status code: {response.status_code}. Reason: {response.text}" ) - return response.json() \ No newline at end of file + return response.json() diff --git a/src/crewai_tools/tools/pdf_text_writing_tool/pdf_text_writing_tool.py b/src/crewai_tools/tools/pdf_text_writing_tool/pdf_text_writing_tool.py index 851593167..ad4d847b6 100644 --- a/src/crewai_tools/tools/pdf_text_writing_tool/pdf_text_writing_tool.py +++ b/src/crewai_tools/tools/pdf_text_writing_tool/pdf_text_writing_tool.py @@ -1,7 +1,9 @@ -from typing import Optional, Type -from pydantic import BaseModel, Field -from pypdf import PdfReader, PdfWriter, PageObject, ContentStream, NameObject, Font from pathlib import Path +from typing import Optional, Type + +from pydantic import BaseModel, Field +from pypdf import ContentStream, Font, NameObject, PageObject, PdfReader, PdfWriter + from crewai_tools.tools.rag.rag_tool import RagTool diff --git a/src/crewai_tools/tools/pg_seach_tool/pg_search_tool.py b/src/crewai_tools/tools/pg_seach_tool/pg_search_tool.py index dc75470a2..ec0207aa7 100644 --- a/src/crewai_tools/tools/pg_seach_tool/pg_search_tool.py +++ b/src/crewai_tools/tools/pg_seach_tool/pg_search_tool.py @@ -17,9 +17,7 @@ class PGSearchToolSchema(BaseModel): class PGSearchTool(RagTool): name: str = "Search a database's table content" - description: str = ( - "A tool that can be used to semantic search a query from a database table's content." - ) + description: str = "A tool that can be used to semantic search a query from a database table's content." args_schema: Type[BaseModel] = PGSearchToolSchema db_uri: str = Field(..., description="Mandatory database URI") diff --git a/src/crewai_tools/tools/scrape_element_from_website/scrape_element_from_website.py b/src/crewai_tools/tools/scrape_element_from_website/scrape_element_from_website.py index 14757d247..f1e215bf3 100644 --- a/src/crewai_tools/tools/scrape_element_from_website/scrape_element_from_website.py +++ b/src/crewai_tools/tools/scrape_element_from_website/scrape_element_from_website.py @@ -10,8 +10,6 @@ from pydantic import BaseModel, Field class FixedScrapeElementFromWebsiteToolSchema(BaseModel): """Input for ScrapeElementFromWebsiteTool.""" - pass - class ScrapeElementFromWebsiteToolSchema(FixedScrapeElementFromWebsiteToolSchema): """Input for ScrapeElementFromWebsiteTool.""" diff --git a/src/crewai_tools/tools/scrape_website_tool/scrape_website_tool.py b/src/crewai_tools/tools/scrape_website_tool/scrape_website_tool.py index 8cfc5d136..0e7e25ca6 100644 --- a/src/crewai_tools/tools/scrape_website_tool/scrape_website_tool.py +++ b/src/crewai_tools/tools/scrape_website_tool/scrape_website_tool.py @@ -11,8 +11,6 @@ from pydantic import BaseModel, Field class FixedScrapeWebsiteToolSchema(BaseModel): """Input for ScrapeWebsiteTool.""" - pass - class ScrapeWebsiteToolSchema(FixedScrapeWebsiteToolSchema): """Input for ScrapeWebsiteTool.""" diff --git a/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py b/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py index 906bf6376..29c132ea9 100644 --- a/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py +++ b/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py @@ -10,17 +10,14 @@ from scrapegraph_py.logger import sgai_logger class ScrapegraphError(Exception): """Base exception for Scrapegraph-related errors""" - pass class RateLimitError(ScrapegraphError): """Raised when API rate limits are exceeded""" - pass class FixedScrapegraphScrapeToolSchema(BaseModel): """Input for ScrapegraphScrapeTool when website_url is fixed.""" - pass class ScrapegraphScrapeToolSchema(FixedScrapegraphScrapeToolSchema): @@ -32,7 +29,7 @@ class ScrapegraphScrapeToolSchema(FixedScrapegraphScrapeToolSchema): description="Prompt to guide the extraction of content", ) - @validator('website_url') + @validator("website_url") def validate_url(cls, v): """Validate URL format""" try: @@ -41,13 +38,15 @@ class ScrapegraphScrapeToolSchema(FixedScrapegraphScrapeToolSchema): raise ValueError return v except Exception: - raise ValueError("Invalid URL format. URL must include scheme (http/https) and domain") + raise ValueError( + "Invalid URL format. URL must include scheme (http/https) and domain" + ) class ScrapegraphScrapeTool(BaseTool): """ A tool that uses Scrapegraph AI to intelligently scrape website content. - + Raises: ValueError: If API key is missing or URL format is invalid RateLimitError: If API rate limits are exceeded @@ -55,7 +54,9 @@ class ScrapegraphScrapeTool(BaseTool): """ name: str = "Scrapegraph website scraper" - description: str = "A tool that uses Scrapegraph AI to intelligently scrape website content." + description: str = ( + "A tool that uses Scrapegraph AI to intelligently scrape website content." + ) args_schema: Type[BaseModel] = ScrapegraphScrapeToolSchema website_url: Optional[str] = None user_prompt: Optional[str] = None @@ -70,7 +71,7 @@ class ScrapegraphScrapeTool(BaseTool): ): super().__init__(**kwargs) self.api_key = api_key or os.getenv("SCRAPEGRAPH_API_KEY") - + if not self.api_key: raise ValueError("Scrapegraph API key is required") @@ -79,7 +80,7 @@ class ScrapegraphScrapeTool(BaseTool): self.website_url = website_url self.description = f"A tool that uses Scrapegraph AI to intelligently scrape {website_url}'s content." self.args_schema = FixedScrapegraphScrapeToolSchema - + if user_prompt is not None: self.user_prompt = user_prompt @@ -94,22 +95,24 @@ class ScrapegraphScrapeTool(BaseTool): if not all([result.scheme, result.netloc]): raise ValueError except Exception: - raise ValueError("Invalid URL format. URL must include scheme (http/https) and domain") + raise ValueError( + "Invalid URL format. URL must include scheme (http/https) and domain" + ) def _handle_api_response(self, response: dict) -> str: """Handle and validate API response""" if not response: raise RuntimeError("Empty response from Scrapegraph API") - + if "error" in response: error_msg = response.get("error", {}).get("message", "Unknown error") if "rate limit" in error_msg.lower(): raise RateLimitError(f"Rate limit exceeded: {error_msg}") raise RuntimeError(f"API error: {error_msg}") - + if "result" not in response: raise RuntimeError("Invalid response format from Scrapegraph API") - + return response["result"] def _run( @@ -117,7 +120,10 @@ class ScrapegraphScrapeTool(BaseTool): **kwargs: Any, ) -> Any: website_url = kwargs.get("website_url", self.website_url) - user_prompt = kwargs.get("user_prompt", self.user_prompt) or "Extract the main content of the webpage" + user_prompt = ( + kwargs.get("user_prompt", self.user_prompt) + or "Extract the main content of the webpage" + ) if not website_url: raise ValueError("website_url is required") diff --git a/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py b/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py index d7a55428d..8099a06ab 100644 --- a/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py +++ b/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py @@ -17,33 +17,36 @@ class FixedSeleniumScrapingToolSchema(BaseModel): class SeleniumScrapingToolSchema(FixedSeleniumScrapingToolSchema): """Input for SeleniumScrapingTool.""" - website_url: str = Field(..., description="Mandatory website url to read the file. Must start with http:// or https://") + website_url: str = Field( + ..., + description="Mandatory website url to read the file. Must start with http:// or https://", + ) css_element: str = Field( ..., description="Mandatory css reference for element to scrape from the website", ) - @validator('website_url') + @validator("website_url") def validate_website_url(cls, v): if not v: raise ValueError("Website URL cannot be empty") - + if len(v) > 2048: # Common maximum URL length raise ValueError("URL is too long (max 2048 characters)") - - if not re.match(r'^https?://', v): + + if not re.match(r"^https?://", v): raise ValueError("URL must start with http:// or https://") - + try: result = urlparse(v) if not all([result.scheme, result.netloc]): raise ValueError("Invalid URL format") except Exception as e: raise ValueError(f"Invalid URL: {str(e)}") - - if re.search(r'\s', v): + + if re.search(r"\s", v): raise ValueError("URL cannot contain whitespace") - + return v @@ -130,11 +133,11 @@ class SeleniumScrapingTool(BaseTool): def _create_driver(self, url, cookie, wait_time): if not url: raise ValueError("URL cannot be empty") - + # Validate URL format - if not re.match(r'^https?://', url): + if not re.match(r"^https?://", url): raise ValueError("URL must start with http:// or https://") - + options = Options() options.add_argument("--headless") driver = self.driver(options=options) diff --git a/src/crewai_tools/tools/serpapi_tool/serpapi_base_tool.py b/src/crewai_tools/tools/serpapi_tool/serpapi_base_tool.py index 98491190c..895f3aadc 100644 --- a/src/crewai_tools/tools/serpapi_tool/serpapi_base_tool.py +++ b/src/crewai_tools/tools/serpapi_tool/serpapi_base_tool.py @@ -1,9 +1,10 @@ import os import re -from typing import Optional, Any, Union +from typing import Any, Optional, Union from crewai.tools import BaseTool + class SerpApiBaseTool(BaseTool): """Base class for SerpApi functionality with shared capabilities.""" diff --git a/src/crewai_tools/tools/serpapi_tool/serpapi_google_search_tool.py b/src/crewai_tools/tools/serpapi_tool/serpapi_google_search_tool.py index 199b7f5a2..c1a877f23 100644 --- a/src/crewai_tools/tools/serpapi_tool/serpapi_google_search_tool.py +++ b/src/crewai_tools/tools/serpapi_tool/serpapi_google_search_tool.py @@ -1,14 +1,21 @@ -from typing import Any, Type, Optional +from typing import Any, Optional, Type -import re from pydantic import BaseModel, Field -from .serpapi_base_tool import SerpApiBaseTool from serpapi import HTTPError +from .serpapi_base_tool import SerpApiBaseTool + + class SerpApiGoogleSearchToolSchema(BaseModel): """Input for Google Search.""" - search_query: str = Field(..., description="Mandatory search query you want to use to Google search.") - location: Optional[str] = Field(None, description="Location you want the search to be performed in.") + + search_query: str = Field( + ..., description="Mandatory search query you want to use to Google search." + ) + location: Optional[str] = Field( + None, description="Location you want the search to be performed in." + ) + class SerpApiGoogleSearchTool(SerpApiBaseTool): name: str = "Google Search" @@ -22,19 +29,25 @@ class SerpApiGoogleSearchTool(SerpApiBaseTool): **kwargs: Any, ) -> Any: try: - results = self.client.search({ - "q": kwargs.get("search_query"), - "location": kwargs.get("location"), - }).as_dict() + results = self.client.search( + { + "q": kwargs.get("search_query"), + "location": kwargs.get("location"), + } + ).as_dict() self._omit_fields( - results, - [r"search_metadata", r"search_parameters", r"serpapi_.+", r".+_token", r"displayed_link", r"pagination"] + results, + [ + r"search_metadata", + r"search_parameters", + r"serpapi_.+", + r".+_token", + r"displayed_link", + r"pagination", + ], ) return results except HTTPError as e: return f"An error occurred: {str(e)}. Some parameters may be invalid." - - - \ No newline at end of file diff --git a/src/crewai_tools/tools/serpapi_tool/serpapi_google_shopping_tool.py b/src/crewai_tools/tools/serpapi_tool/serpapi_google_shopping_tool.py index b44b3a809..ec9477351 100644 --- a/src/crewai_tools/tools/serpapi_tool/serpapi_google_shopping_tool.py +++ b/src/crewai_tools/tools/serpapi_tool/serpapi_google_shopping_tool.py @@ -1,14 +1,20 @@ -from typing import Any, Type, Optional +from typing import Any, Optional, Type -import re from pydantic import BaseModel, Field -from .serpapi_base_tool import SerpApiBaseTool from serpapi import HTTPError +from .serpapi_base_tool import SerpApiBaseTool + + class SerpApiGoogleShoppingToolSchema(BaseModel): """Input for Google Shopping.""" - search_query: str = Field(..., description="Mandatory search query you want to use to Google shopping.") - location: Optional[str] = Field(None, description="Location you want the search to be performed in.") + + search_query: str = Field( + ..., description="Mandatory search query you want to use to Google shopping." + ) + location: Optional[str] = Field( + None, description="Location you want the search to be performed in." + ) class SerpApiGoogleShoppingTool(SerpApiBaseTool): @@ -23,20 +29,25 @@ class SerpApiGoogleShoppingTool(SerpApiBaseTool): **kwargs: Any, ) -> Any: try: - results = self.client.search({ - "engine": "google_shopping", - "q": kwargs.get("search_query"), - "location": kwargs.get("location") - }).as_dict() + results = self.client.search( + { + "engine": "google_shopping", + "q": kwargs.get("search_query"), + "location": kwargs.get("location"), + } + ).as_dict() self._omit_fields( - results, - [r"search_metadata", r"search_parameters", r"serpapi_.+", r"filters", r"pagination"] + results, + [ + r"search_metadata", + r"search_parameters", + r"serpapi_.+", + r"filters", + r"pagination", + ], ) return results except HTTPError as e: return f"An error occurred: {str(e)}. Some parameters may be invalid." - - - \ No newline at end of file diff --git a/src/crewai_tools/tools/serper_dev_tool/serper_dev_tool.py b/src/crewai_tools/tools/serper_dev_tool/serper_dev_tool.py index e9eab56a2..2db347190 100644 --- a/src/crewai_tools/tools/serper_dev_tool/serper_dev_tool.py +++ b/src/crewai_tools/tools/serper_dev_tool/serper_dev_tool.py @@ -1,19 +1,19 @@ import datetime import json -import os import logging +import os from typing import Any, Type import requests from crewai.tools import BaseTool from pydantic import BaseModel, Field - logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) + def _save_results_to_file(content: str) -> None: """Saves the search results to a file.""" try: diff --git a/src/crewai_tools/tools/serply_api_tool/serply_webpage_to_markdown_tool.py b/src/crewai_tools/tools/serply_api_tool/serply_webpage_to_markdown_tool.py index e09a36fd9..4010236cc 100644 --- a/src/crewai_tools/tools/serply_api_tool/serply_webpage_to_markdown_tool.py +++ b/src/crewai_tools/tools/serply_api_tool/serply_webpage_to_markdown_tool.py @@ -18,9 +18,7 @@ class SerplyWebpageToMarkdownToolSchema(BaseModel): class SerplyWebpageToMarkdownTool(RagTool): name: str = "Webpage to Markdown" - description: str = ( - "A tool to perform convert a webpage to markdown to make it easier for LLMs to understand" - ) + description: str = "A tool to perform convert a webpage to markdown to make it easier for LLMs to understand" args_schema: Type[BaseModel] = SerplyWebpageToMarkdownToolSchema request_url: str = "https://api.serply.io/v1/request" proxy_location: Optional[str] = "US" diff --git a/src/crewai_tools/tools/snowflake_search_tool/README.md b/src/crewai_tools/tools/snowflake_search_tool/README.md new file mode 100644 index 000000000..fc0b845c3 --- /dev/null +++ b/src/crewai_tools/tools/snowflake_search_tool/README.md @@ -0,0 +1,155 @@ +# Snowflake Search Tool + +A tool for executing queries on Snowflake data warehouse with built-in connection pooling, retry logic, and async execution support. + +## Installation + +```bash +uv sync --extra snowflake + +OR +uv pip install snowflake-connector-python>=3.5.0 snowflake-sqlalchemy>=1.5.0 cryptography>=41.0.0 + +OR +pip install snowflake-connector-python>=3.5.0 snowflake-sqlalchemy>=1.5.0 cryptography>=41.0.0 +``` + +## Quick Start + +```python +import asyncio +from crewai_tools import SnowflakeSearchTool, SnowflakeConfig + +# Create configuration +config = SnowflakeConfig( + account="your_account", + user="your_username", + password="your_password", + warehouse="COMPUTE_WH", + database="your_database", + snowflake_schema="your_schema" # Note: Uses snowflake_schema instead of schema +) + +# Initialize tool +tool = SnowflakeSearchTool( + config=config, + pool_size=5, + max_retries=3, + enable_caching=True +) + +# Execute query +async def main(): + results = await tool._run( + query="SELECT * FROM your_table LIMIT 10", + timeout=300 + ) + print(f"Retrieved {len(results)} rows") + +if __name__ == "__main__": + asyncio.run(main()) +``` + +## Features + +- ✨ Asynchronous query execution +- 🚀 Connection pooling for better performance +- 🔄 Automatic retries for transient failures +- 💾 Query result caching (optional) +- 🔒 Support for both password and key-pair authentication +- 📝 Comprehensive error handling and logging + +## Configuration Options + +### SnowflakeConfig Parameters + +| Parameter | Required | Description | +|-----------|----------|-------------| +| account | Yes | Snowflake account identifier | +| user | Yes | Snowflake username | +| password | Yes* | Snowflake password | +| private_key_path | No* | Path to private key file (alternative to password) | +| warehouse | Yes | Snowflake warehouse name | +| database | Yes | Default database | +| snowflake_schema | Yes | Default schema | +| role | No | Snowflake role | +| session_parameters | No | Custom session parameters dict | + +\* Either password or private_key_path must be provided + +### Tool Parameters + +| Parameter | Default | Description | +|-----------|---------|-------------| +| pool_size | 5 | Number of connections in the pool | +| max_retries | 3 | Maximum retry attempts for failed queries | +| retry_delay | 1.0 | Delay between retries in seconds | +| enable_caching | True | Enable/disable query result caching | + +## Advanced Usage + +### Using Key-Pair Authentication + +```python +config = SnowflakeConfig( + account="your_account", + user="your_username", + private_key_path="/path/to/private_key.p8", + warehouse="your_warehouse", + database="your_database", + snowflake_schema="your_schema" +) +``` + +### Custom Session Parameters + +```python +config = SnowflakeConfig( + # ... other config parameters ... + session_parameters={ + "QUERY_TAG": "my_app", + "TIMEZONE": "America/Los_Angeles" + } +) +``` + +## Best Practices + +1. **Error Handling**: Always wrap query execution in try-except blocks +2. **Logging**: Enable logging to track query execution and errors +3. **Connection Management**: Use appropriate pool sizes for your workload +4. **Timeouts**: Set reasonable query timeouts to prevent hanging +5. **Security**: Use key-pair auth in production and never hardcode credentials + +## Example with Logging + +```python +import logging + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +async def main(): + try: + # ... tool initialization ... + results = await tool._run(query="SELECT * FROM table LIMIT 10") + logger.info(f"Query completed successfully. Retrieved {len(results)} rows") + except Exception as e: + logger.error(f"Query failed: {str(e)}") + raise +``` + +## Error Handling + +The tool automatically handles common Snowflake errors: +- DatabaseError +- OperationalError +- ProgrammingError +- Network timeouts +- Connection issues + +Errors are logged and retried based on your retry configuration. \ No newline at end of file diff --git a/src/crewai_tools/tools/snowflake_search_tool/__init__.py b/src/crewai_tools/tools/snowflake_search_tool/__init__.py new file mode 100644 index 000000000..abc1a45f5 --- /dev/null +++ b/src/crewai_tools/tools/snowflake_search_tool/__init__.py @@ -0,0 +1,11 @@ +from .snowflake_search_tool import ( + SnowflakeConfig, + SnowflakeSearchTool, + SnowflakeSearchToolInput, +) + +__all__ = [ + "SnowflakeSearchTool", + "SnowflakeSearchToolInput", + "SnowflakeConfig", +] diff --git a/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py b/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py new file mode 100644 index 000000000..75c671d21 --- /dev/null +++ b/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py @@ -0,0 +1,201 @@ +import asyncio +import logging +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, List, Optional, Type + +import snowflake.connector +from crewai.tools.base_tool import BaseTool +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from pydantic import BaseModel, ConfigDict, Field, SecretStr +from snowflake.connector.connection import SnowflakeConnection +from snowflake.connector.errors import DatabaseError, OperationalError + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Cache for query results +_query_cache = {} + + +class SnowflakeConfig(BaseModel): + """Configuration for Snowflake connection.""" + + model_config = ConfigDict(protected_namespaces=()) + + account: str = Field( + ..., description="Snowflake account identifier", pattern=r"^[a-zA-Z0-9\-_]+$" + ) + user: str = Field(..., description="Snowflake username") + password: Optional[SecretStr] = Field(None, description="Snowflake password") + private_key_path: Optional[str] = Field( + None, description="Path to private key file" + ) + warehouse: Optional[str] = Field(None, description="Snowflake warehouse") + database: Optional[str] = Field(None, description="Default database") + snowflake_schema: Optional[str] = Field(None, description="Default schema") + role: Optional[str] = Field(None, description="Snowflake role") + session_parameters: Optional[Dict[str, Any]] = Field( + default_factory=dict, description="Session parameters" + ) + + @property + def has_auth(self) -> bool: + return bool(self.password or self.private_key_path) + + def model_post_init(self, *args, **kwargs): + if not self.has_auth: + raise ValueError("Either password or private_key_path must be provided") + + +class SnowflakeSearchToolInput(BaseModel): + """Input schema for SnowflakeSearchTool.""" + + model_config = ConfigDict(protected_namespaces=()) + + query: str = Field(..., description="SQL query or semantic search query to execute") + database: Optional[str] = Field(None, description="Override default database") + snowflake_schema: Optional[str] = Field(None, description="Override default schema") + timeout: Optional[int] = Field(300, description="Query timeout in seconds") + + +class SnowflakeSearchTool(BaseTool): + """Tool for executing queries and semantic search on Snowflake.""" + + name: str = "Snowflake Database Search" + description: str = ( + "Execute SQL queries or semantic search on Snowflake data warehouse. " + "Supports both raw SQL and natural language queries." + ) + args_schema: Type[BaseModel] = SnowflakeSearchToolInput + + # Define Pydantic fields + config: SnowflakeConfig = Field( + ..., description="Snowflake connection configuration" + ) + pool_size: int = Field(default=5, description="Size of connection pool") + max_retries: int = Field(default=3, description="Maximum retry attempts") + retry_delay: float = Field( + default=1.0, description="Delay between retries in seconds" + ) + enable_caching: bool = Field( + default=True, description="Enable query result caching" + ) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def __init__(self, **data): + """Initialize SnowflakeSearchTool.""" + super().__init__(**data) + self._connection_pool: List[SnowflakeConnection] = [] + self._pool_lock = asyncio.Lock() + self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size) + + async def _get_connection(self) -> SnowflakeConnection: + """Get a connection from the pool or create a new one.""" + async with self._pool_lock: + if not self._connection_pool: + conn = self._create_connection() + self._connection_pool.append(conn) + return self._connection_pool.pop() + + def _create_connection(self) -> SnowflakeConnection: + """Create a new Snowflake connection.""" + conn_params = { + "account": self.config.account, + "user": self.config.user, + "warehouse": self.config.warehouse, + "database": self.config.database, + "schema": self.config.snowflake_schema, + "role": self.config.role, + "session_parameters": self.config.session_parameters, + } + + if self.config.password: + conn_params["password"] = self.config.password.get_secret_value() + elif self.config.private_key_path: + with open(self.config.private_key_path, "rb") as key_file: + p_key = serialization.load_pem_private_key( + key_file.read(), password=None, backend=default_backend() + ) + conn_params["private_key"] = p_key + + return snowflake.connector.connect(**conn_params) + + def _get_cache_key(self, query: str, timeout: int) -> str: + """Generate a cache key for the query.""" + return f"{self.config.account}:{self.config.database}:{self.config.snowflake_schema}:{query}:{timeout}" + + async def _execute_query( + self, query: str, timeout: int = 300 + ) -> List[Dict[str, Any]]: + """Execute a query with retries and return results.""" + if self.enable_caching: + cache_key = self._get_cache_key(query, timeout) + if cache_key in _query_cache: + logger.info("Returning cached result") + return _query_cache[cache_key] + + for attempt in range(self.max_retries): + try: + conn = await self._get_connection() + try: + cursor = conn.cursor() + cursor.execute(query, timeout=timeout) + + if not cursor.description: + return [] + + columns = [col[0] for col in cursor.description] + results = [dict(zip(columns, row)) for row in cursor.fetchall()] + + if self.enable_caching: + _query_cache[self._get_cache_key(query, timeout)] = results + + return results + finally: + cursor.close() + async with self._pool_lock: + self._connection_pool.append(conn) + except (DatabaseError, OperationalError) as e: + if attempt == self.max_retries - 1: + raise + await asyncio.sleep(self.retry_delay * (2**attempt)) + logger.warning(f"Query failed, attempt {attempt + 1}: {str(e)}") + continue + + async def _run( + self, + query: str, + database: Optional[str] = None, + snowflake_schema: Optional[str] = None, + timeout: int = 300, + **kwargs: Any, + ) -> Any: + """Execute the search query.""" + try: + # Override database/schema if provided + if database: + await self._execute_query(f"USE DATABASE {database}") + if snowflake_schema: + await self._execute_query(f"USE SCHEMA {snowflake_schema}") + + results = await self._execute_query(query, timeout) + return results + except Exception as e: + logger.error(f"Error executing query: {str(e)}") + raise + + def __del__(self): + """Cleanup connections on deletion.""" + try: + for conn in getattr(self, "_connection_pool", []): + try: + conn.close() + except: + pass + if hasattr(self, "_thread_pool"): + self._thread_pool.shutdown() + except: + pass diff --git a/src/crewai_tools/tools/stagehand_tool/stagehand_tool.py b/src/crewai_tools/tools/stagehand_tool/stagehand_tool.py index 07c76c8c3..37b414509 100644 --- a/src/crewai_tools/tools/stagehand_tool/stagehand_tool.py +++ b/src/crewai_tools/tools/stagehand_tool/stagehand_tool.py @@ -14,9 +14,8 @@ import os from functools import lru_cache from typing import Any, Dict, List, Optional, Type, Union -from pydantic import BaseModel, Field - from crewai.tools.base_tool import BaseTool +from pydantic import BaseModel, Field # Set up logging logger = logging.getLogger(__name__) @@ -25,6 +24,7 @@ logger = logging.getLogger(__name__) STAGEHAND_AVAILABLE = False try: import stagehand + STAGEHAND_AVAILABLE = True except ImportError: pass # Keep STAGEHAND_AVAILABLE as False @@ -32,33 +32,45 @@ except ImportError: class StagehandResult(BaseModel): """Result from a Stagehand operation. - + Attributes: success: Whether the operation completed successfully data: The result data from the operation error: Optional error message if the operation failed """ - success: bool = Field(..., description="Whether the operation completed successfully") - data: Union[str, Dict, List] = Field(..., description="The result data from the operation") - error: Optional[str] = Field(None, description="Optional error message if the operation failed") + + success: bool = Field( + ..., description="Whether the operation completed successfully" + ) + data: Union[str, Dict, List] = Field( + ..., description="The result data from the operation" + ) + error: Optional[str] = Field( + None, description="Optional error message if the operation failed" + ) class StagehandToolConfig(BaseModel): """Configuration for the StagehandTool. - + Attributes: api_key: OpenAI API key for Stagehand authentication timeout: Maximum time in seconds to wait for operations (default: 30) retry_attempts: Number of times to retry failed operations (default: 3) """ + api_key: str = Field(..., description="OpenAI API key for Stagehand authentication") - timeout: int = Field(30, description="Maximum time in seconds to wait for operations") - retry_attempts: int = Field(3, description="Number of times to retry failed operations") + timeout: int = Field( + 30, description="Maximum time in seconds to wait for operations" + ) + retry_attempts: int = Field( + 3, description="Number of times to retry failed operations" + ) class StagehandToolSchema(BaseModel): """Schema for the StagehandTool input parameters. - + Examples: ```python # Using the 'act' API to click a button @@ -66,13 +78,13 @@ class StagehandToolSchema(BaseModel): api_method="act", instruction="Click the 'Sign In' button" ) - + # Using the 'extract' API to get text tool.run( api_method="extract", instruction="Get the text content of the main article" ) - + # Using the 'observe' API to monitor changes tool.run( api_method="observe", @@ -80,48 +92,49 @@ class StagehandToolSchema(BaseModel): ) ``` """ + api_method: str = Field( ..., description="The Stagehand API to use: 'act' for interactions, 'extract' for getting content, or 'observe' for monitoring changes", - pattern="^(act|extract|observe)$" + pattern="^(act|extract|observe)$", ) instruction: str = Field( ..., description="An atomic instruction for Stagehand to execute. Instructions should be simple and specific to increase reliability.", min_length=1, - max_length=500 + max_length=500, ) class StagehandTool(BaseTool): """A tool for using Stagehand's AI-powered web automation capabilities. - + This tool provides access to Stagehand's three core APIs: - act: Perform web interactions (e.g., clicking buttons, filling forms) - extract: Extract information from web pages (e.g., getting text content) - observe: Monitor web page changes (e.g., watching for updates) - + Each function takes atomic instructions to increase reliability. - + Required Environment Variables: OPENAI_API_KEY: API key for OpenAI (required by Stagehand) - + Examples: ```python tool = StagehandTool() - + # Perform a web interaction result = tool.run( api_method="act", instruction="Click the 'Sign In' button" ) - + # Extract content from a page content = tool.run( api_method="extract", instruction="Get the text content of the main article" ) - + # Monitor for changes changes = tool.run( api_method="observe", @@ -129,7 +142,7 @@ class StagehandTool(BaseTool): ) ``` """ - + name: str = "StagehandTool" description: str = ( "A tool that uses Stagehand's AI-powered web automation to interact with websites. " @@ -137,27 +150,29 @@ class StagehandTool(BaseTool): "Each instruction should be atomic (simple and specific) to increase reliability." ) args_schema: Type[BaseModel] = StagehandToolSchema - - def __init__(self, config: StagehandToolConfig | None = None, **kwargs: Any) -> None: + + def __init__( + self, config: StagehandToolConfig | None = None, **kwargs: Any + ) -> None: """Initialize the StagehandTool. - + Args: config: Optional configuration for the tool. If not provided, will attempt to use OPENAI_API_KEY from environment. **kwargs: Additional keyword arguments passed to the base class. - + Raises: ImportError: If the stagehand package is not installed ValueError: If no API key is provided via config or environment """ super().__init__(**kwargs) - + if not STAGEHAND_AVAILABLE: raise ImportError( "The 'stagehand' package is required to use this tool. " "Please install it with: pip install stagehand" ) - + # Use config if provided, otherwise try environment variable if config is not None: self.config = config @@ -168,24 +183,22 @@ class StagehandTool(BaseTool): "Either provide config with api_key or set OPENAI_API_KEY environment variable" ) self.config = StagehandToolConfig( - api_key=api_key, - timeout=30, - retry_attempts=3 + api_key=api_key, timeout=30, retry_attempts=3 ) - + @lru_cache(maxsize=100) def _cached_run(self, api_method: str, instruction: str) -> Any: """Execute a cached Stagehand command. - + This method is cached to improve performance for repeated operations. - + Args: api_method: The Stagehand API to use ('act', 'extract', or 'observe') instruction: An atomic instruction for Stagehand to execute - + Returns: The raw result from the Stagehand API call - + Raises: ValueError: If an invalid api_method is provided Exception: If the Stagehand API call fails @@ -193,23 +206,25 @@ class StagehandTool(BaseTool): logger.debug( "Cache operation - Method: %s, Instruction length: %d", api_method, - len(instruction) + len(instruction), ) - + # Initialize Stagehand with configuration logger.info( "Initializing Stagehand (timeout=%ds, retries=%d)", self.config.timeout, - self.config.retry_attempts + self.config.retry_attempts, ) st = stagehand.Stagehand( api_key=self.config.api_key, timeout=self.config.timeout, - retry_attempts=self.config.retry_attempts + retry_attempts=self.config.retry_attempts, ) - + # Call the appropriate Stagehand API based on the method - logger.info("Executing %s operation with instruction: %s", api_method, instruction[:100]) + logger.info( + "Executing %s operation with instruction: %s", api_method, instruction[:100] + ) try: if api_method == "act": result = st.act(instruction) @@ -219,28 +234,27 @@ class StagehandTool(BaseTool): result = st.observe(instruction) else: raise ValueError(f"Unknown api_method: {api_method}") - - + logger.info("Successfully executed %s operation", api_method) return result - + except Exception as e: logger.warning( "Operation failed (method=%s, error=%s), will be retried on next attempt", api_method, - str(e) + str(e), ) raise def _run(self, api_method: str, instruction: str, **kwargs: Any) -> StagehandResult: """Execute a Stagehand command using the specified API method. - + Args: api_method: The Stagehand API to use ('act', 'extract', or 'observe') instruction: An atomic instruction for Stagehand to execute **kwargs: Additional keyword arguments passed to the Stagehand API - - Returns: + + Returns: StagehandResult containing the operation result and status """ try: @@ -249,56 +263,36 @@ class StagehandTool(BaseTool): "Starting operation - Method: %s, Instruction length: %d, Args: %s", api_method, len(instruction), - kwargs + kwargs, ) - + # Use cached execution result = self._cached_run(api_method, instruction) logger.info("Operation completed successfully") return StagehandResult(success=True, data=result) - + except stagehand.AuthenticationError as e: logger.error( - "Authentication failed - Method: %s, Error: %s", - api_method, - str(e) + "Authentication failed - Method: %s, Error: %s", api_method, str(e) ) return StagehandResult( - success=False, - data={}, - error=f"Authentication failed: {str(e)}" + success=False, data={}, error=f"Authentication failed: {str(e)}" ) except stagehand.APIError as e: - logger.error( - "API error - Method: %s, Error: %s", - api_method, - str(e) - ) - return StagehandResult( - success=False, - data={}, - error=f"API error: {str(e)}" - ) + logger.error("API error - Method: %s, Error: %s", api_method, str(e)) + return StagehandResult(success=False, data={}, error=f"API error: {str(e)}") except stagehand.BrowserError as e: - logger.error( - "Browser error - Method: %s, Error: %s", - api_method, - str(e) - ) + logger.error("Browser error - Method: %s, Error: %s", api_method, str(e)) return StagehandResult( - success=False, - data={}, - error=f"Browser error: {str(e)}" + success=False, data={}, error=f"Browser error: {str(e)}" ) except Exception as e: logger.error( "Unexpected error - Method: %s, Error type: %s, Message: %s", api_method, type(e).__name__, - str(e) + str(e), ) return StagehandResult( - success=False, - data={}, - error=f"Unexpected error: {str(e)}" + success=False, data={}, error=f"Unexpected error: {str(e)}" ) diff --git a/src/crewai_tools/tools/vision_tool/vision_tool.py b/src/crewai_tools/tools/vision_tool/vision_tool.py index 4fbc1df0e..594be0b22 100644 --- a/src/crewai_tools/tools/vision_tool/vision_tool.py +++ b/src/crewai_tools/tools/vision_tool/vision_tool.py @@ -1,30 +1,36 @@ import base64 -from typing import Type, Optional from pathlib import Path +from typing import Optional, Type + from crewai.tools import BaseTool from openai import OpenAI from pydantic import BaseModel, validator + class ImagePromptSchema(BaseModel): """Input for Vision Tool.""" + image_path_url: str = "The image path or URL." @validator("image_path_url") def validate_image_path_url(cls, v: str) -> str: if v.startswith("http"): return v - + path = Path(v) if not path.exists(): raise ValueError(f"Image file does not exist: {v}") - + # Validate supported formats valid_extensions = {".jpg", ".jpeg", ".png", ".gif", ".webp"} if path.suffix.lower() not in valid_extensions: - raise ValueError(f"Unsupported image format. Supported formats: {valid_extensions}") - + raise ValueError( + f"Unsupported image format. Supported formats: {valid_extensions}" + ) + return v + class VisionTool(BaseTool): name: str = "Vision Tool" description: str = ( @@ -45,10 +51,10 @@ class VisionTool(BaseTool): image_path_url = kwargs.get("image_path_url") if not image_path_url: return "Image Path or URL is required." - + # Validate input using Pydantic ImagePromptSchema(image_path_url=image_path_url) - + if image_path_url.startswith("http"): image_data = image_path_url else: @@ -68,12 +74,12 @@ class VisionTool(BaseTool): { "type": "image_url", "image_url": {"url": image_data}, - } + }, ], } ], max_tokens=300, - ) + ) return response.choices[0].message.content diff --git a/src/crewai_tools/tools/weaviate_tool/vector_search.py b/src/crewai_tools/tools/weaviate_tool/vector_search.py index 14e10d7c5..53f641272 100644 --- a/src/crewai_tools/tools/weaviate_tool/vector_search.py +++ b/src/crewai_tools/tools/weaviate_tool/vector_search.py @@ -15,9 +15,8 @@ except ImportError: Vectorizers = Any Auth = Any -from pydantic import BaseModel, Field - from crewai.tools import BaseTool +from pydantic import BaseModel, Field class WeaviateToolSchema(BaseModel): diff --git a/src/crewai_tools/tools/website_search/website_search_tool.py b/src/crewai_tools/tools/website_search/website_search_tool.py index faa1a02e8..842462546 100644 --- a/src/crewai_tools/tools/website_search/website_search_tool.py +++ b/src/crewai_tools/tools/website_search/website_search_tool.py @@ -25,9 +25,7 @@ class WebsiteSearchToolSchema(FixedWebsiteSearchToolSchema): class WebsiteSearchTool(RagTool): name: str = "Search in a specific website" - description: str = ( - "A tool that can be used to semantic search a query from a specific URL content." - ) + description: str = "A tool that can be used to semantic search a query from a specific URL content." args_schema: Type[BaseModel] = WebsiteSearchToolSchema def __init__(self, website: Optional[str] = None, **kwargs): diff --git a/src/crewai_tools/tools/youtube_channel_search_tool/youtube_channel_search_tool.py b/src/crewai_tools/tools/youtube_channel_search_tool/youtube_channel_search_tool.py index b0c6209f1..81ecc30c3 100644 --- a/src/crewai_tools/tools/youtube_channel_search_tool/youtube_channel_search_tool.py +++ b/src/crewai_tools/tools/youtube_channel_search_tool/youtube_channel_search_tool.py @@ -25,9 +25,7 @@ class YoutubeChannelSearchToolSchema(FixedYoutubeChannelSearchToolSchema): class YoutubeChannelSearchTool(RagTool): name: str = "Search a Youtube Channels content" - description: str = ( - "A tool that can be used to semantic search a query from a Youtube Channels content." - ) + description: str = "A tool that can be used to semantic search a query from a Youtube Channels content." args_schema: Type[BaseModel] = YoutubeChannelSearchToolSchema def __init__(self, youtube_channel_handle: Optional[str] = None, **kwargs): diff --git a/src/crewai_tools/tools/youtube_video_search_tool/youtube_video_search_tool.py b/src/crewai_tools/tools/youtube_video_search_tool/youtube_video_search_tool.py index 6852fafb4..1ad8434c8 100644 --- a/src/crewai_tools/tools/youtube_video_search_tool/youtube_video_search_tool.py +++ b/src/crewai_tools/tools/youtube_video_search_tool/youtube_video_search_tool.py @@ -25,9 +25,7 @@ class YoutubeVideoSearchToolSchema(FixedYoutubeVideoSearchToolSchema): class YoutubeVideoSearchTool(RagTool): name: str = "Search a Youtube Video content" - description: str = ( - "A tool that can be used to semantic search a query from a Youtube Video content." - ) + description: str = "A tool that can be used to semantic search a query from a Youtube Video content." args_schema: Type[BaseModel] = YoutubeVideoSearchToolSchema def __init__(self, youtube_video_url: Optional[str] = None, **kwargs): diff --git a/tests/base_tool_test.py b/tests/base_tool_test.py index 4a4e40783..e6f4f127d 100644 --- a/tests/base_tool_test.py +++ b/tests/base_tool_test.py @@ -1,69 +1,104 @@ from typing import Callable + from crewai.tools import BaseTool, tool from crewai.tools.base_tool import to_langchain + def test_creating_a_tool_using_annotation(): - @tool("Name of my tool") - def my_tool(question: str) -> str: - """Clear description for what this tool is useful for, you agent will need this information to use it.""" - return question + @tool("Name of my tool") + def my_tool(question: str) -> str: + """Clear description for what this tool is useful for, you agent will need this information to use it.""" + return question - # Assert all the right attributes were defined - assert my_tool.name == "Name of my tool" - assert my_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." - assert my_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}} - assert my_tool.func("What is the meaning of life?") == "What is the meaning of life?" + # Assert all the right attributes were defined + assert my_tool.name == "Name of my tool" + assert ( + my_tool.description + == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." + ) + assert my_tool.args_schema.schema()["properties"] == { + "question": {"title": "Question", "type": "string"} + } + assert ( + my_tool.func("What is the meaning of life?") == "What is the meaning of life?" + ) + + # Assert the langchain tool conversion worked as expected + converted_tool = to_langchain([my_tool])[0] + assert converted_tool.name == "Name of my tool" + assert ( + converted_tool.description + == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." + ) + assert converted_tool.args_schema.schema()["properties"] == { + "question": {"title": "Question", "type": "string"} + } + assert ( + converted_tool.func("What is the meaning of life?") + == "What is the meaning of life?" + ) - # Assert the langchain tool conversion worked as expected - converted_tool = to_langchain([my_tool])[0] - assert converted_tool.name == "Name of my tool" - assert converted_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." - assert converted_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}} - assert converted_tool.func("What is the meaning of life?") == "What is the meaning of life?" def test_creating_a_tool_using_baseclass(): - class MyCustomTool(BaseTool): - name: str = "Name of my tool" - description: str = "Clear description for what this tool is useful for, you agent will need this information to use it." + class MyCustomTool(BaseTool): + name: str = "Name of my tool" + description: str = "Clear description for what this tool is useful for, you agent will need this information to use it." - def _run(self, question: str) -> str: - return question + def _run(self, question: str) -> str: + return question - my_tool = MyCustomTool() - # Assert all the right attributes were defined - assert my_tool.name == "Name of my tool" - assert my_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." - assert my_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}} - assert my_tool._run("What is the meaning of life?") == "What is the meaning of life?" + my_tool = MyCustomTool() + # Assert all the right attributes were defined + assert my_tool.name == "Name of my tool" + assert ( + my_tool.description + == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." + ) + assert my_tool.args_schema.schema()["properties"] == { + "question": {"title": "Question", "type": "string"} + } + assert ( + my_tool._run("What is the meaning of life?") == "What is the meaning of life?" + ) + + # Assert the langchain tool conversion worked as expected + converted_tool = to_langchain([my_tool])[0] + assert converted_tool.name == "Name of my tool" + assert ( + converted_tool.description + == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." + ) + assert converted_tool.args_schema.schema()["properties"] == { + "question": {"title": "Question", "type": "string"} + } + assert ( + converted_tool.invoke({"question": "What is the meaning of life?"}) + == "What is the meaning of life?" + ) - # Assert the langchain tool conversion worked as expected - converted_tool = to_langchain([my_tool])[0] - assert converted_tool.name == "Name of my tool" - assert converted_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." - assert converted_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}} - assert converted_tool.invoke({"question": "What is the meaning of life?"}) == "What is the meaning of life?" def test_setting_cache_function(): - class MyCustomTool(BaseTool): - name: str = "Name of my tool" - description: str = "Clear description for what this tool is useful for, you agent will need this information to use it." - cache_function: Callable = lambda: False + class MyCustomTool(BaseTool): + name: str = "Name of my tool" + description: str = "Clear description for what this tool is useful for, you agent will need this information to use it." + cache_function: Callable = lambda: False - def _run(self, question: str) -> str: - return question + def _run(self, question: str) -> str: + return question + + my_tool = MyCustomTool() + # Assert all the right attributes were defined + assert my_tool.cache_function() == False - my_tool = MyCustomTool() - # Assert all the right attributes were defined - assert my_tool.cache_function() == False def test_default_cache_function_is_true(): - class MyCustomTool(BaseTool): - name: str = "Name of my tool" - description: str = "Clear description for what this tool is useful for, you agent will need this information to use it." + class MyCustomTool(BaseTool): + name: str = "Name of my tool" + description: str = "Clear description for what this tool is useful for, you agent will need this information to use it." - def _run(self, question: str) -> str: - return question + def _run(self, question: str) -> str: + return question - my_tool = MyCustomTool() - # Assert all the right attributes were defined - assert my_tool.cache_function() == True \ No newline at end of file + my_tool = MyCustomTool() + # Assert all the right attributes were defined + assert my_tool.cache_function() == True diff --git a/tests/file_read_tool_test.py b/tests/file_read_tool_test.py index 4646df24c..5957f863b 100644 --- a/tests/file_read_tool_test.py +++ b/tests/file_read_tool_test.py @@ -1,7 +1,8 @@ import os -import pytest + from crewai_tools import FileReadTool + def test_file_read_tool_constructor(): """Test FileReadTool initialization with file_path.""" # Create a temporary test file @@ -18,6 +19,7 @@ def test_file_read_tool_constructor(): # Clean up os.remove(test_file) + def test_file_read_tool_run(): """Test FileReadTool _run method with file_path at runtime.""" # Create a temporary test file @@ -34,6 +36,7 @@ def test_file_read_tool_run(): # Clean up os.remove(test_file) + def test_file_read_tool_error_handling(): """Test FileReadTool error handling.""" # Test missing file path @@ -58,6 +61,7 @@ def test_file_read_tool_error_handling(): os.chmod(test_file, 0o666) # Restore permissions to delete os.remove(test_file) + def test_file_read_tool_constructor_and_run(): """Test FileReadTool using both constructor and runtime file paths.""" # Create two test files diff --git a/tests/it/tools/__init__.py b/tests/it/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/it/tools/conftest.py b/tests/it/tools/conftest.py new file mode 100644 index 000000000..a633c22c7 --- /dev/null +++ b/tests/it/tools/conftest.py @@ -0,0 +1,21 @@ +import pytest + + +def pytest_configure(config): + """Register custom markers.""" + config.addinivalue_line("markers", "integration: mark test as an integration test") + config.addinivalue_line("markers", "asyncio: mark test as an async test") + + # Set the asyncio loop scope through ini configuration + config.inicfg["asyncio_mode"] = "auto" + + +@pytest.fixture(scope="function") +def event_loop(): + """Create an instance of the default event loop for each test case.""" + import asyncio + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + yield loop + loop.close() diff --git a/tests/it/tools/snowflake_search_tool_test.py b/tests/it/tools/snowflake_search_tool_test.py new file mode 100644 index 000000000..70dc07953 --- /dev/null +++ b/tests/it/tools/snowflake_search_tool_test.py @@ -0,0 +1,219 @@ +import asyncio +import json +from decimal import Decimal + +import pytest +from snowflake.connector.errors import DatabaseError, OperationalError + +from crewai_tools import SnowflakeConfig, SnowflakeSearchTool + +# Test Data +MENU_ITEMS = [ + (10001, "Ice Cream", "Freezing Point", "Lemonade", "Beverage", "Cold Option", 1, 4), + ( + 10002, + "Ice Cream", + "Freezing Point", + "Vanilla Ice Cream", + "Dessert", + "Ice Cream", + 2, + 6, + ), +] + +INVALID_QUERIES = [ + ("SELECT * FROM nonexistent_table", "relation 'nonexistent_table' does not exist"), + ("SELECT invalid_column FROM menu", "invalid identifier 'invalid_column'"), + ("INVALID SQL QUERY", "SQL compilation error"), +] + + +# Integration Test Fixtures +@pytest.fixture +def config(): + """Create a Snowflake configuration with test credentials.""" + return SnowflakeConfig( + account="lwyhjun-wx11931", + user="crewgitci", + password="crewaiT00ls_publicCIpass123", + warehouse="COMPUTE_WH", + database="tasty_bytes_sample_data", + snowflake_schema="raw_pos", + ) + + +@pytest.fixture +def snowflake_tool(config): + """Create a SnowflakeSearchTool instance.""" + return SnowflakeSearchTool(config=config) + + +# Integration Tests with Real Snowflake Connection +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.parametrize( + "menu_id,expected_type,brand,item_name,category,subcategory,cost,price", MENU_ITEMS +) +async def test_menu_items( + snowflake_tool, + menu_id, + expected_type, + brand, + item_name, + category, + subcategory, + cost, + price, +): + """Test menu items with parameterized data for multiple test cases.""" + results = await snowflake_tool._run( + query=f"SELECT * FROM menu WHERE menu_id = {menu_id}" + ) + assert len(results) == 1 + menu_item = results[0] + + # Validate all fields + assert menu_item["MENU_ID"] == menu_id + assert menu_item["MENU_TYPE"] == expected_type + assert menu_item["TRUCK_BRAND_NAME"] == brand + assert menu_item["MENU_ITEM_NAME"] == item_name + assert menu_item["ITEM_CATEGORY"] == category + assert menu_item["ITEM_SUBCATEGORY"] == subcategory + assert menu_item["COST_OF_GOODS_USD"] == cost + assert menu_item["SALE_PRICE_USD"] == price + + # Validate health metrics JSON structure + health_metrics = json.loads(menu_item["MENU_ITEM_HEALTH_METRICS_OBJ"]) + assert "menu_item_health_metrics" in health_metrics + metrics = health_metrics["menu_item_health_metrics"][0] + assert "ingredients" in metrics + assert isinstance(metrics["ingredients"], list) + assert all(isinstance(ingredient, str) for ingredient in metrics["ingredients"]) + assert metrics["is_dairy_free_flag"] in ["Y", "N"] + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_menu_categories_aggregation(snowflake_tool): + """Test complex aggregation query on menu categories with detailed validations.""" + results = await snowflake_tool._run( + query=""" + SELECT + item_category, + COUNT(*) as item_count, + AVG(sale_price_usd) as avg_price, + SUM(sale_price_usd - cost_of_goods_usd) as total_margin, + COUNT(DISTINCT menu_type) as menu_type_count, + MIN(sale_price_usd) as min_price, + MAX(sale_price_usd) as max_price + FROM menu + GROUP BY item_category + HAVING COUNT(*) > 1 + ORDER BY item_count DESC + """ + ) + + assert len(results) > 0 + for category in results: + # Basic presence checks + assert all( + key in category + for key in [ + "ITEM_CATEGORY", + "ITEM_COUNT", + "AVG_PRICE", + "TOTAL_MARGIN", + "MENU_TYPE_COUNT", + "MIN_PRICE", + "MAX_PRICE", + ] + ) + + # Value validations + assert category["ITEM_COUNT"] > 1 # Due to HAVING clause + assert category["MIN_PRICE"] <= category["MAX_PRICE"] + assert category["AVG_PRICE"] >= category["MIN_PRICE"] + assert category["AVG_PRICE"] <= category["MAX_PRICE"] + assert category["MENU_TYPE_COUNT"] >= 1 + assert isinstance(category["TOTAL_MARGIN"], (float, Decimal)) + + +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.parametrize("invalid_query,expected_error", INVALID_QUERIES) +async def test_invalid_queries(snowflake_tool, invalid_query, expected_error): + """Test error handling for invalid queries.""" + with pytest.raises((DatabaseError, OperationalError)) as exc_info: + await snowflake_tool._run(query=invalid_query) + assert expected_error.lower() in str(exc_info.value).lower() + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_concurrent_queries(snowflake_tool): + """Test handling of concurrent queries.""" + queries = [ + "SELECT COUNT(*) FROM menu", + "SELECT COUNT(DISTINCT menu_type) FROM menu", + "SELECT COUNT(DISTINCT item_category) FROM menu", + ] + + tasks = [snowflake_tool._run(query=query) for query in queries] + results = await asyncio.gather(*tasks) + + assert len(results) == 3 + assert all(isinstance(result, list) for result in results) + assert all(len(result) == 1 for result in results) + assert all(isinstance(result[0], dict) for result in results) + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_query_timeout(snowflake_tool): + """Test query timeout handling with a complex query.""" + with pytest.raises((DatabaseError, OperationalError)) as exc_info: + await snowflake_tool._run( + query=""" + WITH RECURSIVE numbers AS ( + SELECT 1 as n + UNION ALL + SELECT n + 1 + FROM numbers + WHERE n < 1000000 + ) + SELECT COUNT(*) FROM numbers + """ + ) + assert ( + "timeout" in str(exc_info.value).lower() + or "execution time" in str(exc_info.value).lower() + ) + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_caching_behavior(snowflake_tool): + """Test query caching behavior and performance.""" + query = "SELECT * FROM menu LIMIT 5" + + # First execution + start_time = asyncio.get_event_loop().time() + results1 = await snowflake_tool._run(query=query) + first_duration = asyncio.get_event_loop().time() - start_time + + # Second execution (should be cached) + start_time = asyncio.get_event_loop().time() + results2 = await snowflake_tool._run(query=query) + second_duration = asyncio.get_event_loop().time() - start_time + + # Verify results + assert results1 == results2 + assert len(results1) == 5 + assert second_duration < first_duration + + # Verify cache invalidation with different query + different_query = "SELECT * FROM menu LIMIT 10" + different_results = await snowflake_tool._run(query=different_query) + assert len(different_results) == 10 + assert different_results != results1 diff --git a/tests/spider_tool_test.py b/tests/spider_tool_test.py index 264394777..7f5613fe6 100644 --- a/tests/spider_tool_test.py +++ b/tests/spider_tool_test.py @@ -1,5 +1,7 @@ +from crewai import Agent, Crew, Task + from crewai_tools.tools.spider_tool.spider_tool import SpiderTool -from crewai import Agent, Task, Crew + def test_spider_tool(): spider_tool = SpiderTool() @@ -10,38 +12,35 @@ def test_spider_tool(): backstory="An expert web researcher that uses the web extremely well", tools=[spider_tool], verbose=True, - cache=False + cache=False, ) choose_between_scrape_crawl = Task( description="Scrape the page of spider.cloud and return a summary of how fast it is", expected_output="spider.cloud is a fast scraping and crawling tool", - agent=searcher + agent=searcher, ) return_metadata = Task( description="Scrape https://spider.cloud with a limit of 1 and enable metadata", expected_output="Metadata and 10 word summary of spider.cloud", - agent=searcher + agent=searcher, ) css_selector = Task( description="Scrape one page of spider.cloud with the `body > div > main > section.grid.md\:grid-cols-2.gap-10.place-items-center.md\:max-w-screen-xl.mx-auto.pb-8.pt-20 > div:nth-child(1) > h1` CSS selector", expected_output="The content of the element with the css selector body > div > main > section.grid.md\:grid-cols-2.gap-10.place-items-center.md\:max-w-screen-xl.mx-auto.pb-8.pt-20 > div:nth-child(1) > h1", - agent=searcher + agent=searcher, ) crew = Crew( agents=[searcher], - tasks=[ - choose_between_scrape_crawl, - return_metadata, - css_selector - ], - verbose=True + tasks=[choose_between_scrape_crawl, return_metadata, css_selector], + verbose=True, ) crew.kickoff() + if __name__ == "__main__": test_spider_tool() diff --git a/tests/tools/snowflake_search_tool_test.py b/tests/tools/snowflake_search_tool_test.py new file mode 100644 index 000000000..d4851b8ab --- /dev/null +++ b/tests/tools/snowflake_search_tool_test.py @@ -0,0 +1,103 @@ +import asyncio +from unittest.mock import MagicMock, patch + +import pytest + +from crewai_tools import SnowflakeConfig, SnowflakeSearchTool + + +# Unit Test Fixtures +@pytest.fixture +def mock_snowflake_connection(): + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.description = [("col1",), ("col2",)] + mock_cursor.fetchall.return_value = [(1, "value1"), (2, "value2")] + mock_cursor.execute.return_value = None + mock_conn.cursor.return_value = mock_cursor + return mock_conn + + +@pytest.fixture +def mock_config(): + return SnowflakeConfig( + account="test_account", + user="test_user", + password="test_password", + warehouse="test_warehouse", + database="test_db", + snowflake_schema="test_schema", + ) + + +@pytest.fixture +def snowflake_tool(mock_config): + with patch("snowflake.connector.connect") as mock_connect: + tool = SnowflakeSearchTool(config=mock_config) + yield tool + + +# Unit Tests +@pytest.mark.asyncio +async def test_successful_query_execution(snowflake_tool, mock_snowflake_connection): + with patch.object(snowflake_tool, "_create_connection") as mock_create_conn: + mock_create_conn.return_value = mock_snowflake_connection + + results = await snowflake_tool._run( + query="SELECT * FROM test_table", timeout=300 + ) + + assert len(results) == 2 + assert results[0]["col1"] == 1 + assert results[0]["col2"] == "value1" + mock_snowflake_connection.cursor.assert_called_once() + + +@pytest.mark.asyncio +async def test_connection_pooling(snowflake_tool, mock_snowflake_connection): + with patch.object(snowflake_tool, "_create_connection") as mock_create_conn: + mock_create_conn.return_value = mock_snowflake_connection + + # Execute multiple queries + await asyncio.gather( + snowflake_tool._run("SELECT 1"), + snowflake_tool._run("SELECT 2"), + snowflake_tool._run("SELECT 3"), + ) + + # Should reuse connections from pool + assert mock_create_conn.call_count <= snowflake_tool.pool_size + + +@pytest.mark.asyncio +async def test_cleanup_on_deletion(snowflake_tool, mock_snowflake_connection): + with patch.object(snowflake_tool, "_create_connection") as mock_create_conn: + mock_create_conn.return_value = mock_snowflake_connection + + # Add connection to pool + await snowflake_tool._get_connection() + + # Return connection to pool + async with snowflake_tool._pool_lock: + snowflake_tool._connection_pool.append(mock_snowflake_connection) + + # Trigger cleanup + snowflake_tool.__del__() + + mock_snowflake_connection.close.assert_called_once() + + +def test_config_validation(): + # Test missing required fields + with pytest.raises(ValueError): + SnowflakeConfig() + + # Test invalid account format + with pytest.raises(ValueError): + SnowflakeConfig( + account="invalid//account", user="test_user", password="test_pass" + ) + + # Test missing authentication + with pytest.raises(ValueError): + SnowflakeConfig(account="test_account", user="test_user") diff --git a/tests/tools/test_code_interpreter_tool.py b/tests/tools/test_code_interpreter_tool.py index 6470c9dc1..e281fffaf 100644 --- a/tests/tools/test_code_interpreter_tool.py +++ b/tests/tools/test_code_interpreter_tool.py @@ -7,7 +7,9 @@ from crewai_tools.tools.code_interpreter_tool.code_interpreter_tool import ( class TestCodeInterpreterTool(unittest.TestCase): - @patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env") + @patch( + "crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env" + ) def test_run_code_in_docker(self, docker_mock): tool = CodeInterpreterTool() code = "print('Hello, World!')" @@ -15,14 +17,14 @@ class TestCodeInterpreterTool(unittest.TestCase): expected_output = "Hello, World!\n" docker_mock().containers.run().exec_run().exit_code = 0 - docker_mock().containers.run().exec_run().output = ( - expected_output.encode() - ) + docker_mock().containers.run().exec_run().output = expected_output.encode() result = tool.run_code_in_docker(code, libraries_used) self.assertEqual(result, expected_output) - @patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env") + @patch( + "crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env" + ) def test_run_code_in_docker_with_error(self, docker_mock): tool = CodeInterpreterTool() code = "print(1/0)" @@ -37,7 +39,9 @@ class TestCodeInterpreterTool(unittest.TestCase): self.assertEqual(result, expected_output) - @patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env") + @patch( + "crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env" + ) def test_run_code_in_docker_with_script(self, docker_mock): tool = CodeInterpreterTool() code = """print("This is line 1") From a606f48b70b346e70bf3bbfdad78290367f4469b Mon Sep 17 00:00:00 2001 From: ArchiusVuong-sudo Date: Sat, 18 Jan 2025 21:58:50 +0700 Subject: [PATCH 2/3] FIX: Fix HTTPError cannot be found in serperai --- .../tools/serpapi_tool/serpapi_google_search_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/crewai_tools/tools/serpapi_tool/serpapi_google_search_tool.py b/src/crewai_tools/tools/serpapi_tool/serpapi_google_search_tool.py index 199b7f5a2..f8edd6458 100644 --- a/src/crewai_tools/tools/serpapi_tool/serpapi_google_search_tool.py +++ b/src/crewai_tools/tools/serpapi_tool/serpapi_google_search_tool.py @@ -3,7 +3,7 @@ from typing import Any, Type, Optional import re from pydantic import BaseModel, Field from .serpapi_base_tool import SerpApiBaseTool -from serpapi import HTTPError +from urllib.error import HTTPError class SerpApiGoogleSearchToolSchema(BaseModel): """Input for Google Search.""" From 659cb6279e2b2833fea0d4c8da4946160100befd Mon Sep 17 00:00:00 2001 From: ArchiusVuong-sudo Date: Sat, 18 Jan 2025 23:01:01 +0700 Subject: [PATCH 3/3] fix: Fixed all from urllib.error import HTTPError --- .../tools/serpapi_tool/serpapi_google_shopping_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/crewai_tools/tools/serpapi_tool/serpapi_google_shopping_tool.py b/src/crewai_tools/tools/serpapi_tool/serpapi_google_shopping_tool.py index b44b3a809..5863239c5 100644 --- a/src/crewai_tools/tools/serpapi_tool/serpapi_google_shopping_tool.py +++ b/src/crewai_tools/tools/serpapi_tool/serpapi_google_shopping_tool.py @@ -3,7 +3,7 @@ from typing import Any, Type, Optional import re from pydantic import BaseModel, Field from .serpapi_base_tool import SerpApiBaseTool -from serpapi import HTTPError +from urllib.error import HTTPError class SerpApiGoogleShoppingToolSchema(BaseModel): """Input for Google Shopping."""