mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Fix type-checking errors in structured_tool.py and tests
- Add comprehensive type annotations for all function parameters and return types - Fix generic type parameters for dict, Callable, and Flow classes - Add proper type ignore comments for complex type inference scenarios - Resolve all 27 mypy errors across Python 3.10-3.13 - Ensure compatibility with strict type checking requirements Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -66,7 +66,7 @@ class CrewStructuredTool:
|
||||
@classmethod
|
||||
def from_function(
|
||||
cls,
|
||||
func: Callable,
|
||||
func: Callable[..., Any],
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
return_direct: bool = False,
|
||||
@@ -127,7 +127,7 @@ class CrewStructuredTool:
|
||||
@staticmethod
|
||||
def _create_schema_from_function(
|
||||
name: str,
|
||||
func: Callable,
|
||||
func: Callable[..., Any],
|
||||
) -> type[BaseModel]:
|
||||
"""Create a Pydantic schema from a function's signature.
|
||||
|
||||
@@ -162,7 +162,7 @@ class CrewStructuredTool:
|
||||
|
||||
# Create model
|
||||
schema_name = f"{name.title()}Schema"
|
||||
return create_model(schema_name, **fields)
|
||||
return create_model(schema_name, **fields) # type: ignore[call-overload,no-any-return]
|
||||
|
||||
def _validate_function_signature(self) -> None:
|
||||
"""Validate that the function signature matches the args schema."""
|
||||
@@ -190,7 +190,7 @@ class CrewStructuredTool:
|
||||
f"not found in args_schema"
|
||||
)
|
||||
|
||||
def _parse_args(self, raw_args: str | dict) -> dict:
|
||||
def _parse_args(self, raw_args: str | dict[str, Any]) -> dict[str, Any]:
|
||||
"""Parse and validate the input arguments against the schema.
|
||||
|
||||
Args:
|
||||
@@ -215,8 +215,8 @@ class CrewStructuredTool:
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: str | dict,
|
||||
config: Optional[dict] = None,
|
||||
input: str | dict[str, Any],
|
||||
config: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Asynchronously invoke the tool.
|
||||
@@ -251,7 +251,7 @@ class CrewStructuredTool:
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
def _run(self, *args, **kwargs) -> Any:
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Legacy method for compatibility."""
|
||||
# Convert args/kwargs to our expected format
|
||||
input_dict = dict(zip(self.args_schema.model_fields.keys(), args))
|
||||
@@ -259,7 +259,10 @@ class CrewStructuredTool:
|
||||
return self.invoke(input_dict)
|
||||
|
||||
def invoke(
|
||||
self, input: str | dict, config: Optional[dict] = None, **kwargs: Any
|
||||
self,
|
||||
input: str | dict[str, Any],
|
||||
config: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Main method for tool execution."""
|
||||
parsed_args = self._parse_args(input)
|
||||
@@ -319,9 +322,9 @@ class CrewStructuredTool:
|
||||
self._original_tool.current_usage_count = self.current_usage_count
|
||||
|
||||
@property
|
||||
def args(self) -> dict:
|
||||
def args(self) -> dict[str, Any]:
|
||||
"""Get the tool's input arguments schema."""
|
||||
return self.args_schema.model_json_schema()["properties"]
|
||||
return self.args_schema.model_json_schema()["properties"] # type: ignore[no-any-return]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
from typing import Optional
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
|
||||
|
||||
# Test fixtures
|
||||
@pytest.fixture
|
||||
def basic_function():
|
||||
def basic_function() -> Callable[[str, int], str]:
|
||||
def test_func(param1: str, param2: int = 0) -> str:
|
||||
"""Test function with basic params."""
|
||||
return f"{param1} {param2}"
|
||||
@@ -17,7 +19,7 @@ def basic_function():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def schema_class():
|
||||
def schema_class() -> type[BaseModel]:
|
||||
class TestSchema(BaseModel):
|
||||
param1: str
|
||||
param2: int = Field(default=0)
|
||||
@@ -25,7 +27,9 @@ def schema_class():
|
||||
return TestSchema
|
||||
|
||||
|
||||
def test_initialization(basic_function, schema_class):
|
||||
def test_initialization(
|
||||
basic_function: Callable[[str], str], schema_class: type[BaseModel]
|
||||
) -> None:
|
||||
"""Test basic initialization of CrewStructuredTool"""
|
||||
tool = CrewStructuredTool(
|
||||
name="test_tool",
|
||||
@@ -40,7 +44,7 @@ def test_initialization(basic_function, schema_class):
|
||||
assert tool.args_schema == schema_class
|
||||
|
||||
|
||||
def test_from_function(basic_function):
|
||||
def test_from_function(basic_function: Callable[[str], str]) -> None:
|
||||
"""Test creating tool from function"""
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=basic_function, name="test_tool", description="Test description"
|
||||
@@ -52,7 +56,9 @@ def test_from_function(basic_function):
|
||||
assert isinstance(tool.args_schema, type(BaseModel))
|
||||
|
||||
|
||||
def test_validate_function_signature(basic_function, schema_class):
|
||||
def test_validate_function_signature(
|
||||
basic_function: Callable[[str, int], str], schema_class: type[BaseModel]
|
||||
) -> None:
|
||||
"""Test function signature validation"""
|
||||
tool = CrewStructuredTool(
|
||||
name="test_tool",
|
||||
@@ -66,7 +72,7 @@ def test_validate_function_signature(basic_function, schema_class):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke(basic_function):
|
||||
async def test_ainvoke(basic_function: Callable[[str, int], str]) -> None:
|
||||
"""Test asynchronous invocation"""
|
||||
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
||||
|
||||
@@ -74,7 +80,7 @@ async def test_ainvoke(basic_function):
|
||||
assert result == "test 0"
|
||||
|
||||
|
||||
def test_parse_args_dict(basic_function):
|
||||
def test_parse_args_dict(basic_function: Callable[[str, int], str]) -> None:
|
||||
"""Test parsing dictionary arguments"""
|
||||
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
||||
|
||||
@@ -83,7 +89,7 @@ def test_parse_args_dict(basic_function):
|
||||
assert parsed["param2"] == 42
|
||||
|
||||
|
||||
def test_parse_args_string(basic_function):
|
||||
def test_parse_args_string(basic_function: Callable[[str, int], str]) -> None:
|
||||
"""Test parsing string arguments"""
|
||||
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
||||
|
||||
@@ -92,10 +98,10 @@ def test_parse_args_string(basic_function):
|
||||
assert parsed["param2"] == 42
|
||||
|
||||
|
||||
def test_complex_types():
|
||||
def test_complex_types() -> None:
|
||||
"""Test handling of complex parameter types"""
|
||||
|
||||
def complex_func(nested: dict, items: list) -> str:
|
||||
def complex_func(nested: dict[str, Any], items: list[Any]) -> str:
|
||||
"""Process complex types."""
|
||||
return f"Processed {len(items)} items with {len(nested)} nested keys"
|
||||
|
||||
@@ -106,7 +112,7 @@ def test_complex_types():
|
||||
assert result == "Processed 3 items with 1 nested keys"
|
||||
|
||||
|
||||
def test_schema_inheritance():
|
||||
def test_schema_inheritance() -> None:
|
||||
"""Test tool creation with inherited schema"""
|
||||
|
||||
def extended_func(base_param: str, extra_param: int) -> str:
|
||||
@@ -127,7 +133,7 @@ def test_schema_inheritance():
|
||||
assert result == "test 42"
|
||||
|
||||
|
||||
def test_default_values_in_schema():
|
||||
def test_default_values_in_schema() -> None:
|
||||
"""Test handling of default values in schema"""
|
||||
|
||||
def default_func(
|
||||
@@ -154,11 +160,11 @@ def test_default_values_in_schema():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def custom_tool_decorator():
|
||||
def custom_tool_decorator() -> Any:
|
||||
from crewai.tools import tool
|
||||
|
||||
@tool("custom_tool", result_as_answer=True)
|
||||
async def custom_tool():
|
||||
async def custom_tool() -> str:
|
||||
"""This is a tool that does something"""
|
||||
return "Hello World from Custom Tool"
|
||||
|
||||
@@ -166,7 +172,7 @@ def custom_tool_decorator():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def custom_tool():
|
||||
def custom_tool() -> BaseTool:
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
class CustomTool(BaseTool):
|
||||
@@ -174,13 +180,13 @@ def custom_tool():
|
||||
description: str = "This is a tool that does something"
|
||||
result_as_answer: bool = True
|
||||
|
||||
async def _run(self):
|
||||
async def _run(self) -> str:
|
||||
return "Hello World from Custom Tool"
|
||||
|
||||
return CustomTool()
|
||||
|
||||
|
||||
def build_simple_crew(tool):
|
||||
def build_simple_crew(tool: Any) -> Any:
|
||||
from crewai import Agent, Crew, Task
|
||||
|
||||
agent1 = Agent(
|
||||
@@ -201,7 +207,7 @@ def build_simple_crew(tool):
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_async_tool_using_within_isolated_crew(custom_tool):
|
||||
def test_async_tool_using_within_isolated_crew(custom_tool: BaseTool) -> None:
|
||||
crew = build_simple_crew(custom_tool)
|
||||
result = crew.kickoff()
|
||||
|
||||
@@ -209,7 +215,9 @@ def test_async_tool_using_within_isolated_crew(custom_tool):
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_async_tool_using_decorator_within_isolated_crew(custom_tool_decorator):
|
||||
def test_async_tool_using_decorator_within_isolated_crew(
|
||||
custom_tool_decorator: Any,
|
||||
) -> None:
|
||||
crew = build_simple_crew(custom_tool_decorator)
|
||||
result = crew.kickoff()
|
||||
|
||||
@@ -217,14 +225,12 @@ def test_async_tool_using_decorator_within_isolated_crew(custom_tool_decorator):
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_async_tool_within_flow(custom_tool):
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
class StructuredExampleFlow(Flow):
|
||||
from crewai.flow.flow import start
|
||||
def test_async_tool_within_flow(custom_tool: BaseTool) -> None:
|
||||
from crewai.flow.flow import Flow, start
|
||||
|
||||
class StructuredExampleFlow(Flow): # type: ignore[type-arg]
|
||||
@start()
|
||||
async def start(self):
|
||||
async def start(self) -> Any:
|
||||
crew = build_simple_crew(custom_tool)
|
||||
result = await crew.kickoff_async()
|
||||
return result
|
||||
@@ -235,14 +241,12 @@ def test_async_tool_within_flow(custom_tool):
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_async_tool_using_decorator_within_flow(custom_tool_decorator):
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
class StructuredExampleFlow(Flow):
|
||||
from crewai.flow.flow import start
|
||||
def test_async_tool_using_decorator_within_flow(custom_tool_decorator: Any) -> None:
|
||||
from crewai.flow.flow import Flow, start
|
||||
|
||||
class StructuredExampleFlow(Flow): # type: ignore[type-arg]
|
||||
@start()
|
||||
async def start(self):
|
||||
async def start(self) -> Any:
|
||||
crew = build_simple_crew(custom_tool_decorator)
|
||||
result = await crew.kickoff_async()
|
||||
return result
|
||||
@@ -252,7 +256,7 @@ def test_async_tool_using_decorator_within_flow(custom_tool_decorator):
|
||||
assert result.raw == "Hello World from Custom Tool"
|
||||
|
||||
|
||||
def test_invoke_sync_function_single_execution():
|
||||
def test_invoke_sync_function_single_execution() -> None:
|
||||
"""Test that sync functions are called only once, not twice."""
|
||||
call_count = 0
|
||||
|
||||
@@ -270,7 +274,7 @@ def test_invoke_sync_function_single_execution():
|
||||
assert result == "Called 1 times with: test"
|
||||
|
||||
|
||||
def test_invoke_async_function_outside_event_loop():
|
||||
def test_invoke_async_function_outside_event_loop() -> None:
|
||||
"""Test that async functions work correctly when called outside event loop."""
|
||||
|
||||
async def async_func(message: str) -> str:
|
||||
@@ -285,7 +289,7 @@ def test_invoke_async_function_outside_event_loop():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_async_function_in_event_loop_raises_error():
|
||||
async def test_invoke_async_function_in_event_loop_raises_error() -> None:
|
||||
"""Test that async functions raise RuntimeError when called from within event loop."""
|
||||
|
||||
async def async_func(message: str) -> str:
|
||||
@@ -302,13 +306,13 @@ async def test_invoke_async_function_in_event_loop_raises_error():
|
||||
tool.invoke({"message": "test"})
|
||||
|
||||
|
||||
def test_invoke_sync_function_returning_coroutine():
|
||||
def test_invoke_sync_function_returning_coroutine() -> None:
|
||||
"""Test handling of sync functions that return coroutines."""
|
||||
|
||||
async def inner_async(message: str) -> str:
|
||||
return f"Inner async: {message}"
|
||||
|
||||
def sync_func_returning_coro(message: str):
|
||||
def sync_func_returning_coro(message: str) -> Any:
|
||||
return inner_async(message)
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
@@ -322,13 +326,15 @@ def test_invoke_sync_function_returning_coroutine():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_sync_function_returning_coroutine_in_event_loop_raises_error():
|
||||
async def test_invoke_sync_function_returning_coroutine_in_event_loop_raises_error() -> (
|
||||
None
|
||||
):
|
||||
"""Test that sync functions returning coroutines raise RuntimeError in event loop."""
|
||||
|
||||
async def inner_async(message: str) -> str:
|
||||
return f"Inner async: {message}"
|
||||
|
||||
def sync_func_returning_coro(message: str):
|
||||
def sync_func_returning_coro(message: str) -> Any:
|
||||
return inner_async(message)
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
|
||||
Reference in New Issue
Block a user