fix: pass cache_function from BaseTool to CrewStructuredTool

This commit is contained in:
Greyson LaLonde
2026-03-20 16:04:52 -04:00
committed by GitHub
parent 8e427164ca
commit f13d307534
3 changed files with 56 additions and 10 deletions

View File

@@ -281,6 +281,7 @@ class BaseTool(BaseModel, ABC):
result_as_answer=self.result_as_answer,
max_usage_count=self.max_usage_count,
current_usage_count=self.current_usage_count,
cache_function=self.cache_function,
)
structured_tool._original_tool = self
return structured_tool

View File

@@ -58,6 +58,7 @@ class CrewStructuredTool:
result_as_answer: bool = False,
max_usage_count: int | None = None,
current_usage_count: int = 0,
cache_function: Callable[..., bool] | None = None,
) -> None:
"""Initialize the structured tool.
@@ -69,6 +70,7 @@ class CrewStructuredTool:
result_as_answer: Whether to return the output directly
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
current_usage_count: Current number of times this tool has been used.
cache_function: Function to determine if the tool result should be cached.
"""
self.name = name
self.description = description
@@ -78,6 +80,7 @@ class CrewStructuredTool:
self.result_as_answer = result_as_answer
self.max_usage_count = max_usage_count
self.current_usage_count = current_usage_count
self.cache_function = cache_function
self._original_tool: BaseTool | None = None
# Validate the function signature matches the schema
@@ -86,7 +89,7 @@ class CrewStructuredTool:
@classmethod
def from_function(
cls,
func: Callable,
func: Callable[..., Any],
name: str | None = None,
description: str | None = None,
return_direct: bool = False,
@@ -147,7 +150,7 @@ class CrewStructuredTool:
@staticmethod
def _create_schema_from_function(
name: str,
func: Callable,
func: Callable[..., Any],
) -> type[BaseModel]:
"""Create a Pydantic schema from a function's signature.
@@ -182,7 +185,7 @@ class CrewStructuredTool:
# Create model
schema_name = f"{name.title()}Schema"
return create_model(schema_name, **fields) # type: ignore[call-overload]
return create_model(schema_name, **fields) # type: ignore[call-overload, no-any-return]
def _validate_function_signature(self) -> None:
"""Validate that the function signature matches the args schema."""
@@ -210,7 +213,7 @@ class CrewStructuredTool:
f"not found in args_schema"
)
def _parse_args(self, raw_args: str | dict) -> dict:
def _parse_args(self, raw_args: str | dict[str, Any]) -> dict[str, Any]:
"""Parse and validate the input arguments against the schema.
Args:
@@ -234,8 +237,8 @@ class CrewStructuredTool:
async def ainvoke(
self,
input: str | dict,
config: dict | None = None,
input: str | dict[str, Any],
config: dict[str, Any] | None = None,
**kwargs: Any,
) -> Any:
"""Asynchronously invoke the tool.
@@ -269,7 +272,7 @@ class CrewStructuredTool:
except Exception:
raise
def _run(self, *args, **kwargs) -> Any:
def _run(self, *args: Any, **kwargs: Any) -> Any:
"""Legacy method for compatibility."""
# Convert args/kwargs to our expected format
input_dict = dict(zip(self.args_schema.model_fields.keys(), args, strict=False))
@@ -277,7 +280,10 @@ class CrewStructuredTool:
return self.invoke(input_dict)
def invoke(
self, input: str | dict, config: dict | None = None, **kwargs: Any
self,
input: str | dict[str, Any],
config: dict[str, Any] | None = None,
**kwargs: Any,
) -> Any:
"""Main method for tool execution."""
parsed_args = self._parse_args(input)
@@ -313,9 +319,10 @@ class CrewStructuredTool:
self._original_tool.current_usage_count = self.current_usage_count
@property
def args(self) -> dict:
def args(self) -> dict[str, Any]:
"""Get the tool's input arguments schema."""
return self.args_schema.model_json_schema()["properties"]
schema: dict[str, Any] = self.args_schema.model_json_schema()["properties"]
return schema
def __repr__(self) -> str:
return f"CrewStructuredTool(name='{sanitize_tool_name(self.name)}', description='{self.description}')"

View File

@@ -38,6 +38,44 @@ def test_initialization(basic_function, schema_class):
assert tool.args_schema == schema_class
def test_cache_function_passed_through(basic_function, schema_class):
"""Test that cache_function is stored on CrewStructuredTool."""
def no_cache(_args: dict, _result: str) -> bool:
return False
tool = CrewStructuredTool(
name="test_tool",
description="Test tool description",
func=basic_function,
args_schema=schema_class,
cache_function=no_cache,
)
assert tool.cache_function is no_cache
def test_base_tool_passes_cache_function_to_structured_tool():
"""Test that BaseTool.to_structured_tool propagates cache_function."""
from crewai.tools import BaseTool
def no_cache(_args: dict, _result: str) -> bool:
return False
class MyCacheTool(BaseTool):
name: str = "cache_test"
description: str = "tool for testing cache passthrough"
def _run(self, query: str = "") -> str:
return "result"
my_tool = MyCacheTool()
my_tool.cache_function = no_cache # type: ignore[assignment]
structured = my_tool.to_structured_tool()
assert structured.cache_function is no_cache
def test_from_function(basic_function):
"""Test creating tool from function"""
tool = CrewStructuredTool.from_function(