Fix CrewStructuredTool invoke() method bugs

- Fix RuntimeError from asyncio.run() in nested event loops
- Fix double execution of sync functions
- Fix inconsistent coroutine handling
- Add comprehensive tests for all scenarios
- Properly detect event loop context to avoid asyncio.run() conflicts

Fixes #3447

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-09-04 08:43:43 +00:00
parent f0def350a4
commit 89fcd2a5b4
2 changed files with 154 additions and 24 deletions

View File

@@ -1,17 +1,15 @@
from __future__ import annotations
import asyncio
import inspect
import textwrap
from typing import Any, Callable, Optional, Union, get_type_hints
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, Optional, Union, get_type_hints
from pydantic import BaseModel, Field, create_model
from crewai.utilities.logger import Logger
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from crewai.tools.base_tool import BaseTool
@@ -192,7 +190,7 @@ class CrewStructuredTool:
f"not found in args_schema"
)
def _parse_args(self, raw_args: Union[str, dict]) -> dict:
def _parse_args(self, raw_args: str | dict) -> dict:
"""Parse and validate the input arguments against the schema.
Args:
@@ -217,7 +215,7 @@ class CrewStructuredTool:
async def ainvoke(
self,
input: Union[str, dict],
input: str | dict,
config: Optional[dict] = None,
**kwargs: Any,
) -> Any:
@@ -261,7 +259,7 @@ class CrewStructuredTool:
return self.invoke(input_dict)
def invoke(
self, input: Union[str, dict], config: Optional[dict] = None, **kwargs: Any
self, input: str | dict, config: Optional[dict] = None, **kwargs: Any
) -> Any:
"""Main method for tool execution."""
parsed_args = self._parse_args(input)
@@ -273,22 +271,40 @@ class CrewStructuredTool:
self._increment_usage_count()
if inspect.iscoroutinefunction(self.func):
result = asyncio.run(self.func(**parsed_args, **kwargs))
return result
try:
result = self.func(**parsed_args, **kwargs)
if inspect.iscoroutinefunction(self.func):
coro = self.func(**parsed_args, **kwargs)
try:
asyncio.get_running_loop()
raise RuntimeError(
f"Cannot call async tool '{self.name}' from synchronous context within an event loop. "
f"Use ainvoke() instead or call from outside the event loop."
)
except RuntimeError as e:
if "Cannot call async tool" in str(e):
raise
else:
return asyncio.run(coro)
else:
result = self.func(**parsed_args, **kwargs)
if asyncio.iscoroutine(result):
try:
asyncio.get_running_loop()
raise RuntimeError(
f"Sync function '{self.name}' returned a coroutine but we're in an event loop. "
f"Use ainvoke() instead or call from outside the event loop."
)
except RuntimeError as e:
if "returned a coroutine but we're in an event loop" in str(e):
raise
else:
return asyncio.run(result)
return result
except Exception:
raise
result = self.func(**parsed_args, **kwargs)
if asyncio.iscoroutine(result):
return asyncio.run(result)
return result
def has_reached_max_usage_count(self) -> bool:
"""Check if the tool has reached its maximum usage count."""
return (

View File

@@ -39,6 +39,7 @@ def test_initialization(basic_function, schema_class):
assert tool.func == basic_function
assert tool.args_schema == schema_class
def test_from_function(basic_function):
"""Test creating tool from function"""
tool = CrewStructuredTool.from_function(
@@ -50,6 +51,7 @@ def test_from_function(basic_function):
assert tool.func == basic_function
assert isinstance(tool.args_schema, type(BaseModel))
def test_validate_function_signature(basic_function, schema_class):
"""Test function signature validation"""
tool = CrewStructuredTool(
@@ -62,6 +64,7 @@ def test_validate_function_signature(basic_function, schema_class):
# Should not raise any exceptions
tool._validate_function_signature()
@pytest.mark.asyncio
async def test_ainvoke(basic_function):
"""Test asynchronous invocation"""
@@ -70,6 +73,7 @@ async def test_ainvoke(basic_function):
result = await tool.ainvoke(input={"param1": "test"})
assert result == "test 0"
def test_parse_args_dict(basic_function):
"""Test parsing dictionary arguments"""
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
@@ -78,6 +82,7 @@ def test_parse_args_dict(basic_function):
assert parsed["param1"] == "test"
assert parsed["param2"] == 42
def test_parse_args_string(basic_function):
"""Test parsing string arguments"""
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
@@ -86,6 +91,7 @@ def test_parse_args_string(basic_function):
assert parsed["param1"] == "test"
assert parsed["param2"] == 42
def test_complex_types():
"""Test handling of complex parameter types"""
@@ -99,6 +105,7 @@ def test_complex_types():
result = tool.invoke({"nested": {"key": "value"}, "items": [1, 2, 3]})
assert result == "Processed 3 items with 1 nested keys"
def test_schema_inheritance():
"""Test tool creation with inherited schema"""
@@ -119,6 +126,7 @@ def test_schema_inheritance():
result = tool.invoke({"base_param": "test", "extra_param": 42})
assert result == "test 42"
def test_default_values_in_schema():
"""Test handling of default values in schema"""
@@ -144,6 +152,7 @@ def test_default_values_in_schema():
)
assert result == "test custom 42"
@pytest.fixture
def custom_tool_decorator():
from crewai.tools import tool
@@ -155,6 +164,7 @@ def custom_tool_decorator():
return custom_tool
@pytest.fixture
def custom_tool():
from crewai.tools import BaseTool
@@ -169,18 +179,27 @@ def custom_tool():
return CustomTool()
def build_simple_crew(tool):
from crewai import Agent, Task, Crew
agent1 = Agent(role="Simple role", goal="Simple goal", backstory="Simple backstory", tools=[tool])
def build_simple_crew(tool):
from crewai import Agent, Crew, Task
agent1 = Agent(
role="Simple role",
goal="Simple goal",
backstory="Simple backstory",
tools=[tool],
)
say_hi_task = Task(
description="Use the custom tool result as answer.", agent=agent1, expected_output="Use the tool result"
description="Use the custom tool result as answer.",
agent=agent1,
expected_output="Use the tool result",
)
crew = Crew(agents=[agent1], tasks=[say_hi_task])
return crew
@pytest.mark.vcr(filter_headers=["authorization"])
def test_async_tool_using_within_isolated_crew(custom_tool):
crew = build_simple_crew(custom_tool)
@@ -188,6 +207,7 @@ def test_async_tool_using_within_isolated_crew(custom_tool):
assert result.raw == "Hello World from Custom Tool"
@pytest.mark.vcr(filter_headers=["authorization"])
def test_async_tool_using_decorator_within_isolated_crew(custom_tool_decorator):
crew = build_simple_crew(custom_tool_decorator)
@@ -195,6 +215,7 @@ def test_async_tool_using_decorator_within_isolated_crew(custom_tool_decorator):
assert result.raw == "Hello World from Custom Tool"
@pytest.mark.vcr(filter_headers=["authorization"])
def test_async_tool_within_flow(custom_tool):
from crewai.flow.flow import Flow
@@ -219,6 +240,7 @@ def test_async_tool_using_decorator_within_flow(custom_tool_decorator):
class StructuredExampleFlow(Flow):
from crewai.flow.flow import start
@start()
async def start(self):
crew = build_simple_crew(custom_tool_decorator)
@@ -227,4 +249,96 @@ def test_async_tool_using_decorator_within_flow(custom_tool_decorator):
flow = StructuredExampleFlow()
result = flow.kickoff()
assert result.raw == "Hello World from Custom Tool"
assert result.raw == "Hello World from Custom Tool"
def test_invoke_sync_function_single_execution():
"""Test that sync functions are called only once, not twice."""
call_count = 0
def counting_func(message: str) -> str:
nonlocal call_count
call_count += 1
return f"Called {call_count} times with: {message}"
tool = CrewStructuredTool.from_function(
func=counting_func, name="counting_tool", description="A tool that counts calls"
)
result = tool.invoke({"message": "test"})
assert call_count == 1, f"Function was called {call_count} times, expected 1"
assert result == "Called 1 times with: test"
def test_invoke_async_function_outside_event_loop():
"""Test that async functions work correctly when called outside event loop."""
async def async_func(message: str) -> str:
return f"Async result: {message}"
tool = CrewStructuredTool.from_function(
func=async_func, name="async_tool", description="An async tool"
)
result = tool.invoke({"message": "test"})
assert result == "Async result: test"
@pytest.mark.asyncio
async def test_invoke_async_function_in_event_loop_raises_error():
"""Test that async functions raise RuntimeError when called from within event loop."""
async def async_func(message: str) -> str:
return f"Async result: {message}"
tool = CrewStructuredTool.from_function(
func=async_func, name="async_tool", description="An async tool"
)
with pytest.raises(
RuntimeError,
match="Cannot call async tool.*from synchronous context within an event loop",
):
tool.invoke({"message": "test"})
def test_invoke_sync_function_returning_coroutine():
"""Test handling of sync functions that return coroutines."""
async def inner_async(message: str) -> str:
return f"Inner async: {message}"
def sync_func_returning_coro(message: str):
return inner_async(message)
tool = CrewStructuredTool.from_function(
func=sync_func_returning_coro,
name="sync_coro_tool",
description="A sync tool that returns coroutine",
)
result = tool.invoke({"message": "test"})
assert result == "Inner async: test"
@pytest.mark.asyncio
async def test_invoke_sync_function_returning_coroutine_in_event_loop_raises_error():
"""Test that sync functions returning coroutines raise RuntimeError in event loop."""
async def inner_async(message: str) -> str:
return f"Inner async: {message}"
def sync_func_returning_coro(message: str):
return inner_async(message)
tool = CrewStructuredTool.from_function(
func=sync_func_returning_coro,
name="sync_coro_tool",
description="A sync tool that returns coroutine",
)
with pytest.raises(
RuntimeError,
match="Sync function.*returned a coroutine but we're in an event loop",
):
tool.invoke({"message": "test"})