Merge branch 'main' of github.com:crewAIInc/crewAI-tools into fix/optional-dependencies

This commit is contained in:
Lorenze Jay
2025-01-21 15:53:14 -08:00
50 changed files with 1242 additions and 188 deletions

View File

@@ -1,4 +1,5 @@
from .tools import (
AIMindTool,
BraveSearchTool,
BrowserbaseLoadTool,
CodeDocsSearchTool,
@@ -16,6 +17,7 @@ from .tools import (
FirecrawlScrapeWebsiteTool,
FirecrawlSearchTool,
GithubSearchTool,
HyperbrowserLoadTool,
JSONSearchTool,
LinkupSearchTool,
LlamaIndexTool,
@@ -43,6 +45,8 @@ from .tools import (
SerplyScholarSearchTool,
SerplyWebpageToMarkdownTool,
SerplyWebSearchTool,
SnowflakeConfig,
SnowflakeSearchTool,
SpiderTool,
TXTSearchTool,
VisionTool,

View File

@@ -1,3 +1,4 @@
from .ai_mind_tool.ai_mind_tool import AIMindTool
from .brave_search_tool.brave_search_tool import BraveSearchTool
from .browserbase_load_tool.browserbase_load_tool import BrowserbaseLoadTool
from .code_docs_search_tool.code_docs_search_tool import CodeDocsSearchTool
@@ -19,6 +20,7 @@ from .firecrawl_scrape_website_tool.firecrawl_scrape_website_tool import (
)
from .firecrawl_search_tool.firecrawl_search_tool import FirecrawlSearchTool
from .github_search_tool.github_search_tool import GithubSearchTool
from .hyperbrowser_load_tool.hyperbrowser_load_tool import HyperbrowserLoadTool
from .json_search_tool.json_search_tool import JSONSearchTool
from .linkup.linkup_search_tool import LinkupSearchTool
from .llamaindex_tool.llamaindex_tool import LlamaIndexTool
@@ -54,6 +56,11 @@ from .serply_api_tool.serply_news_search_tool import SerplyNewsSearchTool
from .serply_api_tool.serply_scholar_search_tool import SerplyScholarSearchTool
from .serply_api_tool.serply_web_search_tool import SerplyWebSearchTool
from .serply_api_tool.serply_webpage_to_markdown_tool import SerplyWebpageToMarkdownTool
from .snowflake_search_tool import (
SnowflakeConfig,
SnowflakeSearchTool,
SnowflakeSearchToolInput,
)
from .spider_tool.spider_tool import SpiderTool
from .txt_search_tool.txt_search_tool import TXTSearchTool
from .vision_tool.vision_tool import VisionTool

View File

@@ -0,0 +1,79 @@
# AIMind Tool
## Description
[Minds](https://mindsdb.com/minds) are AI systems provided by [MindsDB](https://mindsdb.com/) that work similarly to large language models (LLMs) but go beyond by answering any question from any data.
This is accomplished by selecting the most relevant data for an answer using parametric search, understanding the meaning and providing responses within the correct context through semantic search, and finally, delivering precise answers by analyzing data and using machine learning (ML) models.
The `AIMindTool` can be used to query data sources in natural language by simply configuring their connection parameters.
## Installation
1. Install the `crewai[tools]` package:
```shell
pip install 'crewai[tools]'
```
2. Install the Minds SDK:
```shell
pip install minds-sdk
```
3. Sign for a Minds account [here](https://mdb.ai/register), and obtain an API key.
4. Set the Minds API key in an environment variable named `MINDS_API_KEY`.
## Usage
```python
from crewai_tools import AIMindTool
# Initialize the AIMindTool.
aimind_tool = AIMindTool(
datasources=[
{
"description": "house sales data",
"engine": "postgres",
"connection_data": {
"user": "demo_user",
"password": "demo_password",
"host": "samples.mindsdb.com",
"port": 5432,
"database": "demo",
"schema": "demo_data"
},
"tables": ["house_sales"]
}
]
)
aimind_tool.run("How many 3 bedroom houses were sold in 2008?")
```
The `datasources` parameter is a list of dictionaries, each containing the following keys:
- `description`: A description of the data contained in the datasource.
- `engine`: The engine (or type) of the datasource. Find a list of supported engines in the link below.
- `connection_data`: A dictionary containing the connection parameters for the datasource. Find a list of connection parameters for each engine in the link below.
- `tables`: A list of tables that the data source will use. This is optional and can be omitted if all tables in the data source are to be used.
A list of supported data sources and their connection parameters can be found [here](https://docs.mdb.ai/docs/data_sources).
```python
from crewai import Agent
from crewai.project import agent
# Define an agent with the AIMindTool.
@agent
def researcher(self) -> Agent:
return Agent(
config=self.agents_config["researcher"],
allow_delegation=False,
tools=[aimind_tool]
)
```

View File

@@ -0,0 +1,87 @@
import os
import secrets
from typing import Any, Dict, List, Optional, Text, Type
from crewai.tools import BaseTool
from openai import OpenAI
from pydantic import BaseModel
class AIMindToolConstants:
MINDS_API_BASE_URL = "https://mdb.ai/"
MIND_NAME_PREFIX = "crwai_mind_"
DATASOURCE_NAME_PREFIX = "crwai_ds_"
class AIMindToolInputSchema(BaseModel):
"""Input for AIMind Tool."""
query: str = "Question in natural language to ask the AI-Mind"
class AIMindTool(BaseTool):
name: str = "AIMind Tool"
description: str = (
"A wrapper around [AI-Minds](https://mindsdb.com/minds). "
"Useful for when you need answers to questions from your data, stored in "
"data sources including PostgreSQL, MySQL, MariaDB, ClickHouse, Snowflake "
"and Google BigQuery. "
"Input should be a question in natural language."
)
args_schema: Type[BaseModel] = AIMindToolInputSchema
api_key: Optional[str] = None
datasources: Optional[List[Dict[str, Any]]] = None
mind_name: Optional[Text] = None
def __init__(self, api_key: Optional[Text] = None, **kwargs):
super().__init__(**kwargs)
self.api_key = api_key or os.getenv("MINDS_API_KEY")
if not self.api_key:
raise ValueError("API key must be provided either through constructor or MINDS_API_KEY environment variable")
try:
from minds.client import Client # type: ignore
from minds.datasources import DatabaseConfig # type: ignore
except ImportError:
raise ImportError(
"`minds_sdk` package not found, please run `pip install minds-sdk`"
)
minds_client = Client(api_key=self.api_key)
# Convert the datasources to DatabaseConfig objects.
datasources = []
for datasource in self.datasources:
config = DatabaseConfig(
name=f"{AIMindToolConstants.DATASOURCE_NAME_PREFIX}_{secrets.token_hex(5)}",
engine=datasource["engine"],
description=datasource["description"],
connection_data=datasource["connection_data"],
tables=datasource["tables"],
)
datasources.append(config)
# Generate a random name for the Mind.
name = f"{AIMindToolConstants.MIND_NAME_PREFIX}_{secrets.token_hex(5)}"
mind = minds_client.minds.create(
name=name, datasources=datasources, replace=True
)
self.mind_name = mind.name
def _run(
self,
query: Text
):
# Run the query on the AI-Mind.
# The Minds API is OpenAI compatible and therefore, the OpenAI client can be used.
openai_client = OpenAI(base_url=AIMindToolConstants.MINDS_API_BASE_URL, api_key=self.api_key)
completion = openai_client.chat.completions.create(
model=self.mind_name,
messages=[{"role": "user", "content": query}],
stream=False,
)
return completion.choices[0].message.content

View File

@@ -1,8 +1,8 @@
import os
from typing import Any, Optional, Type
from pydantic import BaseModel, Field
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
class BrowserbaseLoadToolSchema(BaseModel):

View File

@@ -2,10 +2,12 @@ import importlib.util
import os
from typing import List, Optional, Type
from crewai.tools import BaseTool
from docker import from_env as docker_from_env
from docker import DockerClient
from docker.models.containers import Container
from docker.errors import ImageNotFound, NotFound
from crewai.tools import BaseTool
from docker.models.containers import Container
from pydantic import BaseModel, Field
@@ -30,7 +32,7 @@ class CodeInterpreterTool(BaseTool):
default_image_tag: str = "code-interpreter:latest"
code: Optional[str] = None
user_dockerfile_path: Optional[str] = None
user_docker_base_url: Optional[str] = None
user_docker_base_url: Optional[str] = None
unsafe_mode: bool = False
@staticmethod
@@ -43,7 +45,11 @@ class CodeInterpreterTool(BaseTool):
Verify if the Docker image is available. Optionally use a user-provided Dockerfile.
"""
client = docker_from_env() if self.user_docker_base_url == None else docker.DockerClient(base_url=self.user_docker_base_url)
client = (
docker_from_env()
if self.user_docker_base_url == None
else DockerClient(base_url=self.user_docker_base_url)
)
try:
client.images.get(self.default_image_tag)
@@ -76,9 +82,7 @@ class CodeInterpreterTool(BaseTool):
else:
return self.run_code_in_docker(code, libraries_used)
def _install_libraries(
self, container: Container, libraries: List[str]
) -> None:
def _install_libraries(self, container: Container, libraries: List[str]) -> None:
"""
Install missing libraries in the Docker container
"""
@@ -135,4 +139,4 @@ class CodeInterpreterTool(BaseTool):
exec(code, {}, exec_locals)
return exec_locals.get("result", "No result variable found.")
except Exception as e:
return f"An error occurred: {str(e)}"
return f"An error occurred: {str(e)}"

View File

@@ -8,8 +8,6 @@ from pydantic import BaseModel, Field
class FixedDirectoryReadToolSchema(BaseModel):
"""Input for DirectoryReadTool."""
pass
class DirectoryReadToolSchema(FixedDirectoryReadToolSchema):
"""Input for DirectoryReadTool."""

View File

@@ -32,6 +32,7 @@ class FileReadTool(BaseTool):
>>> content = tool.run() # Reads /path/to/file.txt
>>> content = tool.run(file_path="/path/to/other.txt") # Reads other.txt
"""
name: str = "Read a file's content"
description: str = "A tool that reads the content of a file. To use this tool, provide a 'file_path' parameter with the path to the file you want to read."
args_schema: Type[BaseModel] = FileReadToolSchema
@@ -45,10 +46,11 @@ class FileReadTool(BaseTool):
this becomes the default file path for the tool.
**kwargs: Additional keyword arguments passed to BaseTool.
"""
super().__init__(**kwargs)
if file_path is not None:
self.file_path = file_path
self.description = f"A tool that reads file content. The default file is {file_path}, but you can provide a different 'file_path' parameter to read another file."
kwargs['description'] = f"A tool that reads file content. The default file is {file_path}, but you can provide a different 'file_path' parameter to read another file."
super().__init__(**kwargs)
self.file_path = file_path
def _run(
self,
@@ -57,7 +59,7 @@ class FileReadTool(BaseTool):
file_path = kwargs.get("file_path", self.file_path)
if file_path is None:
return "Error: No file path provided. Please provide a file path either in the constructor or as an argument."
try:
with open(file_path, "r") as file:
return file.read()
@@ -66,16 +68,4 @@ class FileReadTool(BaseTool):
except PermissionError:
return f"Error: Permission denied when trying to read file: {file_path}"
except Exception as e:
return f"Error: Failed to read file {file_path}. {str(e)}"
def _generate_description(self) -> None:
"""Generate the tool description based on file path.
This method updates the tool's description to include information about
the default file path while maintaining the ability to specify a different
file at runtime.
Returns:
None
"""
self.description = f"A tool that can be used to read {self.file_path}'s content."
return f"Error: Failed to read file {file_path}. {str(e)}"

View File

@@ -15,9 +15,7 @@ class FileWriterToolInput(BaseModel):
class FileWriterTool(BaseTool):
name: str = "File Writer Tool"
description: str = (
"A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input."
)
description: str = "A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input."
args_schema: Type[BaseModel] = FileWriterToolInput
def _run(self, **kwargs: Any) -> str:

View File

@@ -88,4 +88,3 @@ except ImportError:
"""
When this tool is not used, then exception can be ignored.
"""
pass

View File

@@ -78,4 +78,3 @@ except ImportError:
"""
When this tool is not used, then exception can be ignored.
"""
pass

View File

@@ -27,9 +27,7 @@ class GithubSearchToolSchema(FixedGithubSearchToolSchema):
class GithubSearchTool(RagTool):
name: str = "Search a github repo's content"
description: str = (
"A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
)
description: str = "A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
summarize: bool = False
gh_token: str
args_schema: Type[BaseModel] = GithubSearchToolSchema

View File

@@ -0,0 +1,42 @@
# HyperbrowserLoadTool
## Description
[Hyperbrowser](https://hyperbrowser.ai) is a platform for running and scaling headless browsers. It lets you launch and manage browser sessions at scale and provides easy to use solutions for any webscraping needs, such as scraping a single page or crawling an entire site.
Key Features:
- Instant Scalability - Spin up hundreds of browser sessions in seconds without infrastructure headaches
- Simple Integration - Works seamlessly with popular tools like Puppeteer and Playwright
- Powerful APIs - Easy to use APIs for scraping/crawling any site, and much more
- Bypass Anti-Bot Measures - Built-in stealth mode, ad blocking, automatic CAPTCHA solving, and rotating proxies
For more information about Hyperbrowser, please visit the [Hyperbrowser website](https://hyperbrowser.ai) or if you want to check out the docs, you can visit the [Hyperbrowser docs](https://docs.hyperbrowser.ai).
## Installation
- Head to [Hyperbrowser](https://app.hyperbrowser.ai/) to sign up and generate an API key. Once you've done this set the `HYPERBROWSER_API_KEY` environment variable or you can pass it to the `HyperbrowserLoadTool` constructor.
- Install the [Hyperbrowser SDK](https://github.com/hyperbrowserai/python-sdk):
```
pip install hyperbrowser 'crewai[tools]'
```
## Example
Utilize the HyperbrowserLoadTool as follows to allow your agent to load websites:
```python
from crewai_tools import HyperbrowserLoadTool
tool = HyperbrowserLoadTool()
```
## Arguments
`__init__` arguments:
- `api_key`: Optional. Specifies Hyperbrowser API key. Defaults to the `HYPERBROWSER_API_KEY` environment variable.
`run` arguments:
- `url`: The base URL to start scraping or crawling from.
- `operation`: Optional. Specifies the operation to perform on the website. Either 'scrape' or 'crawl'. Defaults is 'scrape'.
- `params`: Optional. Specifies the params for the operation. For more information on the supported params, visit https://docs.hyperbrowser.ai/reference/sdks/python/scrape#start-scrape-job-and-wait or https://docs.hyperbrowser.ai/reference/sdks/python/crawl#start-crawl-job-and-wait.

View File

@@ -0,0 +1,103 @@
import os
from typing import Any, Optional, Type, Dict, Literal, Union
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
class HyperbrowserLoadToolSchema(BaseModel):
url: str = Field(description="Website URL")
operation: Literal['scrape', 'crawl'] = Field(description="Operation to perform on the website. Either 'scrape' or 'crawl'")
params: Optional[Dict] = Field(description="Optional params for scrape or crawl. For more information on the supported params, visit https://docs.hyperbrowser.ai/reference/sdks/python/scrape#start-scrape-job-and-wait or https://docs.hyperbrowser.ai/reference/sdks/python/crawl#start-crawl-job-and-wait")
class HyperbrowserLoadTool(BaseTool):
"""HyperbrowserLoadTool.
Scrape or crawl web pages and load the contents with optional parameters for configuring content extraction.
Requires the `hyperbrowser` package.
Get your API Key from https://app.hyperbrowser.ai/
Args:
api_key: The Hyperbrowser API key, can be set as an environment variable `HYPERBROWSER_API_KEY` or passed directly
"""
name: str = "Hyperbrowser web load tool"
description: str = "Scrape or crawl a website using Hyperbrowser and return the contents in properly formatted markdown or html"
args_schema: Type[BaseModel] = HyperbrowserLoadToolSchema
api_key: Optional[str] = None
hyperbrowser: Optional[Any] = None
def __init__(self, api_key: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.api_key = api_key or os.getenv('HYPERBROWSER_API_KEY')
if not api_key:
raise ValueError(
"`api_key` is required, please set the `HYPERBROWSER_API_KEY` environment variable or pass it directly"
)
try:
from hyperbrowser import Hyperbrowser
except ImportError:
raise ImportError("`hyperbrowser` package not found, please run `pip install hyperbrowser`")
if not self.api_key:
raise ValueError("HYPERBROWSER_API_KEY is not set. Please provide it either via the constructor with the `api_key` argument or by setting the HYPERBROWSER_API_KEY environment variable.")
self.hyperbrowser = Hyperbrowser(api_key=self.api_key)
def _prepare_params(self, params: Dict) -> Dict:
"""Prepare session and scrape options parameters."""
try:
from hyperbrowser.models.session import CreateSessionParams
from hyperbrowser.models.scrape import ScrapeOptions
except ImportError:
raise ImportError(
"`hyperbrowser` package not found, please run `pip install hyperbrowser`"
)
if "scrape_options" in params:
if "formats" in params["scrape_options"]:
formats = params["scrape_options"]["formats"]
if not all(fmt in ["markdown", "html"] for fmt in formats):
raise ValueError("formats can only contain 'markdown' or 'html'")
if "session_options" in params:
params["session_options"] = CreateSessionParams(**params["session_options"])
if "scrape_options" in params:
params["scrape_options"] = ScrapeOptions(**params["scrape_options"])
return params
def _extract_content(self, data: Union[Any, None]):
"""Extract content from response data."""
content = ""
if data:
content = data.markdown or data.html or ""
return content
def _run(self, url: str, operation: Literal['scrape', 'crawl'] = 'scrape', params: Optional[Dict] = {}):
try:
from hyperbrowser.models.scrape import StartScrapeJobParams
from hyperbrowser.models.crawl import StartCrawlJobParams
except ImportError:
raise ImportError(
"`hyperbrowser` package not found, please run `pip install hyperbrowser`"
)
params = self._prepare_params(params)
if operation == 'scrape':
scrape_params = StartScrapeJobParams(url=url, **params)
scrape_resp = self.hyperbrowser.scrape.start_and_wait(scrape_params)
content = self._extract_content(scrape_resp.data)
return content
else:
crawl_params = StartCrawlJobParams(url=url, **params)
crawl_resp = self.hyperbrowser.crawl.start_and_wait(crawl_params)
content = ""
if crawl_resp.data:
for page in crawl_resp.data:
page_content = self._extract_content(page)
if page_content:
content += (
f"\n{'-'*50}\nUrl: {page.url}\nContent:\n{page_content}\n"
)
return content

View File

@@ -13,9 +13,7 @@ class JinaScrapeWebsiteToolInput(BaseModel):
class JinaScrapeWebsiteTool(BaseTool):
name: str = "JinaScrapeWebsiteTool"
description: str = (
"A tool that can be used to read a website content using Jina.ai reader and return markdown content."
)
description: str = "A tool that can be used to read a website content using Jina.ai reader and return markdown content."
args_schema: Type[BaseModel] = JinaScrapeWebsiteToolInput
website_url: Optional[str] = None
api_key: Optional[str] = None

View File

@@ -19,6 +19,10 @@ class LinkupSearchTool(BaseTool):
"Performs an API call to Linkup to retrieve contextual information."
)
_client: LinkupClient = PrivateAttr() # type: ignore
description: str = (
"Performs an API call to Linkup to retrieve contextual information."
)
_client: LinkupClient = PrivateAttr() # type: ignore
def __init__(self, api_key: str):
"""

View File

@@ -17,9 +17,7 @@ class MySQLSearchToolSchema(BaseModel):
class MySQLSearchTool(RagTool):
name: str = "Search a database's table content"
description: str = (
"A tool that can be used to semantic search a query from a database table's content."
)
description: str = "A tool that can be used to semantic search a query from a database table's content."
args_schema: Type[BaseModel] = MySQLSearchToolSchema
db_uri: str = Field(..., description="Mandatory database URI")

View File

@@ -1,30 +1,24 @@
from crewai import Agent, Crew, Task
from patronus_eval_tool import (
PatronusEvalTool,
)
from patronus_local_evaluator_tool import (
PatronusLocalEvaluatorTool,
)
from patronus_predefined_criteria_eval_tool import (
PatronusPredefinedCriteriaEvalTool,
)
from patronus import Client, EvaluationResult
import random
from crewai import Agent, Crew, Task
from patronus import Client, EvaluationResult
from patronus_local_evaluator_tool import PatronusLocalEvaluatorTool
# Test the PatronusLocalEvaluatorTool where agent uses the local evaluator
client = Client()
# Example of an evaluator that returns a random pass/fail result
@client.register_local_evaluator("random_evaluator")
def random_evaluator(**kwargs):
score = random.random()
return EvaluationResult(
score_raw=score,
pass_=score >= 0.5,
explanation="example explanation" # Optional justification for LLM judges
score_raw=score,
pass_=score >= 0.5,
explanation="example explanation", # Optional justification for LLM judges
)
# 1. Uses PatronusEvalTool: agent can pick the best evaluator and criteria
# patronus_eval_tool = PatronusEvalTool()
@@ -35,7 +29,9 @@ def random_evaluator(**kwargs):
# 3. Uses PatronusLocalEvaluatorTool: agent uses user defined evaluator
patronus_eval_tool = PatronusLocalEvaluatorTool(
patronus_client=client, evaluator="random_evaluator", evaluated_model_gold_answer="example label"
patronus_client=client,
evaluator="random_evaluator",
evaluated_model_gold_answer="example label",
)
# Create a new agent

View File

@@ -1,8 +1,9 @@
import os
import json
import requests
import os
import warnings
from typing import Any, List, Dict, Optional
from typing import Any, Dict, List, Optional
import requests
from crewai.tools import BaseTool
@@ -19,7 +20,9 @@ class PatronusEvalTool(BaseTool):
self.evaluators = temp_evaluators
self.criteria = temp_criteria
self.description = self._generate_description()
warnings.warn("You are allowing the agent to select the best evaluator and criteria when you use the `PatronusEvalTool`. If this is not intended then please use `PatronusPredefinedCriteriaEvalTool` instead.")
warnings.warn(
"You are allowing the agent to select the best evaluator and criteria when you use the `PatronusEvalTool`. If this is not intended then please use `PatronusPredefinedCriteriaEvalTool` instead."
)
def _init_run(self):
evaluators_set = json.loads(
@@ -104,7 +107,6 @@ class PatronusEvalTool(BaseTool):
evaluated_model_retrieved_context: Optional[str],
evaluators: List[Dict[str, str]],
) -> Any:
# Assert correct format of evaluators
evals = []
for ev in evaluators:
@@ -136,4 +138,4 @@ class PatronusEvalTool(BaseTool):
f"Failed to evaluate model input and output. Response status code: {response.status_code}. Reason: {response.text}"
)
return response.json()
return response.json()

View File

@@ -1,5 +1,7 @@
from typing import Any, Type
from crewai.tools import BaseTool
from patronus import Client
from pydantic import BaseModel, Field
try:
@@ -32,12 +34,20 @@ class PatronusLocalEvaluatorTool(BaseTool):
evaluator: str = "The registered local evaluator"
evaluated_model_gold_answer: str = "The agent's gold answer"
description: str = "This tool is used to evaluate the model input and output using custom function evaluators."
description: str = "This tool is used to evaluate the model input and output using custom function evaluators."
client: Any = None
args_schema: Type[BaseModel] = FixedLocalEvaluatorToolSchema
class Config:
arbitrary_types_allowed = True
def __init__(
self,
patronus_client: Client,
evaluator: str,
evaluated_model_gold_answer: str,
**kwargs: Any,
):
def __init__(
self,
patronus_client: Client,
@@ -105,6 +115,7 @@ class PatronusLocalEvaluatorTool(BaseTool):
else evaluated_model_gold_answer.get("description")
),
tags={}, # Optional metadata, supports arbitrary kv pairs
tags={}, # Optional metadata, supports arbitrary kv pairs
)
output = f"Evaluation result: {result.pass_}, Explanation: {result.explanation}"
return output

View File

@@ -1,7 +1,8 @@
import os
import json
import os
from typing import Any, Dict, List, Type
import requests
from typing import Any, List, Dict, Type
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
@@ -33,9 +34,7 @@ class PatronusPredefinedCriteriaEvalTool(BaseTool):
"""
name: str = "Call Patronus API tool for evaluation of model inputs and outputs"
description: str = (
"""This tool calls the Patronus Evaluation API that takes the following arguments:"""
)
description: str = """This tool calls the Patronus Evaluation API that takes the following arguments:"""
evaluate_url: str = "https://api.patronus.ai/v1/evaluate"
args_schema: Type[BaseModel] = FixedBaseToolSchema
evaluators: List[Dict[str, str]] = []
@@ -52,7 +51,6 @@ class PatronusPredefinedCriteriaEvalTool(BaseTool):
self,
**kwargs: Any,
) -> Any:
evaluated_model_input = kwargs.get("evaluated_model_input")
evaluated_model_output = kwargs.get("evaluated_model_output")
evaluated_model_retrieved_context = kwargs.get(
@@ -103,4 +101,4 @@ class PatronusPredefinedCriteriaEvalTool(BaseTool):
f"Failed to evaluate model input and output. Status code: {response.status_code}. Reason: {response.text}"
)
return response.json()
return response.json()

View File

@@ -1,7 +1,9 @@
from typing import Optional, Type
from pydantic import BaseModel, Field
from pypdf import PdfReader, PdfWriter, PageObject, ContentStream, NameObject, Font
from pathlib import Path
from typing import Optional, Type
from pydantic import BaseModel, Field
from pypdf import ContentStream, Font, NameObject, PageObject, PdfReader, PdfWriter
from crewai_tools.tools.rag.rag_tool import RagTool

View File

@@ -17,9 +17,7 @@ class PGSearchToolSchema(BaseModel):
class PGSearchTool(RagTool):
name: str = "Search a database's table content"
description: str = (
"A tool that can be used to semantic search a query from a database table's content."
)
description: str = "A tool that can be used to semantic search a query from a database table's content."
args_schema: Type[BaseModel] = PGSearchToolSchema
db_uri: str = Field(..., description="Mandatory database URI")

View File

@@ -10,8 +10,6 @@ from pydantic import BaseModel, Field
class FixedScrapeElementFromWebsiteToolSchema(BaseModel):
"""Input for ScrapeElementFromWebsiteTool."""
pass
class ScrapeElementFromWebsiteToolSchema(FixedScrapeElementFromWebsiteToolSchema):
"""Input for ScrapeElementFromWebsiteTool."""

View File

@@ -11,8 +11,6 @@ from pydantic import BaseModel, Field
class FixedScrapeWebsiteToolSchema(BaseModel):
"""Input for ScrapeWebsiteTool."""
pass
class ScrapeWebsiteToolSchema(FixedScrapeWebsiteToolSchema):
"""Input for ScrapeWebsiteTool."""

View File

@@ -13,20 +13,14 @@ if TYPE_CHECKING:
class ScrapegraphError(Exception):
"""Base exception for Scrapegraph-related errors"""
pass
class RateLimitError(ScrapegraphError):
"""Raised when API rate limits are exceeded"""
pass
class FixedScrapegraphScrapeToolSchema(BaseModel):
"""Input for ScrapegraphScrapeTool when website_url is fixed."""
pass
class ScrapegraphScrapeToolSchema(FixedScrapegraphScrapeToolSchema):
"""Input for ScrapegraphScrapeTool."""
@@ -71,6 +65,7 @@ class ScrapegraphScrapeTool(BaseTool):
website_url: Optional[str] = None
user_prompt: Optional[str] = None
api_key: Optional[str] = None
enable_logging: bool = False
_client: Optional["Client"] = None
def __init__(
@@ -78,6 +73,7 @@ class ScrapegraphScrapeTool(BaseTool):
website_url: Optional[str] = None,
user_prompt: Optional[str] = None,
api_key: Optional[str] = None,
enable_logging: bool = False,
**kwargs,
):
super().__init__(**kwargs)
@@ -118,8 +114,9 @@ class ScrapegraphScrapeTool(BaseTool):
if user_prompt is not None:
self.user_prompt = user_prompt
# Configure logging
sgai_logger.set_logging(level="INFO")
# Configure logging only if enabled
if self.enable_logging:
sgai_logger.set_logging(level="INFO")
@staticmethod
def _validate_url(url: str) -> None:
@@ -172,8 +169,7 @@ class ScrapegraphScrapeTool(BaseTool):
user_prompt=user_prompt,
)
# Handle and validate the response
return self._handle_api_response(response)
return response
except RateLimitError:
raise # Re-raise rate limit errors

View File

@@ -163,7 +163,7 @@ class SeleniumScrapingTool(BaseTool):
if not re.match(r"^https?://", url):
raise ValueError("URL must start with http:// or https://")
options = self._options()
options = Options()
options.add_argument("--headless")
driver = self.driver(options=options)
driver.get(url)

View File

@@ -1,6 +1,6 @@
import os
import re
from typing import Optional, Any, Union
from typing import Any, Optional, Union
from crewai.tools import BaseTool

View File

@@ -1,4 +1,4 @@
from typing import Any, Type, Optional
from typing import Any, Optional, Type
import re
from pydantic import BaseModel, Field, ConfigDict

View File

@@ -1,6 +1,5 @@
from typing import Any, Type, Optional
from typing import Any, Optional, Type
import re
from pydantic import BaseModel, Field
from .serpapi_base_tool import SerpApiBaseTool
from pydantic import ConfigDict

View File

@@ -1,19 +1,19 @@
import datetime
import json
import os
import logging
import os
from typing import Any, Type
import requests
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
def _save_results_to_file(content: str) -> None:
"""Saves the search results to a file."""
try:

View File

@@ -18,9 +18,7 @@ class SerplyWebpageToMarkdownToolSchema(BaseModel):
class SerplyWebpageToMarkdownTool(RagTool):
name: str = "Webpage to Markdown"
description: str = (
"A tool to perform convert a webpage to markdown to make it easier for LLMs to understand"
)
description: str = "A tool to perform convert a webpage to markdown to make it easier for LLMs to understand"
args_schema: Type[BaseModel] = SerplyWebpageToMarkdownToolSchema
request_url: str = "https://api.serply.io/v1/request"
proxy_location: Optional[str] = "US"

View File

@@ -0,0 +1,155 @@
# Snowflake Search Tool
A tool for executing queries on Snowflake data warehouse with built-in connection pooling, retry logic, and async execution support.
## Installation
```bash
uv sync --extra snowflake
OR
uv pip install snowflake-connector-python>=3.5.0 snowflake-sqlalchemy>=1.5.0 cryptography>=41.0.0
OR
pip install snowflake-connector-python>=3.5.0 snowflake-sqlalchemy>=1.5.0 cryptography>=41.0.0
```
## Quick Start
```python
import asyncio
from crewai_tools import SnowflakeSearchTool, SnowflakeConfig
# Create configuration
config = SnowflakeConfig(
account="your_account",
user="your_username",
password="your_password",
warehouse="COMPUTE_WH",
database="your_database",
snowflake_schema="your_schema" # Note: Uses snowflake_schema instead of schema
)
# Initialize tool
tool = SnowflakeSearchTool(
config=config,
pool_size=5,
max_retries=3,
enable_caching=True
)
# Execute query
async def main():
results = await tool._run(
query="SELECT * FROM your_table LIMIT 10",
timeout=300
)
print(f"Retrieved {len(results)} rows")
if __name__ == "__main__":
asyncio.run(main())
```
## Features
- ✨ Asynchronous query execution
- 🚀 Connection pooling for better performance
- 🔄 Automatic retries for transient failures
- 💾 Query result caching (optional)
- 🔒 Support for both password and key-pair authentication
- 📝 Comprehensive error handling and logging
## Configuration Options
### SnowflakeConfig Parameters
| Parameter | Required | Description |
|-----------|----------|-------------|
| account | Yes | Snowflake account identifier |
| user | Yes | Snowflake username |
| password | Yes* | Snowflake password |
| private_key_path | No* | Path to private key file (alternative to password) |
| warehouse | Yes | Snowflake warehouse name |
| database | Yes | Default database |
| snowflake_schema | Yes | Default schema |
| role | No | Snowflake role |
| session_parameters | No | Custom session parameters dict |
\* Either password or private_key_path must be provided
### Tool Parameters
| Parameter | Default | Description |
|-----------|---------|-------------|
| pool_size | 5 | Number of connections in the pool |
| max_retries | 3 | Maximum retry attempts for failed queries |
| retry_delay | 1.0 | Delay between retries in seconds |
| enable_caching | True | Enable/disable query result caching |
## Advanced Usage
### Using Key-Pair Authentication
```python
config = SnowflakeConfig(
account="your_account",
user="your_username",
private_key_path="/path/to/private_key.p8",
warehouse="your_warehouse",
database="your_database",
snowflake_schema="your_schema"
)
```
### Custom Session Parameters
```python
config = SnowflakeConfig(
# ... other config parameters ...
session_parameters={
"QUERY_TAG": "my_app",
"TIMEZONE": "America/Los_Angeles"
}
)
```
## Best Practices
1. **Error Handling**: Always wrap query execution in try-except blocks
2. **Logging**: Enable logging to track query execution and errors
3. **Connection Management**: Use appropriate pool sizes for your workload
4. **Timeouts**: Set reasonable query timeouts to prevent hanging
5. **Security**: Use key-pair auth in production and never hardcode credentials
## Example with Logging
```python
import logging
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
async def main():
try:
# ... tool initialization ...
results = await tool._run(query="SELECT * FROM table LIMIT 10")
logger.info(f"Query completed successfully. Retrieved {len(results)} rows")
except Exception as e:
logger.error(f"Query failed: {str(e)}")
raise
```
## Error Handling
The tool automatically handles common Snowflake errors:
- DatabaseError
- OperationalError
- ProgrammingError
- Network timeouts
- Connection issues
Errors are logged and retried based on your retry configuration.

View File

@@ -0,0 +1,11 @@
from .snowflake_search_tool import (
SnowflakeConfig,
SnowflakeSearchTool,
SnowflakeSearchToolInput,
)
__all__ = [
"SnowflakeSearchTool",
"SnowflakeSearchToolInput",
"SnowflakeConfig",
]

View File

@@ -0,0 +1,201 @@
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Type
import snowflake.connector
from crewai.tools.base_tool import BaseTool
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from pydantic import BaseModel, ConfigDict, Field, SecretStr
from snowflake.connector.connection import SnowflakeConnection
from snowflake.connector.errors import DatabaseError, OperationalError
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Cache for query results
_query_cache = {}
class SnowflakeConfig(BaseModel):
"""Configuration for Snowflake connection."""
model_config = ConfigDict(protected_namespaces=())
account: str = Field(
..., description="Snowflake account identifier", pattern=r"^[a-zA-Z0-9\-_]+$"
)
user: str = Field(..., description="Snowflake username")
password: Optional[SecretStr] = Field(None, description="Snowflake password")
private_key_path: Optional[str] = Field(
None, description="Path to private key file"
)
warehouse: Optional[str] = Field(None, description="Snowflake warehouse")
database: Optional[str] = Field(None, description="Default database")
snowflake_schema: Optional[str] = Field(None, description="Default schema")
role: Optional[str] = Field(None, description="Snowflake role")
session_parameters: Optional[Dict[str, Any]] = Field(
default_factory=dict, description="Session parameters"
)
@property
def has_auth(self) -> bool:
return bool(self.password or self.private_key_path)
def model_post_init(self, *args, **kwargs):
if not self.has_auth:
raise ValueError("Either password or private_key_path must be provided")
class SnowflakeSearchToolInput(BaseModel):
"""Input schema for SnowflakeSearchTool."""
model_config = ConfigDict(protected_namespaces=())
query: str = Field(..., description="SQL query or semantic search query to execute")
database: Optional[str] = Field(None, description="Override default database")
snowflake_schema: Optional[str] = Field(None, description="Override default schema")
timeout: Optional[int] = Field(300, description="Query timeout in seconds")
class SnowflakeSearchTool(BaseTool):
"""Tool for executing queries and semantic search on Snowflake."""
name: str = "Snowflake Database Search"
description: str = (
"Execute SQL queries or semantic search on Snowflake data warehouse. "
"Supports both raw SQL and natural language queries."
)
args_schema: Type[BaseModel] = SnowflakeSearchToolInput
# Define Pydantic fields
config: SnowflakeConfig = Field(
..., description="Snowflake connection configuration"
)
pool_size: int = Field(default=5, description="Size of connection pool")
max_retries: int = Field(default=3, description="Maximum retry attempts")
retry_delay: float = Field(
default=1.0, description="Delay between retries in seconds"
)
enable_caching: bool = Field(
default=True, description="Enable query result caching"
)
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, **data):
"""Initialize SnowflakeSearchTool."""
super().__init__(**data)
self._connection_pool: List[SnowflakeConnection] = []
self._pool_lock = asyncio.Lock()
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
async def _get_connection(self) -> SnowflakeConnection:
"""Get a connection from the pool or create a new one."""
async with self._pool_lock:
if not self._connection_pool:
conn = self._create_connection()
self._connection_pool.append(conn)
return self._connection_pool.pop()
def _create_connection(self) -> SnowflakeConnection:
"""Create a new Snowflake connection."""
conn_params = {
"account": self.config.account,
"user": self.config.user,
"warehouse": self.config.warehouse,
"database": self.config.database,
"schema": self.config.snowflake_schema,
"role": self.config.role,
"session_parameters": self.config.session_parameters,
}
if self.config.password:
conn_params["password"] = self.config.password.get_secret_value()
elif self.config.private_key_path:
with open(self.config.private_key_path, "rb") as key_file:
p_key = serialization.load_pem_private_key(
key_file.read(), password=None, backend=default_backend()
)
conn_params["private_key"] = p_key
return snowflake.connector.connect(**conn_params)
def _get_cache_key(self, query: str, timeout: int) -> str:
"""Generate a cache key for the query."""
return f"{self.config.account}:{self.config.database}:{self.config.snowflake_schema}:{query}:{timeout}"
async def _execute_query(
self, query: str, timeout: int = 300
) -> List[Dict[str, Any]]:
"""Execute a query with retries and return results."""
if self.enable_caching:
cache_key = self._get_cache_key(query, timeout)
if cache_key in _query_cache:
logger.info("Returning cached result")
return _query_cache[cache_key]
for attempt in range(self.max_retries):
try:
conn = await self._get_connection()
try:
cursor = conn.cursor()
cursor.execute(query, timeout=timeout)
if not cursor.description:
return []
columns = [col[0] for col in cursor.description]
results = [dict(zip(columns, row)) for row in cursor.fetchall()]
if self.enable_caching:
_query_cache[self._get_cache_key(query, timeout)] = results
return results
finally:
cursor.close()
async with self._pool_lock:
self._connection_pool.append(conn)
except (DatabaseError, OperationalError) as e:
if attempt == self.max_retries - 1:
raise
await asyncio.sleep(self.retry_delay * (2**attempt))
logger.warning(f"Query failed, attempt {attempt + 1}: {str(e)}")
continue
async def _run(
self,
query: str,
database: Optional[str] = None,
snowflake_schema: Optional[str] = None,
timeout: int = 300,
**kwargs: Any,
) -> Any:
"""Execute the search query."""
try:
# Override database/schema if provided
if database:
await self._execute_query(f"USE DATABASE {database}")
if snowflake_schema:
await self._execute_query(f"USE SCHEMA {snowflake_schema}")
results = await self._execute_query(query, timeout)
return results
except Exception as e:
logger.error(f"Error executing query: {str(e)}")
raise
def __del__(self):
"""Cleanup connections on deletion."""
try:
for conn in getattr(self, "_connection_pool", []):
try:
conn.close()
except:
pass
if hasattr(self, "_thread_pool"):
self._thread_pool.shutdown()
except:
pass

View File

@@ -14,9 +14,8 @@ import os
from functools import lru_cache
from typing import Any, Dict, List, Optional, Type, Union
from pydantic import BaseModel, Field
from crewai.tools.base_tool import BaseTool
from pydantic import BaseModel, Field
# Set up logging
logger = logging.getLogger(__name__)

View File

@@ -1,30 +1,36 @@
import base64
from typing import Type, Optional
from pathlib import Path
from typing import Optional, Type
from crewai.tools import BaseTool
from openai import OpenAI
from pydantic import BaseModel, validator
class ImagePromptSchema(BaseModel):
"""Input for Vision Tool."""
image_path_url: str = "The image path or URL."
@validator("image_path_url")
def validate_image_path_url(cls, v: str) -> str:
if v.startswith("http"):
return v
path = Path(v)
if not path.exists():
raise ValueError(f"Image file does not exist: {v}")
# Validate supported formats
valid_extensions = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
if path.suffix.lower() not in valid_extensions:
raise ValueError(f"Unsupported image format. Supported formats: {valid_extensions}")
raise ValueError(
f"Unsupported image format. Supported formats: {valid_extensions}"
)
return v
class VisionTool(BaseTool):
name: str = "Vision Tool"
description: str = (
@@ -45,10 +51,10 @@ class VisionTool(BaseTool):
image_path_url = kwargs.get("image_path_url")
if not image_path_url:
return "Image Path or URL is required."
# Validate input using Pydantic
ImagePromptSchema(image_path_url=image_path_url)
if image_path_url.startswith("http"):
image_data = image_path_url
else:
@@ -68,12 +74,12 @@ class VisionTool(BaseTool):
{
"type": "image_url",
"image_url": {"url": image_data},
}
},
],
}
],
max_tokens=300,
)
)
return response.choices[0].message.content

View File

@@ -15,9 +15,8 @@ except ImportError:
Vectorizers = Any
Auth = Any
from pydantic import BaseModel, Field
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
class WeaviateToolSchema(BaseModel):

View File

@@ -25,9 +25,7 @@ class WebsiteSearchToolSchema(FixedWebsiteSearchToolSchema):
class WebsiteSearchTool(RagTool):
name: str = "Search in a specific website"
description: str = (
"A tool that can be used to semantic search a query from a specific URL content."
)
description: str = "A tool that can be used to semantic search a query from a specific URL content."
args_schema: Type[BaseModel] = WebsiteSearchToolSchema
def __init__(self, website: Optional[str] = None, **kwargs):

View File

@@ -25,9 +25,7 @@ class YoutubeChannelSearchToolSchema(FixedYoutubeChannelSearchToolSchema):
class YoutubeChannelSearchTool(RagTool):
name: str = "Search a Youtube Channels content"
description: str = (
"A tool that can be used to semantic search a query from a Youtube Channels content."
)
description: str = "A tool that can be used to semantic search a query from a Youtube Channels content."
args_schema: Type[BaseModel] = YoutubeChannelSearchToolSchema
def __init__(self, youtube_channel_handle: Optional[str] = None, **kwargs):

View File

@@ -25,9 +25,7 @@ class YoutubeVideoSearchToolSchema(FixedYoutubeVideoSearchToolSchema):
class YoutubeVideoSearchTool(RagTool):
name: str = "Search a Youtube Video content"
description: str = (
"A tool that can be used to semantic search a query from a Youtube Video content."
)
description: str = "A tool that can be used to semantic search a query from a Youtube Video content."
args_schema: Type[BaseModel] = YoutubeVideoSearchToolSchema
def __init__(self, youtube_video_url: Optional[str] = None, **kwargs):

View File

@@ -1,69 +1,104 @@
from typing import Callable
from crewai.tools import BaseTool, tool
from crewai.tools.base_tool import to_langchain
def test_creating_a_tool_using_annotation():
@tool("Name of my tool")
def my_tool(question: str) -> str:
"""Clear description for what this tool is useful for, you agent will need this information to use it."""
return question
@tool("Name of my tool")
def my_tool(question: str) -> str:
"""Clear description for what this tool is useful for, you agent will need this information to use it."""
return question
# Assert all the right attributes were defined
assert my_tool.name == "Name of my tool"
assert my_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
assert my_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}}
assert my_tool.func("What is the meaning of life?") == "What is the meaning of life?"
# Assert all the right attributes were defined
assert my_tool.name == "Name of my tool"
assert (
my_tool.description
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
)
assert my_tool.args_schema.schema()["properties"] == {
"question": {"title": "Question", "type": "string"}
}
assert (
my_tool.func("What is the meaning of life?") == "What is the meaning of life?"
)
# Assert the langchain tool conversion worked as expected
converted_tool = to_langchain([my_tool])[0]
assert converted_tool.name == "Name of my tool"
assert (
converted_tool.description
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
)
assert converted_tool.args_schema.schema()["properties"] == {
"question": {"title": "Question", "type": "string"}
}
assert (
converted_tool.func("What is the meaning of life?")
== "What is the meaning of life?"
)
# Assert the langchain tool conversion worked as expected
converted_tool = to_langchain([my_tool])[0]
assert converted_tool.name == "Name of my tool"
assert converted_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
assert converted_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}}
assert converted_tool.func("What is the meaning of life?") == "What is the meaning of life?"
def test_creating_a_tool_using_baseclass():
class MyCustomTool(BaseTool):
name: str = "Name of my tool"
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
class MyCustomTool(BaseTool):
name: str = "Name of my tool"
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
def _run(self, question: str) -> str:
return question
def _run(self, question: str) -> str:
return question
my_tool = MyCustomTool()
# Assert all the right attributes were defined
assert my_tool.name == "Name of my tool"
assert my_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
assert my_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}}
assert my_tool._run("What is the meaning of life?") == "What is the meaning of life?"
my_tool = MyCustomTool()
# Assert all the right attributes were defined
assert my_tool.name == "Name of my tool"
assert (
my_tool.description
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
)
assert my_tool.args_schema.schema()["properties"] == {
"question": {"title": "Question", "type": "string"}
}
assert (
my_tool._run("What is the meaning of life?") == "What is the meaning of life?"
)
# Assert the langchain tool conversion worked as expected
converted_tool = to_langchain([my_tool])[0]
assert converted_tool.name == "Name of my tool"
assert (
converted_tool.description
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
)
assert converted_tool.args_schema.schema()["properties"] == {
"question": {"title": "Question", "type": "string"}
}
assert (
converted_tool.invoke({"question": "What is the meaning of life?"})
== "What is the meaning of life?"
)
# Assert the langchain tool conversion worked as expected
converted_tool = to_langchain([my_tool])[0]
assert converted_tool.name == "Name of my tool"
assert converted_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
assert converted_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}}
assert converted_tool.invoke({"question": "What is the meaning of life?"}) == "What is the meaning of life?"
def test_setting_cache_function():
class MyCustomTool(BaseTool):
name: str = "Name of my tool"
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
cache_function: Callable = lambda: False
class MyCustomTool(BaseTool):
name: str = "Name of my tool"
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
cache_function: Callable = lambda: False
def _run(self, question: str) -> str:
return question
def _run(self, question: str) -> str:
return question
my_tool = MyCustomTool()
# Assert all the right attributes were defined
assert my_tool.cache_function() == False
my_tool = MyCustomTool()
# Assert all the right attributes were defined
assert my_tool.cache_function() == False
def test_default_cache_function_is_true():
class MyCustomTool(BaseTool):
name: str = "Name of my tool"
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
class MyCustomTool(BaseTool):
name: str = "Name of my tool"
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
def _run(self, question: str) -> str:
return question
def _run(self, question: str) -> str:
return question
my_tool = MyCustomTool()
# Assert all the right attributes were defined
assert my_tool.cache_function() == True
my_tool = MyCustomTool()
# Assert all the right attributes were defined
assert my_tool.cache_function() == True

View File

@@ -1,7 +1,8 @@
import os
import pytest
from crewai_tools import FileReadTool
def test_file_read_tool_constructor():
"""Test FileReadTool initialization with file_path."""
# Create a temporary test file
@@ -18,6 +19,7 @@ def test_file_read_tool_constructor():
# Clean up
os.remove(test_file)
def test_file_read_tool_run():
"""Test FileReadTool _run method with file_path at runtime."""
# Create a temporary test file
@@ -34,6 +36,7 @@ def test_file_read_tool_run():
# Clean up
os.remove(test_file)
def test_file_read_tool_error_handling():
"""Test FileReadTool error handling."""
# Test missing file path
@@ -58,6 +61,7 @@ def test_file_read_tool_error_handling():
os.chmod(test_file, 0o666) # Restore permissions to delete
os.remove(test_file)
def test_file_read_tool_constructor_and_run():
"""Test FileReadTool using both constructor and runtime file paths."""
# Create two test files

View File

View File

@@ -0,0 +1,21 @@
import pytest
def pytest_configure(config):
"""Register custom markers."""
config.addinivalue_line("markers", "integration: mark test as an integration test")
config.addinivalue_line("markers", "asyncio: mark test as an async test")
# Set the asyncio loop scope through ini configuration
config.inicfg["asyncio_mode"] = "auto"
@pytest.fixture(scope="function")
def event_loop():
"""Create an instance of the default event loop for each test case."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
yield loop
loop.close()

View File

@@ -0,0 +1,219 @@
import asyncio
import json
from decimal import Decimal
import pytest
from snowflake.connector.errors import DatabaseError, OperationalError
from crewai_tools import SnowflakeConfig, SnowflakeSearchTool
# Test Data
MENU_ITEMS = [
(10001, "Ice Cream", "Freezing Point", "Lemonade", "Beverage", "Cold Option", 1, 4),
(
10002,
"Ice Cream",
"Freezing Point",
"Vanilla Ice Cream",
"Dessert",
"Ice Cream",
2,
6,
),
]
INVALID_QUERIES = [
("SELECT * FROM nonexistent_table", "relation 'nonexistent_table' does not exist"),
("SELECT invalid_column FROM menu", "invalid identifier 'invalid_column'"),
("INVALID SQL QUERY", "SQL compilation error"),
]
# Integration Test Fixtures
@pytest.fixture
def config():
"""Create a Snowflake configuration with test credentials."""
return SnowflakeConfig(
account="lwyhjun-wx11931",
user="crewgitci",
password="crewaiT00ls_publicCIpass123",
warehouse="COMPUTE_WH",
database="tasty_bytes_sample_data",
snowflake_schema="raw_pos",
)
@pytest.fixture
def snowflake_tool(config):
"""Create a SnowflakeSearchTool instance."""
return SnowflakeSearchTool(config=config)
# Integration Tests with Real Snowflake Connection
@pytest.mark.integration
@pytest.mark.asyncio
@pytest.mark.parametrize(
"menu_id,expected_type,brand,item_name,category,subcategory,cost,price", MENU_ITEMS
)
async def test_menu_items(
snowflake_tool,
menu_id,
expected_type,
brand,
item_name,
category,
subcategory,
cost,
price,
):
"""Test menu items with parameterized data for multiple test cases."""
results = await snowflake_tool._run(
query=f"SELECT * FROM menu WHERE menu_id = {menu_id}"
)
assert len(results) == 1
menu_item = results[0]
# Validate all fields
assert menu_item["MENU_ID"] == menu_id
assert menu_item["MENU_TYPE"] == expected_type
assert menu_item["TRUCK_BRAND_NAME"] == brand
assert menu_item["MENU_ITEM_NAME"] == item_name
assert menu_item["ITEM_CATEGORY"] == category
assert menu_item["ITEM_SUBCATEGORY"] == subcategory
assert menu_item["COST_OF_GOODS_USD"] == cost
assert menu_item["SALE_PRICE_USD"] == price
# Validate health metrics JSON structure
health_metrics = json.loads(menu_item["MENU_ITEM_HEALTH_METRICS_OBJ"])
assert "menu_item_health_metrics" in health_metrics
metrics = health_metrics["menu_item_health_metrics"][0]
assert "ingredients" in metrics
assert isinstance(metrics["ingredients"], list)
assert all(isinstance(ingredient, str) for ingredient in metrics["ingredients"])
assert metrics["is_dairy_free_flag"] in ["Y", "N"]
@pytest.mark.integration
@pytest.mark.asyncio
async def test_menu_categories_aggregation(snowflake_tool):
"""Test complex aggregation query on menu categories with detailed validations."""
results = await snowflake_tool._run(
query="""
SELECT
item_category,
COUNT(*) as item_count,
AVG(sale_price_usd) as avg_price,
SUM(sale_price_usd - cost_of_goods_usd) as total_margin,
COUNT(DISTINCT menu_type) as menu_type_count,
MIN(sale_price_usd) as min_price,
MAX(sale_price_usd) as max_price
FROM menu
GROUP BY item_category
HAVING COUNT(*) > 1
ORDER BY item_count DESC
"""
)
assert len(results) > 0
for category in results:
# Basic presence checks
assert all(
key in category
for key in [
"ITEM_CATEGORY",
"ITEM_COUNT",
"AVG_PRICE",
"TOTAL_MARGIN",
"MENU_TYPE_COUNT",
"MIN_PRICE",
"MAX_PRICE",
]
)
# Value validations
assert category["ITEM_COUNT"] > 1 # Due to HAVING clause
assert category["MIN_PRICE"] <= category["MAX_PRICE"]
assert category["AVG_PRICE"] >= category["MIN_PRICE"]
assert category["AVG_PRICE"] <= category["MAX_PRICE"]
assert category["MENU_TYPE_COUNT"] >= 1
assert isinstance(category["TOTAL_MARGIN"], (float, Decimal))
@pytest.mark.integration
@pytest.mark.asyncio
@pytest.mark.parametrize("invalid_query,expected_error", INVALID_QUERIES)
async def test_invalid_queries(snowflake_tool, invalid_query, expected_error):
"""Test error handling for invalid queries."""
with pytest.raises((DatabaseError, OperationalError)) as exc_info:
await snowflake_tool._run(query=invalid_query)
assert expected_error.lower() in str(exc_info.value).lower()
@pytest.mark.integration
@pytest.mark.asyncio
async def test_concurrent_queries(snowflake_tool):
"""Test handling of concurrent queries."""
queries = [
"SELECT COUNT(*) FROM menu",
"SELECT COUNT(DISTINCT menu_type) FROM menu",
"SELECT COUNT(DISTINCT item_category) FROM menu",
]
tasks = [snowflake_tool._run(query=query) for query in queries]
results = await asyncio.gather(*tasks)
assert len(results) == 3
assert all(isinstance(result, list) for result in results)
assert all(len(result) == 1 for result in results)
assert all(isinstance(result[0], dict) for result in results)
@pytest.mark.integration
@pytest.mark.asyncio
async def test_query_timeout(snowflake_tool):
"""Test query timeout handling with a complex query."""
with pytest.raises((DatabaseError, OperationalError)) as exc_info:
await snowflake_tool._run(
query="""
WITH RECURSIVE numbers AS (
SELECT 1 as n
UNION ALL
SELECT n + 1
FROM numbers
WHERE n < 1000000
)
SELECT COUNT(*) FROM numbers
"""
)
assert (
"timeout" in str(exc_info.value).lower()
or "execution time" in str(exc_info.value).lower()
)
@pytest.mark.integration
@pytest.mark.asyncio
async def test_caching_behavior(snowflake_tool):
"""Test query caching behavior and performance."""
query = "SELECT * FROM menu LIMIT 5"
# First execution
start_time = asyncio.get_event_loop().time()
results1 = await snowflake_tool._run(query=query)
first_duration = asyncio.get_event_loop().time() - start_time
# Second execution (should be cached)
start_time = asyncio.get_event_loop().time()
results2 = await snowflake_tool._run(query=query)
second_duration = asyncio.get_event_loop().time() - start_time
# Verify results
assert results1 == results2
assert len(results1) == 5
assert second_duration < first_duration
# Verify cache invalidation with different query
different_query = "SELECT * FROM menu LIMIT 10"
different_results = await snowflake_tool._run(query=different_query)
assert len(different_results) == 10
assert different_results != results1

View File

@@ -1,5 +1,7 @@
from crewai import Agent, Crew, Task
from crewai_tools.tools.spider_tool.spider_tool import SpiderTool
from crewai import Agent, Task, Crew
def test_spider_tool():
spider_tool = SpiderTool()
@@ -10,38 +12,35 @@ def test_spider_tool():
backstory="An expert web researcher that uses the web extremely well",
tools=[spider_tool],
verbose=True,
cache=False
cache=False,
)
choose_between_scrape_crawl = Task(
description="Scrape the page of spider.cloud and return a summary of how fast it is",
expected_output="spider.cloud is a fast scraping and crawling tool",
agent=searcher
agent=searcher,
)
return_metadata = Task(
description="Scrape https://spider.cloud with a limit of 1 and enable metadata",
expected_output="Metadata and 10 word summary of spider.cloud",
agent=searcher
agent=searcher,
)
css_selector = Task(
description="Scrape one page of spider.cloud with the `body > div > main > section.grid.md\:grid-cols-2.gap-10.place-items-center.md\:max-w-screen-xl.mx-auto.pb-8.pt-20 > div:nth-child(1) > h1` CSS selector",
expected_output="The content of the element with the css selector body > div > main > section.grid.md\:grid-cols-2.gap-10.place-items-center.md\:max-w-screen-xl.mx-auto.pb-8.pt-20 > div:nth-child(1) > h1",
agent=searcher
agent=searcher,
)
crew = Crew(
agents=[searcher],
tasks=[
choose_between_scrape_crawl,
return_metadata,
css_selector
],
verbose=True
tasks=[choose_between_scrape_crawl, return_metadata, css_selector],
verbose=True,
)
crew.kickoff()
if __name__ == "__main__":
test_spider_tool()

View File

@@ -0,0 +1,103 @@
import asyncio
from unittest.mock import MagicMock, patch
import pytest
from crewai_tools import SnowflakeConfig, SnowflakeSearchTool
# Unit Test Fixtures
@pytest.fixture
def mock_snowflake_connection():
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_cursor.description = [("col1",), ("col2",)]
mock_cursor.fetchall.return_value = [(1, "value1"), (2, "value2")]
mock_cursor.execute.return_value = None
mock_conn.cursor.return_value = mock_cursor
return mock_conn
@pytest.fixture
def mock_config():
return SnowflakeConfig(
account="test_account",
user="test_user",
password="test_password",
warehouse="test_warehouse",
database="test_db",
snowflake_schema="test_schema",
)
@pytest.fixture
def snowflake_tool(mock_config):
with patch("snowflake.connector.connect") as mock_connect:
tool = SnowflakeSearchTool(config=mock_config)
yield tool
# Unit Tests
@pytest.mark.asyncio
async def test_successful_query_execution(snowflake_tool, mock_snowflake_connection):
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
mock_create_conn.return_value = mock_snowflake_connection
results = await snowflake_tool._run(
query="SELECT * FROM test_table", timeout=300
)
assert len(results) == 2
assert results[0]["col1"] == 1
assert results[0]["col2"] == "value1"
mock_snowflake_connection.cursor.assert_called_once()
@pytest.mark.asyncio
async def test_connection_pooling(snowflake_tool, mock_snowflake_connection):
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
mock_create_conn.return_value = mock_snowflake_connection
# Execute multiple queries
await asyncio.gather(
snowflake_tool._run("SELECT 1"),
snowflake_tool._run("SELECT 2"),
snowflake_tool._run("SELECT 3"),
)
# Should reuse connections from pool
assert mock_create_conn.call_count <= snowflake_tool.pool_size
@pytest.mark.asyncio
async def test_cleanup_on_deletion(snowflake_tool, mock_snowflake_connection):
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
mock_create_conn.return_value = mock_snowflake_connection
# Add connection to pool
await snowflake_tool._get_connection()
# Return connection to pool
async with snowflake_tool._pool_lock:
snowflake_tool._connection_pool.append(mock_snowflake_connection)
# Trigger cleanup
snowflake_tool.__del__()
mock_snowflake_connection.close.assert_called_once()
def test_config_validation():
# Test missing required fields
with pytest.raises(ValueError):
SnowflakeConfig()
# Test invalid account format
with pytest.raises(ValueError):
SnowflakeConfig(
account="invalid//account", user="test_user", password="test_pass"
)
# Test missing authentication
with pytest.raises(ValueError):
SnowflakeConfig(account="test_account", user="test_user")

View File

@@ -7,7 +7,9 @@ from crewai_tools.tools.code_interpreter_tool.code_interpreter_tool import (
class TestCodeInterpreterTool(unittest.TestCase):
@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env")
@patch(
"crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env"
)
def test_run_code_in_docker(self, docker_mock):
tool = CodeInterpreterTool()
code = "print('Hello, World!')"
@@ -15,14 +17,14 @@ class TestCodeInterpreterTool(unittest.TestCase):
expected_output = "Hello, World!\n"
docker_mock().containers.run().exec_run().exit_code = 0
docker_mock().containers.run().exec_run().output = (
expected_output.encode()
)
docker_mock().containers.run().exec_run().output = expected_output.encode()
result = tool.run_code_in_docker(code, libraries_used)
self.assertEqual(result, expected_output)
@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env")
@patch(
"crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env"
)
def test_run_code_in_docker_with_error(self, docker_mock):
tool = CodeInterpreterTool()
code = "print(1/0)"
@@ -37,7 +39,9 @@ class TestCodeInterpreterTool(unittest.TestCase):
self.assertEqual(result, expected_output)
@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env")
@patch(
"crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env"
)
def test_run_code_in_docker_with_script(self, docker_mock):
tool = CodeInterpreterTool()
code = """print("This is line 1")