fix: ensure async methods are executable for annotations

This commit is contained in:
Greyson LaLonde
2025-11-28 19:54:40 -05:00
committed by GitHub
parent 4d8eec96e8
commit c59173a762
4 changed files with 180 additions and 19 deletions

View File

@@ -2,8 +2,10 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
from functools import wraps
import inspect
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
from crewai.project.utils import memoize
@@ -156,6 +158,23 @@ def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
return CacheHandlerMethod(memoize(meth))
def _call_method(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
"""Call a method, awaiting it if async and running in an event loop."""
result = method(*args, **kwargs)
if inspect.iscoroutine(result):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
return pool.submit(asyncio.run, result).result()
return asyncio.run(result)
return result
@overload
def crew(
meth: Callable[Concatenate[SelfT, P], Crew],
@@ -198,7 +217,7 @@ def crew(
# Instantiate tasks in order
for _, task_method in tasks:
task_instance = task_method(self)
task_instance = _call_method(task_method, self)
instantiated_tasks.append(task_instance)
agent_instance = getattr(task_instance, "agent", None)
if agent_instance and agent_instance.role not in agent_roles:
@@ -207,7 +226,7 @@ def crew(
# Instantiate agents not included by tasks
for _, agent_method in agents:
agent_instance = agent_method(self)
agent_instance = _call_method(agent_method, self)
if agent_instance.role not in agent_roles:
instantiated_agents.append(agent_instance)
agent_roles.add(agent_instance.role)
@@ -215,7 +234,7 @@ def crew(
self.agents = instantiated_agents
self.tasks = instantiated_tasks
crew_instance = meth(self, *args, **kwargs)
crew_instance: Crew = _call_method(meth, self, *args, **kwargs)
def callback_wrapper(
hook: Callable[Concatenate[CrewInstance, P2], R2], instance: CrewInstance

View File

@@ -1,7 +1,8 @@
"""Utility functions for the crewai project module."""
from collections.abc import Callable
from collections.abc import Callable, Coroutine
from functools import wraps
import inspect
from typing import Any, ParamSpec, TypeVar, cast
from pydantic import BaseModel
@@ -37,8 +38,8 @@ def _make_hashable(arg: Any) -> Any:
def memoize(meth: Callable[P, R]) -> Callable[P, R]:
"""Memoize a method by caching its results based on arguments.
Handles Pydantic BaseModel instances by converting them to JSON strings
before hashing for cache lookup.
Handles both sync and async methods. Pydantic BaseModel instances are
converted to JSON strings before hashing for cache lookup.
Args:
meth: The method to memoize.
@@ -46,18 +47,16 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]:
Returns:
A memoized version of the method that caches results.
"""
if inspect.iscoroutinefunction(meth):
return cast(Callable[P, R], _memoize_async(meth))
return _memoize_sync(meth)
def _memoize_sync(meth: Callable[P, R]) -> Callable[P, R]:
"""Memoize a synchronous method."""
@wraps(meth)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
"""Wrapper that converts arguments to hashable form before caching.
Args:
*args: Positional arguments to the memoized method.
**kwargs: Keyword arguments to the memoized method.
Returns:
The result of the memoized method call.
"""
hashable_args = tuple(_make_hashable(arg) for arg in args)
hashable_kwargs = tuple(
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
@@ -73,3 +72,27 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]:
return result
return cast(Callable[P, R], wrapper)
def _memoize_async(
meth: Callable[P, Coroutine[Any, Any, R]],
) -> Callable[P, Coroutine[Any, Any, R]]:
"""Memoize an async method."""
@wraps(meth)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
hashable_args = tuple(_make_hashable(arg) for arg in args)
hashable_kwargs = tuple(
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
)
cache_key = str((hashable_args, hashable_kwargs))
cached_result: R | None = cache.read(tool=meth.__name__, input=cache_key)
if cached_result is not None:
return cached_result
result = await meth(*args, **kwargs)
cache.add(tool=meth.__name__, input=cache_key, output=result)
return result
return wrapper

View File

@@ -2,8 +2,10 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
from functools import partial
import inspect
from pathlib import Path
from typing import (
TYPE_CHECKING,
@@ -132,6 +134,22 @@ class CrewClass(Protocol):
crew: Callable[..., Crew]
def _resolve_result(result: Any) -> Any:
"""Resolve a potentially async result to its value."""
if inspect.iscoroutine(result):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
return pool.submit(asyncio.run, result).result()
return asyncio.run(result)
return result
class DecoratedMethod(Generic[P, R]):
"""Base wrapper for methods with decorator metadata.
@@ -162,7 +180,12 @@ class DecoratedMethod(Generic[P, R]):
"""
if obj is None:
return self
bound = partial(self._meth, obj)
inner = partial(self._meth, obj)
def _bound(*args: Any, **kwargs: Any) -> R:
result: R = _resolve_result(inner(*args, **kwargs)) # type: ignore[call-arg]
return result
for attr in (
"is_agent",
"is_llm",
@@ -174,8 +197,8 @@ class DecoratedMethod(Generic[P, R]):
"is_crew",
):
if hasattr(self, attr):
setattr(bound, attr, getattr(self, attr))
return bound
setattr(_bound, attr, getattr(self, attr))
return _bound
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Call the wrapped method.
@@ -236,6 +259,7 @@ class BoundTaskMethod(Generic[TaskResultT]):
The task result with name ensured.
"""
result = self._task_method.unwrap()(self._obj, *args, **kwargs)
result = _resolve_result(result)
return self._task_method.ensure_task_name(result)
@@ -292,7 +316,9 @@ class TaskMethod(Generic[P, TaskResultT]):
Returns:
The task instance with name set if not already provided.
"""
return self.ensure_task_name(self._meth(*args, **kwargs))
result = self._meth(*args, **kwargs)
result = _resolve_result(result)
return self.ensure_task_name(result)
def unwrap(self) -> Callable[P, TaskResultT]:
"""Get the original unwrapped method.