diff --git a/src/crewai/tools/structured_tool.py b/src/crewai/tools/structured_tool.py index a514f4229..47faa1d9c 100644 --- a/src/crewai/tools/structured_tool.py +++ b/src/crewai/tools/structured_tool.py @@ -1,17 +1,15 @@ from __future__ import annotations import asyncio - import inspect import textwrap -from typing import Any, Callable, Optional, Union, get_type_hints +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Optional, Union, get_type_hints from pydantic import BaseModel, Field, create_model from crewai.utilities.logger import Logger -from typing import TYPE_CHECKING - if TYPE_CHECKING: from crewai.tools.base_tool import BaseTool @@ -192,7 +190,7 @@ class CrewStructuredTool: f"not found in args_schema" ) - def _parse_args(self, raw_args: Union[str, dict]) -> dict: + def _parse_args(self, raw_args: str | dict) -> dict: """Parse and validate the input arguments against the schema. Args: @@ -217,7 +215,7 @@ class CrewStructuredTool: async def ainvoke( self, - input: Union[str, dict], + input: str | dict, config: Optional[dict] = None, **kwargs: Any, ) -> Any: @@ -261,7 +259,7 @@ class CrewStructuredTool: return self.invoke(input_dict) def invoke( - self, input: Union[str, dict], config: Optional[dict] = None, **kwargs: Any + self, input: str | dict, config: Optional[dict] = None, **kwargs: Any ) -> Any: """Main method for tool execution.""" parsed_args = self._parse_args(input) @@ -273,22 +271,40 @@ class CrewStructuredTool: self._increment_usage_count() - if inspect.iscoroutinefunction(self.func): - result = asyncio.run(self.func(**parsed_args, **kwargs)) - return result - try: - result = self.func(**parsed_args, **kwargs) + if inspect.iscoroutinefunction(self.func): + coro = self.func(**parsed_args, **kwargs) + try: + asyncio.get_running_loop() + raise RuntimeError( + f"Cannot call async tool '{self.name}' from synchronous context within an event loop. " + f"Use ainvoke() instead or call from outside the event loop." + ) + except RuntimeError as e: + if "Cannot call async tool" in str(e): + raise + else: + return asyncio.run(coro) + else: + result = self.func(**parsed_args, **kwargs) + + if asyncio.iscoroutine(result): + try: + asyncio.get_running_loop() + raise RuntimeError( + f"Sync function '{self.name}' returned a coroutine but we're in an event loop. " + f"Use ainvoke() instead or call from outside the event loop." + ) + except RuntimeError as e: + if "returned a coroutine but we're in an event loop" in str(e): + raise + else: + return asyncio.run(result) + + return result except Exception: raise - result = self.func(**parsed_args, **kwargs) - - if asyncio.iscoroutine(result): - return asyncio.run(result) - - return result - def has_reached_max_usage_count(self) -> bool: """Check if the tool has reached its maximum usage count.""" return ( diff --git a/tests/tools/test_structured_tool.py b/tests/tools/test_structured_tool.py index f347b1db1..b65149768 100644 --- a/tests/tools/test_structured_tool.py +++ b/tests/tools/test_structured_tool.py @@ -39,6 +39,7 @@ def test_initialization(basic_function, schema_class): assert tool.func == basic_function assert tool.args_schema == schema_class + def test_from_function(basic_function): """Test creating tool from function""" tool = CrewStructuredTool.from_function( @@ -50,6 +51,7 @@ def test_from_function(basic_function): assert tool.func == basic_function assert isinstance(tool.args_schema, type(BaseModel)) + def test_validate_function_signature(basic_function, schema_class): """Test function signature validation""" tool = CrewStructuredTool( @@ -62,6 +64,7 @@ def test_validate_function_signature(basic_function, schema_class): # Should not raise any exceptions tool._validate_function_signature() + @pytest.mark.asyncio async def test_ainvoke(basic_function): """Test asynchronous invocation""" @@ -70,6 +73,7 @@ async def test_ainvoke(basic_function): result = await tool.ainvoke(input={"param1": "test"}) assert result == "test 0" + def test_parse_args_dict(basic_function): """Test parsing dictionary arguments""" tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool") @@ -78,6 +82,7 @@ def test_parse_args_dict(basic_function): assert parsed["param1"] == "test" assert parsed["param2"] == 42 + def test_parse_args_string(basic_function): """Test parsing string arguments""" tool = CrewStructuredTool.from_function(func=basic_function, name="test_tool") @@ -86,6 +91,7 @@ def test_parse_args_string(basic_function): assert parsed["param1"] == "test" assert parsed["param2"] == 42 + def test_complex_types(): """Test handling of complex parameter types""" @@ -99,6 +105,7 @@ def test_complex_types(): result = tool.invoke({"nested": {"key": "value"}, "items": [1, 2, 3]}) assert result == "Processed 3 items with 1 nested keys" + def test_schema_inheritance(): """Test tool creation with inherited schema""" @@ -119,6 +126,7 @@ def test_schema_inheritance(): result = tool.invoke({"base_param": "test", "extra_param": 42}) assert result == "test 42" + def test_default_values_in_schema(): """Test handling of default values in schema""" @@ -144,6 +152,7 @@ def test_default_values_in_schema(): ) assert result == "test custom 42" + @pytest.fixture def custom_tool_decorator(): from crewai.tools import tool @@ -155,6 +164,7 @@ def custom_tool_decorator(): return custom_tool + @pytest.fixture def custom_tool(): from crewai.tools import BaseTool @@ -169,18 +179,27 @@ def custom_tool(): return CustomTool() -def build_simple_crew(tool): - from crewai import Agent, Task, Crew - agent1 = Agent(role="Simple role", goal="Simple goal", backstory="Simple backstory", tools=[tool]) +def build_simple_crew(tool): + from crewai import Agent, Crew, Task + + agent1 = Agent( + role="Simple role", + goal="Simple goal", + backstory="Simple backstory", + tools=[tool], + ) say_hi_task = Task( - description="Use the custom tool result as answer.", agent=agent1, expected_output="Use the tool result" + description="Use the custom tool result as answer.", + agent=agent1, + expected_output="Use the tool result", ) crew = Crew(agents=[agent1], tasks=[say_hi_task]) return crew + @pytest.mark.vcr(filter_headers=["authorization"]) def test_async_tool_using_within_isolated_crew(custom_tool): crew = build_simple_crew(custom_tool) @@ -188,6 +207,7 @@ def test_async_tool_using_within_isolated_crew(custom_tool): assert result.raw == "Hello World from Custom Tool" + @pytest.mark.vcr(filter_headers=["authorization"]) def test_async_tool_using_decorator_within_isolated_crew(custom_tool_decorator): crew = build_simple_crew(custom_tool_decorator) @@ -195,6 +215,7 @@ def test_async_tool_using_decorator_within_isolated_crew(custom_tool_decorator): assert result.raw == "Hello World from Custom Tool" + @pytest.mark.vcr(filter_headers=["authorization"]) def test_async_tool_within_flow(custom_tool): from crewai.flow.flow import Flow @@ -219,6 +240,7 @@ def test_async_tool_using_decorator_within_flow(custom_tool_decorator): class StructuredExampleFlow(Flow): from crewai.flow.flow import start + @start() async def start(self): crew = build_simple_crew(custom_tool_decorator) @@ -227,4 +249,96 @@ def test_async_tool_using_decorator_within_flow(custom_tool_decorator): flow = StructuredExampleFlow() result = flow.kickoff() - assert result.raw == "Hello World from Custom Tool" \ No newline at end of file + assert result.raw == "Hello World from Custom Tool" + + +def test_invoke_sync_function_single_execution(): + """Test that sync functions are called only once, not twice.""" + call_count = 0 + + def counting_func(message: str) -> str: + nonlocal call_count + call_count += 1 + return f"Called {call_count} times with: {message}" + + tool = CrewStructuredTool.from_function( + func=counting_func, name="counting_tool", description="A tool that counts calls" + ) + + result = tool.invoke({"message": "test"}) + assert call_count == 1, f"Function was called {call_count} times, expected 1" + assert result == "Called 1 times with: test" + + +def test_invoke_async_function_outside_event_loop(): + """Test that async functions work correctly when called outside event loop.""" + + async def async_func(message: str) -> str: + return f"Async result: {message}" + + tool = CrewStructuredTool.from_function( + func=async_func, name="async_tool", description="An async tool" + ) + + result = tool.invoke({"message": "test"}) + assert result == "Async result: test" + + +@pytest.mark.asyncio +async def test_invoke_async_function_in_event_loop_raises_error(): + """Test that async functions raise RuntimeError when called from within event loop.""" + + async def async_func(message: str) -> str: + return f"Async result: {message}" + + tool = CrewStructuredTool.from_function( + func=async_func, name="async_tool", description="An async tool" + ) + + with pytest.raises( + RuntimeError, + match="Cannot call async tool.*from synchronous context within an event loop", + ): + tool.invoke({"message": "test"}) + + +def test_invoke_sync_function_returning_coroutine(): + """Test handling of sync functions that return coroutines.""" + + async def inner_async(message: str) -> str: + return f"Inner async: {message}" + + def sync_func_returning_coro(message: str): + return inner_async(message) + + tool = CrewStructuredTool.from_function( + func=sync_func_returning_coro, + name="sync_coro_tool", + description="A sync tool that returns coroutine", + ) + + result = tool.invoke({"message": "test"}) + assert result == "Inner async: test" + + +@pytest.mark.asyncio +async def test_invoke_sync_function_returning_coroutine_in_event_loop_raises_error(): + """Test that sync functions returning coroutines raise RuntimeError in event loop.""" + + async def inner_async(message: str) -> str: + return f"Inner async: {message}" + + def sync_func_returning_coro(message: str): + return inner_async(message) + + tool = CrewStructuredTool.from_function( + func=sync_func_returning_coro, + name="sync_coro_tool", + description="A sync tool that returns coroutine", + ) + + with pytest.raises( + RuntimeError, + match="Sync function.*returned a coroutine but we're in an event loop", + ): + tool.invoke({"message": "test"})