diff --git a/src/crewai/tools/structured_tool.py b/src/crewai/tools/structured_tool.py index 47faa1d9c..0409819cc 100644 --- a/src/crewai/tools/structured_tool.py +++ b/src/crewai/tools/structured_tool.py @@ -66,7 +66,7 @@ class CrewStructuredTool: @classmethod def from_function( cls, - func: Callable, + func: Callable[..., Any], name: Optional[str] = None, description: Optional[str] = None, return_direct: bool = False, @@ -127,7 +127,7 @@ class CrewStructuredTool: @staticmethod def _create_schema_from_function( name: str, - func: Callable, + func: Callable[..., Any], ) -> type[BaseModel]: """Create a Pydantic schema from a function's signature. @@ -162,7 +162,7 @@ class CrewStructuredTool: # Create model schema_name = f"{name.title()}Schema" - return create_model(schema_name, **fields) + return create_model(schema_name, **fields) # type: ignore[call-overload,no-any-return] def _validate_function_signature(self) -> None: """Validate that the function signature matches the args schema.""" @@ -190,7 +190,7 @@ class CrewStructuredTool: f"not found in args_schema" ) - def _parse_args(self, raw_args: str | dict) -> dict: + def _parse_args(self, raw_args: str | dict[str, Any]) -> dict[str, Any]: """Parse and validate the input arguments against the schema. Args: @@ -215,8 +215,8 @@ class CrewStructuredTool: async def ainvoke( self, - input: str | dict, - config: Optional[dict] = None, + input: str | dict[str, Any], + config: Optional[dict[str, Any]] = None, **kwargs: Any, ) -> Any: """Asynchronously invoke the tool. @@ -251,7 +251,7 @@ class CrewStructuredTool: except Exception: raise - def _run(self, *args, **kwargs) -> Any: + def _run(self, *args: Any, **kwargs: Any) -> Any: """Legacy method for compatibility.""" # Convert args/kwargs to our expected format input_dict = dict(zip(self.args_schema.model_fields.keys(), args)) @@ -259,7 +259,10 @@ class CrewStructuredTool: return self.invoke(input_dict) def invoke( - self, input: str | dict, config: Optional[dict] = None, **kwargs: Any + self, + input: str | dict[str, Any], + config: Optional[dict[str, Any]] = None, + **kwargs: Any, ) -> Any: """Main method for tool execution.""" parsed_args = self._parse_args(input) @@ -319,9 +322,9 @@ class CrewStructuredTool: self._original_tool.current_usage_count = self.current_usage_count @property - def args(self) -> dict: + def args(self) -> dict[str, Any]: """Get the tool's input arguments schema.""" - return self.args_schema.model_json_schema()["properties"] + return self.args_schema.model_json_schema()["properties"] # type: ignore[no-any-return] def __repr__(self) -> str: return ( diff --git a/tests/tools/test_structured_tool.py b/tests/tools/test_structured_tool.py index b65149768..c454641ec 100644 --- a/tests/tools/test_structured_tool.py +++ b/tests/tools/test_structured_tool.py @@ -1,14 +1,16 @@ -from typing import Optional +from collections.abc import Callable +from typing import Any, Optional import pytest from pydantic import BaseModel, Field +from crewai.tools import BaseTool from crewai.tools.structured_tool import CrewStructuredTool # Test fixtures @pytest.fixture -def basic_function(): +def basic_function() -> Callable[[str, int], str]: def test_func(param1: str, param2: int = 0) -> str: """Test function with basic params.""" return f"{param1} {param2}" @@ -17,7 +19,7 @@ def basic_function(): @pytest.fixture -def schema_class(): +def schema_class() -> type[BaseModel]: class TestSchema(BaseModel): param1: str param2: int = Field(default=0) @@ -25,7 +27,9 @@ def schema_class(): return TestSchema -def test_initialization(basic_function, schema_class): +def test_initialization( + basic_function: Callable[[str], str], schema_class: type[BaseModel] +) -> None: """Test basic initialization of CrewStructuredTool""" tool = CrewStructuredTool( name="test_tool", @@ -40,7 +44,7 @@ def test_initialization(basic_function, schema_class): assert tool.args_schema == schema_class -def test_from_function(basic_function): +def test_from_function(basic_function: Callable[[str], str]) -> None: """Test creating tool from function""" tool = CrewStructuredTool.from_function( func=basic_function, name="test_tool", description="Test description" @@ -52,7 +56,9 @@ def test_from_function(basic_function): assert isinstance(tool.args_schema, type(BaseModel)) -def test_validate_function_signature(basic_function, schema_class): +def test_validate_function_signature( + basic_function: Callable[[str, int], str], schema_class: type[BaseModel] +) -> None: """Test function signature validation""" tool = CrewStructuredTool( name="test_tool", @@ -66,7 +72,7 @@ def test_validate_function_signature(basic_function, schema_class): @pytest.mark.asyncio -async def test_ainvoke(basic_function): +async def test_ainvoke(basic_function: Callable[[str, int], str]) -> None: """Test asynchronous invocation""" tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool") @@ -74,7 +80,7 @@ async def test_ainvoke(basic_function): assert result == "test 0" -def test_parse_args_dict(basic_function): +def test_parse_args_dict(basic_function: Callable[[str, int], str]) -> None: """Test parsing dictionary arguments""" tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool") @@ -83,7 +89,7 @@ def test_parse_args_dict(basic_function): assert parsed["param2"] == 42 -def test_parse_args_string(basic_function): +def test_parse_args_string(basic_function: Callable[[str, int], str]) -> None: """Test parsing string arguments""" tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool") @@ -92,10 +98,10 @@ def test_parse_args_string(basic_function): assert parsed["param2"] == 42 -def test_complex_types(): +def test_complex_types() -> None: """Test handling of complex parameter types""" - def complex_func(nested: dict, items: list) -> str: + def complex_func(nested: dict[str, Any], items: list[Any]) -> str: """Process complex types.""" return f"Processed {len(items)} items with {len(nested)} nested keys" @@ -106,7 +112,7 @@ def test_complex_types(): assert result == "Processed 3 items with 1 nested keys" -def test_schema_inheritance(): +def test_schema_inheritance() -> None: """Test tool creation with inherited schema""" def extended_func(base_param: str, extra_param: int) -> str: @@ -127,7 +133,7 @@ def test_schema_inheritance(): assert result == "test 42" -def test_default_values_in_schema(): +def test_default_values_in_schema() -> None: """Test handling of default values in schema""" def default_func( @@ -154,11 +160,11 @@ def test_default_values_in_schema(): @pytest.fixture -def custom_tool_decorator(): +def custom_tool_decorator() -> Any: from crewai.tools import tool @tool("custom_tool", result_as_answer=True) - async def custom_tool(): + async def custom_tool() -> str: """This is a tool that does something""" return "Hello World from Custom Tool" @@ -166,7 +172,7 @@ def custom_tool_decorator(): @pytest.fixture -def custom_tool(): +def custom_tool() -> BaseTool: from crewai.tools import BaseTool class CustomTool(BaseTool): @@ -174,13 +180,13 @@ def custom_tool(): description: str = "This is a tool that does something" result_as_answer: bool = True - async def _run(self): + async def _run(self) -> str: return "Hello World from Custom Tool" return CustomTool() -def build_simple_crew(tool): +def build_simple_crew(tool: Any) -> Any: from crewai import Agent, Crew, Task agent1 = Agent( @@ -201,7 +207,7 @@ def build_simple_crew(tool): @pytest.mark.vcr(filter_headers=["authorization"]) -def test_async_tool_using_within_isolated_crew(custom_tool): +def test_async_tool_using_within_isolated_crew(custom_tool: BaseTool) -> None: crew = build_simple_crew(custom_tool) result = crew.kickoff() @@ -209,7 +215,9 @@ def test_async_tool_using_within_isolated_crew(custom_tool): @pytest.mark.vcr(filter_headers=["authorization"]) -def test_async_tool_using_decorator_within_isolated_crew(custom_tool_decorator): +def test_async_tool_using_decorator_within_isolated_crew( + custom_tool_decorator: Any, +) -> None: crew = build_simple_crew(custom_tool_decorator) result = crew.kickoff() @@ -217,14 +225,12 @@ def test_async_tool_using_decorator_within_isolated_crew(custom_tool_decorator): @pytest.mark.vcr(filter_headers=["authorization"]) -def test_async_tool_within_flow(custom_tool): - from crewai.flow.flow import Flow - - class StructuredExampleFlow(Flow): - from crewai.flow.flow import start +def test_async_tool_within_flow(custom_tool: BaseTool) -> None: + from crewai.flow.flow import Flow, start + class StructuredExampleFlow(Flow): # type: ignore[type-arg] @start() - async def start(self): + async def start(self) -> Any: crew = build_simple_crew(custom_tool) result = await crew.kickoff_async() return result @@ -235,14 +241,12 @@ def test_async_tool_within_flow(custom_tool): @pytest.mark.vcr(filter_headers=["authorization"]) -def test_async_tool_using_decorator_within_flow(custom_tool_decorator): - from crewai.flow.flow import Flow - - class StructuredExampleFlow(Flow): - from crewai.flow.flow import start +def test_async_tool_using_decorator_within_flow(custom_tool_decorator: Any) -> None: + from crewai.flow.flow import Flow, start + class StructuredExampleFlow(Flow): # type: ignore[type-arg] @start() - async def start(self): + async def start(self) -> Any: crew = build_simple_crew(custom_tool_decorator) result = await crew.kickoff_async() return result @@ -252,7 +256,7 @@ def test_async_tool_using_decorator_within_flow(custom_tool_decorator): assert result.raw == "Hello World from Custom Tool" -def test_invoke_sync_function_single_execution(): +def test_invoke_sync_function_single_execution() -> None: """Test that sync functions are called only once, not twice.""" call_count = 0 @@ -270,7 +274,7 @@ def test_invoke_sync_function_single_execution(): assert result == "Called 1 times with: test" -def test_invoke_async_function_outside_event_loop(): +def test_invoke_async_function_outside_event_loop() -> None: """Test that async functions work correctly when called outside event loop.""" async def async_func(message: str) -> str: @@ -285,7 +289,7 @@ def test_invoke_async_function_outside_event_loop(): @pytest.mark.asyncio -async def test_invoke_async_function_in_event_loop_raises_error(): +async def test_invoke_async_function_in_event_loop_raises_error() -> None: """Test that async functions raise RuntimeError when called from within event loop.""" async def async_func(message: str) -> str: @@ -302,13 +306,13 @@ async def test_invoke_async_function_in_event_loop_raises_error(): tool.invoke({"message": "test"}) -def test_invoke_sync_function_returning_coroutine(): +def test_invoke_sync_function_returning_coroutine() -> None: """Test handling of sync functions that return coroutines.""" async def inner_async(message: str) -> str: return f"Inner async: {message}" - def sync_func_returning_coro(message: str): + def sync_func_returning_coro(message: str) -> Any: return inner_async(message) tool = CrewStructuredTool.from_function( @@ -322,13 +326,15 @@ def test_invoke_sync_function_returning_coroutine(): @pytest.mark.asyncio -async def test_invoke_sync_function_returning_coroutine_in_event_loop_raises_error(): +async def test_invoke_sync_function_returning_coroutine_in_event_loop_raises_error() -> ( + None +): """Test that sync functions returning coroutines raise RuntimeError in event loop.""" async def inner_async(message: str) -> str: return f"Inner async: {message}" - def sync_func_returning_coro(message: str): + def sync_func_returning_coro(message: str) -> Any: return inner_async(message) tool = CrewStructuredTool.from_function(