mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
feat: add protocols for Agent, Task, and Crew instances
This commit is contained in:
@@ -2,26 +2,31 @@
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any, Concatenate, ParamSpec, TypeVar
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from crewai import Crew
|
||||
from crewai.project.utils import memoize
|
||||
from crewai.project.wrappers import (
|
||||
AfterKickoffMethod,
|
||||
AgentInstance,
|
||||
AgentMethod,
|
||||
BeforeKickoffMethod,
|
||||
CacheHandlerMethod,
|
||||
CallbackMethod,
|
||||
CrewInstance,
|
||||
LLMMethod,
|
||||
OutputJsonClass,
|
||||
OutputPydanticClass,
|
||||
TaskInstance,
|
||||
TaskMethod,
|
||||
TaskResultT,
|
||||
ToolMethod,
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
P2 = ParamSpec("P2")
|
||||
R = TypeVar("R")
|
||||
R2 = TypeVar("R2")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@@ -145,14 +150,33 @@ def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
|
||||
return CacheHandlerMethod(memoize(meth))
|
||||
|
||||
|
||||
def crew(meth) -> Callable[..., Crew]:
|
||||
"""Marks a method as the main crew execution point."""
|
||||
def crew(
|
||||
meth: Callable[Concatenate[CrewInstance, P], Crew],
|
||||
) -> Callable[Concatenate[CrewInstance, P], Crew]:
|
||||
"""Marks a method as the main crew execution point.
|
||||
|
||||
Args:
|
||||
meth: The method to mark as crew execution point.
|
||||
|
||||
Returns:
|
||||
A wrapped method that instantiates tasks and agents before execution.
|
||||
"""
|
||||
|
||||
@wraps(meth)
|
||||
def wrapper(self, *args, **kwargs) -> Crew:
|
||||
instantiated_tasks = []
|
||||
instantiated_agents = []
|
||||
agent_roles = set()
|
||||
def wrapper(self: CrewInstance, *args: P.args, **kwargs: P.kwargs) -> Crew:
|
||||
"""Wrapper that sets up crew before calling the decorated method.
|
||||
|
||||
Args:
|
||||
self: The crew class instance.
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Keyword arguments to pass to the method.
|
||||
|
||||
Returns:
|
||||
The configured Crew instance with callbacks attached.
|
||||
"""
|
||||
instantiated_tasks: list[TaskInstance] = []
|
||||
instantiated_agents: list[AgentInstance] = []
|
||||
agent_roles: set[str] = set()
|
||||
|
||||
# Use the preserved task and agent information
|
||||
tasks = self._original_tasks.items()
|
||||
@@ -180,9 +204,28 @@ def crew(meth) -> Callable[..., Crew]:
|
||||
crew_instance = meth(self, *args, **kwargs)
|
||||
|
||||
def callback_wrapper(
|
||||
hook: Callable[Concatenate[Any, P], R], instance: Any
|
||||
) -> Callable[P, R]:
|
||||
def bound_callback(*cb_args: P.args, **cb_kwargs: P.kwargs) -> R:
|
||||
hook: Callable[Concatenate[CrewInstance, P2], R2], instance: CrewInstance
|
||||
) -> Callable[P2, R2]:
|
||||
"""Bind a hook callback to an instance.
|
||||
|
||||
Args:
|
||||
hook: The callback hook to bind.
|
||||
instance: The instance to bind to.
|
||||
|
||||
Returns:
|
||||
A bound callback function.
|
||||
"""
|
||||
|
||||
def bound_callback(*cb_args: P2.args, **cb_kwargs: P2.kwargs) -> R2:
|
||||
"""Execute the bound callback.
|
||||
|
||||
Args:
|
||||
*cb_args: Positional arguments for the callback.
|
||||
**cb_kwargs: Keyword arguments for the callback.
|
||||
|
||||
Returns:
|
||||
The result of the callback execution.
|
||||
"""
|
||||
return hook(instance, *cb_args, **cb_kwargs)
|
||||
|
||||
return bound_callback
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Generic, ParamSpec, Protocol, TypeVar
|
||||
from typing import Any, Generic, ParamSpec, Protocol, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
@@ -25,6 +25,29 @@ class TaskResult(Protocol):
|
||||
TaskResultT = TypeVar("TaskResultT", bound=TaskResult)
|
||||
|
||||
|
||||
class AgentInstance(Protocol):
|
||||
"""Protocol for agent instances."""
|
||||
|
||||
role: str
|
||||
|
||||
|
||||
class TaskInstance(Protocol):
|
||||
"""Protocol for task instances."""
|
||||
|
||||
agent: AgentInstance | None
|
||||
|
||||
|
||||
class CrewInstance(Protocol):
|
||||
"""Protocol for crew class instances with required attributes."""
|
||||
|
||||
_original_tasks: dict[str, Callable[..., Any]]
|
||||
_original_agents: dict[str, Callable[..., Any]]
|
||||
_before_kickoff: dict[str, Callable[..., Any]]
|
||||
_after_kickoff: dict[str, Callable[..., Any]]
|
||||
agents: list[AgentInstance]
|
||||
tasks: list[TaskInstance]
|
||||
|
||||
|
||||
class DecoratedMethod(Generic[P, R]):
|
||||
"""Base wrapper for methods with decorator metadata.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user