mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Merge branch 'main' into feature/add-hyperbrowser
This commit is contained in:
@@ -44,6 +44,8 @@ from .tools import (
|
||||
SerplyScholarSearchTool,
|
||||
SerplyWebpageToMarkdownTool,
|
||||
SerplyWebSearchTool,
|
||||
SnowflakeConfig,
|
||||
SnowflakeSearchTool,
|
||||
SpiderTool,
|
||||
TXTSearchTool,
|
||||
VisionTool,
|
||||
|
||||
@@ -55,6 +55,11 @@ from .serply_api_tool.serply_news_search_tool import SerplyNewsSearchTool
|
||||
from .serply_api_tool.serply_scholar_search_tool import SerplyScholarSearchTool
|
||||
from .serply_api_tool.serply_web_search_tool import SerplyWebSearchTool
|
||||
from .serply_api_tool.serply_webpage_to_markdown_tool import SerplyWebpageToMarkdownTool
|
||||
from .snowflake_search_tool import (
|
||||
SnowflakeConfig,
|
||||
SnowflakeSearchTool,
|
||||
SnowflakeSearchToolInput,
|
||||
)
|
||||
from .spider_tool.spider_tool import SpiderTool
|
||||
from .txt_search_tool.txt_search_tool import TXTSearchTool
|
||||
from .vision_tool.vision_tool import VisionTool
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
from typing import Any, Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BrowserbaseLoadToolSchema(BaseModel):
|
||||
@@ -11,12 +11,10 @@ class BrowserbaseLoadToolSchema(BaseModel):
|
||||
|
||||
class BrowserbaseLoadTool(BaseTool):
|
||||
name: str = "Browserbase web load tool"
|
||||
description: str = (
|
||||
"Load webpages url in a headless browser using Browserbase and return the contents"
|
||||
)
|
||||
description: str = "Load webpages url in a headless browser using Browserbase and return the contents"
|
||||
args_schema: Type[BaseModel] = BrowserbaseLoadToolSchema
|
||||
api_key: Optional[str] = os.getenv('BROWSERBASE_API_KEY')
|
||||
project_id: Optional[str] = os.getenv('BROWSERBASE_PROJECT_ID')
|
||||
api_key: Optional[str] = os.getenv("BROWSERBASE_API_KEY")
|
||||
project_id: Optional[str] = os.getenv("BROWSERBASE_PROJECT_ID")
|
||||
text_content: Optional[bool] = False
|
||||
session_id: Optional[str] = None
|
||||
proxy: Optional[bool] = None
|
||||
@@ -33,7 +31,9 @@ class BrowserbaseLoadTool(BaseTool):
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if not self.api_key:
|
||||
raise EnvironmentError("BROWSERBASE_API_KEY environment variable is required for initialization")
|
||||
raise EnvironmentError(
|
||||
"BROWSERBASE_API_KEY environment variable is required for initialization"
|
||||
)
|
||||
try:
|
||||
from browserbase import Browserbase # type: ignore
|
||||
except ImportError:
|
||||
|
||||
@@ -2,10 +2,12 @@ import importlib.util
|
||||
import os
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from docker import from_env as docker_from_env
|
||||
from docker import DockerClient
|
||||
from docker.models.containers import Container
|
||||
from docker.errors import ImageNotFound, NotFound
|
||||
from crewai.tools import BaseTool
|
||||
from docker.models.containers import Container
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -30,7 +32,7 @@ class CodeInterpreterTool(BaseTool):
|
||||
default_image_tag: str = "code-interpreter:latest"
|
||||
code: Optional[str] = None
|
||||
user_dockerfile_path: Optional[str] = None
|
||||
user_docker_base_url: Optional[str] = None
|
||||
user_docker_base_url: Optional[str] = None
|
||||
unsafe_mode: bool = False
|
||||
|
||||
@staticmethod
|
||||
@@ -43,7 +45,11 @@ class CodeInterpreterTool(BaseTool):
|
||||
Verify if the Docker image is available. Optionally use a user-provided Dockerfile.
|
||||
"""
|
||||
|
||||
client = docker_from_env() if self.user_docker_base_url == None else docker.DockerClient(base_url=self.user_docker_base_url)
|
||||
client = (
|
||||
docker_from_env()
|
||||
if self.user_docker_base_url == None
|
||||
else DockerClient(base_url=self.user_docker_base_url)
|
||||
)
|
||||
|
||||
try:
|
||||
client.images.get(self.default_image_tag)
|
||||
@@ -76,9 +82,7 @@ class CodeInterpreterTool(BaseTool):
|
||||
else:
|
||||
return self.run_code_in_docker(code, libraries_used)
|
||||
|
||||
def _install_libraries(
|
||||
self, container: Container, libraries: List[str]
|
||||
) -> None:
|
||||
def _install_libraries(self, container: Container, libraries: List[str]) -> None:
|
||||
"""
|
||||
Install missing libraries in the Docker container
|
||||
"""
|
||||
@@ -135,4 +139,4 @@ class CodeInterpreterTool(BaseTool):
|
||||
exec(code, {}, exec_locals)
|
||||
return exec_locals.get("result", "No result variable found.")
|
||||
except Exception as e:
|
||||
return f"An error occurred: {str(e)}"
|
||||
return f"An error occurred: {str(e)}"
|
||||
@@ -8,8 +8,6 @@ from pydantic import BaseModel, Field
|
||||
class FixedDirectoryReadToolSchema(BaseModel):
|
||||
"""Input for DirectoryReadTool."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DirectoryReadToolSchema(FixedDirectoryReadToolSchema):
|
||||
"""Input for DirectoryReadTool."""
|
||||
|
||||
@@ -32,6 +32,7 @@ class FileReadTool(BaseTool):
|
||||
>>> content = tool.run() # Reads /path/to/file.txt
|
||||
>>> content = tool.run(file_path="/path/to/other.txt") # Reads other.txt
|
||||
"""
|
||||
|
||||
name: str = "Read a file's content"
|
||||
description: str = "A tool that reads the content of a file. To use this tool, provide a 'file_path' parameter with the path to the file you want to read."
|
||||
args_schema: Type[BaseModel] = FileReadToolSchema
|
||||
@@ -45,10 +46,11 @@ class FileReadTool(BaseTool):
|
||||
this becomes the default file path for the tool.
|
||||
**kwargs: Additional keyword arguments passed to BaseTool.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
if file_path is not None:
|
||||
self.file_path = file_path
|
||||
self.description = f"A tool that reads file content. The default file is {file_path}, but you can provide a different 'file_path' parameter to read another file."
|
||||
kwargs['description'] = f"A tool that reads file content. The default file is {file_path}, but you can provide a different 'file_path' parameter to read another file."
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.file_path = file_path
|
||||
|
||||
def _run(
|
||||
self,
|
||||
@@ -57,7 +59,7 @@ class FileReadTool(BaseTool):
|
||||
file_path = kwargs.get("file_path", self.file_path)
|
||||
if file_path is None:
|
||||
return "Error: No file path provided. Please provide a file path either in the constructor or as an argument."
|
||||
|
||||
|
||||
try:
|
||||
with open(file_path, "r") as file:
|
||||
return file.read()
|
||||
@@ -66,16 +68,4 @@ class FileReadTool(BaseTool):
|
||||
except PermissionError:
|
||||
return f"Error: Permission denied when trying to read file: {file_path}"
|
||||
except Exception as e:
|
||||
return f"Error: Failed to read file {file_path}. {str(e)}"
|
||||
|
||||
def _generate_description(self) -> None:
|
||||
"""Generate the tool description based on file path.
|
||||
|
||||
This method updates the tool's description to include information about
|
||||
the default file path while maintaining the ability to specify a different
|
||||
file at runtime.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
self.description = f"A tool that can be used to read {self.file_path}'s content."
|
||||
return f"Error: Failed to read file {file_path}. {str(e)}"
|
||||
@@ -15,9 +15,7 @@ class FileWriterToolInput(BaseModel):
|
||||
|
||||
class FileWriterTool(BaseTool):
|
||||
name: str = "File Writer Tool"
|
||||
description: str = (
|
||||
"A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input."
|
||||
)
|
||||
description: str = "A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input."
|
||||
args_schema: Type[BaseModel] = FileWriterToolInput
|
||||
|
||||
def _run(self, **kwargs: Any) -> str:
|
||||
|
||||
@@ -72,4 +72,3 @@ except ImportError:
|
||||
"""
|
||||
When this tool is not used, then exception can be ignored.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -63,4 +63,3 @@ except ImportError:
|
||||
"""
|
||||
When this tool is not used, then exception can be ignored.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -27,9 +27,7 @@ class GithubSearchToolSchema(FixedGithubSearchToolSchema):
|
||||
|
||||
class GithubSearchTool(RagTool):
|
||||
name: str = "Search a github repo's content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
|
||||
)
|
||||
description: str = "A tool that can be used to semantic search a query from a github repo's content. This is not the GitHub API, but instead a tool that can provide semantic search capabilities."
|
||||
summarize: bool = False
|
||||
gh_token: str
|
||||
args_schema: Type[BaseModel] = GithubSearchToolSchema
|
||||
|
||||
@@ -13,9 +13,7 @@ class JinaScrapeWebsiteToolInput(BaseModel):
|
||||
|
||||
class JinaScrapeWebsiteTool(BaseTool):
|
||||
name: str = "JinaScrapeWebsiteTool"
|
||||
description: str = (
|
||||
"A tool that can be used to read a website content using Jina.ai reader and return markdown content."
|
||||
)
|
||||
description: str = "A tool that can be used to read a website content using Jina.ai reader and return markdown content."
|
||||
args_schema: Type[BaseModel] = JinaScrapeWebsiteToolInput
|
||||
website_url: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import Any
|
||||
|
||||
try:
|
||||
from linkup import LinkupClient
|
||||
|
||||
LINKUP_AVAILABLE = True
|
||||
except ImportError:
|
||||
LINKUP_AVAILABLE = False
|
||||
@@ -9,10 +10,13 @@ except ImportError:
|
||||
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
|
||||
class LinkupSearchTool:
|
||||
name: str = "Linkup Search Tool"
|
||||
description: str = "Performs an API call to Linkup to retrieve contextual information."
|
||||
_client: LinkupClient = PrivateAttr() # type: ignore
|
||||
description: str = (
|
||||
"Performs an API call to Linkup to retrieve contextual information."
|
||||
)
|
||||
_client: LinkupClient = PrivateAttr() # type: ignore
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
"""
|
||||
@@ -25,7 +29,9 @@ class LinkupSearchTool:
|
||||
)
|
||||
self._client = LinkupClient(api_key=api_key)
|
||||
|
||||
def _run(self, query: str, depth: str = "standard", output_type: str = "searchResults") -> dict:
|
||||
def _run(
|
||||
self, query: str, depth: str = "standard", output_type: str = "searchResults"
|
||||
) -> dict:
|
||||
"""
|
||||
Executes a search using the Linkup API.
|
||||
|
||||
@@ -36,9 +42,7 @@ class LinkupSearchTool:
|
||||
"""
|
||||
try:
|
||||
response = self._client.search(
|
||||
query=query,
|
||||
depth=depth,
|
||||
output_type=output_type
|
||||
query=query, depth=depth, output_type=output_type
|
||||
)
|
||||
results = [
|
||||
{"name": result.name, "url": result.url, "content": result.content}
|
||||
|
||||
@@ -17,9 +17,7 @@ class MySQLSearchToolSchema(BaseModel):
|
||||
|
||||
class MySQLSearchTool(RagTool):
|
||||
name: str = "Search a database's table content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a database table's content."
|
||||
)
|
||||
description: str = "A tool that can be used to semantic search a query from a database table's content."
|
||||
args_schema: Type[BaseModel] = MySQLSearchToolSchema
|
||||
db_uri: str = Field(..., description="Mandatory database URI")
|
||||
|
||||
|
||||
@@ -1,30 +1,24 @@
|
||||
from crewai import Agent, Crew, Task
|
||||
from patronus_eval_tool import (
|
||||
PatronusEvalTool,
|
||||
)
|
||||
from patronus_local_evaluator_tool import (
|
||||
PatronusLocalEvaluatorTool,
|
||||
)
|
||||
from patronus_predefined_criteria_eval_tool import (
|
||||
PatronusPredefinedCriteriaEvalTool,
|
||||
)
|
||||
from patronus import Client, EvaluationResult
|
||||
import random
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from patronus import Client, EvaluationResult
|
||||
from patronus_local_evaluator_tool import PatronusLocalEvaluatorTool
|
||||
|
||||
# Test the PatronusLocalEvaluatorTool where agent uses the local evaluator
|
||||
client = Client()
|
||||
|
||||
|
||||
# Example of an evaluator that returns a random pass/fail result
|
||||
@client.register_local_evaluator("random_evaluator")
|
||||
def random_evaluator(**kwargs):
|
||||
score = random.random()
|
||||
return EvaluationResult(
|
||||
score_raw=score,
|
||||
pass_=score >= 0.5,
|
||||
explanation="example explanation" # Optional justification for LLM judges
|
||||
score_raw=score,
|
||||
pass_=score >= 0.5,
|
||||
explanation="example explanation", # Optional justification for LLM judges
|
||||
)
|
||||
|
||||
|
||||
# 1. Uses PatronusEvalTool: agent can pick the best evaluator and criteria
|
||||
# patronus_eval_tool = PatronusEvalTool()
|
||||
|
||||
@@ -35,7 +29,9 @@ def random_evaluator(**kwargs):
|
||||
|
||||
# 3. Uses PatronusLocalEvaluatorTool: agent uses user defined evaluator
|
||||
patronus_eval_tool = PatronusLocalEvaluatorTool(
|
||||
patronus_client=client, evaluator="random_evaluator", evaluated_model_gold_answer="example label"
|
||||
patronus_client=client,
|
||||
evaluator="random_evaluator",
|
||||
evaluated_model_gold_answer="example label",
|
||||
)
|
||||
|
||||
# Create a new agent
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import os
|
||||
import json
|
||||
import requests
|
||||
import os
|
||||
import warnings
|
||||
from typing import Any, List, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
|
||||
@@ -19,7 +20,9 @@ class PatronusEvalTool(BaseTool):
|
||||
self.evaluators = temp_evaluators
|
||||
self.criteria = temp_criteria
|
||||
self.description = self._generate_description()
|
||||
warnings.warn("You are allowing the agent to select the best evaluator and criteria when you use the `PatronusEvalTool`. If this is not intended then please use `PatronusPredefinedCriteriaEvalTool` instead.")
|
||||
warnings.warn(
|
||||
"You are allowing the agent to select the best evaluator and criteria when you use the `PatronusEvalTool`. If this is not intended then please use `PatronusPredefinedCriteriaEvalTool` instead."
|
||||
)
|
||||
|
||||
def _init_run(self):
|
||||
evaluators_set = json.loads(
|
||||
@@ -104,7 +107,6 @@ class PatronusEvalTool(BaseTool):
|
||||
evaluated_model_retrieved_context: Optional[str],
|
||||
evaluators: List[Dict[str, str]],
|
||||
) -> Any:
|
||||
|
||||
# Assert correct format of evaluators
|
||||
evals = []
|
||||
for ev in evaluators:
|
||||
@@ -136,4 +138,4 @@ class PatronusEvalTool(BaseTool):
|
||||
f"Failed to evaluate model input and output. Response status code: {response.status_code}. Reason: {response.text}"
|
||||
)
|
||||
|
||||
return response.json()
|
||||
return response.json()
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from typing import Any, Type
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
from patronus import Client
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FixedLocalEvaluatorToolSchema(BaseModel):
|
||||
@@ -24,16 +25,20 @@ class PatronusLocalEvaluatorTool(BaseTool):
|
||||
name: str = "Patronus Local Evaluator Tool"
|
||||
evaluator: str = "The registered local evaluator"
|
||||
evaluated_model_gold_answer: str = "The agent's gold answer"
|
||||
description: str = (
|
||||
"This tool is used to evaluate the model input and output using custom function evaluators."
|
||||
)
|
||||
description: str = "This tool is used to evaluate the model input and output using custom function evaluators."
|
||||
client: Any = None
|
||||
args_schema: Type[BaseModel] = FixedLocalEvaluatorToolSchema
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, patronus_client: Client, evaluator: str, evaluated_model_gold_answer: str, **kwargs: Any):
|
||||
def __init__(
|
||||
self,
|
||||
patronus_client: Client,
|
||||
evaluator: str,
|
||||
evaluated_model_gold_answer: str,
|
||||
**kwargs: Any,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.client = patronus_client
|
||||
if evaluator:
|
||||
@@ -79,7 +84,7 @@ class PatronusLocalEvaluatorTool(BaseTool):
|
||||
if isinstance(evaluated_model_gold_answer, str)
|
||||
else evaluated_model_gold_answer.get("description")
|
||||
),
|
||||
tags={}, # Optional metadata, supports arbitrary kv pairs
|
||||
tags={}, # Optional metadata, supports arbitrary kv pairs
|
||||
)
|
||||
output = f"Evaluation result: {result.pass_}, Explanation: {result.explanation}"
|
||||
return output
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
import requests
|
||||
from typing import Any, List, Dict, Type
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -33,9 +34,7 @@ class PatronusPredefinedCriteriaEvalTool(BaseTool):
|
||||
"""
|
||||
|
||||
name: str = "Call Patronus API tool for evaluation of model inputs and outputs"
|
||||
description: str = (
|
||||
"""This tool calls the Patronus Evaluation API that takes the following arguments:"""
|
||||
)
|
||||
description: str = """This tool calls the Patronus Evaluation API that takes the following arguments:"""
|
||||
evaluate_url: str = "https://api.patronus.ai/v1/evaluate"
|
||||
args_schema: Type[BaseModel] = FixedBaseToolSchema
|
||||
evaluators: List[Dict[str, str]] = []
|
||||
@@ -52,7 +51,6 @@ class PatronusPredefinedCriteriaEvalTool(BaseTool):
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
|
||||
evaluated_model_input = kwargs.get("evaluated_model_input")
|
||||
evaluated_model_output = kwargs.get("evaluated_model_output")
|
||||
evaluated_model_retrieved_context = kwargs.get(
|
||||
@@ -103,4 +101,4 @@ class PatronusPredefinedCriteriaEvalTool(BaseTool):
|
||||
f"Failed to evaluate model input and output. Status code: {response.status_code}. Reason: {response.text}"
|
||||
)
|
||||
|
||||
return response.json()
|
||||
return response.json()
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
from pypdf import PdfReader, PdfWriter, PageObject, ContentStream, NameObject, Font
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pypdf import ContentStream, Font, NameObject, PageObject, PdfReader, PdfWriter
|
||||
|
||||
from crewai_tools.tools.rag.rag_tool import RagTool
|
||||
|
||||
|
||||
|
||||
@@ -17,9 +17,7 @@ class PGSearchToolSchema(BaseModel):
|
||||
|
||||
class PGSearchTool(RagTool):
|
||||
name: str = "Search a database's table content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a database table's content."
|
||||
)
|
||||
description: str = "A tool that can be used to semantic search a query from a database table's content."
|
||||
args_schema: Type[BaseModel] = PGSearchToolSchema
|
||||
db_uri: str = Field(..., description="Mandatory database URI")
|
||||
|
||||
|
||||
@@ -10,8 +10,6 @@ from pydantic import BaseModel, Field
|
||||
class FixedScrapeElementFromWebsiteToolSchema(BaseModel):
|
||||
"""Input for ScrapeElementFromWebsiteTool."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ScrapeElementFromWebsiteToolSchema(FixedScrapeElementFromWebsiteToolSchema):
|
||||
"""Input for ScrapeElementFromWebsiteTool."""
|
||||
|
||||
@@ -11,8 +11,6 @@ from pydantic import BaseModel, Field
|
||||
class FixedScrapeWebsiteToolSchema(BaseModel):
|
||||
"""Input for ScrapeWebsiteTool."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ScrapeWebsiteToolSchema(FixedScrapeWebsiteToolSchema):
|
||||
"""Input for ScrapeWebsiteTool."""
|
||||
|
||||
@@ -10,17 +10,14 @@ from scrapegraph_py.logger import sgai_logger
|
||||
|
||||
class ScrapegraphError(Exception):
|
||||
"""Base exception for Scrapegraph-related errors"""
|
||||
pass
|
||||
|
||||
|
||||
class RateLimitError(ScrapegraphError):
|
||||
"""Raised when API rate limits are exceeded"""
|
||||
pass
|
||||
|
||||
|
||||
class FixedScrapegraphScrapeToolSchema(BaseModel):
|
||||
"""Input for ScrapegraphScrapeTool when website_url is fixed."""
|
||||
pass
|
||||
|
||||
|
||||
class ScrapegraphScrapeToolSchema(FixedScrapegraphScrapeToolSchema):
|
||||
@@ -32,7 +29,7 @@ class ScrapegraphScrapeToolSchema(FixedScrapegraphScrapeToolSchema):
|
||||
description="Prompt to guide the extraction of content",
|
||||
)
|
||||
|
||||
@validator('website_url')
|
||||
@validator("website_url")
|
||||
def validate_url(cls, v):
|
||||
"""Validate URL format"""
|
||||
try:
|
||||
@@ -41,13 +38,15 @@ class ScrapegraphScrapeToolSchema(FixedScrapegraphScrapeToolSchema):
|
||||
raise ValueError
|
||||
return v
|
||||
except Exception:
|
||||
raise ValueError("Invalid URL format. URL must include scheme (http/https) and domain")
|
||||
raise ValueError(
|
||||
"Invalid URL format. URL must include scheme (http/https) and domain"
|
||||
)
|
||||
|
||||
|
||||
class ScrapegraphScrapeTool(BaseTool):
|
||||
"""
|
||||
A tool that uses Scrapegraph AI to intelligently scrape website content.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If API key is missing or URL format is invalid
|
||||
RateLimitError: If API rate limits are exceeded
|
||||
@@ -55,7 +54,9 @@ class ScrapegraphScrapeTool(BaseTool):
|
||||
"""
|
||||
|
||||
name: str = "Scrapegraph website scraper"
|
||||
description: str = "A tool that uses Scrapegraph AI to intelligently scrape website content."
|
||||
description: str = (
|
||||
"A tool that uses Scrapegraph AI to intelligently scrape website content."
|
||||
)
|
||||
args_schema: Type[BaseModel] = ScrapegraphScrapeToolSchema
|
||||
website_url: Optional[str] = None
|
||||
user_prompt: Optional[str] = None
|
||||
@@ -70,7 +71,7 @@ class ScrapegraphScrapeTool(BaseTool):
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = api_key or os.getenv("SCRAPEGRAPH_API_KEY")
|
||||
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("Scrapegraph API key is required")
|
||||
|
||||
@@ -79,7 +80,7 @@ class ScrapegraphScrapeTool(BaseTool):
|
||||
self.website_url = website_url
|
||||
self.description = f"A tool that uses Scrapegraph AI to intelligently scrape {website_url}'s content."
|
||||
self.args_schema = FixedScrapegraphScrapeToolSchema
|
||||
|
||||
|
||||
if user_prompt is not None:
|
||||
self.user_prompt = user_prompt
|
||||
|
||||
@@ -94,22 +95,24 @@ class ScrapegraphScrapeTool(BaseTool):
|
||||
if not all([result.scheme, result.netloc]):
|
||||
raise ValueError
|
||||
except Exception:
|
||||
raise ValueError("Invalid URL format. URL must include scheme (http/https) and domain")
|
||||
raise ValueError(
|
||||
"Invalid URL format. URL must include scheme (http/https) and domain"
|
||||
)
|
||||
|
||||
def _handle_api_response(self, response: dict) -> str:
|
||||
"""Handle and validate API response"""
|
||||
if not response:
|
||||
raise RuntimeError("Empty response from Scrapegraph API")
|
||||
|
||||
|
||||
if "error" in response:
|
||||
error_msg = response.get("error", {}).get("message", "Unknown error")
|
||||
if "rate limit" in error_msg.lower():
|
||||
raise RateLimitError(f"Rate limit exceeded: {error_msg}")
|
||||
raise RuntimeError(f"API error: {error_msg}")
|
||||
|
||||
|
||||
if "result" not in response:
|
||||
raise RuntimeError("Invalid response format from Scrapegraph API")
|
||||
|
||||
|
||||
return response["result"]
|
||||
|
||||
def _run(
|
||||
@@ -117,7 +120,10 @@ class ScrapegraphScrapeTool(BaseTool):
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
website_url = kwargs.get("website_url", self.website_url)
|
||||
user_prompt = kwargs.get("user_prompt", self.user_prompt) or "Extract the main content of the webpage"
|
||||
user_prompt = (
|
||||
kwargs.get("user_prompt", self.user_prompt)
|
||||
or "Extract the main content of the webpage"
|
||||
)
|
||||
|
||||
if not website_url:
|
||||
raise ValueError("website_url is required")
|
||||
|
||||
@@ -17,33 +17,36 @@ class FixedSeleniumScrapingToolSchema(BaseModel):
|
||||
class SeleniumScrapingToolSchema(FixedSeleniumScrapingToolSchema):
|
||||
"""Input for SeleniumScrapingTool."""
|
||||
|
||||
website_url: str = Field(..., description="Mandatory website url to read the file. Must start with http:// or https://")
|
||||
website_url: str = Field(
|
||||
...,
|
||||
description="Mandatory website url to read the file. Must start with http:// or https://",
|
||||
)
|
||||
css_element: str = Field(
|
||||
...,
|
||||
description="Mandatory css reference for element to scrape from the website",
|
||||
)
|
||||
|
||||
@validator('website_url')
|
||||
@validator("website_url")
|
||||
def validate_website_url(cls, v):
|
||||
if not v:
|
||||
raise ValueError("Website URL cannot be empty")
|
||||
|
||||
|
||||
if len(v) > 2048: # Common maximum URL length
|
||||
raise ValueError("URL is too long (max 2048 characters)")
|
||||
|
||||
if not re.match(r'^https?://', v):
|
||||
|
||||
if not re.match(r"^https?://", v):
|
||||
raise ValueError("URL must start with http:// or https://")
|
||||
|
||||
|
||||
try:
|
||||
result = urlparse(v)
|
||||
if not all([result.scheme, result.netloc]):
|
||||
raise ValueError("Invalid URL format")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid URL: {str(e)}")
|
||||
|
||||
if re.search(r'\s', v):
|
||||
|
||||
if re.search(r"\s", v):
|
||||
raise ValueError("URL cannot contain whitespace")
|
||||
|
||||
|
||||
return v
|
||||
|
||||
|
||||
@@ -130,11 +133,11 @@ class SeleniumScrapingTool(BaseTool):
|
||||
def _create_driver(self, url, cookie, wait_time):
|
||||
if not url:
|
||||
raise ValueError("URL cannot be empty")
|
||||
|
||||
|
||||
# Validate URL format
|
||||
if not re.match(r'^https?://', url):
|
||||
if not re.match(r"^https?://", url):
|
||||
raise ValueError("URL must start with http:// or https://")
|
||||
|
||||
|
||||
options = Options()
|
||||
options.add_argument("--headless")
|
||||
driver = self.driver(options=options)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Optional, Any, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
|
||||
class SerpApiBaseTool(BaseTool):
|
||||
"""Base class for SerpApi functionality with shared capabilities."""
|
||||
|
||||
|
||||
@@ -1,14 +1,22 @@
|
||||
from typing import Any, Type, Optional
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
import re
|
||||
from pydantic import BaseModel, Field
|
||||
from .serpapi_base_tool import SerpApiBaseTool
|
||||
from serpapi import HTTPError
|
||||
from urllib.error import HTTPError
|
||||
|
||||
from .serpapi_base_tool import SerpApiBaseTool
|
||||
|
||||
|
||||
class SerpApiGoogleSearchToolSchema(BaseModel):
|
||||
"""Input for Google Search."""
|
||||
search_query: str = Field(..., description="Mandatory search query you want to use to Google search.")
|
||||
location: Optional[str] = Field(None, description="Location you want the search to be performed in.")
|
||||
|
||||
search_query: str = Field(
|
||||
..., description="Mandatory search query you want to use to Google search."
|
||||
)
|
||||
location: Optional[str] = Field(
|
||||
None, description="Location you want the search to be performed in."
|
||||
)
|
||||
|
||||
|
||||
class SerpApiGoogleSearchTool(SerpApiBaseTool):
|
||||
name: str = "Google Search"
|
||||
@@ -22,19 +30,25 @@ class SerpApiGoogleSearchTool(SerpApiBaseTool):
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
try:
|
||||
results = self.client.search({
|
||||
"q": kwargs.get("search_query"),
|
||||
"location": kwargs.get("location"),
|
||||
}).as_dict()
|
||||
results = self.client.search(
|
||||
{
|
||||
"q": kwargs.get("search_query"),
|
||||
"location": kwargs.get("location"),
|
||||
}
|
||||
).as_dict()
|
||||
|
||||
self._omit_fields(
|
||||
results,
|
||||
[r"search_metadata", r"search_parameters", r"serpapi_.+", r".+_token", r"displayed_link", r"pagination"]
|
||||
results,
|
||||
[
|
||||
r"search_metadata",
|
||||
r"search_parameters",
|
||||
r"serpapi_.+",
|
||||
r".+_token",
|
||||
r"displayed_link",
|
||||
r"pagination",
|
||||
],
|
||||
)
|
||||
|
||||
return results
|
||||
except HTTPError as e:
|
||||
return f"An error occurred: {str(e)}. Some parameters may be invalid."
|
||||
|
||||
|
||||
|
||||
@@ -1,14 +1,21 @@
|
||||
from typing import Any, Type, Optional
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
import re
|
||||
from pydantic import BaseModel, Field
|
||||
from .serpapi_base_tool import SerpApiBaseTool
|
||||
from serpapi import HTTPError
|
||||
from urllib.error import HTTPError
|
||||
|
||||
from .serpapi_base_tool import SerpApiBaseTool
|
||||
|
||||
|
||||
class SerpApiGoogleShoppingToolSchema(BaseModel):
|
||||
"""Input for Google Shopping."""
|
||||
search_query: str = Field(..., description="Mandatory search query you want to use to Google shopping.")
|
||||
location: Optional[str] = Field(None, description="Location you want the search to be performed in.")
|
||||
|
||||
search_query: str = Field(
|
||||
..., description="Mandatory search query you want to use to Google shopping."
|
||||
)
|
||||
location: Optional[str] = Field(
|
||||
None, description="Location you want the search to be performed in."
|
||||
)
|
||||
|
||||
|
||||
class SerpApiGoogleShoppingTool(SerpApiBaseTool):
|
||||
@@ -23,20 +30,25 @@ class SerpApiGoogleShoppingTool(SerpApiBaseTool):
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
try:
|
||||
results = self.client.search({
|
||||
"engine": "google_shopping",
|
||||
"q": kwargs.get("search_query"),
|
||||
"location": kwargs.get("location")
|
||||
}).as_dict()
|
||||
results = self.client.search(
|
||||
{
|
||||
"engine": "google_shopping",
|
||||
"q": kwargs.get("search_query"),
|
||||
"location": kwargs.get("location"),
|
||||
}
|
||||
).as_dict()
|
||||
|
||||
self._omit_fields(
|
||||
results,
|
||||
[r"search_metadata", r"search_parameters", r"serpapi_.+", r"filters", r"pagination"]
|
||||
results,
|
||||
[
|
||||
r"search_metadata",
|
||||
r"search_parameters",
|
||||
r"serpapi_.+",
|
||||
r"filters",
|
||||
r"pagination",
|
||||
],
|
||||
)
|
||||
|
||||
return results
|
||||
except HTTPError as e:
|
||||
return f"An error occurred: {str(e)}. Some parameters may be invalid."
|
||||
|
||||
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Type
|
||||
|
||||
import requests
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _save_results_to_file(content: str) -> None:
|
||||
"""Saves the search results to a file."""
|
||||
try:
|
||||
|
||||
@@ -18,9 +18,7 @@ class SerplyWebpageToMarkdownToolSchema(BaseModel):
|
||||
|
||||
class SerplyWebpageToMarkdownTool(RagTool):
|
||||
name: str = "Webpage to Markdown"
|
||||
description: str = (
|
||||
"A tool to perform convert a webpage to markdown to make it easier for LLMs to understand"
|
||||
)
|
||||
description: str = "A tool to perform convert a webpage to markdown to make it easier for LLMs to understand"
|
||||
args_schema: Type[BaseModel] = SerplyWebpageToMarkdownToolSchema
|
||||
request_url: str = "https://api.serply.io/v1/request"
|
||||
proxy_location: Optional[str] = "US"
|
||||
|
||||
155
src/crewai_tools/tools/snowflake_search_tool/README.md
Normal file
155
src/crewai_tools/tools/snowflake_search_tool/README.md
Normal file
@@ -0,0 +1,155 @@
|
||||
# Snowflake Search Tool
|
||||
|
||||
A tool for executing queries on Snowflake data warehouse with built-in connection pooling, retry logic, and async execution support.
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
uv sync --extra snowflake
|
||||
|
||||
OR
|
||||
uv pip install snowflake-connector-python>=3.5.0 snowflake-sqlalchemy>=1.5.0 cryptography>=41.0.0
|
||||
|
||||
OR
|
||||
pip install snowflake-connector-python>=3.5.0 snowflake-sqlalchemy>=1.5.0 cryptography>=41.0.0
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from crewai_tools import SnowflakeSearchTool, SnowflakeConfig
|
||||
|
||||
# Create configuration
|
||||
config = SnowflakeConfig(
|
||||
account="your_account",
|
||||
user="your_username",
|
||||
password="your_password",
|
||||
warehouse="COMPUTE_WH",
|
||||
database="your_database",
|
||||
snowflake_schema="your_schema" # Note: Uses snowflake_schema instead of schema
|
||||
)
|
||||
|
||||
# Initialize tool
|
||||
tool = SnowflakeSearchTool(
|
||||
config=config,
|
||||
pool_size=5,
|
||||
max_retries=3,
|
||||
enable_caching=True
|
||||
)
|
||||
|
||||
# Execute query
|
||||
async def main():
|
||||
results = await tool._run(
|
||||
query="SELECT * FROM your_table LIMIT 10",
|
||||
timeout=300
|
||||
)
|
||||
print(f"Retrieved {len(results)} rows")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
- ✨ Asynchronous query execution
|
||||
- 🚀 Connection pooling for better performance
|
||||
- 🔄 Automatic retries for transient failures
|
||||
- 💾 Query result caching (optional)
|
||||
- 🔒 Support for both password and key-pair authentication
|
||||
- 📝 Comprehensive error handling and logging
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### SnowflakeConfig Parameters
|
||||
|
||||
| Parameter | Required | Description |
|
||||
|-----------|----------|-------------|
|
||||
| account | Yes | Snowflake account identifier |
|
||||
| user | Yes | Snowflake username |
|
||||
| password | Yes* | Snowflake password |
|
||||
| private_key_path | No* | Path to private key file (alternative to password) |
|
||||
| warehouse | Yes | Snowflake warehouse name |
|
||||
| database | Yes | Default database |
|
||||
| snowflake_schema | Yes | Default schema |
|
||||
| role | No | Snowflake role |
|
||||
| session_parameters | No | Custom session parameters dict |
|
||||
|
||||
\* Either password or private_key_path must be provided
|
||||
|
||||
### Tool Parameters
|
||||
|
||||
| Parameter | Default | Description |
|
||||
|-----------|---------|-------------|
|
||||
| pool_size | 5 | Number of connections in the pool |
|
||||
| max_retries | 3 | Maximum retry attempts for failed queries |
|
||||
| retry_delay | 1.0 | Delay between retries in seconds |
|
||||
| enable_caching | True | Enable/disable query result caching |
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Using Key-Pair Authentication
|
||||
|
||||
```python
|
||||
config = SnowflakeConfig(
|
||||
account="your_account",
|
||||
user="your_username",
|
||||
private_key_path="/path/to/private_key.p8",
|
||||
warehouse="your_warehouse",
|
||||
database="your_database",
|
||||
snowflake_schema="your_schema"
|
||||
)
|
||||
```
|
||||
|
||||
### Custom Session Parameters
|
||||
|
||||
```python
|
||||
config = SnowflakeConfig(
|
||||
# ... other config parameters ...
|
||||
session_parameters={
|
||||
"QUERY_TAG": "my_app",
|
||||
"TIMEZONE": "America/Los_Angeles"
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Error Handling**: Always wrap query execution in try-except blocks
|
||||
2. **Logging**: Enable logging to track query execution and errors
|
||||
3. **Connection Management**: Use appropriate pool sizes for your workload
|
||||
4. **Timeouts**: Set reasonable query timeouts to prevent hanging
|
||||
5. **Security**: Use key-pair auth in production and never hardcode credentials
|
||||
|
||||
## Example with Logging
|
||||
|
||||
```python
|
||||
import logging
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def main():
|
||||
try:
|
||||
# ... tool initialization ...
|
||||
results = await tool._run(query="SELECT * FROM table LIMIT 10")
|
||||
logger.info(f"Query completed successfully. Retrieved {len(results)} rows")
|
||||
except Exception as e:
|
||||
logger.error(f"Query failed: {str(e)}")
|
||||
raise
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
The tool automatically handles common Snowflake errors:
|
||||
- DatabaseError
|
||||
- OperationalError
|
||||
- ProgrammingError
|
||||
- Network timeouts
|
||||
- Connection issues
|
||||
|
||||
Errors are logged and retried based on your retry configuration.
|
||||
11
src/crewai_tools/tools/snowflake_search_tool/__init__.py
Normal file
11
src/crewai_tools/tools/snowflake_search_tool/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .snowflake_search_tool import (
|
||||
SnowflakeConfig,
|
||||
SnowflakeSearchTool,
|
||||
SnowflakeSearchToolInput,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SnowflakeSearchTool",
|
||||
"SnowflakeSearchToolInput",
|
||||
"SnowflakeConfig",
|
||||
]
|
||||
@@ -0,0 +1,201 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
import snowflake.connector
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from pydantic import BaseModel, ConfigDict, Field, SecretStr
|
||||
from snowflake.connector.connection import SnowflakeConnection
|
||||
from snowflake.connector.errors import DatabaseError, OperationalError
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Cache for query results
|
||||
_query_cache = {}
|
||||
|
||||
|
||||
class SnowflakeConfig(BaseModel):
|
||||
"""Configuration for Snowflake connection."""
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
account: str = Field(
|
||||
..., description="Snowflake account identifier", pattern=r"^[a-zA-Z0-9\-_]+$"
|
||||
)
|
||||
user: str = Field(..., description="Snowflake username")
|
||||
password: Optional[SecretStr] = Field(None, description="Snowflake password")
|
||||
private_key_path: Optional[str] = Field(
|
||||
None, description="Path to private key file"
|
||||
)
|
||||
warehouse: Optional[str] = Field(None, description="Snowflake warehouse")
|
||||
database: Optional[str] = Field(None, description="Default database")
|
||||
snowflake_schema: Optional[str] = Field(None, description="Default schema")
|
||||
role: Optional[str] = Field(None, description="Snowflake role")
|
||||
session_parameters: Optional[Dict[str, Any]] = Field(
|
||||
default_factory=dict, description="Session parameters"
|
||||
)
|
||||
|
||||
@property
|
||||
def has_auth(self) -> bool:
|
||||
return bool(self.password or self.private_key_path)
|
||||
|
||||
def model_post_init(self, *args, **kwargs):
|
||||
if not self.has_auth:
|
||||
raise ValueError("Either password or private_key_path must be provided")
|
||||
|
||||
|
||||
class SnowflakeSearchToolInput(BaseModel):
|
||||
"""Input schema for SnowflakeSearchTool."""
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
query: str = Field(..., description="SQL query or semantic search query to execute")
|
||||
database: Optional[str] = Field(None, description="Override default database")
|
||||
snowflake_schema: Optional[str] = Field(None, description="Override default schema")
|
||||
timeout: Optional[int] = Field(300, description="Query timeout in seconds")
|
||||
|
||||
|
||||
class SnowflakeSearchTool(BaseTool):
|
||||
"""Tool for executing queries and semantic search on Snowflake."""
|
||||
|
||||
name: str = "Snowflake Database Search"
|
||||
description: str = (
|
||||
"Execute SQL queries or semantic search on Snowflake data warehouse. "
|
||||
"Supports both raw SQL and natural language queries."
|
||||
)
|
||||
args_schema: Type[BaseModel] = SnowflakeSearchToolInput
|
||||
|
||||
# Define Pydantic fields
|
||||
config: SnowflakeConfig = Field(
|
||||
..., description="Snowflake connection configuration"
|
||||
)
|
||||
pool_size: int = Field(default=5, description="Size of connection pool")
|
||||
max_retries: int = Field(default=3, description="Maximum retry attempts")
|
||||
retry_delay: float = Field(
|
||||
default=1.0, description="Delay between retries in seconds"
|
||||
)
|
||||
enable_caching: bool = Field(
|
||||
default=True, description="Enable query result caching"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, **data):
|
||||
"""Initialize SnowflakeSearchTool."""
|
||||
super().__init__(**data)
|
||||
self._connection_pool: List[SnowflakeConnection] = []
|
||||
self._pool_lock = asyncio.Lock()
|
||||
self._thread_pool = ThreadPoolExecutor(max_workers=self.pool_size)
|
||||
|
||||
async def _get_connection(self) -> SnowflakeConnection:
|
||||
"""Get a connection from the pool or create a new one."""
|
||||
async with self._pool_lock:
|
||||
if not self._connection_pool:
|
||||
conn = self._create_connection()
|
||||
self._connection_pool.append(conn)
|
||||
return self._connection_pool.pop()
|
||||
|
||||
def _create_connection(self) -> SnowflakeConnection:
|
||||
"""Create a new Snowflake connection."""
|
||||
conn_params = {
|
||||
"account": self.config.account,
|
||||
"user": self.config.user,
|
||||
"warehouse": self.config.warehouse,
|
||||
"database": self.config.database,
|
||||
"schema": self.config.snowflake_schema,
|
||||
"role": self.config.role,
|
||||
"session_parameters": self.config.session_parameters,
|
||||
}
|
||||
|
||||
if self.config.password:
|
||||
conn_params["password"] = self.config.password.get_secret_value()
|
||||
elif self.config.private_key_path:
|
||||
with open(self.config.private_key_path, "rb") as key_file:
|
||||
p_key = serialization.load_pem_private_key(
|
||||
key_file.read(), password=None, backend=default_backend()
|
||||
)
|
||||
conn_params["private_key"] = p_key
|
||||
|
||||
return snowflake.connector.connect(**conn_params)
|
||||
|
||||
def _get_cache_key(self, query: str, timeout: int) -> str:
|
||||
"""Generate a cache key for the query."""
|
||||
return f"{self.config.account}:{self.config.database}:{self.config.snowflake_schema}:{query}:{timeout}"
|
||||
|
||||
async def _execute_query(
|
||||
self, query: str, timeout: int = 300
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Execute a query with retries and return results."""
|
||||
if self.enable_caching:
|
||||
cache_key = self._get_cache_key(query, timeout)
|
||||
if cache_key in _query_cache:
|
||||
logger.info("Returning cached result")
|
||||
return _query_cache[cache_key]
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
conn = await self._get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(query, timeout=timeout)
|
||||
|
||||
if not cursor.description:
|
||||
return []
|
||||
|
||||
columns = [col[0] for col in cursor.description]
|
||||
results = [dict(zip(columns, row)) for row in cursor.fetchall()]
|
||||
|
||||
if self.enable_caching:
|
||||
_query_cache[self._get_cache_key(query, timeout)] = results
|
||||
|
||||
return results
|
||||
finally:
|
||||
cursor.close()
|
||||
async with self._pool_lock:
|
||||
self._connection_pool.append(conn)
|
||||
except (DatabaseError, OperationalError) as e:
|
||||
if attempt == self.max_retries - 1:
|
||||
raise
|
||||
await asyncio.sleep(self.retry_delay * (2**attempt))
|
||||
logger.warning(f"Query failed, attempt {attempt + 1}: {str(e)}")
|
||||
continue
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
query: str,
|
||||
database: Optional[str] = None,
|
||||
snowflake_schema: Optional[str] = None,
|
||||
timeout: int = 300,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Execute the search query."""
|
||||
try:
|
||||
# Override database/schema if provided
|
||||
if database:
|
||||
await self._execute_query(f"USE DATABASE {database}")
|
||||
if snowflake_schema:
|
||||
await self._execute_query(f"USE SCHEMA {snowflake_schema}")
|
||||
|
||||
results = await self._execute_query(query, timeout)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing query: {str(e)}")
|
||||
raise
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup connections on deletion."""
|
||||
try:
|
||||
for conn in getattr(self, "_connection_pool", []):
|
||||
try:
|
||||
conn.close()
|
||||
except:
|
||||
pass
|
||||
if hasattr(self, "_thread_pool"):
|
||||
self._thread_pool.shutdown()
|
||||
except:
|
||||
pass
|
||||
@@ -14,9 +14,8 @@ import os
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Set up logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -25,6 +24,7 @@ logger = logging.getLogger(__name__)
|
||||
STAGEHAND_AVAILABLE = False
|
||||
try:
|
||||
import stagehand
|
||||
|
||||
STAGEHAND_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass # Keep STAGEHAND_AVAILABLE as False
|
||||
@@ -32,33 +32,45 @@ except ImportError:
|
||||
|
||||
class StagehandResult(BaseModel):
|
||||
"""Result from a Stagehand operation.
|
||||
|
||||
|
||||
Attributes:
|
||||
success: Whether the operation completed successfully
|
||||
data: The result data from the operation
|
||||
error: Optional error message if the operation failed
|
||||
"""
|
||||
success: bool = Field(..., description="Whether the operation completed successfully")
|
||||
data: Union[str, Dict, List] = Field(..., description="The result data from the operation")
|
||||
error: Optional[str] = Field(None, description="Optional error message if the operation failed")
|
||||
|
||||
success: bool = Field(
|
||||
..., description="Whether the operation completed successfully"
|
||||
)
|
||||
data: Union[str, Dict, List] = Field(
|
||||
..., description="The result data from the operation"
|
||||
)
|
||||
error: Optional[str] = Field(
|
||||
None, description="Optional error message if the operation failed"
|
||||
)
|
||||
|
||||
|
||||
class StagehandToolConfig(BaseModel):
|
||||
"""Configuration for the StagehandTool.
|
||||
|
||||
|
||||
Attributes:
|
||||
api_key: OpenAI API key for Stagehand authentication
|
||||
timeout: Maximum time in seconds to wait for operations (default: 30)
|
||||
retry_attempts: Number of times to retry failed operations (default: 3)
|
||||
"""
|
||||
|
||||
api_key: str = Field(..., description="OpenAI API key for Stagehand authentication")
|
||||
timeout: int = Field(30, description="Maximum time in seconds to wait for operations")
|
||||
retry_attempts: int = Field(3, description="Number of times to retry failed operations")
|
||||
timeout: int = Field(
|
||||
30, description="Maximum time in seconds to wait for operations"
|
||||
)
|
||||
retry_attempts: int = Field(
|
||||
3, description="Number of times to retry failed operations"
|
||||
)
|
||||
|
||||
|
||||
class StagehandToolSchema(BaseModel):
|
||||
"""Schema for the StagehandTool input parameters.
|
||||
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Using the 'act' API to click a button
|
||||
@@ -66,13 +78,13 @@ class StagehandToolSchema(BaseModel):
|
||||
api_method="act",
|
||||
instruction="Click the 'Sign In' button"
|
||||
)
|
||||
|
||||
|
||||
# Using the 'extract' API to get text
|
||||
tool.run(
|
||||
api_method="extract",
|
||||
instruction="Get the text content of the main article"
|
||||
)
|
||||
|
||||
|
||||
# Using the 'observe' API to monitor changes
|
||||
tool.run(
|
||||
api_method="observe",
|
||||
@@ -80,48 +92,49 @@ class StagehandToolSchema(BaseModel):
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
api_method: str = Field(
|
||||
...,
|
||||
description="The Stagehand API to use: 'act' for interactions, 'extract' for getting content, or 'observe' for monitoring changes",
|
||||
pattern="^(act|extract|observe)$"
|
||||
pattern="^(act|extract|observe)$",
|
||||
)
|
||||
instruction: str = Field(
|
||||
...,
|
||||
description="An atomic instruction for Stagehand to execute. Instructions should be simple and specific to increase reliability.",
|
||||
min_length=1,
|
||||
max_length=500
|
||||
max_length=500,
|
||||
)
|
||||
|
||||
|
||||
class StagehandTool(BaseTool):
|
||||
"""A tool for using Stagehand's AI-powered web automation capabilities.
|
||||
|
||||
|
||||
This tool provides access to Stagehand's three core APIs:
|
||||
- act: Perform web interactions (e.g., clicking buttons, filling forms)
|
||||
- extract: Extract information from web pages (e.g., getting text content)
|
||||
- observe: Monitor web page changes (e.g., watching for updates)
|
||||
|
||||
|
||||
Each function takes atomic instructions to increase reliability.
|
||||
|
||||
|
||||
Required Environment Variables:
|
||||
OPENAI_API_KEY: API key for OpenAI (required by Stagehand)
|
||||
|
||||
|
||||
Examples:
|
||||
```python
|
||||
tool = StagehandTool()
|
||||
|
||||
|
||||
# Perform a web interaction
|
||||
result = tool.run(
|
||||
api_method="act",
|
||||
instruction="Click the 'Sign In' button"
|
||||
)
|
||||
|
||||
|
||||
# Extract content from a page
|
||||
content = tool.run(
|
||||
api_method="extract",
|
||||
instruction="Get the text content of the main article"
|
||||
)
|
||||
|
||||
|
||||
# Monitor for changes
|
||||
changes = tool.run(
|
||||
api_method="observe",
|
||||
@@ -129,7 +142,7 @@ class StagehandTool(BaseTool):
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
name: str = "StagehandTool"
|
||||
description: str = (
|
||||
"A tool that uses Stagehand's AI-powered web automation to interact with websites. "
|
||||
@@ -137,27 +150,29 @@ class StagehandTool(BaseTool):
|
||||
"Each instruction should be atomic (simple and specific) to increase reliability."
|
||||
)
|
||||
args_schema: Type[BaseModel] = StagehandToolSchema
|
||||
|
||||
def __init__(self, config: StagehandToolConfig | None = None, **kwargs: Any) -> None:
|
||||
|
||||
def __init__(
|
||||
self, config: StagehandToolConfig | None = None, **kwargs: Any
|
||||
) -> None:
|
||||
"""Initialize the StagehandTool.
|
||||
|
||||
|
||||
Args:
|
||||
config: Optional configuration for the tool. If not provided,
|
||||
will attempt to use OPENAI_API_KEY from environment.
|
||||
**kwargs: Additional keyword arguments passed to the base class.
|
||||
|
||||
|
||||
Raises:
|
||||
ImportError: If the stagehand package is not installed
|
||||
ValueError: If no API key is provided via config or environment
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
if not STAGEHAND_AVAILABLE:
|
||||
raise ImportError(
|
||||
"The 'stagehand' package is required to use this tool. "
|
||||
"Please install it with: pip install stagehand"
|
||||
)
|
||||
|
||||
|
||||
# Use config if provided, otherwise try environment variable
|
||||
if config is not None:
|
||||
self.config = config
|
||||
@@ -168,24 +183,22 @@ class StagehandTool(BaseTool):
|
||||
"Either provide config with api_key or set OPENAI_API_KEY environment variable"
|
||||
)
|
||||
self.config = StagehandToolConfig(
|
||||
api_key=api_key,
|
||||
timeout=30,
|
||||
retry_attempts=3
|
||||
api_key=api_key, timeout=30, retry_attempts=3
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=100)
|
||||
def _cached_run(self, api_method: str, instruction: str) -> Any:
|
||||
"""Execute a cached Stagehand command.
|
||||
|
||||
|
||||
This method is cached to improve performance for repeated operations.
|
||||
|
||||
|
||||
Args:
|
||||
api_method: The Stagehand API to use ('act', 'extract', or 'observe')
|
||||
instruction: An atomic instruction for Stagehand to execute
|
||||
|
||||
|
||||
Returns:
|
||||
The raw result from the Stagehand API call
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid api_method is provided
|
||||
Exception: If the Stagehand API call fails
|
||||
@@ -193,23 +206,25 @@ class StagehandTool(BaseTool):
|
||||
logger.debug(
|
||||
"Cache operation - Method: %s, Instruction length: %d",
|
||||
api_method,
|
||||
len(instruction)
|
||||
len(instruction),
|
||||
)
|
||||
|
||||
|
||||
# Initialize Stagehand with configuration
|
||||
logger.info(
|
||||
"Initializing Stagehand (timeout=%ds, retries=%d)",
|
||||
self.config.timeout,
|
||||
self.config.retry_attempts
|
||||
self.config.retry_attempts,
|
||||
)
|
||||
st = stagehand.Stagehand(
|
||||
api_key=self.config.api_key,
|
||||
timeout=self.config.timeout,
|
||||
retry_attempts=self.config.retry_attempts
|
||||
retry_attempts=self.config.retry_attempts,
|
||||
)
|
||||
|
||||
|
||||
# Call the appropriate Stagehand API based on the method
|
||||
logger.info("Executing %s operation with instruction: %s", api_method, instruction[:100])
|
||||
logger.info(
|
||||
"Executing %s operation with instruction: %s", api_method, instruction[:100]
|
||||
)
|
||||
try:
|
||||
if api_method == "act":
|
||||
result = st.act(instruction)
|
||||
@@ -219,28 +234,27 @@ class StagehandTool(BaseTool):
|
||||
result = st.observe(instruction)
|
||||
else:
|
||||
raise ValueError(f"Unknown api_method: {api_method}")
|
||||
|
||||
|
||||
|
||||
logger.info("Successfully executed %s operation", api_method)
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Operation failed (method=%s, error=%s), will be retried on next attempt",
|
||||
api_method,
|
||||
str(e)
|
||||
str(e),
|
||||
)
|
||||
raise
|
||||
|
||||
def _run(self, api_method: str, instruction: str, **kwargs: Any) -> StagehandResult:
|
||||
"""Execute a Stagehand command using the specified API method.
|
||||
|
||||
|
||||
Args:
|
||||
api_method: The Stagehand API to use ('act', 'extract', or 'observe')
|
||||
instruction: An atomic instruction for Stagehand to execute
|
||||
**kwargs: Additional keyword arguments passed to the Stagehand API
|
||||
|
||||
Returns:
|
||||
|
||||
Returns:
|
||||
StagehandResult containing the operation result and status
|
||||
"""
|
||||
try:
|
||||
@@ -249,56 +263,36 @@ class StagehandTool(BaseTool):
|
||||
"Starting operation - Method: %s, Instruction length: %d, Args: %s",
|
||||
api_method,
|
||||
len(instruction),
|
||||
kwargs
|
||||
kwargs,
|
||||
)
|
||||
|
||||
|
||||
# Use cached execution
|
||||
result = self._cached_run(api_method, instruction)
|
||||
logger.info("Operation completed successfully")
|
||||
return StagehandResult(success=True, data=result)
|
||||
|
||||
|
||||
except stagehand.AuthenticationError as e:
|
||||
logger.error(
|
||||
"Authentication failed - Method: %s, Error: %s",
|
||||
api_method,
|
||||
str(e)
|
||||
"Authentication failed - Method: %s, Error: %s", api_method, str(e)
|
||||
)
|
||||
return StagehandResult(
|
||||
success=False,
|
||||
data={},
|
||||
error=f"Authentication failed: {str(e)}"
|
||||
success=False, data={}, error=f"Authentication failed: {str(e)}"
|
||||
)
|
||||
except stagehand.APIError as e:
|
||||
logger.error(
|
||||
"API error - Method: %s, Error: %s",
|
||||
api_method,
|
||||
str(e)
|
||||
)
|
||||
return StagehandResult(
|
||||
success=False,
|
||||
data={},
|
||||
error=f"API error: {str(e)}"
|
||||
)
|
||||
logger.error("API error - Method: %s, Error: %s", api_method, str(e))
|
||||
return StagehandResult(success=False, data={}, error=f"API error: {str(e)}")
|
||||
except stagehand.BrowserError as e:
|
||||
logger.error(
|
||||
"Browser error - Method: %s, Error: %s",
|
||||
api_method,
|
||||
str(e)
|
||||
)
|
||||
logger.error("Browser error - Method: %s, Error: %s", api_method, str(e))
|
||||
return StagehandResult(
|
||||
success=False,
|
||||
data={},
|
||||
error=f"Browser error: {str(e)}"
|
||||
success=False, data={}, error=f"Browser error: {str(e)}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Unexpected error - Method: %s, Error type: %s, Message: %s",
|
||||
api_method,
|
||||
type(e).__name__,
|
||||
str(e)
|
||||
str(e),
|
||||
)
|
||||
return StagehandResult(
|
||||
success=False,
|
||||
data={},
|
||||
error=f"Unexpected error: {str(e)}"
|
||||
success=False, data={}, error=f"Unexpected error: {str(e)}"
|
||||
)
|
||||
|
||||
@@ -1,30 +1,36 @@
|
||||
import base64
|
||||
from typing import Type, Optional
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
|
||||
class ImagePromptSchema(BaseModel):
|
||||
"""Input for Vision Tool."""
|
||||
|
||||
image_path_url: str = "The image path or URL."
|
||||
|
||||
@validator("image_path_url")
|
||||
def validate_image_path_url(cls, v: str) -> str:
|
||||
if v.startswith("http"):
|
||||
return v
|
||||
|
||||
|
||||
path = Path(v)
|
||||
if not path.exists():
|
||||
raise ValueError(f"Image file does not exist: {v}")
|
||||
|
||||
|
||||
# Validate supported formats
|
||||
valid_extensions = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
|
||||
if path.suffix.lower() not in valid_extensions:
|
||||
raise ValueError(f"Unsupported image format. Supported formats: {valid_extensions}")
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported image format. Supported formats: {valid_extensions}"
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class VisionTool(BaseTool):
|
||||
name: str = "Vision Tool"
|
||||
description: str = (
|
||||
@@ -45,10 +51,10 @@ class VisionTool(BaseTool):
|
||||
image_path_url = kwargs.get("image_path_url")
|
||||
if not image_path_url:
|
||||
return "Image Path or URL is required."
|
||||
|
||||
|
||||
# Validate input using Pydantic
|
||||
ImagePromptSchema(image_path_url=image_path_url)
|
||||
|
||||
|
||||
if image_path_url.startswith("http"):
|
||||
image_data = image_path_url
|
||||
else:
|
||||
@@ -68,12 +74,12 @@ class VisionTool(BaseTool):
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image_data},
|
||||
}
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
max_tokens=300,
|
||||
)
|
||||
)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
|
||||
@@ -15,9 +15,8 @@ except ImportError:
|
||||
Vectorizers = Any
|
||||
Auth = Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class WeaviateToolSchema(BaseModel):
|
||||
|
||||
@@ -25,9 +25,7 @@ class WebsiteSearchToolSchema(FixedWebsiteSearchToolSchema):
|
||||
|
||||
class WebsiteSearchTool(RagTool):
|
||||
name: str = "Search in a specific website"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a specific URL content."
|
||||
)
|
||||
description: str = "A tool that can be used to semantic search a query from a specific URL content."
|
||||
args_schema: Type[BaseModel] = WebsiteSearchToolSchema
|
||||
|
||||
def __init__(self, website: Optional[str] = None, **kwargs):
|
||||
|
||||
@@ -25,9 +25,7 @@ class YoutubeChannelSearchToolSchema(FixedYoutubeChannelSearchToolSchema):
|
||||
|
||||
class YoutubeChannelSearchTool(RagTool):
|
||||
name: str = "Search a Youtube Channels content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a Youtube Channels content."
|
||||
)
|
||||
description: str = "A tool that can be used to semantic search a query from a Youtube Channels content."
|
||||
args_schema: Type[BaseModel] = YoutubeChannelSearchToolSchema
|
||||
|
||||
def __init__(self, youtube_channel_handle: Optional[str] = None, **kwargs):
|
||||
|
||||
@@ -25,9 +25,7 @@ class YoutubeVideoSearchToolSchema(FixedYoutubeVideoSearchToolSchema):
|
||||
|
||||
class YoutubeVideoSearchTool(RagTool):
|
||||
name: str = "Search a Youtube Video content"
|
||||
description: str = (
|
||||
"A tool that can be used to semantic search a query from a Youtube Video content."
|
||||
)
|
||||
description: str = "A tool that can be used to semantic search a query from a Youtube Video content."
|
||||
args_schema: Type[BaseModel] = YoutubeVideoSearchToolSchema
|
||||
|
||||
def __init__(self, youtube_video_url: Optional[str] = None, **kwargs):
|
||||
|
||||
@@ -1,69 +1,104 @@
|
||||
from typing import Callable
|
||||
|
||||
from crewai.tools import BaseTool, tool
|
||||
from crewai.tools.base_tool import to_langchain
|
||||
|
||||
|
||||
def test_creating_a_tool_using_annotation():
|
||||
@tool("Name of my tool")
|
||||
def my_tool(question: str) -> str:
|
||||
"""Clear description for what this tool is useful for, you agent will need this information to use it."""
|
||||
return question
|
||||
@tool("Name of my tool")
|
||||
def my_tool(question: str) -> str:
|
||||
"""Clear description for what this tool is useful for, you agent will need this information to use it."""
|
||||
return question
|
||||
|
||||
# Assert all the right attributes were defined
|
||||
assert my_tool.name == "Name of my tool"
|
||||
assert my_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
assert my_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}}
|
||||
assert my_tool.func("What is the meaning of life?") == "What is the meaning of life?"
|
||||
# Assert all the right attributes were defined
|
||||
assert my_tool.name == "Name of my tool"
|
||||
assert (
|
||||
my_tool.description
|
||||
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
)
|
||||
assert my_tool.args_schema.schema()["properties"] == {
|
||||
"question": {"title": "Question", "type": "string"}
|
||||
}
|
||||
assert (
|
||||
my_tool.func("What is the meaning of life?") == "What is the meaning of life?"
|
||||
)
|
||||
|
||||
# Assert the langchain tool conversion worked as expected
|
||||
converted_tool = to_langchain([my_tool])[0]
|
||||
assert converted_tool.name == "Name of my tool"
|
||||
assert (
|
||||
converted_tool.description
|
||||
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
)
|
||||
assert converted_tool.args_schema.schema()["properties"] == {
|
||||
"question": {"title": "Question", "type": "string"}
|
||||
}
|
||||
assert (
|
||||
converted_tool.func("What is the meaning of life?")
|
||||
== "What is the meaning of life?"
|
||||
)
|
||||
|
||||
# Assert the langchain tool conversion worked as expected
|
||||
converted_tool = to_langchain([my_tool])[0]
|
||||
assert converted_tool.name == "Name of my tool"
|
||||
assert converted_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
assert converted_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}}
|
||||
assert converted_tool.func("What is the meaning of life?") == "What is the meaning of life?"
|
||||
|
||||
def test_creating_a_tool_using_baseclass():
|
||||
class MyCustomTool(BaseTool):
|
||||
name: str = "Name of my tool"
|
||||
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
class MyCustomTool(BaseTool):
|
||||
name: str = "Name of my tool"
|
||||
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
|
||||
def _run(self, question: str) -> str:
|
||||
return question
|
||||
def _run(self, question: str) -> str:
|
||||
return question
|
||||
|
||||
my_tool = MyCustomTool()
|
||||
# Assert all the right attributes were defined
|
||||
assert my_tool.name == "Name of my tool"
|
||||
assert my_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
assert my_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}}
|
||||
assert my_tool._run("What is the meaning of life?") == "What is the meaning of life?"
|
||||
my_tool = MyCustomTool()
|
||||
# Assert all the right attributes were defined
|
||||
assert my_tool.name == "Name of my tool"
|
||||
assert (
|
||||
my_tool.description
|
||||
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
)
|
||||
assert my_tool.args_schema.schema()["properties"] == {
|
||||
"question": {"title": "Question", "type": "string"}
|
||||
}
|
||||
assert (
|
||||
my_tool._run("What is the meaning of life?") == "What is the meaning of life?"
|
||||
)
|
||||
|
||||
# Assert the langchain tool conversion worked as expected
|
||||
converted_tool = to_langchain([my_tool])[0]
|
||||
assert converted_tool.name == "Name of my tool"
|
||||
assert (
|
||||
converted_tool.description
|
||||
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
)
|
||||
assert converted_tool.args_schema.schema()["properties"] == {
|
||||
"question": {"title": "Question", "type": "string"}
|
||||
}
|
||||
assert (
|
||||
converted_tool.invoke({"question": "What is the meaning of life?"})
|
||||
== "What is the meaning of life?"
|
||||
)
|
||||
|
||||
# Assert the langchain tool conversion worked as expected
|
||||
converted_tool = to_langchain([my_tool])[0]
|
||||
assert converted_tool.name == "Name of my tool"
|
||||
assert converted_tool.description == "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
assert converted_tool.args_schema.schema()["properties"] == {'question': {'title': 'Question', 'type': 'string'}}
|
||||
assert converted_tool.invoke({"question": "What is the meaning of life?"}) == "What is the meaning of life?"
|
||||
|
||||
def test_setting_cache_function():
|
||||
class MyCustomTool(BaseTool):
|
||||
name: str = "Name of my tool"
|
||||
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
cache_function: Callable = lambda: False
|
||||
class MyCustomTool(BaseTool):
|
||||
name: str = "Name of my tool"
|
||||
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
cache_function: Callable = lambda: False
|
||||
|
||||
def _run(self, question: str) -> str:
|
||||
return question
|
||||
def _run(self, question: str) -> str:
|
||||
return question
|
||||
|
||||
my_tool = MyCustomTool()
|
||||
# Assert all the right attributes were defined
|
||||
assert my_tool.cache_function() == False
|
||||
|
||||
my_tool = MyCustomTool()
|
||||
# Assert all the right attributes were defined
|
||||
assert my_tool.cache_function() == False
|
||||
|
||||
def test_default_cache_function_is_true():
|
||||
class MyCustomTool(BaseTool):
|
||||
name: str = "Name of my tool"
|
||||
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
class MyCustomTool(BaseTool):
|
||||
name: str = "Name of my tool"
|
||||
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
|
||||
|
||||
def _run(self, question: str) -> str:
|
||||
return question
|
||||
def _run(self, question: str) -> str:
|
||||
return question
|
||||
|
||||
my_tool = MyCustomTool()
|
||||
# Assert all the right attributes were defined
|
||||
assert my_tool.cache_function() == True
|
||||
my_tool = MyCustomTool()
|
||||
# Assert all the right attributes were defined
|
||||
assert my_tool.cache_function() == True
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from crewai_tools import FileReadTool
|
||||
|
||||
|
||||
def test_file_read_tool_constructor():
|
||||
"""Test FileReadTool initialization with file_path."""
|
||||
# Create a temporary test file
|
||||
@@ -18,6 +19,7 @@ def test_file_read_tool_constructor():
|
||||
# Clean up
|
||||
os.remove(test_file)
|
||||
|
||||
|
||||
def test_file_read_tool_run():
|
||||
"""Test FileReadTool _run method with file_path at runtime."""
|
||||
# Create a temporary test file
|
||||
@@ -34,6 +36,7 @@ def test_file_read_tool_run():
|
||||
# Clean up
|
||||
os.remove(test_file)
|
||||
|
||||
|
||||
def test_file_read_tool_error_handling():
|
||||
"""Test FileReadTool error handling."""
|
||||
# Test missing file path
|
||||
@@ -58,6 +61,7 @@ def test_file_read_tool_error_handling():
|
||||
os.chmod(test_file, 0o666) # Restore permissions to delete
|
||||
os.remove(test_file)
|
||||
|
||||
|
||||
def test_file_read_tool_constructor_and_run():
|
||||
"""Test FileReadTool using both constructor and runtime file paths."""
|
||||
# Create two test files
|
||||
|
||||
0
tests/it/tools/__init__.py
Normal file
0
tests/it/tools/__init__.py
Normal file
21
tests/it/tools/conftest.py
Normal file
21
tests/it/tools/conftest.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import pytest
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Register custom markers."""
|
||||
config.addinivalue_line("markers", "integration: mark test as an integration test")
|
||||
config.addinivalue_line("markers", "asyncio: mark test as an async test")
|
||||
|
||||
# Set the asyncio loop scope through ini configuration
|
||||
config.inicfg["asyncio_mode"] = "auto"
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for each test case."""
|
||||
import asyncio
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
yield loop
|
||||
loop.close()
|
||||
219
tests/it/tools/snowflake_search_tool_test.py
Normal file
219
tests/it/tools/snowflake_search_tool_test.py
Normal file
@@ -0,0 +1,219 @@
|
||||
import asyncio
|
||||
import json
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from snowflake.connector.errors import DatabaseError, OperationalError
|
||||
|
||||
from crewai_tools import SnowflakeConfig, SnowflakeSearchTool
|
||||
|
||||
# Test Data
|
||||
MENU_ITEMS = [
|
||||
(10001, "Ice Cream", "Freezing Point", "Lemonade", "Beverage", "Cold Option", 1, 4),
|
||||
(
|
||||
10002,
|
||||
"Ice Cream",
|
||||
"Freezing Point",
|
||||
"Vanilla Ice Cream",
|
||||
"Dessert",
|
||||
"Ice Cream",
|
||||
2,
|
||||
6,
|
||||
),
|
||||
]
|
||||
|
||||
INVALID_QUERIES = [
|
||||
("SELECT * FROM nonexistent_table", "relation 'nonexistent_table' does not exist"),
|
||||
("SELECT invalid_column FROM menu", "invalid identifier 'invalid_column'"),
|
||||
("INVALID SQL QUERY", "SQL compilation error"),
|
||||
]
|
||||
|
||||
|
||||
# Integration Test Fixtures
|
||||
@pytest.fixture
|
||||
def config():
|
||||
"""Create a Snowflake configuration with test credentials."""
|
||||
return SnowflakeConfig(
|
||||
account="lwyhjun-wx11931",
|
||||
user="crewgitci",
|
||||
password="crewaiT00ls_publicCIpass123",
|
||||
warehouse="COMPUTE_WH",
|
||||
database="tasty_bytes_sample_data",
|
||||
snowflake_schema="raw_pos",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def snowflake_tool(config):
|
||||
"""Create a SnowflakeSearchTool instance."""
|
||||
return SnowflakeSearchTool(config=config)
|
||||
|
||||
|
||||
# Integration Tests with Real Snowflake Connection
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"menu_id,expected_type,brand,item_name,category,subcategory,cost,price", MENU_ITEMS
|
||||
)
|
||||
async def test_menu_items(
|
||||
snowflake_tool,
|
||||
menu_id,
|
||||
expected_type,
|
||||
brand,
|
||||
item_name,
|
||||
category,
|
||||
subcategory,
|
||||
cost,
|
||||
price,
|
||||
):
|
||||
"""Test menu items with parameterized data for multiple test cases."""
|
||||
results = await snowflake_tool._run(
|
||||
query=f"SELECT * FROM menu WHERE menu_id = {menu_id}"
|
||||
)
|
||||
assert len(results) == 1
|
||||
menu_item = results[0]
|
||||
|
||||
# Validate all fields
|
||||
assert menu_item["MENU_ID"] == menu_id
|
||||
assert menu_item["MENU_TYPE"] == expected_type
|
||||
assert menu_item["TRUCK_BRAND_NAME"] == brand
|
||||
assert menu_item["MENU_ITEM_NAME"] == item_name
|
||||
assert menu_item["ITEM_CATEGORY"] == category
|
||||
assert menu_item["ITEM_SUBCATEGORY"] == subcategory
|
||||
assert menu_item["COST_OF_GOODS_USD"] == cost
|
||||
assert menu_item["SALE_PRICE_USD"] == price
|
||||
|
||||
# Validate health metrics JSON structure
|
||||
health_metrics = json.loads(menu_item["MENU_ITEM_HEALTH_METRICS_OBJ"])
|
||||
assert "menu_item_health_metrics" in health_metrics
|
||||
metrics = health_metrics["menu_item_health_metrics"][0]
|
||||
assert "ingredients" in metrics
|
||||
assert isinstance(metrics["ingredients"], list)
|
||||
assert all(isinstance(ingredient, str) for ingredient in metrics["ingredients"])
|
||||
assert metrics["is_dairy_free_flag"] in ["Y", "N"]
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_menu_categories_aggregation(snowflake_tool):
|
||||
"""Test complex aggregation query on menu categories with detailed validations."""
|
||||
results = await snowflake_tool._run(
|
||||
query="""
|
||||
SELECT
|
||||
item_category,
|
||||
COUNT(*) as item_count,
|
||||
AVG(sale_price_usd) as avg_price,
|
||||
SUM(sale_price_usd - cost_of_goods_usd) as total_margin,
|
||||
COUNT(DISTINCT menu_type) as menu_type_count,
|
||||
MIN(sale_price_usd) as min_price,
|
||||
MAX(sale_price_usd) as max_price
|
||||
FROM menu
|
||||
GROUP BY item_category
|
||||
HAVING COUNT(*) > 1
|
||||
ORDER BY item_count DESC
|
||||
"""
|
||||
)
|
||||
|
||||
assert len(results) > 0
|
||||
for category in results:
|
||||
# Basic presence checks
|
||||
assert all(
|
||||
key in category
|
||||
for key in [
|
||||
"ITEM_CATEGORY",
|
||||
"ITEM_COUNT",
|
||||
"AVG_PRICE",
|
||||
"TOTAL_MARGIN",
|
||||
"MENU_TYPE_COUNT",
|
||||
"MIN_PRICE",
|
||||
"MAX_PRICE",
|
||||
]
|
||||
)
|
||||
|
||||
# Value validations
|
||||
assert category["ITEM_COUNT"] > 1 # Due to HAVING clause
|
||||
assert category["MIN_PRICE"] <= category["MAX_PRICE"]
|
||||
assert category["AVG_PRICE"] >= category["MIN_PRICE"]
|
||||
assert category["AVG_PRICE"] <= category["MAX_PRICE"]
|
||||
assert category["MENU_TYPE_COUNT"] >= 1
|
||||
assert isinstance(category["TOTAL_MARGIN"], (float, Decimal))
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("invalid_query,expected_error", INVALID_QUERIES)
|
||||
async def test_invalid_queries(snowflake_tool, invalid_query, expected_error):
|
||||
"""Test error handling for invalid queries."""
|
||||
with pytest.raises((DatabaseError, OperationalError)) as exc_info:
|
||||
await snowflake_tool._run(query=invalid_query)
|
||||
assert expected_error.lower() in str(exc_info.value).lower()
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_queries(snowflake_tool):
|
||||
"""Test handling of concurrent queries."""
|
||||
queries = [
|
||||
"SELECT COUNT(*) FROM menu",
|
||||
"SELECT COUNT(DISTINCT menu_type) FROM menu",
|
||||
"SELECT COUNT(DISTINCT item_category) FROM menu",
|
||||
]
|
||||
|
||||
tasks = [snowflake_tool._run(query=query) for query in queries]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(isinstance(result, list) for result in results)
|
||||
assert all(len(result) == 1 for result in results)
|
||||
assert all(isinstance(result[0], dict) for result in results)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_timeout(snowflake_tool):
|
||||
"""Test query timeout handling with a complex query."""
|
||||
with pytest.raises((DatabaseError, OperationalError)) as exc_info:
|
||||
await snowflake_tool._run(
|
||||
query="""
|
||||
WITH RECURSIVE numbers AS (
|
||||
SELECT 1 as n
|
||||
UNION ALL
|
||||
SELECT n + 1
|
||||
FROM numbers
|
||||
WHERE n < 1000000
|
||||
)
|
||||
SELECT COUNT(*) FROM numbers
|
||||
"""
|
||||
)
|
||||
assert (
|
||||
"timeout" in str(exc_info.value).lower()
|
||||
or "execution time" in str(exc_info.value).lower()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
@pytest.mark.asyncio
|
||||
async def test_caching_behavior(snowflake_tool):
|
||||
"""Test query caching behavior and performance."""
|
||||
query = "SELECT * FROM menu LIMIT 5"
|
||||
|
||||
# First execution
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
results1 = await snowflake_tool._run(query=query)
|
||||
first_duration = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
# Second execution (should be cached)
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
results2 = await snowflake_tool._run(query=query)
|
||||
second_duration = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
# Verify results
|
||||
assert results1 == results2
|
||||
assert len(results1) == 5
|
||||
assert second_duration < first_duration
|
||||
|
||||
# Verify cache invalidation with different query
|
||||
different_query = "SELECT * FROM menu LIMIT 10"
|
||||
different_results = await snowflake_tool._run(query=different_query)
|
||||
assert len(different_results) == 10
|
||||
assert different_results != results1
|
||||
@@ -1,5 +1,7 @@
|
||||
from crewai import Agent, Crew, Task
|
||||
|
||||
from crewai_tools.tools.spider_tool.spider_tool import SpiderTool
|
||||
from crewai import Agent, Task, Crew
|
||||
|
||||
|
||||
def test_spider_tool():
|
||||
spider_tool = SpiderTool()
|
||||
@@ -10,38 +12,35 @@ def test_spider_tool():
|
||||
backstory="An expert web researcher that uses the web extremely well",
|
||||
tools=[spider_tool],
|
||||
verbose=True,
|
||||
cache=False
|
||||
cache=False,
|
||||
)
|
||||
|
||||
choose_between_scrape_crawl = Task(
|
||||
description="Scrape the page of spider.cloud and return a summary of how fast it is",
|
||||
expected_output="spider.cloud is a fast scraping and crawling tool",
|
||||
agent=searcher
|
||||
agent=searcher,
|
||||
)
|
||||
|
||||
return_metadata = Task(
|
||||
description="Scrape https://spider.cloud with a limit of 1 and enable metadata",
|
||||
expected_output="Metadata and 10 word summary of spider.cloud",
|
||||
agent=searcher
|
||||
agent=searcher,
|
||||
)
|
||||
|
||||
css_selector = Task(
|
||||
description="Scrape one page of spider.cloud with the `body > div > main > section.grid.md\:grid-cols-2.gap-10.place-items-center.md\:max-w-screen-xl.mx-auto.pb-8.pt-20 > div:nth-child(1) > h1` CSS selector",
|
||||
expected_output="The content of the element with the css selector body > div > main > section.grid.md\:grid-cols-2.gap-10.place-items-center.md\:max-w-screen-xl.mx-auto.pb-8.pt-20 > div:nth-child(1) > h1",
|
||||
agent=searcher
|
||||
agent=searcher,
|
||||
)
|
||||
|
||||
crew = Crew(
|
||||
agents=[searcher],
|
||||
tasks=[
|
||||
choose_between_scrape_crawl,
|
||||
return_metadata,
|
||||
css_selector
|
||||
],
|
||||
verbose=True
|
||||
tasks=[choose_between_scrape_crawl, return_metadata, css_selector],
|
||||
verbose=True,
|
||||
)
|
||||
|
||||
crew.kickoff()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_spider_tool()
|
||||
|
||||
103
tests/tools/snowflake_search_tool_test.py
Normal file
103
tests/tools/snowflake_search_tool_test.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools import SnowflakeConfig, SnowflakeSearchTool
|
||||
|
||||
|
||||
# Unit Test Fixtures
|
||||
@pytest.fixture
|
||||
def mock_snowflake_connection():
|
||||
mock_conn = MagicMock()
|
||||
mock_cursor = MagicMock()
|
||||
mock_cursor.description = [("col1",), ("col2",)]
|
||||
mock_cursor.fetchall.return_value = [(1, "value1"), (2, "value2")]
|
||||
mock_cursor.execute.return_value = None
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
return mock_conn
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config():
|
||||
return SnowflakeConfig(
|
||||
account="test_account",
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
warehouse="test_warehouse",
|
||||
database="test_db",
|
||||
snowflake_schema="test_schema",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def snowflake_tool(mock_config):
|
||||
with patch("snowflake.connector.connect") as mock_connect:
|
||||
tool = SnowflakeSearchTool(config=mock_config)
|
||||
yield tool
|
||||
|
||||
|
||||
# Unit Tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_query_execution(snowflake_tool, mock_snowflake_connection):
|
||||
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
|
||||
mock_create_conn.return_value = mock_snowflake_connection
|
||||
|
||||
results = await snowflake_tool._run(
|
||||
query="SELECT * FROM test_table", timeout=300
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0]["col1"] == 1
|
||||
assert results[0]["col2"] == "value1"
|
||||
mock_snowflake_connection.cursor.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_connection_pooling(snowflake_tool, mock_snowflake_connection):
|
||||
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
|
||||
mock_create_conn.return_value = mock_snowflake_connection
|
||||
|
||||
# Execute multiple queries
|
||||
await asyncio.gather(
|
||||
snowflake_tool._run("SELECT 1"),
|
||||
snowflake_tool._run("SELECT 2"),
|
||||
snowflake_tool._run("SELECT 3"),
|
||||
)
|
||||
|
||||
# Should reuse connections from pool
|
||||
assert mock_create_conn.call_count <= snowflake_tool.pool_size
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_on_deletion(snowflake_tool, mock_snowflake_connection):
|
||||
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
|
||||
mock_create_conn.return_value = mock_snowflake_connection
|
||||
|
||||
# Add connection to pool
|
||||
await snowflake_tool._get_connection()
|
||||
|
||||
# Return connection to pool
|
||||
async with snowflake_tool._pool_lock:
|
||||
snowflake_tool._connection_pool.append(mock_snowflake_connection)
|
||||
|
||||
# Trigger cleanup
|
||||
snowflake_tool.__del__()
|
||||
|
||||
mock_snowflake_connection.close.assert_called_once()
|
||||
|
||||
|
||||
def test_config_validation():
|
||||
# Test missing required fields
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeConfig()
|
||||
|
||||
# Test invalid account format
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeConfig(
|
||||
account="invalid//account", user="test_user", password="test_pass"
|
||||
)
|
||||
|
||||
# Test missing authentication
|
||||
with pytest.raises(ValueError):
|
||||
SnowflakeConfig(account="test_account", user="test_user")
|
||||
@@ -7,7 +7,9 @@ from crewai_tools.tools.code_interpreter_tool.code_interpreter_tool import (
|
||||
|
||||
|
||||
class TestCodeInterpreterTool(unittest.TestCase):
|
||||
@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env")
|
||||
@patch(
|
||||
"crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env"
|
||||
)
|
||||
def test_run_code_in_docker(self, docker_mock):
|
||||
tool = CodeInterpreterTool()
|
||||
code = "print('Hello, World!')"
|
||||
@@ -15,14 +17,14 @@ class TestCodeInterpreterTool(unittest.TestCase):
|
||||
expected_output = "Hello, World!\n"
|
||||
|
||||
docker_mock().containers.run().exec_run().exit_code = 0
|
||||
docker_mock().containers.run().exec_run().output = (
|
||||
expected_output.encode()
|
||||
)
|
||||
docker_mock().containers.run().exec_run().output = expected_output.encode()
|
||||
result = tool.run_code_in_docker(code, libraries_used)
|
||||
|
||||
self.assertEqual(result, expected_output)
|
||||
|
||||
@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env")
|
||||
@patch(
|
||||
"crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env"
|
||||
)
|
||||
def test_run_code_in_docker_with_error(self, docker_mock):
|
||||
tool = CodeInterpreterTool()
|
||||
code = "print(1/0)"
|
||||
@@ -37,7 +39,9 @@ class TestCodeInterpreterTool(unittest.TestCase):
|
||||
|
||||
self.assertEqual(result, expected_output)
|
||||
|
||||
@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env")
|
||||
@patch(
|
||||
"crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env"
|
||||
)
|
||||
def test_run_code_in_docker_with_script(self, docker_mock):
|
||||
tool = CodeInterpreterTool()
|
||||
code = """print("This is line 1")
|
||||
|
||||
Reference in New Issue
Block a user