mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
Merge branch 'main' into feat/trace-ui-exec-3
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
@@ -65,7 +66,13 @@ class BaseTool(BaseModel, ABC):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
print(f"Using Tool: {self.name}")
|
print(f"Using Tool: {self.name}")
|
||||||
return self._run(*args, **kwargs)
|
result = self._run(*args, **kwargs)
|
||||||
|
|
||||||
|
# If _run is async, we safely run it
|
||||||
|
if asyncio.iscoroutine(result):
|
||||||
|
return asyncio.run(result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _run(
|
def _run(
|
||||||
|
|||||||
@@ -1,4 +1,8 @@
|
|||||||
from typing import Callable
|
import asyncio
|
||||||
|
import inspect
|
||||||
|
import unittest
|
||||||
|
from typing import Any, Callable, Dict, List
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from crewai.tools import BaseTool, tool
|
from crewai.tools import BaseTool, tool
|
||||||
|
|
||||||
@@ -122,3 +126,69 @@ def test_result_as_answer_in_tool_decorator():
|
|||||||
|
|
||||||
converted_tool = my_tool_with_default.to_structured_tool()
|
converted_tool = my_tool_with_default.to_structured_tool()
|
||||||
assert converted_tool.result_as_answer is False
|
assert converted_tool.result_as_answer is False
|
||||||
|
|
||||||
|
|
||||||
|
class SyncTool(BaseTool):
|
||||||
|
"""Test implementation with a synchronous _run method"""
|
||||||
|
name: str = "sync_tool"
|
||||||
|
description: str = "A synchronous tool for testing"
|
||||||
|
|
||||||
|
def _run(self, input_text: str) -> str:
|
||||||
|
"""Process input text synchronously."""
|
||||||
|
return f"Processed {input_text} synchronously"
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncTool(BaseTool):
|
||||||
|
"""Test implementation with an asynchronous _run method"""
|
||||||
|
name: str = "async_tool"
|
||||||
|
description: str = "An asynchronous tool for testing"
|
||||||
|
|
||||||
|
async def _run(self, input_text: str) -> str:
|
||||||
|
"""Process input text asynchronously."""
|
||||||
|
await asyncio.sleep(0.1) # Simulate async operation
|
||||||
|
return f"Processed {input_text} asynchronously"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_run_returns_direct_result():
|
||||||
|
"""Test that _run in a synchronous tool returns a direct result, not a coroutine."""
|
||||||
|
tool = SyncTool()
|
||||||
|
result = tool._run(input_text="hello")
|
||||||
|
|
||||||
|
assert not asyncio.iscoroutine(result)
|
||||||
|
assert result == "Processed hello synchronously"
|
||||||
|
|
||||||
|
run_result = tool.run(input_text="hello")
|
||||||
|
assert run_result == "Processed hello synchronously"
|
||||||
|
|
||||||
|
|
||||||
|
def test_async_run_returns_coroutine():
|
||||||
|
"""Test that _run in an asynchronous tool returns a coroutine object."""
|
||||||
|
tool = AsyncTool()
|
||||||
|
result = tool._run(input_text="hello")
|
||||||
|
|
||||||
|
assert asyncio.iscoroutine(result)
|
||||||
|
result.close() # Clean up the coroutine
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_calls_asyncio_run_for_async_tools():
|
||||||
|
"""Test that asyncio.run is called when using async tools."""
|
||||||
|
async_tool = AsyncTool()
|
||||||
|
|
||||||
|
with patch('asyncio.run') as mock_run:
|
||||||
|
mock_run.return_value = "Processed test asynchronously"
|
||||||
|
async_result = async_tool.run(input_text="test")
|
||||||
|
|
||||||
|
mock_run.assert_called_once()
|
||||||
|
assert async_result == "Processed test asynchronously"
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_does_not_call_asyncio_run_for_sync_tools():
|
||||||
|
"""Test that asyncio.run is NOT called when using sync tools."""
|
||||||
|
sync_tool = SyncTool()
|
||||||
|
|
||||||
|
with patch('asyncio.run') as mock_run:
|
||||||
|
sync_result = sync_tool.run(input_text="test")
|
||||||
|
|
||||||
|
mock_run.assert_not_called()
|
||||||
|
assert sync_result == "Processed test synchronously"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user