Support to collect extra package dependencies of Tools (#330)

* feat: add explictly package_dependencies in the Tools

* feat: collect package_dependencies from Tool to add in tool.specs.json

* feat: add default value in run_params Tool' specs

* fix: support get boolean values

This commit also refactor test to make easier define newest attributes into a Tool
This commit is contained in:
Lucas Gomide
2025-06-16 11:09:19 -03:00
committed by GitHub
parent 5a99f07765
commit fac32d9503
29 changed files with 129 additions and 127 deletions

View File

@@ -1,4 +1,4 @@
from typing import Type, Optional, Dict, Any
from typing import Type, Optional, Dict, Any, List
import os
import json
import uuid
@@ -29,6 +29,7 @@ class BedrockInvokeAgentTool(BaseTool):
session_id: str = None
enable_trace: bool = False
end_session: bool = False
package_dependencies: List[str] = ["boto3"]
def __init__(
self,
@@ -51,7 +52,7 @@ class BedrockInvokeAgentTool(BaseTool):
description (Optional[str]): Custom description for the tool
"""
super().__init__(**kwargs)
# Get values from environment variables if not provided
self.agent_id = agent_id or os.getenv('BEDROCK_AGENT_ID')
self.agent_alias_id = agent_alias_id or os.getenv('BEDROCK_AGENT_ALIAS_ID')
@@ -62,7 +63,7 @@ class BedrockInvokeAgentTool(BaseTool):
# Update the description if provided
if description:
self.description = description
# Validate parameters
self._validate_parameters()
@@ -74,17 +75,17 @@ class BedrockInvokeAgentTool(BaseTool):
raise BedrockValidationError("agent_id cannot be empty")
if not isinstance(self.agent_id, str):
raise BedrockValidationError("agent_id must be a string")
# Validate agent_alias_id
if not self.agent_alias_id:
raise BedrockValidationError("agent_alias_id cannot be empty")
if not isinstance(self.agent_alias_id, str):
raise BedrockValidationError("agent_alias_id must be a string")
# Validate session_id if provided
if self.session_id and not isinstance(self.session_id, str):
raise BedrockValidationError("session_id must be a string")
except BedrockValidationError as e:
raise BedrockValidationError(f"Parameter validation failed: {str(e)}")
@@ -123,7 +124,7 @@ Below is the users query or task. Complete it and answer it consicely and to the
# Process the response
completion = ""
# Check if response contains a completion field
if 'completion' in response:
# Process streaming response format
@@ -134,7 +135,7 @@ Below is the users query or task. Complete it and answer it consicely and to the
completion += chunk_bytes.decode('utf-8')
else:
completion += str(chunk_bytes)
# If no completion found in streaming format, try direct format
if not completion and 'chunk' in response and 'bytes' in response['chunk']:
chunk_bytes = response['chunk']['bytes']
@@ -142,31 +143,31 @@ Below is the users query or task. Complete it and answer it consicely and to the
completion = chunk_bytes.decode('utf-8')
else:
completion = str(chunk_bytes)
# If still no completion, return debug info
if not completion:
debug_info = {
"error": "Could not extract completion from response",
"response_keys": list(response.keys())
}
# Add more debug info
if 'chunk' in response:
debug_info["chunk_keys"] = list(response['chunk'].keys())
raise BedrockAgentError(f"Failed to extract completion: {json.dumps(debug_info, indent=2)}")
return completion
except ClientError as e:
error_code = "Unknown"
error_message = str(e)
# Try to extract error code if available
if hasattr(e, 'response') and 'Error' in e.response:
error_code = e.response['Error'].get('Code', 'Unknown')
error_message = e.response['Error'].get('Message', str(e))
raise BedrockAgentError(f"Error ({error_code}): {error_message}")
except BedrockAgentError:
# Re-raise BedrockAgentError exceptions

View File

@@ -26,6 +26,7 @@ class BedrockKBRetrieverTool(BaseTool):
retrieval_configuration: Optional[Dict[str, Any]] = None
guardrail_configuration: Optional[Dict[str, Any]] = None
next_token: Optional[str] = None
package_dependencies: List[str] = ["boto3"]
def __init__(
self,
@@ -46,13 +47,13 @@ class BedrockKBRetrieverTool(BaseTool):
next_token (Optional[str], optional): Token for retrieving the next batch of results. Defaults to None.
"""
super().__init__(**kwargs)
# Get knowledge_base_id from environment variable if not provided
self.knowledge_base_id = knowledge_base_id or os.getenv('BEDROCK_KB_ID')
self.number_of_results = number_of_results
self.guardrail_configuration = guardrail_configuration
self.next_token = next_token
# Initialize retrieval_configuration with provided parameters or use the one provided
if retrieval_configuration is None:
self.retrieval_configuration = self._build_retrieval_configuration()
@@ -67,16 +68,16 @@ class BedrockKBRetrieverTool(BaseTool):
def _build_retrieval_configuration(self) -> Dict[str, Any]:
"""Build the retrieval configuration based on provided parameters.
Returns:
Dict[str, Any]: The constructed retrieval configuration
"""
vector_search_config = {}
# Add number of results if provided
if self.number_of_results is not None:
vector_search_config["numberOfResults"] = self.number_of_results
return {"vectorSearchConfiguration": vector_search_config}
def _validate_parameters(self):
@@ -91,7 +92,7 @@ class BedrockKBRetrieverTool(BaseTool):
raise BedrockValidationError("knowledge_base_id must be 10 characters or less")
if not all(c.isalnum() for c in self.knowledge_base_id):
raise BedrockValidationError("knowledge_base_id must contain only alphanumeric characters")
# Validate next_token if provided
if self.next_token:
if not isinstance(self.next_token, str):
@@ -100,23 +101,23 @@ class BedrockKBRetrieverTool(BaseTool):
raise BedrockValidationError("next_token must be between 1 and 2048 characters")
if ' ' in self.next_token:
raise BedrockValidationError("next_token cannot contain spaces")
# Validate number_of_results if provided
if self.number_of_results is not None:
if not isinstance(self.number_of_results, int):
raise BedrockValidationError("number_of_results must be an integer")
if self.number_of_results < 1:
raise BedrockValidationError("number_of_results must be greater than 0")
except BedrockValidationError as e:
raise BedrockValidationError(f"Parameter validation failed: {str(e)}")
def _process_retrieval_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
"""Process a single retrieval result from Bedrock Knowledge Base.
Args:
result (Dict[str, Any]): Raw result from Bedrock Knowledge Base
Returns:
Dict[str, Any]: Processed result with standardized format
"""
@@ -124,12 +125,12 @@ class BedrockKBRetrieverTool(BaseTool):
content_obj = result.get('content', {})
content = content_obj.get('text', '')
content_type = content_obj.get('type', 'text')
# Extract location information
location = result.get('location', {})
location_type = location.get('type', 'unknown')
source_uri = None
# Map for location types and their URI fields
location_mapping = {
's3Location': {'field': 'uri', 'type': 'S3'},
@@ -141,7 +142,7 @@ class BedrockKBRetrieverTool(BaseTool):
'kendraDocumentLocation': {'field': 'uri', 'type': 'KendraDocument'},
'sqlLocation': {'field': 'query', 'type': 'SQL'}
}
# Extract the URI based on location type
for loc_key, config in location_mapping.items():
if loc_key in location:
@@ -149,7 +150,7 @@ class BedrockKBRetrieverTool(BaseTool):
if not location_type or location_type == 'unknown':
location_type = config['type']
break
# Create result object
result_object = {
'content': content,
@@ -157,22 +158,22 @@ class BedrockKBRetrieverTool(BaseTool):
'source_type': location_type,
'source_uri': source_uri
}
# Add optional fields if available
if 'score' in result:
result_object['score'] = result['score']
if 'metadata' in result:
result_object['metadata'] = result['metadata']
# Handle byte content if present
if 'byteContent' in content_obj:
result_object['byte_content'] = content_obj['byteContent']
# Handle row content if present
if 'row' in content_obj:
result_object['row_content'] = content_obj['row']
return result_object
def _run(self, query: str) -> str:
@@ -201,10 +202,10 @@ class BedrockKBRetrieverTool(BaseTool):
# Add optional parameters if provided
if self.retrieval_configuration:
retrieve_params['retrievalConfiguration'] = self.retrieval_configuration
if self.guardrail_configuration:
retrieve_params['guardrailConfiguration'] = self.guardrail_configuration
if self.next_token:
retrieve_params['nextToken'] = self.next_token
@@ -223,10 +224,10 @@ class BedrockKBRetrieverTool(BaseTool):
response_object["results"] = results
else:
response_object["message"] = "No results found for the given query."
if "nextToken" in response:
response_object["nextToken"] = response["nextToken"]
if "guardrailAction" in response:
response_object["guardrailAction"] = response["guardrailAction"]
@@ -236,12 +237,12 @@ class BedrockKBRetrieverTool(BaseTool):
except ClientError as e:
error_code = "Unknown"
error_message = str(e)
# Try to extract error code if available
if hasattr(e, 'response') and 'Error' in e.response:
error_code = e.response['Error'].get('Code', 'Unknown')
error_message = e.response['Error'].get('Message', str(e))
raise BedrockKnowledgeBaseError(f"Error ({error_code}): {error_message}")
except Exception as e:
raise BedrockKnowledgeBaseError(f"Unexpected error: {str(e)}")

View File

@@ -1,4 +1,4 @@
from typing import Any, Type
from typing import Any, Type, List
import os
from crewai.tools import BaseTool
@@ -15,6 +15,7 @@ class S3ReaderTool(BaseTool):
name: str = "S3 Reader Tool"
description: str = "Reads a file from Amazon S3 given an S3 file path"
args_schema: Type[BaseModel] = S3ReaderToolInput
package_dependencies: List[str] = ["boto3"]
def _run(self, file_path: str) -> str:
try:

View File

@@ -1,4 +1,4 @@
from typing import Any, Type
from typing import Type, List
import os
from crewai.tools import BaseTool
@@ -14,6 +14,7 @@ class S3WriterTool(BaseTool):
name: str = "S3 Writer Tool"
description: str = "Writes content to a file in Amazon S3 given an S3 file path"
args_schema: Type[BaseModel] = S3WriterToolInput
package_dependencies: List[str] = ["boto3"]
def _run(self, file_path: str, content: str) -> str:
try:

View File

@@ -1,10 +1,10 @@
import os
import secrets
from typing import Any, Dict, List, Optional, Text, Type
from typing import Any, Dict, List, Optional, Type
from crewai.tools import BaseTool
from openai import OpenAI
from pydantic import BaseModel
from pydantic import BaseModel, Field
class AIMindToolConstants:
@@ -16,7 +16,7 @@ class AIMindToolConstants:
class AIMindToolInputSchema(BaseModel):
"""Input for AIMind Tool."""
query: str = "Question in natural language to ask the AI-Mind"
query: str = Field(description="Question in natural language to ask the AI-Mind")
class AIMindTool(BaseTool):
@@ -31,9 +31,10 @@ class AIMindTool(BaseTool):
args_schema: Type[BaseModel] = AIMindToolInputSchema
api_key: Optional[str] = None
datasources: Optional[List[Dict[str, Any]]] = None
mind_name: Optional[Text] = None
mind_name: Optional[str] = None
package_dependencies: List[str] = ["minds-sdk"]
def __init__(self, api_key: Optional[Text] = None, **kwargs):
def __init__(self, api_key: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.api_key = api_key or os.getenv("MINDS_API_KEY")
if not self.api_key:
@@ -72,7 +73,7 @@ class AIMindTool(BaseTool):
def _run(
self,
query: Text
query: str
):
# Run the query on the AI-Mind.
# The Minds API is OpenAI compatible and therefore, the OpenAI client can be used.

View File

@@ -38,6 +38,7 @@ class ApifyActorsTool(BaseTool):
print(f"Content: {result.get('markdown', 'N/A')[:100]}...")
"""
actor_tool: '_ApifyActorsTool' = Field(description="Apify Actor Tool")
package_dependencies: List[str] = ["langchain-apify"]
def __init__(
self,

View File

@@ -1,5 +1,5 @@
import os
from typing import Any, Optional, Type
from typing import Any, Optional, Type, List
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
@@ -19,6 +19,7 @@ class BrowserbaseLoadTool(BaseTool):
session_id: Optional[str] = None
proxy: Optional[bool] = None
browserbase: Optional[Any] = None
package_dependencies: List[str] = ["browserbase"]
def __init__(
self,

View File

@@ -3,13 +3,13 @@ from typing import Type
from crewai.tools import BaseTool
from openai import OpenAI
from pydantic import BaseModel
from pydantic import BaseModel, Field
class ImagePromptSchema(BaseModel):
"""Input for Dall-E Tool."""
image_description: str = "Description of the image to be generated by Dall-E."
image_description: str = Field(description="Description of the image to be generated by Dall-E.")
class DallETool(BaseTool):

View File

@@ -69,6 +69,7 @@ class DatabricksQueryTool(BaseTool):
default_warehouse_id: Optional[str] = None
_workspace_client: Optional["WorkspaceClient"] = None
package_dependencies: List[str] = ["databricks-sdk"]
def __init__(
self,

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional, Type
from typing import Any, Optional, Type, List
from pydantic import BaseModel, Field
from crewai.tools import BaseTool
@@ -35,6 +35,7 @@ class EXASearchTool(BaseTool):
content: Optional[bool] = False
summary: Optional[bool] = False
type: Optional[str] = "auto"
package_dependencies: List[str] = ["exa_py"]
def __init__(
self,

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional, Type
from typing import Any, Optional, Type, List
from crewai.tools import BaseTool
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
@@ -55,6 +55,7 @@ class FirecrawlCrawlWebsiteTool(BaseTool):
}
)
_firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None)
package_dependencies: List[str] = ["firecrawl-py"]
def __init__(self, api_key: Optional[str] = None, **kwargs):
super().__init__(**kwargs)

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional, Type, Dict
from typing import Any, Optional, Type, Dict, List
from crewai.tools import BaseTool
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
@@ -48,6 +48,7 @@ class FirecrawlScrapeWebsiteTool(BaseTool):
)
_firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None)
package_dependencies: List[str] = ["firecrawl-py"]
def __init__(self, api_key: Optional[str] = None, **kwargs):
super().__init__(**kwargs)

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Type
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, List
from crewai.tools import BaseTool
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
@@ -57,6 +57,7 @@ class FirecrawlSearchTool(BaseTool):
}
)
_firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None)
package_dependencies: List[str] = ["firecrawl-py"]
def __init__(self, api_key: Optional[str] = None, **kwargs):
super().__init__(**kwargs)

View File

@@ -1,5 +1,5 @@
import os
from typing import Any, Optional, Type, Dict, Literal, Union
from typing import Any, Optional, Type, Dict, Literal, Union, List
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
@@ -25,6 +25,7 @@ class HyperbrowserLoadTool(BaseTool):
args_schema: Type[BaseModel] = HyperbrowserLoadToolSchema
api_key: Optional[str] = None
hyperbrowser: Optional[Any] = None
package_dependencies: List[str] = ["hyperbrowser"]
def __init__(self, api_key: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
@@ -65,7 +66,7 @@ class HyperbrowserLoadTool(BaseTool):
if "scrape_options" in params:
params["scrape_options"] = ScrapeOptions(**params["scrape_options"])
return params
def _extract_content(self, data: Union[Any, None]):
"""Extract content from response data."""
content = ""
@@ -81,7 +82,7 @@ class HyperbrowserLoadTool(BaseTool):
raise ImportError(
"`hyperbrowser` package not found, please run `pip install hyperbrowser`"
)
params = self._prepare_params(params)
if operation == 'scrape':

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, List
from crewai.tools import BaseTool
@@ -23,6 +23,7 @@ class LinkupSearchTool(BaseTool):
"Performs an API call to Linkup to retrieve contextual information."
)
_client: LinkupClient = PrivateAttr() # type: ignore
package_dependencies: List[str] = ["linkup-sdk"]
def __init__(self, api_key: str):
"""

View File

@@ -1,6 +1,6 @@
"""Multion tool spec."""
from typing import Any, Optional
from typing import Any, Optional, List
from crewai.tools import BaseTool
@@ -16,6 +16,7 @@ class MultiOnTool(BaseTool):
session_id: Optional[str] = None
local: bool = False
max_steps: int = 3
package_dependencies: List[str] = ["multion"]
def __init__(
self,

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Type
from typing import TYPE_CHECKING, Any, Type, List
from crewai.tools import BaseTool
from pydantic import BaseModel, ConfigDict, Field
@@ -41,6 +41,7 @@ class PatronusLocalEvaluatorTool(BaseTool):
evaluated_model_gold_answer: str
model_config = ConfigDict(arbitrary_types_allowed=True)
package_dependencies: List[str] = ["patronus"]
def __init__(
self,

View File

@@ -1,6 +1,6 @@
import json
import os
from typing import Any, Callable, Optional, Type
from typing import Any, Callable, Optional, Type, List
try:
@@ -74,6 +74,7 @@ class QdrantVectorSearchTool(BaseTool):
default=None,
description="A custom embedding function to use for vectorization. If not provided, the default model will be used.",
)
package_dependencies: List[str] = ["qdrant-client"]
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@@ -1,5 +1,5 @@
import os
from typing import TYPE_CHECKING, Any, Optional, Type
from typing import TYPE_CHECKING, Any, Optional, Type, List
from urllib.parse import urlparse
from crewai.tools import BaseTool
@@ -67,6 +67,7 @@ class ScrapegraphScrapeTool(BaseTool):
api_key: Optional[str] = None
enable_logging: bool = False
_client: Optional["Client"] = None
package_dependencies: List[str] = ["scrapegraph-py"]
def __init__(
self,

View File

@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, Literal, Optional, Type
from typing import Any, Dict, Literal, Optional, Type, List
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
@@ -28,6 +28,7 @@ class ScrapflyScrapeWebsiteTool(BaseTool):
args_schema: Type[BaseModel] = ScrapflyScrapeWebsiteToolSchema
api_key: str = None
scrapfly: Optional[Any] = None
package_dependencies: List[str] = ["scrapfly-sdk"]
def __init__(self, api_key: str):
super().__init__()

View File

@@ -1,6 +1,6 @@
import re
import time
from typing import Any, Optional, Type
from typing import Any, Optional, Type, List
from urllib.parse import urlparse
from crewai.tools import BaseTool
@@ -58,6 +58,7 @@ class SeleniumScrapingTool(BaseTool):
css_element: Optional[str] = None
return_html: Optional[bool] = False
_by: Optional[Any] = None
package_dependencies: List[str] = ["selenium", "webdriver-manager"]
def __init__(
self,

View File

@@ -1,6 +1,6 @@
import os
import re
from typing import Any, Optional, Union
from typing import Any, Optional, Union, List
from crewai.tools import BaseTool
@@ -8,6 +8,8 @@ from crewai.tools import BaseTool
class SerpApiBaseTool(BaseTool):
"""Base class for SerpApi functionality with shared capabilities."""
package_dependencies: List[str] = ["serpapi"]
client: Optional[Any] = None
def __init__(self, **kwargs):

View File

@@ -99,6 +99,7 @@ class SnowflakeSearchTool(BaseTool):
_pool_lock: Optional[asyncio.Lock] = None
_thread_pool: Optional[ThreadPoolExecutor] = None
_model_rebuilt: bool = False
package_dependencies: List[str] = ["snowflake-connector-python", "snowflake-sqlalchemy", "cryptography"]
def __init__(self, **data):
"""Initialize SnowflakeSearchTool."""

View File

@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, Literal, Optional, Type
from typing import Any, Dict, Literal, Optional, Type, List
from urllib.parse import unquote, urlparse
from crewai.tools import BaseTool
@@ -53,6 +53,7 @@ class SpiderTool(BaseTool):
spider: Any = None
log_failures: bool = True
config: SpiderToolConfig = SpiderToolConfig()
package_dependencies: List[str] = ["spider-client"]
def __init__(
self,

View File

@@ -88,6 +88,7 @@ class StagehandToolSchema(BaseModel):
class StagehandTool(BaseTool):
package_dependencies: List[str] = ["stagehand"]
"""
A tool that uses Stagehand to automate web browser interactions using natural language.

View File

@@ -26,6 +26,7 @@ class TavilyExtractorToolSchema(BaseModel):
class TavilyExtractorTool(BaseTool):
package_dependencies: List[str] = ["tavily-python"]
"""
Tool that uses the Tavily API to extract content from web pages.

View File

@@ -1,6 +1,6 @@
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
from typing import Optional, Type, Any, Union, Literal, Sequence
from typing import Optional, Type, Any, Union, Literal, Sequence, List
from dotenv import load_dotenv
import os
import json
@@ -101,6 +101,7 @@ class TavilySearchTool(BaseTool):
default=1000,
description="Maximum length for the 'content' of each search result to avoid context window issues.",
)
package_dependencies: List[str] = ["tavily-python"]
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)

View File

@@ -1,6 +1,6 @@
import json
import os
from typing import Any, Optional, Type
from typing import Any, Optional, Type, List
try:
import weaviate
@@ -31,6 +31,7 @@ class WeaviateToolSchema(BaseModel):
class WeaviateVectorSearchTool(BaseTool):
"""Tool to search the Weaviate database"""
package_dependencies: List[str] = ["weaviate-client"]
name: str = "WeaviateVectorSearchTool"
description: str = "A tool to search the Weaviate database for relevant information on internal documents."
args_schema: Type[BaseModel] = WeaviateToolSchema
@@ -48,6 +49,7 @@ class WeaviateVectorSearchTool(BaseTool):
...,
description="The API key for the Weaviate cluster",
)
package_dependencies: List[str] = ["weaviate-client"]
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@@ -1,12 +1,12 @@
import json
from typing import List, Optional
from typing import List, Optional, Type
import pytest
from pydantic import BaseModel, Field
from unittest import mock
from generate_tool_specs import ToolSpecExtractor
from crewai.tools.base_tool import EnvVar
from crewai.tools.base_tool import BaseTool, EnvVar
class MockToolSchema(BaseModel):
query: str = Field(..., description="The query parameter")
@@ -14,54 +14,26 @@ class MockToolSchema(BaseModel):
filters: Optional[List[str]] = Field(None, description="Optional filters to apply")
class MockTool:
name = "Mock Search Tool"
description = "A tool that mocks search functionality"
args_schema = MockToolSchema
class MockTool(BaseTool):
name: str = "Mock Search Tool"
description: str = "A tool that mocks search functionality"
args_schema: Type[BaseModel] = MockToolSchema
another_parameter: str = Field("Another way to define a default value", description="")
my_parameter: str = Field("This is default value", description="What a description")
my_parameter_bool: bool = Field(False)
package_dependencies: List[str] = Field(["this-is-a-required-package", "another-required-package"], description="")
env_vars: List[EnvVar] = [
EnvVar(name="SERPER_API_KEY", description="API key for Serper", required=True, default=None),
EnvVar(name="API_RATE_LIMIT", description="API rate limit", required=False, default="100")
]
@pytest.fixture
def extractor():
ext = ToolSpecExtractor()
MockTool.__pydantic_core_schema__ = create_mock_schema(MockTool)
MockTool.args_schema.__pydantic_core_schema__ = create_mock_schema_args(MockTool.args_schema)
return ext
def create_mock_schema(cls):
return {
"type": "model",
"cls": cls,
"schema": {
"type": "model-fields",
"fields": {
"name": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "str"}, "default": cls.name}, "metadata": {}},
"description": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "str"}, "default": cls.description}, "metadata": {}},
"args_schema": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "is-subclass", "cls": BaseModel}, "default": cls.args_schema}, "metadata": {}},
"env_vars": {
"type": "model-field", "schema": {"type": "default", "schema": {"type": "list", "items_schema": {"type": "model", "cls": "INSPECT CLASS", "schema": {"type": "model-fields", "fields": {"name": {"type": "model-field", "schema": {"type": "str"}, "metadata": {}}, "description": {"type": "model-field", "schema": {"type": "str"}, "metadata": {}}, "required": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "bool"}, "default": True}, "metadata": {}}, "default": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "nullable", "schema": {"type": "str"}}, "default": None}, "metadata": {}},}, "model_name": "EnvVar", "computed_fields": []}, "custom_init": False, "root_model": False, "config": {"title": "EnvVar"}, "ref": "crewai.tools.base_tool.EnvVar:4593650640", "metadata": {"pydantic_js_functions": ["INSPECT __get_pydantic_json_schema__"]}}}, "default": [EnvVar(name='SERPER_API_KEY', description='API key for Serper', required=True, default=None), EnvVar(name='API_RATE_LIMIT', description='API rate limit', required=False, default="100")]}, "metadata": {}
}
},
"model_name": cls.__name__
}
}
def create_mock_schema_args(cls):
return {
"type": "model",
"cls": cls,
"schema": {
"type": "model-fields",
"fields": {
"query": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "str"}, "default": "The query parameter"}},
"count": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "int"}, "default": 5}, "metadata": {"pydantic_js_updates": {"description": "Number of results to return"}}},
"filters": {"type": "model-field", "schema": {"type": "nullable", "schema": {"type": "list", "items_schema": {"type": "str"}}}}
},
"model_name": cls.__name__
}
}
def test_unwrap_schema(extractor):
nested_schema = {
"type": "function-after",
@@ -72,19 +44,6 @@ def test_unwrap_schema(extractor):
assert result["value"] == "test"
@pytest.mark.parametrize(
"field, fallback, expected",
[
({"schema": {"default": "test_value"}}, None, "test_value"),
({}, "fallback_value", "fallback_value"),
({"schema": {"default": 123}}, "fallback_value", "fallback_value")
]
)
def test_extract_field_default(extractor, field, fallback, expected):
result = extractor._extract_field_default(field, fallback=fallback)
assert result == expected
@pytest.mark.parametrize(
"schema, expected",
[
@@ -112,7 +71,7 @@ def test_extract_param_type(extractor, info, expected_type):
assert extractor._extract_param_type(info) == expected_type
def test_extract_tool_info(extractor):
def test_extract_all_tools(extractor):
with mock.patch("generate_tool_specs.dir", return_value=["MockTool"]), \
mock.patch("generate_tool_specs.getattr", return_value=MockTool):
extractor.extract_all_tools()
@@ -120,6 +79,16 @@ def test_extract_tool_info(extractor):
assert len(extractor.tools_spec) == 1
tool_info = extractor.tools_spec[0]
assert tool_info.keys() == {
"name",
"humanized_name",
"description",
"run_params",
"env_vars",
"init_params",
"package_dependencies",
}
assert tool_info["name"] == "MockTool"
assert tool_info["humanized_name"] == "Mock Search Tool"
assert tool_info["description"] == "A tool that mocks search functionality"
@@ -142,12 +111,16 @@ def test_extract_tool_info(extractor):
params = {p["name"]: p for p in tool_info["run_params"]}
assert params["query"]["description"] == "The query parameter"
assert params["query"]["type"] == "str"
assert params["query"]["default"] == ""
assert params["count"]["description"] == "Number of results to return"
assert params["count"]["type"] == "int"
assert params["count"]["default"] == 5
assert params["filters"]["description"] == ""
assert params["filters"]["description"] == "Optional filters to apply"
assert params["filters"]["type"] == "list[str]"
assert params["filters"]["default"] == ""
assert tool_info["package_dependencies"] == ["this-is-a-required-package", "another-required-package"]
def test_save_to_json(extractor, tmp_path):