mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
Merge branch 'main' into main
This commit is contained in:
@@ -16,6 +16,7 @@ from .tools import (
|
||||
FirecrawlScrapeWebsiteTool,
|
||||
FirecrawlSearchTool,
|
||||
GithubSearchTool,
|
||||
HyperbrowserLoadTool,
|
||||
JSONSearchTool,
|
||||
LinkupSearchTool,
|
||||
LlamaIndexTool,
|
||||
@@ -23,6 +24,9 @@ from .tools import (
|
||||
MultiOnTool,
|
||||
MySQLSearchTool,
|
||||
NL2SQLTool,
|
||||
PatronusEvalTool,
|
||||
PatronusLocalEvaluatorTool,
|
||||
PatronusPredefinedCriteriaEvalTool,
|
||||
PDFSearchTool,
|
||||
PGSearchTool,
|
||||
RagTool,
|
||||
@@ -32,20 +36,22 @@ from .tools import (
|
||||
ScrapeWebsiteTool,
|
||||
ScrapflyScrapeWebsiteTool,
|
||||
SeleniumScrapingTool,
|
||||
SerpApiGoogleSearchTool,
|
||||
SerpApiGoogleShoppingTool,
|
||||
SerperDevTool,
|
||||
SerplyJobSearchTool,
|
||||
SerplyNewsSearchTool,
|
||||
SerplyScholarSearchTool,
|
||||
SerplyWebpageToMarkdownTool,
|
||||
SerplyWebSearchTool,
|
||||
SnowflakeConfig,
|
||||
SnowflakeSearchTool,
|
||||
SpiderTool,
|
||||
TXTSearchTool,
|
||||
VisionTool,
|
||||
WeaviateVectorSearchTool,
|
||||
WebsiteSearchTool,
|
||||
XMLSearchTool,
|
||||
YoutubeChannelSearchTool,
|
||||
YoutubeVideoSearchTool,
|
||||
WeaviateVectorSearchTool,
|
||||
SerpApiGoogleSearchTool,
|
||||
SerpApiGoogleShoppingTool,
|
||||
)
|
||||
|
||||
@@ -19,6 +19,7 @@ from .firecrawl_scrape_website_tool.firecrawl_scrape_website_tool import (
|
||||
)
|
||||
from .firecrawl_search_tool.firecrawl_search_tool import FirecrawlSearchTool
|
||||
from .github_search_tool.github_search_tool import GithubSearchTool
|
||||
from .hyperbrowser_load_tool.hyperbrowser_load_tool import HyperbrowserLoadTool
|
||||
from .json_search_tool.json_search_tool import JSONSearchTool
|
||||
from .linkup.linkup_search_tool import LinkupSearchTool
|
||||
from .llamaindex_tool.llamaindex_tool import LlamaIndexTool
|
||||
@@ -26,33 +27,46 @@ from .mdx_seach_tool.mdx_search_tool import MDXSearchTool
|
||||
from .multion_tool.multion_tool import MultiOnTool
|
||||
from .mysql_search_tool.mysql_search_tool import MySQLSearchTool
|
||||
from .nl2sql.nl2sql_tool import NL2SQLTool
|
||||
from .patronus_eval_tool import (
|
||||
PatronusEvalTool,
|
||||
PatronusLocalEvaluatorTool,
|
||||
PatronusPredefinedCriteriaEvalTool,
|
||||
)
|
||||
from .pdf_search_tool.pdf_search_tool import PDFSearchTool
|
||||
from .pg_seach_tool.pg_search_tool import PGSearchTool
|
||||
from .rag.rag_tool import RagTool
|
||||
from .scrape_element_from_website.scrape_element_from_website import (
|
||||
ScrapeElementFromWebsiteTool,
|
||||
)
|
||||
from .scrapegraph_scrape_tool.scrapegraph_scrape_tool import ScrapegraphScrapeTool, ScrapegraphScrapeToolSchema
|
||||
from .scrape_website_tool.scrape_website_tool import ScrapeWebsiteTool
|
||||
from .scrapegraph_scrape_tool.scrapegraph_scrape_tool import (
|
||||
ScrapegraphScrapeTool,
|
||||
ScrapegraphScrapeToolSchema,
|
||||
)
|
||||
from .scrapfly_scrape_website_tool.scrapfly_scrape_website_tool import (
|
||||
ScrapflyScrapeWebsiteTool,
|
||||
)
|
||||
from .selenium_scraping_tool.selenium_scraping_tool import SeleniumScrapingTool
|
||||
from .serpapi_tool.serpapi_google_search_tool import SerpApiGoogleSearchTool
|
||||
from .serpapi_tool.serpapi_google_shopping_tool import SerpApiGoogleShoppingTool
|
||||
from .serper_dev_tool.serper_dev_tool import SerperDevTool
|
||||
from .serply_api_tool.serply_job_search_tool import SerplyJobSearchTool
|
||||
from .serply_api_tool.serply_news_search_tool import SerplyNewsSearchTool
|
||||
from .serply_api_tool.serply_scholar_search_tool import SerplyScholarSearchTool
|
||||
from .serply_api_tool.serply_web_search_tool import SerplyWebSearchTool
|
||||
from .serply_api_tool.serply_webpage_to_markdown_tool import SerplyWebpageToMarkdownTool
|
||||
from .snowflake_search_tool import (
|
||||
SnowflakeConfig,
|
||||
SnowflakeSearchTool,
|
||||
SnowflakeSearchToolInput,
|
||||
)
|
||||
from .spider_tool.spider_tool import SpiderTool
|
||||
from .txt_search_tool.txt_search_tool import TXTSearchTool
|
||||
from .vision_tool.vision_tool import VisionTool
|
||||
from .weaviate_tool.vector_search import WeaviateVectorSearchTool
|
||||
from .website_search.website_search_tool import WebsiteSearchTool
|
||||
from .xml_search_tool.xml_search_tool import XMLSearchTool
|
||||
from .youtube_channel_search_tool.youtube_channel_search_tool import (
|
||||
YoutubeChannelSearchTool,
|
||||
)
|
||||
from .youtube_video_search_tool.youtube_video_search_tool import YoutubeVideoSearchTool
|
||||
from .weaviate_tool.vector_search import WeaviateVectorSearchTool
|
||||
from .serpapi_tool.serpapi_google_search_tool import SerpApiGoogleSearchTool
|
||||
from .serpapi_tool.serpapi_google_shopping_tool import SerpApiGoogleShoppingTool
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
from typing import Any, Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BrowserbaseLoadToolSchema(BaseModel):
|
||||
@@ -11,12 +11,10 @@ class BrowserbaseLoadToolSchema(BaseModel):
|
||||
|
||||
class BrowserbaseLoadTool(BaseTool):
|
||||
name: str = "Browserbase web load tool"
|
||||
description: str = (
|
||||
"Load webpages url in a headless browser using Browserbase and return the contents"
|
||||
)
|
||||
description: str = "Load webpages url in a headless browser using Browserbase and return the contents"
|
||||
args_schema: Type[BaseModel] = BrowserbaseLoadToolSchema
|
||||
api_key: Optional[str] = os.getenv('BROWSERBASE_API_KEY')
|
||||
project_id: Optional[str] = os.getenv('BROWSERBASE_PROJECT_ID')
|
||||
api_key: Optional[str] = os.getenv("BROWSERBASE_API_KEY")
|
||||
project_id: Optional[str] = os.getenv("BROWSERBASE_PROJECT_ID")
|
||||
text_content: Optional[bool] = False
|
||||
session_id: Optional[str] = None
|
||||
proxy: Optional[bool] = None
|
||||
@@ -33,7 +31,9 @@ class BrowserbaseLoadTool(BaseTool):
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if not self.api_key:
|
||||
raise EnvironmentError("BROWSERBASE_API_KEY environment variable is required for initialization")
|
||||
raise EnvironmentError(
|
||||
"BROWSERBASE_API_KEY environment variable is required for initialization"
|
||||
)
|
||||
try:
|
||||
from browserbase import Browserbase # type: ignore
|
||||
except ImportError:
|
||||
|
||||
@@ -2,10 +2,12 @@ import importlib.util
|
||||
import os
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from docker import from_env as docker_from_env
|
||||
from docker import DockerClient
|
||||
from docker.models.containers import Container
|
||||
from docker.errors import ImageNotFound, NotFound
|
||||
from crewai.tools import BaseTool
|
||||
from docker.models.containers import Container
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -43,7 +45,11 @@ class CodeInterpreterTool(BaseTool):
|
||||
Verify if the Docker image is available. Optionally use a user-provided Dockerfile.
|
||||
"""
|
||||
|
||||
client = docker_from_env() if self.user_docker_base_url == None else docker.DockerClient(base_url=self.user_docker_base_url)
|
||||
client = (
|
||||
docker_from_env()
|
||||
if self.user_docker_base_url == None
|
||||
else DockerClient(base_url=self.user_docker_base_url)
|
||||
)
|
||||
|
||||
try:
|
||||
client.images.get(self.default_image_tag)
|
||||
@@ -76,9 +82,7 @@ class CodeInterpreterTool(BaseTool):
|
||||
else:
|
||||
return self.run_code_in_docker(code, libraries_used)
|
||||
|
||||
def _install_libraries(
|
||||
self, container: Container, libraries: List[str]
|
||||
) -> None:
|
||||
def _install_libraries(self, container: Container, libraries: List[str]) -> None:
|
||||
"""
|
||||
Install missing libraries in the Docker container
|
||||
"""
|
||||
|
||||
@@ -8,8 +8,6 @@ from pydantic import BaseModel, Field
|
||||
class FixedDirectoryReadToolSchema(BaseModel):
|
||||
"""Input for DirectoryReadTool."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DirectoryReadToolSchema(FixedDirectoryReadToolSchema):
|
||||
"""Input for DirectoryReadTool."""
|
||||
|
||||
@@ -32,6 +32,7 @@ class FileReadTool(BaseTool):
|
||||
>>> content = tool.run() # Reads /path/to/file.txt
|
||||
>>> content = tool.run(file_path="/path/to/other.txt") # Reads other.txt
|
||||
"""
|
||||
|
||||
name: str = "Read a file's content"
|
||||
description: str = "A tool that reads the content of a file. To use this tool, provide a 'file_path' parameter with the path to the file you want to read."
|
||||
args_schema: Type[BaseModel] = FileReadToolSchema
|
||||
@@ -45,10 +46,11 @@ class FileReadTool(BaseTool):
|
||||
this becomes the default file path for the tool.
|
||||
**kwargs: Additional keyword arguments passed to BaseTool.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
if file_path is not None:
|
||||
self.file_path = file_path
|
||||
self.description = f"A tool that reads file content. The default file is {file_path}, but you can provide a different 'file_path' parameter to read another file."
|
||||
kwargs['description'] = f"A tool that reads file content. The default file is {file_path}, but you can provide a different 'file_path' parameter to read another file."
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.file_path = file_path
|
||||
|
||||
def _run(
|
||||
self,
|
||||
@@ -67,15 +69,3 @@ class FileReadTool(BaseTool):
|
||||
return f"Error: Permission denied when trying to read file: {file_path}"
|
||||
except Exception as e:
|
||||
return f"Error: Failed to read file {file_path}. {str(e)}"
|
||||
|
||||
def _generate_description(self) -> None:
|
||||
"""Generate the tool description based on file path.
|
||||
|
||||
This method updates the tool's description to include information about
|
||||
the default file path while maintaining the ability to specify a different
|
||||
file at runtime.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.description = f"A tool that can be used to read {self.file_path}'s content."
|
||||
|
||||
@@ -15,9 +15,7 @@ class FileWriterToolInput(BaseModel):
|
||||
|
||||
class FileWriterTool(BaseTool):
|
||||
name: str = "File Writer Tool"
|
||||
description: str = (
|
||||
"A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input."
|
||||
)
|
||||
description: str = "A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input."
|
||||
args_schema: Type[BaseModel] = FileWriterToolInput
|
||||
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
|
||||
# Type checking import
|
||||
if TYPE_CHECKING:
|
||||
@@ -12,6 +11,14 @@ if TYPE_CHECKING:
|
||||
|
||||
class FirecrawlCrawlWebsiteToolSchema(BaseModel):
|
||||
url: str = Field(description="Website URL")
|
||||
crawler_options: Optional[Dict[str, Any]] = Field(
|
||||
default=None, description="Options for crawling"
|
||||
)
|
||||
timeout: Optional[int] = Field(
|
||||
default=30000,
|
||||
description="Timeout in milliseconds for the crawling operation. The default value is 30000.",
|
||||
)
|
||||
|
||||
|
||||
class FirecrawlCrawlWebsiteTool(BaseTool):
|
||||
model_config = ConfigDict(
|
||||
@@ -20,25 +27,10 @@ class FirecrawlCrawlWebsiteTool(BaseTool):
|
||||
name: str = "Firecrawl web crawl tool"
|
||||
description: str = "Crawl webpages using Firecrawl and return the contents"
|
||||
args_schema: Type[BaseModel] = FirecrawlCrawlWebsiteToolSchema
|
||||
firecrawl_app: Optional["FirecrawlApp"] = None
|
||||
api_key: Optional[str] = None
|
||||
url: Optional[str] = None
|
||||
params: Optional[Dict[str, Any]] = None
|
||||
poll_interval: Optional[int] = 2
|
||||
idempotency_key: Optional[str] = None
|
||||
_firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None)
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, **kwargs):
|
||||
"""Initialize FirecrawlCrawlWebsiteTool.
|
||||
|
||||
Args:
|
||||
api_key (Optional[str]): Firecrawl API key. If not provided, will check FIRECRAWL_API_KEY env var.
|
||||
url (Optional[str]): Base URL to crawl. Can be overridden by the _run method.
|
||||
firecrawl_app (Optional[FirecrawlApp]): Previously created FirecrawlApp instance.
|
||||
params (Optional[Dict[str, Any]]): Additional parameters to pass to the FirecrawlApp.
|
||||
poll_interval (Optional[int]): Poll interval for the FirecrawlApp.
|
||||
idempotency_key (Optional[str]): Idempotency key for the FirecrawlApp.
|
||||
**kwargs: Additional arguments passed to BaseTool.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
try:
|
||||
from firecrawl import FirecrawlApp # type: ignore
|
||||
@@ -47,28 +39,28 @@ class FirecrawlCrawlWebsiteTool(BaseTool):
|
||||
"`firecrawl` package not found, please run `pip install firecrawl-py`"
|
||||
)
|
||||
|
||||
# Allows passing a previously created FirecrawlApp instance
|
||||
# or builds a new one with the provided API key
|
||||
if not self.firecrawl_app:
|
||||
client_api_key = api_key or os.getenv("FIRECRAWL_API_KEY")
|
||||
if not client_api_key:
|
||||
raise ValueError(
|
||||
"FIRECRAWL_API_KEY is not set. Please provide it either via the constructor "
|
||||
"with the `api_key` argument or by setting the FIRECRAWL_API_KEY environment variable."
|
||||
)
|
||||
self.firecrawl_app = FirecrawlApp(api_key=client_api_key)
|
||||
client_api_key = api_key or os.getenv("FIRECRAWL_API_KEY")
|
||||
if not client_api_key:
|
||||
raise ValueError(
|
||||
"FIRECRAWL_API_KEY is not set. Please provide it either via the constructor "
|
||||
"with the `api_key` argument or by setting the FIRECRAWL_API_KEY environment variable."
|
||||
)
|
||||
self._firecrawl = FirecrawlApp(api_key=client_api_key)
|
||||
|
||||
def _run(self, url: str):
|
||||
# Unless url has been previously set via constructor by the user,
|
||||
# use the url argument provided by the agent at runtime.
|
||||
base_url = self.url or url
|
||||
def _run(
|
||||
self,
|
||||
url: str,
|
||||
crawler_options: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[int] = 30000,
|
||||
):
|
||||
if crawler_options is None:
|
||||
crawler_options = {}
|
||||
|
||||
return self.firecrawl_app.crawl_url(
|
||||
base_url,
|
||||
params=self.params,
|
||||
poll_interval=self.poll_interval,
|
||||
idempotency_key=self.idempotency_key
|
||||
)
|
||||
options = {
|
||||
"crawlerOptions": crawler_options,
|
||||
"timeout": timeout,
|
||||
}
|
||||
return self._firecrawl.crawl_url(url, options)
|
||||
|
||||
|
||||
try:
|
||||
@@ -80,4 +72,3 @@ except ImportError:
|
||||
"""
|
||||
When this tool is not used, then exception can be ignored.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Type
|
||||
from typing import TYPE_CHECKING, Optional, Type
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
|
||||
# Type checking import
|
||||
if TYPE_CHECKING:
|
||||
@@ -10,14 +10,8 @@ if TYPE_CHECKING:
|
||||
|
||||
class FirecrawlScrapeWebsiteToolSchema(BaseModel):
|
||||
url: str = Field(description="Website URL")
|
||||
page_options: Optional[Dict[str, Any]] = Field(
|
||||
default=None, description="Options for page scraping"
|
||||
)
|
||||
extractor_options: Optional[Dict[str, Any]] = Field(
|
||||
default=None, description="Options for data extraction"
|
||||
)
|
||||
timeout: Optional[int] = Field(
|
||||
default=None,
|
||||
default=30000,
|
||||
description="Timeout in milliseconds for the scraping operation. The default value is 30000.",
|
||||
)
|
||||
|
||||
@@ -27,10 +21,10 @@ class FirecrawlScrapeWebsiteTool(BaseTool):
|
||||
arbitrary_types_allowed=True, validate_assignment=True, frozen=False
|
||||
)
|
||||
name: str = "Firecrawl web scrape tool"
|
||||
description: str = "Scrape webpages url using Firecrawl and return the contents"
|
||||
description: str = "Scrape webpages using Firecrawl and return the contents"
|
||||
args_schema: Type[BaseModel] = FirecrawlScrapeWebsiteToolSchema
|
||||
api_key: Optional[str] = None
|
||||
firecrawl: Optional["FirecrawlApp"] = None # Updated to use TYPE_CHECKING
|
||||
_firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None)
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -41,28 +35,23 @@ class FirecrawlScrapeWebsiteTool(BaseTool):
|
||||
"`firecrawl` package not found, please run `pip install firecrawl-py`"
|
||||
)
|
||||
|
||||
self.firecrawl = FirecrawlApp(api_key=api_key)
|
||||
self._firecrawl = FirecrawlApp(api_key=api_key)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
url: str,
|
||||
page_options: Optional[Dict[str, Any]] = None,
|
||||
extractor_options: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[int] = None,
|
||||
timeout: Optional[int] = 30000,
|
||||
):
|
||||
if page_options is None:
|
||||
page_options = {}
|
||||
if extractor_options is None:
|
||||
extractor_options = {}
|
||||
if timeout is None:
|
||||
timeout = 30000
|
||||
|
||||
options = {
|
||||
"pageOptions": page_options,
|
||||
"extractorOptions": extractor_options,
|
||||
"formats": ["markdown"],
|
||||
"onlyMainContent": True,
|
||||
"includeTags": [],
|
||||
"excludeTags": [],
|
||||
"headers": {},
|
||||
"waitFor": 0,
|
||||
"timeout": timeout,
|
||||
}
|
||||
return self.firecrawl.scrape_url(url, options)
|
||||
return self._firecrawl.scrape_url(url, options)
|
||||
|
||||
|
||||
try:
|
||||
@@ -74,4 +63,3 @@ except ImportError:
|
||||
"""
|
||||
When this tool is not used, then exception can be ignored.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Type
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
|
||||
# Type checking import
|
||||
if TYPE_CHECKING:
|
||||
@@ -10,20 +10,34 @@ if TYPE_CHECKING:
|
||||
|
||||
class FirecrawlSearchToolSchema(BaseModel):
|
||||
query: str = Field(description="Search query")
|
||||
page_options: Optional[Dict[str, Any]] = Field(
|
||||
default=None, description="Options for result formatting"
|
||||
limit: Optional[int] = Field(
|
||||
default=5, description="Maximum number of results to return"
|
||||
)
|
||||
search_options: Optional[Dict[str, Any]] = Field(
|
||||
default=None, description="Options for searching"
|
||||
tbs: Optional[str] = Field(default=None, description="Time-based search parameter")
|
||||
lang: Optional[str] = Field(
|
||||
default="en", description="Language code for search results"
|
||||
)
|
||||
country: Optional[str] = Field(
|
||||
default="us", description="Country code for search results"
|
||||
)
|
||||
location: Optional[str] = Field(
|
||||
default=None, description="Location parameter for search results"
|
||||
)
|
||||
timeout: Optional[int] = Field(default=60000, description="Timeout in milliseconds")
|
||||
scrape_options: Optional[Dict[str, Any]] = Field(
|
||||
default=None, description="Options for scraping search results"
|
||||
)
|
||||
|
||||
|
||||
class FirecrawlSearchTool(BaseTool):
|
||||
model_config = ConfigDict(
|
||||
arbitrary_types_allowed=True, validate_assignment=True, frozen=False
|
||||
)
|
||||
name: str = "Firecrawl web search tool"
|
||||
description: str = "Search webpages using Firecrawl and return the results"
|
||||
args_schema: Type[BaseModel] = FirecrawlSearchToolSchema
|
||||
api_key: Optional[str] = None
|
||||
firecrawl: Optional["FirecrawlApp"] = None
|
||||
_firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None)
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -33,19 +47,39 @@ class FirecrawlSearchTool(BaseTool):
|
||||
raise ImportError(
|
||||
"`firecrawl` package not found, please run `pip install firecrawl-py`"
|
||||
)
|
||||
|
||||
self.firecrawl = FirecrawlApp(api_key=api_key)
|
||||
self._firecrawl = FirecrawlApp(api_key=api_key)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
page_options: Optional[Dict[str, Any]] = None,
|
||||
result_options: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = 5,
|
||||
tbs: Optional[str] = None,
|
||||
lang: Optional[str] = "en",
|
||||
country: Optional[str] = "us",
|
||||
location: Optional[str] = None,
|
||||
timeout: Optional[int] = 60000,
|
||||
scrape_options: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if page_options is None:
|
||||
page_options = {}
|
||||
if result_options is None:
|
||||
result_options = {}
|
||||
if scrape_options is None:
|
||||
scrape_options = {}
|
||||
|
||||
options = {"pageOptions": page_options, "resultOptions": result_options}
|
||||
return self.firecrawl.search(query, **options)
|
||||
options = {
|
||||
"limit": limit,
|
||||
"tbs": tbs,
|
||||
"lang": lang,
|
||||
"country": country,
|
||||
"location": location,
|
||||
"timeout": timeout,
|
||||
"scrapeOptions": scrape_options,
|
||||
}
|
||||
return self._firecrawl.search(query, options)
|
||||
|
||||
|
||||
try:
|
||||
from firecrawl import FirecrawlApp
|
||||
|
||||
# Rebuild the model after class is defined
|
||||
FirecrawlSearchTool.model_rebuild()
|
||||
except ImportError:
|
||||
# Exception can be ignored if the tool is not used
|
||||
pass
|
||||
|
||||
@@ -27,9 +27,7 @@ class GithubSearchToolSchema(FixedGithubSearchToolSchema):
|
||||
|
||||
class GithubSearchTool(RagTool):
|
||||
name: str = "Search a github repo's content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
|
||||
)
|
||||
description: str = "A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
|
||||
summarize: bool = False
|
||||
gh_token: str
|
||||
args_schema: Type[BaseModel] = GithubSearchToolSchema
|
||||
|
||||
42
src/crewai_tools/tools/hyperbrowser_load_tool/README.md
Normal file
42
src/crewai_tools/tools/hyperbrowser_load_tool/README.md
Normal file
@@ -0,0 +1,42 @@
|
||||
# HyperbrowserLoadTool
|
||||
|
||||
## Description
|
||||
|
||||
[Hyperbrowser](https://hyperbrowser.ai) is a platform for running and scaling headless browsers. It lets you launch and manage browser sessions at scale and provides easy to use solutions for any webscraping needs, such as scraping a single page or crawling an entire site.
|
||||
|
||||
Key Features:
|
||||
- Instant Scalability - Spin up hundreds of browser sessions in seconds without infrastructure headaches
|
||||
- Simple Integration - Works seamlessly with popular tools like Puppeteer and Playwright
|
||||
- Powerful APIs - Easy to use APIs for scraping/crawling any site, and much more
|
||||
- Bypass Anti-Bot Measures - Built-in stealth mode, ad blocking, automatic CAPTCHA solving, and rotating proxies
|
||||
|
||||
For more information about Hyperbrowser, please visit the [Hyperbrowser website](https://hyperbrowser.ai) or if you want to check out the docs, you can visit the [Hyperbrowser docs](https://docs.hyperbrowser.ai).
|
||||
|
||||
## Installation
|
||||
|
||||
- Head to [Hyperbrowser](https://app.hyperbrowser.ai/) to sign up and generate an API key. Once you've done this set the `HYPERBROWSER_API_KEY` environment variable or you can pass it to the `HyperbrowserLoadTool` constructor.
|
||||
- Install the [Hyperbrowser SDK](https://github.com/hyperbrowserai/python-sdk):
|
||||
|
||||
```
|
||||
pip install hyperbrowser 'crewai[tools]'
|
||||
```
|
||||
|
||||
## Example
|
||||
|
||||
Utilize the HyperbrowserLoadTool as follows to allow your agent to load websites:
|
||||
|
||||
```python
|
||||
from crewai_tools import HyperbrowserLoadTool
|
||||
|
||||
tool = HyperbrowserLoadTool()
|
||||
```
|
||||
|
||||
## Arguments
|
||||
|
||||
`__init__` arguments:
|
||||
- `api_key`: Optional. Specifies Hyperbrowser API key. Defaults to the `HYPERBROWSER_API_KEY` environment variable.
|
||||
|
||||
`run` arguments:
|
||||
- `url`: The base URL to start scraping or crawling from.
|
||||
- `operation`: Optional. Specifies the operation to perform on the website. Either 'scrape' or 'crawl'. Defaults is 'scrape'.
|
||||
- `params`: Optional. Specifies the params for the operation. For more information on the supported params, visit https://docs.hyperbrowser.ai/reference/sdks/python/scrape#start-scrape-job-and-wait or https://docs.hyperbrowser.ai/reference/sdks/python/crawl#start-crawl-job-and-wait.
|
||||
@@ -0,0 +1,103 @@
|
||||
import os
|
||||
from typing import Any, Optional, Type, Dict, Literal, Union
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class HyperbrowserLoadToolSchema(BaseModel):
|
||||
url: str = Field(description="Website URL")
|
||||
operation: Literal['scrape', 'crawl'] = Field(description="Operation to perform on the website. Either 'scrape' or 'crawl'")
|
||||
params: Optional[Dict] = Field(description="Optional params for scrape or crawl. For more information on the supported params, visit https://docs.hyperbrowser.ai/reference/sdks/python/scrape#start-scrape-job-and-wait or https://docs.hyperbrowser.ai/reference/sdks/python/crawl#start-crawl-job-and-wait")
|
||||
|
||||
class HyperbrowserLoadTool(BaseTool):
|
||||
"""HyperbrowserLoadTool.
|
||||
|
||||
Scrape or crawl web pages and load the contents with optional parameters for configuring content extraction.
|
||||
Requires the `hyperbrowser` package.
|
||||
Get your API Key from https://app.hyperbrowser.ai/
|
||||
|
||||
Args:
|
||||
api_key: The Hyperbrowser API key, can be set as an environment variable `HYPERBROWSER_API_KEY` or passed directly
|
||||
"""
|
||||
name: str = "Hyperbrowser web load tool"
|
||||
description: str = "Scrape or crawl a website using Hyperbrowser and return the contents in properly formatted markdown or html"
|
||||
args_schema: Type[BaseModel] = HyperbrowserLoadToolSchema
|
||||
api_key: Optional[str] = None
|
||||
hyperbrowser: Optional[Any] = None
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = api_key or os.getenv('HYPERBROWSER_API_KEY')
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"`api_key` is required, please set the `HYPERBROWSER_API_KEY` environment variable or pass it directly"
|
||||
)
|
||||
|
||||
try:
|
||||
from hyperbrowser import Hyperbrowser
|
||||
except ImportError:
|
||||
raise ImportError("`hyperbrowser` package not found, please run `pip install hyperbrowser`")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("HYPERBROWSER_API_KEY is not set. Please provide it either via the constructor with the `api_key` argument or by setting the HYPERBROWSER_API_KEY environment variable.")
|
||||
|
||||
self.hyperbrowser = Hyperbrowser(api_key=self.api_key)
|
||||
|
||||
def _prepare_params(self, params: Dict) -> Dict:
|
||||
"""Prepare session and scrape options parameters."""
|
||||
try:
|
||||
from hyperbrowser.models.session import CreateSessionParams
|
||||
from hyperbrowser.models.scrape import ScrapeOptions
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`hyperbrowser` package not found, please run `pip install hyperbrowser`"
|
||||
)
|
||||
|
||||
if "scrape_options" in params:
|
||||
if "formats" in params["scrape_options"]:
|
||||
formats = params["scrape_options"]["formats"]
|
||||
if not all(fmt in ["markdown", "html"] for fmt in formats):
|
||||
raise ValueError("formats can only contain 'markdown' or 'html'")
|
||||
|
||||
if "session_options" in params:
|
||||
params["session_options"] = CreateSessionParams(**params["session_options"])
|
||||
if "scrape_options" in params:
|
||||
params["scrape_options"] = ScrapeOptions(**params["scrape_options"])
|
||||
return params
|
||||
|
||||
def _extract_content(self, data: Union[Any, None]):
|
||||
"""Extract content from response data."""
|
||||
content = ""
|
||||
if data:
|
||||
content = data.markdown or data.html or ""
|
||||
return content
|
||||
|
||||
def _run(self, url: str, operation: Literal['scrape', 'crawl'] = 'scrape', params: Optional[Dict] = {}):
|
||||
try:
|
||||
from hyperbrowser.models.scrape import StartScrapeJobParams
|
||||
from hyperbrowser.models.crawl import StartCrawlJobParams
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`hyperbrowser` package not found, please run `pip install hyperbrowser`"
|
||||
)
|
||||
|
||||
params = self._prepare_params(params)
|
||||
|
||||
if operation == 'scrape':
|
||||
scrape_params = StartScrapeJobParams(url=url, **params)
|
||||
scrape_resp = self.hyperbrowser.scrape.start_and_wait(scrape_params)
|
||||
content = self._extract_content(scrape_resp.data)
|
||||
return content
|
||||
else:
|
||||
crawl_params = StartCrawlJobParams(url=url, **params)
|
||||
crawl_resp = self.hyperbrowser.crawl.start_and_wait(crawl_params)
|
||||
content = ""
|
||||
if crawl_resp.data:
|
||||
for page in crawl_resp.data:
|
||||
page_content = self._extract_content(page)
|
||||
if page_content:
|
||||
content += (
|
||||
f"\n{'-'*50}\nUrl: {page.url}\nContent:\n{page_content}\n"
|
||||
)
|
||||
return content
|
||||
@@ -13,9 +13,7 @@ class JinaScrapeWebsiteToolInput(BaseModel):
|
||||
|
||||
class JinaScrapeWebsiteTool(BaseTool):
|
||||
name: str = "JinaScrapeWebsiteTool"
|
||||
description: str = (
|
||||
"A tool that can be used to read a website content using Jina.ai reader and return markdown content."
|
||||
)
|
||||
description: str = "A tool that can be used to read a website content using Jina.ai reader and return markdown content."
|
||||
args_schema: Type[BaseModel] = JinaScrapeWebsiteToolInput
|
||||
website_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Any
|
||||
|
||||
try:
|
||||
from linkup import LinkupClient
|
||||
|
||||
LINKUP_AVAILABLE = True
|
||||
except ImportError:
|
||||
LINKUP_AVAILABLE = False
|
||||
@@ -9,10 +10,13 @@ except ImportError:
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
|
||||
class LinkupSearchTool:
|
||||
name: str = "Linkup Search Tool"
|
||||
description: str = "Performs an API call to Linkup to retrieve contextual information."
|
||||
_client: LinkupClient = PrivateAttr() # type: ignore
|
||||
description: str = (
|
||||
"Performs an API call to Linkup to retrieve contextual information."
|
||||
)
|
||||
_client: LinkupClient = PrivateAttr() # type: ignore
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
"""
|
||||
@@ -25,7 +29,9 @@ class LinkupSearchTool:
|
||||
)
|
||||
self._client = LinkupClient(api_key=api_key)
|
||||
|
||||
def _run(self, query: str, depth: str = "standard", output_type: str = "searchResults") -> dict:
|
||||
def _run(
|
||||
self, query: str, depth: str = "standard", output_type: str = "searchResults"
|
||||
) -> dict:
|
||||
"""
|
||||
Executes a search using the Linkup API.
|
||||
|
||||
@@ -36,9 +42,7 @@ class LinkupSearchTool:
|
||||
"""
|
||||
try:
|
||||
response = self._client.search(
|
||||
query=query,
|
||||
depth=depth,
|
||||
output_type=output_type
|
||||
query=query, depth=depth, output_type=output_type
|
||||
)
|
||||
results = [
|
||||
{"name": result.name, "url": result.url, "content": result.content}
|
||||
|
||||
@@ -17,9 +17,7 @@ class MySQLSearchToolSchema(BaseModel):
|
||||
|
||||
class MySQLSearchTool(RagTool):
|
||||
name: str = "Search a database's table content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a database table's content."
|
||||
)
|
||||
description: str = "A tool that can be used to semantic search a query from a database table's content."
|
||||
args_schema: Type[BaseModel] = MySQLSearchToolSchema
|
||||
db_uri: str = Field(..., description="Mandatory database URI")
|
||||
|
||||
|
||||
3
src/crewai_tools/tools/patronus_eval_tool/__init__.py
Normal file
3
src/crewai_tools/tools/patronus_eval_tool/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .patronus_eval_tool import PatronusEvalTool
|
||||
from .patronus_local_evaluator_tool import PatronusLocalEvaluatorTool
|
||||
from .patronus_predefined_criteria_eval_tool import PatronusPredefinedCriteriaEvalTool
|
||||
55
src/crewai_tools/tools/patronus_eval_tool/example.py
Normal file
55
src/crewai_tools/tools/patronus_eval_tool/example.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import random
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from patronus import Client, EvaluationResult
|
||||
from patronus_local_evaluator_tool import PatronusLocalEvaluatorTool
|
||||
|
||||
# Test the PatronusLocalEvaluatorTool where agent uses the local evaluator
|
||||
client = Client()
|
||||
|
||||
|
||||
# Example of an evaluator that returns a random pass/fail result
|
||||
@client.register_local_evaluator("random_evaluator")
|
||||
def random_evaluator(**kwargs):
|
||||
score = random.random()
|
||||
return EvaluationResult(
|
||||
score_raw=score,
|
||||
pass_=score >= 0.5,
|
||||
explanation="example explanation", # Optional justification for LLM judges
|
||||
)
|
||||
|
||||
|
||||
# 1. Uses PatronusEvalTool: agent can pick the best evaluator and criteria
|
||||
# patronus_eval_tool = PatronusEvalTool()
|
||||
|
||||
# 2. Uses PatronusPredefinedCriteriaEvalTool: agent uses the defined evaluator and criteria
|
||||
# patronus_eval_tool = PatronusPredefinedCriteriaEvalTool(
|
||||
# evaluators=[{"evaluator": "judge", "criteria": "contains-code"}]
|
||||
# )
|
||||
|
||||
# 3. Uses PatronusLocalEvaluatorTool: agent uses user defined evaluator
|
||||
patronus_eval_tool = PatronusLocalEvaluatorTool(
|
||||
patronus_client=client,
|
||||
evaluator="random_evaluator",
|
||||
evaluated_model_gold_answer="example label",
|
||||
)
|
||||
|
||||
# Create a new agent
|
||||
coding_agent = Agent(
|
||||
role="Coding Agent",
|
||||
goal="Generate high quality code and verify that the output is code by using Patronus AI's evaluation tool.",
|
||||
backstory="You are an experienced coder who can generate high quality python code. You can follow complex instructions accurately and effectively.",
|
||||
tools=[patronus_eval_tool],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
# Define tasks
|
||||
generate_code = Task(
|
||||
description="Create a simple program to generate the first N numbers in the Fibonacci sequence. Select the most appropriate evaluator and criteria for evaluating your output.",
|
||||
expected_output="Program that generates the first N numbers in the Fibonacci sequence.",
|
||||
agent=coding_agent,
|
||||
)
|
||||
|
||||
crew = Crew(agents=[coding_agent], tasks=[generate_code])
|
||||
|
||||
crew.kickoff()
|
||||
141
src/crewai_tools/tools/patronus_eval_tool/patronus_eval_tool.py
Normal file
141
src/crewai_tools/tools/patronus_eval_tool/patronus_eval_tool.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
|
||||
class PatronusEvalTool(BaseTool):
|
||||
name: str = "Patronus Evaluation Tool"
|
||||
evaluate_url: str = "https://api.patronus.ai/v1/evaluate"
|
||||
evaluators: List[Dict[str, str]] = []
|
||||
criteria: List[Dict[str, str]] = []
|
||||
description: str = ""
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
temp_evaluators, temp_criteria = self._init_run()
|
||||
self.evaluators = temp_evaluators
|
||||
self.criteria = temp_criteria
|
||||
self.description = self._generate_description()
|
||||
warnings.warn(
|
||||
"You are allowing the agent to select the best evaluator and criteria when you use the `PatronusEvalTool`. If this is not intended then please use `PatronusPredefinedCriteriaEvalTool` instead."
|
||||
)
|
||||
|
||||
def _init_run(self):
|
||||
evaluators_set = json.loads(
|
||||
requests.get(
|
||||
"https://api.patronus.ai/v1/evaluators",
|
||||
headers={
|
||||
"accept": "application/json",
|
||||
"X-API-KEY": os.environ["PATRONUS_API_KEY"],
|
||||
},
|
||||
).text
|
||||
)["evaluators"]
|
||||
ids, evaluators = set(), []
|
||||
for ev in evaluators_set:
|
||||
if not ev["deprecated"] and ev["id"] not in ids:
|
||||
evaluators.append(
|
||||
{
|
||||
"id": ev["id"],
|
||||
"name": ev["name"],
|
||||
"description": ev["description"],
|
||||
"aliases": ev["aliases"],
|
||||
}
|
||||
)
|
||||
ids.add(ev["id"])
|
||||
|
||||
criteria_set = json.loads(
|
||||
requests.get(
|
||||
"https://api.patronus.ai/v1/evaluator-criteria",
|
||||
headers={
|
||||
"accept": "application/json",
|
||||
"X-API-KEY": os.environ["PATRONUS_API_KEY"],
|
||||
},
|
||||
).text
|
||||
)["evaluator_criteria"]
|
||||
criteria = []
|
||||
for cr in criteria_set:
|
||||
if cr["config"].get("pass_criteria", None):
|
||||
if cr["config"].get("rubric", None):
|
||||
criteria.append(
|
||||
{
|
||||
"evaluator": cr["evaluator_family"],
|
||||
"name": cr["name"],
|
||||
"pass_criteria": cr["config"]["pass_criteria"],
|
||||
"rubric": cr["config"]["rubric"],
|
||||
}
|
||||
)
|
||||
else:
|
||||
criteria.append(
|
||||
{
|
||||
"evaluator": cr["evaluator_family"],
|
||||
"name": cr["name"],
|
||||
"pass_criteria": cr["config"]["pass_criteria"],
|
||||
}
|
||||
)
|
||||
elif cr["description"]:
|
||||
criteria.append(
|
||||
{
|
||||
"evaluator": cr["evaluator_family"],
|
||||
"name": cr["name"],
|
||||
"description": cr["description"],
|
||||
}
|
||||
)
|
||||
|
||||
return evaluators, criteria
|
||||
|
||||
def _generate_description(self) -> str:
|
||||
criteria = "\n".join([json.dumps(i) for i in self.criteria])
|
||||
return f"""This tool calls the Patronus Evaluation API that takes the following arguments:
|
||||
1. evaluated_model_input: str: The agent's task description in simple text
|
||||
2. evaluated_model_output: str: The agent's output of the task
|
||||
3. evaluated_model_retrieved_context: str: The agent's context
|
||||
4. evaluators: This is a list of dictionaries containing one of the following evaluators and the corresponding criteria. An example input for this field: [{{"evaluator": "Judge", "criteria": "patronus:is-code"}}]
|
||||
|
||||
Evaluators:
|
||||
{criteria}
|
||||
|
||||
You must ONLY choose the most appropriate evaluator and criteria based on the "pass_criteria" or "description" fields for your evaluation task and nothing from outside of the options present."""
|
||||
|
||||
def _run(
|
||||
self,
|
||||
evaluated_model_input: Optional[str],
|
||||
evaluated_model_output: Optional[str],
|
||||
evaluated_model_retrieved_context: Optional[str],
|
||||
evaluators: List[Dict[str, str]],
|
||||
) -> Any:
|
||||
# Assert correct format of evaluators
|
||||
evals = []
|
||||
for ev in evaluators:
|
||||
evals.append(
|
||||
{
|
||||
"evaluator": ev["evaluator"].lower(),
|
||||
"criteria": ev["name"] if "name" in ev else ev["criteria"],
|
||||
}
|
||||
)
|
||||
|
||||
data = {
|
||||
"evaluated_model_input": evaluated_model_input,
|
||||
"evaluated_model_output": evaluated_model_output,
|
||||
"evaluated_model_retrieved_context": evaluated_model_retrieved_context,
|
||||
"evaluators": evals,
|
||||
}
|
||||
|
||||
headers = {
|
||||
"X-API-KEY": os.getenv("PATRONUS_API_KEY"),
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
self.evaluate_url, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Failed to evaluate model input and output. Response status code: {response.status_code}. Reason: {response.text}"
|
||||
)
|
||||
|
||||
return response.json()
|
||||
@@ -0,0 +1,90 @@
|
||||
from typing import Any, Type
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from patronus import Client
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FixedLocalEvaluatorToolSchema(BaseModel):
|
||||
evaluated_model_input: str = Field(
|
||||
..., description="The agent's task description in simple text"
|
||||
)
|
||||
evaluated_model_output: str = Field(
|
||||
..., description="The agent's output of the task"
|
||||
)
|
||||
evaluated_model_retrieved_context: str = Field(
|
||||
..., description="The agent's context"
|
||||
)
|
||||
evaluated_model_gold_answer: str = Field(
|
||||
..., description="The agent's gold answer only if available"
|
||||
)
|
||||
evaluator: str = Field(..., description="The registered local evaluator")
|
||||
|
||||
|
||||
class PatronusLocalEvaluatorTool(BaseTool):
|
||||
name: str = "Patronus Local Evaluator Tool"
|
||||
evaluator: str = "The registered local evaluator"
|
||||
evaluated_model_gold_answer: str = "The agent's gold answer"
|
||||
description: str = "This tool is used to evaluate the model input and output using custom function evaluators."
|
||||
client: Any = None
|
||||
args_schema: Type[BaseModel] = FixedLocalEvaluatorToolSchema
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patronus_client: Client,
|
||||
evaluator: str,
|
||||
evaluated_model_gold_answer: str,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.client = patronus_client
|
||||
if evaluator:
|
||||
self.evaluator = evaluator
|
||||
self.evaluated_model_gold_answer = evaluated_model_gold_answer
|
||||
self.description = f"This tool calls the Patronus Evaluation API that takes an additional argument in addition to the following new argument:\n evaluators={evaluator}, evaluated_model_gold_answer={evaluated_model_gold_answer}"
|
||||
self._generate_description()
|
||||
print(
|
||||
f"Updating judge evaluator, gold_answer to: {self.evaluator}, {self.evaluated_model_gold_answer}"
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
evaluated_model_input = kwargs.get("evaluated_model_input")
|
||||
evaluated_model_output = kwargs.get("evaluated_model_output")
|
||||
evaluated_model_retrieved_context = kwargs.get(
|
||||
"evaluated_model_retrieved_context"
|
||||
)
|
||||
evaluated_model_gold_answer = self.evaluated_model_gold_answer
|
||||
evaluator = self.evaluator
|
||||
|
||||
result = self.client.evaluate(
|
||||
evaluator=evaluator,
|
||||
evaluated_model_input=(
|
||||
evaluated_model_input
|
||||
if isinstance(evaluated_model_input, str)
|
||||
else evaluated_model_input.get("description")
|
||||
),
|
||||
evaluated_model_output=(
|
||||
evaluated_model_output
|
||||
if isinstance(evaluated_model_output, str)
|
||||
else evaluated_model_output.get("description")
|
||||
),
|
||||
evaluated_model_retrieved_context=(
|
||||
evaluated_model_retrieved_context
|
||||
if isinstance(evaluated_model_retrieved_context, str)
|
||||
else evaluated_model_retrieved_context.get("description")
|
||||
),
|
||||
evaluated_model_gold_answer=(
|
||||
evaluated_model_gold_answer
|
||||
if isinstance(evaluated_model_gold_answer, str)
|
||||
else evaluated_model_gold_answer.get("description")
|
||||
),
|
||||
tags={}, # Optional metadata, supports arbitrary kv pairs
|
||||
)
|
||||
output = f"Evaluation result: {result.pass_}, Explanation: {result.explanation}"
|
||||
return output
|
||||
@@ -0,0 +1,104 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
import requests
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FixedBaseToolSchema(BaseModel):
|
||||
evaluated_model_input: Dict = Field(
|
||||
..., description="The agent's task description in simple text"
|
||||
)
|
||||
evaluated_model_output: Dict = Field(
|
||||
..., description="The agent's output of the task"
|
||||
)
|
||||
evaluated_model_retrieved_context: Dict = Field(
|
||||
..., description="The agent's context"
|
||||
)
|
||||
evaluated_model_gold_answer: Dict = Field(
|
||||
..., description="The agent's gold answer only if available"
|
||||
)
|
||||
evaluators: List[Dict[str, str]] = Field(
|
||||
...,
|
||||
description="List of dictionaries containing the evaluator and criteria to evaluate the model input and output. An example input for this field: [{'evaluator': '[evaluator-from-user]', 'criteria': '[criteria-from-user]'}]",
|
||||
)
|
||||
|
||||
|
||||
class PatronusPredefinedCriteriaEvalTool(BaseTool):
|
||||
"""
|
||||
PatronusEvalTool is a tool to automatically evaluate and score agent interactions.
|
||||
|
||||
Results are logged to the Patronus platform at app.patronus.ai
|
||||
"""
|
||||
|
||||
name: str = "Call Patronus API tool for evaluation of model inputs and outputs"
|
||||
description: str = """This tool calls the Patronus Evaluation API that takes the following arguments:"""
|
||||
evaluate_url: str = "https://api.patronus.ai/v1/evaluate"
|
||||
args_schema: Type[BaseModel] = FixedBaseToolSchema
|
||||
evaluators: List[Dict[str, str]] = []
|
||||
|
||||
def __init__(self, evaluators: List[Dict[str, str]], **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
if evaluators:
|
||||
self.evaluators = evaluators
|
||||
self.description = f"This tool calls the Patronus Evaluation API that takes an additional argument in addition to the following new argument:\n evaluators={evaluators}"
|
||||
self._generate_description()
|
||||
print(f"Updating judge criteria to: {self.evaluators}")
|
||||
|
||||
def _run(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
evaluated_model_input = kwargs.get("evaluated_model_input")
|
||||
evaluated_model_output = kwargs.get("evaluated_model_output")
|
||||
evaluated_model_retrieved_context = kwargs.get(
|
||||
"evaluated_model_retrieved_context"
|
||||
)
|
||||
evaluated_model_gold_answer = kwargs.get("evaluated_model_gold_answer")
|
||||
evaluators = self.evaluators
|
||||
|
||||
headers = {
|
||||
"X-API-KEY": os.getenv("PATRONUS_API_KEY"),
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
data = {
|
||||
"evaluated_model_input": (
|
||||
evaluated_model_input
|
||||
if isinstance(evaluated_model_input, str)
|
||||
else evaluated_model_input.get("description")
|
||||
),
|
||||
"evaluated_model_output": (
|
||||
evaluated_model_output
|
||||
if isinstance(evaluated_model_output, str)
|
||||
else evaluated_model_output.get("description")
|
||||
),
|
||||
"evaluated_model_retrieved_context": (
|
||||
evaluated_model_retrieved_context
|
||||
if isinstance(evaluated_model_retrieved_context, str)
|
||||
else evaluated_model_retrieved_context.get("description")
|
||||
),
|
||||
"evaluated_model_gold_answer": (
|
||||
evaluated_model_gold_answer
|
||||
if isinstance(evaluated_model_gold_answer, str)
|
||||
else evaluated_model_gold_answer.get("description")
|
||||
),
|
||||
"evaluators": (
|
||||
evaluators
|
||||
if isinstance(evaluators, list)
|
||||
else evaluators.get("description")
|
||||
),
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
self.evaluate_url, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Failed to evaluate model input and output. Status code: {response.status_code}. Reason: {response.text}"
|
||||
)
|
||||
|
||||
return response.json()
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
from pypdf import PdfReader, PdfWriter, PageObject, ContentStream, NameObject, Font
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pypdf import ContentStream, Font, NameObject, PageObject, PdfReader, PdfWriter
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
|
||||
|
||||
@@ -17,9 +17,7 @@ class PGSearchToolSchema(BaseModel):
|
||||
|
||||
class PGSearchTool(RagTool):
|
||||
name: str = "Search a database's table content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a database table's content."
|
||||
)
|
||||
description: str = "A tool that can be used to semantic search a query from a database table's content."
|
||||
args_schema: Type[BaseModel] = PGSearchToolSchema
|
||||
db_uri: str = Field(..., description="Mandatory database URI")
|
||||
|
||||
|
||||
@@ -10,8 +10,6 @@ from pydantic import BaseModel, Field
|
||||
class FixedScrapeElementFromWebsiteToolSchema(BaseModel):
|
||||
"""Input for ScrapeElementFromWebsiteTool."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ScrapeElementFromWebsiteToolSchema(FixedScrapeElementFromWebsiteToolSchema):
|
||||
"""Input for ScrapeElementFromWebsiteTool."""
|
||||
|
||||
@@ -11,8 +11,6 @@ from pydantic import BaseModel, Field
|
||||
class FixedScrapeWebsiteToolSchema(BaseModel):
|
||||
"""Input for ScrapeWebsiteTool."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ScrapeWebsiteToolSchema(FixedScrapeWebsiteToolSchema):
|
||||
"""Input for ScrapeWebsiteTool."""
|
||||
|
||||
@@ -10,17 +10,14 @@ from scrapegraph_py.logger import sgai_logger
|
||||
|
||||
class ScrapegraphError(Exception):
|
||||
"""Base exception for Scrapegraph-related errors"""
|
||||
pass
|
||||
|
||||
|
||||
class RateLimitError(ScrapegraphError):
|
||||
"""Raised when API rate limits are exceeded"""
|
||||
pass
|
||||
|
||||
|
||||
class FixedScrapegraphScrapeToolSchema(BaseModel):
|
||||
"""Input for ScrapegraphScrapeTool when website_url is fixed."""
|
||||
pass
|
||||
|
||||
|
||||
class ScrapegraphScrapeToolSchema(FixedScrapegraphScrapeToolSchema):
|
||||
@@ -32,7 +29,7 @@ class ScrapegraphScrapeToolSchema(FixedScrapegraphScrapeToolSchema):
|
||||
description="Prompt to guide the extraction of content",
|
||||
)
|
||||
|
||||
@validator('website_url')
|
||||
@validator("website_url")
|
||||
def validate_url(cls, v):
|
||||
"""Validate URL format"""
|
||||
try:
|
||||
@@ -41,7 +38,9 @@ class ScrapegraphScrapeToolSchema(FixedScrapegraphScrapeToolSchema):
|
||||
raise ValueError
|
||||
return v
|
||||
except Exception:
|
||||
raise ValueError("Invalid URL format. URL must include scheme (http/https) and domain")
|
||||
raise ValueError(
|
||||
"Invalid URL format. URL must include scheme (http/https) and domain"
|
||||
)
|
||||
|
||||
|
||||
class ScrapegraphScrapeTool(BaseTool):
|
||||
@@ -55,7 +54,9 @@ class ScrapegraphScrapeTool(BaseTool):
|
||||
"""
|
||||
|
||||
name: str = "Scrapegraph website scraper"
|
||||
description: str = "A tool that uses Scrapegraph AI to intelligently scrape website content."
|
||||
description: str = (
|
||||
"A tool that uses Scrapegraph AI to intelligently scrape website content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = ScrapegraphScrapeToolSchema
|
||||
website_url: Optional[str] = None
|
||||
user_prompt: Optional[str] = None
|
||||
@@ -72,7 +73,6 @@ class ScrapegraphScrapeTool(BaseTool):
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = api_key or os.getenv("SCRAPEGRAPH_API_KEY")
|
||||
self.enable_logging = enable_logging
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("Scrapegraph API key is required")
|
||||
@@ -98,14 +98,19 @@ class ScrapegraphScrapeTool(BaseTool):
|
||||
if not all([result.scheme, result.netloc]):
|
||||
raise ValueError
|
||||
except Exception:
|
||||
raise ValueError("Invalid URL format. URL must include scheme (http/https) and domain")
|
||||
raise ValueError(
|
||||
"Invalid URL format. URL must include scheme (http/https) and domain"
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
website_url = kwargs.get("website_url", self.website_url)
|
||||
user_prompt = kwargs.get("user_prompt", self.user_prompt) or "Extract the main content of the webpage"
|
||||
user_prompt = (
|
||||
kwargs.get("user_prompt", self.user_prompt)
|
||||
or "Extract the main content of the webpage"
|
||||
)
|
||||
|
||||
if not website_url:
|
||||
raise ValueError("website_url is required")
|
||||
|
||||
@@ -17,13 +17,16 @@ class FixedSeleniumScrapingToolSchema(BaseModel):
|
||||
class SeleniumScrapingToolSchema(FixedSeleniumScrapingToolSchema):
|
||||
"""Input for SeleniumScrapingTool."""
|
||||
|
||||
website_url: str = Field(..., description="Mandatory website url to read the file. Must start with http:// or https://")
|
||||
website_url: str = Field(
|
||||
...,
|
||||
description="Mandatory website url to read the file. Must start with http:// or https://",
|
||||
)
|
||||
css_element: str = Field(
|
||||
...,
|
||||
description="Mandatory css reference for element to scrape from the website",
|
||||
)
|
||||
|
||||
@validator('website_url')
|
||||
@validator("website_url")
|
||||
def validate_website_url(cls, v):
|
||||
if not v:
|
||||
raise ValueError("Website URL cannot be empty")
|
||||
@@ -31,7 +34,7 @@ class SeleniumScrapingToolSchema(FixedSeleniumScrapingToolSchema):
|
||||
if len(v) > 2048: # Common maximum URL length
|
||||
raise ValueError("URL is too long (max 2048 characters)")
|
||||
|
||||
if not re.match(r'^https?://', v):
|
||||
if not re.match(r"^https?://", v):
|
||||
raise ValueError("URL must start with http:// or https://")
|
||||
|
||||
try:
|
||||
@@ -41,7 +44,7 @@ class SeleniumScrapingToolSchema(FixedSeleniumScrapingToolSchema):
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid URL: {str(e)}")
|
||||
|
||||
if re.search(r'\s', v):
|
||||
if re.search(r"\s", v):
|
||||
raise ValueError("URL cannot contain whitespace")
|
||||
|
||||
return v
|
||||
@@ -132,7 +135,7 @@ class SeleniumScrapingTool(BaseTool):
|
||||
raise ValueError("URL cannot be empty")
|
||||
|
||||
# Validate URL format
|
||||
if not re.match(r'^https?://', url):
|
||||
if not re.match(r"^https?://", url):
|
||||
raise ValueError("URL must start with http:// or https://")
|
||||
|
||||
options = Options()
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Optional, Any, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
|
||||
class SerpApiBaseTool(BaseTool):
|
||||
"""Base class for SerpApi functionality with shared capabilities."""
|
||||
|
||||
|
||||
@@ -1,14 +1,22 @@
|
||||
from typing import Any, Type, Optional
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
import re
|
||||
from pydantic import BaseModel, Field
|
||||
from .serpapi_base_tool import SerpApiBaseTool
|
||||
from serpapi import HTTPError
|
||||
from urllib.error import HTTPError
|
||||
|
||||
from .serpapi_base_tool import SerpApiBaseTool
|
||||
|
||||
|
||||
class SerpApiGoogleSearchToolSchema(BaseModel):
|
||||
"""Input for Google Search."""
|
||||
search_query: str = Field(..., description="Mandatory search query you want to use to Google search.")
|
||||
location: Optional[str] = Field(None, description="Location you want the search to be performed in.")
|
||||
|
||||
search_query: str = Field(
|
||||
..., description="Mandatory search query you want to use to Google search."
|
||||
)
|
||||
location: Optional[str] = Field(
|
||||
None, description="Location you want the search to be performed in."
|
||||
)
|
||||
|
||||
|
||||
class SerpApiGoogleSearchTool(SerpApiBaseTool):
|
||||
name: str = "Google Search"
|
||||
@@ -22,19 +30,25 @@ class SerpApiGoogleSearchTool(SerpApiBaseTool):
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
try:
|
||||
results = self.client.search({
|
||||
"q": kwargs.get("search_query"),
|
||||
"location": kwargs.get("location"),
|
||||
}).as_dict()
|
||||
results = self.client.search(
|
||||
{
|
||||
"q": kwargs.get("search_query"),
|
||||
"location": kwargs.get("location"),
|
||||
}
|
||||
).as_dict()
|
||||
|
||||
self._omit_fields(
|
||||
results,
|
||||
[r"search_metadata", r"search_parameters", r"serpapi_.+", r".+_token", r"displayed_link", r"pagination"]
|
||||
[
|
||||
r"search_metadata",
|
||||
r"search_parameters",
|
||||
r"serpapi_.+",
|
||||
r".+_token",
|
||||
r"displayed_link",
|
||||
r"pagination",
|
||||
],
|
||||
)
|
||||
|
||||
return results
|
||||
except HTTPError as e:
|
||||
return f"An error occurred: {str(e)}. Some parameters may be invalid."
|
||||
|
||||
|
||||
|
||||
@@ -1,14 +1,21 @@
|
||||
from typing import Any, Type, Optional
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
import re
|
||||
from pydantic import BaseModel, Field
|
||||
from .serpapi_base_tool import SerpApiBaseTool
|
||||
from serpapi import HTTPError
|
||||
from urllib.error import HTTPError
|
||||
|
||||
from .serpapi_base_tool import SerpApiBaseTool
|
||||
|
||||
|
||||
class SerpApiGoogleShoppingToolSchema(BaseModel):
|
||||
"""Input for Google Shopping."""
|
||||
search_query: str = Field(..., description="Mandatory search query you want to use to Google shopping.")
|
||||
location: Optional[str] = Field(None, description="Location you want the search to be performed in.")
|
||||
|
||||
search_query: str = Field(
|
||||
..., description="Mandatory search query you want to use to Google shopping."
|
||||
)
|
||||
location: Optional[str] = Field(
|
||||
None, description="Location you want the search to be performed in."
|
||||
)
|
||||
|
||||
|
||||
class SerpApiGoogleShoppingTool(SerpApiBaseTool):
|
||||
@@ -23,20 +30,25 @@ class SerpApiGoogleShoppingTool(SerpApiBaseTool):
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
try:
|
||||
results = self.client.search({
|
||||
"engine": "google_shopping",
|
||||
"q": kwargs.get("search_query"),
|
||||
"location": kwargs.get("location")
|
||||
}).as_dict()
|
||||
results = self.client.search(
|
||||
{
|
||||
"engine": "google_shopping",
|
||||
"q": kwargs.get("search_query"),
|
||||
"location": kwargs.get("location"),
|
||||
}
|
||||
).as_dict()
|
||||
|
||||
self._omit_fields(
|
||||
results,
|
||||
[r"search_metadata", r"search_parameters", r"serpapi_.+", r"filters", r"pagination"]
|
||||
[
|
||||
r"search_metadata",
|
||||
r"search_parameters",
|
||||
r"serpapi_.+",
|
||||
r"filters",
|
||||
r"pagination",
|
||||
],
|
||||
)
|
||||
|
||||
return results
|
||||
except HTTPError as e:
|
||||
return f"An error occurred: {str(e)}. Some parameters may be invalid."
|
||||
|
||||
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Type
|
||||
|
||||
import requests
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _save_results_to_file(content: str) -> None:
|
||||
"""Saves the search results to a file."""
|
||||
try:
|
||||
@@ -35,7 +35,7 @@ class SerperDevToolSchema(BaseModel):
|
||||
|
||||
|
||||
class SerperDevTool(BaseTool):
|
||||
name: str = "Search the internet"
|
||||
name: str = "Search the internet with Serper"
|
||||
description: str = (
|
||||
"A tool that can be used to search the internet with a search_query. "
|
||||
"Supports different search types: 'search' (default), 'news'"
|
||||
|
||||
@@ -18,9 +18,7 @@ class SerplyWebpageToMarkdownToolSchema(BaseModel):
|
||||
|
||||
class SerplyWebpageToMarkdownTool(RagTool):
|
||||
name: str = "Webpage to Markdown"
|
||||
description: str = (
|
||||
"A tool to perform convert a webpage to markdown to make it easier for LLMs to understand"
|
||||
)
|
||||
description: str = "A tool to perform convert a webpage to markdown to make it easier for LLMs to understand"
|
||||
args_schema: Type[BaseModel] = SerplyWebpageToMarkdownToolSchema
|
||||
request_url: str = "https://api.serply.io/v1/request"
|
||||
proxy_location: Optional[str] = "US"
|
||||
|
||||
155
src/crewai_tools/tools/snowflake_search_tool/README.md
Normal file
155
src/crewai_tools/tools/snowflake_search_tool/README.md
Normal file
@@ -0,0 +1,155 @@
|
||||
# Snowflake Search Tool
|
||||
|
||||
A tool for executing queries on Snowflake data warehouse with built-in connection pooling, retry logic, and async execution support.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
uv sync --extra snowflake
|
||||
|
||||
OR
|
||||
uv pip install snowflake-connector-python>=3.5.0 snowflake-sqlalchemy>=1.5.0 cryptography>=41.0.0
|
||||
|
||||
OR
|
||||
pip install snowflake-connector-python>=3.5.0 snowflake-sqlalchemy>=1.5.0 cryptography>=41.0.0
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from crewai_tools import SnowflakeSearchTool, SnowflakeConfig
|
||||
|
||||
# Create configuration
|
||||
config = SnowflakeConfig(
|
||||
account="your_account",
|
||||
user="your_username",
|
||||
password="your_password",
|
||||
warehouse="COMPUTE_WH",
|
||||
database="your_database",
|
||||
snowflake_schema="your_schema" # Note: Uses snowflake_schema instead of schema
|
||||
)
|
||||
|
||||
# Initialize tool
|
||||
tool = SnowflakeSearchTool(
|
||||
config=config,
|
||||
pool_size=5,
|
||||
max_retries=3,
|
||||
enable_caching=True
|
||||
)
|
||||
|
||||
# Execute query
|
||||
async def main():
|
||||
results = await tool._run(
|
||||
query="SELECT * FROM your_table LIMIT 10",
|
||||
timeout=300
|
||||
)
|
||||
print(f"Retrieved {len(results)} rows")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
- ✨ Asynchronous query execution
|
||||
- 🚀 Connection pooling for better performance
|
||||
- 🔄 Automatic retries for transient failures
|
||||
- 💾 Query result caching (optional)
|
||||
- 🔒 Support for both password and key-pair authentication
|
||||
- 📝 Comprehensive error handling and logging
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### SnowflakeConfig Parameters
|
||||
|
||||
| Parameter | Required | Description |
|
||||
|-----------|----------|-------------|
|
||||
| account | Yes | Snowflake account identifier |
|
||||
| user | Yes | Snowflake username |
|
||||
| password | Yes* | Snowflake password |
|
||||
| private_key_path | No* | Path to private key file (alternative to password) |
|
||||
| warehouse | Yes | Snowflake warehouse name |
|
||||
| database | Yes | Default database |
|
||||
| snowflake_schema | Yes | Default schema |
|
||||
| role | No | Snowflake role |
|
||||
| session_parameters | No | Custom session parameters dict |
|
||||
|
||||
\* Either password or private_key_path must be provided
|
||||
|
||||
### Tool Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| pool_size | 5 | Number of connections in the pool |
|
||||
| max_retries | 3 | Maximum retry attempts for failed queries |
|
||||
| retry_delay | 1.0 | Delay between retries in seconds |
|
||||
| enable_caching | True | Enable/disable query result caching |
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Using Key-Pair Authentication
|
||||
|
||||
```python
|
||||
config = SnowflakeConfig(
|
||||
account="your_account",
|
||||
user="your_username",
|
||||
private_key_path="/path/to/private_key.p8",
|
||||
warehouse="your_warehouse",
|
||||
database="your_database",
|
||||
snowflake_schema="your_schema"
|
||||
)
|
||||
```
|
||||
|
||||
### Custom Session Parameters
|
||||
|
||||
```python
|
||||
config = SnowflakeConfig(
|
||||
# ... other config parameters ...
|
||||
session_parameters={
|
||||
"QUERY_TAG": "my_app",
|
||||
"TIMEZONE": "America/Los_Angeles"
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Error Handling**: Always wrap query execution in try-except blocks
|
||||
2. **Logging**: Enable logging to track query execution and errors
|
||||
3. **Connection Management**: Use appropriate pool sizes for your workload
|
||||
4. **Timeouts**: Set reasonable query timeouts to prevent hanging
|
||||
5. **Security**: Use key-pair auth in production and never hardcode credentials
|
||||
|
||||
## Example with Logging
|
||||
|
||||
```python
|
||||
import logging
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# ... tool initialization ...
|
||||
results = await tool._run(query="SELECT * FROM table LIMIT 10")
|
||||
logger.info(f"Query completed successfully. Retrieved {len(results)} rows")
|
||||
except Exception as e:
|
||||
logger.error(f"Query failed: {str(e)}")
|
||||
raise
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
The tool automatically handles common Snowflake errors:
|
||||
- DatabaseError
|
||||
- OperationalError
|
||||
- ProgrammingError
|
||||
- Network timeouts
|
||||
- Connection issues
|
||||
|
||||
Errors are logged and retried based on your retry configuration.
|
||||
11
src/crewai_tools/tools/snowflake_search_tool/__init__.py
Normal file
11
src/crewai_tools/tools/snowflake_search_tool/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .snowflake_search_tool import (
|
||||
SnowflakeConfig,
|
||||
SnowflakeSearchTool,
|
||||
SnowflakeSearchToolInput,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SnowflakeSearchTool",
|
||||
"SnowflakeSearchToolInput",
|
||||
"SnowflakeConfig",
|
||||
]
|
||||
@@ -0,0 +1,201 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
import snowflake.connector
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
||||
from snowflake.connector.connection import SnowflakeConnection
|
||||
from snowflake.connector.errors import DatabaseError, OperationalError
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache for query results
|
||||
_query_cache = {}
|
||||
|
||||
|
||||
class SnowflakeConfig(BaseModel):
|
||||
"""Configuration for Snowflake connection."""
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
account: str = Field(
|
||||
..., description="Snowflake account identifier", pattern=r"^[a-zA-Z0-9\-_]+$"
|
||||
)
|
||||
user: str = Field(..., description="Snowflake username")
|
||||
password: Optional[SecretStr] = Field(None, description="Snowflake password")
|
||||
private_key_path: Optional[str] = Field(
|
||||
None, description="Path to private key file"
|
||||
)
|
||||
warehouse: Optional[str] = Field(None, description="Snowflake warehouse")
|
||||
database: Optional[str] = Field(None, description="Default database")
|
||||
snowflake_schema: Optional[str] = Field(None, description="Default schema")
|
||||
role: Optional[str] = Field(None, description="Snowflake role")
|
||||
session_parameters: Optional[Dict[str, Any]] = Field(
|
||||
default_factory=dict, description="Session parameters"
|
||||
)
|
||||
|
||||
@property
|
||||
def has_auth(self) -> bool:
|
||||
return bool(self.password or self.private_key_path)
|
||||
|
||||
def model_post_init(self, *args, **kwargs):
|
||||
if not self.has_auth:
|
||||
raise ValueError("Either password or private_key_path must be provided")
|
||||
|
||||
|
||||
class SnowflakeSearchToolInput(BaseModel):
|
||||
"""Input schema for SnowflakeSearchTool."""
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
query: str = Field(..., description="SQL query or semantic search query to execute")
|
||||
database: Optional[str] = Field(None, description="Override default database")
|
||||
snowflake_schema: Optional[str] = Field(None, description="Override default schema")
|
||||
timeout: Optional[int] = Field(300, description="Query timeout in seconds")
|
||||
|
||||
|
||||
class SnowflakeSearchTool(BaseTool):
|
||||
"""Tool for executing queries and semantic search on Snowflake."""
|
||||
|
||||
name: str = "Snowflake Database Search"
|
||||
description: str = (
|
||||
"Execute SQL queries or semantic search on Snowflake data warehouse. "
|
||||
"Supports both raw SQL and natural language queries."
|
||||
)
|
||||
args_schema: Type[BaseModel] = SnowflakeSearchToolInput
|
||||
|
||||
# Define Pydantic fields
|
||||
config: SnowflakeConfig = Field(
|
||||
..., description="Snowflake connection configuration"
|
||||
)
|
||||
pool_size: int = Field(default=5, description="Size of connection pool")
|
||||
max_retries: int = Field(default=3, description="Maximum retry attempts")
|
||||
retry_delay: float = Field(
|
||||
default=1.0, description="Delay between retries in seconds"
|
||||
)
|
||||
enable_caching: bool = Field(
|
||||
default=True, description="Enable query result caching"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, **data):
|
||||
"""Initialize SnowflakeSearchTool."""
|
||||
super().__init__(**data)
|
||||
self._connection_pool: List[SnowflakeConnection] = []
|
||||
self._pool_lock = asyncio.Lock()
|
||||
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
|
||||
|
||||
async def _get_connection(self) -> SnowflakeConnection:
|
||||
"""Get a connection from the pool or create a new one."""
|
||||
async with self._pool_lock:
|
||||
if not self._connection_pool:
|
||||
conn = self._create_connection()
|
||||
self._connection_pool.append(conn)
|
||||
return self._connection_pool.pop()
|
||||
|
||||
def _create_connection(self) -> SnowflakeConnection:
|
||||
"""Create a new Snowflake connection."""
|
||||
conn_params = {
|
||||
"account": self.config.account,
|
||||
"user": self.config.user,
|
||||
"warehouse": self.config.warehouse,
|
||||
"database": self.config.database,
|
||||
"schema": self.config.snowflake_schema,
|
||||
"role": self.config.role,
|
||||
"session_parameters": self.config.session_parameters,
|
||||
}
|
||||
|
||||
if self.config.password:
|
||||
conn_params["password"] = self.config.password.get_secret_value()
|
||||
elif self.config.private_key_path:
|
||||
with open(self.config.private_key_path, "rb") as key_file:
|
||||
p_key = serialization.load_pem_private_key(
|
||||
key_file.read(), password=None, backend=default_backend()
|
||||
)
|
||||
conn_params["private_key"] = p_key
|
||||
|
||||
return snowflake.connector.connect(**conn_params)
|
||||
|
||||
def _get_cache_key(self, query: str, timeout: int) -> str:
|
||||
"""Generate a cache key for the query."""
|
||||
return f"{self.config.account}:{self.config.database}:{self.config.snowflake_schema}:{query}:{timeout}"
|
||||
|
||||
async def _execute_query(
|
||||
self, query: str, timeout: int = 300
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Execute a query with retries and return results."""
|
||||
if self.enable_caching:
|
||||
cache_key = self._get_cache_key(query, timeout)
|
||||
if cache_key in _query_cache:
|
||||
logger.info("Returning cached result")
|
||||
return _query_cache[cache_key]
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
conn = await self._get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(query, timeout=timeout)
|
||||
|
||||
if not cursor.description:
|
||||
return []
|
||||
|
||||
columns = [col[0] for col in cursor.description]
|
||||
results = [dict(zip(columns, row)) for row in cursor.fetchall()]
|
||||
|
||||
if self.enable_caching:
|
||||
_query_cache[self._get_cache_key(query, timeout)] = results
|
||||
|
||||
return results
|
||||
finally:
|
||||
cursor.close()
|
||||
async with self._pool_lock:
|
||||
self._connection_pool.append(conn)
|
||||
except (DatabaseError, OperationalError) as e:
|
||||
if attempt == self.max_retries - 1:
|
||||
raise
|
||||
await asyncio.sleep(self.retry_delay * (2**attempt))
|
||||
logger.warning(f"Query failed, attempt {attempt + 1}: {str(e)}")
|
||||
continue
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
query: str,
|
||||
database: Optional[str] = None,
|
||||
snowflake_schema: Optional[str] = None,
|
||||
timeout: int = 300,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Execute the search query."""
|
||||
try:
|
||||
# Override database/schema if provided
|
||||
if database:
|
||||
await self._execute_query(f"USE DATABASE {database}")
|
||||
if snowflake_schema:
|
||||
await self._execute_query(f"USE SCHEMA {snowflake_schema}")
|
||||
|
||||
results = await self._execute_query(query, timeout)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing query: {str(e)}")
|
||||
raise
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup connections on deletion."""
|
||||
try:
|
||||
for conn in getattr(self, "_connection_pool", []):
|
||||
try:
|
||||
conn.close()
|
||||
except:
|
||||
pass
|
||||
if hasattr(self, "_thread_pool"):
|
||||
self._thread_pool.shutdown()
|
||||
except:
|
||||
pass
|
||||
@@ -14,9 +14,8 @@ import os
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -25,6 +24,7 @@ logger = logging.getLogger(__name__)
|
||||
STAGEHAND_AVAILABLE = False
|
||||
try:
|
||||
import stagehand
|
||||
|
||||
STAGEHAND_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass # Keep STAGEHAND_AVAILABLE as False
|
||||
@@ -38,9 +38,16 @@ class StagehandResult(BaseModel):
|
||||
data: The result data from the operation
|
||||
error: Optional error message if the operation failed
|
||||
"""
|
||||
success: bool = Field(..., description="Whether the operation completed successfully")
|
||||
data: Union[str, Dict, List] = Field(..., description="The result data from the operation")
|
||||
error: Optional[str] = Field(None, description="Optional error message if the operation failed")
|
||||
|
||||
success: bool = Field(
|
||||
..., description="Whether the operation completed successfully"
|
||||
)
|
||||
data: Union[str, Dict, List] = Field(
|
||||
..., description="The result data from the operation"
|
||||
)
|
||||
error: Optional[str] = Field(
|
||||
None, description="Optional error message if the operation failed"
|
||||
)
|
||||
|
||||
|
||||
class StagehandToolConfig(BaseModel):
|
||||
@@ -51,9 +58,14 @@ class StagehandToolConfig(BaseModel):
|
||||
timeout: Maximum time in seconds to wait for operations (default: 30)
|
||||
retry_attempts: Number of times to retry failed operations (default: 3)
|
||||
"""
|
||||
|
||||
api_key: str = Field(..., description="OpenAI API key for Stagehand authentication")
|
||||
timeout: int = Field(30, description="Maximum time in seconds to wait for operations")
|
||||
retry_attempts: int = Field(3, description="Number of times to retry failed operations")
|
||||
timeout: int = Field(
|
||||
30, description="Maximum time in seconds to wait for operations"
|
||||
)
|
||||
retry_attempts: int = Field(
|
||||
3, description="Number of times to retry failed operations"
|
||||
)
|
||||
|
||||
|
||||
class StagehandToolSchema(BaseModel):
|
||||
@@ -80,16 +92,17 @@ class StagehandToolSchema(BaseModel):
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
api_method: str = Field(
|
||||
...,
|
||||
description="The Stagehand API to use: 'act' for interactions, 'extract' for getting content, or 'observe' for monitoring changes",
|
||||
pattern="^(act|extract|observe)$"
|
||||
pattern="^(act|extract|observe)$",
|
||||
)
|
||||
instruction: str = Field(
|
||||
...,
|
||||
description="An atomic instruction for Stagehand to execute. Instructions should be simple and specific to increase reliability.",
|
||||
min_length=1,
|
||||
max_length=500
|
||||
max_length=500,
|
||||
)
|
||||
|
||||
|
||||
@@ -138,7 +151,9 @@ class StagehandTool(BaseTool):
|
||||
)
|
||||
args_schema: Type[BaseModel] = StagehandToolSchema
|
||||
|
||||
def __init__(self, config: StagehandToolConfig | None = None, **kwargs: Any) -> None:
|
||||
def __init__(
|
||||
self, config: StagehandToolConfig | None = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize the StagehandTool.
|
||||
|
||||
Args:
|
||||
@@ -168,9 +183,7 @@ class StagehandTool(BaseTool):
|
||||
"Either provide config with api_key or set OPENAI_API_KEY environment variable"
|
||||
)
|
||||
self.config = StagehandToolConfig(
|
||||
api_key=api_key,
|
||||
timeout=30,
|
||||
retry_attempts=3
|
||||
api_key=api_key, timeout=30, retry_attempts=3
|
||||
)
|
||||
|
||||
@lru_cache(maxsize=100)
|
||||
@@ -193,23 +206,25 @@ class StagehandTool(BaseTool):
|
||||
logger.debug(
|
||||
"Cache operation - Method: %s, Instruction length: %d",
|
||||
api_method,
|
||||
len(instruction)
|
||||
len(instruction),
|
||||
)
|
||||
|
||||
# Initialize Stagehand with configuration
|
||||
logger.info(
|
||||
"Initializing Stagehand (timeout=%ds, retries=%d)",
|
||||
self.config.timeout,
|
||||
self.config.retry_attempts
|
||||
self.config.retry_attempts,
|
||||
)
|
||||
st = stagehand.Stagehand(
|
||||
api_key=self.config.api_key,
|
||||
timeout=self.config.timeout,
|
||||
retry_attempts=self.config.retry_attempts
|
||||
retry_attempts=self.config.retry_attempts,
|
||||
)
|
||||
|
||||
# Call the appropriate Stagehand API based on the method
|
||||
logger.info("Executing %s operation with instruction: %s", api_method, instruction[:100])
|
||||
logger.info(
|
||||
"Executing %s operation with instruction: %s", api_method, instruction[:100]
|
||||
)
|
||||
try:
|
||||
if api_method == "act":
|
||||
result = st.act(instruction)
|
||||
@@ -220,7 +235,6 @@ class StagehandTool(BaseTool):
|
||||
else:
|
||||
raise ValueError(f"Unknown api_method: {api_method}")
|
||||
|
||||
|
||||
logger.info("Successfully executed %s operation", api_method)
|
||||
return result
|
||||
|
||||
@@ -228,7 +242,7 @@ class StagehandTool(BaseTool):
|
||||
logger.warning(
|
||||
"Operation failed (method=%s, error=%s), will be retried on next attempt",
|
||||
api_method,
|
||||
str(e)
|
||||
str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
@@ -249,7 +263,7 @@ class StagehandTool(BaseTool):
|
||||
"Starting operation - Method: %s, Instruction length: %d, Args: %s",
|
||||
api_method,
|
||||
len(instruction),
|
||||
kwargs
|
||||
kwargs,
|
||||
)
|
||||
|
||||
# Use cached execution
|
||||
@@ -259,46 +273,26 @@ class StagehandTool(BaseTool):
|
||||
|
||||
except stagehand.AuthenticationError as e:
|
||||
logger.error(
|
||||
"Authentication failed - Method: %s, Error: %s",
|
||||
api_method,
|
||||
str(e)
|
||||
"Authentication failed - Method: %s, Error: %s", api_method, str(e)
|
||||
)
|
||||
return StagehandResult(
|
||||
success=False,
|
||||
data={},
|
||||
error=f"Authentication failed: {str(e)}"
|
||||
success=False, data={}, error=f"Authentication failed: {str(e)}"
|
||||
)
|
||||
except stagehand.APIError as e:
|
||||
logger.error(
|
||||
"API error - Method: %s, Error: %s",
|
||||
api_method,
|
||||
str(e)
|
||||
)
|
||||
return StagehandResult(
|
||||
success=False,
|
||||
data={},
|
||||
error=f"API error: {str(e)}"
|
||||
)
|
||||
logger.error("API error - Method: %s, Error: %s", api_method, str(e))
|
||||
return StagehandResult(success=False, data={}, error=f"API error: {str(e)}")
|
||||
except stagehand.BrowserError as e:
|
||||
logger.error(
|
||||
"Browser error - Method: %s, Error: %s",
|
||||
api_method,
|
||||
str(e)
|
||||
)
|
||||
logger.error("Browser error - Method: %s, Error: %s", api_method, str(e))
|
||||
return StagehandResult(
|
||||
success=False,
|
||||
data={},
|
||||
error=f"Browser error: {str(e)}"
|
||||
success=False, data={}, error=f"Browser error: {str(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Unexpected error - Method: %s, Error type: %s, Message: %s",
|
||||
api_method,
|
||||
type(e).__name__,
|
||||
str(e)
|
||||
str(e),
|
||||
)
|
||||
return StagehandResult(
|
||||
success=False,
|
||||
data={},
|
||||
error=f"Unexpected error: {str(e)}"
|
||||
success=False, data={}, error=f"Unexpected error: {str(e)}"
|
||||
)
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
import base64
|
||||
from typing import Type, Optional
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
|
||||
class ImagePromptSchema(BaseModel):
|
||||
"""Input for Vision Tool."""
|
||||
|
||||
image_path_url: str = "The image path or URL."
|
||||
|
||||
@validator("image_path_url")
|
||||
@@ -21,10 +24,13 @@ class ImagePromptSchema(BaseModel):
|
||||
# Validate supported formats
|
||||
valid_extensions = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
|
||||
if path.suffix.lower() not in valid_extensions:
|
||||
raise ValueError(f"Unsupported image format. Supported formats: {valid_extensions}")
|
||||
raise ValueError(
|
||||
f"Unsupported image format. Supported formats: {valid_extensions}"
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class VisionTool(BaseTool):
|
||||
name: str = "Vision Tool"
|
||||
description: str = (
|
||||
@@ -68,7 +74,7 @@ class VisionTool(BaseTool):
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
}
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
|
||||
@@ -15,9 +15,8 @@ except ImportError:
|
||||
Vectorizers = Any
|
||||
Auth = Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class WeaviateToolSchema(BaseModel):
|
||||
|
||||
@@ -25,9 +25,7 @@ class WebsiteSearchToolSchema(FixedWebsiteSearchToolSchema):
|
||||
|
||||
class WebsiteSearchTool(RagTool):
|
||||
name: str = "Search in a specific website"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a specific URL content."
|
||||
)
|
||||
description: str = "A tool that can be used to semantic search a query from a specific URL content."
|
||||
args_schema: Type[BaseModel] = WebsiteSearchToolSchema
|
||||
|
||||
def __init__(self, website: Optional[str] = None, **kwargs):
|
||||
|
||||
@@ -25,9 +25,7 @@ class YoutubeChannelSearchToolSchema(FixedYoutubeChannelSearchToolSchema):
|
||||
|
||||
class YoutubeChannelSearchTool(RagTool):
|
||||
name: str = "Search a Youtube Channels content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a Youtube Channels content."
|
||||
)
|
||||
description: str = "A tool that can be used to semantic search a query from a Youtube Channels content."
|
||||
args_schema: Type[BaseModel] = YoutubeChannelSearchToolSchema
|
||||
|
||||
def __init__(self, youtube_channel_handle: Optional[str] = None, **kwargs):
|
||||
|
||||
@@ -25,9 +25,7 @@ class YoutubeVideoSearchToolSchema(FixedYoutubeVideoSearchToolSchema):
|
||||
|
||||
class YoutubeVideoSearchTool(RagTool):
|
||||
name: str = "Search a Youtube Video content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a Youtube Video content."
|
||||
)
|
||||
description: str = "A tool that can be used to semantic search a query from a Youtube Video content."
|
||||
args_schema: Type[BaseModel] = YoutubeVideoSearchToolSchema
|
||||
|
||||
def __init__(self, youtube_video_url: Optional[str] = None, **kwargs):
|
||||
|
||||
@@ -1,69 +1,104 @@
|
||||
from typing import Callable
|
||||
|
||||
from crewai.tools import BaseTool, tool
|
||||
from crewai.tools.base_tool import to_langchain
|
||||
|
||||
|
||||
def test_creating_a_tool_using_annotation():
|
||||
@tool("Name of my tool")
|
||||
def my_tool(question: str) -> str:
|
||||
"""Clear description for what this tool is useful for, you agent will need this information to use it."""
|
||||
return question
|
||||
@tool("Name of my tool")
|
||||
def my_tool(question: str) -> str:
|
||||
"""Clear description for what this tool is useful for, you agent will need this information to use it."""
|
||||
return question
|
||||
|
||||
# Assert all the right attributes were defined
|
||||
assert my_tool.name == "Name of my tool"
|
||||
assert my_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
assert my_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}}
|
||||
assert my_tool.func("What is the meaning of life?") == "What is the meaning of life?"
|
||||
# Assert all the right attributes were defined
|
||||
assert my_tool.name == "Name of my tool"
|
||||
assert (
|
||||
my_tool.description
|
||||
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
)
|
||||
assert my_tool.args_schema.schema()["properties"] == {
|
||||
"question": {"title": "Question", "type": "string"}
|
||||
}
|
||||
assert (
|
||||
my_tool.func("What is the meaning of life?") == "What is the meaning of life?"
|
||||
)
|
||||
|
||||
# Assert the langchain tool conversion worked as expected
|
||||
converted_tool = to_langchain([my_tool])[0]
|
||||
assert converted_tool.name == "Name of my tool"
|
||||
assert (
|
||||
converted_tool.description
|
||||
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
)
|
||||
assert converted_tool.args_schema.schema()["properties"] == {
|
||||
"question": {"title": "Question", "type": "string"}
|
||||
}
|
||||
assert (
|
||||
converted_tool.func("What is the meaning of life?")
|
||||
== "What is the meaning of life?"
|
||||
)
|
||||
|
||||
# Assert the langchain tool conversion worked as expected
|
||||
converted_tool = to_langchain([my_tool])[0]
|
||||
assert converted_tool.name == "Name of my tool"
|
||||
assert converted_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
assert converted_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}}
|
||||
assert converted_tool.func("What is the meaning of life?") == "What is the meaning of life?"
|
||||
|
||||
def test_creating_a_tool_using_baseclass():
|
||||
class MyCustomTool(BaseTool):
|
||||
name: str = "Name of my tool"
|
||||
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
class MyCustomTool(BaseTool):
|
||||
name: str = "Name of my tool"
|
||||
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
|
||||
def _run(self, question: str) -> str:
|
||||
return question
|
||||
def _run(self, question: str) -> str:
|
||||
return question
|
||||
|
||||
my_tool = MyCustomTool()
|
||||
# Assert all the right attributes were defined
|
||||
assert my_tool.name == "Name of my tool"
|
||||
assert my_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
assert my_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}}
|
||||
assert my_tool._run("What is the meaning of life?") == "What is the meaning of life?"
|
||||
my_tool = MyCustomTool()
|
||||
# Assert all the right attributes were defined
|
||||
assert my_tool.name == "Name of my tool"
|
||||
assert (
|
||||
my_tool.description
|
||||
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
)
|
||||
assert my_tool.args_schema.schema()["properties"] == {
|
||||
"question": {"title": "Question", "type": "string"}
|
||||
}
|
||||
assert (
|
||||
my_tool._run("What is the meaning of life?") == "What is the meaning of life?"
|
||||
)
|
||||
|
||||
# Assert the langchain tool conversion worked as expected
|
||||
converted_tool = to_langchain([my_tool])[0]
|
||||
assert converted_tool.name == "Name of my tool"
|
||||
assert (
|
||||
converted_tool.description
|
||||
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
)
|
||||
assert converted_tool.args_schema.schema()["properties"] == {
|
||||
"question": {"title": "Question", "type": "string"}
|
||||
}
|
||||
assert (
|
||||
converted_tool.invoke({"question": "What is the meaning of life?"})
|
||||
== "What is the meaning of life?"
|
||||
)
|
||||
|
||||
# Assert the langchain tool conversion worked as expected
|
||||
converted_tool = to_langchain([my_tool])[0]
|
||||
assert converted_tool.name == "Name of my tool"
|
||||
assert converted_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
assert converted_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}}
|
||||
assert converted_tool.invoke({"question": "What is the meaning of life?"}) == "What is the meaning of life?"
|
||||
|
||||
def test_setting_cache_function():
|
||||
class MyCustomTool(BaseTool):
|
||||
name: str = "Name of my tool"
|
||||
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
cache_function: Callable = lambda: False
|
||||
class MyCustomTool(BaseTool):
|
||||
name: str = "Name of my tool"
|
||||
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
cache_function: Callable = lambda: False
|
||||
|
||||
def _run(self, question: str) -> str:
|
||||
return question
|
||||
def _run(self, question: str) -> str:
|
||||
return question
|
||||
|
||||
my_tool = MyCustomTool()
|
||||
# Assert all the right attributes were defined
|
||||
assert my_tool.cache_function() == False
|
||||
|
||||
my_tool = MyCustomTool()
|
||||
# Assert all the right attributes were defined
|
||||
assert my_tool.cache_function() == False
|
||||
|
||||
def test_default_cache_function_is_true():
|
||||
class MyCustomTool(BaseTool):
|
||||
name: str = "Name of my tool"
|
||||
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
class MyCustomTool(BaseTool):
|
||||
name: str = "Name of my tool"
|
||||
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
|
||||
def _run(self, question: str) -> str:
|
||||
return question
|
||||
def _run(self, question: str) -> str:
|
||||
return question
|
||||
|
||||
my_tool = MyCustomTool()
|
||||
# Assert all the right attributes were defined
|
||||
assert my_tool.cache_function() == True
|
||||
my_tool = MyCustomTool()
|
||||
# Assert all the right attributes were defined
|
||||
assert my_tool.cache_function() == True
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from crewai_tools import FileReadTool
|
||||
|
||||
|
||||
def test_file_read_tool_constructor():
|
||||
"""Test FileReadTool initialization with file_path."""
|
||||
# Create a temporary test file
|
||||
@@ -18,6 +19,7 @@ def test_file_read_tool_constructor():
|
||||
# Clean up
|
||||
os.remove(test_file)
|
||||
|
||||
|
||||
def test_file_read_tool_run():
|
||||
"""Test FileReadTool _run method with file_path at runtime."""
|
||||
# Create a temporary test file
|
||||
@@ -34,6 +36,7 @@ def test_file_read_tool_run():
|
||||
# Clean up
|
||||
os.remove(test_file)
|
||||
|
||||
|
||||
def test_file_read_tool_error_handling():
|
||||
"""Test FileReadTool error handling."""
|
||||
# Test missing file path
|
||||
@@ -58,6 +61,7 @@ def test_file_read_tool_error_handling():
|
||||
os.chmod(test_file, 0o666) # Restore permissions to delete
|
||||
os.remove(test_file)
|
||||
|
||||
|
||||
def test_file_read_tool_constructor_and_run():
|
||||
"""Test FileReadTool using both constructor and runtime file paths."""
|
||||
# Create two test files
|
||||
|
||||
0
tests/it/tools/__init__.py
Normal file
0
tests/it/tools/__init__.py
Normal file
21
tests/it/tools/conftest.py
Normal file
21
tests/it/tools/conftest.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Register custom markers."""
|
||||
config.addinivalue_line("markers", "integration: mark test as an integration test")
|
||||
config.addinivalue_line("markers", "asyncio: mark test as an async test")
|
||||
|
||||
# Set the asyncio loop scope through ini configuration
|
||||
config.inicfg["asyncio_mode"] = "auto"
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for each test case."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
yield loop
|
||||
loop.close()
|
||||
219
tests/it/tools/snowflake_search_tool_test.py
Normal file
219
tests/it/tools/snowflake_search_tool_test.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import asyncio
|
||||
import json
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from snowflake.connector.errors import DatabaseError, OperationalError
|
||||
|
||||
from crewai_tools import SnowflakeConfig, SnowflakeSearchTool
|
||||
|
||||
# Test Data
|
||||
MENU_ITEMS = [
|
||||
(10001, "Ice Cream", "Freezing Point", "Lemonade", "Beverage", "Cold Option", 1, 4),
|
||||
(
|
||||
10002,
|
||||
"Ice Cream",
|
||||
"Freezing Point",
|
||||
"Vanilla Ice Cream",
|
||||
"Dessert",
|
||||
"Ice Cream",
|
||||
2,
|
||||
6,
|
||||
),
|
||||
]
|
||||
|
||||
INVALID_QUERIES = [
|
||||
("SELECT * FROM nonexistent_table", "relation 'nonexistent_table' does not exist"),
|
||||
("SELECT invalid_column FROM menu", "invalid identifier 'invalid_column'"),
|
||||
("INVALID SQL QUERY", "SQL compilation error"),
|
||||
]
|
||||
|
||||
|
||||
# Integration Test Fixtures
|
||||
@pytest.fixture
|
||||
def config():
|
||||
"""Create a Snowflake configuration with test credentials."""
|
||||
return SnowflakeConfig(
|
||||
account="lwyhjun-wx11931",
|
||||
user="crewgitci",
|
||||
password="crewaiT00ls_publicCIpass123",
|
||||
warehouse="COMPUTE_WH",
|
||||
database="tasty_bytes_sample_data",
|
||||
snowflake_schema="raw_pos",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def snowflake_tool(config):
|
||||
"""Create a SnowflakeSearchTool instance."""
|
||||
return SnowflakeSearchTool(config=config)
|
||||
|
||||
|
||||
# Integration Tests with Real Snowflake Connection
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"menu_id,expected_type,brand,item_name,category,subcategory,cost,price", MENU_ITEMS
|
||||
)
|
||||
async def test_menu_items(
|
||||
snowflake_tool,
|
||||
menu_id,
|
||||
expected_type,
|
||||
brand,
|
||||
item_name,
|
||||
category,
|
||||
subcategory,
|
||||
cost,
|
||||
price,
|
||||
):
|
||||
"""Test menu items with parameterized data for multiple test cases."""
|
||||
results = await snowflake_tool._run(
|
||||
query=f"SELECT * FROM menu WHERE menu_id = {menu_id}"
|
||||
)
|
||||
assert len(results) == 1
|
||||
menu_item = results[0]
|
||||
|
||||
# Validate all fields
|
||||
assert menu_item["MENU_ID"] == menu_id
|
||||
assert menu_item["MENU_TYPE"] == expected_type
|
||||
assert menu_item["TRUCK_BRAND_NAME"] == brand
|
||||
assert menu_item["MENU_ITEM_NAME"] == item_name
|
||||
assert menu_item["ITEM_CATEGORY"] == category
|
||||
assert menu_item["ITEM_SUBCATEGORY"] == subcategory
|
||||
assert menu_item["COST_OF_GOODS_USD"] == cost
|
||||
assert menu_item["SALE_PRICE_USD"] == price
|
||||
|
||||
# Validate health metrics JSON structure
|
||||
health_metrics = json.loads(menu_item["MENU_ITEM_HEALTH_METRICS_OBJ"])
|
||||
assert "menu_item_health_metrics" in health_metrics
|
||||
metrics = health_metrics["menu_item_health_metrics"][0]
|
||||
assert "ingredients" in metrics
|
||||
assert isinstance(metrics["ingredients"], list)
|
||||
assert all(isinstance(ingredient, str) for ingredient in metrics["ingredients"])
|
||||
assert metrics["is_dairy_free_flag"] in ["Y", "N"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_menu_categories_aggregation(snowflake_tool):
|
||||
"""Test complex aggregation query on menu categories with detailed validations."""
|
||||
results = await snowflake_tool._run(
|
||||
query="""
|
||||
SELECT
|
||||
item_category,
|
||||
COUNT(*) as item_count,
|
||||
AVG(sale_price_usd) as avg_price,
|
||||
SUM(sale_price_usd - cost_of_goods_usd) as total_margin,
|
||||
COUNT(DISTINCT menu_type) as menu_type_count,
|
||||
MIN(sale_price_usd) as min_price,
|
||||
MAX(sale_price_usd) as max_price
|
||||
FROM menu
|
||||
GROUP BY item_category
|
||||
HAVING COUNT(*) > 1
|
||||
ORDER BY item_count DESC
|
||||
"""
|
||||
)
|
||||
|
||||
assert len(results) > 0
|
||||
for category in results:
|
||||
# Basic presence checks
|
||||
assert all(
|
||||
key in category
|
||||
for key in [
|
||||
"ITEM_CATEGORY",
|
||||
"ITEM_COUNT",
|
||||
"AVG_PRICE",
|
||||
"TOTAL_MARGIN",
|
||||
"MENU_TYPE_COUNT",
|
||||
"MIN_PRICE",
|
||||
"MAX_PRICE",
|
||||
]
|
||||
)
|
||||
|
||||
# Value validations
|
||||
assert category["ITEM_COUNT"] > 1 # Due to HAVING clause
|
||||
assert category["MIN_PRICE"] <= category["MAX_PRICE"]
|
||||
assert category["AVG_PRICE"] >= category["MIN_PRICE"]
|
||||
assert category["AVG_PRICE"] <= category["MAX_PRICE"]
|
||||
assert category["MENU_TYPE_COUNT"] >= 1
|
||||
assert isinstance(category["TOTAL_MARGIN"], (float, Decimal))
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("invalid_query,expected_error", INVALID_QUERIES)
|
||||
async def test_invalid_queries(snowflake_tool, invalid_query, expected_error):
|
||||
"""Test error handling for invalid queries."""
|
||||
with pytest.raises((DatabaseError, OperationalError)) as exc_info:
|
||||
await snowflake_tool._run(query=invalid_query)
|
||||
assert expected_error.lower() in str(exc_info.value).lower()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_queries(snowflake_tool):
|
||||
"""Test handling of concurrent queries."""
|
||||
queries = [
|
||||
"SELECT COUNT(*) FROM menu",
|
||||
"SELECT COUNT(DISTINCT menu_type) FROM menu",
|
||||
"SELECT COUNT(DISTINCT item_category) FROM menu",
|
||||
]
|
||||
|
||||
tasks = [snowflake_tool._run(query=query) for query in queries]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(isinstance(result, list) for result in results)
|
||||
assert all(len(result) == 1 for result in results)
|
||||
assert all(isinstance(result[0], dict) for result in results)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_timeout(snowflake_tool):
|
||||
"""Test query timeout handling with a complex query."""
|
||||
with pytest.raises((DatabaseError, OperationalError)) as exc_info:
|
||||
await snowflake_tool._run(
|
||||
query="""
|
||||
WITH RECURSIVE numbers AS (
|
||||
SELECT 1 as n
|
||||
UNION ALL
|
||||
SELECT n + 1
|
||||
FROM numbers
|
||||
WHERE n < 1000000
|
||||
)
|
||||
SELECT COUNT(*) FROM numbers
|
||||
"""
|
||||
)
|
||||
assert (
|
||||
"timeout" in str(exc_info.value).lower()
|
||||
or "execution time" in str(exc_info.value).lower()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching_behavior(snowflake_tool):
|
||||
"""Test query caching behavior and performance."""
|
||||
query = "SELECT * FROM menu LIMIT 5"
|
||||
|
||||
# First execution
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
results1 = await snowflake_tool._run(query=query)
|
||||
first_duration = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
# Second execution (should be cached)
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
results2 = await snowflake_tool._run(query=query)
|
||||
second_duration = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
# Verify results
|
||||
assert results1 == results2
|
||||
assert len(results1) == 5
|
||||
assert second_duration < first_duration
|
||||
|
||||
# Verify cache invalidation with different query
|
||||
different_query = "SELECT * FROM menu LIMIT 10"
|
||||
different_results = await snowflake_tool._run(query=different_query)
|
||||
assert len(different_results) == 10
|
||||
assert different_results != results1
|
||||
@@ -1,5 +1,7 @@
|
||||
from crewai import Agent, Crew, Task
|
||||
|
||||
from crewai_tools.tools.spider_tool.spider_tool import SpiderTool
|
||||
from crewai import Agent, Task, Crew
|
||||
|
||||
|
||||
def test_spider_tool():
|
||||
spider_tool = SpiderTool()
|
||||
@@ -10,38 +12,35 @@ def test_spider_tool():
|
||||
backstory="An expert web researcher that uses the web extremely well",
|
||||
tools=[spider_tool],
|
||||
verbose=True,
|
||||
cache=False
|
||||
cache=False,
|
||||
)
|
||||
|
||||
choose_between_scrape_crawl = Task(
|
||||
description="Scrape the page of spider.cloud and return a summary of how fast it is",
|
||||
expected_output="spider.cloud is a fast scraping and crawling tool",
|
||||
agent=searcher
|
||||
agent=searcher,
|
||||
)
|
||||
|
||||
return_metadata = Task(
|
||||
description="Scrape https://spider.cloud with a limit of 1 and enable metadata",
|
||||
expected_output="Metadata and 10 word summary of spider.cloud",
|
||||
agent=searcher
|
||||
agent=searcher,
|
||||
)
|
||||
|
||||
css_selector = Task(
|
||||
description="Scrape one page of spider.cloud with the `body > div > main > section.grid.md\:grid-cols-2.gap-10.place-items-center.md\:max-w-screen-xl.mx-auto.pb-8.pt-20 > div:nth-child(1) > h1` CSS selector",
|
||||
expected_output="The content of the element with the css selector body > div > main > section.grid.md\:grid-cols-2.gap-10.place-items-center.md\:max-w-screen-xl.mx-auto.pb-8.pt-20 > div:nth-child(1) > h1",
|
||||
agent=searcher
|
||||
agent=searcher,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[searcher],
|
||||
tasks=[
|
||||
choose_between_scrape_crawl,
|
||||
return_metadata,
|
||||
css_selector
|
||||
],
|
||||
verbose=True
|
||||
tasks=[choose_between_scrape_crawl, return_metadata, css_selector],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
crew.kickoff()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_spider_tool()
|
||||
|
||||
103
tests/tools/snowflake_search_tool_test.py
Normal file
103
tests/tools/snowflake_search_tool_test.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools import SnowflakeConfig, SnowflakeSearchTool
|
||||
|
||||
|
||||
# Unit Test Fixtures
|
||||
@pytest.fixture
|
||||
def mock_snowflake_connection():
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.description = [("col1",), ("col2",)]
|
||||
mock_cursor.fetchall.return_value = [(1, "value1"), (2, "value2")]
|
||||
mock_cursor.execute.return_value = None
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
return mock_conn
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
return SnowflakeConfig(
|
||||
account="test_account",
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
warehouse="test_warehouse",
|
||||
database="test_db",
|
||||
snowflake_schema="test_schema",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def snowflake_tool(mock_config):
|
||||
with patch("snowflake.connector.connect") as mock_connect:
|
||||
tool = SnowflakeSearchTool(config=mock_config)
|
||||
yield tool
|
||||
|
||||
|
||||
# Unit Tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_query_execution(snowflake_tool, mock_snowflake_connection):
|
||||
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
|
||||
mock_create_conn.return_value = mock_snowflake_connection
|
||||
|
||||
results = await snowflake_tool._run(
|
||||
query="SELECT * FROM test_table", timeout=300
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["col1"] == 1
|
||||
assert results[0]["col2"] == "value1"
|
||||
mock_snowflake_connection.cursor.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_pooling(snowflake_tool, mock_snowflake_connection):
|
||||
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
|
||||
mock_create_conn.return_value = mock_snowflake_connection
|
||||
|
||||
# Execute multiple queries
|
||||
await asyncio.gather(
|
||||
snowflake_tool._run("SELECT 1"),
|
||||
snowflake_tool._run("SELECT 2"),
|
||||
snowflake_tool._run("SELECT 3"),
|
||||
)
|
||||
|
||||
# Should reuse connections from pool
|
||||
assert mock_create_conn.call_count <= snowflake_tool.pool_size
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_on_deletion(snowflake_tool, mock_snowflake_connection):
|
||||
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
|
||||
mock_create_conn.return_value = mock_snowflake_connection
|
||||
|
||||
# Add connection to pool
|
||||
await snowflake_tool._get_connection()
|
||||
|
||||
# Return connection to pool
|
||||
async with snowflake_tool._pool_lock:
|
||||
snowflake_tool._connection_pool.append(mock_snowflake_connection)
|
||||
|
||||
# Trigger cleanup
|
||||
snowflake_tool.__del__()
|
||||
|
||||
mock_snowflake_connection.close.assert_called_once()
|
||||
|
||||
|
||||
def test_config_validation():
|
||||
# Test missing required fields
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeConfig()
|
||||
|
||||
# Test invalid account format
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeConfig(
|
||||
account="invalid//account", user="test_user", password="test_pass"
|
||||
)
|
||||
|
||||
# Test missing authentication
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeConfig(account="test_account", user="test_user")
|
||||
@@ -7,7 +7,9 @@ from crewai_tools.tools.code_interpreter_tool.code_interpreter_tool import (
|
||||
|
||||
|
||||
class TestCodeInterpreterTool(unittest.TestCase):
|
||||
@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env")
|
||||
@patch(
|
||||
"crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env"
|
||||
)
|
||||
def test_run_code_in_docker(self, docker_mock):
|
||||
tool = CodeInterpreterTool()
|
||||
code = "print('Hello, World!')"
|
||||
@@ -15,14 +17,14 @@ class TestCodeInterpreterTool(unittest.TestCase):
|
||||
expected_output = "Hello, World!\n"
|
||||
|
||||
docker_mock().containers.run().exec_run().exit_code = 0
|
||||
docker_mock().containers.run().exec_run().output = (
|
||||
expected_output.encode()
|
||||
)
|
||||
docker_mock().containers.run().exec_run().output = expected_output.encode()
|
||||
result = tool.run_code_in_docker(code, libraries_used)
|
||||
|
||||
self.assertEqual(result, expected_output)
|
||||
|
||||
@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env")
|
||||
@patch(
|
||||
"crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env"
|
||||
)
|
||||
def test_run_code_in_docker_with_error(self, docker_mock):
|
||||
tool = CodeInterpreterTool()
|
||||
code = "print(1/0)"
|
||||
@@ -37,7 +39,9 @@ class TestCodeInterpreterTool(unittest.TestCase):
|
||||
|
||||
self.assertEqual(result, expected_output)
|
||||
|
||||
@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env")
|
||||
@patch(
|
||||
"crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env"
|
||||
)
|
||||
def test_run_code_in_docker_with_script(self, docker_mock):
|
||||
tool = CodeInterpreterTool()
|
||||
code = """print("This is line 1")
|
||||
|
||||
Reference in New Issue
Block a user