mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
Support to collect extra package dependencies of Tools (#330)
* feat: add explictly package_dependencies in the Tools * feat: collect package_dependencies from Tool to add in tool.specs.json * feat: add default value in run_params Tool' specs * fix: support get boolean values This commit also refactor test to make easier define newest attributes into a Tool
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Type, Optional, Dict, Any
|
||||
from typing import Type, Optional, Dict, Any, List
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
@@ -29,6 +29,7 @@ class BedrockInvokeAgentTool(BaseTool):
|
||||
session_id: str = None
|
||||
enable_trace: bool = False
|
||||
end_session: bool = False
|
||||
package_dependencies: List[str] = ["boto3"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -51,7 +52,7 @@ class BedrockInvokeAgentTool(BaseTool):
|
||||
description (Optional[str]): Custom description for the tool
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
# Get values from environment variables if not provided
|
||||
self.agent_id = agent_id or os.getenv('BEDROCK_AGENT_ID')
|
||||
self.agent_alias_id = agent_alias_id or os.getenv('BEDROCK_AGENT_ALIAS_ID')
|
||||
@@ -62,7 +63,7 @@ class BedrockInvokeAgentTool(BaseTool):
|
||||
# Update the description if provided
|
||||
if description:
|
||||
self.description = description
|
||||
|
||||
|
||||
# Validate parameters
|
||||
self._validate_parameters()
|
||||
|
||||
@@ -74,17 +75,17 @@ class BedrockInvokeAgentTool(BaseTool):
|
||||
raise BedrockValidationError("agent_id cannot be empty")
|
||||
if not isinstance(self.agent_id, str):
|
||||
raise BedrockValidationError("agent_id must be a string")
|
||||
|
||||
|
||||
# Validate agent_alias_id
|
||||
if not self.agent_alias_id:
|
||||
raise BedrockValidationError("agent_alias_id cannot be empty")
|
||||
if not isinstance(self.agent_alias_id, str):
|
||||
raise BedrockValidationError("agent_alias_id must be a string")
|
||||
|
||||
|
||||
# Validate session_id if provided
|
||||
if self.session_id and not isinstance(self.session_id, str):
|
||||
raise BedrockValidationError("session_id must be a string")
|
||||
|
||||
|
||||
except BedrockValidationError as e:
|
||||
raise BedrockValidationError(f"Parameter validation failed: {str(e)}")
|
||||
|
||||
@@ -123,7 +124,7 @@ Below is the users query or task. Complete it and answer it consicely and to the
|
||||
|
||||
# Process the response
|
||||
completion = ""
|
||||
|
||||
|
||||
# Check if response contains a completion field
|
||||
if 'completion' in response:
|
||||
# Process streaming response format
|
||||
@@ -134,7 +135,7 @@ Below is the users query or task. Complete it and answer it consicely and to the
|
||||
completion += chunk_bytes.decode('utf-8')
|
||||
else:
|
||||
completion += str(chunk_bytes)
|
||||
|
||||
|
||||
# If no completion found in streaming format, try direct format
|
||||
if not completion and 'chunk' in response and 'bytes' in response['chunk']:
|
||||
chunk_bytes = response['chunk']['bytes']
|
||||
@@ -142,31 +143,31 @@ Below is the users query or task. Complete it and answer it consicely and to the
|
||||
completion = chunk_bytes.decode('utf-8')
|
||||
else:
|
||||
completion = str(chunk_bytes)
|
||||
|
||||
|
||||
# If still no completion, return debug info
|
||||
if not completion:
|
||||
debug_info = {
|
||||
"error": "Could not extract completion from response",
|
||||
"response_keys": list(response.keys())
|
||||
}
|
||||
|
||||
|
||||
# Add more debug info
|
||||
if 'chunk' in response:
|
||||
debug_info["chunk_keys"] = list(response['chunk'].keys())
|
||||
|
||||
|
||||
raise BedrockAgentError(f"Failed to extract completion: {json.dumps(debug_info, indent=2)}")
|
||||
|
||||
|
||||
return completion
|
||||
|
||||
except ClientError as e:
|
||||
error_code = "Unknown"
|
||||
error_message = str(e)
|
||||
|
||||
|
||||
# Try to extract error code if available
|
||||
if hasattr(e, 'response') and 'Error' in e.response:
|
||||
error_code = e.response['Error'].get('Code', 'Unknown')
|
||||
error_message = e.response['Error'].get('Message', str(e))
|
||||
|
||||
|
||||
raise BedrockAgentError(f"Error ({error_code}): {error_message}")
|
||||
except BedrockAgentError:
|
||||
# Re-raise BedrockAgentError exceptions
|
||||
|
||||
@@ -26,6 +26,7 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
retrieval_configuration: Optional[Dict[str, Any]] = None
|
||||
guardrail_configuration: Optional[Dict[str, Any]] = None
|
||||
next_token: Optional[str] = None
|
||||
package_dependencies: List[str] = ["boto3"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -46,13 +47,13 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
next_token (Optional[str], optional): Token for retrieving the next batch of results. Defaults to None.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
# Get knowledge_base_id from environment variable if not provided
|
||||
self.knowledge_base_id = knowledge_base_id or os.getenv('BEDROCK_KB_ID')
|
||||
self.number_of_results = number_of_results
|
||||
self.guardrail_configuration = guardrail_configuration
|
||||
self.next_token = next_token
|
||||
|
||||
|
||||
# Initialize retrieval_configuration with provided parameters or use the one provided
|
||||
if retrieval_configuration is None:
|
||||
self.retrieval_configuration = self._build_retrieval_configuration()
|
||||
@@ -67,16 +68,16 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
|
||||
def _build_retrieval_configuration(self) -> Dict[str, Any]:
|
||||
"""Build the retrieval configuration based on provided parameters.
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The constructed retrieval configuration
|
||||
"""
|
||||
vector_search_config = {}
|
||||
|
||||
|
||||
# Add number of results if provided
|
||||
if self.number_of_results is not None:
|
||||
vector_search_config["numberOfResults"] = self.number_of_results
|
||||
|
||||
|
||||
return {"vectorSearchConfiguration": vector_search_config}
|
||||
|
||||
def _validate_parameters(self):
|
||||
@@ -91,7 +92,7 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
raise BedrockValidationError("knowledge_base_id must be 10 characters or less")
|
||||
if not all(c.isalnum() for c in self.knowledge_base_id):
|
||||
raise BedrockValidationError("knowledge_base_id must contain only alphanumeric characters")
|
||||
|
||||
|
||||
# Validate next_token if provided
|
||||
if self.next_token:
|
||||
if not isinstance(self.next_token, str):
|
||||
@@ -100,23 +101,23 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
raise BedrockValidationError("next_token must be between 1 and 2048 characters")
|
||||
if ' ' in self.next_token:
|
||||
raise BedrockValidationError("next_token cannot contain spaces")
|
||||
|
||||
|
||||
# Validate number_of_results if provided
|
||||
if self.number_of_results is not None:
|
||||
if not isinstance(self.number_of_results, int):
|
||||
raise BedrockValidationError("number_of_results must be an integer")
|
||||
if self.number_of_results < 1:
|
||||
raise BedrockValidationError("number_of_results must be greater than 0")
|
||||
|
||||
|
||||
except BedrockValidationError as e:
|
||||
raise BedrockValidationError(f"Parameter validation failed: {str(e)}")
|
||||
|
||||
def _process_retrieval_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Process a single retrieval result from Bedrock Knowledge Base.
|
||||
|
||||
|
||||
Args:
|
||||
result (Dict[str, Any]): Raw result from Bedrock Knowledge Base
|
||||
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Processed result with standardized format
|
||||
"""
|
||||
@@ -124,12 +125,12 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
content_obj = result.get('content', {})
|
||||
content = content_obj.get('text', '')
|
||||
content_type = content_obj.get('type', 'text')
|
||||
|
||||
|
||||
# Extract location information
|
||||
location = result.get('location', {})
|
||||
location_type = location.get('type', 'unknown')
|
||||
source_uri = None
|
||||
|
||||
|
||||
# Map for location types and their URI fields
|
||||
location_mapping = {
|
||||
's3Location': {'field': 'uri', 'type': 'S3'},
|
||||
@@ -141,7 +142,7 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
'kendraDocumentLocation': {'field': 'uri', 'type': 'KendraDocument'},
|
||||
'sqlLocation': {'field': 'query', 'type': 'SQL'}
|
||||
}
|
||||
|
||||
|
||||
# Extract the URI based on location type
|
||||
for loc_key, config in location_mapping.items():
|
||||
if loc_key in location:
|
||||
@@ -149,7 +150,7 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
if not location_type or location_type == 'unknown':
|
||||
location_type = config['type']
|
||||
break
|
||||
|
||||
|
||||
# Create result object
|
||||
result_object = {
|
||||
'content': content,
|
||||
@@ -157,22 +158,22 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
'source_type': location_type,
|
||||
'source_uri': source_uri
|
||||
}
|
||||
|
||||
|
||||
# Add optional fields if available
|
||||
if 'score' in result:
|
||||
result_object['score'] = result['score']
|
||||
|
||||
|
||||
if 'metadata' in result:
|
||||
result_object['metadata'] = result['metadata']
|
||||
|
||||
|
||||
# Handle byte content if present
|
||||
if 'byteContent' in content_obj:
|
||||
result_object['byte_content'] = content_obj['byteContent']
|
||||
|
||||
|
||||
# Handle row content if present
|
||||
if 'row' in content_obj:
|
||||
result_object['row_content'] = content_obj['row']
|
||||
|
||||
|
||||
return result_object
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
@@ -201,10 +202,10 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
# Add optional parameters if provided
|
||||
if self.retrieval_configuration:
|
||||
retrieve_params['retrievalConfiguration'] = self.retrieval_configuration
|
||||
|
||||
|
||||
if self.guardrail_configuration:
|
||||
retrieve_params['guardrailConfiguration'] = self.guardrail_configuration
|
||||
|
||||
|
||||
if self.next_token:
|
||||
retrieve_params['nextToken'] = self.next_token
|
||||
|
||||
@@ -223,10 +224,10 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
response_object["results"] = results
|
||||
else:
|
||||
response_object["message"] = "No results found for the given query."
|
||||
|
||||
|
||||
if "nextToken" in response:
|
||||
response_object["nextToken"] = response["nextToken"]
|
||||
|
||||
|
||||
if "guardrailAction" in response:
|
||||
response_object["guardrailAction"] = response["guardrailAction"]
|
||||
|
||||
@@ -236,12 +237,12 @@ class BedrockKBRetrieverTool(BaseTool):
|
||||
except ClientError as e:
|
||||
error_code = "Unknown"
|
||||
error_message = str(e)
|
||||
|
||||
|
||||
# Try to extract error code if available
|
||||
if hasattr(e, 'response') and 'Error' in e.response:
|
||||
error_code = e.response['Error'].get('Code', 'Unknown')
|
||||
error_message = e.response['Error'].get('Message', str(e))
|
||||
|
||||
|
||||
raise BedrockKnowledgeBaseError(f"Error ({error_code}): {error_message}")
|
||||
except Exception as e:
|
||||
raise BedrockKnowledgeBaseError(f"Unexpected error: {str(e)}")
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Type
|
||||
from typing import Any, Type, List
|
||||
import os
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -15,6 +15,7 @@ class S3ReaderTool(BaseTool):
|
||||
name: str = "S3 Reader Tool"
|
||||
description: str = "Reads a file from Amazon S3 given an S3 file path"
|
||||
args_schema: Type[BaseModel] = S3ReaderToolInput
|
||||
package_dependencies: List[str] = ["boto3"]
|
||||
|
||||
def _run(self, file_path: str) -> str:
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Type
|
||||
from typing import Type, List
|
||||
import os
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -14,6 +14,7 @@ class S3WriterTool(BaseTool):
|
||||
name: str = "S3 Writer Tool"
|
||||
description: str = "Writes content to a file in Amazon S3 given an S3 file path"
|
||||
args_schema: Type[BaseModel] = S3WriterToolInput
|
||||
package_dependencies: List[str] = ["boto3"]
|
||||
|
||||
def _run(self, file_path: str, content: str) -> str:
|
||||
try:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import os
|
||||
import secrets
|
||||
from typing import Any, Dict, List, Optional, Text, Type
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AIMindToolConstants:
|
||||
@@ -16,7 +16,7 @@ class AIMindToolConstants:
|
||||
class AIMindToolInputSchema(BaseModel):
|
||||
"""Input for AIMind Tool."""
|
||||
|
||||
query: str = "Question in natural language to ask the AI-Mind"
|
||||
query: str = Field(description="Question in natural language to ask the AI-Mind")
|
||||
|
||||
|
||||
class AIMindTool(BaseTool):
|
||||
@@ -31,9 +31,10 @@ class AIMindTool(BaseTool):
|
||||
args_schema: Type[BaseModel] = AIMindToolInputSchema
|
||||
api_key: Optional[str] = None
|
||||
datasources: Optional[List[Dict[str, Any]]] = None
|
||||
mind_name: Optional[Text] = None
|
||||
mind_name: Optional[str] = None
|
||||
package_dependencies: List[str] = ["minds-sdk"]
|
||||
|
||||
def __init__(self, api_key: Optional[Text] = None, **kwargs):
|
||||
def __init__(self, api_key: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.api_key = api_key or os.getenv("MINDS_API_KEY")
|
||||
if not self.api_key:
|
||||
@@ -72,7 +73,7 @@ class AIMindTool(BaseTool):
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: Text
|
||||
query: str
|
||||
):
|
||||
# Run the query on the AI-Mind.
|
||||
# The Minds API is OpenAI compatible and therefore, the OpenAI client can be used.
|
||||
|
||||
@@ -38,6 +38,7 @@ class ApifyActorsTool(BaseTool):
|
||||
print(f"Content: {result.get('markdown', 'N/A')[:100]}...")
|
||||
"""
|
||||
actor_tool: '_ApifyActorsTool' = Field(description="Apify Actor Tool")
|
||||
package_dependencies: List[str] = ["langchain-apify"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Any, Optional, Type, List
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -19,6 +19,7 @@ class BrowserbaseLoadTool(BaseTool):
|
||||
session_id: Optional[str] = None
|
||||
proxy: Optional[bool] = None
|
||||
browserbase: Optional[Any] = None
|
||||
package_dependencies: List[str] = ["browserbase"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -3,13 +3,13 @@ from typing import Type
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ImagePromptSchema(BaseModel):
|
||||
"""Input for Dall-E Tool."""
|
||||
|
||||
image_description: str = "Description of the image to be generated by Dall-E."
|
||||
image_description: str = Field(description="Description of the image to be generated by Dall-E.")
|
||||
|
||||
|
||||
class DallETool(BaseTool):
|
||||
|
||||
@@ -69,6 +69,7 @@ class DatabricksQueryTool(BaseTool):
|
||||
default_warehouse_id: Optional[str] = None
|
||||
|
||||
_workspace_client: Optional["WorkspaceClient"] = None
|
||||
package_dependencies: List[str] = ["databricks-sdk"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Any, Optional, Type, List
|
||||
from pydantic import BaseModel, Field
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
@@ -35,6 +35,7 @@ class EXASearchTool(BaseTool):
|
||||
content: Optional[bool] = False
|
||||
summary: Optional[bool] = False
|
||||
type: Optional[str] = "auto"
|
||||
package_dependencies: List[str] = ["exa_py"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Any, Optional, Type, List
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
@@ -55,6 +55,7 @@ class FirecrawlCrawlWebsiteTool(BaseTool):
|
||||
}
|
||||
)
|
||||
_firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None)
|
||||
package_dependencies: List[str] = ["firecrawl-py"]
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Optional, Type, Dict
|
||||
from typing import Any, Optional, Type, Dict, List
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
@@ -48,6 +48,7 @@ class FirecrawlScrapeWebsiteTool(BaseTool):
|
||||
)
|
||||
|
||||
_firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None)
|
||||
package_dependencies: List[str] = ["firecrawl-py"]
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Type
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, List
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
@@ -57,6 +57,7 @@ class FirecrawlSearchTool(BaseTool):
|
||||
}
|
||||
)
|
||||
_firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None)
|
||||
package_dependencies: List[str] = ["firecrawl-py"]
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Optional, Type, Dict, Literal, Union
|
||||
from typing import Any, Optional, Type, Dict, Literal, Union, List
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -25,6 +25,7 @@ class HyperbrowserLoadTool(BaseTool):
|
||||
args_schema: Type[BaseModel] = HyperbrowserLoadToolSchema
|
||||
api_key: Optional[str] = None
|
||||
hyperbrowser: Optional[Any] = None
|
||||
package_dependencies: List[str] = ["hyperbrowser"]
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -65,7 +66,7 @@ class HyperbrowserLoadTool(BaseTool):
|
||||
if "scrape_options" in params:
|
||||
params["scrape_options"] = ScrapeOptions(**params["scrape_options"])
|
||||
return params
|
||||
|
||||
|
||||
def _extract_content(self, data: Union[Any, None]):
|
||||
"""Extract content from response data."""
|
||||
content = ""
|
||||
@@ -81,7 +82,7 @@ class HyperbrowserLoadTool(BaseTool):
|
||||
raise ImportError(
|
||||
"`hyperbrowser` package not found, please run `pip install hyperbrowser`"
|
||||
)
|
||||
|
||||
|
||||
params = self._prepare_params(params)
|
||||
|
||||
if operation == 'scrape':
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, List
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
@@ -23,6 +23,7 @@ class LinkupSearchTool(BaseTool):
|
||||
"Performs an API call to Linkup to retrieve contextual information."
|
||||
)
|
||||
_client: LinkupClient = PrivateAttr() # type: ignore
|
||||
package_dependencies: List[str] = ["linkup-sdk"]
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Multion tool spec."""
|
||||
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Optional, List
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
@@ -16,6 +16,7 @@ class MultiOnTool(BaseTool):
|
||||
session_id: Optional[str] = None
|
||||
local: bool = False
|
||||
max_steps: int = 3
|
||||
package_dependencies: List[str] = ["multion"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import TYPE_CHECKING, Any, Type
|
||||
from typing import TYPE_CHECKING, Any, Type, List
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
@@ -41,6 +41,7 @@ class PatronusLocalEvaluatorTool(BaseTool):
|
||||
evaluated_model_gold_answer: str
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
package_dependencies: List[str] = ["patronus"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, Optional, Type
|
||||
from typing import Any, Callable, Optional, Type, List
|
||||
|
||||
|
||||
try:
|
||||
@@ -74,6 +74,7 @@ class QdrantVectorSearchTool(BaseTool):
|
||||
default=None,
|
||||
description="A custom embedding function to use for vectorization. If not provided, the default model will be used.",
|
||||
)
|
||||
package_dependencies: List[str] = ["qdrant-client"]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Optional, Type
|
||||
from typing import TYPE_CHECKING, Any, Optional, Type, List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -67,6 +67,7 @@ class ScrapegraphScrapeTool(BaseTool):
|
||||
api_key: Optional[str] = None
|
||||
enable_logging: bool = False
|
||||
_client: Optional["Client"] = None
|
||||
package_dependencies: List[str] = ["scrapegraph-py"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Literal, Optional, Type
|
||||
from typing import Any, Dict, Literal, Optional, Type, List
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -28,6 +28,7 @@ class ScrapflyScrapeWebsiteTool(BaseTool):
|
||||
args_schema: Type[BaseModel] = ScrapflyScrapeWebsiteToolSchema
|
||||
api_key: str = None
|
||||
scrapfly: Optional[Any] = None
|
||||
package_dependencies: List[str] = ["scrapfly-sdk"]
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
super().__init__()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Any, Optional, Type, List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -58,6 +58,7 @@ class SeleniumScrapingTool(BaseTool):
|
||||
css_element: Optional[str] = None
|
||||
return_html: Optional[bool] = False
|
||||
_by: Optional[Any] = None
|
||||
package_dependencies: List[str] = ["selenium", "webdriver-manager"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Optional, Union
|
||||
from typing import Any, Optional, Union, List
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
@@ -8,6 +8,8 @@ from crewai.tools import BaseTool
|
||||
class SerpApiBaseTool(BaseTool):
|
||||
"""Base class for SerpApi functionality with shared capabilities."""
|
||||
|
||||
package_dependencies: List[str] = ["serpapi"]
|
||||
|
||||
client: Optional[Any] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
@@ -99,6 +99,7 @@ class SnowflakeSearchTool(BaseTool):
|
||||
_pool_lock: Optional[asyncio.Lock] = None
|
||||
_thread_pool: Optional[ThreadPoolExecutor] = None
|
||||
_model_rebuilt: bool = False
|
||||
package_dependencies: List[str] = ["snowflake-connector-python", "snowflake-sqlalchemy", "cryptography"]
|
||||
|
||||
def __init__(self, **data):
|
||||
"""Initialize SnowflakeSearchTool."""
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Any, Dict, Literal, Optional, Type
|
||||
from typing import Any, Dict, Literal, Optional, Type, List
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
@@ -53,6 +53,7 @@ class SpiderTool(BaseTool):
|
||||
spider: Any = None
|
||||
log_failures: bool = True
|
||||
config: SpiderToolConfig = SpiderToolConfig()
|
||||
package_dependencies: List[str] = ["spider-client"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -88,6 +88,7 @@ class StagehandToolSchema(BaseModel):
|
||||
|
||||
|
||||
class StagehandTool(BaseTool):
|
||||
package_dependencies: List[str] = ["stagehand"]
|
||||
"""
|
||||
A tool that uses Stagehand to automate web browser interactions using natural language.
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ class TavilyExtractorToolSchema(BaseModel):
|
||||
|
||||
|
||||
class TavilyExtractorTool(BaseTool):
|
||||
package_dependencies: List[str] = ["tavily-python"]
|
||||
"""
|
||||
Tool that uses the Tavily API to extract content from web pages.
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, Type, Any, Union, Literal, Sequence
|
||||
from typing import Optional, Type, Any, Union, Literal, Sequence, List
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
import json
|
||||
@@ -101,6 +101,7 @@ class TavilySearchTool(BaseTool):
|
||||
default=1000,
|
||||
description="Maximum length for the 'content' of each search result to avoid context window issues.",
|
||||
)
|
||||
package_dependencies: List[str] = ["tavily-python"]
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Optional, Type
|
||||
from typing import Any, Optional, Type, List
|
||||
|
||||
try:
|
||||
import weaviate
|
||||
@@ -31,6 +31,7 @@ class WeaviateToolSchema(BaseModel):
|
||||
class WeaviateVectorSearchTool(BaseTool):
|
||||
"""Tool to search the Weaviate database"""
|
||||
|
||||
package_dependencies: List[str] = ["weaviate-client"]
|
||||
name: str = "WeaviateVectorSearchTool"
|
||||
description: str = "A tool to search the Weaviate database for relevant information on internal documents."
|
||||
args_schema: Type[BaseModel] = WeaviateToolSchema
|
||||
@@ -48,6 +49,7 @@ class WeaviateVectorSearchTool(BaseTool):
|
||||
...,
|
||||
description="The API key for the Weaviate cluster",
|
||||
)
|
||||
package_dependencies: List[str] = ["weaviate-client"]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
from unittest import mock
|
||||
|
||||
from generate_tool_specs import ToolSpecExtractor
|
||||
from crewai.tools.base_tool import EnvVar
|
||||
from crewai.tools.base_tool import BaseTool, EnvVar
|
||||
|
||||
class MockToolSchema(BaseModel):
|
||||
query: str = Field(..., description="The query parameter")
|
||||
@@ -14,54 +14,26 @@ class MockToolSchema(BaseModel):
|
||||
filters: Optional[List[str]] = Field(None, description="Optional filters to apply")
|
||||
|
||||
|
||||
class MockTool:
|
||||
name = "Mock Search Tool"
|
||||
description = "A tool that mocks search functionality"
|
||||
args_schema = MockToolSchema
|
||||
class MockTool(BaseTool):
|
||||
name: str = "Mock Search Tool"
|
||||
description: str = "A tool that mocks search functionality"
|
||||
args_schema: Type[BaseModel] = MockToolSchema
|
||||
|
||||
another_parameter: str = Field("Another way to define a default value", description="")
|
||||
my_parameter: str = Field("This is default value", description="What a description")
|
||||
my_parameter_bool: bool = Field(False)
|
||||
package_dependencies: List[str] = Field(["this-is-a-required-package", "another-required-package"], description="")
|
||||
env_vars: List[EnvVar] = [
|
||||
EnvVar(name="SERPER_API_KEY", description="API key for Serper", required=True, default=None),
|
||||
EnvVar(name="API_RATE_LIMIT", description="API rate limit", required=False, default="100")
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def extractor():
|
||||
ext = ToolSpecExtractor()
|
||||
MockTool.__pydantic_core_schema__ = create_mock_schema(MockTool)
|
||||
MockTool.args_schema.__pydantic_core_schema__ = create_mock_schema_args(MockTool.args_schema)
|
||||
return ext
|
||||
|
||||
|
||||
def create_mock_schema(cls):
|
||||
return {
|
||||
"type": "model",
|
||||
"cls": cls,
|
||||
"schema": {
|
||||
"type": "model-fields",
|
||||
"fields": {
|
||||
"name": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "str"}, "default": cls.name}, "metadata": {}},
|
||||
"description": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "str"}, "default": cls.description}, "metadata": {}},
|
||||
"args_schema": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "is-subclass", "cls": BaseModel}, "default": cls.args_schema}, "metadata": {}},
|
||||
"env_vars": {
|
||||
"type": "model-field", "schema": {"type": "default", "schema": {"type": "list", "items_schema": {"type": "model", "cls": "INSPECT CLASS", "schema": {"type": "model-fields", "fields": {"name": {"type": "model-field", "schema": {"type": "str"}, "metadata": {}}, "description": {"type": "model-field", "schema": {"type": "str"}, "metadata": {}}, "required": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "bool"}, "default": True}, "metadata": {}}, "default": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "nullable", "schema": {"type": "str"}}, "default": None}, "metadata": {}},}, "model_name": "EnvVar", "computed_fields": []}, "custom_init": False, "root_model": False, "config": {"title": "EnvVar"}, "ref": "crewai.tools.base_tool.EnvVar:4593650640", "metadata": {"pydantic_js_functions": ["INSPECT __get_pydantic_json_schema__"]}}}, "default": [EnvVar(name='SERPER_API_KEY', description='API key for Serper', required=True, default=None), EnvVar(name='API_RATE_LIMIT', description='API rate limit', required=False, default="100")]}, "metadata": {}
|
||||
}
|
||||
},
|
||||
"model_name": cls.__name__
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def create_mock_schema_args(cls):
|
||||
return {
|
||||
"type": "model",
|
||||
"cls": cls,
|
||||
"schema": {
|
||||
"type": "model-fields",
|
||||
"fields": {
|
||||
"query": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "str"}, "default": "The query parameter"}},
|
||||
"count": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "int"}, "default": 5}, "metadata": {"pydantic_js_updates": {"description": "Number of results to return"}}},
|
||||
"filters": {"type": "model-field", "schema": {"type": "nullable", "schema": {"type": "list", "items_schema": {"type": "str"}}}}
|
||||
},
|
||||
"model_name": cls.__name__
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_unwrap_schema(extractor):
|
||||
nested_schema = {
|
||||
"type": "function-after",
|
||||
@@ -72,19 +44,6 @@ def test_unwrap_schema(extractor):
|
||||
assert result["value"] == "test"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"field, fallback, expected",
|
||||
[
|
||||
({"schema": {"default": "test_value"}}, None, "test_value"),
|
||||
({}, "fallback_value", "fallback_value"),
|
||||
({"schema": {"default": 123}}, "fallback_value", "fallback_value")
|
||||
]
|
||||
)
|
||||
def test_extract_field_default(extractor, field, fallback, expected):
|
||||
result = extractor._extract_field_default(field, fallback=fallback)
|
||||
assert result == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"schema, expected",
|
||||
[
|
||||
@@ -112,7 +71,7 @@ def test_extract_param_type(extractor, info, expected_type):
|
||||
assert extractor._extract_param_type(info) == expected_type
|
||||
|
||||
|
||||
def test_extract_tool_info(extractor):
|
||||
def test_extract_all_tools(extractor):
|
||||
with mock.patch("generate_tool_specs.dir", return_value=["MockTool"]), \
|
||||
mock.patch("generate_tool_specs.getattr", return_value=MockTool):
|
||||
extractor.extract_all_tools()
|
||||
@@ -120,6 +79,16 @@ def test_extract_tool_info(extractor):
|
||||
assert len(extractor.tools_spec) == 1
|
||||
tool_info = extractor.tools_spec[0]
|
||||
|
||||
assert tool_info.keys() == {
|
||||
"name",
|
||||
"humanized_name",
|
||||
"description",
|
||||
"run_params",
|
||||
"env_vars",
|
||||
"init_params",
|
||||
"package_dependencies",
|
||||
}
|
||||
|
||||
assert tool_info["name"] == "MockTool"
|
||||
assert tool_info["humanized_name"] == "Mock Search Tool"
|
||||
assert tool_info["description"] == "A tool that mocks search functionality"
|
||||
@@ -142,12 +111,16 @@ def test_extract_tool_info(extractor):
|
||||
params = {p["name"]: p for p in tool_info["run_params"]}
|
||||
assert params["query"]["description"] == "The query parameter"
|
||||
assert params["query"]["type"] == "str"
|
||||
assert params["query"]["default"] == ""
|
||||
|
||||
assert params["count"]["description"] == "Number of results to return"
|
||||
assert params["count"]["type"] == "int"
|
||||
assert params["count"]["default"] == 5
|
||||
|
||||
assert params["filters"]["description"] == ""
|
||||
assert params["filters"]["description"] == "Optional filters to apply"
|
||||
assert params["filters"]["type"] == "list[str]"
|
||||
assert params["filters"]["default"] == ""
|
||||
|
||||
assert tool_info["package_dependencies"] == ["this-is-a-required-package", "another-required-package"]
|
||||
|
||||
|
||||
def test_save_to_json(extractor, tmp_path):
|
||||
|
||||
Reference in New Issue
Block a user