chore: refactor annotation decorators with type-safe wrappers

This commit is contained in:
Greyson Lalonde
2025-10-13 12:44:02 -04:00
parent 541eec0639
commit 7d6324dfa3
2 changed files with 285 additions and 40 deletions

View File

@@ -6,77 +6,154 @@ from typing import Any, Concatenate, ParamSpec, TypeVar
from crewai import Crew
from crewai.project.utils import memoize
from crewai.project.wrappers import (
AfterKickoffMethod,
AgentMethod,
BeforeKickoffMethod,
CacheHandlerMethod,
CallbackMethod,
LLMMethod,
TaskMethod,
TaskResultT,
ToolMethod,
)
P = ParamSpec("P")
R = TypeVar("R")
def before_kickoff(meth):
"""Marks a method to execute before crew kickoff."""
meth.is_before_kickoff = True
return meth
def before_kickoff(meth: Callable[P, R]) -> BeforeKickoffMethod[P, R]:
"""Marks a method to execute before crew kickoff.
Args:
meth: The method to mark.
Returns:
A wrapped method marked for before kickoff execution.
"""
return BeforeKickoffMethod(meth)
def after_kickoff(meth):
"""Marks a method to execute after crew kickoff."""
meth.is_after_kickoff = True
return meth
def after_kickoff(meth: Callable[P, R]) -> AfterKickoffMethod[P, R]:
"""Marks a method to execute after crew kickoff.
Args:
meth: The method to mark.
Returns:
A wrapped method marked for after kickoff execution.
"""
return AfterKickoffMethod(meth)
def task(meth):
"""Marks a method as a crew task."""
meth.is_task = True
def task(meth: Callable[P, TaskResultT]) -> TaskMethod[P, TaskResultT]:
"""Marks a method as a crew task.
@wraps(meth)
def wrapper(*args, **kwargs):
result = meth(*args, **kwargs)
if not result.name:
result.name = meth.__name__
return result
Args:
meth: The method to mark.
return memoize(wrapper)
Returns:
A wrapped method marked as a task with memoization.
"""
wrapped = TaskMethod(meth)
wrapped._wrapped = memoize(wrapped._meth)
return wrapped
def agent(meth):
"""Marks a method as a crew agent."""
meth.is_agent = True
return memoize(meth)
def agent(meth: Callable[P, R]) -> AgentMethod[P, R]:
"""Marks a method as a crew agent.
Args:
meth: The method to mark.
Returns:
A wrapped method marked as an agent with memoization.
"""
wrapped = AgentMethod(meth)
wrapped._wrapped = memoize(wrapped._meth)
return wrapped
def llm(meth):
"""Marks a method as an LLM provider."""
meth.is_llm = True
return memoize(meth)
def llm(meth: Callable[P, R]) -> LLMMethod[P, R]:
"""Marks a method as an LLM provider.
Args:
meth: The method to mark.
Returns:
A wrapped method marked as an LLM provider with memoization.
"""
wrapped = LLMMethod(meth)
wrapped._wrapped = memoize(wrapped._meth)
return wrapped
def output_json(cls):
"""Marks a class as JSON output format."""
"""Marks a class as JSON output format.
Args:
cls: The class to mark.
Returns:
The class with is_output_json attribute set.
"""
cls.is_output_json = True
return cls
def output_pydantic(cls):
"""Marks a class as Pydantic output format."""
"""Marks a class as Pydantic output format.
Args:
cls: The class to mark.
Returns:
The class with is_output_pydantic attribute set.
"""
cls.is_output_pydantic = True
return cls
def tool(meth):
"""Marks a method as a crew tool."""
meth.is_tool = True
return memoize(meth)
def tool(meth: Callable[P, R]) -> ToolMethod[P, R]:
"""Marks a method as a crew tool.
Args:
meth: The method to mark.
Returns:
A wrapped method marked as a tool with memoization.
"""
wrapped = ToolMethod(meth)
wrapped._wrapped = memoize(wrapped._meth)
return wrapped
def callback(meth):
"""Marks a method as a crew callback."""
meth.is_callback = True
return memoize(meth)
def callback(meth: Callable[P, R]) -> CallbackMethod[P, R]:
"""Marks a method as a crew callback.
Args:
meth: The method to mark.
Returns:
A wrapped method marked as a callback with memoization.
"""
wrapped = CallbackMethod(meth)
wrapped._wrapped = memoize(wrapped._meth)
return wrapped
def cache_handler(meth):
"""Marks a method as a cache handler."""
meth.is_cache_handler = True
return memoize(meth)
def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
"""Marks a method as a cache handler.
Args:
meth: The method to mark.
Returns:
A wrapped method marked as a cache handler with memoization.
"""
wrapped = CacheHandlerMethod(meth)
wrapped._wrapped = memoize(wrapped._meth)
return wrapped
def crew(meth) -> Callable[..., Crew]:

