mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
fix: ensure full type signature for tools
This commit is contained in:
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
|||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
from typing import Any, cast, get_args, get_origin
|
from typing import Any, get_args, get_origin
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
@@ -55,7 +55,7 @@ class BaseTool(BaseModel, ABC):
|
|||||||
default=False, description="Flag to check if the description has been updated."
|
default=False, description="Flag to check if the description has been updated."
|
||||||
)
|
)
|
||||||
|
|
||||||
cache_function: Callable = Field(
|
cache_function: Callable[..., bool] = Field(
|
||||||
default=lambda _args=None, _result=None: True,
|
default=lambda _args=None, _result=None: True,
|
||||||
description="Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.",
|
description="Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.",
|
||||||
)
|
)
|
||||||
@@ -80,20 +80,21 @@ class BaseTool(BaseModel, ABC):
|
|||||||
if v != cls._ArgsSchemaPlaceholder:
|
if v != cls._ArgsSchemaPlaceholder:
|
||||||
return v
|
return v
|
||||||
|
|
||||||
return cast(
|
run_sig = signature(cls._run)
|
||||||
type[PydanticBaseModel],
|
fields: dict[str, Any] = {}
|
||||||
type(
|
|
||||||
f"{cls.__name__}Schema",
|
for param_name, param in run_sig.parameters.items():
|
||||||
(PydanticBaseModel,),
|
if param_name in ("self", "return"):
|
||||||
{
|
continue
|
||||||
"__annotations__": {
|
|
||||||
k: v
|
annotation = param.annotation if param.annotation != param.empty else Any
|
||||||
for k, v in cls._run.__annotations__.items()
|
|
||||||
if k != "return"
|
if param.default is param.empty:
|
||||||
},
|
fields[param_name] = (annotation, ...)
|
||||||
},
|
else:
|
||||||
),
|
fields[param_name] = (annotation, param.default)
|
||||||
)
|
|
||||||
|
return create_model(f"{cls.__name__}Schema", **fields)
|
||||||
|
|
||||||
@field_validator("max_usage_count", mode="before")
|
@field_validator("max_usage_count", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -164,24 +165,21 @@ class BaseTool(BaseModel, ABC):
|
|||||||
args_schema = getattr(tool, "args_schema", None)
|
args_schema = getattr(tool, "args_schema", None)
|
||||||
|
|
||||||
if args_schema is None:
|
if args_schema is None:
|
||||||
# Infer args_schema from the function signature if not provided
|
|
||||||
func_signature = signature(tool.func)
|
func_signature = signature(tool.func)
|
||||||
annotations = func_signature.parameters
|
fields: dict[str, Any] = {}
|
||||||
args_fields: dict[str, Any] = {}
|
for name, param in func_signature.parameters.items():
|
||||||
for name, param in annotations.items():
|
if name == "self":
|
||||||
if name != "self":
|
continue
|
||||||
param_annotation = (
|
param_annotation = (
|
||||||
param.annotation if param.annotation != param.empty else Any
|
param.annotation if param.annotation != param.empty else Any
|
||||||
)
|
)
|
||||||
field_info = Field(
|
if param.default is param.empty:
|
||||||
default=...,
|
fields[name] = (param_annotation, ...)
|
||||||
description="",
|
else:
|
||||||
)
|
fields[name] = (param_annotation, param.default)
|
||||||
args_fields[name] = (param_annotation, field_info)
|
if fields:
|
||||||
if args_fields:
|
args_schema = create_model(f"{tool.name}Input", **fields)
|
||||||
args_schema = create_model(f"{tool.name}Input", **args_fields)
|
|
||||||
else:
|
else:
|
||||||
# Create a default schema with no fields if no parameters are found
|
|
||||||
args_schema = create_model(
|
args_schema = create_model(
|
||||||
f"{tool.name}Input", __base__=PydanticBaseModel
|
f"{tool.name}Input", __base__=PydanticBaseModel
|
||||||
)
|
)
|
||||||
@@ -195,20 +193,24 @@ class BaseTool(BaseModel, ABC):
|
|||||||
|
|
||||||
def _set_args_schema(self) -> None:
|
def _set_args_schema(self) -> None:
|
||||||
if self.args_schema is None:
|
if self.args_schema is None:
|
||||||
class_name = f"{self.__class__.__name__}Schema"
|
run_sig = signature(self._run)
|
||||||
self.args_schema = cast(
|
fields: dict[str, Any] = {}
|
||||||
type[PydanticBaseModel],
|
|
||||||
type(
|
for param_name, param in run_sig.parameters.items():
|
||||||
class_name,
|
if param_name in ("self", "return"):
|
||||||
(PydanticBaseModel,),
|
continue
|
||||||
{
|
|
||||||
"__annotations__": {
|
annotation = (
|
||||||
k: v
|
param.annotation if param.annotation != param.empty else Any
|
||||||
for k, v in self._run.__annotations__.items()
|
)
|
||||||
if k != "return"
|
|
||||||
},
|
if param.default is param.empty:
|
||||||
},
|
fields[param_name] = (annotation, ...)
|
||||||
),
|
else:
|
||||||
|
fields[param_name] = (annotation, param.default)
|
||||||
|
|
||||||
|
self.args_schema = create_model(
|
||||||
|
f"{self.__class__.__name__}Schema", **fields
|
||||||
)
|
)
|
||||||
|
|
||||||
def _generate_description(self) -> None:
|
def _generate_description(self) -> None:
|
||||||
@@ -241,13 +243,13 @@ class BaseTool(BaseModel, ABC):
|
|||||||
args_str = ", ".join(BaseTool._get_arg_annotations(arg) for arg in args)
|
args_str = ", ".join(BaseTool._get_arg_annotations(arg) for arg in args)
|
||||||
return f"{origin.__name__}[{args_str}]"
|
return f"{origin.__name__}[{args_str}]"
|
||||||
|
|
||||||
return origin.__name__
|
return str(origin.__name__)
|
||||||
|
|
||||||
|
|
||||||
class Tool(BaseTool):
|
class Tool(BaseTool):
|
||||||
"""The function that will be executed when the tool is called."""
|
"""The function that will be executed when the tool is called."""
|
||||||
|
|
||||||
func: Callable
|
func: Callable[..., Any]
|
||||||
|
|
||||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
return self.func(*args, **kwargs)
|
return self.func(*args, **kwargs)
|
||||||
@@ -275,24 +277,21 @@ class Tool(BaseTool):
|
|||||||
args_schema = getattr(tool, "args_schema", None)
|
args_schema = getattr(tool, "args_schema", None)
|
||||||
|
|
||||||
if args_schema is None:
|
if args_schema is None:
|
||||||
# Infer args_schema from the function signature if not provided
|
|
||||||
func_signature = signature(tool.func)
|
func_signature = signature(tool.func)
|
||||||
annotations = func_signature.parameters
|
fields: dict[str, Any] = {}
|
||||||
args_fields: dict[str, Any] = {}
|
for name, param in func_signature.parameters.items():
|
||||||
for name, param in annotations.items():
|
if name == "self":
|
||||||
if name != "self":
|
continue
|
||||||
param_annotation = (
|
param_annotation = (
|
||||||
param.annotation if param.annotation != param.empty else Any
|
param.annotation if param.annotation != param.empty else Any
|
||||||
)
|
)
|
||||||
field_info = Field(
|
if param.default is param.empty:
|
||||||
default=...,
|
fields[name] = (param_annotation, ...)
|
||||||
description="",
|
else:
|
||||||
)
|
fields[name] = (param_annotation, param.default)
|
||||||
args_fields[name] = (param_annotation, field_info)
|
if fields:
|
||||||
if args_fields:
|
args_schema = create_model(f"{tool.name}Input", **fields)
|
||||||
args_schema = create_model(f"{tool.name}Input", **args_fields)
|
|
||||||
else:
|
else:
|
||||||
# Create a default schema with no fields if no parameters are found
|
|
||||||
args_schema = create_model(
|
args_schema = create_model(
|
||||||
f"{tool.name}Input", __base__=PydanticBaseModel
|
f"{tool.name}Input", __base__=PydanticBaseModel
|
||||||
)
|
)
|
||||||
@@ -312,10 +311,11 @@ def to_langchain(
|
|||||||
|
|
||||||
|
|
||||||
def tool(
|
def tool(
|
||||||
*args, result_as_answer: bool = False, max_usage_count: int | None = None
|
*args: Callable[..., Any] | str,
|
||||||
) -> Callable:
|
result_as_answer: bool = False,
|
||||||
"""
|
max_usage_count: int | None = None,
|
||||||
Decorator to create a tool from a function.
|
) -> Callable[..., Any] | BaseTool:
|
||||||
|
"""Decorator to create a tool from a function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
*args: Positional arguments, either the function to decorate or the tool name.
|
*args: Positional arguments, either the function to decorate or the tool name.
|
||||||
@@ -323,26 +323,31 @@ def tool(
|
|||||||
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
|
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _make_with_name(tool_name: str) -> Callable:
|
def _make_with_name(tool_name: str) -> Callable[[Callable[..., Any]], BaseTool]:
|
||||||
def _make_tool(f: Callable) -> BaseTool:
|
def _make_tool(f: Callable[..., Any]) -> BaseTool:
|
||||||
if f.__doc__ is None:
|
if f.__doc__ is None:
|
||||||
raise ValueError("Function must have a docstring")
|
raise ValueError("Function must have a docstring")
|
||||||
if f.__annotations__ is None:
|
if f.__annotations__ is None:
|
||||||
raise ValueError("Function must have type annotations")
|
raise ValueError("Function must have type annotations")
|
||||||
|
|
||||||
|
func_sig = signature(f)
|
||||||
|
fields: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for param_name, param in func_sig.parameters.items():
|
||||||
|
if param_name == "return":
|
||||||
|
continue
|
||||||
|
|
||||||
|
annotation = (
|
||||||
|
param.annotation if param.annotation != param.empty else Any
|
||||||
|
)
|
||||||
|
|
||||||
|
if param.default is param.empty:
|
||||||
|
fields[param_name] = (annotation, ...)
|
||||||
|
else:
|
||||||
|
fields[param_name] = (annotation, param.default)
|
||||||
|
|
||||||
class_name = "".join(tool_name.split()).title()
|
class_name = "".join(tool_name.split()).title()
|
||||||
args_schema = cast(
|
args_schema = create_model(class_name, **fields)
|
||||||
type[PydanticBaseModel],
|
|
||||||
type(
|
|
||||||
class_name,
|
|
||||||
(PydanticBaseModel,),
|
|
||||||
{
|
|
||||||
"__annotations__": {
|
|
||||||
k: v for k, v in f.__annotations__.items() if k != "return"
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return Tool(
|
return Tool(
|
||||||
name=tool_name,
|
name=tool_name,
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -230,3 +230,67 @@ def test_max_usage_count_is_respected():
|
|||||||
crew.kickoff()
|
crew.kickoff()
|
||||||
assert tool.max_usage_count == 5
|
assert tool.max_usage_count == 5
|
||||||
assert tool.current_usage_count == 5
|
assert tool.current_usage_count == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_schema_respects_default_values():
|
||||||
|
"""Test that tool schema correctly marks optional parameters with defaults."""
|
||||||
|
|
||||||
|
class ToolWithDefaults(BaseTool):
|
||||||
|
name: str = "tool_with_defaults"
|
||||||
|
description: str = "A tool with optional parameters"
|
||||||
|
|
||||||
|
def _run(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
similarity_threshold: float | None = None,
|
||||||
|
limit: int = 5,
|
||||||
|
) -> str:
|
||||||
|
return f"{query} - {similarity_threshold} - {limit}"
|
||||||
|
|
||||||
|
tool = ToolWithDefaults()
|
||||||
|
schema = tool.args_schema.model_json_schema()
|
||||||
|
|
||||||
|
assert schema["required"] == ["query"]
|
||||||
|
|
||||||
|
props = schema["properties"]
|
||||||
|
assert "default" in props["similarity_threshold"]
|
||||||
|
assert props["similarity_threshold"]["default"] is None
|
||||||
|
assert "default" in props["limit"]
|
||||||
|
assert props["limit"]["default"] == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_decorator_respects_default_values():
|
||||||
|
"""Test that @tool decorator correctly handles optional parameters with defaults."""
|
||||||
|
|
||||||
|
@tool("search_tool")
|
||||||
|
def search_with_defaults(
|
||||||
|
query: str, max_results: int = 10, sort_by: str | None = None
|
||||||
|
) -> str:
|
||||||
|
"""Search for information with optional parameters."""
|
||||||
|
return f"{query} - {max_results} - {sort_by}"
|
||||||
|
|
||||||
|
schema = search_with_defaults.args_schema.model_json_schema()
|
||||||
|
|
||||||
|
assert schema["required"] == ["query"]
|
||||||
|
|
||||||
|
props = schema["properties"]
|
||||||
|
assert "default" in props["max_results"]
|
||||||
|
assert props["max_results"]["default"] == 10
|
||||||
|
assert "default" in props["sort_by"]
|
||||||
|
assert props["sort_by"]["default"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_schema_all_required_when_no_defaults():
|
||||||
|
"""Test that all parameters are required when no defaults are provided."""
|
||||||
|
|
||||||
|
class AllRequiredTool(BaseTool):
|
||||||
|
name: str = "all_required"
|
||||||
|
description: str = "All params required"
|
||||||
|
|
||||||
|
def _run(self, param1: str, param2: int, param3: bool) -> str:
|
||||||
|
return f"{param1} - {param2} - {param3}"
|
||||||
|
|
||||||
|
tool = AllRequiredTool()
|
||||||
|
schema = tool.args_schema.model_json_schema()
|
||||||
|
|
||||||
|
assert set(schema["required"]) == {"param1", "param2", "param3"}
|
||||||
|
|||||||
Reference in New Issue
Block a user