refactor: convert CrewBase decorator to metaclass pattern

This commit is contained in:
Greyson Lalonde
2025-10-15 01:28:59 -04:00
parent 2a23bc604c
commit 8465350f1d
2 changed files with 415 additions and 259 deletions

View File

@@ -1,301 +1,447 @@
"""Base decorator for creating crew classes with configuration and function management.""" """Base metaclass for creating crew classes with configuration and function management."""
from __future__ import annotations
import inspect import inspect
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from pathlib import Path from pathlib import Path
from typing import Any, TypeVar, cast from typing import TYPE_CHECKING, Any
import yaml import yaml
from dotenv import load_dotenv from dotenv import load_dotenv
from crewai.tools import BaseTool from crewai.tools import BaseTool
if TYPE_CHECKING:
from crewai.crews.crew_output import CrewOutput
from crewai.project.wrappers import CrewInstance
load_dotenv() load_dotenv()
T = TypeVar("T", bound=type)
class CrewBaseMeta(type):
"""Metaclass that adds crew functionality to classes."""
def CrewBase(cls: T) -> T: # noqa: N802 def __new__(
"""Wraps a class with crew functionality and configuration management.""" mcs,
name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
**kwargs: Any,
) -> type:
"""Create crew class with configuration and method injection.
class WrappedClass(cls): # type: ignore Args:
is_crew_class: bool = True # type: ignore name: Class name.
bases: Base classes.
namespace: Class namespace dictionary.
**kwargs: Additional keyword arguments.
# Get the directory of the class being decorated Returns:
base_directory = Path(inspect.getfile(cls)).parent New crew class with injected methods and attributes.
"""
cls = super().__new__(mcs, name, bases, namespace)
original_agents_config_path = getattr( cls.is_crew_class = True # type: ignore[attr-defined]
cls._crew_name = name # type: ignore[attr-defined]
try:
cls.base_directory = Path(inspect.getfile(cls)).parent # type: ignore[attr-defined]
except (TypeError, OSError):
cls.base_directory = Path.cwd() # type: ignore[attr-defined]
cls.original_agents_config_path = getattr( # type: ignore[attr-defined]
cls, "agents_config", "config/agents.yaml" cls, "agents_config", "config/agents.yaml"
) )
original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml") cls.original_tasks_config_path = getattr( # type: ignore[attr-defined]
cls, "tasks_config", "config/tasks.yaml"
)
mcp_server_params: Any = getattr(cls, "mcp_server_params", None) cls.mcp_server_params = getattr(cls, "mcp_server_params", None) # type: ignore[attr-defined]
mcp_connect_timeout: int = getattr(cls, "mcp_connect_timeout", 30) cls.mcp_connect_timeout = getattr(cls, "mcp_connect_timeout", 30) # type: ignore[attr-defined]
def __init__(self, *args, **kwargs): cls._close_mcp_server = _close_mcp_server # type: ignore[attr-defined]
super().__init__(*args, **kwargs) cls.get_mcp_tools = get_mcp_tools # type: ignore[attr-defined]
self.load_configurations() cls._load_config = _load_config # type: ignore[attr-defined]
self.map_all_agent_variables() cls.load_configurations = load_configurations # type: ignore[attr-defined]
self.map_all_task_variables() cls.load_yaml = staticmethod(load_yaml) # type: ignore[attr-defined]
# Preserve all decorated functions cls.map_all_agent_variables = map_all_agent_variables # type: ignore[attr-defined]
self._original_functions = { cls._map_agent_variables = _map_agent_variables # type: ignore[attr-defined]
name: method cls.map_all_task_variables = map_all_task_variables # type: ignore[attr-defined]
for name, method in cls.__dict__.items() cls._map_task_variables = _map_task_variables # type: ignore[attr-defined]
if any(
hasattr(method, attr)
for attr in [
"is_task",
"is_agent",
"is_before_kickoff",
"is_after_kickoff",
"is_kickoff",
]
)
}
# Store specific function types
self._original_tasks = self._filter_functions(
self._original_functions, "is_task"
)
self._original_agents = self._filter_functions(
self._original_functions, "is_agent"
)
self._before_kickoff = self._filter_functions(
self._original_functions, "is_before_kickoff"
)
self._after_kickoff = self._filter_functions(
self._original_functions, "is_after_kickoff"
)
self._kickoff = self._filter_functions(
self._original_functions, "is_kickoff"
)
# Add close mcp server method to after kickoff return cls
bound_method = self._create_close_mcp_server_method()
self._after_kickoff["_close_mcp_server"] = bound_method
def _create_close_mcp_server_method(self): def __call__(cls, *args: Any, **kwargs: Any) -> CrewInstance:
def _close_mcp_server(self, instance, outputs): """Intercept instance creation to initialize crew functionality.
adapter = getattr(self, "_mcp_server_adapter", None)
if adapter is not None:
try:
adapter.stop()
except Exception as e:
logging.warning(f"Error stopping MCP server: {e}")
return outputs
_close_mcp_server.is_after_kickoff = True Args:
*args: Positional arguments for instance creation.
**kwargs: Keyword arguments for instance creation.
import types Returns:
Initialized crew instance.
"""
instance: CrewInstance = super().__call__(*args, **kwargs)
CrewBaseMeta._initialize_crew_instance(instance, cls)
return instance
return types.MethodType(_close_mcp_server, self) @staticmethod
def _initialize_crew_instance(instance: CrewInstance, cls: type) -> None:
"""Initialize crew instance attributes and load configurations.
def get_mcp_tools(self, *tool_names: list[str]) -> list[BaseTool]: Args:
if not self.mcp_server_params: instance: Crew instance to initialize.
return [] cls: Crew class type.
"""
instance._mcp_server_adapter = None
instance.load_configurations()
instance._all_functions = _get_all_functions(instance)
instance.map_all_agent_variables()
instance.map_all_task_variables()
from crewai_tools import MCPServerAdapter # type: ignore[import-untyped] instance._original_functions = {
name: method
adapter = getattr(self, "_mcp_server_adapter", None) for name, method in cls.__dict__.items()
if not adapter: if any(
self._mcp_server_adapter = MCPServerAdapter( hasattr(method, attr)
self.mcp_server_params, connect_timeout=self.mcp_connect_timeout for attr in [
) "is_task",
"is_agent",
return self._mcp_server_adapter.tools.filter_by_names(tool_names or None) "is_before_kickoff",
"is_after_kickoff",
def load_configurations(self): "is_kickoff",
"""Load agent and task configurations from YAML files."""
if isinstance(self.original_agents_config_path, str):
agents_config_path = (
self.base_directory / self.original_agents_config_path
)
try:
self.agents_config = self.load_yaml(agents_config_path)
except FileNotFoundError:
logging.warning(
f"Agent config file not found at {agents_config_path}. "
"Proceeding with empty agent configurations."
)
self.agents_config = {}
else:
logging.warning(
"No agent configuration path provided. Proceeding with empty agent configurations."
)
self.agents_config = {}
if isinstance(self.original_tasks_config_path, str):
tasks_config_path = (
self.base_directory / self.original_tasks_config_path
)
try:
self.tasks_config = self.load_yaml(tasks_config_path)
except FileNotFoundError:
logging.warning(
f"Task config file not found at {tasks_config_path}. "
"Proceeding with empty task configurations."
)
self.tasks_config = {}
else:
logging.warning(
"No task configuration path provided. Proceeding with empty task configurations."
)
self.tasks_config = {}
@staticmethod
def load_yaml(config_path: Path):
try:
with open(config_path, "r", encoding="utf-8") as file:
return yaml.safe_load(file)
except FileNotFoundError:
print(f"File not found: {config_path}")
raise
def _get_all_functions(self):
return {
name: getattr(self, name)
for name in dir(self)
if callable(getattr(self, name))
}
def _filter_functions(
self, functions: dict[str, Callable], attribute: str
) -> dict[str, Callable]:
return {
name: func
for name, func in functions.items()
if hasattr(func, attribute)
}
def map_all_agent_variables(self) -> None:
all_functions = self._get_all_functions()
llms = self._filter_functions(all_functions, "is_llm")
tool_functions = self._filter_functions(all_functions, "is_tool")
cache_handler_functions = self._filter_functions(
all_functions, "is_cache_handler"
)
callbacks = self._filter_functions(all_functions, "is_callback")
for agent_name, agent_info in self.agents_config.items():
self._map_agent_variables(
agent_name,
agent_info,
llms,
tool_functions,
cache_handler_functions,
callbacks,
)
def _map_agent_variables(
self,
agent_name: str,
agent_info: dict[str, Any],
llms: dict[str, Callable],
tool_functions: dict[str, Callable],
cache_handler_functions: dict[str, Callable],
callbacks: dict[str, Callable],
) -> None:
if llm := agent_info.get("llm"):
try:
self.agents_config[agent_name]["llm"] = llms[llm]()
except KeyError:
self.agents_config[agent_name]["llm"] = llm
if tools := agent_info.get("tools"):
self.agents_config[agent_name]["tools"] = [
tool_functions[tool]() for tool in tools
] ]
if function_calling_llm := agent_info.get("function_calling_llm"):
try:
self.agents_config[agent_name]["function_calling_llm"] = llms[
function_calling_llm
]()
except KeyError:
self.agents_config[agent_name]["function_calling_llm"] = (
function_calling_llm
)
if step_callback := agent_info.get("step_callback"):
self.agents_config[agent_name]["step_callback"] = callbacks[
step_callback
]()
if cache_handler := agent_info.get("cache_handler"):
self.agents_config[agent_name]["cache_handler"] = (
cache_handler_functions[cache_handler]()
)
def map_all_task_variables(self) -> None:
all_functions = self._get_all_functions()
agents = self._filter_functions(all_functions, "is_agent")
tasks = self._filter_functions(all_functions, "is_task")
output_json_functions = self._filter_functions(
all_functions, "is_output_json"
) )
tool_functions = self._filter_functions(all_functions, "is_tool") }
callback_functions = self._filter_functions(all_functions, "is_callback")
output_pydantic_functions = self._filter_functions( instance._original_tasks = _filter_functions(
all_functions, "is_output_pydantic" instance._original_functions, "is_task"
)
instance._original_agents = _filter_functions(
instance._original_functions, "is_agent"
)
instance._before_kickoff = _filter_functions(
instance._original_functions, "is_before_kickoff"
)
instance._after_kickoff = _filter_functions(
instance._original_functions, "is_after_kickoff"
)
instance._kickoff = _filter_functions(
instance._original_functions, "is_kickoff"
)
instance._after_kickoff["_close_mcp_server"] = instance._close_mcp_server
def _close_mcp_server(
self: CrewInstance, _instance: CrewInstance, outputs: CrewOutput
) -> CrewOutput:
"""Stop MCP server adapter and return outputs.
Args:
self: Crew instance with MCP server adapter.
_instance: Crew instance (unused, required by callback signature).
outputs: Crew execution outputs.
Returns:
Unmodified crew outputs.
"""
if self._mcp_server_adapter is not None:
try:
self._mcp_server_adapter.stop()
except Exception as e:
logging.warning(f"Error stopping MCP server: {e}")
return outputs
def get_mcp_tools(self: Any, *tool_names: str) -> list[BaseTool]:
"""Get MCP tools filtered by name.
Args:
self: Crew instance with MCP server configuration.
*tool_names: Optional tool names to filter by.
Returns:
List of filtered MCP tools, or empty list if no MCP server configured.
"""
if not self.mcp_server_params:
return []
from crewai_tools import MCPServerAdapter # type: ignore[import-untyped]
if self._mcp_server_adapter is None:
self._mcp_server_adapter = MCPServerAdapter(
self.mcp_server_params, connect_timeout=self.mcp_connect_timeout
)
return self._mcp_server_adapter.tools.filter_by_names(tool_names or None)
def _load_config(
self: Any, config_path: str | None, config_type: str
) -> dict[str, Any]:
"""Load YAML config file or return empty dict if not found.
Args:
config_path: Relative path to config file.
config_type: Config type for logging (e.g., "agent", "task").
Returns:
Config dictionary or empty dict.
"""
if isinstance(config_path, str):
full_path = self.base_directory / config_path
try:
return self.load_yaml(full_path)
except FileNotFoundError:
logging.warning(
f"{config_type.capitalize()} config file not found at {full_path}. "
f"Proceeding with empty {config_type} configurations."
)
return {}
else:
logging.warning(
f"No {config_type} configuration path provided. "
f"Proceeding with empty {config_type} configurations."
)
return {}
def load_configurations(self: Any) -> None:
"""Load agent and task YAML configurations.
Args:
self: Crew instance with configuration paths.
"""
self.agents_config = self._load_config(self.original_agents_config_path, "agent")
self.tasks_config = self._load_config(self.original_tasks_config_path, "task")
def load_yaml(config_path: Path) -> Any:
"""Load and parse YAML file.
Args:
config_path: Path to YAML configuration file.
Returns:
Parsed YAML content.
Raises:
FileNotFoundError: If config file does not exist.
"""
try:
with open(config_path, encoding="utf-8") as file:
return yaml.safe_load(file)
except FileNotFoundError:
print(f"File not found: {config_path}")
raise
def _get_all_functions(self: Any) -> dict[str, Callable[..., Any]]:
"""Return all non-dunder callable attributes.
Args:
self: Instance to inspect for callable attributes.
Returns:
Dictionary mapping attribute names to callable objects.
"""
return {
name: getattr(self, name)
for name in dir(self)
if not (name.startswith("__") and name.endswith("__"))
and callable(getattr(self, name, None))
}
def _filter_functions(
functions: dict[str, Callable[..., Any]], attribute: str
) -> dict[str, Callable[..., Any]]:
"""Filter functions by attribute presence.
Args:
functions: Dictionary of functions to filter.
attribute: Attribute name to check for.
Returns:
Dictionary containing only functions with the specified attribute.
"""
return {name: func for name, func in functions.items() if hasattr(func, attribute)}
def map_all_agent_variables(self: Any) -> None:
"""Map agent configuration variables to callable instances.
Args:
self: Crew instance with agent configurations to map.
"""
llms = _filter_functions(self._all_functions, "is_llm")
tool_functions = _filter_functions(self._all_functions, "is_tool")
cache_handler_functions = _filter_functions(self._all_functions, "is_cache_handler")
callbacks = _filter_functions(self._all_functions, "is_callback")
for agent_name, agent_info in self.agents_config.items():
self._map_agent_variables(
agent_name,
agent_info,
llms,
tool_functions,
cache_handler_functions,
callbacks,
)
def _map_agent_variables(
self: Any,
agent_name: str,
agent_info: dict[str, Any],
llms: dict[str, Callable[..., Any]],
tool_functions: dict[str, Callable[..., Any]],
cache_handler_functions: dict[str, Callable[..., Any]],
callbacks: dict[str, Callable[..., Any]],
) -> None:
"""Resolve and map variables for a single agent.
Args:
self: Crew instance with agent configurations.
agent_name: Name of agent to configure.
agent_info: Agent configuration dictionary.
llms: Dictionary of available LLM providers.
tool_functions: Dictionary of available tools.
cache_handler_functions: Dictionary of available cache handlers.
callbacks: Dictionary of available callbacks.
"""
if llm := agent_info.get("llm"):
try:
self.agents_config[agent_name]["llm"] = llms[llm]()
except KeyError:
self.agents_config[agent_name]["llm"] = llm
if tools := agent_info.get("tools"):
self.agents_config[agent_name]["tools"] = [
tool_functions[tool]() for tool in tools
]
if function_calling_llm := agent_info.get("function_calling_llm"):
try:
self.agents_config[agent_name]["function_calling_llm"] = llms[
function_calling_llm
]()
except KeyError:
self.agents_config[agent_name]["function_calling_llm"] = (
function_calling_llm
) )
for task_name, task_info in self.tasks_config.items(): if step_callback := agent_info.get("step_callback"):
self._map_task_variables( self.agents_config[agent_name]["step_callback"] = callbacks[step_callback]()
task_name,
task_info,
agents,
tasks,
output_json_functions,
tool_functions,
callback_functions,
output_pydantic_functions,
)
def _map_task_variables( if cache_handler := agent_info.get("cache_handler"):
self, self.agents_config[agent_name]["cache_handler"] = cache_handler_functions[
task_name: str, cache_handler
task_info: dict[str, Any], ]()
agents: dict[str, Callable],
tasks: dict[str, Callable],
output_json_functions: dict[str, Callable],
tool_functions: dict[str, Callable],
callback_functions: dict[str, Callable],
output_pydantic_functions: dict[str, Callable],
) -> None:
if context_list := task_info.get("context"):
self.tasks_config[task_name]["context"] = [
tasks[context_task_name]() for context_task_name in context_list
]
if tools := task_info.get("tools"):
self.tasks_config[task_name]["tools"] = [
tool_functions[tool]() for tool in tools
]
if agent_name := task_info.get("agent"): def map_all_task_variables(self: Any) -> None:
self.tasks_config[task_name]["agent"] = agents[agent_name]() """Map task configuration variables to callable instances.
if output_json := task_info.get("output_json"): Args:
self.tasks_config[task_name]["output_json"] = output_json_functions[ self: Crew instance with task configurations to map.
output_json """
] agents = _filter_functions(self._all_functions, "is_agent")
tasks = _filter_functions(self._all_functions, "is_task")
output_json_functions = _filter_functions(self._all_functions, "is_output_json")
tool_functions = _filter_functions(self._all_functions, "is_tool")
callback_functions = _filter_functions(self._all_functions, "is_callback")
output_pydantic_functions = _filter_functions(
self._all_functions, "is_output_pydantic"
)
if output_pydantic := task_info.get("output_pydantic"): for task_name, task_info in self.tasks_config.items():
self.tasks_config[task_name]["output_pydantic"] = ( self._map_task_variables(
output_pydantic_functions[output_pydantic] task_name,
) task_info,
agents,
tasks,
output_json_functions,
tool_functions,
callback_functions,
output_pydantic_functions,
)
if callbacks := task_info.get("callbacks"):
self.tasks_config[task_name]["callbacks"] = [
callback_functions[callback]() for callback in callbacks
]
if guardrail := task_info.get("guardrail"): def _map_task_variables(
self.tasks_config[task_name]["guardrail"] = guardrail self: Any,
task_name: str,
task_info: dict[str, Any],
agents: dict[str, Callable[..., Any]],
tasks: dict[str, Callable[..., Any]],
output_json_functions: dict[str, Callable[..., Any]],
tool_functions: dict[str, Callable[..., Any]],
callback_functions: dict[str, Callable[..., Any]],
output_pydantic_functions: dict[str, Callable[..., Any]],
) -> None:
"""Resolve and map variables for a single task.
# Include base class (qual)name in the wrapper class (qual)name. Args:
WrappedClass.__name__ = CrewBase.__name__ + "(" + cls.__name__ + ")" self: Crew instance with task configurations.
WrappedClass.__qualname__ = CrewBase.__qualname__ + "(" + cls.__name__ + ")" task_name: Name of task to configure.
WrappedClass._crew_name = cls.__name__ task_info: Task configuration dictionary.
agents: Dictionary of available agents.
tasks: Dictionary of available tasks.
output_json_functions: Dictionary of available JSON output schemas.
tool_functions: Dictionary of available tools.
callback_functions: Dictionary of available callbacks.
output_pydantic_functions: Dictionary of available Pydantic output schemas.
"""
if context_list := task_info.get("context"):
self.tasks_config[task_name]["context"] = [
tasks[context_task_name]() for context_task_name in context_list
]
return cast(T, WrappedClass) if tools := task_info.get("tools"):
self.tasks_config[task_name]["tools"] = [
tool_functions[tool]() for tool in tools
]
if agent_name := task_info.get("agent"):
self.tasks_config[task_name]["agent"] = agents[agent_name]()
if output_json := task_info.get("output_json"):
self.tasks_config[task_name]["output_json"] = output_json_functions[output_json]
if output_pydantic := task_info.get("output_pydantic"):
self.tasks_config[task_name]["output_pydantic"] = output_pydantic_functions[
output_pydantic
]
if callbacks := task_info.get("callbacks"):
self.tasks_config[task_name]["callbacks"] = [
callback_functions[callback]() for callback in callbacks
]
if guardrail := task_info.get("guardrail"):
self.tasks_config[task_name]["guardrail"] = guardrail
def CrewBase(cls: type) -> type: # noqa: N802
"""Apply CrewBaseMeta metaclass to a class for decorator syntax compatibility.
Args:
cls: Class to apply metaclass to.
Returns:
New class with CrewBaseMeta metaclass applied.
"""
if isinstance(cls, CrewBaseMeta):
return cls
namespace = {
key: value
for key, value in cls.__dict__.items()
if not key.startswith("__")
or key in ("__module__", "__qualname__", "__annotations__")
}
return CrewBaseMeta(cls.__name__, cls.__bases__, namespace)