View File

@@ -0,0 +1,168 @@
"""Wrapper classes for decorated methods with type-safe metadata."""
from collections.abc import Callable
from functools import wraps
from typing import Generic, ParamSpec, Protocol, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
class TaskResult(Protocol):
"""Protocol for task objects that have a name property."""
@property
def name(self) -> str | None:
"""Get the task name."""
...
@name.setter
def name(self, value: str) -> None:
"""Set the task name."""
...
TaskResultT = TypeVar("TaskResultT", bound=TaskResult)
class DecoratedMethod(Generic[P, R]):
"""Base wrapper for methods with decorator metadata.
This class provides a type-safe way to add metadata to methods
while preserving their callable signature and attributes.
"""
def __init__(self, meth: Callable[P, R]) -> None:
"""Initialize the decorated method wrapper.
Args:
meth: The method to wrap.
"""
self._meth = meth
self._wrapped: Callable[P, R] | None = None
# Preserve function metadata
wraps(meth)(self)
@property
def __name__(self) -> str:
"""Get the name of the wrapped method."""
return self._meth.__name__
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Call the wrapped method.
Args:
*args: Positional arguments.
**kwargs: Keyword arguments.
Returns:
The result of calling the wrapped method.
"""
if self._wrapped:
return self._wrapped(*args, **kwargs)
return self._meth(*args, **kwargs)
def unwrap(self) -> Callable[P, R]:
"""Get the original unwrapped method.
Returns:
The original method before decoration.
"""
return self._meth
class BeforeKickoffMethod(DecoratedMethod[P, R]):
"""Wrapper for methods marked to execute before crew kickoff."""
is_before_kickoff: bool = True
class AfterKickoffMethod(DecoratedMethod[P, R]):
"""Wrapper for methods marked to execute after crew kickoff."""
is_after_kickoff: bool = True
class TaskMethod(Generic[P, TaskResultT]):
"""Wrapper for methods marked as crew tasks."""
is_task: bool = True
def __init__(self, meth: Callable[P, TaskResultT]) -> None:
"""Initialize the task method wrapper.
Args:
meth: The method to wrap.
"""
self._meth = meth
self._wrapped: Callable[P, TaskResultT] | None = None
# Preserve function metadata
wraps(meth)(self)
@property
def __name__(self) -> str:
"""Get the name of the wrapped method."""
return self._meth.__name__
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> TaskResultT:
"""Call the wrapped method and set task name if not provided.
Args:
*args: Positional arguments.
**kwargs: Keyword arguments.
Returns:
The task instance with name set if not already provided.
"""
if self._wrapped:
result = self._wrapped(*args, **kwargs)
else:
result = self._meth(*args, **kwargs)
if not result.name:
result.name = self._meth.__name__
return result
def unwrap(self) -> Callable[P, TaskResultT]:
"""Get the original unwrapped method.
Returns:
The original method before decoration.
"""
return self._meth
class AgentMethod(DecoratedMethod[P, R]):
"""Wrapper for methods marked as crew agents."""
is_agent: bool = True
class LLMMethod(DecoratedMethod[P, R]):
"""Wrapper for methods marked as LLM providers."""
is_llm: bool = True
class ToolMethod(DecoratedMethod[P, R]):
"""Wrapper for methods marked as crew tools."""
is_tool: bool = True
class CallbackMethod(DecoratedMethod[P, R]):
"""Wrapper for methods marked as crew callbacks."""
is_callback: bool = True
class CacheHandlerMethod(DecoratedMethod[P, R]):
"""Wrapper for methods marked as cache handlers."""
is_cache_handler: bool = True
class CrewMethod(DecoratedMethod[P, R]):
"""Wrapper for methods marked as the main crew execution point."""
is_crew: bool = True