chore: refactor decorators and add output class wrappers

Refactored task, agent, llm, tool, callback, and cache_handler decorators to directly memoize the method before wrapping, simplifying their implementation. Introduced OutputJsonClass and OutputPydanticClass wrappers to mark classes for JSON and Pydantic output formats, replacing the previous attribute-based approach.
This commit is contained in:
Greyson Lalonde
2025-10-13 13:01:14 -04:00
parent 7d6324dfa3
commit 620df71763
2 changed files with 103 additions and 35 deletions

View File

@@ -13,6 +13,8 @@ from crewai.project.wrappers import (
CacheHandlerMethod,
CallbackMethod,
LLMMethod,
OutputJsonClass,
OutputPydanticClass,
TaskMethod,
TaskResultT,
ToolMethod,
@@ -20,6 +22,7 @@ from crewai.project.wrappers import (
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def before_kickoff(meth: Callable[P, R]) -> BeforeKickoffMethod[P, R]:
@@ -55,9 +58,7 @@ def task(meth: Callable[P, TaskResultT]) -> TaskMethod[P, TaskResultT]:
Returns:
A wrapped method marked as a task with memoization.
"""
wrapped = TaskMethod(meth)
wrapped._wrapped = memoize(wrapped._meth)
return wrapped
return TaskMethod(memoize(meth))
def agent(meth: Callable[P, R]) -> AgentMethod[P, R]:
@@ -69,9 +70,7 @@ def agent(meth: Callable[P, R]) -> AgentMethod[P, R]:
Returns:
A wrapped method marked as an agent with memoization.
"""
wrapped = AgentMethod(meth)
wrapped._wrapped = memoize(wrapped._meth)
return wrapped
return AgentMethod(memoize(meth))
def llm(meth: Callable[P, R]) -> LLMMethod[P, R]:
@@ -83,35 +82,31 @@ def llm(meth: Callable[P, R]) -> LLMMethod[P, R]:
Returns:
A wrapped method marked as an LLM provider with memoization.
"""
wrapped = LLMMethod(meth)
wrapped._wrapped = memoize(wrapped._meth)
return wrapped
return LLMMethod(memoize(meth))
def output_json(cls):
def output_json(cls: type[T]) -> OutputJsonClass[T]:
"""Marks a class as JSON output format.
Args:
cls: The class to mark.
Returns:
The class with is_output_json attribute set.
A wrapped class marked as JSON output format.
"""
cls.is_output_json = True
return cls
return OutputJsonClass(cls)
def output_pydantic(cls):
def output_pydantic(cls: type[T]) -> OutputPydanticClass[T]:
"""Marks a class as Pydantic output format.
Args:
cls: The class to mark.
Returns:
The class with is_output_pydantic attribute set.
A wrapped class marked as Pydantic output format.
"""
cls.is_output_pydantic = True
return cls
return OutputPydanticClass(cls)
def tool(meth: Callable[P, R]) -> ToolMethod[P, R]:
@@ -123,9 +118,7 @@ def tool(meth: Callable[P, R]) -> ToolMethod[P, R]:
Returns:
A wrapped method marked as a tool with memoization.
"""
wrapped = ToolMethod(meth)
wrapped._wrapped = memoize(wrapped._meth)
return wrapped
return ToolMethod(memoize(meth))
def callback(meth: Callable[P, R]) -> CallbackMethod[P, R]:
@@ -137,9 +130,7 @@ def callback(meth: Callable[P, R]) -> CallbackMethod[P, R]:
Returns:
A wrapped method marked as a callback with memoization.
"""
wrapped = CallbackMethod(meth)
wrapped._wrapped = memoize(wrapped._meth)
return wrapped
return CallbackMethod(memoize(meth))
def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
@@ -151,9 +142,7 @@ def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
Returns:
A wrapped method marked as a cache handler with memoization.
"""
wrapped = CacheHandlerMethod(meth)
wrapped._wrapped = memoize(wrapped._meth)
return wrapped
return CacheHandlerMethod(memoize(meth))
def crew(meth) -> Callable[..., Crew]:

View File

@@ -39,7 +39,6 @@ class DecoratedMethod(Generic[P, R]):
meth: The method to wrap.
"""
self._meth = meth
self._wrapped: Callable[P, R] | None = None
# Preserve function metadata
wraps(meth)(self)
@@ -58,8 +57,6 @@ class DecoratedMethod(Generic[P, R]):
Returns:
The result of calling the wrapped method.
"""
if self._wrapped:
return self._wrapped(*args, **kwargs)
return self._meth(*args, **kwargs)
def unwrap(self) -> Callable[P, R]:
@@ -95,7 +92,6 @@ class TaskMethod(Generic[P, TaskResultT]):
meth: The method to wrap.
"""
self._meth = meth
self._wrapped: Callable[P, TaskResultT] | None = None
# Preserve function metadata
wraps(meth)(self)
@@ -114,11 +110,7 @@ class TaskMethod(Generic[P, TaskResultT]):
Returns:
The task instance with name set if not already provided.
"""
if self._wrapped:
result = self._wrapped(*args, **kwargs)
else:
result = self._meth(*args, **kwargs)
result = self._meth(*args, **kwargs)
if not result.name:
result.name = self._meth.__name__
return result
@@ -166,3 +158,90 @@ class CrewMethod(DecoratedMethod[P, R]):
"""Wrapper for methods marked as the main crew execution point."""
is_crew: bool = True
T = TypeVar("T")
class OutputJsonClass(Generic[T]):
"""Wrapper for classes marked as JSON output format."""
is_output_json: bool = True
def __init__(self, cls: type[T]) -> None:
"""Initialize the output JSON class wrapper.
Args:
cls: The class to wrap.
"""
self._cls = cls
# Copy class attributes
self.__name__ = cls.__name__
self.__qualname__ = cls.__qualname__
self.__module__ = cls.__module__
self.__doc__ = cls.__doc__
def __call__(self, *args, **kwargs) -> T:
"""Create an instance of the wrapped class.
Args:
*args: Positional arguments for the class constructor.
**kwargs: Keyword arguments for the class constructor.
Returns:
An instance of the wrapped class.
"""
return self._cls(*args, **kwargs)
def __getattr__(self, name: str):
"""Delegate attribute access to the wrapped class.
Args:
name: The attribute name.
Returns:
The attribute from the wrapped class.
"""
return getattr(self._cls, name)
class OutputPydanticClass(Generic[T]):
"""Wrapper for classes marked as Pydantic output format."""
is_output_pydantic: bool = True
def __init__(self, cls: type[T]) -> None:
"""Initialize the output Pydantic class wrapper.
Args:
cls: The class to wrap.
"""
self._cls = cls
# Copy class attributes
self.__name__ = cls.__name__
self.__qualname__ = cls.__qualname__
self.__module__ = cls.__module__
self.__doc__ = cls.__doc__
def __call__(self, *args, **kwargs) -> T:
"""Create an instance of the wrapped class.
Args:
*args: Positional arguments for the class constructor.
**kwargs: Keyword arguments for the class constructor.
Returns:
An instance of the wrapped class.
"""
return self._cls(*args, **kwargs)
def __getattr__(self, name: str):
"""Delegate attribute access to the wrapped class.
Args:
name: The attribute name.
Returns:
The attribute from the wrapped class.
"""
return getattr(self._cls, name)