View File

@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Generic, ParamSpec, Protocol, Self, TypeV
if TYPE_CHECKING: if TYPE_CHECKING:
from crewai import Agent, Task from crewai import Agent, Task
from crewai.crews.crew_output import CrewOutput
P = ParamSpec("P") P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
@@ -37,13 +38,22 @@ def _copy_function_metadata(wrapper: Any, func: Callable[..., Any]) -> None:
class CrewInstance(Protocol): class CrewInstance(Protocol):
"""Protocol for crew class instances with required attributes.""" """Protocol for crew class instances with required attributes."""
_mcp_server_adapter: Any
_all_functions: dict[str, Callable[..., Any]]
_original_functions: dict[str, Callable[..., Any]]
_original_tasks: dict[str, Callable[[Self], Task]] _original_tasks: dict[str, Callable[[Self], Task]]
_original_agents: dict[str, Callable[[Self], Agent]] _original_agents: dict[str, Callable[[Self], Agent]]
_before_kickoff: dict[str, Callable[..., Any]] _before_kickoff: dict[str, Callable[..., Any]]
_after_kickoff: dict[str, Callable[..., Any]] _after_kickoff: dict[str, Callable[..., Any]]
_kickoff: dict[str, Callable[..., Any]]
agents: list[Agent] agents: list[Agent]
tasks: list[Task] tasks: list[Task]
def load_configurations(self) -> None: ...
def map_all_agent_variables(self) -> None: ...
def map_all_task_variables(self) -> None: ...
def _close_mcp_server(self, instance: Self, outputs: CrewOutput) -> CrewOutput: ...
class DecoratedMethod(Generic[P, R]): class DecoratedMethod(Generic[P, R]):
"""Base wrapper for methods with decorator metadata. """Base wrapper for methods with decorator metadata.