mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
chore: refactor decorators and add output class wrappers
Refactored task, agent, llm, tool, callback, and cache_handler decorators to directly memoize the method before wrapping, simplifying their implementation. Introduced OutputJsonClass and OutputPydanticClass wrappers to mark classes for JSON and Pydantic output formats, replacing the previous attribute-based approach.
This commit is contained in:
@@ -13,6 +13,8 @@ from crewai.project.wrappers import (
|
||||
CacheHandlerMethod,
|
||||
CallbackMethod,
|
||||
LLMMethod,
|
||||
OutputJsonClass,
|
||||
OutputPydanticClass,
|
||||
TaskMethod,
|
||||
TaskResultT,
|
||||
ToolMethod,
|
||||
@@ -20,6 +22,7 @@ from crewai.project.wrappers import (
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def before_kickoff(meth: Callable[P, R]) -> BeforeKickoffMethod[P, R]:
|
||||
@@ -55,9 +58,7 @@ def task(meth: Callable[P, TaskResultT]) -> TaskMethod[P, TaskResultT]:
|
||||
Returns:
|
||||
A wrapped method marked as a task with memoization.
|
||||
"""
|
||||
wrapped = TaskMethod(meth)
|
||||
wrapped._wrapped = memoize(wrapped._meth)
|
||||
return wrapped
|
||||
return TaskMethod(memoize(meth))
|
||||
|
||||
|
||||
def agent(meth: Callable[P, R]) -> AgentMethod[P, R]:
|
||||
@@ -69,9 +70,7 @@ def agent(meth: Callable[P, R]) -> AgentMethod[P, R]:
|
||||
Returns:
|
||||
A wrapped method marked as an agent with memoization.
|
||||
"""
|
||||
wrapped = AgentMethod(meth)
|
||||
wrapped._wrapped = memoize(wrapped._meth)
|
||||
return wrapped
|
||||
return AgentMethod(memoize(meth))
|
||||
|
||||
|
||||
def llm(meth: Callable[P, R]) -> LLMMethod[P, R]:
|
||||
@@ -83,35 +82,31 @@ def llm(meth: Callable[P, R]) -> LLMMethod[P, R]:
|
||||
Returns:
|
||||
A wrapped method marked as an LLM provider with memoization.
|
||||
"""
|
||||
wrapped = LLMMethod(meth)
|
||||
wrapped._wrapped = memoize(wrapped._meth)
|
||||
return wrapped
|
||||
return LLMMethod(memoize(meth))
|
||||
|
||||
|
||||
def output_json(cls):
|
||||
def output_json(cls: type[T]) -> OutputJsonClass[T]:
|
||||
"""Marks a class as JSON output format.
|
||||
|
||||
Args:
|
||||
cls: The class to mark.
|
||||
|
||||
Returns:
|
||||
The class with is_output_json attribute set.
|
||||
A wrapped class marked as JSON output format.
|
||||
"""
|
||||
cls.is_output_json = True
|
||||
return cls
|
||||
return OutputJsonClass(cls)
|
||||
|
||||
|
||||
def output_pydantic(cls):
|
||||
def output_pydantic(cls: type[T]) -> OutputPydanticClass[T]:
|
||||
"""Marks a class as Pydantic output format.
|
||||
|
||||
Args:
|
||||
cls: The class to mark.
|
||||
|
||||
Returns:
|
||||
The class with is_output_pydantic attribute set.
|
||||
A wrapped class marked as Pydantic output format.
|
||||
"""
|
||||
cls.is_output_pydantic = True
|
||||
return cls
|
||||
return OutputPydanticClass(cls)
|
||||
|
||||
|
||||
def tool(meth: Callable[P, R]) -> ToolMethod[P, R]:
|
||||
@@ -123,9 +118,7 @@ def tool(meth: Callable[P, R]) -> ToolMethod[P, R]:
|
||||
Returns:
|
||||
A wrapped method marked as a tool with memoization.
|
||||
"""
|
||||
wrapped = ToolMethod(meth)
|
||||
wrapped._wrapped = memoize(wrapped._meth)
|
||||
return wrapped
|
||||
return ToolMethod(memoize(meth))
|
||||
|
||||
|
||||
def callback(meth: Callable[P, R]) -> CallbackMethod[P, R]:
|
||||
@@ -137,9 +130,7 @@ def callback(meth: Callable[P, R]) -> CallbackMethod[P, R]:
|
||||
Returns:
|
||||
A wrapped method marked as a callback with memoization.
|
||||
"""
|
||||
wrapped = CallbackMethod(meth)
|
||||
wrapped._wrapped = memoize(wrapped._meth)
|
||||
return wrapped
|
||||
return CallbackMethod(memoize(meth))
|
||||
|
||||
|
||||
def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
|
||||
@@ -151,9 +142,7 @@ def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
|
||||
Returns:
|
||||
A wrapped method marked as a cache handler with memoization.
|
||||
"""
|
||||
wrapped = CacheHandlerMethod(meth)
|
||||
wrapped._wrapped = memoize(wrapped._meth)
|
||||
return wrapped
|
||||
return CacheHandlerMethod(memoize(meth))
|
||||
|
||||
|
||||
def crew(meth) -> Callable[..., Crew]:
|
||||
|
||||
@@ -39,7 +39,6 @@ class DecoratedMethod(Generic[P, R]):
|
||||
meth: The method to wrap.
|
||||
"""
|
||||
self._meth = meth
|
||||
self._wrapped: Callable[P, R] | None = None
|
||||
# Preserve function metadata
|
||||
wraps(meth)(self)
|
||||
|
||||
@@ -58,8 +57,6 @@ class DecoratedMethod(Generic[P, R]):
|
||||
Returns:
|
||||
The result of calling the wrapped method.
|
||||
"""
|
||||
if self._wrapped:
|
||||
return self._wrapped(*args, **kwargs)
|
||||
return self._meth(*args, **kwargs)
|
||||
|
||||
def unwrap(self) -> Callable[P, R]:
|
||||
@@ -95,7 +92,6 @@ class TaskMethod(Generic[P, TaskResultT]):
|
||||
meth: The method to wrap.
|
||||
"""
|
||||
self._meth = meth
|
||||
self._wrapped: Callable[P, TaskResultT] | None = None
|
||||
# Preserve function metadata
|
||||
wraps(meth)(self)
|
||||
|
||||
@@ -114,11 +110,7 @@ class TaskMethod(Generic[P, TaskResultT]):
|
||||
Returns:
|
||||
The task instance with name set if not already provided.
|
||||
"""
|
||||
if self._wrapped:
|
||||
result = self._wrapped(*args, **kwargs)
|
||||
else:
|
||||
result = self._meth(*args, **kwargs)
|
||||
|
||||
result = self._meth(*args, **kwargs)
|
||||
if not result.name:
|
||||
result.name = self._meth.__name__
|
||||
return result
|
||||
@@ -166,3 +158,90 @@ class CrewMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked as the main crew execution point."""
|
||||
|
||||
is_crew: bool = True
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class OutputJsonClass(Generic[T]):
|
||||
"""Wrapper for classes marked as JSON output format."""
|
||||
|
||||
is_output_json: bool = True
|
||||
|
||||
def __init__(self, cls: type[T]) -> None:
|
||||
"""Initialize the output JSON class wrapper.
|
||||
|
||||
Args:
|
||||
cls: The class to wrap.
|
||||
"""
|
||||
self._cls = cls
|
||||
# Copy class attributes
|
||||
self.__name__ = cls.__name__
|
||||
self.__qualname__ = cls.__qualname__
|
||||
self.__module__ = cls.__module__
|
||||
self.__doc__ = cls.__doc__
|
||||
|
||||
def __call__(self, *args, **kwargs) -> T:
|
||||
"""Create an instance of the wrapped class.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments for the class constructor.
|
||||
**kwargs: Keyword arguments for the class constructor.
|
||||
|
||||
Returns:
|
||||
An instance of the wrapped class.
|
||||
"""
|
||||
return self._cls(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
"""Delegate attribute access to the wrapped class.
|
||||
|
||||
Args:
|
||||
name: The attribute name.
|
||||
|
||||
Returns:
|
||||
The attribute from the wrapped class.
|
||||
"""
|
||||
return getattr(self._cls, name)
|
||||
|
||||
|
||||
class OutputPydanticClass(Generic[T]):
|
||||
"""Wrapper for classes marked as Pydantic output format."""
|
||||
|
||||
is_output_pydantic: bool = True
|
||||
|
||||
def __init__(self, cls: type[T]) -> None:
|
||||
"""Initialize the output Pydantic class wrapper.
|
||||
|
||||
Args:
|
||||
cls: The class to wrap.
|
||||
"""
|
||||
self._cls = cls
|
||||
# Copy class attributes
|
||||
self.__name__ = cls.__name__
|
||||
self.__qualname__ = cls.__qualname__
|
||||
self.__module__ = cls.__module__
|
||||
self.__doc__ = cls.__doc__
|
||||
|
||||
def __call__(self, *args, **kwargs) -> T:
|
||||
"""Create an instance of the wrapped class.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments for the class constructor.
|
||||
**kwargs: Keyword arguments for the class constructor.
|
||||
|
||||
Returns:
|
||||
An instance of the wrapped class.
|
||||
"""
|
||||
return self._cls(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
"""Delegate attribute access to the wrapped class.
|
||||
|
||||
Args:
|
||||
name: The attribute name.
|
||||
|
||||
Returns:
|
||||
The attribute from the wrapped class.
|
||||
"""
|
||||
return getattr(self._cls, name)
|
||||
|
||||
Reference in New Issue
Block a user