From 620df7176388905bfc722d148995e5f029cbfcad Mon Sep 17 00:00:00 2001 From: Greyson Lalonde Date: Mon, 13 Oct 2025 13:01:14 -0400 Subject: [PATCH] chore: refactor decorators and add output class wrappers Refactored task, agent, llm, tool, callback, and cache_handler decorators to directly memoize the method before wrapping, simplifying their implementation. Introduced OutputJsonClass and OutputPydanticClass wrappers to mark classes for JSON and Pydantic output formats, replacing the previous attribute-based approach. --- src/crewai/project/annotations.py | 41 +++++-------- src/crewai/project/wrappers.py | 97 ++++++++++++++++++++++++++++--- 2 files changed, 103 insertions(+), 35 deletions(-) diff --git a/src/crewai/project/annotations.py b/src/crewai/project/annotations.py index f3205ffa6..564223edf 100644 --- a/src/crewai/project/annotations.py +++ b/src/crewai/project/annotations.py @@ -13,6 +13,8 @@ from crewai.project.wrappers import ( CacheHandlerMethod, CallbackMethod, LLMMethod, + OutputJsonClass, + OutputPydanticClass, TaskMethod, TaskResultT, ToolMethod, @@ -20,6 +22,7 @@ from crewai.project.wrappers import ( P = ParamSpec("P") R = TypeVar("R") +T = TypeVar("T") def before_kickoff(meth: Callable[P, R]) -> BeforeKickoffMethod[P, R]: @@ -55,9 +58,7 @@ def task(meth: Callable[P, TaskResultT]) -> TaskMethod[P, TaskResultT]: Returns: A wrapped method marked as a task with memoization. """ - wrapped = TaskMethod(meth) - wrapped._wrapped = memoize(wrapped._meth) - return wrapped + return TaskMethod(memoize(meth)) def agent(meth: Callable[P, R]) -> AgentMethod[P, R]: @@ -69,9 +70,7 @@ def agent(meth: Callable[P, R]) -> AgentMethod[P, R]: Returns: A wrapped method marked as an agent with memoization. """ - wrapped = AgentMethod(meth) - wrapped._wrapped = memoize(wrapped._meth) - return wrapped + return AgentMethod(memoize(meth)) def llm(meth: Callable[P, R]) -> LLMMethod[P, R]: @@ -83,35 +82,31 @@ def llm(meth: Callable[P, R]) -> LLMMethod[P, R]: Returns: A wrapped method marked as an LLM provider with memoization. """ - wrapped = LLMMethod(meth) - wrapped._wrapped = memoize(wrapped._meth) - return wrapped + return LLMMethod(memoize(meth)) -def output_json(cls): +def output_json(cls: type[T]) -> OutputJsonClass[T]: """Marks a class as JSON output format. Args: cls: The class to mark. Returns: - The class with is_output_json attribute set. + A wrapped class marked as JSON output format. """ - cls.is_output_json = True - return cls + return OutputJsonClass(cls) -def output_pydantic(cls): +def output_pydantic(cls: type[T]) -> OutputPydanticClass[T]: """Marks a class as Pydantic output format. Args: cls: The class to mark. Returns: - The class with is_output_pydantic attribute set. + A wrapped class marked as Pydantic output format. """ - cls.is_output_pydantic = True - return cls + return OutputPydanticClass(cls) def tool(meth: Callable[P, R]) -> ToolMethod[P, R]: @@ -123,9 +118,7 @@ def tool(meth: Callable[P, R]) -> ToolMethod[P, R]: Returns: A wrapped method marked as a tool with memoization. """ - wrapped = ToolMethod(meth) - wrapped._wrapped = memoize(wrapped._meth) - return wrapped + return ToolMethod(memoize(meth)) def callback(meth: Callable[P, R]) -> CallbackMethod[P, R]: @@ -137,9 +130,7 @@ def callback(meth: Callable[P, R]) -> CallbackMethod[P, R]: Returns: A wrapped method marked as a callback with memoization. """ - wrapped = CallbackMethod(meth) - wrapped._wrapped = memoize(wrapped._meth) - return wrapped + return CallbackMethod(memoize(meth)) def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]: @@ -151,9 +142,7 @@ def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]: Returns: A wrapped method marked as a cache handler with memoization. """ - wrapped = CacheHandlerMethod(meth) - wrapped._wrapped = memoize(wrapped._meth) - return wrapped + return CacheHandlerMethod(memoize(meth)) def crew(meth) -> Callable[..., Crew]: diff --git a/src/crewai/project/wrappers.py b/src/crewai/project/wrappers.py index a7f007b96..6f7c24b54 100644 --- a/src/crewai/project/wrappers.py +++ b/src/crewai/project/wrappers.py @@ -39,7 +39,6 @@ class DecoratedMethod(Generic[P, R]): meth: The method to wrap. """ self._meth = meth - self._wrapped: Callable[P, R] | None = None # Preserve function metadata wraps(meth)(self) @@ -58,8 +57,6 @@ class DecoratedMethod(Generic[P, R]): Returns: The result of calling the wrapped method. """ - if self._wrapped: - return self._wrapped(*args, **kwargs) return self._meth(*args, **kwargs) def unwrap(self) -> Callable[P, R]: @@ -95,7 +92,6 @@ class TaskMethod(Generic[P, TaskResultT]): meth: The method to wrap. """ self._meth = meth - self._wrapped: Callable[P, TaskResultT] | None = None # Preserve function metadata wraps(meth)(self) @@ -114,11 +110,7 @@ class TaskMethod(Generic[P, TaskResultT]): Returns: The task instance with name set if not already provided. """ - if self._wrapped: - result = self._wrapped(*args, **kwargs) - else: - result = self._meth(*args, **kwargs) - + result = self._meth(*args, **kwargs) if not result.name: result.name = self._meth.__name__ return result @@ -166,3 +158,90 @@ class CrewMethod(DecoratedMethod[P, R]): """Wrapper for methods marked as the main crew execution point.""" is_crew: bool = True + + +T = TypeVar("T") + + +class OutputJsonClass(Generic[T]): + """Wrapper for classes marked as JSON output format.""" + + is_output_json: bool = True + + def __init__(self, cls: type[T]) -> None: + """Initialize the output JSON class wrapper. + + Args: + cls: The class to wrap. + """ + self._cls = cls + # Copy class attributes + self.__name__ = cls.__name__ + self.__qualname__ = cls.__qualname__ + self.__module__ = cls.__module__ + self.__doc__ = cls.__doc__ + + def __call__(self, *args, **kwargs) -> T: + """Create an instance of the wrapped class. + + Args: + *args: Positional arguments for the class constructor. + **kwargs: Keyword arguments for the class constructor. + + Returns: + An instance of the wrapped class. + """ + return self._cls(*args, **kwargs) + + def __getattr__(self, name: str): + """Delegate attribute access to the wrapped class. + + Args: + name: The attribute name. + + Returns: + The attribute from the wrapped class. + """ + return getattr(self._cls, name) + + +class OutputPydanticClass(Generic[T]): + """Wrapper for classes marked as Pydantic output format.""" + + is_output_pydantic: bool = True + + def __init__(self, cls: type[T]) -> None: + """Initialize the output Pydantic class wrapper. + + Args: + cls: The class to wrap. + """ + self._cls = cls + # Copy class attributes + self.__name__ = cls.__name__ + self.__qualname__ = cls.__qualname__ + self.__module__ = cls.__module__ + self.__doc__ = cls.__doc__ + + def __call__(self, *args, **kwargs) -> T: + """Create an instance of the wrapped class. + + Args: + *args: Positional arguments for the class constructor. + **kwargs: Keyword arguments for the class constructor. + + Returns: + An instance of the wrapped class. + """ + return self._cls(*args, **kwargs) + + def __getattr__(self, name: str): + """Delegate attribute access to the wrapped class. + + Args: + name: The attribute name. + + Returns: + The attribute from the wrapped class. + """ + return getattr(self._cls, name)