mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-12 01:28:30 +00:00
chore: refactor annotation decorators with type-safe wrappers
This commit is contained in:
@@ -6,77 +6,154 @@ from typing import Any, Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from crewai import Crew
|
||||
from crewai.project.utils import memoize
|
||||
from crewai.project.wrappers import (
|
||||
AfterKickoffMethod,
|
||||
AgentMethod,
|
||||
BeforeKickoffMethod,
|
||||
CacheHandlerMethod,
|
||||
CallbackMethod,
|
||||
LLMMethod,
|
||||
TaskMethod,
|
||||
TaskResultT,
|
||||
ToolMethod,
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def before_kickoff(meth):
|
||||
"""Marks a method to execute before crew kickoff."""
|
||||
meth.is_before_kickoff = True
|
||||
return meth
|
||||
def before_kickoff(meth: Callable[P, R]) -> BeforeKickoffMethod[P, R]:
|
||||
"""Marks a method to execute before crew kickoff.
|
||||
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped method marked for before kickoff execution.
|
||||
"""
|
||||
return BeforeKickoffMethod(meth)
|
||||
|
||||
|
||||
def after_kickoff(meth):
|
||||
"""Marks a method to execute after crew kickoff."""
|
||||
meth.is_after_kickoff = True
|
||||
return meth
|
||||
def after_kickoff(meth: Callable[P, R]) -> AfterKickoffMethod[P, R]:
|
||||
"""Marks a method to execute after crew kickoff.
|
||||
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped method marked for after kickoff execution.
|
||||
"""
|
||||
return AfterKickoffMethod(meth)
|
||||
|
||||
|
||||
def task(meth):
|
||||
"""Marks a method as a crew task."""
|
||||
meth.is_task = True
|
||||
def task(meth: Callable[P, TaskResultT]) -> TaskMethod[P, TaskResultT]:
|
||||
"""Marks a method as a crew task.
|
||||
|
||||
@wraps(meth)
|
||||
def wrapper(*args, **kwargs):
|
||||
result = meth(*args, **kwargs)
|
||||
if not result.name:
|
||||
result.name = meth.__name__
|
||||
return result
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
return memoize(wrapper)
|
||||
Returns:
|
||||
A wrapped method marked as a task with memoization.
|
||||
"""
|
||||
wrapped = TaskMethod(meth)
|
||||
wrapped._wrapped = memoize(wrapped._meth)
|
||||
return wrapped
|
||||
|
||||
|
||||
def agent(meth):
|
||||
"""Marks a method as a crew agent."""
|
||||
meth.is_agent = True
|
||||
return memoize(meth)
|
||||
def agent(meth: Callable[P, R]) -> AgentMethod[P, R]:
|
||||
"""Marks a method as a crew agent.
|
||||
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped method marked as an agent with memoization.
|
||||
"""
|
||||
wrapped = AgentMethod(meth)
|
||||
wrapped._wrapped = memoize(wrapped._meth)
|
||||
return wrapped
|
||||
|
||||
|
||||
def llm(meth):
|
||||
"""Marks a method as an LLM provider."""
|
||||
meth.is_llm = True
|
||||
return memoize(meth)
|
||||
def llm(meth: Callable[P, R]) -> LLMMethod[P, R]:
|
||||
"""Marks a method as an LLM provider.
|
||||
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped method marked as an LLM provider with memoization.
|
||||
"""
|
||||
wrapped = LLMMethod(meth)
|
||||
wrapped._wrapped = memoize(wrapped._meth)
|
||||
return wrapped
|
||||
|
||||
|
||||
def output_json(cls):
|
||||
"""Marks a class as JSON output format."""
|
||||
"""Marks a class as JSON output format.
|
||||
|
||||
Args:
|
||||
cls: The class to mark.
|
||||
|
||||
Returns:
|
||||
The class with is_output_json attribute set.
|
||||
"""
|
||||
cls.is_output_json = True
|
||||
return cls
|
||||
|
||||
|
||||
def output_pydantic(cls):
|
||||
"""Marks a class as Pydantic output format."""
|
||||
"""Marks a class as Pydantic output format.
|
||||
|
||||
Args:
|
||||
cls: The class to mark.
|
||||
|
||||
Returns:
|
||||
The class with is_output_pydantic attribute set.
|
||||
"""
|
||||
cls.is_output_pydantic = True
|
||||
return cls
|
||||
|
||||
|
||||
def tool(meth):
|
||||
"""Marks a method as a crew tool."""
|
||||
meth.is_tool = True
|
||||
return memoize(meth)
|
||||
def tool(meth: Callable[P, R]) -> ToolMethod[P, R]:
|
||||
"""Marks a method as a crew tool.
|
||||
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped method marked as a tool with memoization.
|
||||
"""
|
||||
wrapped = ToolMethod(meth)
|
||||
wrapped._wrapped = memoize(wrapped._meth)
|
||||
return wrapped
|
||||
|
||||
|
||||
def callback(meth):
|
||||
"""Marks a method as a crew callback."""
|
||||
meth.is_callback = True
|
||||
return memoize(meth)
|
||||
def callback(meth: Callable[P, R]) -> CallbackMethod[P, R]:
|
||||
"""Marks a method as a crew callback.
|
||||
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped method marked as a callback with memoization.
|
||||
"""
|
||||
wrapped = CallbackMethod(meth)
|
||||
wrapped._wrapped = memoize(wrapped._meth)
|
||||
return wrapped
|
||||
|
||||
|
||||
def cache_handler(meth):
|
||||
"""Marks a method as a cache handler."""
|
||||
meth.is_cache_handler = True
|
||||
return memoize(meth)
|
||||
def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
|
||||
"""Marks a method as a cache handler.
|
||||
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped method marked as a cache handler with memoization.
|
||||
"""
|
||||
wrapped = CacheHandlerMethod(meth)
|
||||
wrapped._wrapped = memoize(wrapped._meth)
|
||||
return wrapped
|
||||
|
||||
|
||||
def crew(meth) -> Callable[..., Crew]:
|
||||
|
||||
168
src/crewai/project/wrappers.py
Normal file
168
src/crewai/project/wrappers.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Wrapper classes for decorated methods with type-safe metadata."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Generic, ParamSpec, Protocol, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class TaskResult(Protocol):
|
||||
"""Protocol for task objects that have a name property."""
|
||||
|
||||
@property
|
||||
def name(self) -> str | None:
|
||||
"""Get the task name."""
|
||||
...
|
||||
|
||||
@name.setter
|
||||
def name(self, value: str) -> None:
|
||||
"""Set the task name."""
|
||||
...
|
||||
|
||||
|
||||
TaskResultT = TypeVar("TaskResultT", bound=TaskResult)
|
||||
|
||||
|
||||
class DecoratedMethod(Generic[P, R]):
|
||||
"""Base wrapper for methods with decorator metadata.
|
||||
|
||||
This class provides a type-safe way to add metadata to methods
|
||||
while preserving their callable signature and attributes.
|
||||
"""
|
||||
|
||||
def __init__(self, meth: Callable[P, R]) -> None:
|
||||
"""Initialize the decorated method wrapper.
|
||||
|
||||
Args:
|
||||
meth: The method to wrap.
|
||||
"""
|
||||
self._meth = meth
|
||||
self._wrapped: Callable[P, R] | None = None
|
||||
# Preserve function metadata
|
||||
wraps(meth)(self)
|
||||
|
||||
@property
|
||||
def __name__(self) -> str:
|
||||
"""Get the name of the wrapped method."""
|
||||
return self._meth.__name__
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Call the wrapped method.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments.
|
||||
**kwargs: Keyword arguments.
|
||||
|
||||
Returns:
|
||||
The result of calling the wrapped method.
|
||||
"""
|
||||
if self._wrapped:
|
||||
return self._wrapped(*args, **kwargs)
|
||||
return self._meth(*args, **kwargs)
|
||||
|
||||
def unwrap(self) -> Callable[P, R]:
|
||||
"""Get the original unwrapped method.
|
||||
|
||||
Returns:
|
||||
The original method before decoration.
|
||||
"""
|
||||
return self._meth
|
||||
|
||||
|
||||
class BeforeKickoffMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked to execute before crew kickoff."""
|
||||
|
||||
is_before_kickoff: bool = True
|
||||
|
||||
|
||||
class AfterKickoffMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked to execute after crew kickoff."""
|
||||
|
||||
is_after_kickoff: bool = True
|
||||
|
||||
|
||||
class TaskMethod(Generic[P, TaskResultT]):
|
||||
"""Wrapper for methods marked as crew tasks."""
|
||||
|
||||
is_task: bool = True
|
||||
|
||||
def __init__(self, meth: Callable[P, TaskResultT]) -> None:
|
||||
"""Initialize the task method wrapper.
|
||||
|
||||
Args:
|
||||
meth: The method to wrap.
|
||||
"""
|
||||
self._meth = meth
|
||||
self._wrapped: Callable[P, TaskResultT] | None = None
|
||||
# Preserve function metadata
|
||||
wraps(meth)(self)
|
||||
|
||||
@property
|
||||
def __name__(self) -> str:
|
||||
"""Get the name of the wrapped method."""
|
||||
return self._meth.__name__
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> TaskResultT:
|
||||
"""Call the wrapped method and set task name if not provided.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments.
|
||||
**kwargs: Keyword arguments.
|
||||
|
||||
Returns:
|
||||
The task instance with name set if not already provided.
|
||||
"""
|
||||
if self._wrapped:
|
||||
result = self._wrapped(*args, **kwargs)
|
||||
else:
|
||||
result = self._meth(*args, **kwargs)
|
||||
|
||||
if not result.name:
|
||||
result.name = self._meth.__name__
|
||||
return result
|
||||
|
||||
def unwrap(self) -> Callable[P, TaskResultT]:
|
||||
"""Get the original unwrapped method.
|
||||
|
||||
Returns:
|
||||
The original method before decoration.
|
||||
"""
|
||||
return self._meth
|
||||
|
||||
|
||||
class AgentMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked as crew agents."""
|
||||
|
||||
is_agent: bool = True
|
||||
|
||||
|
||||
class LLMMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked as LLM providers."""
|
||||
|
||||
is_llm: bool = True
|
||||
|
||||
|
||||
class ToolMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked as crew tools."""
|
||||
|
||||
is_tool: bool = True
|
||||
|
||||
|
||||
class CallbackMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked as crew callbacks."""
|
||||
|
||||
is_callback: bool = True
|
||||
|
||||
|
||||
class CacheHandlerMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked as cache handlers."""
|
||||
|
||||
is_cache_handler: bool = True
|
||||
|
||||
|
||||
class CrewMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked as the main crew execution point."""
|
||||
|
||||
is_crew: bool = True
|
||||
Reference in New Issue
Block a user