From 29da6659cf2eeff61c515642efabfbffd8052969 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Thu, 2 Jan 2025 19:40:56 +0530 Subject: [PATCH 01/23] added the skeleton for the AIMind tool --- .../tools/ai_minds_tool/README.md | 0 .../tools/ai_minds_tool/__init__.py | 0 .../tools/ai_minds_tool/ai_minds_tool.py | 40 +++++++++++++++++++ 3 files changed, 40 insertions(+) create mode 100644 src/crewai_tools/tools/ai_minds_tool/README.md create mode 100644 src/crewai_tools/tools/ai_minds_tool/__init__.py create mode 100644 src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py diff --git a/src/crewai_tools/tools/ai_minds_tool/README.md b/src/crewai_tools/tools/ai_minds_tool/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/src/crewai_tools/tools/ai_minds_tool/__init__.py b/src/crewai_tools/tools/ai_minds_tool/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py b/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py new file mode 100644 index 000000000..99d8e3f8f --- /dev/null +++ b/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py @@ -0,0 +1,40 @@ +from typing import Dict, Optional, Type, TYPE_CHECKING + +from crewai.tools import BaseTool +from openai import OpenAI +from pydantic import BaseModel + +if TYPE_CHECKING: + from minds_sdk import Client + + +class AIMindInputSchema(BaseModel): + """Input for AIMind Tool.""" + + query: str = "Question in natural language to ask the AI-Mind" + + +class AIMindTool(BaseTool): + name: str = "AIMind Tool" + description: str = ( + "A wrapper around [AI-Minds](https://mindsdb.com/minds). " + "Useful for when you need answers to questions from your data, stored in " + "data sources including PostgreSQL, MySQL, MariaDB, ClickHouse, Snowflake " + "and Google BigQuery. " + "Input should be a question in natural language." + ) + args_schema: Type[BaseModel] = AIMindInputSchema + api_key: Optional[str] = None + datasources: Optional[Dict] = None + minds_client: Optional["Client"] = None + + def __init__(self, api_key: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + try: + from minds_sdk import Client # type: ignore + except ImportError: + raise ImportError( + "`minds_sdk` package not found, please run `pip install minds-sdk`" + ) + + self.minds_client = Client(api_key=api_key) \ No newline at end of file From 55f669989bca634ba61919f1295ee8f47b4a208c Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Fri, 3 Jan 2025 00:28:30 +0530 Subject: [PATCH 02/23] completed the initialization logic for the tool --- .../tools/ai_minds_tool/ai_minds_tool.py | 33 +++++++++++++++---- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py b/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py index 99d8e3f8f..411daf209 100644 --- a/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py +++ b/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py @@ -1,12 +1,10 @@ -from typing import Dict, Optional, Type, TYPE_CHECKING +import secrets +from typing import Dict, Optional, Text, Type from crewai.tools import BaseTool from openai import OpenAI from pydantic import BaseModel -if TYPE_CHECKING: - from minds_sdk import Client - class AIMindInputSchema(BaseModel): """Input for AIMind Tool.""" @@ -26,15 +24,38 @@ class AIMindTool(BaseTool): args_schema: Type[BaseModel] = AIMindInputSchema api_key: Optional[str] = None datasources: Optional[Dict] = None - minds_client: Optional["Client"] = None + mind_name: Optional[Text] = None def __init__(self, api_key: Optional[str] = None, **kwargs): super().__init__(**kwargs) try: from minds_sdk import Client # type: ignore + from minds.datasources import DatabaseConfig # type: ignore except ImportError: raise ImportError( "`minds_sdk` package not found, please run `pip install minds-sdk`" ) - self.minds_client = Client(api_key=api_key) \ No newline at end of file + minds_client = Client(api_key=api_key) + + # Convert the datasources to DatabaseConfig objects. + datasources = [] + for datasource in self.datasources: + if datasource["type"] == "database": + config = DatabaseConfig( + name=datasource["name"], + engine=datasource["engine"], + description=datasource["description"], + connection_data=datasource["connection_data"], + tables=datasource["tables"], + ) + datasources.append(config) + + # Generate a random name for the Mind. + name = f"cai_mind_{secrets.token_hex(5)}" + + mind = minds_client.minds.create( + name=name, datasources=datasources, replace=True + ) + + self.mind_name = mind.name \ No newline at end of file From 0b5f0841bf235eb029e4d24c19efd00c90bbeccd Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Fri, 3 Jan 2025 00:32:24 +0530 Subject: [PATCH 03/23] implemented the run function for the tool --- .../tools/ai_minds_tool/ai_minds_tool.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py b/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py index 411daf209..915ed1ca0 100644 --- a/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py +++ b/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py @@ -58,4 +58,20 @@ class AIMindTool(BaseTool): name=name, datasources=datasources, replace=True ) - self.mind_name = mind.name \ No newline at end of file + self.mind_name = mind.name + + def _run( + self, + query: Text + ): + # Run the query on the AI-Mind. + # The Minds API is OpenAI compatible and therefore, the OpenAI client can be used. + openai_client = OpenAI(base_url="https://mdb.ai/", api_key=self.api_key) + + completion = openai_client.create( + model=self.mind_name, + messages=[{"role": "user", "content": query}], + stream=False, + ) + + return completion.choices[0].message.content \ No newline at end of file From 555638a654f61c2c07afc68320e272421be94f7b Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Fri, 3 Jan 2025 00:37:12 +0530 Subject: [PATCH 04/23] added the main import statements --- src/crewai_tools/__init__.py | 1 + src/crewai_tools/tools/__init__.py | 1 + src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py | 2 +- 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/crewai_tools/__init__.py b/src/crewai_tools/__init__.py index 87aca8531..a0e384683 100644 --- a/src/crewai_tools/__init__.py +++ b/src/crewai_tools/__init__.py @@ -1,4 +1,5 @@ from .tools import ( + AIMindTool, BraveSearchTool, BrowserbaseLoadTool, CodeDocsSearchTool, diff --git a/src/crewai_tools/tools/__init__.py b/src/crewai_tools/tools/__init__.py index f6c31f45f..c125082f3 100644 --- a/src/crewai_tools/tools/__init__.py +++ b/src/crewai_tools/tools/__init__.py @@ -1,3 +1,4 @@ +from .ai_minds_tool.ai_minds_tool import AIMindTool from .brave_search_tool.brave_search_tool import BraveSearchTool from .browserbase_load_tool.browserbase_load_tool import BrowserbaseLoadTool from .code_docs_search_tool.code_docs_search_tool import CodeDocsSearchTool diff --git a/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py b/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py index 915ed1ca0..8d7750771 100644 --- a/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py +++ b/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py @@ -26,7 +26,7 @@ class AIMindTool(BaseTool): datasources: Optional[Dict] = None mind_name: Optional[Text] = None - def __init__(self, api_key: Optional[str] = None, **kwargs): + def __init__(self, api_key: Optional[Text] = None, **kwargs): super().__init__(**kwargs) try: from minds_sdk import Client # type: ignore From faff58ba1cc18f2aef203164a75388a2d4f04d3f Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Fri, 3 Jan 2025 01:17:11 +0530 Subject: [PATCH 05/23] fixed a few bugs, type hints and imports --- .../tools/ai_minds_tool/ai_minds_tool.py | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py b/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py index 8d7750771..222271d7f 100644 --- a/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py +++ b/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py @@ -1,5 +1,5 @@ import secrets -from typing import Dict, Optional, Text, Type +from typing import Any, Dict, List, Optional, Text, Type from crewai.tools import BaseTool from openai import OpenAI @@ -23,13 +23,13 @@ class AIMindTool(BaseTool): ) args_schema: Type[BaseModel] = AIMindInputSchema api_key: Optional[str] = None - datasources: Optional[Dict] = None + datasources: Optional[List[Dict[str, Any]]] = None mind_name: Optional[Text] = None def __init__(self, api_key: Optional[Text] = None, **kwargs): - super().__init__(**kwargs) + super().__init__(api_key=api_key, **kwargs) try: - from minds_sdk import Client # type: ignore + from minds.client import Client # type: ignore from minds.datasources import DatabaseConfig # type: ignore except ImportError: raise ImportError( @@ -41,15 +41,14 @@ class AIMindTool(BaseTool): # Convert the datasources to DatabaseConfig objects. datasources = [] for datasource in self.datasources: - if datasource["type"] == "database": - config = DatabaseConfig( - name=datasource["name"], - engine=datasource["engine"], - description=datasource["description"], - connection_data=datasource["connection_data"], - tables=datasource["tables"], - ) - datasources.append(config) + config = DatabaseConfig( + name=f"cai_ds_{secrets.token_hex(5)}", + engine=datasource["engine"], + description=datasource["description"], + connection_data=datasource["connection_data"], + tables=datasource["tables"], + ) + datasources.append(config) # Generate a random name for the Mind. name = f"cai_mind_{secrets.token_hex(5)}" @@ -68,7 +67,7 @@ class AIMindTool(BaseTool): # The Minds API is OpenAI compatible and therefore, the OpenAI client can be used. openai_client = OpenAI(base_url="https://mdb.ai/", api_key=self.api_key) - completion = openai_client.create( + completion = openai_client.chat.completions.create( model=self.mind_name, messages=[{"role": "user", "content": query}], stream=False, From 64d54bd42352e54615a89211571b9c22557bee8a Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Fri, 3 Jan 2025 01:55:51 +0530 Subject: [PATCH 06/23] updated the content in the README --- .../tools/ai_minds_tool/README.md | 75 +++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/src/crewai_tools/tools/ai_minds_tool/README.md b/src/crewai_tools/tools/ai_minds_tool/README.md index e69de29bb..7bb47cde5 100644 --- a/src/crewai_tools/tools/ai_minds_tool/README.md +++ b/src/crewai_tools/tools/ai_minds_tool/README.md @@ -0,0 +1,75 @@ +# AIMind Tool + +## Description + +[Minds](https://mindsdb.com/minds) are AI systems provided by [MindsDB](https://mindsdb.com/) that work similarly to large language models (LLMs) but go beyond by answering any question from any data. + +This is accomplished by selecting the most relevant data for an answer using parametric search, understanding the meaning and providing responses within the correct context through semantic search, and finally, delivering precise answers by analyzing data and using machine learning (ML) models. + +## Installation + +1. Install the `crewai[tools]` package: + +```shell +pip install 'crewai[tools]' +``` + +2. Install the Minds SDK: + +```shell +pip install minds-sdk +``` + +3. Sign for a Minds account [here](https://mdb.ai/register), and obtain an API key. + +4. Set the Minds API key in an environment variable named `MINDS_API_KEY`. + +## Usage + +```python +from crewai_tools import AIMindTool + + +# Initialize the AIMindTool. +aimind_tool = AIMindTool( + datasources=[ + { + "description": "house sales data", + "engine": "postgres", + "connection_data": { + "user": "demo_user", + "password": "demo_password", + "host": "samples.mindsdb.com", + "port": 5432, + "database": "demo", + "schema": "demo_data" + }, + "tables": ["house_sales"] + } + ] +) +``` + +The `datasources` parameter is a list of dictionaries, each containing the following keys: + +- `description`: A description of the data contained in the datasource. +- `engine`: The engine (or type) of the datasource. +- `connection_data`: A dictionary containing the connection parameters for the datasource. +- `tables`: A list of tables that the data source will use. + +A list of supported data sources and their connection parameters can be found [here](https://docs.mdb.ai/docs/data_sources). + +```python +from crewai import Agent +from crewai.project import agent + + +# Define an agent with the AIMindTool. +@agent +def researcher(self) -> Agent: + return Agent( + config=self.agents_config["researcher"], + allow_delegation=False, + tools=[aimind_tool] + ) +``` From 3c29a6cc11cf0a221e2bb84eefcf84797ce6d450 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Fri, 3 Jan 2025 02:25:14 +0530 Subject: [PATCH 07/23] added an example of running the tool to the README --- src/crewai_tools/tools/ai_minds_tool/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/crewai_tools/tools/ai_minds_tool/README.md b/src/crewai_tools/tools/ai_minds_tool/README.md index 7bb47cde5..5b3755515 100644 --- a/src/crewai_tools/tools/ai_minds_tool/README.md +++ b/src/crewai_tools/tools/ai_minds_tool/README.md @@ -48,6 +48,8 @@ aimind_tool = AIMindTool( } ] ) + +aimind_tool.run("How many 3 bedroom houses were sold in 2008?") ``` The `datasources` parameter is a list of dictionaries, each containing the following keys: From 94cce06044af904a8f794511449abf250dc64c2f Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Fri, 3 Jan 2025 11:08:38 +0530 Subject: [PATCH 08/23] updated the initialization logic to allow the API key to be passed as env var --- src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py b/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py index 222271d7f..1059d0053 100644 --- a/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py +++ b/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py @@ -1,3 +1,4 @@ +import os import secrets from typing import Any, Dict, List, Optional, Text, Type @@ -36,7 +37,13 @@ class AIMindTool(BaseTool): "`minds_sdk` package not found, please run `pip install minds-sdk`" ) - minds_client = Client(api_key=api_key) + if os.getenv("MINDS_API_KEY"): + self.api_key = os.getenv("MINDS_API_KEY") + + if self.api_key is None: + raise ValueError("A Minds API key is required to use the AIMind Tool.") + + minds_client = Client(api_key=self.api_key) # Convert the datasources to DatabaseConfig objects. datasources = [] From 29a7961ca8c3164f8d20f46011af4d35019895e0 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Fri, 3 Jan 2025 11:26:16 +0530 Subject: [PATCH 09/23] refined the content in the README --- src/crewai_tools/tools/ai_minds_tool/README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/crewai_tools/tools/ai_minds_tool/README.md b/src/crewai_tools/tools/ai_minds_tool/README.md index 5b3755515..95d2deb42 100644 --- a/src/crewai_tools/tools/ai_minds_tool/README.md +++ b/src/crewai_tools/tools/ai_minds_tool/README.md @@ -6,6 +6,8 @@ This is accomplished by selecting the most relevant data for an answer using parametric search, understanding the meaning and providing responses within the correct context through semantic search, and finally, delivering precise answers by analyzing data and using machine learning (ML) models. +The `AIMindTool` can be used to query data sources in natural language by simply configuring their connection parameters. + ## Installation 1. Install the `crewai[tools]` package: @@ -55,9 +57,9 @@ aimind_tool.run("How many 3 bedroom houses were sold in 2008?") The `datasources` parameter is a list of dictionaries, each containing the following keys: - `description`: A description of the data contained in the datasource. -- `engine`: The engine (or type) of the datasource. -- `connection_data`: A dictionary containing the connection parameters for the datasource. -- `tables`: A list of tables that the data source will use. +- `engine`: The engine (or type) of the datasource. Find a list of supported engines in the link below. +- `connection_data`: A dictionary containing the connection parameters for the datasource. Find a list of connection parameters for each engine in the link below. +- `tables`: A list of tables that the data source will use. This is optional and can be omitted if all tables in the data source are to be used. A list of supported data sources and their connection parameters can be found [here](https://docs.mdb.ai/docs/data_sources). From d360906f578c830ea6c80e7bc2e012bfc4195acc Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Fri, 3 Jan 2025 11:41:59 +0530 Subject: [PATCH 10/23] renamed the pkg and module --- src/crewai_tools/tools/__init__.py | 2 +- .../tools/{ai_minds_tool => ai_mind_tool}/README.md | 0 .../tools/{ai_minds_tool => ai_mind_tool}/__init__.py | 0 .../ai_minds_tool.py => ai_mind_tool/ai_mind_tool.py} | 0 4 files changed, 1 insertion(+), 1 deletion(-) rename src/crewai_tools/tools/{ai_minds_tool => ai_mind_tool}/README.md (100%) rename src/crewai_tools/tools/{ai_minds_tool => ai_mind_tool}/__init__.py (100%) rename src/crewai_tools/tools/{ai_minds_tool/ai_minds_tool.py => ai_mind_tool/ai_mind_tool.py} (100%) diff --git a/src/crewai_tools/tools/__init__.py b/src/crewai_tools/tools/__init__.py index c125082f3..33d68fb26 100644 --- a/src/crewai_tools/tools/__init__.py +++ b/src/crewai_tools/tools/__init__.py @@ -1,4 +1,4 @@ -from .ai_minds_tool.ai_minds_tool import AIMindTool +from .ai_mind_tool.ai_mind_tool import AIMindTool from .brave_search_tool.brave_search_tool import BraveSearchTool from .browserbase_load_tool.browserbase_load_tool import BrowserbaseLoadTool from .code_docs_search_tool.code_docs_search_tool import CodeDocsSearchTool diff --git a/src/crewai_tools/tools/ai_minds_tool/README.md b/src/crewai_tools/tools/ai_mind_tool/README.md similarity index 100% rename from src/crewai_tools/tools/ai_minds_tool/README.md rename to src/crewai_tools/tools/ai_mind_tool/README.md diff --git a/src/crewai_tools/tools/ai_minds_tool/__init__.py b/src/crewai_tools/tools/ai_mind_tool/__init__.py similarity index 100% rename from src/crewai_tools/tools/ai_minds_tool/__init__.py rename to src/crewai_tools/tools/ai_mind_tool/__init__.py diff --git a/src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py b/src/crewai_tools/tools/ai_mind_tool/ai_mind_tool.py similarity index 100% rename from src/crewai_tools/tools/ai_minds_tool/ai_minds_tool.py rename to src/crewai_tools/tools/ai_mind_tool/ai_mind_tool.py From d1be5a937f6b19569b76a4aab2cccc7e6355a6dd Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Fri, 3 Jan 2025 11:48:11 +0530 Subject: [PATCH 11/23] moved constants like the base URL to a class --- .../tools/ai_mind_tool/ai_mind_tool.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/crewai_tools/tools/ai_mind_tool/ai_mind_tool.py b/src/crewai_tools/tools/ai_mind_tool/ai_mind_tool.py index 1059d0053..c36400d0b 100644 --- a/src/crewai_tools/tools/ai_mind_tool/ai_mind_tool.py +++ b/src/crewai_tools/tools/ai_mind_tool/ai_mind_tool.py @@ -7,7 +7,13 @@ from openai import OpenAI from pydantic import BaseModel -class AIMindInputSchema(BaseModel): +class AIMindToolConstants: + MINDS_API_BASE_URL = "https://mdb.ai/" + MIND_NAME_PREFIX = "crwai_mind_" + DATASOURCE_NAME_PREFIX = "crwai_ds_" + + +class AIMindToolInputSchema(BaseModel): """Input for AIMind Tool.""" query: str = "Question in natural language to ask the AI-Mind" @@ -22,7 +28,7 @@ class AIMindTool(BaseTool): "and Google BigQuery. " "Input should be a question in natural language." ) - args_schema: Type[BaseModel] = AIMindInputSchema + args_schema: Type[BaseModel] = AIMindToolInputSchema api_key: Optional[str] = None datasources: Optional[List[Dict[str, Any]]] = None mind_name: Optional[Text] = None @@ -49,7 +55,7 @@ class AIMindTool(BaseTool): datasources = [] for datasource in self.datasources: config = DatabaseConfig( - name=f"cai_ds_{secrets.token_hex(5)}", + name=f"{AIMindToolConstants.DATASOURCE_NAME_PREFIX}_{secrets.token_hex(5)}", engine=datasource["engine"], description=datasource["description"], connection_data=datasource["connection_data"], @@ -58,7 +64,7 @@ class AIMindTool(BaseTool): datasources.append(config) # Generate a random name for the Mind. - name = f"cai_mind_{secrets.token_hex(5)}" + name = f"{AIMindToolConstants.MIND_NAME_PREFIX}_{secrets.token_hex(5)}" mind = minds_client.minds.create( name=name, datasources=datasources, replace=True @@ -72,7 +78,7 @@ class AIMindTool(BaseTool): ): # Run the query on the AI-Mind. # The Minds API is OpenAI compatible and therefore, the OpenAI client can be used. - openai_client = OpenAI(base_url="https://mdb.ai/", api_key=self.api_key) + openai_client = OpenAI(base_url=AIMindToolConstants.MINDS_API_BASE_URL, api_key=self.api_key) completion = openai_client.chat.completions.create( model=self.mind_name, From ea85f02e035ba106ead271ee7b83d628feae2215 Mon Sep 17 00:00:00 2001 From: Minura Punchihewa Date: Fri, 3 Jan 2025 11:49:58 +0530 Subject: [PATCH 12/23] refactored the logic for accessing the API key --- src/crewai_tools/tools/ai_mind_tool/ai_mind_tool.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/crewai_tools/tools/ai_mind_tool/ai_mind_tool.py b/src/crewai_tools/tools/ai_mind_tool/ai_mind_tool.py index c36400d0b..b38426e09 100644 --- a/src/crewai_tools/tools/ai_mind_tool/ai_mind_tool.py +++ b/src/crewai_tools/tools/ai_mind_tool/ai_mind_tool.py @@ -34,7 +34,11 @@ class AIMindTool(BaseTool): mind_name: Optional[Text] = None def __init__(self, api_key: Optional[Text] = None, **kwargs): - super().__init__(api_key=api_key, **kwargs) + super().__init__(**kwargs) + self.api_key = api_key or os.getenv("MINDS_API_KEY") + if not self.api_key: + raise ValueError("API key must be provided either through constructor or MINDS_API_KEY environment variable") + try: from minds.client import Client # type: ignore from minds.datasources import DatabaseConfig # type: ignore @@ -43,12 +47,6 @@ class AIMindTool(BaseTool): "`minds_sdk` package not found, please run `pip install minds-sdk`" ) - if os.getenv("MINDS_API_KEY"): - self.api_key = os.getenv("MINDS_API_KEY") - - if self.api_key is None: - raise ValueError("A Minds API key is required to use the AIMind Tool.") - minds_client = Client(api_key=self.api_key) # Convert the datasources to DatabaseConfig objects. From c27727b16eb37b00773197162b32312dde16ff17 Mon Sep 17 00:00:00 2001 From: Marco Vinciguerra Date: Tue, 7 Jan 2025 15:51:52 +0100 Subject: [PATCH 13/23] Update scrapegraph_scrape_tool.py --- .../scrapegraph_scrape_tool.py | 27 +++++-------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py b/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py index 906bf6376..9b5806b19 100644 --- a/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py +++ b/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py @@ -60,16 +60,19 @@ class ScrapegraphScrapeTool(BaseTool): website_url: Optional[str] = None user_prompt: Optional[str] = None api_key: Optional[str] = None + enable_logging: bool = False def __init__( self, website_url: Optional[str] = None, user_prompt: Optional[str] = None, api_key: Optional[str] = None, + enable_logging: bool = False, **kwargs, ): super().__init__(**kwargs) self.api_key = api_key or os.getenv("SCRAPEGRAPH_API_KEY") + self.enable_logging = enable_logging if not self.api_key: raise ValueError("Scrapegraph API key is required") @@ -83,8 +86,9 @@ class ScrapegraphScrapeTool(BaseTool): if user_prompt is not None: self.user_prompt = user_prompt - # Configure logging - sgai_logger.set_logging(level="INFO") + # Configure logging only if enabled + if self.enable_logging: + sgai_logger.set_logging(level="INFO") @staticmethod def _validate_url(url: str) -> None: @@ -96,22 +100,6 @@ class ScrapegraphScrapeTool(BaseTool): except Exception: raise ValueError("Invalid URL format. URL must include scheme (http/https) and domain") - def _handle_api_response(self, response: dict) -> str: - """Handle and validate API response""" - if not response: - raise RuntimeError("Empty response from Scrapegraph API") - - if "error" in response: - error_msg = response.get("error", {}).get("message", "Unknown error") - if "rate limit" in error_msg.lower(): - raise RateLimitError(f"Rate limit exceeded: {error_msg}") - raise RuntimeError(f"API error: {error_msg}") - - if "result" not in response: - raise RuntimeError("Invalid response format from Scrapegraph API") - - return response["result"] - def _run( self, **kwargs: Any, @@ -135,8 +123,7 @@ class ScrapegraphScrapeTool(BaseTool): user_prompt=user_prompt, ) - # Handle and validate the response - return self._handle_api_response(response) + return response except RateLimitError: raise # Re-raise rate limit errors From 4f4b0619079235dcdc59522879822cad5bf0e32a Mon Sep 17 00:00:00 2001 From: Marco Vinciguerra Date: Tue, 7 Jan 2025 16:13:50 +0100 Subject: [PATCH 14/23] fix: scrapegraph-tool --- .../scrapegraph_scrape_tool.py | 27 +++++-------------- 1 file changed, 7 insertions(+), 20 deletions(-) diff --git a/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py b/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py index 906bf6376..9b5806b19 100644 --- a/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py +++ b/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py @@ -60,16 +60,19 @@ class ScrapegraphScrapeTool(BaseTool): website_url: Optional[str] = None user_prompt: Optional[str] = None api_key: Optional[str] = None + enable_logging: bool = False def __init__( self, website_url: Optional[str] = None, user_prompt: Optional[str] = None, api_key: Optional[str] = None, + enable_logging: bool = False, **kwargs, ): super().__init__(**kwargs) self.api_key = api_key or os.getenv("SCRAPEGRAPH_API_KEY") + self.enable_logging = enable_logging if not self.api_key: raise ValueError("Scrapegraph API key is required") @@ -83,8 +86,9 @@ class ScrapegraphScrapeTool(BaseTool): if user_prompt is not None: self.user_prompt = user_prompt - # Configure logging - sgai_logger.set_logging(level="INFO") + # Configure logging only if enabled + if self.enable_logging: + sgai_logger.set_logging(level="INFO") @staticmethod def _validate_url(url: str) -> None: @@ -96,22 +100,6 @@ class ScrapegraphScrapeTool(BaseTool): except Exception: raise ValueError("Invalid URL format. URL must include scheme (http/https) and domain") - def _handle_api_response(self, response: dict) -> str: - """Handle and validate API response""" - if not response: - raise RuntimeError("Empty response from Scrapegraph API") - - if "error" in response: - error_msg = response.get("error", {}).get("message", "Unknown error") - if "rate limit" in error_msg.lower(): - raise RateLimitError(f"Rate limit exceeded: {error_msg}") - raise RuntimeError(f"API error: {error_msg}") - - if "result" not in response: - raise RuntimeError("Invalid response format from Scrapegraph API") - - return response["result"] - def _run( self, **kwargs: Any, @@ -135,8 +123,7 @@ class ScrapegraphScrapeTool(BaseTool): user_prompt=user_prompt, ) - # Handle and validate the response - return self._handle_api_response(response) + return response except RateLimitError: raise # Re-raise rate limit errors From 1a824cf432bbb9feab15ee12b27abeaaa8915e3e Mon Sep 17 00:00:00 2001 From: Nikhil Shahi Date: Mon, 13 Jan 2025 15:48:45 -0600 Subject: [PATCH 15/23] added HyperbrowserLoadTool --- src/crewai_tools/__init__.py | 1 + src/crewai_tools/tools/__init__.py | 1 + .../tools/hyperbrowser_load_tool/README.md | 42 +++++++++ .../hyperbrowser_load_tool.py | 94 +++++++++++++++++++ 4 files changed, 138 insertions(+) create mode 100644 src/crewai_tools/tools/hyperbrowser_load_tool/README.md create mode 100644 src/crewai_tools/tools/hyperbrowser_load_tool/hyperbrowser_load_tool.py diff --git a/src/crewai_tools/__init__.py b/src/crewai_tools/__init__.py index 2db0fa05f..ca46c34d2 100644 --- a/src/crewai_tools/__init__.py +++ b/src/crewai_tools/__init__.py @@ -16,6 +16,7 @@ from .tools import ( FirecrawlScrapeWebsiteTool, FirecrawlSearchTool, GithubSearchTool, + HyperbrowserLoadTool, JSONSearchTool, LinkupSearchTool, LlamaIndexTool, diff --git a/src/crewai_tools/tools/__init__.py b/src/crewai_tools/tools/__init__.py index e4288a310..ac42857bc 100644 --- a/src/crewai_tools/tools/__init__.py +++ b/src/crewai_tools/tools/__init__.py @@ -19,6 +19,7 @@ from .firecrawl_scrape_website_tool.firecrawl_scrape_website_tool import ( ) from .firecrawl_search_tool.firecrawl_search_tool import FirecrawlSearchTool from .github_search_tool.github_search_tool import GithubSearchTool +from .hyperbrowser_load_tool.hyperbrowser_load_tool import HyperbrowserLoadTool from .json_search_tool.json_search_tool import JSONSearchTool from .linkup.linkup_search_tool import LinkupSearchTool from .llamaindex_tool.llamaindex_tool import LlamaIndexTool diff --git a/src/crewai_tools/tools/hyperbrowser_load_tool/README.md b/src/crewai_tools/tools/hyperbrowser_load_tool/README.md new file mode 100644 index 000000000..e95864f5a --- /dev/null +++ b/src/crewai_tools/tools/hyperbrowser_load_tool/README.md @@ -0,0 +1,42 @@ +# HyperbrowserLoadTool + +## Description + +[Hyperbrowser](https://hyperbrowser.ai) is a platform for running and scaling headless browsers. It lets you launch and manage browser sessions at scale and provides easy to use solutions for any webscraping needs, such as scraping a single page or crawling an entire site. + +Key Features: +- Instant Scalability - Spin up hundreds of browser sessions in seconds without infrastructure headaches +- Simple Integration - Works seamlessly with popular tools like Puppeteer and Playwright +- Powerful APIs - Easy to use APIs for scraping/crawling any site, and much more +- Bypass Anti-Bot Measures - Built-in stealth mode, ad blocking, automatic CAPTCHA solving, and rotating proxies + +For more information about Hyperbrowser, please visit the [Hyperbrowser website](https://hyperbrowser.ai) or if you want to check out the docs, you can visit the [Hyperbrowser docs](https://docs.hyperbrowser.ai). + +## Installation + +- Head to [Hyperbrowser](https://app.hyperbrowser.ai/) to sign up and generate an API key. Once you've done this set the `HYPERBROWSER_API_KEY` environment variable or you can pass it to the `HyperbrowserLoadTool` constructor. +- Install the [Hyperbrowser SDK](https://github.com/hyperbrowserai/python-sdk): + +``` +pip install hyperbrowser 'crewai[tools]' +``` + +## Example + +Utilize the HyperbrowserLoadTool as follows to allow your agent to load websites: + +```python +from crewai_tools import HyperbrowserLoadTool + +tool = HyperbrowserLoadTool() +``` + +## Arguments + +`__init__` arguments: +- `api_key`: Optional. Specifies Hyperbrowser API key. Defaults to the `HYPERBROWSER_API_KEY` environment variable. + +`run` arguments: +- `url`: The base URL to start scraping or crawling from. +- `operation`: Optional. Specifies the operation to perform on the website. Either 'scrape' or 'crawl'. Defaults is 'scrape'. +- `params`: Optional. Specifies the params for the operation. For more information on the supported params, visit https://docs.hyperbrowser.ai/reference/sdks/python/scrape#start-scrape-job-and-wait or https://docs.hyperbrowser.ai/reference/sdks/python/crawl#start-crawl-job-and-wait. diff --git a/src/crewai_tools/tools/hyperbrowser_load_tool/hyperbrowser_load_tool.py b/src/crewai_tools/tools/hyperbrowser_load_tool/hyperbrowser_load_tool.py new file mode 100644 index 000000000..eb52b151c --- /dev/null +++ b/src/crewai_tools/tools/hyperbrowser_load_tool/hyperbrowser_load_tool.py @@ -0,0 +1,94 @@ +import os +from typing import Any, Optional, Type, Dict, Literal, Union + +from crewai.tools import BaseTool +from pydantic import BaseModel, Field + + +class HyperbrowserLoadToolSchema(BaseModel): + url: str = Field(description="Website URL") + operation: Literal['scrape', 'crawl'] = Field(description="Operation to perform on the website. Either 'scrape' or 'crawl'") + params: Optional[Dict] = Field(description="Optional params for scrape or crawl. For more information on the supported params, visit https://docs.hyperbrowser.ai/reference/sdks/python/scrape#start-scrape-job-and-wait or https://docs.hyperbrowser.ai/reference/sdks/python/crawl#start-crawl-job-and-wait") + +class HyperbrowserLoadTool(BaseTool): + name: str = "Hyperbrowser web load tool" + description: str = "Scrape or crawl a website using Hyperbrowser and return the contents in properly formatted markdown or html" + args_schema: Type[BaseModel] = HyperbrowserLoadToolSchema + api_key: Optional[str] = None + hyperbrowser: Optional[Any] = None + + def __init__(self, api_key: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + self.api_key = api_key or os.getenv('HYPERBROWSER_API_KEY') + if not api_key: + raise ValueError( + "`api_key` is required, please set the `HYPERBROWSER_API_KEY` environment variable or pass it directly" + ) + + try: + from hyperbrowser import Hyperbrowser + except ImportError: + raise ImportError("`hyperbrowser` package not found, please run `pip install hyperbrowser`") + + if not self.api_key: + raise ValueError("HYPERBROWSER_API_KEY is not set. Please provide it either via the constructor with the `api_key` argument or by setting the HYPERBROWSER_API_KEY environment variable.") + + self.hyperbrowser = Hyperbrowser(api_key=self.api_key) + + def _prepare_params(self, params: Dict) -> Dict: + """Prepare session and scrape options parameters.""" + try: + from hyperbrowser.models.session import CreateSessionParams + from hyperbrowser.models.scrape import ScrapeOptions + except ImportError: + raise ImportError( + "`hyperbrowser` package not found, please run `pip install hyperbrowser`" + ) + + if "scrape_options" in params: + if "formats" in params["scrape_options"]: + formats = params["scrape_options"]["formats"] + if not all(fmt in ["markdown", "html"] for fmt in formats): + raise ValueError("formats can only contain 'markdown' or 'html'") + + if "session_options" in params: + params["session_options"] = CreateSessionParams(**params["session_options"]) + if "scrape_options" in params: + params["scrape_options"] = ScrapeOptions(**params["scrape_options"]) + return params + + def _extract_content(self, data: Union[Any, None]): + """Extract content from response data.""" + content = "" + if data: + content = data.markdown or data.html or "" + return content + + def _run(self, url: str, operation: Literal['scrape', 'crawl'] = 'scrape', params: Optional[Dict] = {}): + try: + from hyperbrowser.models.scrape import StartScrapeJobParams + from hyperbrowser.models.crawl import StartCrawlJobParams + except ImportError: + raise ImportError( + "`hyperbrowser` package not found, please run `pip install hyperbrowser`" + ) + + params = self._prepare_params(params) + + if operation == 'scrape': + scrape_params = StartScrapeJobParams(url=url, **params) + scrape_resp = self.hyperbrowser.scrape.start_and_wait(scrape_params) + content = self._extract_content(scrape_resp.data) + return content + else: + crawl_params = StartCrawlJobParams(url=url, **params) + crawl_resp = self.hyperbrowser.crawl.start_and_wait(crawl_params) + content = "" + if crawl_resp.data: + for page in crawl_resp.data: + page_content = self._extract_content(page) + if page_content: + content += ( + f"\n{'-'*50}\nUrl: {page.url}\nContent:\n{page_content}\n" + ) + return content From e343f26c037f5557c0f31654bf053802b1b534f6 Mon Sep 17 00:00:00 2001 From: Nikhil Shahi Date: Mon, 13 Jan 2025 16:08:11 -0600 Subject: [PATCH 16/23] add docstring --- .../hyperbrowser_load_tool/hyperbrowser_load_tool.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/crewai_tools/tools/hyperbrowser_load_tool/hyperbrowser_load_tool.py b/src/crewai_tools/tools/hyperbrowser_load_tool/hyperbrowser_load_tool.py index eb52b151c..b802d1859 100644 --- a/src/crewai_tools/tools/hyperbrowser_load_tool/hyperbrowser_load_tool.py +++ b/src/crewai_tools/tools/hyperbrowser_load_tool/hyperbrowser_load_tool.py @@ -11,6 +11,15 @@ class HyperbrowserLoadToolSchema(BaseModel): params: Optional[Dict] = Field(description="Optional params for scrape or crawl. For more information on the supported params, visit https://docs.hyperbrowser.ai/reference/sdks/python/scrape#start-scrape-job-and-wait or https://docs.hyperbrowser.ai/reference/sdks/python/crawl#start-crawl-job-and-wait") class HyperbrowserLoadTool(BaseTool): + """HyperbrowserLoadTool. + + Scrape or crawl web pages and load the contents with optional parameters for configuring content extraction. + Requires the `hyperbrowser` package. + Get your API Key from https://app.hyperbrowser.ai/ + + Args: + api_key: The Hyperbrowser API key, can be set as an environment variable `HYPERBROWSER_API_KEY` or passed directly + """ name: str = "Hyperbrowser web load tool" description: str = "Scrape or crawl a website using Hyperbrowser and return the contents in properly formatted markdown or html" args_schema: Type[BaseModel] = HyperbrowserLoadToolSchema From 334beda1810201fcadab8dde8c50c04d9f968549 Mon Sep 17 00:00:00 2001 From: Tom Mahler Date: Tue, 14 Jan 2025 21:06:42 +0200 Subject: [PATCH 17/23] added missing import --- .../tools/code_interpreter_tool/code_interpreter_tool.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/crewai_tools/tools/code_interpreter_tool/code_interpreter_tool.py b/src/crewai_tools/tools/code_interpreter_tool/code_interpreter_tool.py index fd0d39932..8924d52c0 100644 --- a/src/crewai_tools/tools/code_interpreter_tool/code_interpreter_tool.py +++ b/src/crewai_tools/tools/code_interpreter_tool/code_interpreter_tool.py @@ -3,6 +3,7 @@ import os from typing import List, Optional, Type from docker import from_env as docker_from_env +from docker import DockerClient from docker.models.containers import Container from docker.errors import ImageNotFound, NotFound from crewai.tools import BaseTool @@ -43,7 +44,7 @@ class CodeInterpreterTool(BaseTool): Verify if the Docker image is available. Optionally use a user-provided Dockerfile. """ - client = docker_from_env() if self.user_docker_base_url == None else docker.DockerClient(base_url=self.user_docker_base_url) + client = docker_from_env() if self.user_docker_base_url == None else DockerClient(base_url=self.user_docker_base_url) try: client.images.get(self.default_image_tag) From 1bd87f514e8984e19b1c6e09cb8cee09a30a5601 Mon Sep 17 00:00:00 2001 From: Tom Mahler Date: Tue, 14 Jan 2025 21:07:08 +0200 Subject: [PATCH 18/23] changed == None to is None --- .../tools/code_interpreter_tool/code_interpreter_tool.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/crewai_tools/tools/code_interpreter_tool/code_interpreter_tool.py b/src/crewai_tools/tools/code_interpreter_tool/code_interpreter_tool.py index 8924d52c0..5d23c580a 100644 --- a/src/crewai_tools/tools/code_interpreter_tool/code_interpreter_tool.py +++ b/src/crewai_tools/tools/code_interpreter_tool/code_interpreter_tool.py @@ -31,7 +31,7 @@ class CodeInterpreterTool(BaseTool): default_image_tag: str = "code-interpreter:latest" code: Optional[str] = None user_dockerfile_path: Optional[str] = None - user_docker_base_url: Optional[str] = None + user_docker_base_url: Optional[str] = None unsafe_mode: bool = False @staticmethod @@ -44,7 +44,7 @@ class CodeInterpreterTool(BaseTool): Verify if the Docker image is available. Optionally use a user-provided Dockerfile. """ - client = docker_from_env() if self.user_docker_base_url == None else DockerClient(base_url=self.user_docker_base_url) + client = docker_from_env() if self.user_docker_base_url is None else DockerClient(base_url=self.user_docker_base_url) try: client.images.get(self.default_image_tag) @@ -136,4 +136,4 @@ class CodeInterpreterTool(BaseTool): exec(code, {}, exec_locals) return exec_locals.get("result", "No result variable found.") except Exception as e: - return f"An error occurred: {str(e)}" + return f"An error occurred: {str(e)}" \ No newline at end of file From 1568008db61d63eadb4b6f03657b785e3aec00f1 Mon Sep 17 00:00:00 2001 From: Carter Chen Date: Mon, 13 Jan 2025 21:33:40 -0500 Subject: [PATCH 19/23] remove _generate_description on file_read_tool --- .../tools/file_read_tool/file_read_tool.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/src/crewai_tools/tools/file_read_tool/file_read_tool.py b/src/crewai_tools/tools/file_read_tool/file_read_tool.py index 323a26d51..9106533fa 100644 --- a/src/crewai_tools/tools/file_read_tool/file_read_tool.py +++ b/src/crewai_tools/tools/file_read_tool/file_read_tool.py @@ -49,6 +49,7 @@ class FileReadTool(BaseTool): if file_path is not None: self.file_path = file_path self.description = f"A tool that reads file content. The default file is {file_path}, but you can provide a different 'file_path' parameter to read another file." + self._generate_description() def _run( self, @@ -68,14 +69,3 @@ class FileReadTool(BaseTool): except Exception as e: return f"Error: Failed to read file {file_path}. {str(e)}" - def _generate_description(self) -> None: - """Generate the tool description based on file path. - - This method updates the tool's description to include information about - the default file path while maintaining the ability to specify a different - file at runtime. - - Returns: - None - """ - self.description = f"A tool that can be used to read {self.file_path}'s content." From fe2a5abf8d19b72bedd2c7d0c099976e7abf7051 Mon Sep 17 00:00:00 2001 From: Carter Chen Date: Tue, 14 Jan 2025 21:16:11 -0500 Subject: [PATCH 20/23] restructure init statement to remove duplicate call to _generate_description --- src/crewai_tools/tools/file_read_tool/file_read_tool.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/crewai_tools/tools/file_read_tool/file_read_tool.py b/src/crewai_tools/tools/file_read_tool/file_read_tool.py index 9106533fa..22a1204f6 100644 --- a/src/crewai_tools/tools/file_read_tool/file_read_tool.py +++ b/src/crewai_tools/tools/file_read_tool/file_read_tool.py @@ -45,11 +45,11 @@ class FileReadTool(BaseTool): this becomes the default file path for the tool. **kwargs: Additional keyword arguments passed to BaseTool. """ - super().__init__(**kwargs) if file_path is not None: - self.file_path = file_path - self.description = f"A tool that reads file content. The default file is {file_path}, but you can provide a different 'file_path' parameter to read another file." - self._generate_description() + kwargs['description'] = f"A tool that reads file content. The default file is {file_path}, but you can provide a different 'file_path' parameter to read another file." + + super().__init__(**kwargs) + self.file_path = file_path def _run( self, From 9c4c4219cd18b75f56fce8279a7cca1eb7672829 Mon Sep 17 00:00:00 2001 From: ChethanUK Date: Fri, 17 Jan 2025 02:23:06 +0530 Subject: [PATCH 21/23] Adding Snowflake search tool --- src/crewai_tools/__init__.py | 2 + src/crewai_tools/tools/__init__.py | 5 + .../browserbase_load_tool.py | 14 +- .../code_interpreter_tool.py | 18 +- .../directory_read_tool.py | 2 - .../tools/file_read_tool/file_read_tool.py | 7 +- .../file_writer_tool/file_writer_tool.py | 4 +- .../firecrawl_crawl_website_tool.py | 1 - .../firecrawl_scrape_website_tool.py | 1 - .../github_search_tool/github_search_tool.py | 4 +- .../jina_scrape_website_tool.py | 4 +- .../tools/linkup/linkup_search_tool.py | 16 +- .../mysql_search_tool/mysql_search_tool.py | 4 +- .../tools/patronus_eval_tool/example.py | 26 +-- .../patronus_eval_tool/patronus_eval_tool.py | 14 +- .../patronus_local_evaluator_tool.py | 17 +- .../patronus_predefined_criteria_eval_tool.py | 12 +- .../pdf_text_writing_tool.py | 8 +- .../tools/pg_seach_tool/pg_search_tool.py | 4 +- .../scrape_element_from_website.py | 2 - .../scrape_website_tool.py | 2 - .../scrapegraph_scrape_tool.py | 34 +-- .../selenium_scraping_tool.py | 27 ++- .../tools/serpapi_tool/serpapi_base_tool.py | 3 +- .../serpapi_google_search_tool.py | 41 ++-- .../serpapi_google_shopping_tool.py | 41 ++-- .../tools/serper_dev_tool/serper_dev_tool.py | 4 +- .../serply_webpage_to_markdown_tool.py | 4 +- .../tools/snowflake_search_tool/README.md | 155 +++++++++++++ .../tools/snowflake_search_tool/__init__.py | 11 + .../snowflake_search_tool.py | 201 ++++++++++++++++ .../tools/stagehand_tool/stagehand_tool.py | 154 ++++++------ .../tools/vision_tool/vision_tool.py | 24 +- .../tools/weaviate_tool/vector_search.py | 3 +- .../website_search/website_search_tool.py | 4 +- .../youtube_channel_search_tool.py | 4 +- .../youtube_video_search_tool.py | 4 +- tests/base_tool_test.py | 133 +++++++---- tests/file_read_tool_test.py | 6 +- tests/it/tools/__init__.py | 0 tests/it/tools/conftest.py | 21 ++ tests/it/tools/snowflake_search_tool_test.py | 219 ++++++++++++++++++ tests/spider_tool_test.py | 21 +- tests/tools/snowflake_search_tool_test.py | 103 ++++++++ tests/tools/test_code_interpreter_tool.py | 16 +- 45 files changed, 1089 insertions(+), 311 deletions(-) create mode 100644 src/crewai_tools/tools/snowflake_search_tool/README.md create mode 100644 src/crewai_tools/tools/snowflake_search_tool/__init__.py create mode 100644 src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py create mode 100644 tests/it/tools/__init__.py create mode 100644 tests/it/tools/conftest.py create mode 100644 tests/it/tools/snowflake_search_tool_test.py create mode 100644 tests/tools/snowflake_search_tool_test.py diff --git a/src/crewai_tools/__init__.py b/src/crewai_tools/__init__.py index 2db0fa05f..9c7e9d9a9 100644 --- a/src/crewai_tools/__init__.py +++ b/src/crewai_tools/__init__.py @@ -43,6 +43,8 @@ from .tools import ( SerplyScholarSearchTool, SerplyWebpageToMarkdownTool, SerplyWebSearchTool, + SnowflakeConfig, + SnowflakeSearchTool, SpiderTool, TXTSearchTool, VisionTool, diff --git a/src/crewai_tools/tools/__init__.py b/src/crewai_tools/tools/__init__.py index e4288a310..ea5a87ce1 100644 --- a/src/crewai_tools/tools/__init__.py +++ b/src/crewai_tools/tools/__init__.py @@ -54,6 +54,11 @@ from .serply_api_tool.serply_news_search_tool import SerplyNewsSearchTool from .serply_api_tool.serply_scholar_search_tool import SerplyScholarSearchTool from .serply_api_tool.serply_web_search_tool import SerplyWebSearchTool from .serply_api_tool.serply_webpage_to_markdown_tool import SerplyWebpageToMarkdownTool +from .snowflake_search_tool import ( + SnowflakeConfig, + SnowflakeSearchTool, + SnowflakeSearchToolInput, +) from .spider_tool.spider_tool import SpiderTool from .txt_search_tool.txt_search_tool import TXTSearchTool from .vision_tool.vision_tool import VisionTool diff --git a/src/crewai_tools/tools/browserbase_load_tool/browserbase_load_tool.py b/src/crewai_tools/tools/browserbase_load_tool/browserbase_load_tool.py index 2ca1b95fc..d3f76e0a6 100644 --- a/src/crewai_tools/tools/browserbase_load_tool/browserbase_load_tool.py +++ b/src/crewai_tools/tools/browserbase_load_tool/browserbase_load_tool.py @@ -1,8 +1,8 @@ import os from typing import Any, Optional, Type -from pydantic import BaseModel, Field from crewai.tools import BaseTool +from pydantic import BaseModel, Field class BrowserbaseLoadToolSchema(BaseModel): @@ -11,12 +11,10 @@ class BrowserbaseLoadToolSchema(BaseModel): class BrowserbaseLoadTool(BaseTool): name: str = "Browserbase web load tool" - description: str = ( - "Load webpages url in a headless browser using Browserbase and return the contents" - ) + description: str = "Load webpages url in a headless browser using Browserbase and return the contents" args_schema: Type[BaseModel] = BrowserbaseLoadToolSchema - api_key: Optional[str] = os.getenv('BROWSERBASE_API_KEY') - project_id: Optional[str] = os.getenv('BROWSERBASE_PROJECT_ID') + api_key: Optional[str] = os.getenv("BROWSERBASE_API_KEY") + project_id: Optional[str] = os.getenv("BROWSERBASE_PROJECT_ID") text_content: Optional[bool] = False session_id: Optional[str] = None proxy: Optional[bool] = None @@ -33,7 +31,9 @@ class BrowserbaseLoadTool(BaseTool): ): super().__init__(**kwargs) if not self.api_key: - raise EnvironmentError("BROWSERBASE_API_KEY environment variable is required for initialization") + raise EnvironmentError( + "BROWSERBASE_API_KEY environment variable is required for initialization" + ) try: from browserbase import Browserbase # type: ignore except ImportError: diff --git a/src/crewai_tools/tools/code_interpreter_tool/code_interpreter_tool.py b/src/crewai_tools/tools/code_interpreter_tool/code_interpreter_tool.py index fd0d39932..b508e4b6a 100644 --- a/src/crewai_tools/tools/code_interpreter_tool/code_interpreter_tool.py +++ b/src/crewai_tools/tools/code_interpreter_tool/code_interpreter_tool.py @@ -2,10 +2,10 @@ import importlib.util import os from typing import List, Optional, Type -from docker import from_env as docker_from_env -from docker.models.containers import Container -from docker.errors import ImageNotFound, NotFound from crewai.tools import BaseTool +from docker import from_env as docker_from_env +from docker.errors import ImageNotFound, NotFound +from docker.models.containers import Container from pydantic import BaseModel, Field @@ -30,7 +30,7 @@ class CodeInterpreterTool(BaseTool): default_image_tag: str = "code-interpreter:latest" code: Optional[str] = None user_dockerfile_path: Optional[str] = None - user_docker_base_url: Optional[str] = None + user_docker_base_url: Optional[str] = None unsafe_mode: bool = False @staticmethod @@ -43,7 +43,11 @@ class CodeInterpreterTool(BaseTool): Verify if the Docker image is available. Optionally use a user-provided Dockerfile. """ - client = docker_from_env() if self.user_docker_base_url == None else docker.DockerClient(base_url=self.user_docker_base_url) + client = ( + docker_from_env() + if self.user_docker_base_url == None + else docker.DockerClient(base_url=self.user_docker_base_url) + ) try: client.images.get(self.default_image_tag) @@ -76,9 +80,7 @@ class CodeInterpreterTool(BaseTool): else: return self.run_code_in_docker(code, libraries_used) - def _install_libraries( - self, container: Container, libraries: List[str] - ) -> None: + def _install_libraries(self, container: Container, libraries: List[str]) -> None: """ Install missing libraries in the Docker container """ diff --git a/src/crewai_tools/tools/directory_read_tool/directory_read_tool.py b/src/crewai_tools/tools/directory_read_tool/directory_read_tool.py index 6033202be..8488f391e 100644 --- a/src/crewai_tools/tools/directory_read_tool/directory_read_tool.py +++ b/src/crewai_tools/tools/directory_read_tool/directory_read_tool.py @@ -8,8 +8,6 @@ from pydantic import BaseModel, Field class FixedDirectoryReadToolSchema(BaseModel): """Input for DirectoryReadTool.""" - pass - class DirectoryReadToolSchema(FixedDirectoryReadToolSchema): """Input for DirectoryReadTool.""" diff --git a/src/crewai_tools/tools/file_read_tool/file_read_tool.py b/src/crewai_tools/tools/file_read_tool/file_read_tool.py index 323a26d51..384b97f40 100644 --- a/src/crewai_tools/tools/file_read_tool/file_read_tool.py +++ b/src/crewai_tools/tools/file_read_tool/file_read_tool.py @@ -32,6 +32,7 @@ class FileReadTool(BaseTool): >>> content = tool.run() # Reads /path/to/file.txt >>> content = tool.run(file_path="/path/to/other.txt") # Reads other.txt """ + name: str = "Read a file's content" description: str = "A tool that reads the content of a file. To use this tool, provide a 'file_path' parameter with the path to the file you want to read." args_schema: Type[BaseModel] = FileReadToolSchema @@ -57,7 +58,7 @@ class FileReadTool(BaseTool): file_path = kwargs.get("file_path", self.file_path) if file_path is None: return "Error: No file path provided. Please provide a file path either in the constructor or as an argument." - + try: with open(file_path, "r") as file: return file.read() @@ -78,4 +79,6 @@ class FileReadTool(BaseTool): Returns: None """ - self.description = f"A tool that can be used to read {self.file_path}'s content." + self.description = ( + f"A tool that can be used to read {self.file_path}'s content." + ) diff --git a/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py b/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py index ed454a1bd..f975d3301 100644 --- a/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py +++ b/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py @@ -15,9 +15,7 @@ class FileWriterToolInput(BaseModel): class FileWriterTool(BaseTool): name: str = "File Writer Tool" - description: str = ( - "A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input." - ) + description: str = "A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input." args_schema: Type[BaseModel] = FileWriterToolInput def _run(self, **kwargs: Any) -> str: diff --git a/src/crewai_tools/tools/firecrawl_crawl_website_tool/firecrawl_crawl_website_tool.py b/src/crewai_tools/tools/firecrawl_crawl_website_tool/firecrawl_crawl_website_tool.py index 6c7c4ffd9..dcb70e291 100644 --- a/src/crewai_tools/tools/firecrawl_crawl_website_tool/firecrawl_crawl_website_tool.py +++ b/src/crewai_tools/tools/firecrawl_crawl_website_tool/firecrawl_crawl_website_tool.py @@ -72,4 +72,3 @@ except ImportError: """ When this tool is not used, then exception can be ignored. """ - pass diff --git a/src/crewai_tools/tools/firecrawl_scrape_website_tool/firecrawl_scrape_website_tool.py b/src/crewai_tools/tools/firecrawl_scrape_website_tool/firecrawl_scrape_website_tool.py index 9458e7a4f..3f5f8c4c4 100644 --- a/src/crewai_tools/tools/firecrawl_scrape_website_tool/firecrawl_scrape_website_tool.py +++ b/src/crewai_tools/tools/firecrawl_scrape_website_tool/firecrawl_scrape_website_tool.py @@ -63,4 +63,3 @@ except ImportError: """ When this tool is not used, then exception can be ignored. """ - pass diff --git a/src/crewai_tools/tools/github_search_tool/github_search_tool.py b/src/crewai_tools/tools/github_search_tool/github_search_tool.py index 4bf8b9e05..6ba7b919c 100644 --- a/src/crewai_tools/tools/github_search_tool/github_search_tool.py +++ b/src/crewai_tools/tools/github_search_tool/github_search_tool.py @@ -27,9 +27,7 @@ class GithubSearchToolSchema(FixedGithubSearchToolSchema): class GithubSearchTool(RagTool): name: str = "Search a github repo's content" - description: str = ( - "A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities." - ) + description: str = "A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities." summarize: bool = False gh_token: str args_schema: Type[BaseModel] = GithubSearchToolSchema diff --git a/src/crewai_tools/tools/jina_scrape_website_tool/jina_scrape_website_tool.py b/src/crewai_tools/tools/jina_scrape_website_tool/jina_scrape_website_tool.py index a10a4ffdb..86f771cd0 100644 --- a/src/crewai_tools/tools/jina_scrape_website_tool/jina_scrape_website_tool.py +++ b/src/crewai_tools/tools/jina_scrape_website_tool/jina_scrape_website_tool.py @@ -13,9 +13,7 @@ class JinaScrapeWebsiteToolInput(BaseModel): class JinaScrapeWebsiteTool(BaseTool): name: str = "JinaScrapeWebsiteTool" - description: str = ( - "A tool that can be used to read a website content using Jina.ai reader and return markdown content." - ) + description: str = "A tool that can be used to read a website content using Jina.ai reader and return markdown content." args_schema: Type[BaseModel] = JinaScrapeWebsiteToolInput website_url: Optional[str] = None api_key: Optional[str] = None diff --git a/src/crewai_tools/tools/linkup/linkup_search_tool.py b/src/crewai_tools/tools/linkup/linkup_search_tool.py index b172ad029..486663d3e 100644 --- a/src/crewai_tools/tools/linkup/linkup_search_tool.py +++ b/src/crewai_tools/tools/linkup/linkup_search_tool.py @@ -2,6 +2,7 @@ from typing import Any try: from linkup import LinkupClient + LINKUP_AVAILABLE = True except ImportError: LINKUP_AVAILABLE = False @@ -9,10 +10,13 @@ except ImportError: from pydantic import PrivateAttr + class LinkupSearchTool: name: str = "Linkup Search Tool" - description: str = "Performs an API call to Linkup to retrieve contextual information." - _client: LinkupClient = PrivateAttr() # type: ignore + description: str = ( + "Performs an API call to Linkup to retrieve contextual information." + ) + _client: LinkupClient = PrivateAttr() # type: ignore def __init__(self, api_key: str): """ @@ -25,7 +29,9 @@ class LinkupSearchTool: ) self._client = LinkupClient(api_key=api_key) - def _run(self, query: str, depth: str = "standard", output_type: str = "searchResults") -> dict: + def _run( + self, query: str, depth: str = "standard", output_type: str = "searchResults" + ) -> dict: """ Executes a search using the Linkup API. @@ -36,9 +42,7 @@ class LinkupSearchTool: """ try: response = self._client.search( - query=query, - depth=depth, - output_type=output_type + query=query, depth=depth, output_type=output_type ) results = [ {"name": result.name, "url": result.url, "content": result.content} diff --git a/src/crewai_tools/tools/mysql_search_tool/mysql_search_tool.py b/src/crewai_tools/tools/mysql_search_tool/mysql_search_tool.py index f931a006b..a472e1761 100644 --- a/src/crewai_tools/tools/mysql_search_tool/mysql_search_tool.py +++ b/src/crewai_tools/tools/mysql_search_tool/mysql_search_tool.py @@ -17,9 +17,7 @@ class MySQLSearchToolSchema(BaseModel): class MySQLSearchTool(RagTool): name: str = "Search a database's table content" - description: str = ( - "A tool that can be used to semantic search a query from a database table's content." - ) + description: str = "A tool that can be used to semantic search a query from a database table's content." args_schema: Type[BaseModel] = MySQLSearchToolSchema db_uri: str = Field(..., description="Mandatory database URI") diff --git a/src/crewai_tools/tools/patronus_eval_tool/example.py b/src/crewai_tools/tools/patronus_eval_tool/example.py index b9e1bad5e..185e9f485 100644 --- a/src/crewai_tools/tools/patronus_eval_tool/example.py +++ b/src/crewai_tools/tools/patronus_eval_tool/example.py @@ -1,30 +1,24 @@ -from crewai import Agent, Crew, Task -from patronus_eval_tool import ( - PatronusEvalTool, -) -from patronus_local_evaluator_tool import ( - PatronusLocalEvaluatorTool, -) -from patronus_predefined_criteria_eval_tool import ( - PatronusPredefinedCriteriaEvalTool, -) -from patronus import Client, EvaluationResult import random +from crewai import Agent, Crew, Task +from patronus import Client, EvaluationResult +from patronus_local_evaluator_tool import PatronusLocalEvaluatorTool # Test the PatronusLocalEvaluatorTool where agent uses the local evaluator client = Client() + # Example of an evaluator that returns a random pass/fail result @client.register_local_evaluator("random_evaluator") def random_evaluator(**kwargs): score = random.random() return EvaluationResult( - score_raw=score, - pass_=score >= 0.5, - explanation="example explanation" # Optional justification for LLM judges + score_raw=score, + pass_=score >= 0.5, + explanation="example explanation", # Optional justification for LLM judges ) + # 1. Uses PatronusEvalTool: agent can pick the best evaluator and criteria # patronus_eval_tool = PatronusEvalTool() @@ -35,7 +29,9 @@ def random_evaluator(**kwargs): # 3. Uses PatronusLocalEvaluatorTool: agent uses user defined evaluator patronus_eval_tool = PatronusLocalEvaluatorTool( - patronus_client=client, evaluator="random_evaluator", evaluated_model_gold_answer="example label" + patronus_client=client, + evaluator="random_evaluator", + evaluated_model_gold_answer="example label", ) # Create a new agent diff --git a/src/crewai_tools/tools/patronus_eval_tool/patronus_eval_tool.py b/src/crewai_tools/tools/patronus_eval_tool/patronus_eval_tool.py index 23ffe2fd4..be1f410e2 100644 --- a/src/crewai_tools/tools/patronus_eval_tool/patronus_eval_tool.py +++ b/src/crewai_tools/tools/patronus_eval_tool/patronus_eval_tool.py @@ -1,8 +1,9 @@ -import os import json -import requests +import os import warnings -from typing import Any, List, Dict, Optional +from typing import Any, Dict, List, Optional + +import requests from crewai.tools import BaseTool @@ -19,7 +20,9 @@ class PatronusEvalTool(BaseTool): self.evaluators = temp_evaluators self.criteria = temp_criteria self.description = self._generate_description() - warnings.warn("You are allowing the agent to select the best evaluator and criteria when you use the `PatronusEvalTool`. If this is not intended then please use `PatronusPredefinedCriteriaEvalTool` instead.") + warnings.warn( + "You are allowing the agent to select the best evaluator and criteria when you use the `PatronusEvalTool`. If this is not intended then please use `PatronusPredefinedCriteriaEvalTool` instead." + ) def _init_run(self): evaluators_set = json.loads( @@ -104,7 +107,6 @@ class PatronusEvalTool(BaseTool): evaluated_model_retrieved_context: Optional[str], evaluators: List[Dict[str, str]], ) -> Any: - # Assert correct format of evaluators evals = [] for ev in evaluators: @@ -136,4 +138,4 @@ class PatronusEvalTool(BaseTool): f"Failed to evaluate model input and output. Response status code: {response.status_code}. Reason: {response.text}" ) - return response.json() \ No newline at end of file + return response.json() diff --git a/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py b/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py index e65cb342d..66781c593 100644 --- a/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py +++ b/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py @@ -1,7 +1,8 @@ from typing import Any, Type + from crewai.tools import BaseTool -from pydantic import BaseModel, Field from patronus import Client +from pydantic import BaseModel, Field class FixedLocalEvaluatorToolSchema(BaseModel): @@ -24,16 +25,20 @@ class PatronusLocalEvaluatorTool(BaseTool): name: str = "Patronus Local Evaluator Tool" evaluator: str = "The registered local evaluator" evaluated_model_gold_answer: str = "The agent's gold answer" - description: str = ( - "This tool is used to evaluate the model input and output using custom function evaluators." - ) + description: str = "This tool is used to evaluate the model input and output using custom function evaluators." client: Any = None args_schema: Type[BaseModel] = FixedLocalEvaluatorToolSchema class Config: arbitrary_types_allowed = True - def __init__(self, patronus_client: Client, evaluator: str, evaluated_model_gold_answer: str, **kwargs: Any): + def __init__( + self, + patronus_client: Client, + evaluator: str, + evaluated_model_gold_answer: str, + **kwargs: Any, + ): super().__init__(**kwargs) self.client = patronus_client if evaluator: @@ -79,7 +84,7 @@ class PatronusLocalEvaluatorTool(BaseTool): if isinstance(evaluated_model_gold_answer, str) else evaluated_model_gold_answer.get("description") ), - tags={}, # Optional metadata, supports arbitrary kv pairs + tags={}, # Optional metadata, supports arbitrary kv pairs ) output = f"Evaluation result: {result.pass_}, Explanation: {result.explanation}" return output diff --git a/src/crewai_tools/tools/patronus_eval_tool/patronus_predefined_criteria_eval_tool.py b/src/crewai_tools/tools/patronus_eval_tool/patronus_predefined_criteria_eval_tool.py index 28ffc2912..cf906586d 100644 --- a/src/crewai_tools/tools/patronus_eval_tool/patronus_predefined_criteria_eval_tool.py +++ b/src/crewai_tools/tools/patronus_eval_tool/patronus_predefined_criteria_eval_tool.py @@ -1,7 +1,8 @@ -import os import json +import os +from typing import Any, Dict, List, Type + import requests -from typing import Any, List, Dict, Type from crewai.tools import BaseTool from pydantic import BaseModel, Field @@ -33,9 +34,7 @@ class PatronusPredefinedCriteriaEvalTool(BaseTool): """ name: str = "Call Patronus API tool for evaluation of model inputs and outputs" - description: str = ( - """This tool calls the Patronus Evaluation API that takes the following arguments:""" - ) + description: str = """This tool calls the Patronus Evaluation API that takes the following arguments:""" evaluate_url: str = "https://api.patronus.ai/v1/evaluate" args_schema: Type[BaseModel] = FixedBaseToolSchema evaluators: List[Dict[str, str]] = [] @@ -52,7 +51,6 @@ class PatronusPredefinedCriteriaEvalTool(BaseTool): self, **kwargs: Any, ) -> Any: - evaluated_model_input = kwargs.get("evaluated_model_input") evaluated_model_output = kwargs.get("evaluated_model_output") evaluated_model_retrieved_context = kwargs.get( @@ -103,4 +101,4 @@ class PatronusPredefinedCriteriaEvalTool(BaseTool): f"Failed to evaluate model input and output. Status code: {response.status_code}. Reason: {response.text}" ) - return response.json() \ No newline at end of file + return response.json() diff --git a/src/crewai_tools/tools/pdf_text_writing_tool/pdf_text_writing_tool.py b/src/crewai_tools/tools/pdf_text_writing_tool/pdf_text_writing_tool.py index 851593167..ad4d847b6 100644 --- a/src/crewai_tools/tools/pdf_text_writing_tool/pdf_text_writing_tool.py +++ b/src/crewai_tools/tools/pdf_text_writing_tool/pdf_text_writing_tool.py @@ -1,7 +1,9 @@ -from typing import Optional, Type -from pydantic import BaseModel, Field -from pypdf import PdfReader, PdfWriter, PageObject, ContentStream, NameObject, Font from pathlib import Path +from typing import Optional, Type + +from pydantic import BaseModel, Field +from pypdf import ContentStream, Font, NameObject, PageObject, PdfReader, PdfWriter + from crewai_tools.tools.rag.rag_tool import RagTool diff --git a/src/crewai_tools/tools/pg_seach_tool/pg_search_tool.py b/src/crewai_tools/tools/pg_seach_tool/pg_search_tool.py index dc75470a2..ec0207aa7 100644 --- a/src/crewai_tools/tools/pg_seach_tool/pg_search_tool.py +++ b/src/crewai_tools/tools/pg_seach_tool/pg_search_tool.py @@ -17,9 +17,7 @@ class PGSearchToolSchema(BaseModel): class PGSearchTool(RagTool): name: str = "Search a database's table content" - description: str = ( - "A tool that can be used to semantic search a query from a database table's content." - ) + description: str = "A tool that can be used to semantic search a query from a database table's content." args_schema: Type[BaseModel] = PGSearchToolSchema db_uri: str = Field(..., description="Mandatory database URI") diff --git a/src/crewai_tools/tools/scrape_element_from_website/scrape_element_from_website.py b/src/crewai_tools/tools/scrape_element_from_website/scrape_element_from_website.py index 14757d247..f1e215bf3 100644 --- a/src/crewai_tools/tools/scrape_element_from_website/scrape_element_from_website.py +++ b/src/crewai_tools/tools/scrape_element_from_website/scrape_element_from_website.py @@ -10,8 +10,6 @@ from pydantic import BaseModel, Field class FixedScrapeElementFromWebsiteToolSchema(BaseModel): """Input for ScrapeElementFromWebsiteTool.""" - pass - class ScrapeElementFromWebsiteToolSchema(FixedScrapeElementFromWebsiteToolSchema): """Input for ScrapeElementFromWebsiteTool.""" diff --git a/src/crewai_tools/tools/scrape_website_tool/scrape_website_tool.py b/src/crewai_tools/tools/scrape_website_tool/scrape_website_tool.py index 8cfc5d136..0e7e25ca6 100644 --- a/src/crewai_tools/tools/scrape_website_tool/scrape_website_tool.py +++ b/src/crewai_tools/tools/scrape_website_tool/scrape_website_tool.py @@ -11,8 +11,6 @@ from pydantic import BaseModel, Field class FixedScrapeWebsiteToolSchema(BaseModel): """Input for ScrapeWebsiteTool.""" - pass - class ScrapeWebsiteToolSchema(FixedScrapeWebsiteToolSchema): """Input for ScrapeWebsiteTool.""" diff --git a/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py b/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py index 906bf6376..29c132ea9 100644 --- a/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py +++ b/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py @@ -10,17 +10,14 @@ from scrapegraph_py.logger import sgai_logger class ScrapegraphError(Exception): """Base exception for Scrapegraph-related errors""" - pass class RateLimitError(ScrapegraphError): """Raised when API rate limits are exceeded""" - pass class FixedScrapegraphScrapeToolSchema(BaseModel): """Input for ScrapegraphScrapeTool when website_url is fixed.""" - pass class ScrapegraphScrapeToolSchema(FixedScrapegraphScrapeToolSchema): @@ -32,7 +29,7 @@ class ScrapegraphScrapeToolSchema(FixedScrapegraphScrapeToolSchema): description="Prompt to guide the extraction of content", ) - @validator('website_url') + @validator("website_url") def validate_url(cls, v): """Validate URL format""" try: @@ -41,13 +38,15 @@ class ScrapegraphScrapeToolSchema(FixedScrapegraphScrapeToolSchema): raise ValueError return v except Exception: - raise ValueError("Invalid URL format. URL must include scheme (http/https) and domain") + raise ValueError( + "Invalid URL format. URL must include scheme (http/https) and domain" + ) class ScrapegraphScrapeTool(BaseTool): """ A tool that uses Scrapegraph AI to intelligently scrape website content. - + Raises: ValueError: If API key is missing or URL format is invalid RateLimitError: If API rate limits are exceeded @@ -55,7 +54,9 @@ class ScrapegraphScrapeTool(BaseTool): """ name: str = "Scrapegraph website scraper" - description: str = "A tool that uses Scrapegraph AI to intelligently scrape website content." + description: str = ( + "A tool that uses Scrapegraph AI to intelligently scrape website content." + ) args_schema: Type[BaseModel] = ScrapegraphScrapeToolSchema website_url: Optional[str] = None user_prompt: Optional[str] = None @@ -70,7 +71,7 @@ class ScrapegraphScrapeTool(BaseTool): ): super().__init__(**kwargs) self.api_key = api_key or os.getenv("SCRAPEGRAPH_API_KEY") - + if not self.api_key: raise ValueError("Scrapegraph API key is required") @@ -79,7 +80,7 @@ class ScrapegraphScrapeTool(BaseTool): self.website_url = website_url self.description = f"A tool that uses Scrapegraph AI to intelligently scrape {website_url}'s content." self.args_schema = FixedScrapegraphScrapeToolSchema - + if user_prompt is not None: self.user_prompt = user_prompt @@ -94,22 +95,24 @@ class ScrapegraphScrapeTool(BaseTool): if not all([result.scheme, result.netloc]): raise ValueError except Exception: - raise ValueError("Invalid URL format. URL must include scheme (http/https) and domain") + raise ValueError( + "Invalid URL format. URL must include scheme (http/https) and domain" + ) def _handle_api_response(self, response: dict) -> str: """Handle and validate API response""" if not response: raise RuntimeError("Empty response from Scrapegraph API") - + if "error" in response: error_msg = response.get("error", {}).get("message", "Unknown error") if "rate limit" in error_msg.lower(): raise RateLimitError(f"Rate limit exceeded: {error_msg}") raise RuntimeError(f"API error: {error_msg}") - + if "result" not in response: raise RuntimeError("Invalid response format from Scrapegraph API") - + return response["result"] def _run( @@ -117,7 +120,10 @@ class ScrapegraphScrapeTool(BaseTool): **kwargs: Any, ) -> Any: website_url = kwargs.get("website_url", self.website_url) - user_prompt = kwargs.get("user_prompt", self.user_prompt) or "Extract the main content of the webpage" + user_prompt = ( + kwargs.get("user_prompt", self.user_prompt) + or "Extract the main content of the webpage" + ) if not website_url: raise ValueError("website_url is required") diff --git a/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py b/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py index d7a55428d..8099a06ab 100644 --- a/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py +++ b/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py @@ -17,33 +17,36 @@ class FixedSeleniumScrapingToolSchema(BaseModel): class SeleniumScrapingToolSchema(FixedSeleniumScrapingToolSchema): """Input for SeleniumScrapingTool.""" - website_url: str = Field(..., description="Mandatory website url to read the file. Must start with http:// or https://") + website_url: str = Field( + ..., + description="Mandatory website url to read the file. Must start with http:// or https://", + ) css_element: str = Field( ..., description="Mandatory css reference for element to scrape from the website", ) - @validator('website_url') + @validator("website_url") def validate_website_url(cls, v): if not v: raise ValueError("Website URL cannot be empty") - + if len(v) > 2048: # Common maximum URL length raise ValueError("URL is too long (max 2048 characters)") - - if not re.match(r'^https?://', v): + + if not re.match(r"^https?://", v): raise ValueError("URL must start with http:// or https://") - + try: result = urlparse(v) if not all([result.scheme, result.netloc]): raise ValueError("Invalid URL format") except Exception as e: raise ValueError(f"Invalid URL: {str(e)}") - - if re.search(r'\s', v): + + if re.search(r"\s", v): raise ValueError("URL cannot contain whitespace") - + return v @@ -130,11 +133,11 @@ class SeleniumScrapingTool(BaseTool): def _create_driver(self, url, cookie, wait_time): if not url: raise ValueError("URL cannot be empty") - + # Validate URL format - if not re.match(r'^https?://', url): + if not re.match(r"^https?://", url): raise ValueError("URL must start with http:// or https://") - + options = Options() options.add_argument("--headless") driver = self.driver(options=options) diff --git a/src/crewai_tools/tools/serpapi_tool/serpapi_base_tool.py b/src/crewai_tools/tools/serpapi_tool/serpapi_base_tool.py index 98491190c..895f3aadc 100644 --- a/src/crewai_tools/tools/serpapi_tool/serpapi_base_tool.py +++ b/src/crewai_tools/tools/serpapi_tool/serpapi_base_tool.py @@ -1,9 +1,10 @@ import os import re -from typing import Optional, Any, Union +from typing import Any, Optional, Union from crewai.tools import BaseTool + class SerpApiBaseTool(BaseTool): """Base class for SerpApi functionality with shared capabilities.""" diff --git a/src/crewai_tools/tools/serpapi_tool/serpapi_google_search_tool.py b/src/crewai_tools/tools/serpapi_tool/serpapi_google_search_tool.py index 199b7f5a2..c1a877f23 100644 --- a/src/crewai_tools/tools/serpapi_tool/serpapi_google_search_tool.py +++ b/src/crewai_tools/tools/serpapi_tool/serpapi_google_search_tool.py @@ -1,14 +1,21 @@ -from typing import Any, Type, Optional +from typing import Any, Optional, Type -import re from pydantic import BaseModel, Field -from .serpapi_base_tool import SerpApiBaseTool from serpapi import HTTPError +from .serpapi_base_tool import SerpApiBaseTool + + class SerpApiGoogleSearchToolSchema(BaseModel): """Input for Google Search.""" - search_query: str = Field(..., description="Mandatory search query you want to use to Google search.") - location: Optional[str] = Field(None, description="Location you want the search to be performed in.") + + search_query: str = Field( + ..., description="Mandatory search query you want to use to Google search." + ) + location: Optional[str] = Field( + None, description="Location you want the search to be performed in." + ) + class SerpApiGoogleSearchTool(SerpApiBaseTool): name: str = "Google Search" @@ -22,19 +29,25 @@ class SerpApiGoogleSearchTool(SerpApiBaseTool): **kwargs: Any, ) -> Any: try: - results = self.client.search({ - "q": kwargs.get("search_query"), - "location": kwargs.get("location"), - }).as_dict() + results = self.client.search( + { + "q": kwargs.get("search_query"), + "location": kwargs.get("location"), + } + ).as_dict() self._omit_fields( - results, - [r"search_metadata", r"search_parameters", r"serpapi_.+", r".+_token", r"displayed_link", r"pagination"] + results, + [ + r"search_metadata", + r"search_parameters", + r"serpapi_.+", + r".+_token", + r"displayed_link", + r"pagination", + ], ) return results except HTTPError as e: return f"An error occurred: {str(e)}. Some parameters may be invalid." - - - \ No newline at end of file diff --git a/src/crewai_tools/tools/serpapi_tool/serpapi_google_shopping_tool.py b/src/crewai_tools/tools/serpapi_tool/serpapi_google_shopping_tool.py index b44b3a809..ec9477351 100644 --- a/src/crewai_tools/tools/serpapi_tool/serpapi_google_shopping_tool.py +++ b/src/crewai_tools/tools/serpapi_tool/serpapi_google_shopping_tool.py @@ -1,14 +1,20 @@ -from typing import Any, Type, Optional +from typing import Any, Optional, Type -import re from pydantic import BaseModel, Field -from .serpapi_base_tool import SerpApiBaseTool from serpapi import HTTPError +from .serpapi_base_tool import SerpApiBaseTool + + class SerpApiGoogleShoppingToolSchema(BaseModel): """Input for Google Shopping.""" - search_query: str = Field(..., description="Mandatory search query you want to use to Google shopping.") - location: Optional[str] = Field(None, description="Location you want the search to be performed in.") + + search_query: str = Field( + ..., description="Mandatory search query you want to use to Google shopping." + ) + location: Optional[str] = Field( + None, description="Location you want the search to be performed in." + ) class SerpApiGoogleShoppingTool(SerpApiBaseTool): @@ -23,20 +29,25 @@ class SerpApiGoogleShoppingTool(SerpApiBaseTool): **kwargs: Any, ) -> Any: try: - results = self.client.search({ - "engine": "google_shopping", - "q": kwargs.get("search_query"), - "location": kwargs.get("location") - }).as_dict() + results = self.client.search( + { + "engine": "google_shopping", + "q": kwargs.get("search_query"), + "location": kwargs.get("location"), + } + ).as_dict() self._omit_fields( - results, - [r"search_metadata", r"search_parameters", r"serpapi_.+", r"filters", r"pagination"] + results, + [ + r"search_metadata", + r"search_parameters", + r"serpapi_.+", + r"filters", + r"pagination", + ], ) return results except HTTPError as e: return f"An error occurred: {str(e)}. Some parameters may be invalid." - - - \ No newline at end of file diff --git a/src/crewai_tools/tools/serper_dev_tool/serper_dev_tool.py b/src/crewai_tools/tools/serper_dev_tool/serper_dev_tool.py index e9eab56a2..2db347190 100644 --- a/src/crewai_tools/tools/serper_dev_tool/serper_dev_tool.py +++ b/src/crewai_tools/tools/serper_dev_tool/serper_dev_tool.py @@ -1,19 +1,19 @@ import datetime import json -import os import logging +import os from typing import Any, Type import requests from crewai.tools import BaseTool from pydantic import BaseModel, Field - logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) + def _save_results_to_file(content: str) -> None: """Saves the search results to a file.""" try: diff --git a/src/crewai_tools/tools/serply_api_tool/serply_webpage_to_markdown_tool.py b/src/crewai_tools/tools/serply_api_tool/serply_webpage_to_markdown_tool.py index e09a36fd9..4010236cc 100644 --- a/src/crewai_tools/tools/serply_api_tool/serply_webpage_to_markdown_tool.py +++ b/src/crewai_tools/tools/serply_api_tool/serply_webpage_to_markdown_tool.py @@ -18,9 +18,7 @@ class SerplyWebpageToMarkdownToolSchema(BaseModel): class SerplyWebpageToMarkdownTool(RagTool): name: str = "Webpage to Markdown" - description: str = ( - "A tool to perform convert a webpage to markdown to make it easier for LLMs to understand" - ) + description: str = "A tool to perform convert a webpage to markdown to make it easier for LLMs to understand" args_schema: Type[BaseModel] = SerplyWebpageToMarkdownToolSchema request_url: str = "https://api.serply.io/v1/request" proxy_location: Optional[str] = "US" diff --git a/src/crewai_tools/tools/snowflake_search_tool/README.md b/src/crewai_tools/tools/snowflake_search_tool/README.md new file mode 100644 index 000000000..fc0b845c3 --- /dev/null +++ b/src/crewai_tools/tools/snowflake_search_tool/README.md @@ -0,0 +1,155 @@ +# Snowflake Search Tool + +A tool for executing queries on Snowflake data warehouse with built-in connection pooling, retry logic, and async execution support. + +## Installation + +```bash +uv sync --extra snowflake + +OR +uv pip install snowflake-connector-python>=3.5.0 snowflake-sqlalchemy>=1.5.0 cryptography>=41.0.0 + +OR +pip install snowflake-connector-python>=3.5.0 snowflake-sqlalchemy>=1.5.0 cryptography>=41.0.0 +``` + +## Quick Start + +```python +import asyncio +from crewai_tools import SnowflakeSearchTool, SnowflakeConfig + +# Create configuration +config = SnowflakeConfig( + account="your_account", + user="your_username", + password="your_password", + warehouse="COMPUTE_WH", + database="your_database", + snowflake_schema="your_schema" # Note: Uses snowflake_schema instead of schema +) + +# Initialize tool +tool = SnowflakeSearchTool( + config=config, + pool_size=5, + max_retries=3, + enable_caching=True +) + +# Execute query +async def main(): + results = await tool._run( + query="SELECT * FROM your_table LIMIT 10", + timeout=300 + ) + print(f"Retrieved {len(results)} rows") + +if __name__ == "__main__": + asyncio.run(main()) +``` + +## Features + +- ✨ Asynchronous query execution +- 🚀 Connection pooling for better performance +- 🔄 Automatic retries for transient failures +- 💾 Query result caching (optional) +- 🔒 Support for both password and key-pair authentication +- 📝 Comprehensive error handling and logging + +## Configuration Options + +### SnowflakeConfig Parameters + +| Parameter | Required | Description | +|-----------|----------|-------------| +| account | Yes | Snowflake account identifier | +| user | Yes | Snowflake username | +| password | Yes* | Snowflake password | +| private_key_path | No* | Path to private key file (alternative to password) | +| warehouse | Yes | Snowflake warehouse name | +| database | Yes | Default database | +| snowflake_schema | Yes | Default schema | +| role | No | Snowflake role | +| session_parameters | No | Custom session parameters dict | + +\* Either password or private_key_path must be provided + +### Tool Parameters + +| Parameter | Default | Description | +|-----------|---------|-------------| +| pool_size | 5 | Number of connections in the pool | +| max_retries | 3 | Maximum retry attempts for failed queries | +| retry_delay | 1.0 | Delay between retries in seconds | +| enable_caching | True | Enable/disable query result caching | + +## Advanced Usage + +### Using Key-Pair Authentication + +```python +config = SnowflakeConfig( + account="your_account", + user="your_username", + private_key_path="/path/to/private_key.p8", + warehouse="your_warehouse", + database="your_database", + snowflake_schema="your_schema" +) +``` + +### Custom Session Parameters + +```python +config = SnowflakeConfig( + # ... other config parameters ... + session_parameters={ + "QUERY_TAG": "my_app", + "TIMEZONE": "America/Los_Angeles" + } +) +``` + +## Best Practices + +1. **Error Handling**: Always wrap query execution in try-except blocks +2. **Logging**: Enable logging to track query execution and errors +3. **Connection Management**: Use appropriate pool sizes for your workload +4. **Timeouts**: Set reasonable query timeouts to prevent hanging +5. **Security**: Use key-pair auth in production and never hardcode credentials + +## Example with Logging + +```python +import logging + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +async def main(): + try: + # ... tool initialization ... + results = await tool._run(query="SELECT * FROM table LIMIT 10") + logger.info(f"Query completed successfully. Retrieved {len(results)} rows") + except Exception as e: + logger.error(f"Query failed: {str(e)}") + raise +``` + +## Error Handling + +The tool automatically handles common Snowflake errors: +- DatabaseError +- OperationalError +- ProgrammingError +- Network timeouts +- Connection issues + +Errors are logged and retried based on your retry configuration. \ No newline at end of file diff --git a/src/crewai_tools/tools/snowflake_search_tool/__init__.py b/src/crewai_tools/tools/snowflake_search_tool/__init__.py new file mode 100644 index 000000000..abc1a45f5 --- /dev/null +++ b/src/crewai_tools/tools/snowflake_search_tool/__init__.py @@ -0,0 +1,11 @@ +from .snowflake_search_tool import ( + SnowflakeConfig, + SnowflakeSearchTool, + SnowflakeSearchToolInput, +) + +__all__ = [ + "SnowflakeSearchTool", + "SnowflakeSearchToolInput", + "SnowflakeConfig", +] diff --git a/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py b/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py new file mode 100644 index 000000000..75c671d21 --- /dev/null +++ b/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py @@ -0,0 +1,201 @@ +import asyncio +import logging +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Dict, List, Optional, Type + +import snowflake.connector +from crewai.tools.base_tool import BaseTool +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization +from pydantic import BaseModel, ConfigDict, Field, SecretStr +from snowflake.connector.connection import SnowflakeConnection +from snowflake.connector.errors import DatabaseError, OperationalError + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Cache for query results +_query_cache = {} + + +class SnowflakeConfig(BaseModel): + """Configuration for Snowflake connection.""" + + model_config = ConfigDict(protected_namespaces=()) + + account: str = Field( + ..., description="Snowflake account identifier", pattern=r"^[a-zA-Z0-9\-_]+$" + ) + user: str = Field(..., description="Snowflake username") + password: Optional[SecretStr] = Field(None, description="Snowflake password") + private_key_path: Optional[str] = Field( + None, description="Path to private key file" + ) + warehouse: Optional[str] = Field(None, description="Snowflake warehouse") + database: Optional[str] = Field(None, description="Default database") + snowflake_schema: Optional[str] = Field(None, description="Default schema") + role: Optional[str] = Field(None, description="Snowflake role") + session_parameters: Optional[Dict[str, Any]] = Field( + default_factory=dict, description="Session parameters" + ) + + @property + def has_auth(self) -> bool: + return bool(self.password or self.private_key_path) + + def model_post_init(self, *args, **kwargs): + if not self.has_auth: + raise ValueError("Either password or private_key_path must be provided") + + +class SnowflakeSearchToolInput(BaseModel): + """Input schema for SnowflakeSearchTool.""" + + model_config = ConfigDict(protected_namespaces=()) + + query: str = Field(..., description="SQL query or semantic search query to execute") + database: Optional[str] = Field(None, description="Override default database") + snowflake_schema: Optional[str] = Field(None, description="Override default schema") + timeout: Optional[int] = Field(300, description="Query timeout in seconds") + + +class SnowflakeSearchTool(BaseTool): + """Tool for executing queries and semantic search on Snowflake.""" + + name: str = "Snowflake Database Search" + description: str = ( + "Execute SQL queries or semantic search on Snowflake data warehouse. " + "Supports both raw SQL and natural language queries." + ) + args_schema: Type[BaseModel] = SnowflakeSearchToolInput + + # Define Pydantic fields + config: SnowflakeConfig = Field( + ..., description="Snowflake connection configuration" + ) + pool_size: int = Field(default=5, description="Size of connection pool") + max_retries: int = Field(default=3, description="Maximum retry attempts") + retry_delay: float = Field( + default=1.0, description="Delay between retries in seconds" + ) + enable_caching: bool = Field( + default=True, description="Enable query result caching" + ) + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def __init__(self, **data): + """Initialize SnowflakeSearchTool.""" + super().__init__(**data) + self._connection_pool: List[SnowflakeConnection] = [] + self._pool_lock = asyncio.Lock() + self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size) + + async def _get_connection(self) -> SnowflakeConnection: + """Get a connection from the pool or create a new one.""" + async with self._pool_lock: + if not self._connection_pool: + conn = self._create_connection() + self._connection_pool.append(conn) + return self._connection_pool.pop() + + def _create_connection(self) -> SnowflakeConnection: + """Create a new Snowflake connection.""" + conn_params = { + "account": self.config.account, + "user": self.config.user, + "warehouse": self.config.warehouse, + "database": self.config.database, + "schema": self.config.snowflake_schema, + "role": self.config.role, + "session_parameters": self.config.session_parameters, + } + + if self.config.password: + conn_params["password"] = self.config.password.get_secret_value() + elif self.config.private_key_path: + with open(self.config.private_key_path, "rb") as key_file: + p_key = serialization.load_pem_private_key( + key_file.read(), password=None, backend=default_backend() + ) + conn_params["private_key"] = p_key + + return snowflake.connector.connect(**conn_params) + + def _get_cache_key(self, query: str, timeout: int) -> str: + """Generate a cache key for the query.""" + return f"{self.config.account}:{self.config.database}:{self.config.snowflake_schema}:{query}:{timeout}" + + async def _execute_query( + self, query: str, timeout: int = 300 + ) -> List[Dict[str, Any]]: + """Execute a query with retries and return results.""" + if self.enable_caching: + cache_key = self._get_cache_key(query, timeout) + if cache_key in _query_cache: + logger.info("Returning cached result") + return _query_cache[cache_key] + + for attempt in range(self.max_retries): + try: + conn = await self._get_connection() + try: + cursor = conn.cursor() + cursor.execute(query, timeout=timeout) + + if not cursor.description: + return [] + + columns = [col[0] for col in cursor.description] + results = [dict(zip(columns, row)) for row in cursor.fetchall()] + + if self.enable_caching: + _query_cache[self._get_cache_key(query, timeout)] = results + + return results + finally: + cursor.close() + async with self._pool_lock: + self._connection_pool.append(conn) + except (DatabaseError, OperationalError) as e: + if attempt == self.max_retries - 1: + raise + await asyncio.sleep(self.retry_delay * (2**attempt)) + logger.warning(f"Query failed, attempt {attempt + 1}: {str(e)}") + continue + + async def _run( + self, + query: str, + database: Optional[str] = None, + snowflake_schema: Optional[str] = None, + timeout: int = 300, + **kwargs: Any, + ) -> Any: + """Execute the search query.""" + try: + # Override database/schema if provided + if database: + await self._execute_query(f"USE DATABASE {database}") + if snowflake_schema: + await self._execute_query(f"USE SCHEMA {snowflake_schema}") + + results = await self._execute_query(query, timeout) + return results + except Exception as e: + logger.error(f"Error executing query: {str(e)}") + raise + + def __del__(self): + """Cleanup connections on deletion.""" + try: + for conn in getattr(self, "_connection_pool", []): + try: + conn.close() + except: + pass + if hasattr(self, "_thread_pool"): + self._thread_pool.shutdown() + except: + pass diff --git a/src/crewai_tools/tools/stagehand_tool/stagehand_tool.py b/src/crewai_tools/tools/stagehand_tool/stagehand_tool.py index 07c76c8c3..37b414509 100644 --- a/src/crewai_tools/tools/stagehand_tool/stagehand_tool.py +++ b/src/crewai_tools/tools/stagehand_tool/stagehand_tool.py @@ -14,9 +14,8 @@ import os from functools import lru_cache from typing import Any, Dict, List, Optional, Type, Union -from pydantic import BaseModel, Field - from crewai.tools.base_tool import BaseTool +from pydantic import BaseModel, Field # Set up logging logger = logging.getLogger(__name__) @@ -25,6 +24,7 @@ logger = logging.getLogger(__name__) STAGEHAND_AVAILABLE = False try: import stagehand + STAGEHAND_AVAILABLE = True except ImportError: pass # Keep STAGEHAND_AVAILABLE as False @@ -32,33 +32,45 @@ except ImportError: class StagehandResult(BaseModel): """Result from a Stagehand operation. - + Attributes: success: Whether the operation completed successfully data: The result data from the operation error: Optional error message if the operation failed """ - success: bool = Field(..., description="Whether the operation completed successfully") - data: Union[str, Dict, List] = Field(..., description="The result data from the operation") - error: Optional[str] = Field(None, description="Optional error message if the operation failed") + + success: bool = Field( + ..., description="Whether the operation completed successfully" + ) + data: Union[str, Dict, List] = Field( + ..., description="The result data from the operation" + ) + error: Optional[str] = Field( + None, description="Optional error message if the operation failed" + ) class StagehandToolConfig(BaseModel): """Configuration for the StagehandTool. - + Attributes: api_key: OpenAI API key for Stagehand authentication timeout: Maximum time in seconds to wait for operations (default: 30) retry_attempts: Number of times to retry failed operations (default: 3) """ + api_key: str = Field(..., description="OpenAI API key for Stagehand authentication") - timeout: int = Field(30, description="Maximum time in seconds to wait for operations") - retry_attempts: int = Field(3, description="Number of times to retry failed operations") + timeout: int = Field( + 30, description="Maximum time in seconds to wait for operations" + ) + retry_attempts: int = Field( + 3, description="Number of times to retry failed operations" + ) class StagehandToolSchema(BaseModel): """Schema for the StagehandTool input parameters. - + Examples: ```python # Using the 'act' API to click a button @@ -66,13 +78,13 @@ class StagehandToolSchema(BaseModel): api_method="act", instruction="Click the 'Sign In' button" ) - + # Using the 'extract' API to get text tool.run( api_method="extract", instruction="Get the text content of the main article" ) - + # Using the 'observe' API to monitor changes tool.run( api_method="observe", @@ -80,48 +92,49 @@ class StagehandToolSchema(BaseModel): ) ``` """ + api_method: str = Field( ..., description="The Stagehand API to use: 'act' for interactions, 'extract' for getting content, or 'observe' for monitoring changes", - pattern="^(act|extract|observe)$" + pattern="^(act|extract|observe)$", ) instruction: str = Field( ..., description="An atomic instruction for Stagehand to execute. Instructions should be simple and specific to increase reliability.", min_length=1, - max_length=500 + max_length=500, ) class StagehandTool(BaseTool): """A tool for using Stagehand's AI-powered web automation capabilities. - + This tool provides access to Stagehand's three core APIs: - act: Perform web interactions (e.g., clicking buttons, filling forms) - extract: Extract information from web pages (e.g., getting text content) - observe: Monitor web page changes (e.g., watching for updates) - + Each function takes atomic instructions to increase reliability. - + Required Environment Variables: OPENAI_API_KEY: API key for OpenAI (required by Stagehand) - + Examples: ```python tool = StagehandTool() - + # Perform a web interaction result = tool.run( api_method="act", instruction="Click the 'Sign In' button" ) - + # Extract content from a page content = tool.run( api_method="extract", instruction="Get the text content of the main article" ) - + # Monitor for changes changes = tool.run( api_method="observe", @@ -129,7 +142,7 @@ class StagehandTool(BaseTool): ) ``` """ - + name: str = "StagehandTool" description: str = ( "A tool that uses Stagehand's AI-powered web automation to interact with websites. " @@ -137,27 +150,29 @@ class StagehandTool(BaseTool): "Each instruction should be atomic (simple and specific) to increase reliability." ) args_schema: Type[BaseModel] = StagehandToolSchema - - def __init__(self, config: StagehandToolConfig | None = None, **kwargs: Any) -> None: + + def __init__( + self, config: StagehandToolConfig | None = None, **kwargs: Any + ) -> None: """Initialize the StagehandTool. - + Args: config: Optional configuration for the tool. If not provided, will attempt to use OPENAI_API_KEY from environment. **kwargs: Additional keyword arguments passed to the base class. - + Raises: ImportError: If the stagehand package is not installed ValueError: If no API key is provided via config or environment """ super().__init__(**kwargs) - + if not STAGEHAND_AVAILABLE: raise ImportError( "The 'stagehand' package is required to use this tool. " "Please install it with: pip install stagehand" ) - + # Use config if provided, otherwise try environment variable if config is not None: self.config = config @@ -168,24 +183,22 @@ class StagehandTool(BaseTool): "Either provide config with api_key or set OPENAI_API_KEY environment variable" ) self.config = StagehandToolConfig( - api_key=api_key, - timeout=30, - retry_attempts=3 + api_key=api_key, timeout=30, retry_attempts=3 ) - + @lru_cache(maxsize=100) def _cached_run(self, api_method: str, instruction: str) -> Any: """Execute a cached Stagehand command. - + This method is cached to improve performance for repeated operations. - + Args: api_method: The Stagehand API to use ('act', 'extract', or 'observe') instruction: An atomic instruction for Stagehand to execute - + Returns: The raw result from the Stagehand API call - + Raises: ValueError: If an invalid api_method is provided Exception: If the Stagehand API call fails @@ -193,23 +206,25 @@ class StagehandTool(BaseTool): logger.debug( "Cache operation - Method: %s, Instruction length: %d", api_method, - len(instruction) + len(instruction), ) - + # Initialize Stagehand with configuration logger.info( "Initializing Stagehand (timeout=%ds, retries=%d)", self.config.timeout, - self.config.retry_attempts + self.config.retry_attempts, ) st = stagehand.Stagehand( api_key=self.config.api_key, timeout=self.config.timeout, - retry_attempts=self.config.retry_attempts + retry_attempts=self.config.retry_attempts, ) - + # Call the appropriate Stagehand API based on the method - logger.info("Executing %s operation with instruction: %s", api_method, instruction[:100]) + logger.info( + "Executing %s operation with instruction: %s", api_method, instruction[:100] + ) try: if api_method == "act": result = st.act(instruction) @@ -219,28 +234,27 @@ class StagehandTool(BaseTool): result = st.observe(instruction) else: raise ValueError(f"Unknown api_method: {api_method}") - - + logger.info("Successfully executed %s operation", api_method) return result - + except Exception as e: logger.warning( "Operation failed (method=%s, error=%s), will be retried on next attempt", api_method, - str(e) + str(e), ) raise def _run(self, api_method: str, instruction: str, **kwargs: Any) -> StagehandResult: """Execute a Stagehand command using the specified API method. - + Args: api_method: The Stagehand API to use ('act', 'extract', or 'observe') instruction: An atomic instruction for Stagehand to execute **kwargs: Additional keyword arguments passed to the Stagehand API - - Returns: + + Returns: StagehandResult containing the operation result and status """ try: @@ -249,56 +263,36 @@ class StagehandTool(BaseTool): "Starting operation - Method: %s, Instruction length: %d, Args: %s", api_method, len(instruction), - kwargs + kwargs, ) - + # Use cached execution result = self._cached_run(api_method, instruction) logger.info("Operation completed successfully") return StagehandResult(success=True, data=result) - + except stagehand.AuthenticationError as e: logger.error( - "Authentication failed - Method: %s, Error: %s", - api_method, - str(e) + "Authentication failed - Method: %s, Error: %s", api_method, str(e) ) return StagehandResult( - success=False, - data={}, - error=f"Authentication failed: {str(e)}" + success=False, data={}, error=f"Authentication failed: {str(e)}" ) except stagehand.APIError as e: - logger.error( - "API error - Method: %s, Error: %s", - api_method, - str(e) - ) - return StagehandResult( - success=False, - data={}, - error=f"API error: {str(e)}" - ) + logger.error("API error - Method: %s, Error: %s", api_method, str(e)) + return StagehandResult(success=False, data={}, error=f"API error: {str(e)}") except stagehand.BrowserError as e: - logger.error( - "Browser error - Method: %s, Error: %s", - api_method, - str(e) - ) + logger.error("Browser error - Method: %s, Error: %s", api_method, str(e)) return StagehandResult( - success=False, - data={}, - error=f"Browser error: {str(e)}" + success=False, data={}, error=f"Browser error: {str(e)}" ) except Exception as e: logger.error( "Unexpected error - Method: %s, Error type: %s, Message: %s", api_method, type(e).__name__, - str(e) + str(e), ) return StagehandResult( - success=False, - data={}, - error=f"Unexpected error: {str(e)}" + success=False, data={}, error=f"Unexpected error: {str(e)}" ) diff --git a/src/crewai_tools/tools/vision_tool/vision_tool.py b/src/crewai_tools/tools/vision_tool/vision_tool.py index 4fbc1df0e..594be0b22 100644 --- a/src/crewai_tools/tools/vision_tool/vision_tool.py +++ b/src/crewai_tools/tools/vision_tool/vision_tool.py @@ -1,30 +1,36 @@ import base64 -from typing import Type, Optional from pathlib import Path +from typing import Optional, Type + from crewai.tools import BaseTool from openai import OpenAI from pydantic import BaseModel, validator + class ImagePromptSchema(BaseModel): """Input for Vision Tool.""" + image_path_url: str = "The image path or URL." @validator("image_path_url") def validate_image_path_url(cls, v: str) -> str: if v.startswith("http"): return v - + path = Path(v) if not path.exists(): raise ValueError(f"Image file does not exist: {v}") - + # Validate supported formats valid_extensions = {".jpg", ".jpeg", ".png", ".gif", ".webp"} if path.suffix.lower() not in valid_extensions: - raise ValueError(f"Unsupported image format. Supported formats: {valid_extensions}") - + raise ValueError( + f"Unsupported image format. Supported formats: {valid_extensions}" + ) + return v + class VisionTool(BaseTool): name: str = "Vision Tool" description: str = ( @@ -45,10 +51,10 @@ class VisionTool(BaseTool): image_path_url = kwargs.get("image_path_url") if not image_path_url: return "Image Path or URL is required." - + # Validate input using Pydantic ImagePromptSchema(image_path_url=image_path_url) - + if image_path_url.startswith("http"): image_data = image_path_url else: @@ -68,12 +74,12 @@ class VisionTool(BaseTool): { "type": "image_url", "image_url": {"url": image_data}, - } + }, ], } ], max_tokens=300, - ) + ) return response.choices[0].message.content diff --git a/src/crewai_tools/tools/weaviate_tool/vector_search.py b/src/crewai_tools/tools/weaviate_tool/vector_search.py index 14e10d7c5..53f641272 100644 --- a/src/crewai_tools/tools/weaviate_tool/vector_search.py +++ b/src/crewai_tools/tools/weaviate_tool/vector_search.py @@ -15,9 +15,8 @@ except ImportError: Vectorizers = Any Auth = Any -from pydantic import BaseModel, Field - from crewai.tools import BaseTool +from pydantic import BaseModel, Field class WeaviateToolSchema(BaseModel): diff --git a/src/crewai_tools/tools/website_search/website_search_tool.py b/src/crewai_tools/tools/website_search/website_search_tool.py index faa1a02e8..842462546 100644 --- a/src/crewai_tools/tools/website_search/website_search_tool.py +++ b/src/crewai_tools/tools/website_search/website_search_tool.py @@ -25,9 +25,7 @@ class WebsiteSearchToolSchema(FixedWebsiteSearchToolSchema): class WebsiteSearchTool(RagTool): name: str = "Search in a specific website" - description: str = ( - "A tool that can be used to semantic search a query from a specific URL content." - ) + description: str = "A tool that can be used to semantic search a query from a specific URL content." args_schema: Type[BaseModel] = WebsiteSearchToolSchema def __init__(self, website: Optional[str] = None, **kwargs): diff --git a/src/crewai_tools/tools/youtube_channel_search_tool/youtube_channel_search_tool.py b/src/crewai_tools/tools/youtube_channel_search_tool/youtube_channel_search_tool.py index b0c6209f1..81ecc30c3 100644 --- a/src/crewai_tools/tools/youtube_channel_search_tool/youtube_channel_search_tool.py +++ b/src/crewai_tools/tools/youtube_channel_search_tool/youtube_channel_search_tool.py @@ -25,9 +25,7 @@ class YoutubeChannelSearchToolSchema(FixedYoutubeChannelSearchToolSchema): class YoutubeChannelSearchTool(RagTool): name: str = "Search a Youtube Channels content" - description: str = ( - "A tool that can be used to semantic search a query from a Youtube Channels content." - ) + description: str = "A tool that can be used to semantic search a query from a Youtube Channels content." args_schema: Type[BaseModel] = YoutubeChannelSearchToolSchema def __init__(self, youtube_channel_handle: Optional[str] = None, **kwargs): diff --git a/src/crewai_tools/tools/youtube_video_search_tool/youtube_video_search_tool.py b/src/crewai_tools/tools/youtube_video_search_tool/youtube_video_search_tool.py index 6852fafb4..1ad8434c8 100644 --- a/src/crewai_tools/tools/youtube_video_search_tool/youtube_video_search_tool.py +++ b/src/crewai_tools/tools/youtube_video_search_tool/youtube_video_search_tool.py @@ -25,9 +25,7 @@ class YoutubeVideoSearchToolSchema(FixedYoutubeVideoSearchToolSchema): class YoutubeVideoSearchTool(RagTool): name: str = "Search a Youtube Video content" - description: str = ( - "A tool that can be used to semantic search a query from a Youtube Video content." - ) + description: str = "A tool that can be used to semantic search a query from a Youtube Video content." args_schema: Type[BaseModel] = YoutubeVideoSearchToolSchema def __init__(self, youtube_video_url: Optional[str] = None, **kwargs): diff --git a/tests/base_tool_test.py b/tests/base_tool_test.py index 4a4e40783..e6f4f127d 100644 --- a/tests/base_tool_test.py +++ b/tests/base_tool_test.py @@ -1,69 +1,104 @@ from typing import Callable + from crewai.tools import BaseTool, tool from crewai.tools.base_tool import to_langchain + def test_creating_a_tool_using_annotation(): - @tool("Name of my tool") - def my_tool(question: str) -> str: - """Clear description for what this tool is useful for, you agent will need this information to use it.""" - return question + @tool("Name of my tool") + def my_tool(question: str) -> str: + """Clear description for what this tool is useful for, you agent will need this information to use it.""" + return question - # Assert all the right attributes were defined - assert my_tool.name == "Name of my tool" - assert my_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." - assert my_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}} - assert my_tool.func("What is the meaning of life?") == "What is the meaning of life?" + # Assert all the right attributes were defined + assert my_tool.name == "Name of my tool" + assert ( + my_tool.description + == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." + ) + assert my_tool.args_schema.schema()["properties"] == { + "question": {"title": "Question", "type": "string"} + } + assert ( + my_tool.func("What is the meaning of life?") == "What is the meaning of life?" + ) + + # Assert the langchain tool conversion worked as expected + converted_tool = to_langchain([my_tool])[0] + assert converted_tool.name == "Name of my tool" + assert ( + converted_tool.description + == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." + ) + assert converted_tool.args_schema.schema()["properties"] == { + "question": {"title": "Question", "type": "string"} + } + assert ( + converted_tool.func("What is the meaning of life?") + == "What is the meaning of life?" + ) - # Assert the langchain tool conversion worked as expected - converted_tool = to_langchain([my_tool])[0] - assert converted_tool.name == "Name of my tool" - assert converted_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." - assert converted_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}} - assert converted_tool.func("What is the meaning of life?") == "What is the meaning of life?" def test_creating_a_tool_using_baseclass(): - class MyCustomTool(BaseTool): - name: str = "Name of my tool" - description: str = "Clear description for what this tool is useful for, you agent will need this information to use it." + class MyCustomTool(BaseTool): + name: str = "Name of my tool" + description: str = "Clear description for what this tool is useful for, you agent will need this information to use it." - def _run(self, question: str) -> str: - return question + def _run(self, question: str) -> str: + return question - my_tool = MyCustomTool() - # Assert all the right attributes were defined - assert my_tool.name == "Name of my tool" - assert my_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." - assert my_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}} - assert my_tool._run("What is the meaning of life?") == "What is the meaning of life?" + my_tool = MyCustomTool() + # Assert all the right attributes were defined + assert my_tool.name == "Name of my tool" + assert ( + my_tool.description + == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." + ) + assert my_tool.args_schema.schema()["properties"] == { + "question": {"title": "Question", "type": "string"} + } + assert ( + my_tool._run("What is the meaning of life?") == "What is the meaning of life?" + ) + + # Assert the langchain tool conversion worked as expected + converted_tool = to_langchain([my_tool])[0] + assert converted_tool.name == "Name of my tool" + assert ( + converted_tool.description + == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." + ) + assert converted_tool.args_schema.schema()["properties"] == { + "question": {"title": "Question", "type": "string"} + } + assert ( + converted_tool.invoke({"question": "What is the meaning of life?"}) + == "What is the meaning of life?" + ) - # Assert the langchain tool conversion worked as expected - converted_tool = to_langchain([my_tool])[0] - assert converted_tool.name == "Name of my tool" - assert converted_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it." - assert converted_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}} - assert converted_tool.invoke({"question": "What is the meaning of life?"}) == "What is the meaning of life?" def test_setting_cache_function(): - class MyCustomTool(BaseTool): - name: str = "Name of my tool" - description: str = "Clear description for what this tool is useful for, you agent will need this information to use it." - cache_function: Callable = lambda: False + class MyCustomTool(BaseTool): + name: str = "Name of my tool" + description: str = "Clear description for what this tool is useful for, you agent will need this information to use it." + cache_function: Callable = lambda: False - def _run(self, question: str) -> str: - return question + def _run(self, question: str) -> str: + return question + + my_tool = MyCustomTool() + # Assert all the right attributes were defined + assert my_tool.cache_function() == False - my_tool = MyCustomTool() - # Assert all the right attributes were defined - assert my_tool.cache_function() == False def test_default_cache_function_is_true(): - class MyCustomTool(BaseTool): - name: str = "Name of my tool" - description: str = "Clear description for what this tool is useful for, you agent will need this information to use it." + class MyCustomTool(BaseTool): + name: str = "Name of my tool" + description: str = "Clear description for what this tool is useful for, you agent will need this information to use it." - def _run(self, question: str) -> str: - return question + def _run(self, question: str) -> str: + return question - my_tool = MyCustomTool() - # Assert all the right attributes were defined - assert my_tool.cache_function() == True \ No newline at end of file + my_tool = MyCustomTool() + # Assert all the right attributes were defined + assert my_tool.cache_function() == True diff --git a/tests/file_read_tool_test.py b/tests/file_read_tool_test.py index 4646df24c..5957f863b 100644 --- a/tests/file_read_tool_test.py +++ b/tests/file_read_tool_test.py @@ -1,7 +1,8 @@ import os -import pytest + from crewai_tools import FileReadTool + def test_file_read_tool_constructor(): """Test FileReadTool initialization with file_path.""" # Create a temporary test file @@ -18,6 +19,7 @@ def test_file_read_tool_constructor(): # Clean up os.remove(test_file) + def test_file_read_tool_run(): """Test FileReadTool _run method with file_path at runtime.""" # Create a temporary test file @@ -34,6 +36,7 @@ def test_file_read_tool_run(): # Clean up os.remove(test_file) + def test_file_read_tool_error_handling(): """Test FileReadTool error handling.""" # Test missing file path @@ -58,6 +61,7 @@ def test_file_read_tool_error_handling(): os.chmod(test_file, 0o666) # Restore permissions to delete os.remove(test_file) + def test_file_read_tool_constructor_and_run(): """Test FileReadTool using both constructor and runtime file paths.""" # Create two test files diff --git a/tests/it/tools/__init__.py b/tests/it/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/it/tools/conftest.py b/tests/it/tools/conftest.py new file mode 100644 index 000000000..a633c22c7 --- /dev/null +++ b/tests/it/tools/conftest.py @@ -0,0 +1,21 @@ +import pytest + + +def pytest_configure(config): + """Register custom markers.""" + config.addinivalue_line("markers", "integration: mark test as an integration test") + config.addinivalue_line("markers", "asyncio: mark test as an async test") + + # Set the asyncio loop scope through ini configuration + config.inicfg["asyncio_mode"] = "auto" + + +@pytest.fixture(scope="function") +def event_loop(): + """Create an instance of the default event loop for each test case.""" + import asyncio + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + yield loop + loop.close() diff --git a/tests/it/tools/snowflake_search_tool_test.py b/tests/it/tools/snowflake_search_tool_test.py new file mode 100644 index 000000000..70dc07953 --- /dev/null +++ b/tests/it/tools/snowflake_search_tool_test.py @@ -0,0 +1,219 @@ +import asyncio +import json +from decimal import Decimal + +import pytest +from snowflake.connector.errors import DatabaseError, OperationalError + +from crewai_tools import SnowflakeConfig, SnowflakeSearchTool + +# Test Data +MENU_ITEMS = [ + (10001, "Ice Cream", "Freezing Point", "Lemonade", "Beverage", "Cold Option", 1, 4), + ( + 10002, + "Ice Cream", + "Freezing Point", + "Vanilla Ice Cream", + "Dessert", + "Ice Cream", + 2, + 6, + ), +] + +INVALID_QUERIES = [ + ("SELECT * FROM nonexistent_table", "relation 'nonexistent_table' does not exist"), + ("SELECT invalid_column FROM menu", "invalid identifier 'invalid_column'"), + ("INVALID SQL QUERY", "SQL compilation error"), +] + + +# Integration Test Fixtures +@pytest.fixture +def config(): + """Create a Snowflake configuration with test credentials.""" + return SnowflakeConfig( + account="lwyhjun-wx11931", + user="crewgitci", + password="crewaiT00ls_publicCIpass123", + warehouse="COMPUTE_WH", + database="tasty_bytes_sample_data", + snowflake_schema="raw_pos", + ) + + +@pytest.fixture +def snowflake_tool(config): + """Create a SnowflakeSearchTool instance.""" + return SnowflakeSearchTool(config=config) + + +# Integration Tests with Real Snowflake Connection +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.parametrize( + "menu_id,expected_type,brand,item_name,category,subcategory,cost,price", MENU_ITEMS +) +async def test_menu_items( + snowflake_tool, + menu_id, + expected_type, + brand, + item_name, + category, + subcategory, + cost, + price, +): + """Test menu items with parameterized data for multiple test cases.""" + results = await snowflake_tool._run( + query=f"SELECT * FROM menu WHERE menu_id = {menu_id}" + ) + assert len(results) == 1 + menu_item = results[0] + + # Validate all fields + assert menu_item["MENU_ID"] == menu_id + assert menu_item["MENU_TYPE"] == expected_type + assert menu_item["TRUCK_BRAND_NAME"] == brand + assert menu_item["MENU_ITEM_NAME"] == item_name + assert menu_item["ITEM_CATEGORY"] == category + assert menu_item["ITEM_SUBCATEGORY"] == subcategory + assert menu_item["COST_OF_GOODS_USD"] == cost + assert menu_item["SALE_PRICE_USD"] == price + + # Validate health metrics JSON structure + health_metrics = json.loads(menu_item["MENU_ITEM_HEALTH_METRICS_OBJ"]) + assert "menu_item_health_metrics" in health_metrics + metrics = health_metrics["menu_item_health_metrics"][0] + assert "ingredients" in metrics + assert isinstance(metrics["ingredients"], list) + assert all(isinstance(ingredient, str) for ingredient in metrics["ingredients"]) + assert metrics["is_dairy_free_flag"] in ["Y", "N"] + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_menu_categories_aggregation(snowflake_tool): + """Test complex aggregation query on menu categories with detailed validations.""" + results = await snowflake_tool._run( + query=""" + SELECT + item_category, + COUNT(*) as item_count, + AVG(sale_price_usd) as avg_price, + SUM(sale_price_usd - cost_of_goods_usd) as total_margin, + COUNT(DISTINCT menu_type) as menu_type_count, + MIN(sale_price_usd) as min_price, + MAX(sale_price_usd) as max_price + FROM menu + GROUP BY item_category + HAVING COUNT(*) > 1 + ORDER BY item_count DESC + """ + ) + + assert len(results) > 0 + for category in results: + # Basic presence checks + assert all( + key in category + for key in [ + "ITEM_CATEGORY", + "ITEM_COUNT", + "AVG_PRICE", + "TOTAL_MARGIN", + "MENU_TYPE_COUNT", + "MIN_PRICE", + "MAX_PRICE", + ] + ) + + # Value validations + assert category["ITEM_COUNT"] > 1 # Due to HAVING clause + assert category["MIN_PRICE"] <= category["MAX_PRICE"] + assert category["AVG_PRICE"] >= category["MIN_PRICE"] + assert category["AVG_PRICE"] <= category["MAX_PRICE"] + assert category["MENU_TYPE_COUNT"] >= 1 + assert isinstance(category["TOTAL_MARGIN"], (float, Decimal)) + + +@pytest.mark.integration +@pytest.mark.asyncio +@pytest.mark.parametrize("invalid_query,expected_error", INVALID_QUERIES) +async def test_invalid_queries(snowflake_tool, invalid_query, expected_error): + """Test error handling for invalid queries.""" + with pytest.raises((DatabaseError, OperationalError)) as exc_info: + await snowflake_tool._run(query=invalid_query) + assert expected_error.lower() in str(exc_info.value).lower() + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_concurrent_queries(snowflake_tool): + """Test handling of concurrent queries.""" + queries = [ + "SELECT COUNT(*) FROM menu", + "SELECT COUNT(DISTINCT menu_type) FROM menu", + "SELECT COUNT(DISTINCT item_category) FROM menu", + ] + + tasks = [snowflake_tool._run(query=query) for query in queries] + results = await asyncio.gather(*tasks) + + assert len(results) == 3 + assert all(isinstance(result, list) for result in results) + assert all(len(result) == 1 for result in results) + assert all(isinstance(result[0], dict) for result in results) + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_query_timeout(snowflake_tool): + """Test query timeout handling with a complex query.""" + with pytest.raises((DatabaseError, OperationalError)) as exc_info: + await snowflake_tool._run( + query=""" + WITH RECURSIVE numbers AS ( + SELECT 1 as n + UNION ALL + SELECT n + 1 + FROM numbers + WHERE n < 1000000 + ) + SELECT COUNT(*) FROM numbers + """ + ) + assert ( + "timeout" in str(exc_info.value).lower() + or "execution time" in str(exc_info.value).lower() + ) + + +@pytest.mark.integration +@pytest.mark.asyncio +async def test_caching_behavior(snowflake_tool): + """Test query caching behavior and performance.""" + query = "SELECT * FROM menu LIMIT 5" + + # First execution + start_time = asyncio.get_event_loop().time() + results1 = await snowflake_tool._run(query=query) + first_duration = asyncio.get_event_loop().time() - start_time + + # Second execution (should be cached) + start_time = asyncio.get_event_loop().time() + results2 = await snowflake_tool._run(query=query) + second_duration = asyncio.get_event_loop().time() - start_time + + # Verify results + assert results1 == results2 + assert len(results1) == 5 + assert second_duration < first_duration + + # Verify cache invalidation with different query + different_query = "SELECT * FROM menu LIMIT 10" + different_results = await snowflake_tool._run(query=different_query) + assert len(different_results) == 10 + assert different_results != results1 diff --git a/tests/spider_tool_test.py b/tests/spider_tool_test.py index 264394777..7f5613fe6 100644 --- a/tests/spider_tool_test.py +++ b/tests/spider_tool_test.py @@ -1,5 +1,7 @@ +from crewai import Agent, Crew, Task + from crewai_tools.tools.spider_tool.spider_tool import SpiderTool -from crewai import Agent, Task, Crew + def test_spider_tool(): spider_tool = SpiderTool() @@ -10,38 +12,35 @@ def test_spider_tool(): backstory="An expert web researcher that uses the web extremely well", tools=[spider_tool], verbose=True, - cache=False + cache=False, ) choose_between_scrape_crawl = Task( description="Scrape the page of spider.cloud and return a summary of how fast it is", expected_output="spider.cloud is a fast scraping and crawling tool", - agent=searcher + agent=searcher, ) return_metadata = Task( description="Scrape https://spider.cloud with a limit of 1 and enable metadata", expected_output="Metadata and 10 word summary of spider.cloud", - agent=searcher + agent=searcher, ) css_selector = Task( description="Scrape one page of spider.cloud with the `body > div > main > section.grid.md\:grid-cols-2.gap-10.place-items-center.md\:max-w-screen-xl.mx-auto.pb-8.pt-20 > div:nth-child(1) > h1` CSS selector", expected_output="The content of the element with the css selector body > div > main > section.grid.md\:grid-cols-2.gap-10.place-items-center.md\:max-w-screen-xl.mx-auto.pb-8.pt-20 > div:nth-child(1) > h1", - agent=searcher + agent=searcher, ) crew = Crew( agents=[searcher], - tasks=[ - choose_between_scrape_crawl, - return_metadata, - css_selector - ], - verbose=True + tasks=[choose_between_scrape_crawl, return_metadata, css_selector], + verbose=True, ) crew.kickoff() + if __name__ == "__main__": test_spider_tool() diff --git a/tests/tools/snowflake_search_tool_test.py b/tests/tools/snowflake_search_tool_test.py new file mode 100644 index 000000000..d4851b8ab --- /dev/null +++ b/tests/tools/snowflake_search_tool_test.py @@ -0,0 +1,103 @@ +import asyncio +from unittest.mock import MagicMock, patch + +import pytest + +from crewai_tools import SnowflakeConfig, SnowflakeSearchTool + + +# Unit Test Fixtures +@pytest.fixture +def mock_snowflake_connection(): + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_cursor.description = [("col1",), ("col2",)] + mock_cursor.fetchall.return_value = [(1, "value1"), (2, "value2")] + mock_cursor.execute.return_value = None + mock_conn.cursor.return_value = mock_cursor + return mock_conn + + +@pytest.fixture +def mock_config(): + return SnowflakeConfig( + account="test_account", + user="test_user", + password="test_password", + warehouse="test_warehouse", + database="test_db", + snowflake_schema="test_schema", + ) + + +@pytest.fixture +def snowflake_tool(mock_config): + with patch("snowflake.connector.connect") as mock_connect: + tool = SnowflakeSearchTool(config=mock_config) + yield tool + + +# Unit Tests +@pytest.mark.asyncio +async def test_successful_query_execution(snowflake_tool, mock_snowflake_connection): + with patch.object(snowflake_tool, "_create_connection") as mock_create_conn: + mock_create_conn.return_value = mock_snowflake_connection + + results = await snowflake_tool._run( + query="SELECT * FROM test_table", timeout=300 + ) + + assert len(results) == 2 + assert results[0]["col1"] == 1 + assert results[0]["col2"] == "value1" + mock_snowflake_connection.cursor.assert_called_once() + + +@pytest.mark.asyncio +async def test_connection_pooling(snowflake_tool, mock_snowflake_connection): + with patch.object(snowflake_tool, "_create_connection") as mock_create_conn: + mock_create_conn.return_value = mock_snowflake_connection + + # Execute multiple queries + await asyncio.gather( + snowflake_tool._run("SELECT 1"), + snowflake_tool._run("SELECT 2"), + snowflake_tool._run("SELECT 3"), + ) + + # Should reuse connections from pool + assert mock_create_conn.call_count <= snowflake_tool.pool_size + + +@pytest.mark.asyncio +async def test_cleanup_on_deletion(snowflake_tool, mock_snowflake_connection): + with patch.object(snowflake_tool, "_create_connection") as mock_create_conn: + mock_create_conn.return_value = mock_snowflake_connection + + # Add connection to pool + await snowflake_tool._get_connection() + + # Return connection to pool + async with snowflake_tool._pool_lock: + snowflake_tool._connection_pool.append(mock_snowflake_connection) + + # Trigger cleanup + snowflake_tool.__del__() + + mock_snowflake_connection.close.assert_called_once() + + +def test_config_validation(): + # Test missing required fields + with pytest.raises(ValueError): + SnowflakeConfig() + + # Test invalid account format + with pytest.raises(ValueError): + SnowflakeConfig( + account="invalid//account", user="test_user", password="test_pass" + ) + + # Test missing authentication + with pytest.raises(ValueError): + SnowflakeConfig(account="test_account", user="test_user") diff --git a/tests/tools/test_code_interpreter_tool.py b/tests/tools/test_code_interpreter_tool.py index 6470c9dc1..e281fffaf 100644 --- a/tests/tools/test_code_interpreter_tool.py +++ b/tests/tools/test_code_interpreter_tool.py @@ -7,7 +7,9 @@ from crewai_tools.tools.code_interpreter_tool.code_interpreter_tool import ( class TestCodeInterpreterTool(unittest.TestCase): - @patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env") + @patch( + "crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env" + ) def test_run_code_in_docker(self, docker_mock): tool = CodeInterpreterTool() code = "print('Hello, World!')" @@ -15,14 +17,14 @@ class TestCodeInterpreterTool(unittest.TestCase): expected_output = "Hello, World!\n" docker_mock().containers.run().exec_run().exit_code = 0 - docker_mock().containers.run().exec_run().output = ( - expected_output.encode() - ) + docker_mock().containers.run().exec_run().output = expected_output.encode() result = tool.run_code_in_docker(code, libraries_used) self.assertEqual(result, expected_output) - @patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env") + @patch( + "crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env" + ) def test_run_code_in_docker_with_error(self, docker_mock): tool = CodeInterpreterTool() code = "print(1/0)" @@ -37,7 +39,9 @@ class TestCodeInterpreterTool(unittest.TestCase): self.assertEqual(result, expected_output) - @patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env") + @patch( + "crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env" + ) def test_run_code_in_docker_with_script(self, docker_mock): tool = CodeInterpreterTool() code = """print("This is line 1") From a606f48b70b346e70bf3bbfdad78290367f4469b Mon Sep 17 00:00:00 2001 From: ArchiusVuong-sudo Date: Sat, 18 Jan 2025 21:58:50 +0700 Subject: [PATCH 22/23] FIX: Fix HTTPError cannot be found in serperai --- .../tools/serpapi_tool/serpapi_google_search_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/crewai_tools/tools/serpapi_tool/serpapi_google_search_tool.py b/src/crewai_tools/tools/serpapi_tool/serpapi_google_search_tool.py index 199b7f5a2..f8edd6458 100644 --- a/src/crewai_tools/tools/serpapi_tool/serpapi_google_search_tool.py +++ b/src/crewai_tools/tools/serpapi_tool/serpapi_google_search_tool.py @@ -3,7 +3,7 @@ from typing import Any, Type, Optional import re from pydantic import BaseModel, Field from .serpapi_base_tool import SerpApiBaseTool -from serpapi import HTTPError +from urllib.error import HTTPError class SerpApiGoogleSearchToolSchema(BaseModel): """Input for Google Search.""" From 659cb6279e2b2833fea0d4c8da4946160100befd Mon Sep 17 00:00:00 2001 From: ArchiusVuong-sudo Date: Sat, 18 Jan 2025 23:01:01 +0700 Subject: [PATCH 23/23] fix: Fixed all from urllib.error import HTTPError --- .../tools/serpapi_tool/serpapi_google_shopping_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/crewai_tools/tools/serpapi_tool/serpapi_google_shopping_tool.py b/src/crewai_tools/tools/serpapi_tool/serpapi_google_shopping_tool.py index b44b3a809..5863239c5 100644 --- a/src/crewai_tools/tools/serpapi_tool/serpapi_google_shopping_tool.py +++ b/src/crewai_tools/tools/serpapi_tool/serpapi_google_shopping_tool.py @@ -3,7 +3,7 @@ from typing import Any, Type, Optional import re from pydantic import BaseModel, Field from .serpapi_base_tool import SerpApiBaseTool -from serpapi import HTTPError +from urllib.error import HTTPError class SerpApiGoogleShoppingToolSchema(BaseModel): """Input for Google Shopping."""