refactor: convert CrewBase decorator to metaclass pattern

This commit is contained in:
Greyson Lalonde
2025-10-15 01:28:59 -04:00
parent 2a23bc604c
commit 8465350f1d
2 changed files with 415 additions and 259 deletions

View File

@@ -1,301 +1,447 @@
"""Base decorator for creating crew classes with configuration and function management."""
"""Base metaclass for creating crew classes with configuration and function management."""
from __future__ import annotations
import inspect
import logging
from collections.abc import Callable
from pathlib import Path
from typing import Any, TypeVar, cast
from typing import TYPE_CHECKING, Any
import yaml
from dotenv import load_dotenv
from crewai.tools import BaseTool
if TYPE_CHECKING:
from crewai.crews.crew_output import CrewOutput
from crewai.project.wrappers import CrewInstance
load_dotenv()
T = TypeVar("T", bound=type)
class CrewBaseMeta(type):
"""Metaclass that adds crew functionality to classes."""
def CrewBase(cls: T) -> T: # noqa: N802
"""Wraps a class with crew functionality and configuration management."""
def __new__(
mcs,
name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
**kwargs: Any,
) -> type:
"""Create crew class with configuration and method injection.
class WrappedClass(cls): # type: ignore
is_crew_class: bool = True # type: ignore
Args:
name: Class name.
bases: Base classes.
namespace: Class namespace dictionary.
**kwargs: Additional keyword arguments.
# Get the directory of the class being decorated
base_directory = Path(inspect.getfile(cls)).parent
Returns:
New crew class with injected methods and attributes.
"""
cls = super().__new__(mcs, name, bases, namespace)
original_agents_config_path = getattr(
cls.is_crew_class = True # type: ignore[attr-defined]
cls._crew_name = name # type: ignore[attr-defined]
try:
cls.base_directory = Path(inspect.getfile(cls)).parent # type: ignore[attr-defined]
except (TypeError, OSError):
cls.base_directory = Path.cwd() # type: ignore[attr-defined]
cls.original_agents_config_path = getattr( # type: ignore[attr-defined]
cls, "agents_config", "config/agents.yaml"
)
original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml")
cls.original_tasks_config_path = getattr( # type: ignore[attr-defined]
cls, "tasks_config", "config/tasks.yaml"
)
mcp_server_params: Any = getattr(cls, "mcp_server_params", None)
mcp_connect_timeout: int = getattr(cls, "mcp_connect_timeout", 30)
cls.mcp_server_params = getattr(cls, "mcp_server_params", None) # type: ignore[attr-defined]
cls.mcp_connect_timeout = getattr(cls, "mcp_connect_timeout", 30) # type: ignore[attr-defined]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.load_configurations()
self.map_all_agent_variables()
self.map_all_task_variables()
# Preserve all decorated functions
self._original_functions = {
name: method
for name, method in cls.__dict__.items()
if any(
hasattr(method, attr)
for attr in [
"is_task",
"is_agent",
"is_before_kickoff",
"is_after_kickoff",
"is_kickoff",
]
)
}
# Store specific function types
self._original_tasks = self._filter_functions(
self._original_functions, "is_task"
)
self._original_agents = self._filter_functions(
self._original_functions, "is_agent"
)
self._before_kickoff = self._filter_functions(
self._original_functions, "is_before_kickoff"
)
self._after_kickoff = self._filter_functions(
self._original_functions, "is_after_kickoff"
)
self._kickoff = self._filter_functions(
self._original_functions, "is_kickoff"
)
cls._close_mcp_server = _close_mcp_server # type: ignore[attr-defined]
cls.get_mcp_tools = get_mcp_tools # type: ignore[attr-defined]
cls._load_config = _load_config # type: ignore[attr-defined]
cls.load_configurations = load_configurations # type: ignore[attr-defined]
cls.load_yaml = staticmethod(load_yaml) # type: ignore[attr-defined]
cls.map_all_agent_variables = map_all_agent_variables # type: ignore[attr-defined]
cls._map_agent_variables = _map_agent_variables # type: ignore[attr-defined]
cls.map_all_task_variables = map_all_task_variables # type: ignore[attr-defined]
cls._map_task_variables = _map_task_variables # type: ignore[attr-defined]
# Add close mcp server method to after kickoff
bound_method = self._create_close_mcp_server_method()
self._after_kickoff["_close_mcp_server"] = bound_method
return cls
def _create_close_mcp_server_method(self):
def _close_mcp_server(self, instance, outputs):
adapter = getattr(self, "_mcp_server_adapter", None)
if adapter is not None:
try:
adapter.stop()
except Exception as e:
logging.warning(f"Error stopping MCP server: {e}")
return outputs
def __call__(cls, *args: Any, **kwargs: Any) -> CrewInstance:
"""Intercept instance creation to initialize crew functionality.
_close_mcp_server.is_after_kickoff = True
Args:
*args: Positional arguments for instance creation.
**kwargs: Keyword arguments for instance creation.
import types
Returns:
Initialized crew instance.
"""
instance: CrewInstance = super().__call__(*args, **kwargs)
CrewBaseMeta._initialize_crew_instance(instance, cls)
return instance
return types.MethodType(_close_mcp_server, self)
@staticmethod
def _initialize_crew_instance(instance: CrewInstance, cls: type) -> None:
"""Initialize crew instance attributes and load configurations.
def get_mcp_tools(self, *tool_names: list[str]) -> list[BaseTool]:
if not self.mcp_server_params:
return []
Args:
instance: Crew instance to initialize.
cls: Crew class type.
"""
instance._mcp_server_adapter = None
instance.load_configurations()
instance._all_functions = _get_all_functions(instance)
instance.map_all_agent_variables()
instance.map_all_task_variables()
from crewai_tools import MCPServerAdapter # type: ignore[import-untyped]
adapter = getattr(self, "_mcp_server_adapter", None)
if not adapter:
self._mcp_server_adapter = MCPServerAdapter(
self.mcp_server_params, connect_timeout=self.mcp_connect_timeout
)
return self._mcp_server_adapter.tools.filter_by_names(tool_names or None)
def load_configurations(self):
"""Load agent and task configurations from YAML files."""
if isinstance(self.original_agents_config_path, str):
agents_config_path = (
self.base_directory / self.original_agents_config_path
)
try:
self.agents_config = self.load_yaml(agents_config_path)
except FileNotFoundError:
logging.warning(
f"Agent config file not found at {agents_config_path}. "
"Proceeding with empty agent configurations."
)
self.agents_config = {}
else:
logging.warning(
"No agent configuration path provided. Proceeding with empty agent configurations."
)
self.agents_config = {}
if isinstance(self.original_tasks_config_path, str):
tasks_config_path = (
self.base_directory / self.original_tasks_config_path
)
try:
self.tasks_config = self.load_yaml(tasks_config_path)
except FileNotFoundError:
logging.warning(
f"Task config file not found at {tasks_config_path}. "
"Proceeding with empty task configurations."
)
self.tasks_config = {}
else:
logging.warning(
"No task configuration path provided. Proceeding with empty task configurations."
)
self.tasks_config = {}
@staticmethod
def load_yaml(config_path: Path):
try:
with open(config_path, "r", encoding="utf-8") as file:
return yaml.safe_load(file)
except FileNotFoundError:
print(f"File not found: {config_path}")
raise
def _get_all_functions(self):
return {
name: getattr(self, name)
for name in dir(self)
if callable(getattr(self, name))
}
def _filter_functions(
self, functions: dict[str, Callable], attribute: str
) -> dict[str, Callable]:
return {
name: func
for name, func in functions.items()
if hasattr(func, attribute)
}
def map_all_agent_variables(self) -> None:
all_functions = self._get_all_functions()
llms = self._filter_functions(all_functions, "is_llm")
tool_functions = self._filter_functions(all_functions, "is_tool")
cache_handler_functions = self._filter_functions(
all_functions, "is_cache_handler"
)
callbacks = self._filter_functions(all_functions, "is_callback")
for agent_name, agent_info in self.agents_config.items():
self._map_agent_variables(
agent_name,
agent_info,
llms,
tool_functions,
cache_handler_functions,
callbacks,
)
def _map_agent_variables(
self,
agent_name: str,
agent_info: dict[str, Any],
llms: dict[str, Callable],
tool_functions: dict[str, Callable],
cache_handler_functions: dict[str, Callable],
callbacks: dict[str, Callable],
) -> None:
if llm := agent_info.get("llm"):
try:
self.agents_config[agent_name]["llm"] = llms[llm]()
except KeyError:
self.agents_config[agent_name]["llm"] = llm
if tools := agent_info.get("tools"):
self.agents_config[agent_name]["tools"] = [
tool_functions[tool]() for tool in tools
instance._original_functions = {
name: method
for name, method in cls.__dict__.items()
if any(
hasattr(method, attr)
for attr in [
"is_task",
"is_agent",
"is_before_kickoff",
"is_after_kickoff",
"is_kickoff",
]
if function_calling_llm := agent_info.get("function_calling_llm"):
try:
self.agents_config[agent_name]["function_calling_llm"] = llms[
function_calling_llm
]()
except KeyError:
self.agents_config[agent_name]["function_calling_llm"] = (
function_calling_llm
)
if step_callback := agent_info.get("step_callback"):
self.agents_config[agent_name]["step_callback"] = callbacks[
step_callback
]()
if cache_handler := agent_info.get("cache_handler"):
self.agents_config[agent_name]["cache_handler"] = (
cache_handler_functions[cache_handler]()
)
def map_all_task_variables(self) -> None:
all_functions = self._get_all_functions()
agents = self._filter_functions(all_functions, "is_agent")
tasks = self._filter_functions(all_functions, "is_task")
output_json_functions = self._filter_functions(
all_functions, "is_output_json"
)
tool_functions = self._filter_functions(all_functions, "is_tool")
callback_functions = self._filter_functions(all_functions, "is_callback")
output_pydantic_functions = self._filter_functions(
all_functions, "is_output_pydantic"
}
instance._original_tasks = _filter_functions(
instance._original_functions, "is_task"
)
instance._original_agents = _filter_functions(
instance._original_functions, "is_agent"
)
instance._before_kickoff = _filter_functions(
instance._original_functions, "is_before_kickoff"
)
instance._after_kickoff = _filter_functions(
instance._original_functions, "is_after_kickoff"
)
instance._kickoff = _filter_functions(
instance._original_functions, "is_kickoff"
)
instance._after_kickoff["_close_mcp_server"] = instance._close_mcp_server
def _close_mcp_server(
self: CrewInstance, _instance: CrewInstance, outputs: CrewOutput
) -> CrewOutput:
"""Stop MCP server adapter and return outputs.
Args:
self: Crew instance with MCP server adapter.
_instance: Crew instance (unused, required by callback signature).
outputs: Crew execution outputs.
Returns:
Unmodified crew outputs.
"""
if self._mcp_server_adapter is not None:
try:
self._mcp_server_adapter.stop()
except Exception as e:
logging.warning(f"Error stopping MCP server: {e}")
return outputs
def get_mcp_tools(self: Any, *tool_names: str) -> list[BaseTool]:
"""Get MCP tools filtered by name.
Args:
self: Crew instance with MCP server configuration.
*tool_names: Optional tool names to filter by.
Returns:
List of filtered MCP tools, or empty list if no MCP server configured.
"""
if not self.mcp_server_params:
return []
from crewai_tools import MCPServerAdapter # type: ignore[import-untyped]
if self._mcp_server_adapter is None:
self._mcp_server_adapter = MCPServerAdapter(
self.mcp_server_params, connect_timeout=self.mcp_connect_timeout
)
return self._mcp_server_adapter.tools.filter_by_names(tool_names or None)
def _load_config(
self: Any, config_path: str | None, config_type: str
) -> dict[str, Any]:
"""Load YAML config file or return empty dict if not found.
Args:
config_path: Relative path to config file.
config_type: Config type for logging (e.g., "agent", "task").
Returns:
Config dictionary or empty dict.
"""
if isinstance(config_path, str):
full_path = self.base_directory / config_path
try:
return self.load_yaml(full_path)
except FileNotFoundError:
logging.warning(
f"{config_type.capitalize()} config file not found at {full_path}. "
f"Proceeding with empty {config_type} configurations."
)
return {}
else:
logging.warning(
f"No {config_type} configuration path provided. "
f"Proceeding with empty {config_type} configurations."
)
return {}
def load_configurations(self: Any) -> None:
"""Load agent and task YAML configurations.
Args:
self: Crew instance with configuration paths.
"""
self.agents_config = self._load_config(self.original_agents_config_path, "agent")
self.tasks_config = self._load_config(self.original_tasks_config_path, "task")
def load_yaml(config_path: Path) -> Any:
"""Load and parse YAML file.
Args:
config_path: Path to YAML configuration file.
Returns:
Parsed YAML content.
Raises:
FileNotFoundError: If config file does not exist.
"""
try:
with open(config_path, encoding="utf-8") as file:
return yaml.safe_load(file)
except FileNotFoundError:
print(f"File not found: {config_path}")
raise
def _get_all_functions(self: Any) -> dict[str, Callable[..., Any]]:
"""Return all non-dunder callable attributes.
Args:
self: Instance to inspect for callable attributes.
Returns:
Dictionary mapping attribute names to callable objects.
"""
return {
name: getattr(self, name)
for name in dir(self)
if not (name.startswith("__") and name.endswith("__"))
and callable(getattr(self, name, None))
}
def _filter_functions(
functions: dict[str, Callable[..., Any]], attribute: str
) -> dict[str, Callable[..., Any]]:
"""Filter functions by attribute presence.
Args:
functions: Dictionary of functions to filter.
attribute: Attribute name to check for.
Returns:
Dictionary containing only functions with the specified attribute.
"""
return {name: func for name, func in functions.items() if hasattr(func, attribute)}
def map_all_agent_variables(self: Any) -> None:
"""Map agent configuration variables to callable instances.
Args:
self: Crew instance with agent configurations to map.
"""
llms = _filter_functions(self._all_functions, "is_llm")
tool_functions = _filter_functions(self._all_functions, "is_tool")
cache_handler_functions = _filter_functions(self._all_functions, "is_cache_handler")
callbacks = _filter_functions(self._all_functions, "is_callback")
for agent_name, agent_info in self.agents_config.items():
self._map_agent_variables(
agent_name,
agent_info,
llms,
tool_functions,
cache_handler_functions,
callbacks,
)
def _map_agent_variables(
self: Any,
agent_name: str,
agent_info: dict[str, Any],
llms: dict[str, Callable[..., Any]],
tool_functions: dict[str, Callable[..., Any]],
cache_handler_functions: dict[str, Callable[..., Any]],
callbacks: dict[str, Callable[..., Any]],
) -> None:
"""Resolve and map variables for a single agent.
Args:
self: Crew instance with agent configurations.
agent_name: Name of agent to configure.
agent_info: Agent configuration dictionary.
llms: Dictionary of available LLM providers.
tool_functions: Dictionary of available tools.
cache_handler_functions: Dictionary of available cache handlers.
callbacks: Dictionary of available callbacks.
"""
if llm := agent_info.get("llm"):
try:
self.agents_config[agent_name]["llm"] = llms[llm]()
except KeyError:
self.agents_config[agent_name]["llm"] = llm
if tools := agent_info.get("tools"):
self.agents_config[agent_name]["tools"] = [
tool_functions[tool]() for tool in tools
]
if function_calling_llm := agent_info.get("function_calling_llm"):
try:
self.agents_config[agent_name]["function_calling_llm"] = llms[
function_calling_llm
]()
except KeyError:
self.agents_config[agent_name]["function_calling_llm"] = (
function_calling_llm
)
for task_name, task_info in self.tasks_config.items():
self._map_task_variables(
task_name,
task_info,
agents,
tasks,
output_json_functions,
tool_functions,
callback_functions,
output_pydantic_functions,
)
if step_callback := agent_info.get("step_callback"):
self.agents_config[agent_name]["step_callback"] = callbacks[step_callback]()
def _map_task_variables(
self,
task_name: str,
task_info: dict[str, Any],
agents: dict[str, Callable],
tasks: dict[str, Callable],
output_json_functions: dict[str, Callable],
tool_functions: dict[str, Callable],
callback_functions: dict[str, Callable],
output_pydantic_functions: dict[str, Callable],
) -> None:
if context_list := task_info.get("context"):
self.tasks_config[task_name]["context"] = [
tasks[context_task_name]() for context_task_name in context_list
]
if cache_handler := agent_info.get("cache_handler"):
self.agents_config[agent_name]["cache_handler"] = cache_handler_functions[
cache_handler
]()
if tools := task_info.get("tools"):
self.tasks_config[task_name]["tools"] = [
tool_functions[tool]() for tool in tools
]
if agent_name := task_info.get("agent"):
self.tasks_config[task_name]["agent"] = agents[agent_name]()
def map_all_task_variables(self: Any) -> None:
"""Map task configuration variables to callable instances.
if output_json := task_info.get("output_json"):
self.tasks_config[task_name]["output_json"] = output_json_functions[
output_json
]
Args:
self: Crew instance with task configurations to map.
"""
agents = _filter_functions(self._all_functions, "is_agent")
tasks = _filter_functions(self._all_functions, "is_task")
output_json_functions = _filter_functions(self._all_functions, "is_output_json")
tool_functions = _filter_functions(self._all_functions, "is_tool")
callback_functions = _filter_functions(self._all_functions, "is_callback")
output_pydantic_functions = _filter_functions(
self._all_functions, "is_output_pydantic"
)
if output_pydantic := task_info.get("output_pydantic"):
self.tasks_config[task_name]["output_pydantic"] = (
output_pydantic_functions[output_pydantic]
)
for task_name, task_info in self.tasks_config.items():
self._map_task_variables(
task_name,
task_info,
agents,
tasks,
output_json_functions,
tool_functions,
callback_functions,
output_pydantic_functions,
)
if callbacks := task_info.get("callbacks"):
self.tasks_config[task_name]["callbacks"] = [
callback_functions[callback]() for callback in callbacks
]
if guardrail := task_info.get("guardrail"):
self.tasks_config[task_name]["guardrail"] = guardrail
def _map_task_variables(
self: Any,
task_name: str,
task_info: dict[str, Any],
agents: dict[str, Callable[..., Any]],
tasks: dict[str, Callable[..., Any]],
output_json_functions: dict[str, Callable[..., Any]],
tool_functions: dict[str, Callable[..., Any]],
callback_functions: dict[str, Callable[..., Any]],
output_pydantic_functions: dict[str, Callable[..., Any]],
) -> None:
"""Resolve and map variables for a single task.
# Include base class (qual)name in the wrapper class (qual)name.
WrappedClass.__name__ = CrewBase.__name__ + "(" + cls.__name__ + ")"
WrappedClass.__qualname__ = CrewBase.__qualname__ + "(" + cls.__name__ + ")"
WrappedClass._crew_name = cls.__name__
Args:
self: Crew instance with task configurations.
task_name: Name of task to configure.
task_info: Task configuration dictionary.
agents: Dictionary of available agents.
tasks: Dictionary of available tasks.
output_json_functions: Dictionary of available JSON output schemas.
tool_functions: Dictionary of available tools.
callback_functions: Dictionary of available callbacks.
output_pydantic_functions: Dictionary of available Pydantic output schemas.
"""
if context_list := task_info.get("context"):
self.tasks_config[task_name]["context"] = [
tasks[context_task_name]() for context_task_name in context_list
]
return cast(T, WrappedClass)
if tools := task_info.get("tools"):
self.tasks_config[task_name]["tools"] = [
tool_functions[tool]() for tool in tools
]
if agent_name := task_info.get("agent"):
self.tasks_config[task_name]["agent"] = agents[agent_name]()
if output_json := task_info.get("output_json"):
self.tasks_config[task_name]["output_json"] = output_json_functions[output_json]
if output_pydantic := task_info.get("output_pydantic"):
self.tasks_config[task_name]["output_pydantic"] = output_pydantic_functions[
output_pydantic
]
if callbacks := task_info.get("callbacks"):
self.tasks_config[task_name]["callbacks"] = [
callback_functions[callback]() for callback in callbacks
]
if guardrail := task_info.get("guardrail"):
self.tasks_config[task_name]["guardrail"] = guardrail
def CrewBase(cls: type) -> type: # noqa: N802
"""Apply CrewBaseMeta metaclass to a class for decorator syntax compatibility.
Args:
cls: Class to apply metaclass to.
Returns:
New class with CrewBaseMeta metaclass applied.
"""
if isinstance(cls, CrewBaseMeta):
return cls
namespace = {
key: value
for key, value in cls.__dict__.items()
if not key.startswith("__")
or key in ("__module__", "__qualname__", "__annotations__")
}
return CrewBaseMeta(cls.__name__, cls.__bases__, namespace)

View File

@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Generic, ParamSpec, Protocol, Self, TypeV
if TYPE_CHECKING:
from crewai import Agent, Task
from crewai.crews.crew_output import CrewOutput
P = ParamSpec("P")
R = TypeVar("R")
@@ -37,13 +38,22 @@ def _copy_function_metadata(wrapper: Any, func: Callable[..., Any]) -> None:
class CrewInstance(Protocol):
"""Protocol for crew class instances with required attributes."""
_mcp_server_adapter: Any
_all_functions: dict[str, Callable[..., Any]]
_original_functions: dict[str, Callable[..., Any]]
_original_tasks: dict[str, Callable[[Self], Task]]
_original_agents: dict[str, Callable[[Self], Agent]]
_before_kickoff: dict[str, Callable[..., Any]]
_after_kickoff: dict[str, Callable[..., Any]]
_kickoff: dict[str, Callable[..., Any]]
agents: list[Agent]
tasks: list[Task]
def load_configurations(self) -> None: ...
def map_all_agent_variables(self) -> None: ...
def map_all_task_variables(self) -> None: ...
def _close_mcp_server(self, instance: Self, outputs: CrewOutput) -> CrewOutput: ...
class DecoratedMethod(Generic[P, R]):
"""Base wrapper for methods with decorator metadata.