mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
Merge branch 'main' of github.com:crewAIInc/crewAI-tools into fix/optional-dependencies
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from .tools import (
|
||||
AIMindTool,
|
||||
BraveSearchTool,
|
||||
BrowserbaseLoadTool,
|
||||
CodeDocsSearchTool,
|
||||
@@ -16,6 +17,7 @@ from .tools import (
|
||||
FirecrawlScrapeWebsiteTool,
|
||||
FirecrawlSearchTool,
|
||||
GithubSearchTool,
|
||||
HyperbrowserLoadTool,
|
||||
JSONSearchTool,
|
||||
LinkupSearchTool,
|
||||
LlamaIndexTool,
|
||||
@@ -43,6 +45,8 @@ from .tools import (
|
||||
SerplyScholarSearchTool,
|
||||
SerplyWebpageToMarkdownTool,
|
||||
SerplyWebSearchTool,
|
||||
SnowflakeConfig,
|
||||
SnowflakeSearchTool,
|
||||
SpiderTool,
|
||||
TXTSearchTool,
|
||||
VisionTool,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from .ai_mind_tool.ai_mind_tool import AIMindTool
|
||||
from .brave_search_tool.brave_search_tool import BraveSearchTool
|
||||
from .browserbase_load_tool.browserbase_load_tool import BrowserbaseLoadTool
|
||||
from .code_docs_search_tool.code_docs_search_tool import CodeDocsSearchTool
|
||||
@@ -19,6 +20,7 @@ from .firecrawl_scrape_website_tool.firecrawl_scrape_website_tool import (
|
||||
)
|
||||
from .firecrawl_search_tool.firecrawl_search_tool import FirecrawlSearchTool
|
||||
from .github_search_tool.github_search_tool import GithubSearchTool
|
||||
from .hyperbrowser_load_tool.hyperbrowser_load_tool import HyperbrowserLoadTool
|
||||
from .json_search_tool.json_search_tool import JSONSearchTool
|
||||
from .linkup.linkup_search_tool import LinkupSearchTool
|
||||
from .llamaindex_tool.llamaindex_tool import LlamaIndexTool
|
||||
@@ -54,6 +56,11 @@ from .serply_api_tool.serply_news_search_tool import SerplyNewsSearchTool
|
||||
from .serply_api_tool.serply_scholar_search_tool import SerplyScholarSearchTool
|
||||
from .serply_api_tool.serply_web_search_tool import SerplyWebSearchTool
|
||||
from .serply_api_tool.serply_webpage_to_markdown_tool import SerplyWebpageToMarkdownTool
|
||||
from .snowflake_search_tool import (
|
||||
SnowflakeConfig,
|
||||
SnowflakeSearchTool,
|
||||
SnowflakeSearchToolInput,
|
||||
)
|
||||
from .spider_tool.spider_tool import SpiderTool
|
||||
from .txt_search_tool.txt_search_tool import TXTSearchTool
|
||||
from .vision_tool.vision_tool import VisionTool
|
||||
|
||||
79
src/crewai_tools/tools/ai_mind_tool/README.md
Normal file
79
src/crewai_tools/tools/ai_mind_tool/README.md
Normal file
@@ -0,0 +1,79 @@
|
||||
# AIMind Tool
|
||||
|
||||
## Description
|
||||
|
||||
[Minds](https://mindsdb.com/minds) are AI systems provided by [MindsDB](https://mindsdb.com/) that work similarly to large language models (LLMs) but go beyond by answering any question from any data.
|
||||
|
||||
This is accomplished by selecting the most relevant data for an answer using parametric search, understanding the meaning and providing responses within the correct context through semantic search, and finally, delivering precise answers by analyzing data and using machine learning (ML) models.
|
||||
|
||||
The `AIMindTool` can be used to query data sources in natural language by simply configuring their connection parameters.
|
||||
|
||||
## Installation
|
||||
|
||||
1. Install the `crewai[tools]` package:
|
||||
|
||||
```shell
|
||||
pip install 'crewai[tools]'
|
||||
```
|
||||
|
||||
2. Install the Minds SDK:
|
||||
|
||||
```shell
|
||||
pip install minds-sdk
|
||||
```
|
||||
|
||||
3. Sign for a Minds account [here](https://mdb.ai/register), and obtain an API key.
|
||||
|
||||
4. Set the Minds API key in an environment variable named `MINDS_API_KEY`.
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from crewai_tools import AIMindTool
|
||||
|
||||
|
||||
# Initialize the AIMindTool.
|
||||
aimind_tool = AIMindTool(
|
||||
datasources=[
|
||||
{
|
||||
"description": "house sales data",
|
||||
"engine": "postgres",
|
||||
"connection_data": {
|
||||
"user": "demo_user",
|
||||
"password": "demo_password",
|
||||
"host": "samples.mindsdb.com",
|
||||
"port": 5432,
|
||||
"database": "demo",
|
||||
"schema": "demo_data"
|
||||
},
|
||||
"tables": ["house_sales"]
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
aimind_tool.run("How many 3 bedroom houses were sold in 2008?")
|
||||
```
|
||||
|
||||
The `datasources` parameter is a list of dictionaries, each containing the following keys:
|
||||
|
||||
- `description`: A description of the data contained in the datasource.
|
||||
- `engine`: The engine (or type) of the datasource. Find a list of supported engines in the link below.
|
||||
- `connection_data`: A dictionary containing the connection parameters for the datasource. Find a list of connection parameters for each engine in the link below.
|
||||
- `tables`: A list of tables that the data source will use. This is optional and can be omitted if all tables in the data source are to be used.
|
||||
|
||||
A list of supported data sources and their connection parameters can be found [here](https://docs.mdb.ai/docs/data_sources).
|
||||
|
||||
```python
|
||||
from crewai import Agent
|
||||
from crewai.project import agent
|
||||
|
||||
|
||||
# Define an agent with the AIMindTool.
|
||||
@agent
|
||||
def researcher(self) -> Agent:
|
||||
return Agent(
|
||||
config=self.agents_config["researcher"],
|
||||
allow_delegation=False,
|
||||
tools=[aimind_tool]
|
||||
)
|
||||
```
|
||||
0
src/crewai_tools/tools/ai_mind_tool/__init__.py
Normal file
0
src/crewai_tools/tools/ai_mind_tool/__init__.py
Normal file
87
src/crewai_tools/tools/ai_mind_tool/ai_mind_tool.py
Normal file
87
src/crewai_tools/tools/ai_mind_tool/ai_mind_tool.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import os
|
||||
import secrets
|
||||
from typing import Any, Dict, List, Optional, Text, Type
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AIMindToolConstants:
|
||||
MINDS_API_BASE_URL = "https://mdb.ai/"
|
||||
MIND_NAME_PREFIX = "crwai_mind_"
|
||||
DATASOURCE_NAME_PREFIX = "crwai_ds_"
|
||||
|
||||
|
||||
class AIMindToolInputSchema(BaseModel):
|
||||
"""Input for AIMind Tool."""
|
||||
|
||||
query: str = "Question in natural language to ask the AI-Mind"
|
||||
|
||||
|
||||
class AIMindTool(BaseTool):
|
||||
name: str = "AIMind Tool"
|
||||
description: str = (
|
||||
"A wrapper around [AI-Minds](https://mindsdb.com/minds). "
|
||||
"Useful for when you need answers to questions from your data, stored in "
|
||||
"data sources including PostgreSQL, MySQL, MariaDB, ClickHouse, Snowflake "
|
||||
"and Google BigQuery. "
|
||||
"Input should be a question in natural language."
|
||||
)
|
||||
args_schema: Type[BaseModel] = AIMindToolInputSchema
|
||||
api_key: Optional[str] = None
|
||||
datasources: Optional[List[Dict[str, Any]]] = None
|
||||
mind_name: Optional[Text] = None
|
||||
|
||||
def __init__(self, api_key: Optional[Text] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = api_key or os.getenv("MINDS_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("API key must be provided either through constructor or MINDS_API_KEY environment variable")
|
||||
|
||||
try:
|
||||
from minds.client import Client # type: ignore
|
||||
from minds.datasources import DatabaseConfig # type: ignore
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`minds_sdk` package not found, please run `pip install minds-sdk`"
|
||||
)
|
||||
|
||||
minds_client = Client(api_key=self.api_key)
|
||||
|
||||
# Convert the datasources to DatabaseConfig objects.
|
||||
datasources = []
|
||||
for datasource in self.datasources:
|
||||
config = DatabaseConfig(
|
||||
name=f"{AIMindToolConstants.DATASOURCE_NAME_PREFIX}_{secrets.token_hex(5)}",
|
||||
engine=datasource["engine"],
|
||||
description=datasource["description"],
|
||||
connection_data=datasource["connection_data"],
|
||||
tables=datasource["tables"],
|
||||
)
|
||||
datasources.append(config)
|
||||
|
||||
# Generate a random name for the Mind.
|
||||
name = f"{AIMindToolConstants.MIND_NAME_PREFIX}_{secrets.token_hex(5)}"
|
||||
|
||||
mind = minds_client.minds.create(
|
||||
name=name, datasources=datasources, replace=True
|
||||
)
|
||||
|
||||
self.mind_name = mind.name
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: Text
|
||||
):
|
||||
# Run the query on the AI-Mind.
|
||||
# The Minds API is OpenAI compatible and therefore, the OpenAI client can be used.
|
||||
openai_client = OpenAI(base_url=AIMindToolConstants.MINDS_API_BASE_URL, api_key=self.api_key)
|
||||
|
||||
completion = openai_client.chat.completions.create(
|
||||
model=self.mind_name,
|
||||
messages=[{"role": "user", "content": query}],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
return completion.choices[0].message.content
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
from typing import Any, Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BrowserbaseLoadToolSchema(BaseModel):
|
||||
|
||||
@@ -2,10 +2,12 @@ import importlib.util
|
||||
import os
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from docker import from_env as docker_from_env
|
||||
from docker import DockerClient
|
||||
from docker.models.containers import Container
|
||||
from docker.errors import ImageNotFound, NotFound
|
||||
from crewai.tools import BaseTool
|
||||
from docker.models.containers import Container
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -30,7 +32,7 @@ class CodeInterpreterTool(BaseTool):
|
||||
default_image_tag: str = "code-interpreter:latest"
|
||||
code: Optional[str] = None
|
||||
user_dockerfile_path: Optional[str] = None
|
||||
user_docker_base_url: Optional[str] = None
|
||||
user_docker_base_url: Optional[str] = None
|
||||
unsafe_mode: bool = False
|
||||
|
||||
@staticmethod
|
||||
@@ -43,7 +45,11 @@ class CodeInterpreterTool(BaseTool):
|
||||
Verify if the Docker image is available. Optionally use a user-provided Dockerfile.
|
||||
"""
|
||||
|
||||
client = docker_from_env() if self.user_docker_base_url == None else docker.DockerClient(base_url=self.user_docker_base_url)
|
||||
client = (
|
||||
docker_from_env()
|
||||
if self.user_docker_base_url == None
|
||||
else DockerClient(base_url=self.user_docker_base_url)
|
||||
)
|
||||
|
||||
try:
|
||||
client.images.get(self.default_image_tag)
|
||||
@@ -76,9 +82,7 @@ class CodeInterpreterTool(BaseTool):
|
||||
else:
|
||||
return self.run_code_in_docker(code, libraries_used)
|
||||
|
||||
def _install_libraries(
|
||||
self, container: Container, libraries: List[str]
|
||||
) -> None:
|
||||
def _install_libraries(self, container: Container, libraries: List[str]) -> None:
|
||||
"""
|
||||
Install missing libraries in the Docker container
|
||||
"""
|
||||
@@ -135,4 +139,4 @@ class CodeInterpreterTool(BaseTool):
|
||||
exec(code, {}, exec_locals)
|
||||
return exec_locals.get("result", "No result variable found.")
|
||||
except Exception as e:
|
||||
return f"An error occurred: {str(e)}"
|
||||
return f"An error occurred: {str(e)}"
|
||||
@@ -8,8 +8,6 @@ from pydantic import BaseModel, Field
|
||||
class FixedDirectoryReadToolSchema(BaseModel):
|
||||
"""Input for DirectoryReadTool."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DirectoryReadToolSchema(FixedDirectoryReadToolSchema):
|
||||
"""Input for DirectoryReadTool."""
|
||||
|
||||
@@ -32,6 +32,7 @@ class FileReadTool(BaseTool):
|
||||
>>> content = tool.run() # Reads /path/to/file.txt
|
||||
>>> content = tool.run(file_path="/path/to/other.txt") # Reads other.txt
|
||||
"""
|
||||
|
||||
name: str = "Read a file's content"
|
||||
description: str = "A tool that reads the content of a file. To use this tool, provide a 'file_path' parameter with the path to the file you want to read."
|
||||
args_schema: Type[BaseModel] = FileReadToolSchema
|
||||
@@ -45,10 +46,11 @@ class FileReadTool(BaseTool):
|
||||
this becomes the default file path for the tool.
|
||||
**kwargs: Additional keyword arguments passed to BaseTool.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
if file_path is not None:
|
||||
self.file_path = file_path
|
||||
self.description = f"A tool that reads file content. The default file is {file_path}, but you can provide a different 'file_path' parameter to read another file."
|
||||
kwargs['description'] = f"A tool that reads file content. The default file is {file_path}, but you can provide a different 'file_path' parameter to read another file."
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.file_path = file_path
|
||||
|
||||
def _run(
|
||||
self,
|
||||
@@ -57,7 +59,7 @@ class FileReadTool(BaseTool):
|
||||
file_path = kwargs.get("file_path", self.file_path)
|
||||
if file_path is None:
|
||||
return "Error: No file path provided. Please provide a file path either in the constructor or as an argument."
|
||||
|
||||
|
||||
try:
|
||||
with open(file_path, "r") as file:
|
||||
return file.read()
|
||||
@@ -66,16 +68,4 @@ class FileReadTool(BaseTool):
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied when trying to read file: {file_path}"
|
||||
except Exception as e:
|
||||
return f"Error: Failed to read file {file_path}. {str(e)}"
|
||||
|
||||
def _generate_description(self) -> None:
|
||||
"""Generate the tool description based on file path.
|
||||
|
||||
This method updates the tool's description to include information about
|
||||
the default file path while maintaining the ability to specify a different
|
||||
file at runtime.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.description = f"A tool that can be used to read {self.file_path}'s content."
|
||||
return f"Error: Failed to read file {file_path}. {str(e)}"
|
||||
@@ -15,9 +15,7 @@ class FileWriterToolInput(BaseModel):
|
||||
|
||||
class FileWriterTool(BaseTool):
|
||||
name: str = "File Writer Tool"
|
||||
description: str = (
|
||||
"A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input."
|
||||
)
|
||||
description: str = "A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input."
|
||||
args_schema: Type[BaseModel] = FileWriterToolInput
|
||||
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
|
||||
@@ -88,4 +88,3 @@ except ImportError:
|
||||
"""
|
||||
When this tool is not used, then exception can be ignored.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -78,4 +78,3 @@ except ImportError:
|
||||
"""
|
||||
When this tool is not used, then exception can be ignored.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -27,9 +27,7 @@ class GithubSearchToolSchema(FixedGithubSearchToolSchema):
|
||||
|
||||
class GithubSearchTool(RagTool):
|
||||
name: str = "Search a github repo's content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
|
||||
)
|
||||
description: str = "A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
|
||||
summarize: bool = False
|
||||
gh_token: str
|
||||
args_schema: Type[BaseModel] = GithubSearchToolSchema
|
||||
|
||||
42
src/crewai_tools/tools/hyperbrowser_load_tool/README.md
Normal file
42
src/crewai_tools/tools/hyperbrowser_load_tool/README.md
Normal file
@@ -0,0 +1,42 @@
|
||||
# HyperbrowserLoadTool
|
||||
|
||||
## Description
|
||||
|
||||
[Hyperbrowser](https://hyperbrowser.ai) is a platform for running and scaling headless browsers. It lets you launch and manage browser sessions at scale and provides easy to use solutions for any webscraping needs, such as scraping a single page or crawling an entire site.
|
||||
|
||||
Key Features:
|
||||
- Instant Scalability - Spin up hundreds of browser sessions in seconds without infrastructure headaches
|
||||
- Simple Integration - Works seamlessly with popular tools like Puppeteer and Playwright
|
||||
- Powerful APIs - Easy to use APIs for scraping/crawling any site, and much more
|
||||
- Bypass Anti-Bot Measures - Built-in stealth mode, ad blocking, automatic CAPTCHA solving, and rotating proxies
|
||||
|
||||
For more information about Hyperbrowser, please visit the [Hyperbrowser website](https://hyperbrowser.ai) or if you want to check out the docs, you can visit the [Hyperbrowser docs](https://docs.hyperbrowser.ai).
|
||||
|
||||
## Installation
|
||||
|
||||
- Head to [Hyperbrowser](https://app.hyperbrowser.ai/) to sign up and generate an API key. Once you've done this set the `HYPERBROWSER_API_KEY` environment variable or you can pass it to the `HyperbrowserLoadTool` constructor.
|
||||
- Install the [Hyperbrowser SDK](https://github.com/hyperbrowserai/python-sdk):
|
||||
|
||||
```
|
||||
pip install hyperbrowser 'crewai[tools]'
|
||||
```
|
||||
|
||||
## Example
|
||||
|
||||
Utilize the HyperbrowserLoadTool as follows to allow your agent to load websites:
|
||||
|
||||
```python
|
||||
from crewai_tools import HyperbrowserLoadTool
|
||||
|
||||
tool = HyperbrowserLoadTool()
|
||||
```
|
||||
|
||||
## Arguments
|
||||
|
||||
`__init__` arguments:
|
||||
- `api_key`: Optional. Specifies Hyperbrowser API key. Defaults to the `HYPERBROWSER_API_KEY` environment variable.
|
||||
|
||||
`run` arguments:
|
||||
- `url`: The base URL to start scraping or crawling from.
|
||||
- `operation`: Optional. Specifies the operation to perform on the website. Either 'scrape' or 'crawl'. Defaults is 'scrape'.
|
||||
- `params`: Optional. Specifies the params for the operation. For more information on the supported params, visit https://docs.hyperbrowser.ai/reference/sdks/python/scrape#start-scrape-job-and-wait or https://docs.hyperbrowser.ai/reference/sdks/python/crawl#start-crawl-job-and-wait.
|
||||
@@ -0,0 +1,103 @@
|
||||
import os
|
||||
from typing import Any, Optional, Type, Dict, Literal, Union
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class HyperbrowserLoadToolSchema(BaseModel):
|
||||
url: str = Field(description="Website URL")
|
||||
operation: Literal['scrape', 'crawl'] = Field(description="Operation to perform on the website. Either 'scrape' or 'crawl'")
|
||||
params: Optional[Dict] = Field(description="Optional params for scrape or crawl. For more information on the supported params, visit https://docs.hyperbrowser.ai/reference/sdks/python/scrape#start-scrape-job-and-wait or https://docs.hyperbrowser.ai/reference/sdks/python/crawl#start-crawl-job-and-wait")
|
||||
|
||||
class HyperbrowserLoadTool(BaseTool):
|
||||
"""HyperbrowserLoadTool.
|
||||
|
||||
Scrape or crawl web pages and load the contents with optional parameters for configuring content extraction.
|
||||
Requires the `hyperbrowser` package.
|
||||
Get your API Key from https://app.hyperbrowser.ai/
|
||||
|
||||
Args:
|
||||
api_key: The Hyperbrowser API key, can be set as an environment variable `HYPERBROWSER_API_KEY` or passed directly
|
||||
"""
|
||||
name: str = "Hyperbrowser web load tool"
|
||||
description: str = "Scrape or crawl a website using Hyperbrowser and return the contents in properly formatted markdown or html"
|
||||
args_schema: Type[BaseModel] = HyperbrowserLoadToolSchema
|
||||
api_key: Optional[str] = None
|
||||
hyperbrowser: Optional[Any] = None
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = api_key or os.getenv('HYPERBROWSER_API_KEY')
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"`api_key` is required, please set the `HYPERBROWSER_API_KEY` environment variable or pass it directly"
|
||||
)
|
||||
|
||||
try:
|
||||
from hyperbrowser import Hyperbrowser
|
||||
except ImportError:
|
||||
raise ImportError("`hyperbrowser` package not found, please run `pip install hyperbrowser`")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("HYPERBROWSER_API_KEY is not set. Please provide it either via the constructor with the `api_key` argument or by setting the HYPERBROWSER_API_KEY environment variable.")
|
||||
|
||||
self.hyperbrowser = Hyperbrowser(api_key=self.api_key)
|
||||
|
||||
def _prepare_params(self, params: Dict) -> Dict:
|
||||
"""Prepare session and scrape options parameters."""
|
||||
try:
|
||||
from hyperbrowser.models.session import CreateSessionParams
|
||||
from hyperbrowser.models.scrape import ScrapeOptions
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`hyperbrowser` package not found, please run `pip install hyperbrowser`"
|
||||
)
|
||||
|
||||
if "scrape_options" in params:
|
||||
if "formats" in params["scrape_options"]:
|
||||
formats = params["scrape_options"]["formats"]
|
||||
if not all(fmt in ["markdown", "html"] for fmt in formats):
|
||||
raise ValueError("formats can only contain 'markdown' or 'html'")
|
||||
|
||||
if "session_options" in params:
|
||||
params["session_options"] = CreateSessionParams(**params["session_options"])
|
||||
if "scrape_options" in params:
|
||||
params["scrape_options"] = ScrapeOptions(**params["scrape_options"])
|
||||
return params
|
||||
|
||||
def _extract_content(self, data: Union[Any, None]):
|
||||
"""Extract content from response data."""
|
||||
content = ""
|
||||
if data:
|
||||
content = data.markdown or data.html or ""
|
||||
return content
|
||||
|
||||
def _run(self, url: str, operation: Literal['scrape', 'crawl'] = 'scrape', params: Optional[Dict] = {}):
|
||||
try:
|
||||
from hyperbrowser.models.scrape import StartScrapeJobParams
|
||||
from hyperbrowser.models.crawl import StartCrawlJobParams
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`hyperbrowser` package not found, please run `pip install hyperbrowser`"
|
||||
)
|
||||
|
||||
params = self._prepare_params(params)
|
||||
|
||||
if operation == 'scrape':
|
||||
scrape_params = StartScrapeJobParams(url=url, **params)
|
||||
scrape_resp = self.hyperbrowser.scrape.start_and_wait(scrape_params)
|
||||
content = self._extract_content(scrape_resp.data)
|
||||
return content
|
||||
else:
|
||||
crawl_params = StartCrawlJobParams(url=url, **params)
|
||||
crawl_resp = self.hyperbrowser.crawl.start_and_wait(crawl_params)
|
||||
content = ""
|
||||
if crawl_resp.data:
|
||||
for page in crawl_resp.data:
|
||||
page_content = self._extract_content(page)
|
||||
if page_content:
|
||||
content += (
|
||||
f"\n{'-'*50}\nUrl: {page.url}\nContent:\n{page_content}\n"
|
||||
)
|
||||
return content
|
||||
@@ -13,9 +13,7 @@ class JinaScrapeWebsiteToolInput(BaseModel):
|
||||
|
||||
class JinaScrapeWebsiteTool(BaseTool):
|
||||
name: str = "JinaScrapeWebsiteTool"
|
||||
description: str = (
|
||||
"A tool that can be used to read a website content using Jina.ai reader and return markdown content."
|
||||
)
|
||||
description: str = "A tool that can be used to read a website content using Jina.ai reader and return markdown content."
|
||||
args_schema: Type[BaseModel] = JinaScrapeWebsiteToolInput
|
||||
website_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
|
||||
@@ -19,6 +19,10 @@ class LinkupSearchTool(BaseTool):
|
||||
"Performs an API call to Linkup to retrieve contextual information."
|
||||
)
|
||||
_client: LinkupClient = PrivateAttr() # type: ignore
|
||||
description: str = (
|
||||
"Performs an API call to Linkup to retrieve contextual information."
|
||||
)
|
||||
_client: LinkupClient = PrivateAttr() # type: ignore
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
"""
|
||||
|
||||
@@ -17,9 +17,7 @@ class MySQLSearchToolSchema(BaseModel):
|
||||
|
||||
class MySQLSearchTool(RagTool):
|
||||
name: str = "Search a database's table content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a database table's content."
|
||||
)
|
||||
description: str = "A tool that can be used to semantic search a query from a database table's content."
|
||||
args_schema: Type[BaseModel] = MySQLSearchToolSchema
|
||||
db_uri: str = Field(..., description="Mandatory database URI")
|
||||
|
||||
|
||||
@@ -1,30 +1,24 @@
|
||||
from crewai import Agent, Crew, Task
|
||||
from patronus_eval_tool import (
|
||||
PatronusEvalTool,
|
||||
)
|
||||
from patronus_local_evaluator_tool import (
|
||||
PatronusLocalEvaluatorTool,
|
||||
)
|
||||
from patronus_predefined_criteria_eval_tool import (
|
||||
PatronusPredefinedCriteriaEvalTool,
|
||||
)
|
||||
from patronus import Client, EvaluationResult
|
||||
import random
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from patronus import Client, EvaluationResult
|
||||
from patronus_local_evaluator_tool import PatronusLocalEvaluatorTool
|
||||
|
||||
# Test the PatronusLocalEvaluatorTool where agent uses the local evaluator
|
||||
client = Client()
|
||||
|
||||
|
||||
# Example of an evaluator that returns a random pass/fail result
|
||||
@client.register_local_evaluator("random_evaluator")
|
||||
def random_evaluator(**kwargs):
|
||||
score = random.random()
|
||||
return EvaluationResult(
|
||||
score_raw=score,
|
||||
pass_=score >= 0.5,
|
||||
explanation="example explanation" # Optional justification for LLM judges
|
||||
score_raw=score,
|
||||
pass_=score >= 0.5,
|
||||
explanation="example explanation", # Optional justification for LLM judges
|
||||
)
|
||||
|
||||
|
||||
# 1. Uses PatronusEvalTool: agent can pick the best evaluator and criteria
|
||||
# patronus_eval_tool = PatronusEvalTool()
|
||||
|
||||
@@ -35,7 +29,9 @@ def random_evaluator(**kwargs):
|
||||
|
||||
# 3. Uses PatronusLocalEvaluatorTool: agent uses user defined evaluator
|
||||
patronus_eval_tool = PatronusLocalEvaluatorTool(
|
||||
patronus_client=client, evaluator="random_evaluator", evaluated_model_gold_answer="example label"
|
||||
patronus_client=client,
|
||||
evaluator="random_evaluator",
|
||||
evaluated_model_gold_answer="example label",
|
||||
)
|
||||
|
||||
# Create a new agent
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import os
|
||||
import json
|
||||
import requests
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, List, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
|
||||
@@ -19,7 +20,9 @@ class PatronusEvalTool(BaseTool):
|
||||
self.evaluators = temp_evaluators
|
||||
self.criteria = temp_criteria
|
||||
self.description = self._generate_description()
|
||||
warnings.warn("You are allowing the agent to select the best evaluator and criteria when you use the `PatronusEvalTool`. If this is not intended then please use `PatronusPredefinedCriteriaEvalTool` instead.")
|
||||
warnings.warn(
|
||||
"You are allowing the agent to select the best evaluator and criteria when you use the `PatronusEvalTool`. If this is not intended then please use `PatronusPredefinedCriteriaEvalTool` instead."
|
||||
)
|
||||
|
||||
def _init_run(self):
|
||||
evaluators_set = json.loads(
|
||||
@@ -104,7 +107,6 @@ class PatronusEvalTool(BaseTool):
|
||||
evaluated_model_retrieved_context: Optional[str],
|
||||
evaluators: List[Dict[str, str]],
|
||||
) -> Any:
|
||||
|
||||
# Assert correct format of evaluators
|
||||
evals = []
|
||||
for ev in evaluators:
|
||||
@@ -136,4 +138,4 @@ class PatronusEvalTool(BaseTool):
|
||||
f"Failed to evaluate model input and output. Response status code: {response.status_code}. Reason: {response.text}"
|
||||
)
|
||||
|
||||
return response.json()
|
||||
return response.json()
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from typing import Any, Type
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from patronus import Client
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
try:
|
||||
@@ -32,12 +34,20 @@ class PatronusLocalEvaluatorTool(BaseTool):
|
||||
evaluator: str = "The registered local evaluator"
|
||||
evaluated_model_gold_answer: str = "The agent's gold answer"
|
||||
description: str = "This tool is used to evaluate the model input and output using custom function evaluators."
|
||||
description: str = "This tool is used to evaluate the model input and output using custom function evaluators."
|
||||
client: Any = None
|
||||
args_schema: Type[BaseModel] = FixedLocalEvaluatorToolSchema
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patronus_client: Client,
|
||||
evaluator: str,
|
||||
evaluated_model_gold_answer: str,
|
||||
**kwargs: Any,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
patronus_client: Client,
|
||||
@@ -105,6 +115,7 @@ class PatronusLocalEvaluatorTool(BaseTool):
|
||||
else evaluated_model_gold_answer.get("description")
|
||||
),
|
||||
tags={}, # Optional metadata, supports arbitrary kv pairs
|
||||
tags={}, # Optional metadata, supports arbitrary kv pairs
|
||||
)
|
||||
output = f"Evaluation result: {result.pass_}, Explanation: {result.explanation}"
|
||||
return output
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
import requests
|
||||
from typing import Any, List, Dict, Type
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -33,9 +34,7 @@ class PatronusPredefinedCriteriaEvalTool(BaseTool):
|
||||
"""
|
||||
|
||||
name: str = "Call Patronus API tool for evaluation of model inputs and outputs"
|
||||
description: str = (
|
||||
"""This tool calls the Patronus Evaluation API that takes the following arguments:"""
|
||||
)
|
||||
description: str = """This tool calls the Patronus Evaluation API that takes the following arguments:"""
|
||||
evaluate_url: str = "https://api.patronus.ai/v1/evaluate"
|
||||
args_schema: Type[BaseModel] = FixedBaseToolSchema
|
||||
evaluators: List[Dict[str, str]] = []
|
||||
@@ -52,7 +51,6 @@ class PatronusPredefinedCriteriaEvalTool(BaseTool):
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
|
||||
evaluated_model_input = kwargs.get("evaluated_model_input")
|
||||
evaluated_model_output = kwargs.get("evaluated_model_output")
|
||||
evaluated_model_retrieved_context = kwargs.get(
|
||||
@@ -103,4 +101,4 @@ class PatronusPredefinedCriteriaEvalTool(BaseTool):
|
||||
f"Failed to evaluate model input and output. Status code: {response.status_code}. Reason: {response.text}"
|
||||
)
|
||||
|
||||
return response.json()
|
||||
return response.json()
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
from pypdf import PdfReader, PdfWriter, PageObject, ContentStream, NameObject, Font
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pypdf import ContentStream, Font, NameObject, PageObject, PdfReader, PdfWriter
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
|
||||
|
||||
@@ -17,9 +17,7 @@ class PGSearchToolSchema(BaseModel):
|
||||
|
||||
class PGSearchTool(RagTool):
|
||||
name: str = "Search a database's table content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a database table's content."
|
||||
)
|
||||
description: str = "A tool that can be used to semantic search a query from a database table's content."
|
||||
args_schema: Type[BaseModel] = PGSearchToolSchema
|
||||
db_uri: str = Field(..., description="Mandatory database URI")
|
||||
|
||||
|
||||
@@ -10,8 +10,6 @@ from pydantic import BaseModel, Field
|
||||
class FixedScrapeElementFromWebsiteToolSchema(BaseModel):
|
||||
"""Input for ScrapeElementFromWebsiteTool."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ScrapeElementFromWebsiteToolSchema(FixedScrapeElementFromWebsiteToolSchema):
|
||||
"""Input for ScrapeElementFromWebsiteTool."""
|
||||
|
||||
@@ -11,8 +11,6 @@ from pydantic import BaseModel, Field
|
||||
class FixedScrapeWebsiteToolSchema(BaseModel):
|
||||
"""Input for ScrapeWebsiteTool."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ScrapeWebsiteToolSchema(FixedScrapeWebsiteToolSchema):
|
||||
"""Input for ScrapeWebsiteTool."""
|
||||
|
||||
@@ -13,20 +13,14 @@ if TYPE_CHECKING:
|
||||
class ScrapegraphError(Exception):
|
||||
"""Base exception for Scrapegraph-related errors"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RateLimitError(ScrapegraphError):
|
||||
"""Raised when API rate limits are exceeded"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class FixedScrapegraphScrapeToolSchema(BaseModel):
|
||||
"""Input for ScrapegraphScrapeTool when website_url is fixed."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ScrapegraphScrapeToolSchema(FixedScrapegraphScrapeToolSchema):
|
||||
"""Input for ScrapegraphScrapeTool."""
|
||||
@@ -71,6 +65,7 @@ class ScrapegraphScrapeTool(BaseTool):
|
||||
website_url: Optional[str] = None
|
||||
user_prompt: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
enable_logging: bool = False
|
||||
_client: Optional["Client"] = None
|
||||
|
||||
def __init__(
|
||||
@@ -78,6 +73,7 @@ class ScrapegraphScrapeTool(BaseTool):
|
||||
website_url: Optional[str] = None,
|
||||
user_prompt: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
enable_logging: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
@@ -118,8 +114,9 @@ class ScrapegraphScrapeTool(BaseTool):
|
||||
if user_prompt is not None:
|
||||
self.user_prompt = user_prompt
|
||||
|
||||
# Configure logging
|
||||
sgai_logger.set_logging(level="INFO")
|
||||
# Configure logging only if enabled
|
||||
if self.enable_logging:
|
||||
sgai_logger.set_logging(level="INFO")
|
||||
|
||||
@staticmethod
|
||||
def _validate_url(url: str) -> None:
|
||||
@@ -172,8 +169,7 @@ class ScrapegraphScrapeTool(BaseTool):
|
||||
user_prompt=user_prompt,
|
||||
)
|
||||
|
||||
# Handle and validate the response
|
||||
return self._handle_api_response(response)
|
||||
return response
|
||||
|
||||
except RateLimitError:
|
||||
raise # Re-raise rate limit errors
|
||||
|
||||
@@ -163,7 +163,7 @@ class SeleniumScrapingTool(BaseTool):
|
||||
if not re.match(r"^https?://", url):
|
||||
raise ValueError("URL must start with http:// or https://")
|
||||
|
||||
options = self._options()
|
||||
options = Options()
|
||||
options.add_argument("--headless")
|
||||
driver = self.driver(options=options)
|
||||
driver.get(url)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Optional, Any, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Type, Optional
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
import re
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Any, Type, Optional
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
import re
|
||||
from pydantic import BaseModel, Field
|
||||
from .serpapi_base_tool import SerpApiBaseTool
|
||||
from pydantic import ConfigDict
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Type
|
||||
|
||||
import requests
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _save_results_to_file(content: str) -> None:
|
||||
"""Saves the search results to a file."""
|
||||
try:
|
||||
|
||||
@@ -18,9 +18,7 @@ class SerplyWebpageToMarkdownToolSchema(BaseModel):
|
||||
|
||||
class SerplyWebpageToMarkdownTool(RagTool):
|
||||
name: str = "Webpage to Markdown"
|
||||
description: str = (
|
||||
"A tool to perform convert a webpage to markdown to make it easier for LLMs to understand"
|
||||
)
|
||||
description: str = "A tool to perform convert a webpage to markdown to make it easier for LLMs to understand"
|
||||
args_schema: Type[BaseModel] = SerplyWebpageToMarkdownToolSchema
|
||||
request_url: str = "https://api.serply.io/v1/request"
|
||||
proxy_location: Optional[str] = "US"
|
||||
|
||||
155
src/crewai_tools/tools/snowflake_search_tool/README.md
Normal file
155
src/crewai_tools/tools/snowflake_search_tool/README.md
Normal file
@@ -0,0 +1,155 @@
|
||||
# Snowflake Search Tool
|
||||
|
||||
A tool for executing queries on Snowflake data warehouse with built-in connection pooling, retry logic, and async execution support.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
uv sync --extra snowflake
|
||||
|
||||
OR
|
||||
uv pip install snowflake-connector-python>=3.5.0 snowflake-sqlalchemy>=1.5.0 cryptography>=41.0.0
|
||||
|
||||
OR
|
||||
pip install snowflake-connector-python>=3.5.0 snowflake-sqlalchemy>=1.5.0 cryptography>=41.0.0
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from crewai_tools import SnowflakeSearchTool, SnowflakeConfig
|
||||
|
||||
# Create configuration
|
||||
config = SnowflakeConfig(
|
||||
account="your_account",
|
||||
user="your_username",
|
||||
password="your_password",
|
||||
warehouse="COMPUTE_WH",
|
||||
database="your_database",
|
||||
snowflake_schema="your_schema" # Note: Uses snowflake_schema instead of schema
|
||||
)
|
||||
|
||||
# Initialize tool
|
||||
tool = SnowflakeSearchTool(
|
||||
config=config,
|
||||
pool_size=5,
|
||||
max_retries=3,
|
||||
enable_caching=True
|
||||
)
|
||||
|
||||
# Execute query
|
||||
async def main():
|
||||
results = await tool._run(
|
||||
query="SELECT * FROM your_table LIMIT 10",
|
||||
timeout=300
|
||||
)
|
||||
print(f"Retrieved {len(results)} rows")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
- ✨ Asynchronous query execution
|
||||
- 🚀 Connection pooling for better performance
|
||||
- 🔄 Automatic retries for transient failures
|
||||
- 💾 Query result caching (optional)
|
||||
- 🔒 Support for both password and key-pair authentication
|
||||
- 📝 Comprehensive error handling and logging
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### SnowflakeConfig Parameters
|
||||
|
||||
| Parameter | Required | Description |
|
||||
|-----------|----------|-------------|
|
||||
| account | Yes | Snowflake account identifier |
|
||||
| user | Yes | Snowflake username |
|
||||
| password | Yes* | Snowflake password |
|
||||
| private_key_path | No* | Path to private key file (alternative to password) |
|
||||
| warehouse | Yes | Snowflake warehouse name |
|
||||
| database | Yes | Default database |
|
||||
| snowflake_schema | Yes | Default schema |
|
||||
| role | No | Snowflake role |
|
||||
| session_parameters | No | Custom session parameters dict |
|
||||
|
||||
\* Either password or private_key_path must be provided
|
||||
|
||||
### Tool Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| pool_size | 5 | Number of connections in the pool |
|
||||
| max_retries | 3 | Maximum retry attempts for failed queries |
|
||||
| retry_delay | 1.0 | Delay between retries in seconds |
|
||||
| enable_caching | True | Enable/disable query result caching |
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Using Key-Pair Authentication
|
||||
|
||||
```python
|
||||
config = SnowflakeConfig(
|
||||
account="your_account",
|
||||
user="your_username",
|
||||
private_key_path="/path/to/private_key.p8",
|
||||
warehouse="your_warehouse",
|
||||
database="your_database",
|
||||
snowflake_schema="your_schema"
|
||||
)
|
||||
```
|
||||
|
||||
### Custom Session Parameters
|
||||
|
||||
```python
|
||||
config = SnowflakeConfig(
|
||||
# ... other config parameters ...
|
||||
session_parameters={
|
||||
"QUERY_TAG": "my_app",
|
||||
"TIMEZONE": "America/Los_Angeles"
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Error Handling**: Always wrap query execution in try-except blocks
|
||||
2. **Logging**: Enable logging to track query execution and errors
|
||||
3. **Connection Management**: Use appropriate pool sizes for your workload
|
||||
4. **Timeouts**: Set reasonable query timeouts to prevent hanging
|
||||
5. **Security**: Use key-pair auth in production and never hardcode credentials
|
||||
|
||||
## Example with Logging
|
||||
|
||||
```python
|
||||
import logging
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# ... tool initialization ...
|
||||
results = await tool._run(query="SELECT * FROM table LIMIT 10")
|
||||
logger.info(f"Query completed successfully. Retrieved {len(results)} rows")
|
||||
except Exception as e:
|
||||
logger.error(f"Query failed: {str(e)}")
|
||||
raise
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
The tool automatically handles common Snowflake errors:
|
||||
- DatabaseError
|
||||
- OperationalError
|
||||
- ProgrammingError
|
||||
- Network timeouts
|
||||
- Connection issues
|
||||
|
||||
Errors are logged and retried based on your retry configuration.
|
||||
11
src/crewai_tools/tools/snowflake_search_tool/__init__.py
Normal file
11
src/crewai_tools/tools/snowflake_search_tool/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .snowflake_search_tool import (
|
||||
SnowflakeConfig,
|
||||
SnowflakeSearchTool,
|
||||
SnowflakeSearchToolInput,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SnowflakeSearchTool",
|
||||
"SnowflakeSearchToolInput",
|
||||
"SnowflakeConfig",
|
||||
]
|
||||
@@ -0,0 +1,201 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
import snowflake.connector
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
||||
from snowflake.connector.connection import SnowflakeConnection
|
||||
from snowflake.connector.errors import DatabaseError, OperationalError
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache for query results
|
||||
_query_cache = {}
|
||||
|
||||
|
||||
class SnowflakeConfig(BaseModel):
|
||||
"""Configuration for Snowflake connection."""
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
account: str = Field(
|
||||
..., description="Snowflake account identifier", pattern=r"^[a-zA-Z0-9\-_]+$"
|
||||
)
|
||||
user: str = Field(..., description="Snowflake username")
|
||||
password: Optional[SecretStr] = Field(None, description="Snowflake password")
|
||||
private_key_path: Optional[str] = Field(
|
||||
None, description="Path to private key file"
|
||||
)
|
||||
warehouse: Optional[str] = Field(None, description="Snowflake warehouse")
|
||||
database: Optional[str] = Field(None, description="Default database")
|
||||
snowflake_schema: Optional[str] = Field(None, description="Default schema")
|
||||
role: Optional[str] = Field(None, description="Snowflake role")
|
||||
session_parameters: Optional[Dict[str, Any]] = Field(
|
||||
default_factory=dict, description="Session parameters"
|
||||
)
|
||||
|
||||
@property
|
||||
def has_auth(self) -> bool:
|
||||
return bool(self.password or self.private_key_path)
|
||||
|
||||
def model_post_init(self, *args, **kwargs):
|
||||
if not self.has_auth:
|
||||
raise ValueError("Either password or private_key_path must be provided")
|
||||
|
||||
|
||||
class SnowflakeSearchToolInput(BaseModel):
|
||||
"""Input schema for SnowflakeSearchTool."""
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
query: str = Field(..., description="SQL query or semantic search query to execute")
|
||||
database: Optional[str] = Field(None, description="Override default database")
|
||||
snowflake_schema: Optional[str] = Field(None, description="Override default schema")
|
||||
timeout: Optional[int] = Field(300, description="Query timeout in seconds")
|
||||
|
||||
|
||||
class SnowflakeSearchTool(BaseTool):
|
||||
"""Tool for executing queries and semantic search on Snowflake."""
|
||||
|
||||
name: str = "Snowflake Database Search"
|
||||
description: str = (
|
||||
"Execute SQL queries or semantic search on Snowflake data warehouse. "
|
||||
"Supports both raw SQL and natural language queries."
|
||||
)
|
||||
args_schema: Type[BaseModel] = SnowflakeSearchToolInput
|
||||
|
||||
# Define Pydantic fields
|
||||
config: SnowflakeConfig = Field(
|
||||
..., description="Snowflake connection configuration"
|
||||
)
|
||||
pool_size: int = Field(default=5, description="Size of connection pool")
|
||||
max_retries: int = Field(default=3, description="Maximum retry attempts")
|
||||
retry_delay: float = Field(
|
||||
default=1.0, description="Delay between retries in seconds"
|
||||
)
|
||||
enable_caching: bool = Field(
|
||||
default=True, description="Enable query result caching"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, **data):
|
||||
"""Initialize SnowflakeSearchTool."""
|
||||
super().__init__(**data)
|
||||
self._connection_pool: List[SnowflakeConnection] = []
|
||||
self._pool_lock = asyncio.Lock()
|
||||
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
|
||||
|
||||
async def _get_connection(self) -> SnowflakeConnection:
|
||||
"""Get a connection from the pool or create a new one."""
|
||||
async with self._pool_lock:
|
||||
if not self._connection_pool:
|
||||
conn = self._create_connection()
|
||||
self._connection_pool.append(conn)
|
||||
return self._connection_pool.pop()
|
||||
|
||||
def _create_connection(self) -> SnowflakeConnection:
|
||||
"""Create a new Snowflake connection."""
|
||||
conn_params = {
|
||||
"account": self.config.account,
|
||||
"user": self.config.user,
|
||||
"warehouse": self.config.warehouse,
|
||||
"database": self.config.database,
|
||||
"schema": self.config.snowflake_schema,
|
||||
"role": self.config.role,
|
||||
"session_parameters": self.config.session_parameters,
|
||||
}
|
||||
|
||||
if self.config.password:
|
||||
conn_params["password"] = self.config.password.get_secret_value()
|
||||
elif self.config.private_key_path:
|
||||
with open(self.config.private_key_path, "rb") as key_file:
|
||||
p_key = serialization.load_pem_private_key(
|
||||
key_file.read(), password=None, backend=default_backend()
|
||||
)
|
||||
conn_params["private_key"] = p_key
|
||||
|
||||
return snowflake.connector.connect(**conn_params)
|
||||
|
||||
def _get_cache_key(self, query: str, timeout: int) -> str:
|
||||
"""Generate a cache key for the query."""
|
||||
return f"{self.config.account}:{self.config.database}:{self.config.snowflake_schema}:{query}:{timeout}"
|
||||
|
||||
async def _execute_query(
|
||||
self, query: str, timeout: int = 300
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Execute a query with retries and return results."""
|
||||
if self.enable_caching:
|
||||
cache_key = self._get_cache_key(query, timeout)
|
||||
if cache_key in _query_cache:
|
||||
logger.info("Returning cached result")
|
||||
return _query_cache[cache_key]
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
conn = await self._get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(query, timeout=timeout)
|
||||
|
||||
if not cursor.description:
|
||||
return []
|
||||
|
||||
columns = [col[0] for col in cursor.description]
|
||||
results = [dict(zip(columns, row)) for row in cursor.fetchall()]
|
||||
|
||||
if self.enable_caching:
|
||||
_query_cache[self._get_cache_key(query, timeout)] = results
|
||||
|
||||
return results
|
||||
finally:
|
||||
cursor.close()
|
||||
async with self._pool_lock:
|
||||
self._connection_pool.append(conn)
|
||||
except (DatabaseError, OperationalError) as e:
|
||||
if attempt == self.max_retries - 1:
|
||||
raise
|
||||
await asyncio.sleep(self.retry_delay * (2**attempt))
|
||||
logger.warning(f"Query failed, attempt {attempt + 1}: {str(e)}")
|
||||
continue
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
query: str,
|
||||
database: Optional[str] = None,
|
||||
snowflake_schema: Optional[str] = None,
|
||||
timeout: int = 300,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Execute the search query."""
|
||||
try:
|
||||
# Override database/schema if provided
|
||||
if database:
|
||||
await self._execute_query(f"USE DATABASE {database}")
|
||||
if snowflake_schema:
|
||||
await self._execute_query(f"USE SCHEMA {snowflake_schema}")
|
||||
|
||||
results = await self._execute_query(query, timeout)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing query: {str(e)}")
|
||||
raise
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup connections on deletion."""
|
||||
try:
|
||||
for conn in getattr(self, "_connection_pool", []):
|
||||
try:
|
||||
conn.close()
|
||||
except:
|
||||
pass
|
||||
if hasattr(self, "_thread_pool"):
|
||||
self._thread_pool.shutdown()
|
||||
except:
|
||||
pass
|
||||
@@ -14,9 +14,8 @@ import os
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,30 +1,36 @@
|
||||
import base64
|
||||
from typing import Type, Optional
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
|
||||
class ImagePromptSchema(BaseModel):
|
||||
"""Input for Vision Tool."""
|
||||
|
||||
image_path_url: str = "The image path or URL."
|
||||
|
||||
@validator("image_path_url")
|
||||
def validate_image_path_url(cls, v: str) -> str:
|
||||
if v.startswith("http"):
|
||||
return v
|
||||
|
||||
|
||||
path = Path(v)
|
||||
if not path.exists():
|
||||
raise ValueError(f"Image file does not exist: {v}")
|
||||
|
||||
|
||||
# Validate supported formats
|
||||
valid_extensions = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
|
||||
if path.suffix.lower() not in valid_extensions:
|
||||
raise ValueError(f"Unsupported image format. Supported formats: {valid_extensions}")
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported image format. Supported formats: {valid_extensions}"
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class VisionTool(BaseTool):
|
||||
name: str = "Vision Tool"
|
||||
description: str = (
|
||||
@@ -45,10 +51,10 @@ class VisionTool(BaseTool):
|
||||
image_path_url = kwargs.get("image_path_url")
|
||||
if not image_path_url:
|
||||
return "Image Path or URL is required."
|
||||
|
||||
|
||||
# Validate input using Pydantic
|
||||
ImagePromptSchema(image_path_url=image_path_url)
|
||||
|
||||
|
||||
if image_path_url.startswith("http"):
|
||||
image_data = image_path_url
|
||||
else:
|
||||
@@ -68,12 +74,12 @@ class VisionTool(BaseTool):
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
}
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
max_tokens=300,
|
||||
)
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
|
||||
@@ -15,9 +15,8 @@ except ImportError:
|
||||
Vectorizers = Any
|
||||
Auth = Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class WeaviateToolSchema(BaseModel):
|
||||
|
||||
@@ -25,9 +25,7 @@ class WebsiteSearchToolSchema(FixedWebsiteSearchToolSchema):
|
||||
|
||||
class WebsiteSearchTool(RagTool):
|
||||
name: str = "Search in a specific website"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a specific URL content."
|
||||
)
|
||||
description: str = "A tool that can be used to semantic search a query from a specific URL content."
|
||||
args_schema: Type[BaseModel] = WebsiteSearchToolSchema
|
||||
|
||||
def __init__(self, website: Optional[str] = None, **kwargs):
|
||||
|
||||
@@ -25,9 +25,7 @@ class YoutubeChannelSearchToolSchema(FixedYoutubeChannelSearchToolSchema):
|
||||
|
||||
class YoutubeChannelSearchTool(RagTool):
|
||||
name: str = "Search a Youtube Channels content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a Youtube Channels content."
|
||||
)
|
||||
description: str = "A tool that can be used to semantic search a query from a Youtube Channels content."
|
||||
args_schema: Type[BaseModel] = YoutubeChannelSearchToolSchema
|
||||
|
||||
def __init__(self, youtube_channel_handle: Optional[str] = None, **kwargs):
|
||||
|
||||
@@ -25,9 +25,7 @@ class YoutubeVideoSearchToolSchema(FixedYoutubeVideoSearchToolSchema):
|
||||
|
||||
class YoutubeVideoSearchTool(RagTool):
|
||||
name: str = "Search a Youtube Video content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a Youtube Video content."
|
||||
)
|
||||
description: str = "A tool that can be used to semantic search a query from a Youtube Video content."
|
||||
args_schema: Type[BaseModel] = YoutubeVideoSearchToolSchema
|
||||
|
||||
def __init__(self, youtube_video_url: Optional[str] = None, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user