diff --git a/lib/crewai/src/crewai/tools/base_tool.py b/lib/crewai/src/crewai/tools/base_tool.py index 37e9fba09..118fa307b 100644 --- a/lib/crewai/src/crewai/tools/base_tool.py +++ b/lib/crewai/src/crewai/tools/base_tool.py @@ -281,6 +281,7 @@ class BaseTool(BaseModel, ABC): result_as_answer=self.result_as_answer, max_usage_count=self.max_usage_count, current_usage_count=self.current_usage_count, + cache_function=self.cache_function, ) structured_tool._original_tool = self return structured_tool diff --git a/lib/crewai/src/crewai/tools/structured_tool.py b/lib/crewai/src/crewai/tools/structured_tool.py index 4b95caeb7..60a457f3b 100644 --- a/lib/crewai/src/crewai/tools/structured_tool.py +++ b/lib/crewai/src/crewai/tools/structured_tool.py @@ -58,6 +58,7 @@ class CrewStructuredTool: result_as_answer: bool = False, max_usage_count: int | None = None, current_usage_count: int = 0, + cache_function: Callable[..., bool] | None = None, ) -> None: """Initialize the structured tool. @@ -69,6 +70,7 @@ class CrewStructuredTool: result_as_answer: Whether to return the output directly max_usage_count: Maximum number of times this tool can be used. None means unlimited usage. current_usage_count: Current number of times this tool has been used. + cache_function: Function to determine if the tool result should be cached. """ self.name = name self.description = description @@ -78,6 +80,7 @@ class CrewStructuredTool: self.result_as_answer = result_as_answer self.max_usage_count = max_usage_count self.current_usage_count = current_usage_count + self.cache_function = cache_function self._original_tool: BaseTool | None = None # Validate the function signature matches the schema @@ -86,7 +89,7 @@ class CrewStructuredTool: @classmethod def from_function( cls, - func: Callable, + func: Callable[..., Any], name: str | None = None, description: str | None = None, return_direct: bool = False, @@ -147,7 +150,7 @@ class CrewStructuredTool: @staticmethod def _create_schema_from_function( name: str, - func: Callable, + func: Callable[..., Any], ) -> type[BaseModel]: """Create a Pydantic schema from a function's signature. @@ -182,7 +185,7 @@ class CrewStructuredTool: # Create model schema_name = f"{name.title()}Schema" - return create_model(schema_name, **fields) # type: ignore[call-overload] + return create_model(schema_name, **fields) # type: ignore[call-overload, no-any-return] def _validate_function_signature(self) -> None: """Validate that the function signature matches the args schema.""" @@ -210,7 +213,7 @@ class CrewStructuredTool: f"not found in args_schema" ) - def _parse_args(self, raw_args: str | dict) -> dict: + def _parse_args(self, raw_args: str | dict[str, Any]) -> dict[str, Any]: """Parse and validate the input arguments against the schema. Args: @@ -234,8 +237,8 @@ class CrewStructuredTool: async def ainvoke( self, - input: str | dict, - config: dict | None = None, + input: str | dict[str, Any], + config: dict[str, Any] | None = None, **kwargs: Any, ) -> Any: """Asynchronously invoke the tool. @@ -269,7 +272,7 @@ class CrewStructuredTool: except Exception: raise - def _run(self, *args, **kwargs) -> Any: + def _run(self, *args: Any, **kwargs: Any) -> Any: """Legacy method for compatibility.""" # Convert args/kwargs to our expected format input_dict = dict(zip(self.args_schema.model_fields.keys(), args, strict=False)) @@ -277,7 +280,10 @@ class CrewStructuredTool: return self.invoke(input_dict) def invoke( - self, input: str | dict, config: dict | None = None, **kwargs: Any + self, + input: str | dict[str, Any], + config: dict[str, Any] | None = None, + **kwargs: Any, ) -> Any: """Main method for tool execution.""" parsed_args = self._parse_args(input) @@ -313,9 +319,10 @@ class CrewStructuredTool: self._original_tool.current_usage_count = self.current_usage_count @property - def args(self) -> dict: + def args(self) -> dict[str, Any]: """Get the tool's input arguments schema.""" - return self.args_schema.model_json_schema()["properties"] + schema: dict[str, Any] = self.args_schema.model_json_schema()["properties"] + return schema def __repr__(self) -> str: return f"CrewStructuredTool(name='{sanitize_tool_name(self.name)}', description='{self.description}')" diff --git a/lib/crewai/tests/tools/test_structured_tool.py b/lib/crewai/tests/tools/test_structured_tool.py index 999c13072..1cb8b3138 100644 --- a/lib/crewai/tests/tools/test_structured_tool.py +++ b/lib/crewai/tests/tools/test_structured_tool.py @@ -38,6 +38,44 @@ def test_initialization(basic_function, schema_class): assert tool.args_schema == schema_class +def test_cache_function_passed_through(basic_function, schema_class): + """Test that cache_function is stored on CrewStructuredTool.""" + + def no_cache(_args: dict, _result: str) -> bool: + return False + + tool = CrewStructuredTool( + name="test_tool", + description="Test tool description", + func=basic_function, + args_schema=schema_class, + cache_function=no_cache, + ) + + assert tool.cache_function is no_cache + + +def test_base_tool_passes_cache_function_to_structured_tool(): + """Test that BaseTool.to_structured_tool propagates cache_function.""" + from crewai.tools import BaseTool + + def no_cache(_args: dict, _result: str) -> bool: + return False + + class MyCacheTool(BaseTool): + name: str = "cache_test" + description: str = "tool for testing cache passthrough" + + def _run(self, query: str = "") -> str: + return "result" + + my_tool = MyCacheTool() + my_tool.cache_function = no_cache # type: ignore[assignment] + structured = my_tool.to_structured_tool() + + assert structured.cache_function is no_cache + + def test_from_function(basic_function): """Test creating tool from function""" tool = CrewStructuredTool.from_function(