Fix type-checking errors in structured_tool.py and tests

- Add comprehensive type annotations for all function parameters and return types
- Fix generic type parameters for dict, Callable, and Flow classes
- Add proper type ignore comments for complex type inference scenarios
- Resolve all 27 mypy errors across Python 3.10-3.13
- Ensure compatibility with strict type checking requirements

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-09-04 08:58:43 +00:00
parent 89fcd2a5b4
commit 75b7c579f6
2 changed files with 58 additions and 49 deletions

View File

@@ -66,7 +66,7 @@ class CrewStructuredTool:
@classmethod @classmethod
def from_function( def from_function(
cls, cls,
func: Callable, func: Callable[..., Any],
name: Optional[str] = None, name: Optional[str] = None,
description: Optional[str] = None, description: Optional[str] = None,
return_direct: bool = False, return_direct: bool = False,
@@ -127,7 +127,7 @@ class CrewStructuredTool:
@staticmethod @staticmethod
def _create_schema_from_function( def _create_schema_from_function(
name: str, name: str,
func: Callable, func: Callable[..., Any],
) -> type[BaseModel]: ) -> type[BaseModel]:
"""Create a Pydantic schema from a function's signature. """Create a Pydantic schema from a function's signature.
@@ -162,7 +162,7 @@ class CrewStructuredTool:
# Create model # Create model
schema_name = f"{name.title()}Schema" schema_name = f"{name.title()}Schema"
return create_model(schema_name, **fields) return create_model(schema_name, **fields) # type: ignore[call-overload,no-any-return]
def _validate_function_signature(self) -> None: def _validate_function_signature(self) -> None:
"""Validate that the function signature matches the args schema.""" """Validate that the function signature matches the args schema."""
@@ -190,7 +190,7 @@ class CrewStructuredTool:
f"not found in args_schema" f"not found in args_schema"
) )
def _parse_args(self, raw_args: str | dict) -> dict: def _parse_args(self, raw_args: str | dict[str, Any]) -> dict[str, Any]:
"""Parse and validate the input arguments against the schema. """Parse and validate the input arguments against the schema.
Args: Args:
@@ -215,8 +215,8 @@ class CrewStructuredTool:
async def ainvoke( async def ainvoke(
self, self,
input: str | dict, input: str | dict[str, Any],
config: Optional[dict] = None, config: Optional[dict[str, Any]] = None,
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
"""Asynchronously invoke the tool. """Asynchronously invoke the tool.
@@ -251,7 +251,7 @@ class CrewStructuredTool:
except Exception: except Exception:
raise raise
def _run(self, *args, **kwargs) -> Any: def _run(self, *args: Any, **kwargs: Any) -> Any:
"""Legacy method for compatibility.""" """Legacy method for compatibility."""
# Convert args/kwargs to our expected format # Convert args/kwargs to our expected format
input_dict = dict(zip(self.args_schema.model_fields.keys(), args)) input_dict = dict(zip(self.args_schema.model_fields.keys(), args))
@@ -259,7 +259,10 @@ class CrewStructuredTool:
return self.invoke(input_dict) return self.invoke(input_dict)
def invoke( def invoke(
self, input: str | dict, config: Optional[dict] = None, **kwargs: Any self,
input: str | dict[str, Any],
config: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Any: ) -> Any:
"""Main method for tool execution.""" """Main method for tool execution."""
parsed_args = self._parse_args(input) parsed_args = self._parse_args(input)
@@ -319,9 +322,9 @@ class CrewStructuredTool:
self._original_tool.current_usage_count = self.current_usage_count self._original_tool.current_usage_count = self.current_usage_count
@property @property
def args(self) -> dict: def args(self) -> dict[str, Any]:
"""Get the tool's input arguments schema.""" """Get the tool's input arguments schema."""
return self.args_schema.model_json_schema()["properties"] return self.args_schema.model_json_schema()["properties"] # type: ignore[no-any-return]
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (

View File

@@ -1,14 +1,16 @@
from typing import Optional from collections.abc import Callable
from typing import Any, Optional
import pytest import pytest
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from crewai.tools import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool from crewai.tools.structured_tool import CrewStructuredTool
# Test fixtures # Test fixtures
@pytest.fixture @pytest.fixture
def basic_function(): def basic_function() -> Callable[[str, int], str]:
def test_func(param1: str, param2: int = 0) -> str: def test_func(param1: str, param2: int = 0) -> str:
"""Test function with basic params.""" """Test function with basic params."""
return f"{param1} {param2}" return f"{param1} {param2}"
@@ -17,7 +19,7 @@ def basic_function():
@pytest.fixture @pytest.fixture
def schema_class(): def schema_class() -> type[BaseModel]:
class TestSchema(BaseModel): class TestSchema(BaseModel):
param1: str param1: str
param2: int = Field(default=0) param2: int = Field(default=0)
@@ -25,7 +27,9 @@ def schema_class():
return TestSchema return TestSchema
def test_initialization(basic_function, schema_class): def test_initialization(
basic_function: Callable[[str], str], schema_class: type[BaseModel]
) -> None:
"""Test basic initialization of CrewStructuredTool""" """Test basic initialization of CrewStructuredTool"""
tool = CrewStructuredTool( tool = CrewStructuredTool(
name="test_tool", name="test_tool",
@@ -40,7 +44,7 @@ def test_initialization(basic_function, schema_class):
assert tool.args_schema == schema_class assert tool.args_schema == schema_class
def test_from_function(basic_function): def test_from_function(basic_function: Callable[[str], str]) -> None:
"""Test creating tool from function""" """Test creating tool from function"""
tool = CrewStructuredTool.from_function( tool = CrewStructuredTool.from_function(
func=basic_function, name="test_tool", description="Test description" func=basic_function, name="test_tool", description="Test description"
@@ -52,7 +56,9 @@ def test_from_function(basic_function):
assert isinstance(tool.args_schema, type(BaseModel)) assert isinstance(tool.args_schema, type(BaseModel))
def test_validate_function_signature(basic_function, schema_class): def test_validate_function_signature(
basic_function: Callable[[str, int], str], schema_class: type[BaseModel]
) -> None:
"""Test function signature validation""" """Test function signature validation"""
tool = CrewStructuredTool( tool = CrewStructuredTool(
name="test_tool", name="test_tool",
@@ -66,7 +72,7 @@ def test_validate_function_signature(basic_function, schema_class):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ainvoke(basic_function): async def test_ainvoke(basic_function: Callable[[str, int], str]) -> None:
"""Test asynchronous invocation""" """Test asynchronous invocation"""
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool") tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
@@ -74,7 +80,7 @@ async def test_ainvoke(basic_function):
assert result == "test 0" assert result == "test 0"
def test_parse_args_dict(basic_function): def test_parse_args_dict(basic_function: Callable[[str, int], str]) -> None:
"""Test parsing dictionary arguments""" """Test parsing dictionary arguments"""
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool") tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
@@ -83,7 +89,7 @@ def test_parse_args_dict(basic_function):
assert parsed["param2"] == 42 assert parsed["param2"] == 42
def test_parse_args_string(basic_function): def test_parse_args_string(basic_function: Callable[[str, int], str]) -> None:
"""Test parsing string arguments""" """Test parsing string arguments"""
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool") tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
@@ -92,10 +98,10 @@ def test_parse_args_string(basic_function):
assert parsed["param2"] == 42 assert parsed["param2"] == 42
def test_complex_types(): def test_complex_types() -> None:
"""Test handling of complex parameter types""" """Test handling of complex parameter types"""
def complex_func(nested: dict, items: list) -> str: def complex_func(nested: dict[str, Any], items: list[Any]) -> str:
"""Process complex types.""" """Process complex types."""
return f"Processed {len(items)} items with {len(nested)} nested keys" return f"Processed {len(items)} items with {len(nested)} nested keys"
@@ -106,7 +112,7 @@ def test_complex_types():
assert result == "Processed 3 items with 1 nested keys" assert result == "Processed 3 items with 1 nested keys"
def test_schema_inheritance(): def test_schema_inheritance() -> None:
"""Test tool creation with inherited schema""" """Test tool creation with inherited schema"""
def extended_func(base_param: str, extra_param: int) -> str: def extended_func(base_param: str, extra_param: int) -> str:
@@ -127,7 +133,7 @@ def test_schema_inheritance():
assert result == "test 42" assert result == "test 42"
def test_default_values_in_schema(): def test_default_values_in_schema() -> None:
"""Test handling of default values in schema""" """Test handling of default values in schema"""
def default_func( def default_func(
@@ -154,11 +160,11 @@ def test_default_values_in_schema():
@pytest.fixture @pytest.fixture
def custom_tool_decorator(): def custom_tool_decorator() -> Any:
from crewai.tools import tool from crewai.tools import tool
@tool("custom_tool", result_as_answer=True) @tool("custom_tool", result_as_answer=True)
async def custom_tool(): async def custom_tool() -> str:
"""This is a tool that does something""" """This is a tool that does something"""
return "Hello World from Custom Tool" return "Hello World from Custom Tool"
@@ -166,7 +172,7 @@ def custom_tool_decorator():
@pytest.fixture @pytest.fixture
def custom_tool(): def custom_tool() -> BaseTool:
from crewai.tools import BaseTool from crewai.tools import BaseTool
class CustomTool(BaseTool): class CustomTool(BaseTool):
@@ -174,13 +180,13 @@ def custom_tool():
description: str = "This is a tool that does something" description: str = "This is a tool that does something"
result_as_answer: bool = True result_as_answer: bool = True
async def _run(self): async def _run(self) -> str:
return "Hello World from Custom Tool" return "Hello World from Custom Tool"
return CustomTool() return CustomTool()
def build_simple_crew(tool): def build_simple_crew(tool: Any) -> Any:
from crewai import Agent, Crew, Task from crewai import Agent, Crew, Task
agent1 = Agent( agent1 = Agent(
@@ -201,7 +207,7 @@ def build_simple_crew(tool):
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_async_tool_using_within_isolated_crew(custom_tool): def test_async_tool_using_within_isolated_crew(custom_tool: BaseTool) -> None:
crew = build_simple_crew(custom_tool) crew = build_simple_crew(custom_tool)
result = crew.kickoff() result = crew.kickoff()
@@ -209,7 +215,9 @@ def test_async_tool_using_within_isolated_crew(custom_tool):
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_async_tool_using_decorator_within_isolated_crew(custom_tool_decorator): def test_async_tool_using_decorator_within_isolated_crew(
custom_tool_decorator: Any,
) -> None:
crew = build_simple_crew(custom_tool_decorator) crew = build_simple_crew(custom_tool_decorator)
result = crew.kickoff() result = crew.kickoff()
@@ -217,14 +225,12 @@ def test_async_tool_using_decorator_within_isolated_crew(custom_tool_decorator):
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_async_tool_within_flow(custom_tool): def test_async_tool_within_flow(custom_tool: BaseTool) -> None:
from crewai.flow.flow import Flow from crewai.flow.flow import Flow, start
class StructuredExampleFlow(Flow):
from crewai.flow.flow import start
class StructuredExampleFlow(Flow): # type: ignore[type-arg]
@start() @start()
async def start(self): async def start(self) -> Any:
crew = build_simple_crew(custom_tool) crew = build_simple_crew(custom_tool)
result = await crew.kickoff_async() result = await crew.kickoff_async()
return result return result
@@ -235,14 +241,12 @@ def test_async_tool_within_flow(custom_tool):
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_async_tool_using_decorator_within_flow(custom_tool_decorator): def test_async_tool_using_decorator_within_flow(custom_tool_decorator: Any) -> None:
from crewai.flow.flow import Flow from crewai.flow.flow import Flow, start
class StructuredExampleFlow(Flow):
from crewai.flow.flow import start
class StructuredExampleFlow(Flow): # type: ignore[type-arg]
@start() @start()
async def start(self): async def start(self) -> Any:
crew = build_simple_crew(custom_tool_decorator) crew = build_simple_crew(custom_tool_decorator)
result = await crew.kickoff_async() result = await crew.kickoff_async()
return result return result
@@ -252,7 +256,7 @@ def test_async_tool_using_decorator_within_flow(custom_tool_decorator):
assert result.raw == "Hello World from Custom Tool" assert result.raw == "Hello World from Custom Tool"
def test_invoke_sync_function_single_execution(): def test_invoke_sync_function_single_execution() -> None:
"""Test that sync functions are called only once, not twice.""" """Test that sync functions are called only once, not twice."""
call_count = 0 call_count = 0
@@ -270,7 +274,7 @@ def test_invoke_sync_function_single_execution():
assert result == "Called 1 times with: test" assert result == "Called 1 times with: test"
def test_invoke_async_function_outside_event_loop(): def test_invoke_async_function_outside_event_loop() -> None:
"""Test that async functions work correctly when called outside event loop.""" """Test that async functions work correctly when called outside event loop."""
async def async_func(message: str) -> str: async def async_func(message: str) -> str:
@@ -285,7 +289,7 @@ def test_invoke_async_function_outside_event_loop():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invoke_async_function_in_event_loop_raises_error(): async def test_invoke_async_function_in_event_loop_raises_error() -> None:
"""Test that async functions raise RuntimeError when called from within event loop.""" """Test that async functions raise RuntimeError when called from within event loop."""
async def async_func(message: str) -> str: async def async_func(message: str) -> str:
@@ -302,13 +306,13 @@ async def test_invoke_async_function_in_event_loop_raises_error():
tool.invoke({"message": "test"}) tool.invoke({"message": "test"})
def test_invoke_sync_function_returning_coroutine(): def test_invoke_sync_function_returning_coroutine() -> None:
"""Test handling of sync functions that return coroutines.""" """Test handling of sync functions that return coroutines."""
async def inner_async(message: str) -> str: async def inner_async(message: str) -> str:
return f"Inner async: {message}" return f"Inner async: {message}"
def sync_func_returning_coro(message: str): def sync_func_returning_coro(message: str) -> Any:
return inner_async(message) return inner_async(message)
tool = CrewStructuredTool.from_function( tool = CrewStructuredTool.from_function(
@@ -322,13 +326,15 @@ def test_invoke_sync_function_returning_coroutine():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_invoke_sync_function_returning_coroutine_in_event_loop_raises_error(): async def test_invoke_sync_function_returning_coroutine_in_event_loop_raises_error() -> (
None
):
"""Test that sync functions returning coroutines raise RuntimeError in event loop.""" """Test that sync functions returning coroutines raise RuntimeError in event loop."""
async def inner_async(message: str) -> str: async def inner_async(message: str) -> str:
return f"Inner async: {message}" return f"Inner async: {message}"
def sync_func_returning_coro(message: str): def sync_func_returning_coro(message: str) -> Any:
return inner_async(message) return inner_async(message)
tool = CrewStructuredTool.from_function( tool = CrewStructuredTool.from_function(