diff --git a/lib/crewai/src/crewai/project/annotations.py b/lib/crewai/src/crewai/project/annotations.py index a36999052..160359540 100644 --- a/lib/crewai/src/crewai/project/annotations.py +++ b/lib/crewai/src/crewai/project/annotations.py @@ -2,8 +2,10 @@ from __future__ import annotations +import asyncio from collections.abc import Callable from functools import wraps +import inspect from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload from crewai.project.utils import memoize @@ -156,6 +158,23 @@ def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]: return CacheHandlerMethod(memoize(meth)) +def _call_method(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: + """Call a method, awaiting it if async and running in an event loop.""" + result = method(*args, **kwargs) + if inspect.iscoroutine(result): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop and loop.is_running(): + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as pool: + return pool.submit(asyncio.run, result).result() + return asyncio.run(result) + return result + + @overload def crew( meth: Callable[Concatenate[SelfT, P], Crew], @@ -198,7 +217,7 @@ def crew( # Instantiate tasks in order for _, task_method in tasks: - task_instance = task_method(self) + task_instance = _call_method(task_method, self) instantiated_tasks.append(task_instance) agent_instance = getattr(task_instance, "agent", None) if agent_instance and agent_instance.role not in agent_roles: @@ -207,7 +226,7 @@ def crew( # Instantiate agents not included by tasks for _, agent_method in agents: - agent_instance = agent_method(self) + agent_instance = _call_method(agent_method, self) if agent_instance.role not in agent_roles: instantiated_agents.append(agent_instance) agent_roles.add(agent_instance.role) @@ -215,7 +234,7 @@ def crew( self.agents = instantiated_agents self.tasks = instantiated_tasks - crew_instance = meth(self, *args, **kwargs) + crew_instance: Crew = _call_method(meth, self, *args, **kwargs) def callback_wrapper( hook: Callable[Concatenate[CrewInstance, P2], R2], instance: CrewInstance diff --git a/lib/crewai/src/crewai/project/utils.py b/lib/crewai/src/crewai/project/utils.py index eae363b0d..b46a4dc44 100644 --- a/lib/crewai/src/crewai/project/utils.py +++ b/lib/crewai/src/crewai/project/utils.py @@ -1,7 +1,8 @@ """Utility functions for the crewai project module.""" -from collections.abc import Callable +from collections.abc import Callable, Coroutine from functools import wraps +import inspect from typing import Any, ParamSpec, TypeVar, cast from pydantic import BaseModel @@ -37,8 +38,8 @@ def _make_hashable(arg: Any) -> Any: def memoize(meth: Callable[P, R]) -> Callable[P, R]: """Memoize a method by caching its results based on arguments. - Handles Pydantic BaseModel instances by converting them to JSON strings - before hashing for cache lookup. + Handles both sync and async methods. Pydantic BaseModel instances are + converted to JSON strings before hashing for cache lookup. Args: meth: The method to memoize. @@ -46,18 +47,16 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]: Returns: A memoized version of the method that caches results. """ + if inspect.iscoroutinefunction(meth): + return cast(Callable[P, R], _memoize_async(meth)) + return _memoize_sync(meth) + + +def _memoize_sync(meth: Callable[P, R]) -> Callable[P, R]: + """Memoize a synchronous method.""" @wraps(meth) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - """Wrapper that converts arguments to hashable form before caching. - - Args: - *args: Positional arguments to the memoized method. - **kwargs: Keyword arguments to the memoized method. - - Returns: - The result of the memoized method call. - """ hashable_args = tuple(_make_hashable(arg) for arg in args) hashable_kwargs = tuple( sorted((k, _make_hashable(v)) for k, v in kwargs.items()) @@ -73,3 +72,27 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]: return result return cast(Callable[P, R], wrapper) + + +def _memoize_async( + meth: Callable[P, Coroutine[Any, Any, R]], +) -> Callable[P, Coroutine[Any, Any, R]]: + """Memoize an async method.""" + + @wraps(meth) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + hashable_args = tuple(_make_hashable(arg) for arg in args) + hashable_kwargs = tuple( + sorted((k, _make_hashable(v)) for k, v in kwargs.items()) + ) + cache_key = str((hashable_args, hashable_kwargs)) + + cached_result: R | None = cache.read(tool=meth.__name__, input=cache_key) + if cached_result is not None: + return cached_result + + result = await meth(*args, **kwargs) + cache.add(tool=meth.__name__, input=cache_key, output=result) + return result + + return wrapper diff --git a/lib/crewai/src/crewai/project/wrappers.py b/lib/crewai/src/crewai/project/wrappers.py index bfe28aa22..28cd39525 100644 --- a/lib/crewai/src/crewai/project/wrappers.py +++ b/lib/crewai/src/crewai/project/wrappers.py @@ -2,8 +2,10 @@ from __future__ import annotations +import asyncio from collections.abc import Callable from functools import partial +import inspect from pathlib import Path from typing import ( TYPE_CHECKING, @@ -132,6 +134,22 @@ class CrewClass(Protocol): crew: Callable[..., Crew] +def _resolve_result(result: Any) -> Any: + """Resolve a potentially async result to its value.""" + if inspect.iscoroutine(result): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop and loop.is_running(): + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as pool: + return pool.submit(asyncio.run, result).result() + return asyncio.run(result) + return result + + class DecoratedMethod(Generic[P, R]): """Base wrapper for methods with decorator metadata. @@ -162,7 +180,12 @@ class DecoratedMethod(Generic[P, R]): """ if obj is None: return self - bound = partial(self._meth, obj) + inner = partial(self._meth, obj) + + def _bound(*args: Any, **kwargs: Any) -> R: + result: R = _resolve_result(inner(*args, **kwargs)) # type: ignore[call-arg] + return result + for attr in ( "is_agent", "is_llm", @@ -174,8 +197,8 @@ class DecoratedMethod(Generic[P, R]): "is_crew", ): if hasattr(self, attr): - setattr(bound, attr, getattr(self, attr)) - return bound + setattr(_bound, attr, getattr(self, attr)) + return _bound def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: """Call the wrapped method. @@ -236,6 +259,7 @@ class BoundTaskMethod(Generic[TaskResultT]): The task result with name ensured. """ result = self._task_method.unwrap()(self._obj, *args, **kwargs) + result = _resolve_result(result) return self._task_method.ensure_task_name(result) @@ -292,7 +316,9 @@ class TaskMethod(Generic[P, TaskResultT]): Returns: The task instance with name set if not already provided. """ - return self.ensure_task_name(self._meth(*args, **kwargs)) + result = self._meth(*args, **kwargs) + result = _resolve_result(result) + return self.ensure_task_name(result) def unwrap(self) -> Callable[P, TaskResultT]: """Get the original unwrapped method. diff --git a/lib/crewai/tests/test_project.py b/lib/crewai/tests/test_project.py index ebc3dfb82..33cf228f7 100644 --- a/lib/crewai/tests/test_project.py +++ b/lib/crewai/tests/test_project.py @@ -272,6 +272,99 @@ def another_simple_tool(): return "Hi!" +class TestAsyncDecoratorSupport: + """Tests for async method support in @agent, @task decorators.""" + + def test_async_agent_memoization(self): + """Async agent methods should be properly memoized.""" + + class AsyncAgentCrew: + call_count = 0 + + @agent + async def async_agent(self): + AsyncAgentCrew.call_count += 1 + return Agent( + role="Async Agent", goal="Async Goal", backstory="Async Backstory" + ) + + crew = AsyncAgentCrew() + first_call = crew.async_agent() + second_call = crew.async_agent() + + assert first_call is second_call, "Async agent memoization failed" + assert AsyncAgentCrew.call_count == 1, "Async agent called more than once" + + def test_async_task_memoization(self): + """Async task methods should be properly memoized.""" + + class AsyncTaskCrew: + call_count = 0 + + @task + async def async_task(self): + AsyncTaskCrew.call_count += 1 + return Task( + description="Async Description", expected_output="Async Output" + ) + + crew = AsyncTaskCrew() + first_call = crew.async_task() + second_call = crew.async_task() + + assert first_call is second_call, "Async task memoization failed" + assert AsyncTaskCrew.call_count == 1, "Async task called more than once" + + def test_async_task_name_inference(self): + """Async task should have name inferred from method name.""" + + class AsyncTaskNameCrew: + @task + async def my_async_task(self): + return Task( + description="Async Description", expected_output="Async Output" + ) + + crew = AsyncTaskNameCrew() + task_instance = crew.my_async_task() + + assert task_instance.name == "my_async_task", ( + "Async task name not inferred correctly" + ) + + def test_async_agent_returns_agent_not_coroutine(self): + """Async agent decorator should return Agent, not coroutine.""" + + class AsyncAgentTypeCrew: + @agent + async def typed_async_agent(self): + return Agent( + role="Typed Agent", goal="Typed Goal", backstory="Typed Backstory" + ) + + crew = AsyncAgentTypeCrew() + result = crew.typed_async_agent() + + assert isinstance(result, Agent), ( + f"Expected Agent, got {type(result).__name__}" + ) + + def test_async_task_returns_task_not_coroutine(self): + """Async task decorator should return Task, not coroutine.""" + + class AsyncTaskTypeCrew: + @task + async def typed_async_task(self): + return Task( + description="Typed Description", expected_output="Typed Output" + ) + + crew = AsyncTaskTypeCrew() + result = crew.typed_async_task() + + assert isinstance(result, Task), f"Expected Task, got {type(result).__name__}" + + def test_internal_crew_with_mcp(): from crewai_tools.adapters.tool_collection import ToolCollection