From ea5ae9086a18af11fde25cf68f96197f726e6479 Mon Sep 17 00:00:00 2001 From: Vidit Ostwal <110953813+Vidit-Ostwal@users.noreply.github.com> Date: Fri, 11 Apr 2025 22:26:37 +0530 Subject: [PATCH] =?UTF-8?q?added=20condition=20to=20check=20whether=20=5Fr?= =?UTF-8?q?un=20function=20returns=20a=20coroutine=20ob=E2=80=A6=20(#2570)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * added condition to check whether _run function returns a coroutine object * Cleaned the code * Fixed the test modules, Class -> Functions --- src/crewai/tools/base_tool.py | 9 ++++- tests/tools/test_base_tool.py | 72 ++++++++++++++++++++++++++++++++++- 2 files changed, 79 insertions(+), 2 deletions(-) diff --git a/src/crewai/tools/base_tool.py b/src/crewai/tools/base_tool.py index 2d6526266..0e8a7a22b 100644 --- a/src/crewai/tools/base_tool.py +++ b/src/crewai/tools/base_tool.py @@ -1,3 +1,4 @@ +import asyncio import warnings from abc import ABC, abstractmethod from inspect import signature @@ -65,7 +66,13 @@ class BaseTool(BaseModel, ABC): **kwargs: Any, ) -> Any: print(f"Using Tool: {self.name}") - return self._run(*args, **kwargs) + result = self._run(*args, **kwargs) + + # If _run is async, we safely run it + if asyncio.iscoroutine(result): + return asyncio.run(result) + + return result @abstractmethod def _run( diff --git a/tests/tools/test_base_tool.py b/tests/tools/test_base_tool.py index a1eb7a407..b4f3d2488 100644 --- a/tests/tools/test_base_tool.py +++ b/tests/tools/test_base_tool.py @@ -1,4 +1,8 @@ -from typing import Callable +import asyncio +import inspect +import unittest +from typing import Any, Callable, Dict, List +from unittest.mock import patch from crewai.tools import BaseTool, tool @@ -122,3 +126,69 @@ def test_result_as_answer_in_tool_decorator(): converted_tool = my_tool_with_default.to_structured_tool() assert converted_tool.result_as_answer is False + + +class SyncTool(BaseTool): + """Test implementation with a synchronous _run method""" + name: str = "sync_tool" + description: str = "A synchronous tool for testing" + + def _run(self, input_text: str) -> str: + """Process input text synchronously.""" + return f"Processed {input_text} synchronously" + + +class AsyncTool(BaseTool): + """Test implementation with an asynchronous _run method""" + name: str = "async_tool" + description: str = "An asynchronous tool for testing" + + async def _run(self, input_text: str) -> str: + """Process input text asynchronously.""" + await asyncio.sleep(0.1) # Simulate async operation + return f"Processed {input_text} asynchronously" + + +def test_sync_run_returns_direct_result(): + """Test that _run in a synchronous tool returns a direct result, not a coroutine.""" + tool = SyncTool() + result = tool._run(input_text="hello") + + assert not asyncio.iscoroutine(result) + assert result == "Processed hello synchronously" + + run_result = tool.run(input_text="hello") + assert run_result == "Processed hello synchronously" + + +def test_async_run_returns_coroutine(): + """Test that _run in an asynchronous tool returns a coroutine object.""" + tool = AsyncTool() + result = tool._run(input_text="hello") + + assert asyncio.iscoroutine(result) + result.close() # Clean up the coroutine + + +def test_run_calls_asyncio_run_for_async_tools(): + """Test that asyncio.run is called when using async tools.""" + async_tool = AsyncTool() + + with patch('asyncio.run') as mock_run: + mock_run.return_value = "Processed test asynchronously" + async_result = async_tool.run(input_text="test") + + mock_run.assert_called_once() + assert async_result == "Processed test asynchronously" + + +def test_run_does_not_call_asyncio_run_for_sync_tools(): + """Test that asyncio.run is NOT called when using sync tools.""" + sync_tool = SyncTool() + + with patch('asyncio.run') as mock_run: + sync_result = sync_tool.run(input_text="test") + + mock_run.assert_not_called() + assert sync_result == "Processed test synchronously" +