From fac32d9503a734a23b03939262c6c3a9a3347925 Mon Sep 17 00:00:00 2001 From: Lucas Gomide Date: Mon, 16 Jun 2025 11:09:19 -0300 Subject: [PATCH] Support to collect extra package dependencies of Tools (#330) * feat: add explictly package_dependencies in the Tools * feat: collect package_dependencies from Tool to add in tool.specs.json * feat: add default value in run_params Tool' specs * fix: support get boolean values This commit also refactor test to make easier define newest attributes into a Tool --- .../aws/bedrock/agents/invoke_agent_tool.py | 29 +++--- .../bedrock/knowledge_base/retriever_tool.py | 51 ++++++----- src/crewai_tools/aws/s3/reader_tool.py | 3 +- src/crewai_tools/aws/s3/writer_tool.py | 3 +- .../tools/ai_mind_tool/ai_mind_tool.py | 13 +-- .../apify_actors_tool/apify_actors_tool.py | 1 + .../browserbase_load_tool.py | 3 +- .../tools/dalle_tool/dalle_tool.py | 4 +- .../databricks_query_tool.py | 1 + .../tools/exa_tools/exa_search_tool.py | 3 +- .../firecrawl_crawl_website_tool.py | 3 +- .../firecrawl_scrape_website_tool.py | 3 +- .../firecrawl_search_tool.py | 3 +- .../hyperbrowser_load_tool.py | 7 +- .../tools/linkup/linkup_search_tool.py | 3 +- .../tools/multion_tool/multion_tool.py | 3 +- .../patronus_local_evaluator_tool.py | 3 +- .../qdrant_search_tool.py | 3 +- .../scrapegraph_scrape_tool.py | 3 +- .../scrapfly_scrape_website_tool.py | 3 +- .../selenium_scraping_tool.py | 3 +- .../tools/serpapi_tool/serpapi_base_tool.py | 4 +- .../snowflake_search_tool.py | 1 + .../tools/spider_tool/spider_tool.py | 3 +- .../tools/stagehand_tool/stagehand_tool.py | 1 + .../tavily_extractor_tool.py | 1 + .../tavily_search_tool/tavily_search_tool.py | 3 +- .../tools/weaviate_tool/vector_search.py | 4 +- tests/test_generate_tool_specs.py | 91 +++++++------------ 29 files changed, 129 insertions(+), 127 deletions(-) diff --git a/src/crewai_tools/aws/bedrock/agents/invoke_agent_tool.py b/src/crewai_tools/aws/bedrock/agents/invoke_agent_tool.py index c064b9b2d..65280fe7b 100644 --- a/src/crewai_tools/aws/bedrock/agents/invoke_agent_tool.py +++ b/src/crewai_tools/aws/bedrock/agents/invoke_agent_tool.py @@ -1,4 +1,4 @@ -from typing import Type, Optional, Dict, Any +from typing import Type, Optional, Dict, Any, List import os import json import uuid @@ -29,6 +29,7 @@ class BedrockInvokeAgentTool(BaseTool): session_id: str = None enable_trace: bool = False end_session: bool = False + package_dependencies: List[str] = ["boto3"] def __init__( self, @@ -51,7 +52,7 @@ class BedrockInvokeAgentTool(BaseTool): description (Optional[str]): Custom description for the tool """ super().__init__(**kwargs) - + # Get values from environment variables if not provided self.agent_id = agent_id or os.getenv('BEDROCK_AGENT_ID') self.agent_alias_id = agent_alias_id or os.getenv('BEDROCK_AGENT_ALIAS_ID') @@ -62,7 +63,7 @@ class BedrockInvokeAgentTool(BaseTool): # Update the description if provided if description: self.description = description - + # Validate parameters self._validate_parameters() @@ -74,17 +75,17 @@ class BedrockInvokeAgentTool(BaseTool): raise BedrockValidationError("agent_id cannot be empty") if not isinstance(self.agent_id, str): raise BedrockValidationError("agent_id must be a string") - + # Validate agent_alias_id if not self.agent_alias_id: raise BedrockValidationError("agent_alias_id cannot be empty") if not isinstance(self.agent_alias_id, str): raise BedrockValidationError("agent_alias_id must be a string") - + # Validate session_id if provided if self.session_id and not isinstance(self.session_id, str): raise BedrockValidationError("session_id must be a string") - + except BedrockValidationError as e: raise BedrockValidationError(f"Parameter validation failed: {str(e)}") @@ -123,7 +124,7 @@ Below is the users query or task. Complete it and answer it consicely and to the # Process the response completion = "" - + # Check if response contains a completion field if 'completion' in response: # Process streaming response format @@ -134,7 +135,7 @@ Below is the users query or task. Complete it and answer it consicely and to the completion += chunk_bytes.decode('utf-8') else: completion += str(chunk_bytes) - + # If no completion found in streaming format, try direct format if not completion and 'chunk' in response and 'bytes' in response['chunk']: chunk_bytes = response['chunk']['bytes'] @@ -142,31 +143,31 @@ Below is the users query or task. Complete it and answer it consicely and to the completion = chunk_bytes.decode('utf-8') else: completion = str(chunk_bytes) - + # If still no completion, return debug info if not completion: debug_info = { "error": "Could not extract completion from response", "response_keys": list(response.keys()) } - + # Add more debug info if 'chunk' in response: debug_info["chunk_keys"] = list(response['chunk'].keys()) - + raise BedrockAgentError(f"Failed to extract completion: {json.dumps(debug_info, indent=2)}") - + return completion except ClientError as e: error_code = "Unknown" error_message = str(e) - + # Try to extract error code if available if hasattr(e, 'response') and 'Error' in e.response: error_code = e.response['Error'].get('Code', 'Unknown') error_message = e.response['Error'].get('Message', str(e)) - + raise BedrockAgentError(f"Error ({error_code}): {error_message}") except BedrockAgentError: # Re-raise BedrockAgentError exceptions diff --git a/src/crewai_tools/aws/bedrock/knowledge_base/retriever_tool.py b/src/crewai_tools/aws/bedrock/knowledge_base/retriever_tool.py index 15c74077c..06fd3ce38 100644 --- a/src/crewai_tools/aws/bedrock/knowledge_base/retriever_tool.py +++ b/src/crewai_tools/aws/bedrock/knowledge_base/retriever_tool.py @@ -26,6 +26,7 @@ class BedrockKBRetrieverTool(BaseTool): retrieval_configuration: Optional[Dict[str, Any]] = None guardrail_configuration: Optional[Dict[str, Any]] = None next_token: Optional[str] = None + package_dependencies: List[str] = ["boto3"] def __init__( self, @@ -46,13 +47,13 @@ class BedrockKBRetrieverTool(BaseTool): next_token (Optional[str], optional): Token for retrieving the next batch of results. Defaults to None. """ super().__init__(**kwargs) - + # Get knowledge_base_id from environment variable if not provided self.knowledge_base_id = knowledge_base_id or os.getenv('BEDROCK_KB_ID') self.number_of_results = number_of_results self.guardrail_configuration = guardrail_configuration self.next_token = next_token - + # Initialize retrieval_configuration with provided parameters or use the one provided if retrieval_configuration is None: self.retrieval_configuration = self._build_retrieval_configuration() @@ -67,16 +68,16 @@ class BedrockKBRetrieverTool(BaseTool): def _build_retrieval_configuration(self) -> Dict[str, Any]: """Build the retrieval configuration based on provided parameters. - + Returns: Dict[str, Any]: The constructed retrieval configuration """ vector_search_config = {} - + # Add number of results if provided if self.number_of_results is not None: vector_search_config["numberOfResults"] = self.number_of_results - + return {"vectorSearchConfiguration": vector_search_config} def _validate_parameters(self): @@ -91,7 +92,7 @@ class BedrockKBRetrieverTool(BaseTool): raise BedrockValidationError("knowledge_base_id must be 10 characters or less") if not all(c.isalnum() for c in self.knowledge_base_id): raise BedrockValidationError("knowledge_base_id must contain only alphanumeric characters") - + # Validate next_token if provided if self.next_token: if not isinstance(self.next_token, str): @@ -100,23 +101,23 @@ class BedrockKBRetrieverTool(BaseTool): raise BedrockValidationError("next_token must be between 1 and 2048 characters") if ' ' in self.next_token: raise BedrockValidationError("next_token cannot contain spaces") - + # Validate number_of_results if provided if self.number_of_results is not None: if not isinstance(self.number_of_results, int): raise BedrockValidationError("number_of_results must be an integer") if self.number_of_results < 1: raise BedrockValidationError("number_of_results must be greater than 0") - + except BedrockValidationError as e: raise BedrockValidationError(f"Parameter validation failed: {str(e)}") def _process_retrieval_result(self, result: Dict[str, Any]) -> Dict[str, Any]: """Process a single retrieval result from Bedrock Knowledge Base. - + Args: result (Dict[str, Any]): Raw result from Bedrock Knowledge Base - + Returns: Dict[str, Any]: Processed result with standardized format """ @@ -124,12 +125,12 @@ class BedrockKBRetrieverTool(BaseTool): content_obj = result.get('content', {}) content = content_obj.get('text', '') content_type = content_obj.get('type', 'text') - + # Extract location information location = result.get('location', {}) location_type = location.get('type', 'unknown') source_uri = None - + # Map for location types and their URI fields location_mapping = { 's3Location': {'field': 'uri', 'type': 'S3'}, @@ -141,7 +142,7 @@ class BedrockKBRetrieverTool(BaseTool): 'kendraDocumentLocation': {'field': 'uri', 'type': 'KendraDocument'}, 'sqlLocation': {'field': 'query', 'type': 'SQL'} } - + # Extract the URI based on location type for loc_key, config in location_mapping.items(): if loc_key in location: @@ -149,7 +150,7 @@ class BedrockKBRetrieverTool(BaseTool): if not location_type or location_type == 'unknown': location_type = config['type'] break - + # Create result object result_object = { 'content': content, @@ -157,22 +158,22 @@ class BedrockKBRetrieverTool(BaseTool): 'source_type': location_type, 'source_uri': source_uri } - + # Add optional fields if available if 'score' in result: result_object['score'] = result['score'] - + if 'metadata' in result: result_object['metadata'] = result['metadata'] - + # Handle byte content if present if 'byteContent' in content_obj: result_object['byte_content'] = content_obj['byteContent'] - + # Handle row content if present if 'row' in content_obj: result_object['row_content'] = content_obj['row'] - + return result_object def _run(self, query: str) -> str: @@ -201,10 +202,10 @@ class BedrockKBRetrieverTool(BaseTool): # Add optional parameters if provided if self.retrieval_configuration: retrieve_params['retrievalConfiguration'] = self.retrieval_configuration - + if self.guardrail_configuration: retrieve_params['guardrailConfiguration'] = self.guardrail_configuration - + if self.next_token: retrieve_params['nextToken'] = self.next_token @@ -223,10 +224,10 @@ class BedrockKBRetrieverTool(BaseTool): response_object["results"] = results else: response_object["message"] = "No results found for the given query." - + if "nextToken" in response: response_object["nextToken"] = response["nextToken"] - + if "guardrailAction" in response: response_object["guardrailAction"] = response["guardrailAction"] @@ -236,12 +237,12 @@ class BedrockKBRetrieverTool(BaseTool): except ClientError as e: error_code = "Unknown" error_message = str(e) - + # Try to extract error code if available if hasattr(e, 'response') and 'Error' in e.response: error_code = e.response['Error'].get('Code', 'Unknown') error_message = e.response['Error'].get('Message', str(e)) - + raise BedrockKnowledgeBaseError(f"Error ({error_code}): {error_message}") except Exception as e: raise BedrockKnowledgeBaseError(f"Unexpected error: {str(e)}") \ No newline at end of file diff --git a/src/crewai_tools/aws/s3/reader_tool.py b/src/crewai_tools/aws/s3/reader_tool.py index 4b3b9a394..c3f1fa4eb 100644 --- a/src/crewai_tools/aws/s3/reader_tool.py +++ b/src/crewai_tools/aws/s3/reader_tool.py @@ -1,4 +1,4 @@ -from typing import Any, Type +from typing import Any, Type, List import os from crewai.tools import BaseTool @@ -15,6 +15,7 @@ class S3ReaderTool(BaseTool): name: str = "S3 Reader Tool" description: str = "Reads a file from Amazon S3 given an S3 file path" args_schema: Type[BaseModel] = S3ReaderToolInput + package_dependencies: List[str] = ["boto3"] def _run(self, file_path: str) -> str: try: diff --git a/src/crewai_tools/aws/s3/writer_tool.py b/src/crewai_tools/aws/s3/writer_tool.py index f0aaddb28..2e1528d13 100644 --- a/src/crewai_tools/aws/s3/writer_tool.py +++ b/src/crewai_tools/aws/s3/writer_tool.py @@ -1,4 +1,4 @@ -from typing import Any, Type +from typing import Type, List import os from crewai.tools import BaseTool @@ -14,6 +14,7 @@ class S3WriterTool(BaseTool): name: str = "S3 Writer Tool" description: str = "Writes content to a file in Amazon S3 given an S3 file path" args_schema: Type[BaseModel] = S3WriterToolInput + package_dependencies: List[str] = ["boto3"] def _run(self, file_path: str, content: str) -> str: try: diff --git a/src/crewai_tools/tools/ai_mind_tool/ai_mind_tool.py b/src/crewai_tools/tools/ai_mind_tool/ai_mind_tool.py index b38426e09..1a96f62ff 100644 --- a/src/crewai_tools/tools/ai_mind_tool/ai_mind_tool.py +++ b/src/crewai_tools/tools/ai_mind_tool/ai_mind_tool.py @@ -1,10 +1,10 @@ import os import secrets -from typing import Any, Dict, List, Optional, Text, Type +from typing import Any, Dict, List, Optional, Type from crewai.tools import BaseTool from openai import OpenAI -from pydantic import BaseModel +from pydantic import BaseModel, Field class AIMindToolConstants: @@ -16,7 +16,7 @@ class AIMindToolConstants: class AIMindToolInputSchema(BaseModel): """Input for AIMind Tool.""" - query: str = "Question in natural language to ask the AI-Mind" + query: str = Field(description="Question in natural language to ask the AI-Mind") class AIMindTool(BaseTool): @@ -31,9 +31,10 @@ class AIMindTool(BaseTool): args_schema: Type[BaseModel] = AIMindToolInputSchema api_key: Optional[str] = None datasources: Optional[List[Dict[str, Any]]] = None - mind_name: Optional[Text] = None + mind_name: Optional[str] = None + package_dependencies: List[str] = ["minds-sdk"] - def __init__(self, api_key: Optional[Text] = None, **kwargs): + def __init__(self, api_key: Optional[str] = None, **kwargs): super().__init__(**kwargs) self.api_key = api_key or os.getenv("MINDS_API_KEY") if not self.api_key: @@ -72,7 +73,7 @@ class AIMindTool(BaseTool): def _run( self, - query: Text + query: str ): # Run the query on the AI-Mind. # The Minds API is OpenAI compatible and therefore, the OpenAI client can be used. diff --git a/src/crewai_tools/tools/apify_actors_tool/apify_actors_tool.py b/src/crewai_tools/tools/apify_actors_tool/apify_actors_tool.py index 37ae7312b..44c4839e8 100644 --- a/src/crewai_tools/tools/apify_actors_tool/apify_actors_tool.py +++ b/src/crewai_tools/tools/apify_actors_tool/apify_actors_tool.py @@ -38,6 +38,7 @@ class ApifyActorsTool(BaseTool): print(f"Content: {result.get('markdown', 'N/A')[:100]}...") """ actor_tool: '_ApifyActorsTool' = Field(description="Apify Actor Tool") + package_dependencies: List[str] = ["langchain-apify"] def __init__( self, diff --git a/src/crewai_tools/tools/browserbase_load_tool/browserbase_load_tool.py b/src/crewai_tools/tools/browserbase_load_tool/browserbase_load_tool.py index 3a2462f5e..f946baf73 100644 --- a/src/crewai_tools/tools/browserbase_load_tool/browserbase_load_tool.py +++ b/src/crewai_tools/tools/browserbase_load_tool/browserbase_load_tool.py @@ -1,5 +1,5 @@ import os -from typing import Any, Optional, Type +from typing import Any, Optional, Type, List from crewai.tools import BaseTool from pydantic import BaseModel, Field @@ -19,6 +19,7 @@ class BrowserbaseLoadTool(BaseTool): session_id: Optional[str] = None proxy: Optional[bool] = None browserbase: Optional[Any] = None + package_dependencies: List[str] = ["browserbase"] def __init__( self, diff --git a/src/crewai_tools/tools/dalle_tool/dalle_tool.py b/src/crewai_tools/tools/dalle_tool/dalle_tool.py index 7040de11a..8957d9636 100644 --- a/src/crewai_tools/tools/dalle_tool/dalle_tool.py +++ b/src/crewai_tools/tools/dalle_tool/dalle_tool.py @@ -3,13 +3,13 @@ from typing import Type from crewai.tools import BaseTool from openai import OpenAI -from pydantic import BaseModel +from pydantic import BaseModel, Field class ImagePromptSchema(BaseModel): """Input for Dall-E Tool.""" - image_description: str = "Description of the image to be generated by Dall-E." + image_description: str = Field(description="Description of the image to be generated by Dall-E.") class DallETool(BaseTool): diff --git a/src/crewai_tools/tools/databricks_query_tool/databricks_query_tool.py b/src/crewai_tools/tools/databricks_query_tool/databricks_query_tool.py index 428cea5d3..fe73179cb 100644 --- a/src/crewai_tools/tools/databricks_query_tool/databricks_query_tool.py +++ b/src/crewai_tools/tools/databricks_query_tool/databricks_query_tool.py @@ -69,6 +69,7 @@ class DatabricksQueryTool(BaseTool): default_warehouse_id: Optional[str] = None _workspace_client: Optional["WorkspaceClient"] = None + package_dependencies: List[str] = ["databricks-sdk"] def __init__( self, diff --git a/src/crewai_tools/tools/exa_tools/exa_search_tool.py b/src/crewai_tools/tools/exa_tools/exa_search_tool.py index f094b0495..d626c03ed 100644 --- a/src/crewai_tools/tools/exa_tools/exa_search_tool.py +++ b/src/crewai_tools/tools/exa_tools/exa_search_tool.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Type +from typing import Any, Optional, Type, List from pydantic import BaseModel, Field from crewai.tools import BaseTool @@ -35,6 +35,7 @@ class EXASearchTool(BaseTool): content: Optional[bool] = False summary: Optional[bool] = False type: Optional[str] = "auto" + package_dependencies: List[str] = ["exa_py"] def __init__( self, diff --git a/src/crewai_tools/tools/firecrawl_crawl_website_tool/firecrawl_crawl_website_tool.py b/src/crewai_tools/tools/firecrawl_crawl_website_tool/firecrawl_crawl_website_tool.py index ee7e5e3d9..6642fbd54 100644 --- a/src/crewai_tools/tools/firecrawl_crawl_website_tool/firecrawl_crawl_website_tool.py +++ b/src/crewai_tools/tools/firecrawl_crawl_website_tool/firecrawl_crawl_website_tool.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Type +from typing import Any, Optional, Type, List from crewai.tools import BaseTool from pydantic import BaseModel, ConfigDict, Field, PrivateAttr @@ -55,6 +55,7 @@ class FirecrawlCrawlWebsiteTool(BaseTool): } ) _firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None) + package_dependencies: List[str] = ["firecrawl-py"] def __init__(self, api_key: Optional[str] = None, **kwargs): super().__init__(**kwargs) diff --git a/src/crewai_tools/tools/firecrawl_scrape_website_tool/firecrawl_scrape_website_tool.py b/src/crewai_tools/tools/firecrawl_scrape_website_tool/firecrawl_scrape_website_tool.py index fcb5c6c8d..acb1c0af5 100644 --- a/src/crewai_tools/tools/firecrawl_scrape_website_tool/firecrawl_scrape_website_tool.py +++ b/src/crewai_tools/tools/firecrawl_scrape_website_tool/firecrawl_scrape_website_tool.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Type, Dict +from typing import Any, Optional, Type, Dict, List from crewai.tools import BaseTool from pydantic import BaseModel, ConfigDict, Field, PrivateAttr @@ -48,6 +48,7 @@ class FirecrawlScrapeWebsiteTool(BaseTool): ) _firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None) + package_dependencies: List[str] = ["firecrawl-py"] def __init__(self, api_key: Optional[str] = None, **kwargs): super().__init__(**kwargs) diff --git a/src/crewai_tools/tools/firecrawl_search_tool/firecrawl_search_tool.py b/src/crewai_tools/tools/firecrawl_search_tool/firecrawl_search_tool.py index 8b563778c..0fb091b68 100644 --- a/src/crewai_tools/tools/firecrawl_search_tool/firecrawl_search_tool.py +++ b/src/crewai_tools/tools/firecrawl_search_tool/firecrawl_search_tool.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Dict, Optional, Type +from typing import TYPE_CHECKING, Any, Dict, Optional, Type, List from crewai.tools import BaseTool from pydantic import BaseModel, ConfigDict, Field, PrivateAttr @@ -57,6 +57,7 @@ class FirecrawlSearchTool(BaseTool): } ) _firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None) + package_dependencies: List[str] = ["firecrawl-py"] def __init__(self, api_key: Optional[str] = None, **kwargs): super().__init__(**kwargs) diff --git a/src/crewai_tools/tools/hyperbrowser_load_tool/hyperbrowser_load_tool.py b/src/crewai_tools/tools/hyperbrowser_load_tool/hyperbrowser_load_tool.py index b802d1859..5359427b0 100644 --- a/src/crewai_tools/tools/hyperbrowser_load_tool/hyperbrowser_load_tool.py +++ b/src/crewai_tools/tools/hyperbrowser_load_tool/hyperbrowser_load_tool.py @@ -1,5 +1,5 @@ import os -from typing import Any, Optional, Type, Dict, Literal, Union +from typing import Any, Optional, Type, Dict, Literal, Union, List from crewai.tools import BaseTool from pydantic import BaseModel, Field @@ -25,6 +25,7 @@ class HyperbrowserLoadTool(BaseTool): args_schema: Type[BaseModel] = HyperbrowserLoadToolSchema api_key: Optional[str] = None hyperbrowser: Optional[Any] = None + package_dependencies: List[str] = ["hyperbrowser"] def __init__(self, api_key: Optional[str] = None, **kwargs): super().__init__(**kwargs) @@ -65,7 +66,7 @@ class HyperbrowserLoadTool(BaseTool): if "scrape_options" in params: params["scrape_options"] = ScrapeOptions(**params["scrape_options"]) return params - + def _extract_content(self, data: Union[Any, None]): """Extract content from response data.""" content = "" @@ -81,7 +82,7 @@ class HyperbrowserLoadTool(BaseTool): raise ImportError( "`hyperbrowser` package not found, please run `pip install hyperbrowser`" ) - + params = self._prepare_params(params) if operation == 'scrape': diff --git a/src/crewai_tools/tools/linkup/linkup_search_tool.py b/src/crewai_tools/tools/linkup/linkup_search_tool.py index 4eb2d82b3..c35c7fac3 100644 --- a/src/crewai_tools/tools/linkup/linkup_search_tool.py +++ b/src/crewai_tools/tools/linkup/linkup_search_tool.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, List from crewai.tools import BaseTool @@ -23,6 +23,7 @@ class LinkupSearchTool(BaseTool): "Performs an API call to Linkup to retrieve contextual information." ) _client: LinkupClient = PrivateAttr() # type: ignore + package_dependencies: List[str] = ["linkup-sdk"] def __init__(self, api_key: str): """ diff --git a/src/crewai_tools/tools/multion_tool/multion_tool.py b/src/crewai_tools/tools/multion_tool/multion_tool.py index d49321dc0..3c8d17819 100644 --- a/src/crewai_tools/tools/multion_tool/multion_tool.py +++ b/src/crewai_tools/tools/multion_tool/multion_tool.py @@ -1,6 +1,6 @@ """Multion tool spec.""" -from typing import Any, Optional +from typing import Any, Optional, List from crewai.tools import BaseTool @@ -16,6 +16,7 @@ class MultiOnTool(BaseTool): session_id: Optional[str] = None local: bool = False max_steps: int = 3 + package_dependencies: List[str] = ["multion"] def __init__( self, diff --git a/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py b/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py index 602b45864..30b78a3c4 100644 --- a/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py +++ b/src/crewai_tools/tools/patronus_eval_tool/patronus_local_evaluator_tool.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Type +from typing import TYPE_CHECKING, Any, Type, List from crewai.tools import BaseTool from pydantic import BaseModel, ConfigDict, Field @@ -41,6 +41,7 @@ class PatronusLocalEvaluatorTool(BaseTool): evaluated_model_gold_answer: str model_config = ConfigDict(arbitrary_types_allowed=True) + package_dependencies: List[str] = ["patronus"] def __init__( self, diff --git a/src/crewai_tools/tools/qdrant_vector_search_tool/qdrant_search_tool.py b/src/crewai_tools/tools/qdrant_vector_search_tool/qdrant_search_tool.py index 29f172cdf..73e373ae8 100644 --- a/src/crewai_tools/tools/qdrant_vector_search_tool/qdrant_search_tool.py +++ b/src/crewai_tools/tools/qdrant_vector_search_tool/qdrant_search_tool.py @@ -1,6 +1,6 @@ import json import os -from typing import Any, Callable, Optional, Type +from typing import Any, Callable, Optional, Type, List try: @@ -74,6 +74,7 @@ class QdrantVectorSearchTool(BaseTool): default=None, description="A custom embedding function to use for vectorization. If not provided, the default model will be used.", ) + package_dependencies: List[str] = ["qdrant-client"] def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py b/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py index 70764c294..bc3bd667b 100644 --- a/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py +++ b/src/crewai_tools/tools/scrapegraph_scrape_tool/scrapegraph_scrape_tool.py @@ -1,5 +1,5 @@ import os -from typing import TYPE_CHECKING, Any, Optional, Type +from typing import TYPE_CHECKING, Any, Optional, Type, List from urllib.parse import urlparse from crewai.tools import BaseTool @@ -67,6 +67,7 @@ class ScrapegraphScrapeTool(BaseTool): api_key: Optional[str] = None enable_logging: bool = False _client: Optional["Client"] = None + package_dependencies: List[str] = ["scrapegraph-py"] def __init__( self, diff --git a/src/crewai_tools/tools/scrapfly_scrape_website_tool/scrapfly_scrape_website_tool.py b/src/crewai_tools/tools/scrapfly_scrape_website_tool/scrapfly_scrape_website_tool.py index 4d6b72b61..60fc75e16 100644 --- a/src/crewai_tools/tools/scrapfly_scrape_website_tool/scrapfly_scrape_website_tool.py +++ b/src/crewai_tools/tools/scrapfly_scrape_website_tool/scrapfly_scrape_website_tool.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, Literal, Optional, Type +from typing import Any, Dict, Literal, Optional, Type, List from crewai.tools import BaseTool from pydantic import BaseModel, Field @@ -28,6 +28,7 @@ class ScrapflyScrapeWebsiteTool(BaseTool): args_schema: Type[BaseModel] = ScrapflyScrapeWebsiteToolSchema api_key: str = None scrapfly: Optional[Any] = None + package_dependencies: List[str] = ["scrapfly-sdk"] def __init__(self, api_key: str): super().__init__() diff --git a/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py b/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py index 3976facef..5f7365c8a 100644 --- a/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py +++ b/src/crewai_tools/tools/selenium_scraping_tool/selenium_scraping_tool.py @@ -1,6 +1,6 @@ import re import time -from typing import Any, Optional, Type +from typing import Any, Optional, Type, List from urllib.parse import urlparse from crewai.tools import BaseTool @@ -58,6 +58,7 @@ class SeleniumScrapingTool(BaseTool): css_element: Optional[str] = None return_html: Optional[bool] = False _by: Optional[Any] = None + package_dependencies: List[str] = ["selenium", "webdriver-manager"] def __init__( self, diff --git a/src/crewai_tools/tools/serpapi_tool/serpapi_base_tool.py b/src/crewai_tools/tools/serpapi_tool/serpapi_base_tool.py index f6e639a37..c0a5ca9c9 100644 --- a/src/crewai_tools/tools/serpapi_tool/serpapi_base_tool.py +++ b/src/crewai_tools/tools/serpapi_tool/serpapi_base_tool.py @@ -1,6 +1,6 @@ import os import re -from typing import Any, Optional, Union +from typing import Any, Optional, Union, List from crewai.tools import BaseTool @@ -8,6 +8,8 @@ from crewai.tools import BaseTool class SerpApiBaseTool(BaseTool): """Base class for SerpApi functionality with shared capabilities.""" + package_dependencies: List[str] = ["serpapi"] + client: Optional[Any] = None def __init__(self, **kwargs): diff --git a/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py b/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py index bacec2917..a4cd21044 100644 --- a/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py +++ b/src/crewai_tools/tools/snowflake_search_tool/snowflake_search_tool.py @@ -99,6 +99,7 @@ class SnowflakeSearchTool(BaseTool): _pool_lock: Optional[asyncio.Lock] = None _thread_pool: Optional[ThreadPoolExecutor] = None _model_rebuilt: bool = False + package_dependencies: List[str] = ["snowflake-connector-python", "snowflake-sqlalchemy", "cryptography"] def __init__(self, **data): """Initialize SnowflakeSearchTool.""" diff --git a/src/crewai_tools/tools/spider_tool/spider_tool.py b/src/crewai_tools/tools/spider_tool/spider_tool.py index ff52a35dc..853833261 100644 --- a/src/crewai_tools/tools/spider_tool/spider_tool.py +++ b/src/crewai_tools/tools/spider_tool/spider_tool.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, Literal, Optional, Type +from typing import Any, Dict, Literal, Optional, Type, List from urllib.parse import unquote, urlparse from crewai.tools import BaseTool @@ -53,6 +53,7 @@ class SpiderTool(BaseTool): spider: Any = None log_failures: bool = True config: SpiderToolConfig = SpiderToolConfig() + package_dependencies: List[str] = ["spider-client"] def __init__( self, diff --git a/src/crewai_tools/tools/stagehand_tool/stagehand_tool.py b/src/crewai_tools/tools/stagehand_tool/stagehand_tool.py index 5a4d5f485..557c6cb6f 100644 --- a/src/crewai_tools/tools/stagehand_tool/stagehand_tool.py +++ b/src/crewai_tools/tools/stagehand_tool/stagehand_tool.py @@ -88,6 +88,7 @@ class StagehandToolSchema(BaseModel): class StagehandTool(BaseTool): + package_dependencies: List[str] = ["stagehand"] """ A tool that uses Stagehand to automate web browser interactions using natural language. diff --git a/src/crewai_tools/tools/tavily_extractor_tool/tavily_extractor_tool.py b/src/crewai_tools/tools/tavily_extractor_tool/tavily_extractor_tool.py index 0320ab104..043e01fac 100644 --- a/src/crewai_tools/tools/tavily_extractor_tool/tavily_extractor_tool.py +++ b/src/crewai_tools/tools/tavily_extractor_tool/tavily_extractor_tool.py @@ -26,6 +26,7 @@ class TavilyExtractorToolSchema(BaseModel): class TavilyExtractorTool(BaseTool): + package_dependencies: List[str] = ["tavily-python"] """ Tool that uses the Tavily API to extract content from web pages. diff --git a/src/crewai_tools/tools/tavily_search_tool/tavily_search_tool.py b/src/crewai_tools/tools/tavily_search_tool/tavily_search_tool.py index 1179be90d..16841c380 100644 --- a/src/crewai_tools/tools/tavily_search_tool/tavily_search_tool.py +++ b/src/crewai_tools/tools/tavily_search_tool/tavily_search_tool.py @@ -1,6 +1,6 @@ from crewai.tools import BaseTool from pydantic import BaseModel, Field -from typing import Optional, Type, Any, Union, Literal, Sequence +from typing import Optional, Type, Any, Union, Literal, Sequence, List from dotenv import load_dotenv import os import json @@ -101,6 +101,7 @@ class TavilySearchTool(BaseTool): default=1000, description="Maximum length for the 'content' of each search result to avoid context window issues.", ) + package_dependencies: List[str] = ["tavily-python"] def __init__(self, **kwargs: Any): super().__init__(**kwargs) diff --git a/src/crewai_tools/tools/weaviate_tool/vector_search.py b/src/crewai_tools/tools/weaviate_tool/vector_search.py index d363ba7e1..fa332f231 100644 --- a/src/crewai_tools/tools/weaviate_tool/vector_search.py +++ b/src/crewai_tools/tools/weaviate_tool/vector_search.py @@ -1,6 +1,6 @@ import json import os -from typing import Any, Optional, Type +from typing import Any, Optional, Type, List try: import weaviate @@ -31,6 +31,7 @@ class WeaviateToolSchema(BaseModel): class WeaviateVectorSearchTool(BaseTool): """Tool to search the Weaviate database""" + package_dependencies: List[str] = ["weaviate-client"] name: str = "WeaviateVectorSearchTool" description: str = "A tool to search the Weaviate database for relevant information on internal documents." args_schema: Type[BaseModel] = WeaviateToolSchema @@ -48,6 +49,7 @@ class WeaviateVectorSearchTool(BaseTool): ..., description="The API key for the Weaviate cluster", ) + package_dependencies: List[str] = ["weaviate-client"] def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/tests/test_generate_tool_specs.py b/tests/test_generate_tool_specs.py index cb3b18cd5..1315fca49 100644 --- a/tests/test_generate_tool_specs.py +++ b/tests/test_generate_tool_specs.py @@ -1,12 +1,12 @@ import json -from typing import List, Optional +from typing import List, Optional, Type import pytest from pydantic import BaseModel, Field from unittest import mock from generate_tool_specs import ToolSpecExtractor -from crewai.tools.base_tool import EnvVar +from crewai.tools.base_tool import BaseTool, EnvVar class MockToolSchema(BaseModel): query: str = Field(..., description="The query parameter") @@ -14,54 +14,26 @@ class MockToolSchema(BaseModel): filters: Optional[List[str]] = Field(None, description="Optional filters to apply") -class MockTool: - name = "Mock Search Tool" - description = "A tool that mocks search functionality" - args_schema = MockToolSchema +class MockTool(BaseTool): + name: str = "Mock Search Tool" + description: str = "A tool that mocks search functionality" + args_schema: Type[BaseModel] = MockToolSchema + + another_parameter: str = Field("Another way to define a default value", description="") + my_parameter: str = Field("This is default value", description="What a description") + my_parameter_bool: bool = Field(False) + package_dependencies: List[str] = Field(["this-is-a-required-package", "another-required-package"], description="") + env_vars: List[EnvVar] = [ + EnvVar(name="SERPER_API_KEY", description="API key for Serper", required=True, default=None), + EnvVar(name="API_RATE_LIMIT", description="API rate limit", required=False, default="100") + ] @pytest.fixture def extractor(): ext = ToolSpecExtractor() - MockTool.__pydantic_core_schema__ = create_mock_schema(MockTool) - MockTool.args_schema.__pydantic_core_schema__ = create_mock_schema_args(MockTool.args_schema) return ext -def create_mock_schema(cls): - return { - "type": "model", - "cls": cls, - "schema": { - "type": "model-fields", - "fields": { - "name": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "str"}, "default": cls.name}, "metadata": {}}, - "description": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "str"}, "default": cls.description}, "metadata": {}}, - "args_schema": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "is-subclass", "cls": BaseModel}, "default": cls.args_schema}, "metadata": {}}, - "env_vars": { - "type": "model-field", "schema": {"type": "default", "schema": {"type": "list", "items_schema": {"type": "model", "cls": "INSPECT CLASS", "schema": {"type": "model-fields", "fields": {"name": {"type": "model-field", "schema": {"type": "str"}, "metadata": {}}, "description": {"type": "model-field", "schema": {"type": "str"}, "metadata": {}}, "required": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "bool"}, "default": True}, "metadata": {}}, "default": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "nullable", "schema": {"type": "str"}}, "default": None}, "metadata": {}},}, "model_name": "EnvVar", "computed_fields": []}, "custom_init": False, "root_model": False, "config": {"title": "EnvVar"}, "ref": "crewai.tools.base_tool.EnvVar:4593650640", "metadata": {"pydantic_js_functions": ["INSPECT __get_pydantic_json_schema__"]}}}, "default": [EnvVar(name='SERPER_API_KEY', description='API key for Serper', required=True, default=None), EnvVar(name='API_RATE_LIMIT', description='API rate limit', required=False, default="100")]}, "metadata": {} - } - }, - "model_name": cls.__name__ - } - } - - -def create_mock_schema_args(cls): - return { - "type": "model", - "cls": cls, - "schema": { - "type": "model-fields", - "fields": { - "query": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "str"}, "default": "The query parameter"}}, - "count": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "int"}, "default": 5}, "metadata": {"pydantic_js_updates": {"description": "Number of results to return"}}}, - "filters": {"type": "model-field", "schema": {"type": "nullable", "schema": {"type": "list", "items_schema": {"type": "str"}}}} - }, - "model_name": cls.__name__ - } - } - - def test_unwrap_schema(extractor): nested_schema = { "type": "function-after", @@ -72,19 +44,6 @@ def test_unwrap_schema(extractor): assert result["value"] == "test" -@pytest.mark.parametrize( - "field, fallback, expected", - [ - ({"schema": {"default": "test_value"}}, None, "test_value"), - ({}, "fallback_value", "fallback_value"), - ({"schema": {"default": 123}}, "fallback_value", "fallback_value") - ] -) -def test_extract_field_default(extractor, field, fallback, expected): - result = extractor._extract_field_default(field, fallback=fallback) - assert result == expected - - @pytest.mark.parametrize( "schema, expected", [ @@ -112,7 +71,7 @@ def test_extract_param_type(extractor, info, expected_type): assert extractor._extract_param_type(info) == expected_type -def test_extract_tool_info(extractor): +def test_extract_all_tools(extractor): with mock.patch("generate_tool_specs.dir", return_value=["MockTool"]), \ mock.patch("generate_tool_specs.getattr", return_value=MockTool): extractor.extract_all_tools() @@ -120,6 +79,16 @@ def test_extract_tool_info(extractor): assert len(extractor.tools_spec) == 1 tool_info = extractor.tools_spec[0] + assert tool_info.keys() == { + "name", + "humanized_name", + "description", + "run_params", + "env_vars", + "init_params", + "package_dependencies", + } + assert tool_info["name"] == "MockTool" assert tool_info["humanized_name"] == "Mock Search Tool" assert tool_info["description"] == "A tool that mocks search functionality" @@ -142,12 +111,16 @@ def test_extract_tool_info(extractor): params = {p["name"]: p for p in tool_info["run_params"]} assert params["query"]["description"] == "The query parameter" assert params["query"]["type"] == "str" + assert params["query"]["default"] == "" - assert params["count"]["description"] == "Number of results to return" assert params["count"]["type"] == "int" + assert params["count"]["default"] == 5 - assert params["filters"]["description"] == "" + assert params["filters"]["description"] == "Optional filters to apply" assert params["filters"]["type"] == "list[str]" + assert params["filters"]["default"] == "" + + assert tool_info["package_dependencies"] == ["this-is-a-required-package", "another-required-package"] def test_save_to_json(extractor, tmp_path):