fix: ensure full type signature for tools

This commit is contained in:
Greyson LaLonde
2025-11-24 19:21:03 -05:00
parent 4ae8c36815
commit 610c1bb067
3 changed files with 1416 additions and 487 deletions

View File

@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
import asyncio
from collections.abc import Callable
from inspect import signature
from typing import Any, cast, get_args, get_origin
from typing import Any, get_args, get_origin
from pydantic import (
BaseModel,
@@ -55,7 +55,7 @@ class BaseTool(BaseModel, ABC):
default=False, description="Flag to check if the description has been updated."
)
cache_function: Callable = Field(
cache_function: Callable[..., bool] = Field(
default=lambda _args=None, _result=None: True,
description="Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.",
)
@@ -80,20 +80,21 @@ class BaseTool(BaseModel, ABC):
if v != cls._ArgsSchemaPlaceholder:
return v
return cast(
type[PydanticBaseModel],
type(
f"{cls.__name__}Schema",
(PydanticBaseModel,),
{
"__annotations__": {
k: v
for k, v in cls._run.__annotations__.items()
if k != "return"
},
},
),
)
run_sig = signature(cls._run)
fields: dict[str, Any] = {}
for param_name, param in run_sig.parameters.items():
if param_name in ("self", "return"):
continue
annotation = param.annotation if param.annotation != param.empty else Any
if param.default is param.empty:
fields[param_name] = (annotation, ...)
else:
fields[param_name] = (annotation, param.default)
return create_model(f"{cls.__name__}Schema", **fields)
@field_validator("max_usage_count", mode="before")
@classmethod
@@ -164,24 +165,21 @@ class BaseTool(BaseModel, ABC):
args_schema = getattr(tool, "args_schema", None)
if args_schema is None:
# Infer args_schema from the function signature if not provided
func_signature = signature(tool.func)
annotations = func_signature.parameters
args_fields: dict[str, Any] = {}
for name, param in annotations.items():
if name != "self":
param_annotation = (
param.annotation if param.annotation != param.empty else Any
)
field_info = Field(
default=...,
description="",
)
args_fields[name] = (param_annotation, field_info)
if args_fields:
args_schema = create_model(f"{tool.name}Input", **args_fields)
fields: dict[str, Any] = {}
for name, param in func_signature.parameters.items():
if name == "self":
continue
param_annotation = (
param.annotation if param.annotation != param.empty else Any
)
if param.default is param.empty:
fields[name] = (param_annotation, ...)
else:
fields[name] = (param_annotation, param.default)
if fields:
args_schema = create_model(f"{tool.name}Input", **fields)
else:
# Create a default schema with no fields if no parameters are found
args_schema = create_model(
f"{tool.name}Input", __base__=PydanticBaseModel
)
@@ -195,20 +193,24 @@ class BaseTool(BaseModel, ABC):
def _set_args_schema(self) -> None:
if self.args_schema is None:
class_name = f"{self.__class__.__name__}Schema"
self.args_schema = cast(
type[PydanticBaseModel],
type(
class_name,
(PydanticBaseModel,),
{
"__annotations__": {
k: v
for k, v in self._run.__annotations__.items()
if k != "return"
},
},
),
run_sig = signature(self._run)
fields: dict[str, Any] = {}
for param_name, param in run_sig.parameters.items():
if param_name in ("self", "return"):
continue
annotation = (
param.annotation if param.annotation != param.empty else Any
)
if param.default is param.empty:
fields[param_name] = (annotation, ...)
else:
fields[param_name] = (annotation, param.default)
self.args_schema = create_model(
f"{self.__class__.__name__}Schema", **fields
)
def _generate_description(self) -> None:
@@ -241,13 +243,13 @@ class BaseTool(BaseModel, ABC):
args_str = ", ".join(BaseTool._get_arg_annotations(arg) for arg in args)
return f"{origin.__name__}[{args_str}]"
return origin.__name__
return str(origin.__name__)
class Tool(BaseTool):
"""The function that will be executed when the tool is called."""
func: Callable
func: Callable[..., Any]
def _run(self, *args: Any, **kwargs: Any) -> Any:
return self.func(*args, **kwargs)
@@ -275,24 +277,21 @@ class Tool(BaseTool):
args_schema = getattr(tool, "args_schema", None)
if args_schema is None:
# Infer args_schema from the function signature if not provided
func_signature = signature(tool.func)
annotations = func_signature.parameters
args_fields: dict[str, Any] = {}
for name, param in annotations.items():
if name != "self":
param_annotation = (
param.annotation if param.annotation != param.empty else Any
)
field_info = Field(
default=...,
description="",
)
args_fields[name] = (param_annotation, field_info)
if args_fields:
args_schema = create_model(f"{tool.name}Input", **args_fields)
fields: dict[str, Any] = {}
for name, param in func_signature.parameters.items():
if name == "self":
continue
param_annotation = (
param.annotation if param.annotation != param.empty else Any
)
if param.default is param.empty:
fields[name] = (param_annotation, ...)
else:
fields[name] = (param_annotation, param.default)
if fields:
args_schema = create_model(f"{tool.name}Input", **fields)
else:
# Create a default schema with no fields if no parameters are found
args_schema = create_model(
f"{tool.name}Input", __base__=PydanticBaseModel
)
@@ -312,10 +311,11 @@ def to_langchain(
def tool(
*args, result_as_answer: bool = False, max_usage_count: int | None = None
) -> Callable:
"""
Decorator to create a tool from a function.
*args: Callable[..., Any] | str,
result_as_answer: bool = False,
max_usage_count: int | None = None,
) -> Callable[..., Any] | BaseTool:
"""Decorator to create a tool from a function.
Args:
*args: Positional arguments, either the function to decorate or the tool name.
@@ -323,26 +323,31 @@ def tool(
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
"""
def _make_with_name(tool_name: str) -> Callable:
def _make_tool(f: Callable) -> BaseTool:
def _make_with_name(tool_name: str) -> Callable[[Callable[..., Any]], BaseTool]:
def _make_tool(f: Callable[..., Any]) -> BaseTool:
if f.__doc__ is None:
raise ValueError("Function must have a docstring")
if f.__annotations__ is None:
raise ValueError("Function must have type annotations")
func_sig = signature(f)
fields: dict[str, Any] = {}
for param_name, param in func_sig.parameters.items():
if param_name == "return":
continue
annotation = (
param.annotation if param.annotation != param.empty else Any
)
if param.default is param.empty:
fields[param_name] = (annotation, ...)
else:
fields[param_name] = (annotation, param.default)
class_name = "".join(tool_name.split()).title()
args_schema = cast(
type[PydanticBaseModel],
type(
class_name,
(PydanticBaseModel,),
{
"__annotations__": {
k: v for k, v in f.__annotations__.items() if k != "return"
},
},
),
)
args_schema = create_model(class_name, **fields)
return Tool(
name=tool_name,

View File

@@ -230,3 +230,67 @@ def test_max_usage_count_is_respected():
crew.kickoff()
assert tool.max_usage_count == 5
assert tool.current_usage_count == 5
def test_tool_schema_respects_default_values():
"""Test that tool schema correctly marks optional parameters with defaults."""
class ToolWithDefaults(BaseTool):
name: str = "tool_with_defaults"
description: str = "A tool with optional parameters"
def _run(
self,
query: str,
similarity_threshold: float | None = None,
limit: int = 5,
) -> str:
return f"{query} - {similarity_threshold} - {limit}"
tool = ToolWithDefaults()
schema = tool.args_schema.model_json_schema()
assert schema["required"] == ["query"]
props = schema["properties"]
assert "default" in props["similarity_threshold"]
assert props["similarity_threshold"]["default"] is None
assert "default" in props["limit"]
assert props["limit"]["default"] == 5
def test_tool_decorator_respects_default_values():
"""Test that @tool decorator correctly handles optional parameters with defaults."""
@tool("search_tool")
def search_with_defaults(
query: str, max_results: int = 10, sort_by: str | None = None
) -> str:
"""Search for information with optional parameters."""
return f"{query} - {max_results} - {sort_by}"
schema = search_with_defaults.args_schema.model_json_schema()
assert schema["required"] == ["query"]
props = schema["properties"]
assert "default" in props["max_results"]
assert props["max_results"]["default"] == 10
assert "default" in props["sort_by"]
assert props["sort_by"]["default"] is None
def test_tool_schema_all_required_when_no_defaults():
"""Test that all parameters are required when no defaults are provided."""
class AllRequiredTool(BaseTool):
name: str = "all_required"
description: str = "All params required"
def _run(self, param1: str, param2: int, param3: bool) -> str:
return f"{param1} - {param2} - {param3}"
tool = AllRequiredTool()
schema = tool.args_schema.model_json_schema()
assert set(schema["required"]) == {"param1", "param2", "param3"}