mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
fix: ensure full type signature for tools
This commit is contained in:
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from inspect import signature
|
||||
from typing import Any, cast, get_args, get_origin
|
||||
from typing import Any, get_args, get_origin
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@@ -55,7 +55,7 @@ class BaseTool(BaseModel, ABC):
|
||||
default=False, description="Flag to check if the description has been updated."
|
||||
)
|
||||
|
||||
cache_function: Callable = Field(
|
||||
cache_function: Callable[..., bool] = Field(
|
||||
default=lambda _args=None, _result=None: True,
|
||||
description="Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.",
|
||||
)
|
||||
@@ -80,20 +80,21 @@ class BaseTool(BaseModel, ABC):
|
||||
if v != cls._ArgsSchemaPlaceholder:
|
||||
return v
|
||||
|
||||
return cast(
|
||||
type[PydanticBaseModel],
|
||||
type(
|
||||
f"{cls.__name__}Schema",
|
||||
(PydanticBaseModel,),
|
||||
{
|
||||
"__annotations__": {
|
||||
k: v
|
||||
for k, v in cls._run.__annotations__.items()
|
||||
if k != "return"
|
||||
},
|
||||
},
|
||||
),
|
||||
)
|
||||
run_sig = signature(cls._run)
|
||||
fields: dict[str, Any] = {}
|
||||
|
||||
for param_name, param in run_sig.parameters.items():
|
||||
if param_name in ("self", "return"):
|
||||
continue
|
||||
|
||||
annotation = param.annotation if param.annotation != param.empty else Any
|
||||
|
||||
if param.default is param.empty:
|
||||
fields[param_name] = (annotation, ...)
|
||||
else:
|
||||
fields[param_name] = (annotation, param.default)
|
||||
|
||||
return create_model(f"{cls.__name__}Schema", **fields)
|
||||
|
||||
@field_validator("max_usage_count", mode="before")
|
||||
@classmethod
|
||||
@@ -164,24 +165,21 @@ class BaseTool(BaseModel, ABC):
|
||||
args_schema = getattr(tool, "args_schema", None)
|
||||
|
||||
if args_schema is None:
|
||||
# Infer args_schema from the function signature if not provided
|
||||
func_signature = signature(tool.func)
|
||||
annotations = func_signature.parameters
|
||||
args_fields: dict[str, Any] = {}
|
||||
for name, param in annotations.items():
|
||||
if name != "self":
|
||||
param_annotation = (
|
||||
param.annotation if param.annotation != param.empty else Any
|
||||
)
|
||||
field_info = Field(
|
||||
default=...,
|
||||
description="",
|
||||
)
|
||||
args_fields[name] = (param_annotation, field_info)
|
||||
if args_fields:
|
||||
args_schema = create_model(f"{tool.name}Input", **args_fields)
|
||||
fields: dict[str, Any] = {}
|
||||
for name, param in func_signature.parameters.items():
|
||||
if name == "self":
|
||||
continue
|
||||
param_annotation = (
|
||||
param.annotation if param.annotation != param.empty else Any
|
||||
)
|
||||
if param.default is param.empty:
|
||||
fields[name] = (param_annotation, ...)
|
||||
else:
|
||||
fields[name] = (param_annotation, param.default)
|
||||
if fields:
|
||||
args_schema = create_model(f"{tool.name}Input", **fields)
|
||||
else:
|
||||
# Create a default schema with no fields if no parameters are found
|
||||
args_schema = create_model(
|
||||
f"{tool.name}Input", __base__=PydanticBaseModel
|
||||
)
|
||||
@@ -195,20 +193,24 @@ class BaseTool(BaseModel, ABC):
|
||||
|
||||
def _set_args_schema(self) -> None:
|
||||
if self.args_schema is None:
|
||||
class_name = f"{self.__class__.__name__}Schema"
|
||||
self.args_schema = cast(
|
||||
type[PydanticBaseModel],
|
||||
type(
|
||||
class_name,
|
||||
(PydanticBaseModel,),
|
||||
{
|
||||
"__annotations__": {
|
||||
k: v
|
||||
for k, v in self._run.__annotations__.items()
|
||||
if k != "return"
|
||||
},
|
||||
},
|
||||
),
|
||||
run_sig = signature(self._run)
|
||||
fields: dict[str, Any] = {}
|
||||
|
||||
for param_name, param in run_sig.parameters.items():
|
||||
if param_name in ("self", "return"):
|
||||
continue
|
||||
|
||||
annotation = (
|
||||
param.annotation if param.annotation != param.empty else Any
|
||||
)
|
||||
|
||||
if param.default is param.empty:
|
||||
fields[param_name] = (annotation, ...)
|
||||
else:
|
||||
fields[param_name] = (annotation, param.default)
|
||||
|
||||
self.args_schema = create_model(
|
||||
f"{self.__class__.__name__}Schema", **fields
|
||||
)
|
||||
|
||||
def _generate_description(self) -> None:
|
||||
@@ -241,13 +243,13 @@ class BaseTool(BaseModel, ABC):
|
||||
args_str = ", ".join(BaseTool._get_arg_annotations(arg) for arg in args)
|
||||
return f"{origin.__name__}[{args_str}]"
|
||||
|
||||
return origin.__name__
|
||||
return str(origin.__name__)
|
||||
|
||||
|
||||
class Tool(BaseTool):
|
||||
"""The function that will be executed when the tool is called."""
|
||||
|
||||
func: Callable
|
||||
func: Callable[..., Any]
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.func(*args, **kwargs)
|
||||
@@ -275,24 +277,21 @@ class Tool(BaseTool):
|
||||
args_schema = getattr(tool, "args_schema", None)
|
||||
|
||||
if args_schema is None:
|
||||
# Infer args_schema from the function signature if not provided
|
||||
func_signature = signature(tool.func)
|
||||
annotations = func_signature.parameters
|
||||
args_fields: dict[str, Any] = {}
|
||||
for name, param in annotations.items():
|
||||
if name != "self":
|
||||
param_annotation = (
|
||||
param.annotation if param.annotation != param.empty else Any
|
||||
)
|
||||
field_info = Field(
|
||||
default=...,
|
||||
description="",
|
||||
)
|
||||
args_fields[name] = (param_annotation, field_info)
|
||||
if args_fields:
|
||||
args_schema = create_model(f"{tool.name}Input", **args_fields)
|
||||
fields: dict[str, Any] = {}
|
||||
for name, param in func_signature.parameters.items():
|
||||
if name == "self":
|
||||
continue
|
||||
param_annotation = (
|
||||
param.annotation if param.annotation != param.empty else Any
|
||||
)
|
||||
if param.default is param.empty:
|
||||
fields[name] = (param_annotation, ...)
|
||||
else:
|
||||
fields[name] = (param_annotation, param.default)
|
||||
if fields:
|
||||
args_schema = create_model(f"{tool.name}Input", **fields)
|
||||
else:
|
||||
# Create a default schema with no fields if no parameters are found
|
||||
args_schema = create_model(
|
||||
f"{tool.name}Input", __base__=PydanticBaseModel
|
||||
)
|
||||
@@ -312,10 +311,11 @@ def to_langchain(
|
||||
|
||||
|
||||
def tool(
|
||||
*args, result_as_answer: bool = False, max_usage_count: int | None = None
|
||||
) -> Callable:
|
||||
"""
|
||||
Decorator to create a tool from a function.
|
||||
*args: Callable[..., Any] | str,
|
||||
result_as_answer: bool = False,
|
||||
max_usage_count: int | None = None,
|
||||
) -> Callable[..., Any] | BaseTool:
|
||||
"""Decorator to create a tool from a function.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments, either the function to decorate or the tool name.
|
||||
@@ -323,26 +323,31 @@ def tool(
|
||||
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
|
||||
"""
|
||||
|
||||
def _make_with_name(tool_name: str) -> Callable:
|
||||
def _make_tool(f: Callable) -> BaseTool:
|
||||
def _make_with_name(tool_name: str) -> Callable[[Callable[..., Any]], BaseTool]:
|
||||
def _make_tool(f: Callable[..., Any]) -> BaseTool:
|
||||
if f.__doc__ is None:
|
||||
raise ValueError("Function must have a docstring")
|
||||
if f.__annotations__ is None:
|
||||
raise ValueError("Function must have type annotations")
|
||||
|
||||
func_sig = signature(f)
|
||||
fields: dict[str, Any] = {}
|
||||
|
||||
for param_name, param in func_sig.parameters.items():
|
||||
if param_name == "return":
|
||||
continue
|
||||
|
||||
annotation = (
|
||||
param.annotation if param.annotation != param.empty else Any
|
||||
)
|
||||
|
||||
if param.default is param.empty:
|
||||
fields[param_name] = (annotation, ...)
|
||||
else:
|
||||
fields[param_name] = (annotation, param.default)
|
||||
|
||||
class_name = "".join(tool_name.split()).title()
|
||||
args_schema = cast(
|
||||
type[PydanticBaseModel],
|
||||
type(
|
||||
class_name,
|
||||
(PydanticBaseModel,),
|
||||
{
|
||||
"__annotations__": {
|
||||
k: v for k, v in f.__annotations__.items() if k != "return"
|
||||
},
|
||||
},
|
||||
),
|
||||
)
|
||||
args_schema = create_model(class_name, **fields)
|
||||
|
||||
return Tool(
|
||||
name=tool_name,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -230,3 +230,67 @@ def test_max_usage_count_is_respected():
|
||||
crew.kickoff()
|
||||
assert tool.max_usage_count == 5
|
||||
assert tool.current_usage_count == 5
|
||||
|
||||
|
||||
def test_tool_schema_respects_default_values():
|
||||
"""Test that tool schema correctly marks optional parameters with defaults."""
|
||||
|
||||
class ToolWithDefaults(BaseTool):
|
||||
name: str = "tool_with_defaults"
|
||||
description: str = "A tool with optional parameters"
|
||||
|
||||
def _run(
|
||||
self,
|
||||
query: str,
|
||||
similarity_threshold: float | None = None,
|
||||
limit: int = 5,
|
||||
) -> str:
|
||||
return f"{query} - {similarity_threshold} - {limit}"
|
||||
|
||||
tool = ToolWithDefaults()
|
||||
schema = tool.args_schema.model_json_schema()
|
||||
|
||||
assert schema["required"] == ["query"]
|
||||
|
||||
props = schema["properties"]
|
||||
assert "default" in props["similarity_threshold"]
|
||||
assert props["similarity_threshold"]["default"] is None
|
||||
assert "default" in props["limit"]
|
||||
assert props["limit"]["default"] == 5
|
||||
|
||||
|
||||
def test_tool_decorator_respects_default_values():
|
||||
"""Test that @tool decorator correctly handles optional parameters with defaults."""
|
||||
|
||||
@tool("search_tool")
|
||||
def search_with_defaults(
|
||||
query: str, max_results: int = 10, sort_by: str | None = None
|
||||
) -> str:
|
||||
"""Search for information with optional parameters."""
|
||||
return f"{query} - {max_results} - {sort_by}"
|
||||
|
||||
schema = search_with_defaults.args_schema.model_json_schema()
|
||||
|
||||
assert schema["required"] == ["query"]
|
||||
|
||||
props = schema["properties"]
|
||||
assert "default" in props["max_results"]
|
||||
assert props["max_results"]["default"] == 10
|
||||
assert "default" in props["sort_by"]
|
||||
assert props["sort_by"]["default"] is None
|
||||
|
||||
|
||||
def test_tool_schema_all_required_when_no_defaults():
|
||||
"""Test that all parameters are required when no defaults are provided."""
|
||||
|
||||
class AllRequiredTool(BaseTool):
|
||||
name: str = "all_required"
|
||||
description: str = "All params required"
|
||||
|
||||
def _run(self, param1: str, param2: int, param3: bool) -> str:
|
||||
return f"{param1} - {param2} - {param3}"
|
||||
|
||||
tool = AllRequiredTool()
|
||||
schema = tool.args_schema.model_json_schema()
|
||||
|
||||
assert set(schema["required"]) == {"param1", "param2", "param3"}
|
||||
|
||||
Reference in New Issue
Block a user