fix: ensure full type signature for tools

This commit is contained in:
Greyson LaLonde
2025-11-24 19:21:03 -05:00
parent 4ae8c36815
commit 610c1bb067
3 changed files with 1416 additions and 487 deletions

View File

@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
import asyncio import asyncio
from collections.abc import Callable from collections.abc import Callable
from inspect import signature from inspect import signature
from typing import Any, cast, get_args, get_origin from typing import Any, get_args, get_origin
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
@@ -55,7 +55,7 @@ class BaseTool(BaseModel, ABC):
default=False, description="Flag to check if the description has been updated." default=False, description="Flag to check if the description has been updated."
) )
cache_function: Callable = Field( cache_function: Callable[..., bool] = Field(
default=lambda _args=None, _result=None: True, default=lambda _args=None, _result=None: True,
description="Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.", description="Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.",
) )
@@ -80,20 +80,21 @@ class BaseTool(BaseModel, ABC):
if v != cls._ArgsSchemaPlaceholder: if v != cls._ArgsSchemaPlaceholder:
return v return v
return cast( run_sig = signature(cls._run)
type[PydanticBaseModel], fields: dict[str, Any] = {}
type(
f"{cls.__name__}Schema", for param_name, param in run_sig.parameters.items():
(PydanticBaseModel,), if param_name in ("self", "return"):
{ continue
"__annotations__": {
k: v annotation = param.annotation if param.annotation != param.empty else Any
for k, v in cls._run.__annotations__.items()
if k != "return" if param.default is param.empty:
}, fields[param_name] = (annotation, ...)
}, else:
), fields[param_name] = (annotation, param.default)
)
return create_model(f"{cls.__name__}Schema", **fields)
@field_validator("max_usage_count", mode="before") @field_validator("max_usage_count", mode="before")
@classmethod @classmethod
@@ -164,24 +165,21 @@ class BaseTool(BaseModel, ABC):
args_schema = getattr(tool, "args_schema", None) args_schema = getattr(tool, "args_schema", None)
if args_schema is None: if args_schema is None:
# Infer args_schema from the function signature if not provided
func_signature = signature(tool.func) func_signature = signature(tool.func)
annotations = func_signature.parameters fields: dict[str, Any] = {}
args_fields: dict[str, Any] = {} for name, param in func_signature.parameters.items():
for name, param in annotations.items(): if name == "self":
if name != "self": continue
param_annotation = ( param_annotation = (
param.annotation if param.annotation != param.empty else Any param.annotation if param.annotation != param.empty else Any
) )
field_info = Field( if param.default is param.empty:
default=..., fields[name] = (param_annotation, ...)
description="", else:
) fields[name] = (param_annotation, param.default)
args_fields[name] = (param_annotation, field_info) if fields:
if args_fields: args_schema = create_model(f"{tool.name}Input", **fields)
args_schema = create_model(f"{tool.name}Input", **args_fields)
else: else:
# Create a default schema with no fields if no parameters are found
args_schema = create_model( args_schema = create_model(
f"{tool.name}Input", __base__=PydanticBaseModel f"{tool.name}Input", __base__=PydanticBaseModel
) )
@@ -195,20 +193,24 @@ class BaseTool(BaseModel, ABC):
def _set_args_schema(self) -> None: def _set_args_schema(self) -> None:
if self.args_schema is None: if self.args_schema is None:
class_name = f"{self.__class__.__name__}Schema" run_sig = signature(self._run)
self.args_schema = cast( fields: dict[str, Any] = {}
type[PydanticBaseModel],
type( for param_name, param in run_sig.parameters.items():
class_name, if param_name in ("self", "return"):
(PydanticBaseModel,), continue
{
"__annotations__": { annotation = (
k: v param.annotation if param.annotation != param.empty else Any
for k, v in self._run.__annotations__.items() )
if k != "return"
}, if param.default is param.empty:
}, fields[param_name] = (annotation, ...)
), else:
fields[param_name] = (annotation, param.default)
self.args_schema = create_model(
f"{self.__class__.__name__}Schema", **fields
) )
def _generate_description(self) -> None: def _generate_description(self) -> None:
@@ -241,13 +243,13 @@ class BaseTool(BaseModel, ABC):
args_str = ", ".join(BaseTool._get_arg_annotations(arg) for arg in args) args_str = ", ".join(BaseTool._get_arg_annotations(arg) for arg in args)
return f"{origin.__name__}[{args_str}]" return f"{origin.__name__}[{args_str}]"
return origin.__name__ return str(origin.__name__)
class Tool(BaseTool): class Tool(BaseTool):
"""The function that will be executed when the tool is called.""" """The function that will be executed when the tool is called."""
func: Callable func: Callable[..., Any]
def _run(self, *args: Any, **kwargs: Any) -> Any: def _run(self, *args: Any, **kwargs: Any) -> Any:
return self.func(*args, **kwargs) return self.func(*args, **kwargs)
@@ -275,24 +277,21 @@ class Tool(BaseTool):
args_schema = getattr(tool, "args_schema", None) args_schema = getattr(tool, "args_schema", None)
if args_schema is None: if args_schema is None:
# Infer args_schema from the function signature if not provided
func_signature = signature(tool.func) func_signature = signature(tool.func)
annotations = func_signature.parameters fields: dict[str, Any] = {}
args_fields: dict[str, Any] = {} for name, param in func_signature.parameters.items():
for name, param in annotations.items(): if name == "self":
if name != "self": continue
param_annotation = ( param_annotation = (
param.annotation if param.annotation != param.empty else Any param.annotation if param.annotation != param.empty else Any
) )
field_info = Field( if param.default is param.empty:
default=..., fields[name] = (param_annotation, ...)
description="", else:
) fields[name] = (param_annotation, param.default)
args_fields[name] = (param_annotation, field_info) if fields:
if args_fields: args_schema = create_model(f"{tool.name}Input", **fields)
args_schema = create_model(f"{tool.name}Input", **args_fields)
else: else:
# Create a default schema with no fields if no parameters are found
args_schema = create_model( args_schema = create_model(
f"{tool.name}Input", __base__=PydanticBaseModel f"{tool.name}Input", __base__=PydanticBaseModel
) )
@@ -312,10 +311,11 @@ def to_langchain(
def tool( def tool(
*args, result_as_answer: bool = False, max_usage_count: int | None = None *args: Callable[..., Any] | str,
) -> Callable: result_as_answer: bool = False,
""" max_usage_count: int | None = None,
Decorator to create a tool from a function. ) -> Callable[..., Any] | BaseTool:
"""Decorator to create a tool from a function.
Args: Args:
*args: Positional arguments, either the function to decorate or the tool name. *args: Positional arguments, either the function to decorate or the tool name.
@@ -323,26 +323,31 @@ def tool(
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage. max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
""" """
def _make_with_name(tool_name: str) -> Callable: def _make_with_name(tool_name: str) -> Callable[[Callable[..., Any]], BaseTool]:
def _make_tool(f: Callable) -> BaseTool: def _make_tool(f: Callable[..., Any]) -> BaseTool:
if f.__doc__ is None: if f.__doc__ is None:
raise ValueError("Function must have a docstring") raise ValueError("Function must have a docstring")
if f.__annotations__ is None: if f.__annotations__ is None:
raise ValueError("Function must have type annotations") raise ValueError("Function must have type annotations")
func_sig = signature(f)
fields: dict[str, Any] = {}
for param_name, param in func_sig.parameters.items():
if param_name == "return":
continue
annotation = (
param.annotation if param.annotation != param.empty else Any
)
if param.default is param.empty:
fields[param_name] = (annotation, ...)
else:
fields[param_name] = (annotation, param.default)
class_name = "".join(tool_name.split()).title() class_name = "".join(tool_name.split()).title()
args_schema = cast( args_schema = create_model(class_name, **fields)
type[PydanticBaseModel],
type(
class_name,
(PydanticBaseModel,),
{
"__annotations__": {
k: v for k, v in f.__annotations__.items() if k != "return"
},
},
),
)
return Tool( return Tool(
name=tool_name, name=tool_name,

View File

@@ -230,3 +230,67 @@ def test_max_usage_count_is_respected():
crew.kickoff() crew.kickoff()
assert tool.max_usage_count == 5 assert tool.max_usage_count == 5
assert tool.current_usage_count == 5 assert tool.current_usage_count == 5
def test_tool_schema_respects_default_values():
"""Test that tool schema correctly marks optional parameters with defaults."""
class ToolWithDefaults(BaseTool):
name: str = "tool_with_defaults"
description: str = "A tool with optional parameters"
def _run(
self,
query: str,
similarity_threshold: float | None = None,
limit: int = 5,
) -> str:
return f"{query} - {similarity_threshold} - {limit}"
tool = ToolWithDefaults()
schema = tool.args_schema.model_json_schema()
assert schema["required"] == ["query"]
props = schema["properties"]
assert "default" in props["similarity_threshold"]
assert props["similarity_threshold"]["default"] is None
assert "default" in props["limit"]
assert props["limit"]["default"] == 5
def test_tool_decorator_respects_default_values():
"""Test that @tool decorator correctly handles optional parameters with defaults."""
@tool("search_tool")
def search_with_defaults(
query: str, max_results: int = 10, sort_by: str | None = None
) -> str:
"""Search for information with optional parameters."""
return f"{query} - {max_results} - {sort_by}"
schema = search_with_defaults.args_schema.model_json_schema()
assert schema["required"] == ["query"]
props = schema["properties"]
assert "default" in props["max_results"]
assert props["max_results"]["default"] == 10
assert "default" in props["sort_by"]
assert props["sort_by"]["default"] is None
def test_tool_schema_all_required_when_no_defaults():
"""Test that all parameters are required when no defaults are provided."""
class AllRequiredTool(BaseTool):
name: str = "all_required"
description: str = "All params required"
def _run(self, param1: str, param2: int, param3: bool) -> str:
return f"{param1} - {param2} - {param3}"
tool = AllRequiredTool()
schema = tool.args_schema.model_json_schema()
assert set(schema["required"]) == {"param1", "param2", "param3"}