mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 23:32:39 +00:00
fix: ensure async methods are executable for annotations
This commit is contained in:
@@ -2,8 +2,10 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
import inspect
|
||||||
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
|
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
|
||||||
|
|
||||||
from crewai.project.utils import memoize
|
from crewai.project.utils import memoize
|
||||||
@@ -156,6 +158,23 @@ def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
|
|||||||
return CacheHandlerMethod(memoize(meth))
|
return CacheHandlerMethod(memoize(meth))
|
||||||
|
|
||||||
|
|
||||||
|
def _call_method(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
||||||
|
"""Call a method, awaiting it if async and running in an event loop."""
|
||||||
|
result = method(*args, **kwargs)
|
||||||
|
if inspect.iscoroutine(result):
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = None
|
||||||
|
if loop and loop.is_running():
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
|
return pool.submit(asyncio.run, result).result()
|
||||||
|
return asyncio.run(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def crew(
|
def crew(
|
||||||
meth: Callable[Concatenate[SelfT, P], Crew],
|
meth: Callable[Concatenate[SelfT, P], Crew],
|
||||||
@@ -198,7 +217,7 @@ def crew(
|
|||||||
|
|
||||||
# Instantiate tasks in order
|
# Instantiate tasks in order
|
||||||
for _, task_method in tasks:
|
for _, task_method in tasks:
|
||||||
task_instance = task_method(self)
|
task_instance = _call_method(task_method, self)
|
||||||
instantiated_tasks.append(task_instance)
|
instantiated_tasks.append(task_instance)
|
||||||
agent_instance = getattr(task_instance, "agent", None)
|
agent_instance = getattr(task_instance, "agent", None)
|
||||||
if agent_instance and agent_instance.role not in agent_roles:
|
if agent_instance and agent_instance.role not in agent_roles:
|
||||||
@@ -207,7 +226,7 @@ def crew(
|
|||||||
|
|
||||||
# Instantiate agents not included by tasks
|
# Instantiate agents not included by tasks
|
||||||
for _, agent_method in agents:
|
for _, agent_method in agents:
|
||||||
agent_instance = agent_method(self)
|
agent_instance = _call_method(agent_method, self)
|
||||||
if agent_instance.role not in agent_roles:
|
if agent_instance.role not in agent_roles:
|
||||||
instantiated_agents.append(agent_instance)
|
instantiated_agents.append(agent_instance)
|
||||||
agent_roles.add(agent_instance.role)
|
agent_roles.add(agent_instance.role)
|
||||||
@@ -215,7 +234,7 @@ def crew(
|
|||||||
self.agents = instantiated_agents
|
self.agents = instantiated_agents
|
||||||
self.tasks = instantiated_tasks
|
self.tasks = instantiated_tasks
|
||||||
|
|
||||||
crew_instance = meth(self, *args, **kwargs)
|
crew_instance: Crew = _call_method(meth, self, *args, **kwargs)
|
||||||
|
|
||||||
def callback_wrapper(
|
def callback_wrapper(
|
||||||
hook: Callable[Concatenate[CrewInstance, P2], R2], instance: CrewInstance
|
hook: Callable[Concatenate[CrewInstance, P2], R2], instance: CrewInstance
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
"""Utility functions for the crewai project module."""
|
"""Utility functions for the crewai project module."""
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable, Coroutine
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
import inspect
|
||||||
from typing import Any, ParamSpec, TypeVar, cast
|
from typing import Any, ParamSpec, TypeVar, cast
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -37,8 +38,8 @@ def _make_hashable(arg: Any) -> Any:
|
|||||||
def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||||
"""Memoize a method by caching its results based on arguments.
|
"""Memoize a method by caching its results based on arguments.
|
||||||
|
|
||||||
Handles Pydantic BaseModel instances by converting them to JSON strings
|
Handles both sync and async methods. Pydantic BaseModel instances are
|
||||||
before hashing for cache lookup.
|
converted to JSON strings before hashing for cache lookup.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
meth: The method to memoize.
|
meth: The method to memoize.
|
||||||
@@ -46,18 +47,16 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
|||||||
Returns:
|
Returns:
|
||||||
A memoized version of the method that caches results.
|
A memoized version of the method that caches results.
|
||||||
"""
|
"""
|
||||||
|
if inspect.iscoroutinefunction(meth):
|
||||||
|
return cast(Callable[P, R], _memoize_async(meth))
|
||||||
|
return _memoize_sync(meth)
|
||||||
|
|
||||||
|
|
||||||
|
def _memoize_sync(meth: Callable[P, R]) -> Callable[P, R]:
|
||||||
|
"""Memoize a synchronous method."""
|
||||||
|
|
||||||
@wraps(meth)
|
@wraps(meth)
|
||||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
"""Wrapper that converts arguments to hashable form before caching.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
*args: Positional arguments to the memoized method.
|
|
||||||
**kwargs: Keyword arguments to the memoized method.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The result of the memoized method call.
|
|
||||||
"""
|
|
||||||
hashable_args = tuple(_make_hashable(arg) for arg in args)
|
hashable_args = tuple(_make_hashable(arg) for arg in args)
|
||||||
hashable_kwargs = tuple(
|
hashable_kwargs = tuple(
|
||||||
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
|
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
|
||||||
@@ -73,3 +72,27 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
return cast(Callable[P, R], wrapper)
|
return cast(Callable[P, R], wrapper)
|
||||||
|
|
||||||
|
|
||||||
|
def _memoize_async(
|
||||||
|
meth: Callable[P, Coroutine[Any, Any, R]],
|
||||||
|
) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||||
|
"""Memoize an async method."""
|
||||||
|
|
||||||
|
@wraps(meth)
|
||||||
|
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
|
hashable_args = tuple(_make_hashable(arg) for arg in args)
|
||||||
|
hashable_kwargs = tuple(
|
||||||
|
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
|
||||||
|
)
|
||||||
|
cache_key = str((hashable_args, hashable_kwargs))
|
||||||
|
|
||||||
|
cached_result: R | None = cache.read(tool=meth.__name__, input=cache_key)
|
||||||
|
if cached_result is not None:
|
||||||
|
return cached_result
|
||||||
|
|
||||||
|
result = await meth(*args, **kwargs)
|
||||||
|
cache.add(tool=meth.__name__, input=cache_key, output=result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|||||||
@@ -2,8 +2,10 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
import inspect
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
@@ -132,6 +134,22 @@ class CrewClass(Protocol):
|
|||||||
crew: Callable[..., Crew]
|
crew: Callable[..., Crew]
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_result(result: Any) -> Any:
|
||||||
|
"""Resolve a potentially async result to its value."""
|
||||||
|
if inspect.iscoroutine(result):
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = None
|
||||||
|
if loop and loop.is_running():
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
|
return pool.submit(asyncio.run, result).result()
|
||||||
|
return asyncio.run(result)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class DecoratedMethod(Generic[P, R]):
|
class DecoratedMethod(Generic[P, R]):
|
||||||
"""Base wrapper for methods with decorator metadata.
|
"""Base wrapper for methods with decorator metadata.
|
||||||
|
|
||||||
@@ -162,7 +180,12 @@ class DecoratedMethod(Generic[P, R]):
|
|||||||
"""
|
"""
|
||||||
if obj is None:
|
if obj is None:
|
||||||
return self
|
return self
|
||||||
bound = partial(self._meth, obj)
|
inner = partial(self._meth, obj)
|
||||||
|
|
||||||
|
def _bound(*args: Any, **kwargs: Any) -> R:
|
||||||
|
result: R = _resolve_result(inner(*args, **kwargs)) # type: ignore[call-arg]
|
||||||
|
return result
|
||||||
|
|
||||||
for attr in (
|
for attr in (
|
||||||
"is_agent",
|
"is_agent",
|
||||||
"is_llm",
|
"is_llm",
|
||||||
@@ -174,8 +197,8 @@ class DecoratedMethod(Generic[P, R]):
|
|||||||
"is_crew",
|
"is_crew",
|
||||||
):
|
):
|
||||||
if hasattr(self, attr):
|
if hasattr(self, attr):
|
||||||
setattr(bound, attr, getattr(self, attr))
|
setattr(_bound, attr, getattr(self, attr))
|
||||||
return bound
|
return _bound
|
||||||
|
|
||||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
"""Call the wrapped method.
|
"""Call the wrapped method.
|
||||||
@@ -236,6 +259,7 @@ class BoundTaskMethod(Generic[TaskResultT]):
|
|||||||
The task result with name ensured.
|
The task result with name ensured.
|
||||||
"""
|
"""
|
||||||
result = self._task_method.unwrap()(self._obj, *args, **kwargs)
|
result = self._task_method.unwrap()(self._obj, *args, **kwargs)
|
||||||
|
result = _resolve_result(result)
|
||||||
return self._task_method.ensure_task_name(result)
|
return self._task_method.ensure_task_name(result)
|
||||||
|
|
||||||
|
|
||||||
@@ -292,7 +316,9 @@ class TaskMethod(Generic[P, TaskResultT]):
|
|||||||
Returns:
|
Returns:
|
||||||
The task instance with name set if not already provided.
|
The task instance with name set if not already provided.
|
||||||
"""
|
"""
|
||||||
return self.ensure_task_name(self._meth(*args, **kwargs))
|
result = self._meth(*args, **kwargs)
|
||||||
|
result = _resolve_result(result)
|
||||||
|
return self.ensure_task_name(result)
|
||||||
|
|
||||||
def unwrap(self) -> Callable[P, TaskResultT]:
|
def unwrap(self) -> Callable[P, TaskResultT]:
|
||||||
"""Get the original unwrapped method.
|
"""Get the original unwrapped method.
|
||||||
|
|||||||
@@ -272,6 +272,99 @@ def another_simple_tool():
|
|||||||
return "Hi!"
|
return "Hi!"
|
||||||
|
|
||||||
|
|
||||||
|
class TestAsyncDecoratorSupport:
|
||||||
|
"""Tests for async method support in @agent, @task decorators."""
|
||||||
|
|
||||||
|
def test_async_agent_memoization(self):
|
||||||
|
"""Async agent methods should be properly memoized."""
|
||||||
|
|
||||||
|
class AsyncAgentCrew:
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
@agent
|
||||||
|
async def async_agent(self):
|
||||||
|
AsyncAgentCrew.call_count += 1
|
||||||
|
return Agent(
|
||||||
|
role="Async Agent", goal="Async Goal", backstory="Async Backstory"
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = AsyncAgentCrew()
|
||||||
|
first_call = crew.async_agent()
|
||||||
|
second_call = crew.async_agent()
|
||||||
|
|
||||||
|
assert first_call is second_call, "Async agent memoization failed"
|
||||||
|
assert AsyncAgentCrew.call_count == 1, "Async agent called more than once"
|
||||||
|
|
||||||
|
def test_async_task_memoization(self):
|
||||||
|
"""Async task methods should be properly memoized."""
|
||||||
|
|
||||||
|
class AsyncTaskCrew:
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
@task
|
||||||
|
async def async_task(self):
|
||||||
|
AsyncTaskCrew.call_count += 1
|
||||||
|
return Task(
|
||||||
|
description="Async Description", expected_output="Async Output"
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = AsyncTaskCrew()
|
||||||
|
first_call = crew.async_task()
|
||||||
|
second_call = crew.async_task()
|
||||||
|
|
||||||
|
assert first_call is second_call, "Async task memoization failed"
|
||||||
|
assert AsyncTaskCrew.call_count == 1, "Async task called more than once"
|
||||||
|
|
||||||
|
def test_async_task_name_inference(self):
|
||||||
|
"""Async task should have name inferred from method name."""
|
||||||
|
|
||||||
|
class AsyncTaskNameCrew:
|
||||||
|
@task
|
||||||
|
async def my_async_task(self):
|
||||||
|
return Task(
|
||||||
|
description="Async Description", expected_output="Async Output"
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = AsyncTaskNameCrew()
|
||||||
|
task_instance = crew.my_async_task()
|
||||||
|
|
||||||
|
assert task_instance.name == "my_async_task", (
|
||||||
|
"Async task name not inferred correctly"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_async_agent_returns_agent_not_coroutine(self):
|
||||||
|
"""Async agent decorator should return Agent, not coroutine."""
|
||||||
|
|
||||||
|
class AsyncAgentTypeCrew:
|
||||||
|
@agent
|
||||||
|
async def typed_async_agent(self):
|
||||||
|
return Agent(
|
||||||
|
role="Typed Agent", goal="Typed Goal", backstory="Typed Backstory"
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = AsyncAgentTypeCrew()
|
||||||
|
result = crew.typed_async_agent()
|
||||||
|
|
||||||
|
assert isinstance(result, Agent), (
|
||||||
|
f"Expected Agent, got {type(result).__name__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_async_task_returns_task_not_coroutine(self):
|
||||||
|
"""Async task decorator should return Task, not coroutine."""
|
||||||
|
|
||||||
|
class AsyncTaskTypeCrew:
|
||||||
|
@task
|
||||||
|
async def typed_async_task(self):
|
||||||
|
return Task(
|
||||||
|
description="Typed Description", expected_output="Typed Output"
|
||||||
|
)
|
||||||
|
|
||||||
|
crew = AsyncTaskTypeCrew()
|
||||||
|
result = crew.typed_async_task()
|
||||||
|
|
||||||
|
assert isinstance(result, Task), f"Expected Task, got {type(result).__name__}"
|
||||||
|
|
||||||
|
|
||||||
def test_internal_crew_with_mcp():
|
def test_internal_crew_with_mcp():
|
||||||
from crewai_tools.adapters.tool_collection import ToolCollection
|
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user