refactor: convert CrewBase decorator to metaclass pattern

This commit is contained in:
Greyson Lalonde
2025-10-15 01:28:59 -04:00
parent 2a23bc604c
commit 8465350f1d
2 changed files with 415 additions and 259 deletions

View File

@@ -1,45 +1,107 @@
"""Base decorator for creating crew classes with configuration and function management.""" """Base metaclass for creating crew classes with configuration and function management."""
from __future__ import annotations
import inspect import inspect
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from pathlib import Path from pathlib import Path
from typing import Any, TypeVar, cast from typing import TYPE_CHECKING, Any
import yaml import yaml
from dotenv import load_dotenv from dotenv import load_dotenv
from crewai.tools import BaseTool from crewai.tools import BaseTool
if TYPE_CHECKING:
from crewai.crews.crew_output import CrewOutput
from crewai.project.wrappers import CrewInstance
load_dotenv() load_dotenv()
T = TypeVar("T", bound=type)
class CrewBaseMeta(type):
"""Metaclass that adds crew functionality to classes."""
def CrewBase(cls: T) -> T: # noqa: N802 def __new__(
"""Wraps a class with crew functionality and configuration management.""" mcs,
name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
**kwargs: Any,
) -> type:
"""Create crew class with configuration and method injection.
class WrappedClass(cls): # type: ignore Args:
is_crew_class: bool = True # type: ignore name: Class name.
bases: Base classes.
namespace: Class namespace dictionary.
**kwargs: Additional keyword arguments.
# Get the directory of the class being decorated Returns:
base_directory = Path(inspect.getfile(cls)).parent New crew class with injected methods and attributes.
"""
cls = super().__new__(mcs, name, bases, namespace)
original_agents_config_path = getattr( cls.is_crew_class = True # type: ignore[attr-defined]
cls._crew_name = name # type: ignore[attr-defined]
try:
cls.base_directory = Path(inspect.getfile(cls)).parent # type: ignore[attr-defined]
except (TypeError, OSError):
cls.base_directory = Path.cwd() # type: ignore[attr-defined]
cls.original_agents_config_path = getattr( # type: ignore[attr-defined]
cls, "agents_config", "config/agents.yaml" cls, "agents_config", "config/agents.yaml"
) )
original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml") cls.original_tasks_config_path = getattr( # type: ignore[attr-defined]
cls, "tasks_config", "config/tasks.yaml"
)
mcp_server_params: Any = getattr(cls, "mcp_server_params", None) cls.mcp_server_params = getattr(cls, "mcp_server_params", None) # type: ignore[attr-defined]
mcp_connect_timeout: int = getattr(cls, "mcp_connect_timeout", 30) cls.mcp_connect_timeout = getattr(cls, "mcp_connect_timeout", 30) # type: ignore[attr-defined]
def __init__(self, *args, **kwargs): cls._close_mcp_server = _close_mcp_server # type: ignore[attr-defined]
super().__init__(*args, **kwargs) cls.get_mcp_tools = get_mcp_tools # type: ignore[attr-defined]
self.load_configurations() cls._load_config = _load_config # type: ignore[attr-defined]
self.map_all_agent_variables() cls.load_configurations = load_configurations # type: ignore[attr-defined]
self.map_all_task_variables() cls.load_yaml = staticmethod(load_yaml) # type: ignore[attr-defined]
# Preserve all decorated functions cls.map_all_agent_variables = map_all_agent_variables # type: ignore[attr-defined]
self._original_functions = { cls._map_agent_variables = _map_agent_variables # type: ignore[attr-defined]
cls.map_all_task_variables = map_all_task_variables # type: ignore[attr-defined]
cls._map_task_variables = _map_task_variables # type: ignore[attr-defined]
return cls
def __call__(cls, *args: Any, **kwargs: Any) -> CrewInstance:
"""Intercept instance creation to initialize crew functionality.
Args:
*args: Positional arguments for instance creation.
**kwargs: Keyword arguments for instance creation.
Returns:
Initialized crew instance.
"""
instance: CrewInstance = super().__call__(*args, **kwargs)
CrewBaseMeta._initialize_crew_instance(instance, cls)
return instance
@staticmethod
def _initialize_crew_instance(instance: CrewInstance, cls: type) -> None:
"""Initialize crew instance attributes and load configurations.
Args:
instance: Crew instance to initialize.
cls: Crew class type.
"""
instance._mcp_server_adapter = None
instance.load_configurations()
instance._all_functions = _get_all_functions(instance)
instance.map_all_agent_variables()
instance.map_all_task_variables()
instance._original_functions = {
name: method name: method
for name, method in cls.__dict__.items() for name, method in cls.__dict__.items()
if any( if any(
@@ -53,128 +115,172 @@ def CrewBase(cls: T) -> T: # noqa: N802
] ]
) )
} }
# Store specific function types
self._original_tasks = self._filter_functions( instance._original_tasks = _filter_functions(
self._original_functions, "is_task" instance._original_functions, "is_task"
) )
self._original_agents = self._filter_functions( instance._original_agents = _filter_functions(
self._original_functions, "is_agent" instance._original_functions, "is_agent"
) )
self._before_kickoff = self._filter_functions( instance._before_kickoff = _filter_functions(
self._original_functions, "is_before_kickoff" instance._original_functions, "is_before_kickoff"
) )
self._after_kickoff = self._filter_functions( instance._after_kickoff = _filter_functions(
self._original_functions, "is_after_kickoff" instance._original_functions, "is_after_kickoff"
) )
self._kickoff = self._filter_functions( instance._kickoff = _filter_functions(
self._original_functions, "is_kickoff" instance._original_functions, "is_kickoff"
) )
# Add close mcp server method to after kickoff instance._after_kickoff["_close_mcp_server"] = instance._close_mcp_server
bound_method = self._create_close_mcp_server_method()
self._after_kickoff["_close_mcp_server"] = bound_method
def _create_close_mcp_server_method(self):
def _close_mcp_server(self, instance, outputs): def _close_mcp_server(
adapter = getattr(self, "_mcp_server_adapter", None) self: CrewInstance, _instance: CrewInstance, outputs: CrewOutput
if adapter is not None: ) -> CrewOutput:
"""Stop MCP server adapter and return outputs.
Args:
self: Crew instance with MCP server adapter.
_instance: Crew instance (unused, required by callback signature).
outputs: Crew execution outputs.
Returns:
Unmodified crew outputs.
"""
if self._mcp_server_adapter is not None:
try: try:
adapter.stop() self._mcp_server_adapter.stop()
except Exception as e: except Exception as e:
logging.warning(f"Error stopping MCP server: {e}") logging.warning(f"Error stopping MCP server: {e}")
return outputs return outputs
_close_mcp_server.is_after_kickoff = True
import types def get_mcp_tools(self: Any, *tool_names: str) -> list[BaseTool]:
"""Get MCP tools filtered by name.
return types.MethodType(_close_mcp_server, self) Args:
self: Crew instance with MCP server configuration.
*tool_names: Optional tool names to filter by.
def get_mcp_tools(self, *tool_names: list[str]) -> list[BaseTool]: Returns:
List of filtered MCP tools, or empty list if no MCP server configured.
"""
if not self.mcp_server_params: if not self.mcp_server_params:
return [] return []
from crewai_tools import MCPServerAdapter # type: ignore[import-untyped] from crewai_tools import MCPServerAdapter # type: ignore[import-untyped]
adapter = getattr(self, "_mcp_server_adapter", None) if self._mcp_server_adapter is None:
if not adapter:
self._mcp_server_adapter = MCPServerAdapter( self._mcp_server_adapter = MCPServerAdapter(
self.mcp_server_params, connect_timeout=self.mcp_connect_timeout self.mcp_server_params, connect_timeout=self.mcp_connect_timeout
) )
return self._mcp_server_adapter.tools.filter_by_names(tool_names or None) return self._mcp_server_adapter.tools.filter_by_names(tool_names or None)
def load_configurations(self):
"""Load agent and task configurations from YAML files.""" def _load_config(
if isinstance(self.original_agents_config_path, str): self: Any, config_path: str | None, config_type: str
agents_config_path = ( ) -> dict[str, Any]:
self.base_directory / self.original_agents_config_path """Load YAML config file or return empty dict if not found.
)
Args:
config_path: Relative path to config file.
config_type: Config type for logging (e.g., "agent", "task").
Returns:
Config dictionary or empty dict.
"""
if isinstance(config_path, str):
full_path = self.base_directory / config_path
try: try:
self.agents_config = self.load_yaml(agents_config_path) return self.load_yaml(full_path)
except FileNotFoundError: except FileNotFoundError:
logging.warning( logging.warning(
f"Agent config file not found at {agents_config_path}. " f"{config_type.capitalize()} config file not found at {full_path}. "
"Proceeding with empty agent configurations." f"Proceeding with empty {config_type} configurations."
) )
self.agents_config = {} return {}
else: else:
logging.warning( logging.warning(
"No agent configuration path provided. Proceeding with empty agent configurations." f"No {config_type} configuration path provided. "
f"Proceeding with empty {config_type} configurations."
) )
self.agents_config = {} return {}
if isinstance(self.original_tasks_config_path, str):
tasks_config_path = (
self.base_directory / self.original_tasks_config_path
)
try:
self.tasks_config = self.load_yaml(tasks_config_path)
except FileNotFoundError:
logging.warning(
f"Task config file not found at {tasks_config_path}. "
"Proceeding with empty task configurations."
)
self.tasks_config = {}
else:
logging.warning(
"No task configuration path provided. Proceeding with empty task configurations."
)
self.tasks_config = {}
@staticmethod def load_configurations(self: Any) -> None:
def load_yaml(config_path: Path): """Load agent and task YAML configurations.
Args:
self: Crew instance with configuration paths.
"""
self.agents_config = self._load_config(self.original_agents_config_path, "agent")
self.tasks_config = self._load_config(self.original_tasks_config_path, "task")
def load_yaml(config_path: Path) -> Any:
"""Load and parse YAML file.
Args:
config_path: Path to YAML configuration file.
Returns:
Parsed YAML content.
Raises:
FileNotFoundError: If config file does not exist.
"""
try: try:
with open(config_path, "r", encoding="utf-8") as file: with open(config_path, encoding="utf-8") as file:
return yaml.safe_load(file) return yaml.safe_load(file)
except FileNotFoundError: except FileNotFoundError:
print(f"File not found: {config_path}") print(f"File not found: {config_path}")
raise raise
def _get_all_functions(self):
def _get_all_functions(self: Any) -> dict[str, Callable[..., Any]]:
"""Return all non-dunder callable attributes.
Args:
self: Instance to inspect for callable attributes.
Returns:
Dictionary mapping attribute names to callable objects.
"""
return { return {
name: getattr(self, name) name: getattr(self, name)
for name in dir(self) for name in dir(self)
if callable(getattr(self, name)) if not (name.startswith("__") and name.endswith("__"))
and callable(getattr(self, name, None))
} }
def _filter_functions( def _filter_functions(
self, functions: dict[str, Callable], attribute: str functions: dict[str, Callable[..., Any]], attribute: str
) -> dict[str, Callable]: ) -> dict[str, Callable[..., Any]]:
return { """Filter functions by attribute presence.
name: func
for name, func in functions.items()
if hasattr(func, attribute)
}
def map_all_agent_variables(self) -> None: Args:
all_functions = self._get_all_functions() functions: Dictionary of functions to filter.
llms = self._filter_functions(all_functions, "is_llm") attribute: Attribute name to check for.
tool_functions = self._filter_functions(all_functions, "is_tool")
cache_handler_functions = self._filter_functions( Returns:
all_functions, "is_cache_handler" Dictionary containing only functions with the specified attribute.
) """
callbacks = self._filter_functions(all_functions, "is_callback") return {name: func for name, func in functions.items() if hasattr(func, attribute)}
def map_all_agent_variables(self: Any) -> None:
"""Map agent configuration variables to callable instances.
Args:
self: Crew instance with agent configurations to map.
"""
llms = _filter_functions(self._all_functions, "is_llm")
tool_functions = _filter_functions(self._all_functions, "is_tool")
cache_handler_functions = _filter_functions(self._all_functions, "is_cache_handler")
callbacks = _filter_functions(self._all_functions, "is_callback")
for agent_name, agent_info in self.agents_config.items(): for agent_name, agent_info in self.agents_config.items():
self._map_agent_variables( self._map_agent_variables(
@@ -186,15 +292,27 @@ def CrewBase(cls: T) -> T: # noqa: N802
callbacks, callbacks,
) )
def _map_agent_variables( def _map_agent_variables(
self, self: Any,
agent_name: str, agent_name: str,
agent_info: dict[str, Any], agent_info: dict[str, Any],
llms: dict[str, Callable], llms: dict[str, Callable[..., Any]],
tool_functions: dict[str, Callable], tool_functions: dict[str, Callable[..., Any]],
cache_handler_functions: dict[str, Callable], cache_handler_functions: dict[str, Callable[..., Any]],
callbacks: dict[str, Callable], callbacks: dict[str, Callable[..., Any]],
) -> None: ) -> None:
"""Resolve and map variables for a single agent.
Args:
self: Crew instance with agent configurations.
agent_name: Name of agent to configure.
agent_info: Agent configuration dictionary.
llms: Dictionary of available LLM providers.
tool_functions: Dictionary of available tools.
cache_handler_functions: Dictionary of available cache handlers.
callbacks: Dictionary of available callbacks.
"""
if llm := agent_info.get("llm"): if llm := agent_info.get("llm"):
try: try:
self.agents_config[agent_name]["llm"] = llms[llm]() self.agents_config[agent_name]["llm"] = llms[llm]()
@@ -217,26 +335,27 @@ def CrewBase(cls: T) -> T: # noqa: N802
) )
if step_callback := agent_info.get("step_callback"): if step_callback := agent_info.get("step_callback"):
self.agents_config[agent_name]["step_callback"] = callbacks[ self.agents_config[agent_name]["step_callback"] = callbacks[step_callback]()
step_callback
]()
if cache_handler := agent_info.get("cache_handler"): if cache_handler := agent_info.get("cache_handler"):
self.agents_config[agent_name]["cache_handler"] = ( self.agents_config[agent_name]["cache_handler"] = cache_handler_functions[
cache_handler_functions[cache_handler]() cache_handler
) ]()
def map_all_task_variables(self) -> None:
all_functions = self._get_all_functions() def map_all_task_variables(self: Any) -> None:
agents = self._filter_functions(all_functions, "is_agent") """Map task configuration variables to callable instances.
tasks = self._filter_functions(all_functions, "is_task")
output_json_functions = self._filter_functions( Args:
all_functions, "is_output_json" self: Crew instance with task configurations to map.
) """
tool_functions = self._filter_functions(all_functions, "is_tool") agents = _filter_functions(self._all_functions, "is_agent")
callback_functions = self._filter_functions(all_functions, "is_callback") tasks = _filter_functions(self._all_functions, "is_task")
output_pydantic_functions = self._filter_functions( output_json_functions = _filter_functions(self._all_functions, "is_output_json")
all_functions, "is_output_pydantic" tool_functions = _filter_functions(self._all_functions, "is_tool")
callback_functions = _filter_functions(self._all_functions, "is_callback")
output_pydantic_functions = _filter_functions(
self._all_functions, "is_output_pydantic"
) )
for task_name, task_info in self.tasks_config.items(): for task_name, task_info in self.tasks_config.items():
@@ -251,17 +370,31 @@ def CrewBase(cls: T) -> T: # noqa: N802
output_pydantic_functions, output_pydantic_functions,
) )
def _map_task_variables( def _map_task_variables(
self, self: Any,
task_name: str, task_name: str,
task_info: dict[str, Any], task_info: dict[str, Any],
agents: dict[str, Callable], agents: dict[str, Callable[..., Any]],
tasks: dict[str, Callable], tasks: dict[str, Callable[..., Any]],
output_json_functions: dict[str, Callable], output_json_functions: dict[str, Callable[..., Any]],
tool_functions: dict[str, Callable], tool_functions: dict[str, Callable[..., Any]],
callback_functions: dict[str, Callable], callback_functions: dict[str, Callable[..., Any]],
output_pydantic_functions: dict[str, Callable], output_pydantic_functions: dict[str, Callable[..., Any]],
) -> None: ) -> None:
"""Resolve and map variables for a single task.
Args:
self: Crew instance with task configurations.
task_name: Name of task to configure.
task_info: Task configuration dictionary.
agents: Dictionary of available agents.
tasks: Dictionary of available tasks.
output_json_functions: Dictionary of available JSON output schemas.
tool_functions: Dictionary of available tools.
callback_functions: Dictionary of available callbacks.
output_pydantic_functions: Dictionary of available Pydantic output schemas.
"""
if context_list := task_info.get("context"): if context_list := task_info.get("context"):
self.tasks_config[task_name]["context"] = [ self.tasks_config[task_name]["context"] = [
tasks[context_task_name]() for context_task_name in context_list tasks[context_task_name]() for context_task_name in context_list
@@ -276,14 +409,12 @@ def CrewBase(cls: T) -> T: # noqa: N802
self.tasks_config[task_name]["agent"] = agents[agent_name]() self.tasks_config[task_name]["agent"] = agents[agent_name]()
if output_json := task_info.get("output_json"): if output_json := task_info.get("output_json"):
self.tasks_config[task_name]["output_json"] = output_json_functions[ self.tasks_config[task_name]["output_json"] = output_json_functions[output_json]
output_json
]
if output_pydantic := task_info.get("output_pydantic"): if output_pydantic := task_info.get("output_pydantic"):
self.tasks_config[task_name]["output_pydantic"] = ( self.tasks_config[task_name]["output_pydantic"] = output_pydantic_functions[
output_pydantic_functions[output_pydantic] output_pydantic
) ]
if callbacks := task_info.get("callbacks"): if callbacks := task_info.get("callbacks"):
self.tasks_config[task_name]["callbacks"] = [ self.tasks_config[task_name]["callbacks"] = [
@@ -293,9 +424,24 @@ def CrewBase(cls: T) -> T: # noqa: N802
if guardrail := task_info.get("guardrail"): if guardrail := task_info.get("guardrail"):
self.tasks_config[task_name]["guardrail"] = guardrail self.tasks_config[task_name]["guardrail"] = guardrail
# Include base class (qual)name in the wrapper class (qual)name.
WrappedClass.__name__ = CrewBase.__name__ + "(" + cls.__name__ + ")"
WrappedClass.__qualname__ = CrewBase.__qualname__ + "(" + cls.__name__ + ")"
WrappedClass._crew_name = cls.__name__
return cast(T, WrappedClass) def CrewBase(cls: type) -> type: # noqa: N802
"""Apply CrewBaseMeta metaclass to a class for decorator syntax compatibility.
Args:
cls: Class to apply metaclass to.
Returns:
New class with CrewBaseMeta metaclass applied.
"""
if isinstance(cls, CrewBaseMeta):
return cls
namespace = {
key: value
for key, value in cls.__dict__.items()
if not key.startswith("__")
or key in ("__module__", "__qualname__", "__annotations__")
}
return CrewBaseMeta(cls.__name__, cls.__bases__, namespace)

View File

@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Generic, ParamSpec, Protocol, Self, TypeV
if TYPE_CHECKING: if TYPE_CHECKING:
from crewai import Agent, Task from crewai import Agent, Task
from crewai.crews.crew_output import CrewOutput
P = ParamSpec("P") P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
@@ -37,13 +38,22 @@ def _copy_function_metadata(wrapper: Any, func: Callable[..., Any]) -> None:
class CrewInstance(Protocol): class CrewInstance(Protocol):
"""Protocol for crew class instances with required attributes.""" """Protocol for crew class instances with required attributes."""
_mcp_server_adapter: Any
_all_functions: dict[str, Callable[..., Any]]
_original_functions: dict[str, Callable[..., Any]]
_original_tasks: dict[str, Callable[[Self], Task]] _original_tasks: dict[str, Callable[[Self], Task]]
_original_agents: dict[str, Callable[[Self], Agent]] _original_agents: dict[str, Callable[[Self], Agent]]
_before_kickoff: dict[str, Callable[..., Any]] _before_kickoff: dict[str, Callable[..., Any]]
_after_kickoff: dict[str, Callable[..., Any]] _after_kickoff: dict[str, Callable[..., Any]]
_kickoff: dict[str, Callable[..., Any]]
agents: list[Agent] agents: list[Agent]
tasks: list[Task] tasks: list[Task]
def load_configurations(self) -> None: ...
def map_all_agent_variables(self) -> None: ...
def map_all_task_variables(self) -> None: ...
def _close_mcp_server(self, instance: Self, outputs: CrewOutput) -> CrewOutput: ...
class DecoratedMethod(Generic[P, R]): class DecoratedMethod(Generic[P, R]):
"""Base wrapper for methods with decorator metadata. """Base wrapper for methods with decorator metadata.