Support to collect extra package dependencies of Tools (#330)

* feat: add explictly package_dependencies in the Tools

* feat: collect package_dependencies from Tool to add in tool.specs.json

* feat: add default value in run_params Tool' specs

* fix: support get boolean values

This commit also refactor test to make easier define newest attributes into a Tool
This commit is contained in:
Lucas Gomide
2025-06-16 11:09:19 -03:00
committed by GitHub
parent 5a99f07765
commit fac32d9503
29 changed files with 129 additions and 127 deletions

View File

@@ -1,4 +1,4 @@
from typing import Type, Optional, Dict, Any
from typing import Type, Optional, Dict, Any, List
import os
import json
import uuid
@@ -29,6 +29,7 @@ class BedrockInvokeAgentTool(BaseTool):
session_id: str = None
enable_trace: bool = False
end_session: bool = False
package_dependencies: List[str] = ["boto3"]
def __init__(
self,

View File

@@ -26,6 +26,7 @@ class BedrockKBRetrieverTool(BaseTool):
retrieval_configuration: Optional[Dict[str, Any]] = None
guardrail_configuration: Optional[Dict[str, Any]] = None
next_token: Optional[str] = None
package_dependencies: List[str] = ["boto3"]
def __init__(
self,

View File

@@ -1,4 +1,4 @@
from typing import Any, Type
from typing import Any, Type, List
import os
from crewai.tools import BaseTool
@@ -15,6 +15,7 @@ class S3ReaderTool(BaseTool):
name: str = "S3 Reader Tool"
description: str = "Reads a file from Amazon S3 given an S3 file path"
args_schema: Type[BaseModel] = S3ReaderToolInput
package_dependencies: List[str] = ["boto3"]
def _run(self, file_path: str) -> str:
try:

View File

@@ -1,4 +1,4 @@
from typing import Any, Type
from typing import Type, List
import os
from crewai.tools import BaseTool
@@ -14,6 +14,7 @@ class S3WriterTool(BaseTool):
name: str = "S3 Writer Tool"
description: str = "Writes content to a file in Amazon S3 given an S3 file path"
args_schema: Type[BaseModel] = S3WriterToolInput
package_dependencies: List[str] = ["boto3"]
def _run(self, file_path: str, content: str) -> str:
try:

View File

@@ -1,10 +1,10 @@
import os
import secrets
from typing import Any, Dict, List, Optional, Text, Type
from typing import Any, Dict, List, Optional, Type
from crewai.tools import BaseTool
from openai import OpenAI
from pydantic import BaseModel
from pydantic import BaseModel, Field
class AIMindToolConstants:
@@ -16,7 +16,7 @@ class AIMindToolConstants:
class AIMindToolInputSchema(BaseModel):
"""Input for AIMind Tool."""
query: str = "Question in natural language to ask the AI-Mind"
query: str = Field(description="Question in natural language to ask the AI-Mind")
class AIMindTool(BaseTool):
@@ -31,9 +31,10 @@ class AIMindTool(BaseTool):
args_schema: Type[BaseModel] = AIMindToolInputSchema
api_key: Optional[str] = None
datasources: Optional[List[Dict[str, Any]]] = None
mind_name: Optional[Text] = None
mind_name: Optional[str] = None
package_dependencies: List[str] = ["minds-sdk"]
def __init__(self, api_key: Optional[Text] = None, **kwargs):
def __init__(self, api_key: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.api_key = api_key or os.getenv("MINDS_API_KEY")
if not self.api_key:
@@ -72,7 +73,7 @@ class AIMindTool(BaseTool):
def _run(
self,
query: Text
query: str
):
# Run the query on the AI-Mind.
# The Minds API is OpenAI compatible and therefore, the OpenAI client can be used.

View File

@@ -38,6 +38,7 @@ class ApifyActorsTool(BaseTool):
print(f"Content: {result.get('markdown', 'N/A')[:100]}...")
"""
actor_tool: '_ApifyActorsTool' = Field(description="Apify Actor Tool")
package_dependencies: List[str] = ["langchain-apify"]
def __init__(
self,

View File

@@ -1,5 +1,5 @@
import os
from typing import Any, Optional, Type
from typing import Any, Optional, Type, List
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
@@ -19,6 +19,7 @@ class BrowserbaseLoadTool(BaseTool):
session_id: Optional[str] = None
proxy: Optional[bool] = None
browserbase: Optional[Any] = None
package_dependencies: List[str] = ["browserbase"]
def __init__(
self,

View File

@@ -3,13 +3,13 @@ from typing import Type
from crewai.tools import BaseTool
from openai import OpenAI
from pydantic import BaseModel
from pydantic import BaseModel, Field
class ImagePromptSchema(BaseModel):
"""Input for Dall-E Tool."""
image_description: str = "Description of the image to be generated by Dall-E."
image_description: str = Field(description="Description of the image to be generated by Dall-E.")
class DallETool(BaseTool):

View File

@@ -69,6 +69,7 @@ class DatabricksQueryTool(BaseTool):
default_warehouse_id: Optional[str] = None
_workspace_client: Optional["WorkspaceClient"] = None
package_dependencies: List[str] = ["databricks-sdk"]
def __init__(
self,

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional, Type
from typing import Any, Optional, Type, List
from pydantic import BaseModel, Field
from crewai.tools import BaseTool
@@ -35,6 +35,7 @@ class EXASearchTool(BaseTool):
content: Optional[bool] = False
summary: Optional[bool] = False
type: Optional[str] = "auto"
package_dependencies: List[str] = ["exa_py"]
def __init__(
self,

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional, Type
from typing import Any, Optional, Type, List
from crewai.tools import BaseTool
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
@@ -55,6 +55,7 @@ class FirecrawlCrawlWebsiteTool(BaseTool):
}
)
_firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None)
package_dependencies: List[str] = ["firecrawl-py"]
def __init__(self, api_key: Optional[str] = None, **kwargs):
super().__init__(**kwargs)

View File

@@ -1,4 +1,4 @@
from typing import Any, Optional, Type, Dict
from typing import Any, Optional, Type, Dict, List
from crewai.tools import BaseTool
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
@@ -48,6 +48,7 @@ class FirecrawlScrapeWebsiteTool(BaseTool):
)
_firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None)
package_dependencies: List[str] = ["firecrawl-py"]
def __init__(self, api_key: Optional[str] = None, **kwargs):
super().__init__(**kwargs)

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Type
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, List
from crewai.tools import BaseTool
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
@@ -57,6 +57,7 @@ class FirecrawlSearchTool(BaseTool):
}
)
_firecrawl: Optional["FirecrawlApp"] = PrivateAttr(None)
package_dependencies: List[str] = ["firecrawl-py"]
def __init__(self, api_key: Optional[str] = None, **kwargs):
super().__init__(**kwargs)

View File

@@ -1,5 +1,5 @@
import os
from typing import Any, Optional, Type, Dict, Literal, Union
from typing import Any, Optional, Type, Dict, Literal, Union, List
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
@@ -25,6 +25,7 @@ class HyperbrowserLoadTool(BaseTool):
args_schema: Type[BaseModel] = HyperbrowserLoadToolSchema
api_key: Optional[str] = None
hyperbrowser: Optional[Any] = None
package_dependencies: List[str] = ["hyperbrowser"]
def __init__(self, api_key: Optional[str] = None, **kwargs):
super().__init__(**kwargs)

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, List
from crewai.tools import BaseTool
@@ -23,6 +23,7 @@ class LinkupSearchTool(BaseTool):
"Performs an API call to Linkup to retrieve contextual information."
)
_client: LinkupClient = PrivateAttr() # type: ignore
package_dependencies: List[str] = ["linkup-sdk"]
def __init__(self, api_key: str):
"""

View File

@@ -1,6 +1,6 @@
"""Multion tool spec."""
from typing import Any, Optional
from typing import Any, Optional, List
from crewai.tools import BaseTool
@@ -16,6 +16,7 @@ class MultiOnTool(BaseTool):
session_id: Optional[str] = None
local: bool = False
max_steps: int = 3
package_dependencies: List[str] = ["multion"]
def __init__(
self,

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Type
from typing import TYPE_CHECKING, Any, Type, List
from crewai.tools import BaseTool
from pydantic import BaseModel, ConfigDict, Field
@@ -41,6 +41,7 @@ class PatronusLocalEvaluatorTool(BaseTool):
evaluated_model_gold_answer: str
model_config = ConfigDict(arbitrary_types_allowed=True)
package_dependencies: List[str] = ["patronus"]
def __init__(
self,

View File

@@ -1,6 +1,6 @@
import json
import os
from typing import Any, Callable, Optional, Type
from typing import Any, Callable, Optional, Type, List
try:
@@ -74,6 +74,7 @@ class QdrantVectorSearchTool(BaseTool):
default=None,
description="A custom embedding function to use for vectorization. If not provided, the default model will be used.",
)
package_dependencies: List[str] = ["qdrant-client"]
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@@ -1,5 +1,5 @@
import os
from typing import TYPE_CHECKING, Any, Optional, Type
from typing import TYPE_CHECKING, Any, Optional, Type, List
from urllib.parse import urlparse
from crewai.tools import BaseTool
@@ -67,6 +67,7 @@ class ScrapegraphScrapeTool(BaseTool):
api_key: Optional[str] = None
enable_logging: bool = False
_client: Optional["Client"] = None
package_dependencies: List[str] = ["scrapegraph-py"]
def __init__(
self,

View File

@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, Literal, Optional, Type
from typing import Any, Dict, Literal, Optional, Type, List
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
@@ -28,6 +28,7 @@ class ScrapflyScrapeWebsiteTool(BaseTool):
args_schema: Type[BaseModel] = ScrapflyScrapeWebsiteToolSchema
api_key: str = None
scrapfly: Optional[Any] = None
package_dependencies: List[str] = ["scrapfly-sdk"]
def __init__(self, api_key: str):
super().__init__()

View File

@@ -1,6 +1,6 @@
import re
import time
from typing import Any, Optional, Type
from typing import Any, Optional, Type, List
from urllib.parse import urlparse
from crewai.tools import BaseTool
@@ -58,6 +58,7 @@ class SeleniumScrapingTool(BaseTool):
css_element: Optional[str] = None
return_html: Optional[bool] = False
_by: Optional[Any] = None
package_dependencies: List[str] = ["selenium", "webdriver-manager"]
def __init__(
self,

View File

@@ -1,6 +1,6 @@
import os
import re
from typing import Any, Optional, Union
from typing import Any, Optional, Union, List
from crewai.tools import BaseTool
@@ -8,6 +8,8 @@ from crewai.tools import BaseTool
class SerpApiBaseTool(BaseTool):
"""Base class for SerpApi functionality with shared capabilities."""
package_dependencies: List[str] = ["serpapi"]
client: Optional[Any] = None
def __init__(self, **kwargs):

View File

@@ -99,6 +99,7 @@ class SnowflakeSearchTool(BaseTool):
_pool_lock: Optional[asyncio.Lock] = None
_thread_pool: Optional[ThreadPoolExecutor] = None
_model_rebuilt: bool = False
package_dependencies: List[str] = ["snowflake-connector-python", "snowflake-sqlalchemy", "cryptography"]
def __init__(self, **data):
"""Initialize SnowflakeSearchTool."""

View File

@@ -1,5 +1,5 @@
import logging
from typing import Any, Dict, Literal, Optional, Type
from typing import Any, Dict, Literal, Optional, Type, List
from urllib.parse import unquote, urlparse
from crewai.tools import BaseTool
@@ -53,6 +53,7 @@ class SpiderTool(BaseTool):
spider: Any = None
log_failures: bool = True
config: SpiderToolConfig = SpiderToolConfig()
package_dependencies: List[str] = ["spider-client"]
def __init__(
self,

View File

@@ -88,6 +88,7 @@ class StagehandToolSchema(BaseModel):
class StagehandTool(BaseTool):
package_dependencies: List[str] = ["stagehand"]
"""
A tool that uses Stagehand to automate web browser interactions using natural language.

View File

@@ -26,6 +26,7 @@ class TavilyExtractorToolSchema(BaseModel):
class TavilyExtractorTool(BaseTool):
package_dependencies: List[str] = ["tavily-python"]
"""
Tool that uses the Tavily API to extract content from web pages.

View File

@@ -1,6 +1,6 @@
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
from typing import Optional, Type, Any, Union, Literal, Sequence
from typing import Optional, Type, Any, Union, Literal, Sequence, List
from dotenv import load_dotenv
import os
import json
@@ -101,6 +101,7 @@ class TavilySearchTool(BaseTool):
default=1000,
description="Maximum length for the 'content' of each search result to avoid context window issues.",
)
package_dependencies: List[str] = ["tavily-python"]
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)

View File

@@ -1,6 +1,6 @@
import json
import os
from typing import Any, Optional, Type
from typing import Any, Optional, Type, List
try:
import weaviate
@@ -31,6 +31,7 @@ class WeaviateToolSchema(BaseModel):
class WeaviateVectorSearchTool(BaseTool):
"""Tool to search the Weaviate database"""
package_dependencies: List[str] = ["weaviate-client"]
name: str = "WeaviateVectorSearchTool"
description: str = "A tool to search the Weaviate database for relevant information on internal documents."
args_schema: Type[BaseModel] = WeaviateToolSchema
@@ -48,6 +49,7 @@ class WeaviateVectorSearchTool(BaseTool):
...,
description="The API key for the Weaviate cluster",
)
package_dependencies: List[str] = ["weaviate-client"]
def __init__(self, **kwargs):
super().__init__(**kwargs)

View File

@@ -1,12 +1,12 @@
import json
from typing import List, Optional
from typing import List, Optional, Type
import pytest
from pydantic import BaseModel, Field
from unittest import mock
from generate_tool_specs import ToolSpecExtractor
from crewai.tools.base_tool import EnvVar
from crewai.tools.base_tool import BaseTool, EnvVar
class MockToolSchema(BaseModel):
query: str = Field(..., description="The query parameter")
@@ -14,54 +14,26 @@ class MockToolSchema(BaseModel):
filters: Optional[List[str]] = Field(None, description="Optional filters to apply")
class MockTool:
name = "Mock Search Tool"
description = "A tool that mocks search functionality"
args_schema = MockToolSchema
class MockTool(BaseTool):
name: str = "Mock Search Tool"
description: str = "A tool that mocks search functionality"
args_schema: Type[BaseModel] = MockToolSchema
another_parameter: str = Field("Another way to define a default value", description="")
my_parameter: str = Field("This is default value", description="What a description")
my_parameter_bool: bool = Field(False)
package_dependencies: List[str] = Field(["this-is-a-required-package", "another-required-package"], description="")
env_vars: List[EnvVar] = [
EnvVar(name="SERPER_API_KEY", description="API key for Serper", required=True, default=None),
EnvVar(name="API_RATE_LIMIT", description="API rate limit", required=False, default="100")
]
@pytest.fixture
def extractor():
ext = ToolSpecExtractor()
MockTool.__pydantic_core_schema__ = create_mock_schema(MockTool)
MockTool.args_schema.__pydantic_core_schema__ = create_mock_schema_args(MockTool.args_schema)
return ext
def create_mock_schema(cls):
return {
"type": "model",
"cls": cls,
"schema": {
"type": "model-fields",
"fields": {
"name": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "str"}, "default": cls.name}, "metadata": {}},
"description": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "str"}, "default": cls.description}, "metadata": {}},
"args_schema": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "is-subclass", "cls": BaseModel}, "default": cls.args_schema}, "metadata": {}},
"env_vars": {
"type": "model-field", "schema": {"type": "default", "schema": {"type": "list", "items_schema": {"type": "model", "cls": "INSPECT CLASS", "schema": {"type": "model-fields", "fields": {"name": {"type": "model-field", "schema": {"type": "str"}, "metadata": {}}, "description": {"type": "model-field", "schema": {"type": "str"}, "metadata": {}}, "required": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "bool"}, "default": True}, "metadata": {}}, "default": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "nullable", "schema": {"type": "str"}}, "default": None}, "metadata": {}},}, "model_name": "EnvVar", "computed_fields": []}, "custom_init": False, "root_model": False, "config": {"title": "EnvVar"}, "ref": "crewai.tools.base_tool.EnvVar:4593650640", "metadata": {"pydantic_js_functions": ["INSPECT __get_pydantic_json_schema__"]}}}, "default": [EnvVar(name='SERPER_API_KEY', description='API key for Serper', required=True, default=None), EnvVar(name='API_RATE_LIMIT', description='API rate limit', required=False, default="100")]}, "metadata": {}
}
},
"model_name": cls.__name__
}
}
def create_mock_schema_args(cls):
return {
"type": "model",
"cls": cls,
"schema": {
"type": "model-fields",
"fields": {
"query": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "str"}, "default": "The query parameter"}},
"count": {"type": "model-field", "schema": {"type": "default", "schema": {"type": "int"}, "default": 5}, "metadata": {"pydantic_js_updates": {"description": "Number of results to return"}}},
"filters": {"type": "model-field", "schema": {"type": "nullable", "schema": {"type": "list", "items_schema": {"type": "str"}}}}
},
"model_name": cls.__name__
}
}
def test_unwrap_schema(extractor):
nested_schema = {
"type": "function-after",
@@ -72,19 +44,6 @@ def test_unwrap_schema(extractor):
assert result["value"] == "test"
@pytest.mark.parametrize(
"field, fallback, expected",
[
({"schema": {"default": "test_value"}}, None, "test_value"),
({}, "fallback_value", "fallback_value"),
({"schema": {"default": 123}}, "fallback_value", "fallback_value")
]
)
def test_extract_field_default(extractor, field, fallback, expected):
result = extractor._extract_field_default(field, fallback=fallback)
assert result == expected
@pytest.mark.parametrize(
"schema, expected",
[
@@ -112,7 +71,7 @@ def test_extract_param_type(extractor, info, expected_type):
assert extractor._extract_param_type(info) == expected_type
def test_extract_tool_info(extractor):
def test_extract_all_tools(extractor):
with mock.patch("generate_tool_specs.dir", return_value=["MockTool"]), \
mock.patch("generate_tool_specs.getattr", return_value=MockTool):
extractor.extract_all_tools()
@@ -120,6 +79,16 @@ def test_extract_tool_info(extractor):
assert len(extractor.tools_spec) == 1
tool_info = extractor.tools_spec[0]
assert tool_info.keys() == {
"name",
"humanized_name",
"description",
"run_params",
"env_vars",
"init_params",
"package_dependencies",
}
assert tool_info["name"] == "MockTool"
assert tool_info["humanized_name"] == "Mock Search Tool"
assert tool_info["description"] == "A tool that mocks search functionality"
@@ -142,12 +111,16 @@ def test_extract_tool_info(extractor):
params = {p["name"]: p for p in tool_info["run_params"]}
assert params["query"]["description"] == "The query parameter"
assert params["query"]["type"] == "str"
assert params["query"]["default"] == ""
assert params["count"]["description"] == "Number of results to return"
assert params["count"]["type"] == "int"
assert params["count"]["default"] == 5
assert params["filters"]["description"] == ""
assert params["filters"]["description"] == "Optional filters to apply"
assert params["filters"]["type"] == "list[str]"
assert params["filters"]["default"] == ""
assert tool_info["package_dependencies"] == ["this-is-a-required-package", "another-required-package"]
def test_save_to_json(extractor, tmp_path):