mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Fix CrewStructuredTool invoke() method bugs
- Fix RuntimeError from asyncio.run() in nested event loops - Fix double execution of sync functions - Fix inconsistent coroutine handling - Add comprehensive tests for all scenarios - Properly detect event loop context to avoid asyncio.run() conflicts Fixes #3447 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -1,17 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import inspect
|
||||
import textwrap
|
||||
from typing import Any, Callable, Optional, Union, get_type_hints
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, get_type_hints
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
@@ -192,7 +190,7 @@ class CrewStructuredTool:
|
||||
f"not found in args_schema"
|
||||
)
|
||||
|
||||
def _parse_args(self, raw_args: Union[str, dict]) -> dict:
|
||||
def _parse_args(self, raw_args: str | dict) -> dict:
|
||||
"""Parse and validate the input arguments against the schema.
|
||||
|
||||
Args:
|
||||
@@ -217,7 +215,7 @@ class CrewStructuredTool:
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Union[str, dict],
|
||||
input: str | dict,
|
||||
config: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
@@ -261,7 +259,7 @@ class CrewStructuredTool:
|
||||
return self.invoke(input_dict)
|
||||
|
||||
def invoke(
|
||||
self, input: Union[str, dict], config: Optional[dict] = None, **kwargs: Any
|
||||
self, input: str | dict, config: Optional[dict] = None, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Main method for tool execution."""
|
||||
parsed_args = self._parse_args(input)
|
||||
@@ -273,22 +271,40 @@ class CrewStructuredTool:
|
||||
|
||||
self._increment_usage_count()
|
||||
|
||||
if inspect.iscoroutinefunction(self.func):
|
||||
result = asyncio.run(self.func(**parsed_args, **kwargs))
|
||||
return result
|
||||
|
||||
try:
|
||||
result = self.func(**parsed_args, **kwargs)
|
||||
if inspect.iscoroutinefunction(self.func):
|
||||
coro = self.func(**parsed_args, **kwargs)
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
raise RuntimeError(
|
||||
f"Cannot call async tool '{self.name}' from synchronous context within an event loop. "
|
||||
f"Use ainvoke() instead or call from outside the event loop."
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "Cannot call async tool" in str(e):
|
||||
raise
|
||||
else:
|
||||
return asyncio.run(coro)
|
||||
else:
|
||||
result = self.func(**parsed_args, **kwargs)
|
||||
|
||||
if asyncio.iscoroutine(result):
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
raise RuntimeError(
|
||||
f"Sync function '{self.name}' returned a coroutine but we're in an event loop. "
|
||||
f"Use ainvoke() instead or call from outside the event loop."
|
||||
)
|
||||
except RuntimeError as e:
|
||||
if "returned a coroutine but we're in an event loop" in str(e):
|
||||
raise
|
||||
else:
|
||||
return asyncio.run(result)
|
||||
|
||||
return result
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
result = self.func(**parsed_args, **kwargs)
|
||||
|
||||
if asyncio.iscoroutine(result):
|
||||
return asyncio.run(result)
|
||||
|
||||
return result
|
||||
|
||||
def has_reached_max_usage_count(self) -> bool:
|
||||
"""Check if the tool has reached its maximum usage count."""
|
||||
return (
|
||||
|
||||
@@ -39,6 +39,7 @@ def test_initialization(basic_function, schema_class):
|
||||
assert tool.func == basic_function
|
||||
assert tool.args_schema == schema_class
|
||||
|
||||
|
||||
def test_from_function(basic_function):
|
||||
"""Test creating tool from function"""
|
||||
tool = CrewStructuredTool.from_function(
|
||||
@@ -50,6 +51,7 @@ def test_from_function(basic_function):
|
||||
assert tool.func == basic_function
|
||||
assert isinstance(tool.args_schema, type(BaseModel))
|
||||
|
||||
|
||||
def test_validate_function_signature(basic_function, schema_class):
|
||||
"""Test function signature validation"""
|
||||
tool = CrewStructuredTool(
|
||||
@@ -62,6 +64,7 @@ def test_validate_function_signature(basic_function, schema_class):
|
||||
# Should not raise any exceptions
|
||||
tool._validate_function_signature()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke(basic_function):
|
||||
"""Test asynchronous invocation"""
|
||||
@@ -70,6 +73,7 @@ async def test_ainvoke(basic_function):
|
||||
result = await tool.ainvoke(input={"param1": "test"})
|
||||
assert result == "test 0"
|
||||
|
||||
|
||||
def test_parse_args_dict(basic_function):
|
||||
"""Test parsing dictionary arguments"""
|
||||
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
||||
@@ -78,6 +82,7 @@ def test_parse_args_dict(basic_function):
|
||||
assert parsed["param1"] == "test"
|
||||
assert parsed["param2"] == 42
|
||||
|
||||
|
||||
def test_parse_args_string(basic_function):
|
||||
"""Test parsing string arguments"""
|
||||
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
||||
@@ -86,6 +91,7 @@ def test_parse_args_string(basic_function):
|
||||
assert parsed["param1"] == "test"
|
||||
assert parsed["param2"] == 42
|
||||
|
||||
|
||||
def test_complex_types():
|
||||
"""Test handling of complex parameter types"""
|
||||
|
||||
@@ -99,6 +105,7 @@ def test_complex_types():
|
||||
result = tool.invoke({"nested": {"key": "value"}, "items": [1, 2, 3]})
|
||||
assert result == "Processed 3 items with 1 nested keys"
|
||||
|
||||
|
||||
def test_schema_inheritance():
|
||||
"""Test tool creation with inherited schema"""
|
||||
|
||||
@@ -119,6 +126,7 @@ def test_schema_inheritance():
|
||||
result = tool.invoke({"base_param": "test", "extra_param": 42})
|
||||
assert result == "test 42"
|
||||
|
||||
|
||||
def test_default_values_in_schema():
|
||||
"""Test handling of default values in schema"""
|
||||
|
||||
@@ -144,6 +152,7 @@ def test_default_values_in_schema():
|
||||
)
|
||||
assert result == "test custom 42"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def custom_tool_decorator():
|
||||
from crewai.tools import tool
|
||||
@@ -155,6 +164,7 @@ def custom_tool_decorator():
|
||||
|
||||
return custom_tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def custom_tool():
|
||||
from crewai.tools import BaseTool
|
||||
@@ -169,18 +179,27 @@ def custom_tool():
|
||||
|
||||
return CustomTool()
|
||||
|
||||
def build_simple_crew(tool):
|
||||
from crewai import Agent, Task, Crew
|
||||
|
||||
agent1 = Agent(role="Simple role", goal="Simple goal", backstory="Simple backstory", tools=[tool])
|
||||
def build_simple_crew(tool):
|
||||
from crewai import Agent, Crew, Task
|
||||
|
||||
agent1 = Agent(
|
||||
role="Simple role",
|
||||
goal="Simple goal",
|
||||
backstory="Simple backstory",
|
||||
tools=[tool],
|
||||
)
|
||||
|
||||
say_hi_task = Task(
|
||||
description="Use the custom tool result as answer.", agent=agent1, expected_output="Use the tool result"
|
||||
description="Use the custom tool result as answer.",
|
||||
agent=agent1,
|
||||
expected_output="Use the tool result",
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent1], tasks=[say_hi_task])
|
||||
return crew
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_async_tool_using_within_isolated_crew(custom_tool):
|
||||
crew = build_simple_crew(custom_tool)
|
||||
@@ -188,6 +207,7 @@ def test_async_tool_using_within_isolated_crew(custom_tool):
|
||||
|
||||
assert result.raw == "Hello World from Custom Tool"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_async_tool_using_decorator_within_isolated_crew(custom_tool_decorator):
|
||||
crew = build_simple_crew(custom_tool_decorator)
|
||||
@@ -195,6 +215,7 @@ def test_async_tool_using_decorator_within_isolated_crew(custom_tool_decorator):
|
||||
|
||||
assert result.raw == "Hello World from Custom Tool"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_async_tool_within_flow(custom_tool):
|
||||
from crewai.flow.flow import Flow
|
||||
@@ -219,6 +240,7 @@ def test_async_tool_using_decorator_within_flow(custom_tool_decorator):
|
||||
|
||||
class StructuredExampleFlow(Flow):
|
||||
from crewai.flow.flow import start
|
||||
|
||||
@start()
|
||||
async def start(self):
|
||||
crew = build_simple_crew(custom_tool_decorator)
|
||||
@@ -227,4 +249,96 @@ def test_async_tool_using_decorator_within_flow(custom_tool_decorator):
|
||||
|
||||
flow = StructuredExampleFlow()
|
||||
result = flow.kickoff()
|
||||
assert result.raw == "Hello World from Custom Tool"
|
||||
assert result.raw == "Hello World from Custom Tool"
|
||||
|
||||
|
||||
def test_invoke_sync_function_single_execution():
|
||||
"""Test that sync functions are called only once, not twice."""
|
||||
call_count = 0
|
||||
|
||||
def counting_func(message: str) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"Called {call_count} times with: {message}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=counting_func, name="counting_tool", description="A tool that counts calls"
|
||||
)
|
||||
|
||||
result = tool.invoke({"message": "test"})
|
||||
assert call_count == 1, f"Function was called {call_count} times, expected 1"
|
||||
assert result == "Called 1 times with: test"
|
||||
|
||||
|
||||
def test_invoke_async_function_outside_event_loop():
|
||||
"""Test that async functions work correctly when called outside event loop."""
|
||||
|
||||
async def async_func(message: str) -> str:
|
||||
return f"Async result: {message}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=async_func, name="async_tool", description="An async tool"
|
||||
)
|
||||
|
||||
result = tool.invoke({"message": "test"})
|
||||
assert result == "Async result: test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_async_function_in_event_loop_raises_error():
|
||||
"""Test that async functions raise RuntimeError when called from within event loop."""
|
||||
|
||||
async def async_func(message: str) -> str:
|
||||
return f"Async result: {message}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=async_func, name="async_tool", description="An async tool"
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="Cannot call async tool.*from synchronous context within an event loop",
|
||||
):
|
||||
tool.invoke({"message": "test"})
|
||||
|
||||
|
||||
def test_invoke_sync_function_returning_coroutine():
|
||||
"""Test handling of sync functions that return coroutines."""
|
||||
|
||||
async def inner_async(message: str) -> str:
|
||||
return f"Inner async: {message}"
|
||||
|
||||
def sync_func_returning_coro(message: str):
|
||||
return inner_async(message)
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=sync_func_returning_coro,
|
||||
name="sync_coro_tool",
|
||||
description="A sync tool that returns coroutine",
|
||||
)
|
||||
|
||||
result = tool.invoke({"message": "test"})
|
||||
assert result == "Inner async: test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_sync_function_returning_coroutine_in_event_loop_raises_error():
|
||||
"""Test that sync functions returning coroutines raise RuntimeError in event loop."""
|
||||
|
||||
async def inner_async(message: str) -> str:
|
||||
return f"Inner async: {message}"
|
||||
|
||||
def sync_func_returning_coro(message: str):
|
||||
return inner_async(message)
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=sync_func_returning_coro,
|
||||
name="sync_coro_tool",
|
||||
description="A sync tool that returns coroutine",
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="Sync function.*returned a coroutine but we're in an event loop",
|
||||
):
|
||||
tool.invoke({"message": "test"})
|
||||
|
||||
Reference in New Issue
Block a user