mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 09:08:31 +00:00
Fix CrewStructuredTool invoke() method bugs
- Fix RuntimeError from asyncio.run() in nested event loops - Fix double execution of sync functions - Fix inconsistent coroutine handling - Add comprehensive tests for all scenarios - Properly detect event loop context to avoid asyncio.run() conflicts Fixes #3447 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -39,6 +39,7 @@ def test_initialization(basic_function, schema_class):
|
||||
assert tool.func == basic_function
|
||||
assert tool.args_schema == schema_class
|
||||
|
||||
|
||||
def test_from_function(basic_function):
|
||||
"""Test creating tool from function"""
|
||||
tool = CrewStructuredTool.from_function(
|
||||
@@ -50,6 +51,7 @@ def test_from_function(basic_function):
|
||||
assert tool.func == basic_function
|
||||
assert isinstance(tool.args_schema, type(BaseModel))
|
||||
|
||||
|
||||
def test_validate_function_signature(basic_function, schema_class):
|
||||
"""Test function signature validation"""
|
||||
tool = CrewStructuredTool(
|
||||
@@ -62,6 +64,7 @@ def test_validate_function_signature(basic_function, schema_class):
|
||||
# Should not raise any exceptions
|
||||
tool._validate_function_signature()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke(basic_function):
|
||||
"""Test asynchronous invocation"""
|
||||
@@ -70,6 +73,7 @@ async def test_ainvoke(basic_function):
|
||||
result = await tool.ainvoke(input={"param1": "test"})
|
||||
assert result == "test 0"
|
||||
|
||||
|
||||
def test_parse_args_dict(basic_function):
|
||||
"""Test parsing dictionary arguments"""
|
||||
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
||||
@@ -78,6 +82,7 @@ def test_parse_args_dict(basic_function):
|
||||
assert parsed["param1"] == "test"
|
||||
assert parsed["param2"] == 42
|
||||
|
||||
|
||||
def test_parse_args_string(basic_function):
|
||||
"""Test parsing string arguments"""
|
||||
tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool")
|
||||
@@ -86,6 +91,7 @@ def test_parse_args_string(basic_function):
|
||||
assert parsed["param1"] == "test"
|
||||
assert parsed["param2"] == 42
|
||||
|
||||
|
||||
def test_complex_types():
|
||||
"""Test handling of complex parameter types"""
|
||||
|
||||
@@ -99,6 +105,7 @@ def test_complex_types():
|
||||
result = tool.invoke({"nested": {"key": "value"}, "items": [1, 2, 3]})
|
||||
assert result == "Processed 3 items with 1 nested keys"
|
||||
|
||||
|
||||
def test_schema_inheritance():
|
||||
"""Test tool creation with inherited schema"""
|
||||
|
||||
@@ -119,6 +126,7 @@ def test_schema_inheritance():
|
||||
result = tool.invoke({"base_param": "test", "extra_param": 42})
|
||||
assert result == "test 42"
|
||||
|
||||
|
||||
def test_default_values_in_schema():
|
||||
"""Test handling of default values in schema"""
|
||||
|
||||
@@ -144,6 +152,7 @@ def test_default_values_in_schema():
|
||||
)
|
||||
assert result == "test custom 42"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def custom_tool_decorator():
|
||||
from crewai.tools import tool
|
||||
@@ -155,6 +164,7 @@ def custom_tool_decorator():
|
||||
|
||||
return custom_tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def custom_tool():
|
||||
from crewai.tools import BaseTool
|
||||
@@ -169,18 +179,27 @@ def custom_tool():
|
||||
|
||||
return CustomTool()
|
||||
|
||||
def build_simple_crew(tool):
|
||||
from crewai import Agent, Task, Crew
|
||||
|
||||
agent1 = Agent(role="Simple role", goal="Simple goal", backstory="Simple backstory", tools=[tool])
|
||||
def build_simple_crew(tool):
|
||||
from crewai import Agent, Crew, Task
|
||||
|
||||
agent1 = Agent(
|
||||
role="Simple role",
|
||||
goal="Simple goal",
|
||||
backstory="Simple backstory",
|
||||
tools=[tool],
|
||||
)
|
||||
|
||||
say_hi_task = Task(
|
||||
description="Use the custom tool result as answer.", agent=agent1, expected_output="Use the tool result"
|
||||
description="Use the custom tool result as answer.",
|
||||
agent=agent1,
|
||||
expected_output="Use the tool result",
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent1], tasks=[say_hi_task])
|
||||
return crew
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_async_tool_using_within_isolated_crew(custom_tool):
|
||||
crew = build_simple_crew(custom_tool)
|
||||
@@ -188,6 +207,7 @@ def test_async_tool_using_within_isolated_crew(custom_tool):
|
||||
|
||||
assert result.raw == "Hello World from Custom Tool"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_async_tool_using_decorator_within_isolated_crew(custom_tool_decorator):
|
||||
crew = build_simple_crew(custom_tool_decorator)
|
||||
@@ -195,6 +215,7 @@ def test_async_tool_using_decorator_within_isolated_crew(custom_tool_decorator):
|
||||
|
||||
assert result.raw == "Hello World from Custom Tool"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_async_tool_within_flow(custom_tool):
|
||||
from crewai.flow.flow import Flow
|
||||
@@ -219,6 +240,7 @@ def test_async_tool_using_decorator_within_flow(custom_tool_decorator):
|
||||
|
||||
class StructuredExampleFlow(Flow):
|
||||
from crewai.flow.flow import start
|
||||
|
||||
@start()
|
||||
async def start(self):
|
||||
crew = build_simple_crew(custom_tool_decorator)
|
||||
@@ -227,4 +249,96 @@ def test_async_tool_using_decorator_within_flow(custom_tool_decorator):
|
||||
|
||||
flow = StructuredExampleFlow()
|
||||
result = flow.kickoff()
|
||||
assert result.raw == "Hello World from Custom Tool"
|
||||
assert result.raw == "Hello World from Custom Tool"
|
||||
|
||||
|
||||
def test_invoke_sync_function_single_execution():
|
||||
"""Test that sync functions are called only once, not twice."""
|
||||
call_count = 0
|
||||
|
||||
def counting_func(message: str) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return f"Called {call_count} times with: {message}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=counting_func, name="counting_tool", description="A tool that counts calls"
|
||||
)
|
||||
|
||||
result = tool.invoke({"message": "test"})
|
||||
assert call_count == 1, f"Function was called {call_count} times, expected 1"
|
||||
assert result == "Called 1 times with: test"
|
||||
|
||||
|
||||
def test_invoke_async_function_outside_event_loop():
|
||||
"""Test that async functions work correctly when called outside event loop."""
|
||||
|
||||
async def async_func(message: str) -> str:
|
||||
return f"Async result: {message}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=async_func, name="async_tool", description="An async tool"
|
||||
)
|
||||
|
||||
result = tool.invoke({"message": "test"})
|
||||
assert result == "Async result: test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_async_function_in_event_loop_raises_error():
|
||||
"""Test that async functions raise RuntimeError when called from within event loop."""
|
||||
|
||||
async def async_func(message: str) -> str:
|
||||
return f"Async result: {message}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=async_func, name="async_tool", description="An async tool"
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="Cannot call async tool.*from synchronous context within an event loop",
|
||||
):
|
||||
tool.invoke({"message": "test"})
|
||||
|
||||
|
||||
def test_invoke_sync_function_returning_coroutine():
|
||||
"""Test handling of sync functions that return coroutines."""
|
||||
|
||||
async def inner_async(message: str) -> str:
|
||||
return f"Inner async: {message}"
|
||||
|
||||
def sync_func_returning_coro(message: str):
|
||||
return inner_async(message)
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=sync_func_returning_coro,
|
||||
name="sync_coro_tool",
|
||||
description="A sync tool that returns coroutine",
|
||||
)
|
||||
|
||||
result = tool.invoke({"message": "test"})
|
||||
assert result == "Inner async: test"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_sync_function_returning_coroutine_in_event_loop_raises_error():
|
||||
"""Test that sync functions returning coroutines raise RuntimeError in event loop."""
|
||||
|
||||
async def inner_async(message: str) -> str:
|
||||
return f"Inner async: {message}"
|
||||
|
||||
def sync_func_returning_coro(message: str):
|
||||
return inner_async(message)
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=sync_func_returning_coro,
|
||||
name="sync_coro_tool",
|
||||
description="A sync tool that returns coroutine",
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="Sync function.*returned a coroutine but we're in an event loop",
|
||||
):
|
||||
tool.invoke({"message": "test"})
|
||||
|
||||
Reference in New Issue
Block a user