mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
added condition to check whether _run function returns a coroutine ob… (#2570)
* added condition to check whether _run function returns a coroutine object * Cleaned the code * Fixed the test modules, Class -> Functions
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from inspect import signature
|
||||
@@ -65,7 +66,13 @@ class BaseTool(BaseModel, ABC):
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
print(f"Using Tool: {self.name}")
|
||||
return self._run(*args, **kwargs)
|
||||
result = self._run(*args, **kwargs)
|
||||
|
||||
# If _run is async, we safely run it
|
||||
if asyncio.iscoroutine(result):
|
||||
return asyncio.run(result)
|
||||
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def _run(
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
from typing import Callable
|
||||
import asyncio
|
||||
import inspect
|
||||
import unittest
|
||||
from typing import Any, Callable, Dict, List
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai.tools import BaseTool, tool
|
||||
|
||||
@@ -122,3 +126,69 @@ def test_result_as_answer_in_tool_decorator():
|
||||
|
||||
converted_tool = my_tool_with_default.to_structured_tool()
|
||||
assert converted_tool.result_as_answer is False
|
||||
|
||||
|
||||
class SyncTool(BaseTool):
|
||||
"""Test implementation with a synchronous _run method"""
|
||||
name: str = "sync_tool"
|
||||
description: str = "A synchronous tool for testing"
|
||||
|
||||
def _run(self, input_text: str) -> str:
|
||||
"""Process input text synchronously."""
|
||||
return f"Processed {input_text} synchronously"
|
||||
|
||||
|
||||
class AsyncTool(BaseTool):
|
||||
"""Test implementation with an asynchronous _run method"""
|
||||
name: str = "async_tool"
|
||||
description: str = "An asynchronous tool for testing"
|
||||
|
||||
async def _run(self, input_text: str) -> str:
|
||||
"""Process input text asynchronously."""
|
||||
await asyncio.sleep(0.1) # Simulate async operation
|
||||
return f"Processed {input_text} asynchronously"
|
||||
|
||||
|
||||
def test_sync_run_returns_direct_result():
|
||||
"""Test that _run in a synchronous tool returns a direct result, not a coroutine."""
|
||||
tool = SyncTool()
|
||||
result = tool._run(input_text="hello")
|
||||
|
||||
assert not asyncio.iscoroutine(result)
|
||||
assert result == "Processed hello synchronously"
|
||||
|
||||
run_result = tool.run(input_text="hello")
|
||||
assert run_result == "Processed hello synchronously"
|
||||
|
||||
|
||||
def test_async_run_returns_coroutine():
|
||||
"""Test that _run in an asynchronous tool returns a coroutine object."""
|
||||
tool = AsyncTool()
|
||||
result = tool._run(input_text="hello")
|
||||
|
||||
assert asyncio.iscoroutine(result)
|
||||
result.close() # Clean up the coroutine
|
||||
|
||||
|
||||
def test_run_calls_asyncio_run_for_async_tools():
|
||||
"""Test that asyncio.run is called when using async tools."""
|
||||
async_tool = AsyncTool()
|
||||
|
||||
with patch('asyncio.run') as mock_run:
|
||||
mock_run.return_value = "Processed test asynchronously"
|
||||
async_result = async_tool.run(input_text="test")
|
||||
|
||||
mock_run.assert_called_once()
|
||||
assert async_result == "Processed test asynchronously"
|
||||
|
||||
|
||||
def test_run_does_not_call_asyncio_run_for_sync_tools():
|
||||
"""Test that asyncio.run is NOT called when using sync tools."""
|
||||
sync_tool = SyncTool()
|
||||
|
||||
with patch('asyncio.run') as mock_run:
|
||||
sync_result = sync_tool.run(input_text="test")
|
||||
|
||||
mock_run.assert_not_called()
|
||||
assert sync_result == "Processed test synchronously"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user