added condition to check whether _run function returns a coroutine ob… (#2570)

* added condition to check whether _run function returns a coroutine object

* Cleaned the code

* Fixed the test modules, Class -> Functions
This commit is contained in:
Vidit Ostwal
2025-04-11 22:26:37 +05:30
committed by GitHub
parent 0cd524af86
commit ea5ae9086a
2 changed files with 79 additions and 2 deletions

View File

@@ -1,3 +1,4 @@
import asyncio
import warnings
from abc import ABC, abstractmethod
from inspect import signature
@@ -65,7 +66,13 @@ class BaseTool(BaseModel, ABC):
**kwargs: Any,
) -> Any:
print(f"Using Tool: {self.name}")
return self._run(*args, **kwargs)
result = self._run(*args, **kwargs)
# If _run is async, we safely run it
if asyncio.iscoroutine(result):
return asyncio.run(result)
return result
@abstractmethod
def _run(

View File

@@ -1,4 +1,8 @@
from typing import Callable
import asyncio
import inspect
import unittest
from typing import Any, Callable, Dict, List
from unittest.mock import patch
from crewai.tools import BaseTool, tool
@@ -122,3 +126,69 @@ def test_result_as_answer_in_tool_decorator():
converted_tool = my_tool_with_default.to_structured_tool()
assert converted_tool.result_as_answer is False
class SyncTool(BaseTool):
"""Test implementation with a synchronous _run method"""
name: str = "sync_tool"
description: str = "A synchronous tool for testing"
def _run(self, input_text: str) -> str:
"""Process input text synchronously."""
return f"Processed {input_text} synchronously"
class AsyncTool(BaseTool):
"""Test implementation with an asynchronous _run method"""
name: str = "async_tool"
description: str = "An asynchronous tool for testing"
async def _run(self, input_text: str) -> str:
"""Process input text asynchronously."""
await asyncio.sleep(0.1) # Simulate async operation
return f"Processed {input_text} asynchronously"
def test_sync_run_returns_direct_result():
"""Test that _run in a synchronous tool returns a direct result, not a coroutine."""
tool = SyncTool()
result = tool._run(input_text="hello")
assert not asyncio.iscoroutine(result)
assert result == "Processed hello synchronously"
run_result = tool.run(input_text="hello")
assert run_result == "Processed hello synchronously"
def test_async_run_returns_coroutine():
"""Test that _run in an asynchronous tool returns a coroutine object."""
tool = AsyncTool()
result = tool._run(input_text="hello")
assert asyncio.iscoroutine(result)
result.close() # Clean up the coroutine
def test_run_calls_asyncio_run_for_async_tools():
"""Test that asyncio.run is called when using async tools."""
async_tool = AsyncTool()
with patch('asyncio.run') as mock_run:
mock_run.return_value = "Processed test asynchronously"
async_result = async_tool.run(input_text="test")
mock_run.assert_called_once()
assert async_result == "Processed test asynchronously"
def test_run_does_not_call_asyncio_run_for_sync_tools():
"""Test that asyncio.run is NOT called when using sync tools."""
sync_tool = SyncTool()
with patch('asyncio.run') as mock_run:
sync_result = sync_tool.run(input_text="test")
mock_run.assert_not_called()
assert sync_result == "Processed test synchronously"