chore: refactor project decorators and imports for clarity

This commit is contained in:
Greyson Lalonde
2025-10-13 12:29:57 -04:00
parent 2ebb2e845f
commit 541eec0639
3 changed files with 76 additions and 63 deletions

View File

@@ -1,4 +1,6 @@
from .annotations import ( """Project package for CrewAI."""
from crewai.project.annotations import (
after_kickoff, after_kickoff,
agent, agent,
before_kickoff, before_kickoff,
@@ -11,19 +13,19 @@ from .annotations import (
task, task,
tool, tool,
) )
from .crew_base import CrewBase from crewai.project.crew_base import CrewBase
__all__ = [ __all__ = [
"CrewBase",
"after_kickoff",
"agent", "agent",
"before_kickoff",
"cache_handler",
"callback",
"crew", "crew",
"task", "llm",
"output_json", "output_json",
"output_pydantic", "output_pydantic",
"task",
"tool", "tool",
"callback",
"CrewBase",
"llm",
"cache_handler",
"before_kickoff",
"after_kickoff",
] ]

View File

@@ -1,50 +1,52 @@
"""Decorators for defining crew components and their behaviors."""
from collections.abc import Callable
from functools import wraps from functools import wraps
from typing import Callable from typing import Any, Concatenate, ParamSpec, TypeVar
from crewai import Crew from crewai import Crew
from crewai.project.utils import memoize from crewai.project.utils import memoize
"""Decorators for defining crew components and their behaviors.""" P = ParamSpec("P")
R = TypeVar("R")
def before_kickoff(func): def before_kickoff(meth):
"""Marks a method to execute before crew kickoff.""" """Marks a method to execute before crew kickoff."""
func.is_before_kickoff = True meth.is_before_kickoff = True
return func return meth
def after_kickoff(func): def after_kickoff(meth):
"""Marks a method to execute after crew kickoff.""" """Marks a method to execute after crew kickoff."""
func.is_after_kickoff = True meth.is_after_kickoff = True
return func return meth
def task(func): def task(meth):
"""Marks a method as a crew task.""" """Marks a method as a crew task."""
func.is_task = True meth.is_task = True
@wraps(func) @wraps(meth)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
result = func(*args, **kwargs) result = meth(*args, **kwargs)
if not result.name: if not result.name:
result.name = func.__name__ result.name = meth.__name__
return result return result
return memoize(wrapper) return memoize(wrapper)
def agent(func): def agent(meth):
"""Marks a method as a crew agent.""" """Marks a method as a crew agent."""
func.is_agent = True meth.is_agent = True
func = memoize(func) return memoize(meth)
return func
def llm(func): def llm(meth):
"""Marks a method as an LLM provider.""" """Marks a method as an LLM provider."""
func.is_llm = True meth.is_llm = True
func = memoize(func) return memoize(meth)
return func
def output_json(cls): def output_json(cls):
@@ -59,28 +61,28 @@ def output_pydantic(cls):
return cls return cls
def tool(func): def tool(meth):
"""Marks a method as a crew tool.""" """Marks a method as a crew tool."""
func.is_tool = True meth.is_tool = True
return memoize(func) return memoize(meth)
def callback(func): def callback(meth):
"""Marks a method as a crew callback.""" """Marks a method as a crew callback."""
func.is_callback = True meth.is_callback = True
return memoize(func) return memoize(meth)
def cache_handler(func): def cache_handler(meth):
"""Marks a method as a cache handler.""" """Marks a method as a cache handler."""
func.is_cache_handler = True meth.is_cache_handler = True
return memoize(func) return memoize(meth)
def crew(func) -> Callable[..., Crew]: def crew(meth) -> Callable[..., Crew]:
"""Marks a method as the main crew execution point.""" """Marks a method as the main crew execution point."""
@wraps(func) @wraps(meth)
def wrapper(self, *args, **kwargs) -> Crew: def wrapper(self, *args, **kwargs) -> Crew:
instantiated_tasks = [] instantiated_tasks = []
instantiated_agents = [] instantiated_agents = []
@@ -91,7 +93,7 @@ def crew(func) -> Callable[..., Crew]:
agents = self._original_agents.items() agents = self._original_agents.items()
# Instantiate tasks in order # Instantiate tasks in order
for task_name, task_method in tasks: for _, task_method in tasks:
task_instance = task_method(self) task_instance = task_method(self)
instantiated_tasks.append(task_instance) instantiated_tasks.append(task_instance)
agent_instance = getattr(task_instance, "agent", None) agent_instance = getattr(task_instance, "agent", None)
@@ -100,7 +102,7 @@ def crew(func) -> Callable[..., Crew]:
agent_roles.add(agent_instance.role) agent_roles.add(agent_instance.role)
# Instantiate agents not included by tasks # Instantiate agents not included by tasks
for agent_name, agent_method in agents: for _, agent_method in agents:
agent_instance = agent_method(self) agent_instance = agent_method(self)
if agent_instance.role not in agent_roles: if agent_instance.role not in agent_roles:
instantiated_agents.append(agent_instance) instantiated_agents.append(agent_instance)
@@ -109,19 +111,25 @@ def crew(func) -> Callable[..., Crew]:
self.agents = instantiated_agents self.agents = instantiated_agents
self.tasks = instantiated_tasks self.tasks = instantiated_tasks
crew = func(self, *args, **kwargs) crew_instance = meth(self, *args, **kwargs)
def callback_wrapper(callback, instance): def callback_wrapper(
def wrapper(*args, **kwargs): hook: Callable[Concatenate[Any, P], R], instance: Any
return callback(instance, *args, **kwargs) ) -> Callable[P, R]:
def bound_callback(*cb_args: P.args, **cb_kwargs: P.kwargs) -> R:
return hook(instance, *cb_args, **cb_kwargs)
return wrapper return bound_callback
for _, callback in self._before_kickoff.items(): for hook_callback in self._before_kickoff.values():
crew.before_kickoff_callbacks.append(callback_wrapper(callback, self)) crew_instance.before_kickoff_callbacks.append(
for _, callback in self._after_kickoff.items(): callback_wrapper(hook_callback, self)
crew.after_kickoff_callbacks.append(callback_wrapper(callback, self)) )
for hook_callback in self._after_kickoff.values():
crew_instance.after_kickoff_callbacks.append(
callback_wrapper(hook_callback, self)
)
return crew return crew_instance
return memoize(wrapper) return memoize(wrapper)

View File

@@ -1,3 +1,5 @@
"""Base decorator for creating crew classes with configuration and function management."""
import inspect import inspect
import logging import logging
from collections.abc import Callable from collections.abc import Callable
@@ -13,8 +15,6 @@ load_dotenv()
T = TypeVar("T", bound=type) T = TypeVar("T", bound=type)
"""Base decorator for creating crew classes with configuration and function management."""
def CrewBase(cls: T) -> T: # noqa: N802 def CrewBase(cls: T) -> T: # noqa: N802
"""Wraps a class with crew functionality and configuration management.""" """Wraps a class with crew functionality and configuration management."""
@@ -72,11 +72,11 @@ def CrewBase(cls: T) -> T: # noqa: N802
# Add close mcp server method to after kickoff # Add close mcp server method to after kickoff
bound_method = self._create_close_mcp_server_method() bound_method = self._create_close_mcp_server_method()
self._after_kickoff['_close_mcp_server'] = bound_method self._after_kickoff["_close_mcp_server"] = bound_method
def _create_close_mcp_server_method(self): def _create_close_mcp_server_method(self):
def _close_mcp_server(self, instance, outputs): def _close_mcp_server(self, instance, outputs):
adapter = getattr(self, '_mcp_server_adapter', None) adapter = getattr(self, "_mcp_server_adapter", None)
if adapter is not None: if adapter is not None:
try: try:
adapter.stop() adapter.stop()
@@ -87,6 +87,7 @@ def CrewBase(cls: T) -> T: # noqa: N802
_close_mcp_server.is_after_kickoff = True _close_mcp_server.is_after_kickoff = True
import types import types
return types.MethodType(_close_mcp_server, self) return types.MethodType(_close_mcp_server, self)
def get_mcp_tools(self, *tool_names: list[str]) -> list[BaseTool]: def get_mcp_tools(self, *tool_names: list[str]) -> list[BaseTool]:
@@ -95,16 +96,14 @@ def CrewBase(cls: T) -> T: # noqa: N802
from crewai_tools import MCPServerAdapter # type: ignore[import-untyped] from crewai_tools import MCPServerAdapter # type: ignore[import-untyped]
adapter = getattr(self, '_mcp_server_adapter', None) adapter = getattr(self, "_mcp_server_adapter", None)
if not adapter: if not adapter:
self._mcp_server_adapter = MCPServerAdapter( self._mcp_server_adapter = MCPServerAdapter(
self.mcp_server_params, self.mcp_server_params, connect_timeout=self.mcp_connect_timeout
connect_timeout=self.mcp_connect_timeout
) )
return self._mcp_server_adapter.tools.filter_by_names(tool_names or None) return self._mcp_server_adapter.tools.filter_by_names(tool_names or None)
def load_configurations(self): def load_configurations(self):
"""Load agent and task configurations from YAML files.""" """Load agent and task configurations from YAML files."""
if isinstance(self.original_agents_config_path, str): if isinstance(self.original_agents_config_path, str):
@@ -209,9 +208,13 @@ def CrewBase(cls: T) -> T: # noqa: N802
if function_calling_llm := agent_info.get("function_calling_llm"): if function_calling_llm := agent_info.get("function_calling_llm"):
try: try:
self.agents_config[agent_name]["function_calling_llm"] = llms[function_calling_llm]() self.agents_config[agent_name]["function_calling_llm"] = llms[
function_calling_llm
]()
except KeyError: except KeyError:
self.agents_config[agent_name]["function_calling_llm"] = function_calling_llm self.agents_config[agent_name]["function_calling_llm"] = (
function_calling_llm
)
if step_callback := agent_info.get("step_callback"): if step_callback := agent_info.get("step_callback"):
self.agents_config[agent_name]["step_callback"] = callbacks[ self.agents_config[agent_name]["step_callback"] = callbacks[