mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 12:28:30 +00:00
Compare commits
2 Commits
1.6.1
...
devin/1760
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4cb8df09dd | ||
|
|
0938279b89 |
@@ -14,6 +14,7 @@ from pydantic import (
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.utilities.asyncio_utils import run_coroutine_sync
|
||||
|
||||
|
||||
class EnvVar(BaseModel):
|
||||
@@ -90,7 +91,7 @@ class BaseTool(BaseModel, ABC):
|
||||
|
||||
# If _run is async, we safely run it
|
||||
if asyncio.iscoroutine(result):
|
||||
result = asyncio.run(result)
|
||||
result = run_coroutine_sync(result)
|
||||
|
||||
self.current_usage_count += 1
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, get_type_hints
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from crewai.utilities.asyncio_utils import run_coroutine_sync
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -269,12 +270,12 @@ class CrewStructuredTool:
|
||||
self._increment_usage_count()
|
||||
|
||||
if inspect.iscoroutinefunction(self.func):
|
||||
return asyncio.run(self.func(**parsed_args, **kwargs))
|
||||
return run_coroutine_sync(self.func(**parsed_args, **kwargs))
|
||||
|
||||
result = self.func(**parsed_args, **kwargs)
|
||||
|
||||
if asyncio.iscoroutine(result):
|
||||
return asyncio.run(result)
|
||||
return run_coroutine_sync(result)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
53
src/crewai/utilities/asyncio_utils.py
Normal file
53
src/crewai/utilities/asyncio_utils.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Utilities for handling asyncio operations safely across different contexts."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Coroutine
|
||||
from typing import Any
|
||||
|
||||
|
||||
def run_coroutine_sync(coro: Coroutine) -> Any:
|
||||
"""
|
||||
Run a coroutine synchronously, handling both cases where an event loop
|
||||
is already running and where it's not.
|
||||
|
||||
This is useful when you need to run async code from sync code, but you're
|
||||
not sure if you're already in an async context (e.g., when using asyncio.to_thread).
|
||||
|
||||
Args:
|
||||
coro: The coroutine to run
|
||||
|
||||
Returns:
|
||||
The result of the coroutine
|
||||
|
||||
Raises:
|
||||
Any exception raised by the coroutine
|
||||
"""
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return asyncio.run(coro)
|
||||
else:
|
||||
import threading
|
||||
|
||||
result = None
|
||||
exception = None
|
||||
|
||||
def run_in_new_loop():
|
||||
nonlocal result, exception
|
||||
try:
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
result = new_loop.run_until_complete(coro)
|
||||
finally:
|
||||
new_loop.close()
|
||||
except Exception as e:
|
||||
exception = e
|
||||
|
||||
thread = threading.Thread(target=run_in_new_loop)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
if exception:
|
||||
raise exception
|
||||
return result
|
||||
142
tests/test_asyncio_tools.py
Normal file
142
tests/test_asyncio_tools.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Tests for asyncio tool execution in different contexts."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.tools import tool
|
||||
|
||||
|
||||
@tool
|
||||
async def async_test_tool(message: str) -> str:
|
||||
"""An async tool that returns a message."""
|
||||
await asyncio.sleep(0.01)
|
||||
return f"Processed: {message}"
|
||||
|
||||
|
||||
@tool
|
||||
def sync_test_tool(message: str) -> str:
|
||||
"""A sync tool that returns a message."""
|
||||
return f"Sync: {message}"
|
||||
|
||||
|
||||
class TestAsyncioToolExecution:
|
||||
"""Test that tools work correctly in different asyncio contexts."""
|
||||
|
||||
@patch("crewai.Agent.execute_task")
|
||||
def test_async_tool_with_asyncio_to_thread(self, mock_execute_task):
|
||||
"""Test that async tools work when crew is run with asyncio.to_thread."""
|
||||
mock_execute_task.return_value = "Task completed"
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test async tool execution",
|
||||
backstory="A test agent",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="A result",
|
||||
agent=agent,
|
||||
tools=[async_test_tool],
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task], verbose=False)
|
||||
|
||||
async def run_with_to_thread():
|
||||
"""Run crew with asyncio.to_thread - this should not hang."""
|
||||
result = await asyncio.to_thread(crew.kickoff)
|
||||
return result
|
||||
|
||||
result = asyncio.run(run_with_to_thread())
|
||||
assert result is not None
|
||||
|
||||
@patch("crewai.Agent.execute_task")
|
||||
def test_sync_tool_with_asyncio_to_thread(self, mock_execute_task):
|
||||
"""Test that sync tools work when crew is run with asyncio.to_thread."""
|
||||
mock_execute_task.return_value = "Task completed"
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test sync tool execution",
|
||||
backstory="A test agent",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="A result",
|
||||
agent=agent,
|
||||
tools=[sync_test_tool],
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task], verbose=False)
|
||||
|
||||
async def run_with_to_thread():
|
||||
"""Run crew with asyncio.to_thread."""
|
||||
result = await asyncio.to_thread(crew.kickoff)
|
||||
return result
|
||||
|
||||
result = asyncio.run(run_with_to_thread())
|
||||
assert result is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.execute_task")
|
||||
async def test_async_tool_with_kickoff_async(self, mock_execute_task):
|
||||
"""Test that async tools work with kickoff_async."""
|
||||
mock_execute_task.return_value = "Task completed"
|
||||
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test async tool execution",
|
||||
backstory="A test agent",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="A result",
|
||||
agent=agent,
|
||||
tools=[async_test_tool],
|
||||
)
|
||||
|
||||
crew = Crew(agents=[agent], tasks=[task], verbose=False)
|
||||
|
||||
result = await crew.kickoff_async()
|
||||
assert result is not None
|
||||
|
||||
def test_async_tool_direct_invocation(self):
|
||||
"""Test that async tools can be invoked directly from sync context."""
|
||||
structured_tool = async_test_tool.to_structured_tool()
|
||||
result = structured_tool.invoke({"message": "test"})
|
||||
assert result == "Processed: test"
|
||||
|
||||
def test_async_tool_invocation_from_thread(self):
|
||||
"""Test that async tools work when invoked from a thread pool."""
|
||||
structured_tool = async_test_tool.to_structured_tool()
|
||||
|
||||
def invoke_tool():
|
||||
return structured_tool.invoke({"message": "test"})
|
||||
|
||||
async def run_in_thread():
|
||||
result = await asyncio.to_thread(invoke_tool)
|
||||
return result
|
||||
|
||||
result = asyncio.run(run_in_thread())
|
||||
assert result == "Processed: test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_async_tools_concurrent(self):
|
||||
"""Test multiple async tool invocations concurrently."""
|
||||
structured_tool = async_test_tool.to_structured_tool()
|
||||
|
||||
async def invoke_async():
|
||||
return await structured_tool.ainvoke({"message": "test"})
|
||||
|
||||
results = await asyncio.gather(
|
||||
invoke_async(), invoke_async(), invoke_async()
|
||||
)
|
||||
|
||||
assert len(results) == 3
|
||||
for r in results:
|
||||
assert "test" in str(r)
|
||||
Reference in New Issue
Block a user