mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-24 11:52:34 +00:00
fix: pass cache_function from BaseTool to CrewStructuredTool
This commit is contained in:
@@ -281,6 +281,7 @@ class BaseTool(BaseModel, ABC):
|
||||
result_as_answer=self.result_as_answer,
|
||||
max_usage_count=self.max_usage_count,
|
||||
current_usage_count=self.current_usage_count,
|
||||
cache_function=self.cache_function,
|
||||
)
|
||||
structured_tool._original_tool = self
|
||||
return structured_tool
|
||||
|
||||
@@ -58,6 +58,7 @@ class CrewStructuredTool:
|
||||
result_as_answer: bool = False,
|
||||
max_usage_count: int | None = None,
|
||||
current_usage_count: int = 0,
|
||||
cache_function: Callable[..., bool] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the structured tool.
|
||||
|
||||
@@ -69,6 +70,7 @@ class CrewStructuredTool:
|
||||
result_as_answer: Whether to return the output directly
|
||||
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
|
||||
current_usage_count: Current number of times this tool has been used.
|
||||
cache_function: Function to determine if the tool result should be cached.
|
||||
"""
|
||||
self.name = name
|
||||
self.description = description
|
||||
@@ -78,6 +80,7 @@ class CrewStructuredTool:
|
||||
self.result_as_answer = result_as_answer
|
||||
self.max_usage_count = max_usage_count
|
||||
self.current_usage_count = current_usage_count
|
||||
self.cache_function = cache_function
|
||||
self._original_tool: BaseTool | None = None
|
||||
|
||||
# Validate the function signature matches the schema
|
||||
@@ -86,7 +89,7 @@ class CrewStructuredTool:
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls,
|
||||
func: Callable,
|
||||
func: Callable[..., Any],
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
return_direct: bool = False,
|
||||
@@ -147,7 +150,7 @@ class CrewStructuredTool:
|
||||
@staticmethod
|
||||
def _create_schema_from_function(
|
||||
name: str,
|
||||
func: Callable,
|
||||
func: Callable[..., Any],
|
||||
) -> type[BaseModel]:
|
||||
"""Create a Pydantic schema from a function's signature.
|
||||
|
||||
@@ -182,7 +185,7 @@ class CrewStructuredTool:
|
||||
|
||||
# Create model
|
||||
schema_name = f"{name.title()}Schema"
|
||||
return create_model(schema_name, **fields) # type: ignore[call-overload]
|
||||
return create_model(schema_name, **fields) # type: ignore[call-overload, no-any-return]
|
||||
|
||||
def _validate_function_signature(self) -> None:
|
||||
"""Validate that the function signature matches the args schema."""
|
||||
@@ -210,7 +213,7 @@ class CrewStructuredTool:
|
||||
f"not found in args_schema"
|
||||
)
|
||||
|
||||
def _parse_args(self, raw_args: str | dict) -> dict:
|
||||
def _parse_args(self, raw_args: str | dict[str, Any]) -> dict[str, Any]:
|
||||
"""Parse and validate the input arguments against the schema.
|
||||
|
||||
Args:
|
||||
@@ -234,8 +237,8 @@ class CrewStructuredTool:
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: str | dict,
|
||||
config: dict | None = None,
|
||||
input: str | dict[str, Any],
|
||||
config: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Asynchronously invoke the tool.
|
||||
@@ -269,7 +272,7 @@ class CrewStructuredTool:
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def _run(self, *args, **kwargs) -> Any:
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Legacy method for compatibility."""
|
||||
# Convert args/kwargs to our expected format
|
||||
input_dict = dict(zip(self.args_schema.model_fields.keys(), args, strict=False))
|
||||
@@ -277,7 +280,10 @@ class CrewStructuredTool:
|
||||
return self.invoke(input_dict)
|
||||
|
||||
def invoke(
|
||||
self, input: str | dict, config: dict | None = None, **kwargs: Any
|
||||
self,
|
||||
input: str | dict[str, Any],
|
||||
config: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Main method for tool execution."""
|
||||
parsed_args = self._parse_args(input)
|
||||
@@ -313,9 +319,10 @@ class CrewStructuredTool:
|
||||
self._original_tool.current_usage_count = self.current_usage_count
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
def args(self) -> dict[str, Any]:
|
||||
"""Get the tool's input arguments schema."""
|
||||
return self.args_schema.model_json_schema()["properties"]
|
||||
schema: dict[str, Any] = self.args_schema.model_json_schema()["properties"]
|
||||
return schema
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"CrewStructuredTool(name='{sanitize_tool_name(self.name)}', description='{self.description}')"
|
||||
|
||||
@@ -38,6 +38,44 @@ def test_initialization(basic_function, schema_class):
|
||||
assert tool.args_schema == schema_class
|
||||
|
||||
|
||||
def test_cache_function_passed_through(basic_function, schema_class):
|
||||
"""Test that cache_function is stored on CrewStructuredTool."""
|
||||
|
||||
def no_cache(_args: dict, _result: str) -> bool:
|
||||
return False
|
||||
|
||||
tool = CrewStructuredTool(
|
||||
name="test_tool",
|
||||
description="Test tool description",
|
||||
func=basic_function,
|
||||
args_schema=schema_class,
|
||||
cache_function=no_cache,
|
||||
)
|
||||
|
||||
assert tool.cache_function is no_cache
|
||||
|
||||
|
||||
def test_base_tool_passes_cache_function_to_structured_tool():
|
||||
"""Test that BaseTool.to_structured_tool propagates cache_function."""
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
def no_cache(_args: dict, _result: str) -> bool:
|
||||
return False
|
||||
|
||||
class MyCacheTool(BaseTool):
|
||||
name: str = "cache_test"
|
||||
description: str = "tool for testing cache passthrough"
|
||||
|
||||
def _run(self, query: str = "") -> str:
|
||||
return "result"
|
||||
|
||||
my_tool = MyCacheTool()
|
||||
my_tool.cache_function = no_cache # type: ignore[assignment]
|
||||
structured = my_tool.to_structured_tool()
|
||||
|
||||
assert structured.cache_function is no_cache
|
||||
|
||||
|
||||
def test_from_function(basic_function):
|
||||
"""Test creating tool from function"""
|
||||
tool = CrewStructuredTool.from_function(
|
||||
|
||||
Reference in New Issue
Block a user