diff --git a/src/crewai/project/crew_base.py b/src/crewai/project/crew_base.py index 44faba522..3dcb2ae00 100644 --- a/src/crewai/project/crew_base.py +++ b/src/crewai/project/crew_base.py @@ -1,301 +1,447 @@ -"""Base decorator for creating crew classes with configuration and function management.""" +"""Base metaclass for creating crew classes with configuration and function management.""" + +from __future__ import annotations import inspect import logging from collections.abc import Callable from pathlib import Path -from typing import Any, TypeVar, cast +from typing import TYPE_CHECKING, Any import yaml from dotenv import load_dotenv from crewai.tools import BaseTool +if TYPE_CHECKING: + from crewai.crews.crew_output import CrewOutput + from crewai.project.wrappers import CrewInstance + load_dotenv() -T = TypeVar("T", bound=type) +class CrewBaseMeta(type): + """Metaclass that adds crew functionality to classes.""" -def CrewBase(cls: T) -> T: # noqa: N802 - """Wraps a class with crew functionality and configuration management.""" + def __new__( + mcs, + name: str, + bases: tuple[type, ...], + namespace: dict[str, Any], + **kwargs: Any, + ) -> type: + """Create crew class with configuration and method injection. - class WrappedClass(cls): # type: ignore - is_crew_class: bool = True # type: ignore + Args: + name: Class name. + bases: Base classes. + namespace: Class namespace dictionary. + **kwargs: Additional keyword arguments. - # Get the directory of the class being decorated - base_directory = Path(inspect.getfile(cls)).parent + Returns: + New crew class with injected methods and attributes. + """ + cls = super().__new__(mcs, name, bases, namespace) - original_agents_config_path = getattr( + cls.is_crew_class = True # type: ignore[attr-defined] + cls._crew_name = name # type: ignore[attr-defined] + + try: + cls.base_directory = Path(inspect.getfile(cls)).parent # type: ignore[attr-defined] + except (TypeError, OSError): + cls.base_directory = Path.cwd() # type: ignore[attr-defined] + + cls.original_agents_config_path = getattr( # type: ignore[attr-defined] cls, "agents_config", "config/agents.yaml" ) - original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml") + cls.original_tasks_config_path = getattr( # type: ignore[attr-defined] + cls, "tasks_config", "config/tasks.yaml" + ) - mcp_server_params: Any = getattr(cls, "mcp_server_params", None) - mcp_connect_timeout: int = getattr(cls, "mcp_connect_timeout", 30) + cls.mcp_server_params = getattr(cls, "mcp_server_params", None) # type: ignore[attr-defined] + cls.mcp_connect_timeout = getattr(cls, "mcp_connect_timeout", 30) # type: ignore[attr-defined] - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.load_configurations() - self.map_all_agent_variables() - self.map_all_task_variables() - # Preserve all decorated functions - self._original_functions = { - name: method - for name, method in cls.__dict__.items() - if any( - hasattr(method, attr) - for attr in [ - "is_task", - "is_agent", - "is_before_kickoff", - "is_after_kickoff", - "is_kickoff", - ] - ) - } - # Store specific function types - self._original_tasks = self._filter_functions( - self._original_functions, "is_task" - ) - self._original_agents = self._filter_functions( - self._original_functions, "is_agent" - ) - self._before_kickoff = self._filter_functions( - self._original_functions, "is_before_kickoff" - ) - self._after_kickoff = self._filter_functions( - self._original_functions, "is_after_kickoff" - ) - self._kickoff = self._filter_functions( - self._original_functions, "is_kickoff" - ) + cls._close_mcp_server = _close_mcp_server # type: ignore[attr-defined] + cls.get_mcp_tools = get_mcp_tools # type: ignore[attr-defined] + cls._load_config = _load_config # type: ignore[attr-defined] + cls.load_configurations = load_configurations # type: ignore[attr-defined] + cls.load_yaml = staticmethod(load_yaml) # type: ignore[attr-defined] + cls.map_all_agent_variables = map_all_agent_variables # type: ignore[attr-defined] + cls._map_agent_variables = _map_agent_variables # type: ignore[attr-defined] + cls.map_all_task_variables = map_all_task_variables # type: ignore[attr-defined] + cls._map_task_variables = _map_task_variables # type: ignore[attr-defined] - # Add close mcp server method to after kickoff - bound_method = self._create_close_mcp_server_method() - self._after_kickoff["_close_mcp_server"] = bound_method + return cls - def _create_close_mcp_server_method(self): - def _close_mcp_server(self, instance, outputs): - adapter = getattr(self, "_mcp_server_adapter", None) - if adapter is not None: - try: - adapter.stop() - except Exception as e: - logging.warning(f"Error stopping MCP server: {e}") - return outputs + def __call__(cls, *args: Any, **kwargs: Any) -> CrewInstance: + """Intercept instance creation to initialize crew functionality. - _close_mcp_server.is_after_kickoff = True + Args: + *args: Positional arguments for instance creation. + **kwargs: Keyword arguments for instance creation. - import types + Returns: + Initialized crew instance. + """ + instance: CrewInstance = super().__call__(*args, **kwargs) + CrewBaseMeta._initialize_crew_instance(instance, cls) + return instance - return types.MethodType(_close_mcp_server, self) + @staticmethod + def _initialize_crew_instance(instance: CrewInstance, cls: type) -> None: + """Initialize crew instance attributes and load configurations. - def get_mcp_tools(self, *tool_names: list[str]) -> list[BaseTool]: - if not self.mcp_server_params: - return [] + Args: + instance: Crew instance to initialize. + cls: Crew class type. + """ + instance._mcp_server_adapter = None + instance.load_configurations() + instance._all_functions = _get_all_functions(instance) + instance.map_all_agent_variables() + instance.map_all_task_variables() - from crewai_tools import MCPServerAdapter # type: ignore[import-untyped] - - adapter = getattr(self, "_mcp_server_adapter", None) - if not adapter: - self._mcp_server_adapter = MCPServerAdapter( - self.mcp_server_params, connect_timeout=self.mcp_connect_timeout - ) - - return self._mcp_server_adapter.tools.filter_by_names(tool_names or None) - - def load_configurations(self): - """Load agent and task configurations from YAML files.""" - if isinstance(self.original_agents_config_path, str): - agents_config_path = ( - self.base_directory / self.original_agents_config_path - ) - try: - self.agents_config = self.load_yaml(agents_config_path) - except FileNotFoundError: - logging.warning( - f"Agent config file not found at {agents_config_path}. " - "Proceeding with empty agent configurations." - ) - self.agents_config = {} - else: - logging.warning( - "No agent configuration path provided. Proceeding with empty agent configurations." - ) - self.agents_config = {} - - if isinstance(self.original_tasks_config_path, str): - tasks_config_path = ( - self.base_directory / self.original_tasks_config_path - ) - try: - self.tasks_config = self.load_yaml(tasks_config_path) - except FileNotFoundError: - logging.warning( - f"Task config file not found at {tasks_config_path}. " - "Proceeding with empty task configurations." - ) - self.tasks_config = {} - else: - logging.warning( - "No task configuration path provided. Proceeding with empty task configurations." - ) - self.tasks_config = {} - - @staticmethod - def load_yaml(config_path: Path): - try: - with open(config_path, "r", encoding="utf-8") as file: - return yaml.safe_load(file) - except FileNotFoundError: - print(f"File not found: {config_path}") - raise - - def _get_all_functions(self): - return { - name: getattr(self, name) - for name in dir(self) - if callable(getattr(self, name)) - } - - def _filter_functions( - self, functions: dict[str, Callable], attribute: str - ) -> dict[str, Callable]: - return { - name: func - for name, func in functions.items() - if hasattr(func, attribute) - } - - def map_all_agent_variables(self) -> None: - all_functions = self._get_all_functions() - llms = self._filter_functions(all_functions, "is_llm") - tool_functions = self._filter_functions(all_functions, "is_tool") - cache_handler_functions = self._filter_functions( - all_functions, "is_cache_handler" - ) - callbacks = self._filter_functions(all_functions, "is_callback") - - for agent_name, agent_info in self.agents_config.items(): - self._map_agent_variables( - agent_name, - agent_info, - llms, - tool_functions, - cache_handler_functions, - callbacks, - ) - - def _map_agent_variables( - self, - agent_name: str, - agent_info: dict[str, Any], - llms: dict[str, Callable], - tool_functions: dict[str, Callable], - cache_handler_functions: dict[str, Callable], - callbacks: dict[str, Callable], - ) -> None: - if llm := agent_info.get("llm"): - try: - self.agents_config[agent_name]["llm"] = llms[llm]() - except KeyError: - self.agents_config[agent_name]["llm"] = llm - - if tools := agent_info.get("tools"): - self.agents_config[agent_name]["tools"] = [ - tool_functions[tool]() for tool in tools + instance._original_functions = { + name: method + for name, method in cls.__dict__.items() + if any( + hasattr(method, attr) + for attr in [ + "is_task", + "is_agent", + "is_before_kickoff", + "is_after_kickoff", + "is_kickoff", ] - - if function_calling_llm := agent_info.get("function_calling_llm"): - try: - self.agents_config[agent_name]["function_calling_llm"] = llms[ - function_calling_llm - ]() - except KeyError: - self.agents_config[agent_name]["function_calling_llm"] = ( - function_calling_llm - ) - - if step_callback := agent_info.get("step_callback"): - self.agents_config[agent_name]["step_callback"] = callbacks[ - step_callback - ]() - - if cache_handler := agent_info.get("cache_handler"): - self.agents_config[agent_name]["cache_handler"] = ( - cache_handler_functions[cache_handler]() - ) - - def map_all_task_variables(self) -> None: - all_functions = self._get_all_functions() - agents = self._filter_functions(all_functions, "is_agent") - tasks = self._filter_functions(all_functions, "is_task") - output_json_functions = self._filter_functions( - all_functions, "is_output_json" ) - tool_functions = self._filter_functions(all_functions, "is_tool") - callback_functions = self._filter_functions(all_functions, "is_callback") - output_pydantic_functions = self._filter_functions( - all_functions, "is_output_pydantic" + } + + instance._original_tasks = _filter_functions( + instance._original_functions, "is_task" + ) + instance._original_agents = _filter_functions( + instance._original_functions, "is_agent" + ) + instance._before_kickoff = _filter_functions( + instance._original_functions, "is_before_kickoff" + ) + instance._after_kickoff = _filter_functions( + instance._original_functions, "is_after_kickoff" + ) + instance._kickoff = _filter_functions( + instance._original_functions, "is_kickoff" + ) + + instance._after_kickoff["_close_mcp_server"] = instance._close_mcp_server + + +def _close_mcp_server( + self: CrewInstance, _instance: CrewInstance, outputs: CrewOutput +) -> CrewOutput: + """Stop MCP server adapter and return outputs. + + Args: + self: Crew instance with MCP server adapter. + _instance: Crew instance (unused, required by callback signature). + outputs: Crew execution outputs. + + Returns: + Unmodified crew outputs. + """ + if self._mcp_server_adapter is not None: + try: + self._mcp_server_adapter.stop() + except Exception as e: + logging.warning(f"Error stopping MCP server: {e}") + return outputs + + +def get_mcp_tools(self: Any, *tool_names: str) -> list[BaseTool]: + """Get MCP tools filtered by name. + + Args: + self: Crew instance with MCP server configuration. + *tool_names: Optional tool names to filter by. + + Returns: + List of filtered MCP tools, or empty list if no MCP server configured. + """ + if not self.mcp_server_params: + return [] + + from crewai_tools import MCPServerAdapter # type: ignore[import-untyped] + + if self._mcp_server_adapter is None: + self._mcp_server_adapter = MCPServerAdapter( + self.mcp_server_params, connect_timeout=self.mcp_connect_timeout + ) + + return self._mcp_server_adapter.tools.filter_by_names(tool_names or None) + + +def _load_config( + self: Any, config_path: str | None, config_type: str +) -> dict[str, Any]: + """Load YAML config file or return empty dict if not found. + + Args: + config_path: Relative path to config file. + config_type: Config type for logging (e.g., "agent", "task"). + + Returns: + Config dictionary or empty dict. + """ + if isinstance(config_path, str): + full_path = self.base_directory / config_path + try: + return self.load_yaml(full_path) + except FileNotFoundError: + logging.warning( + f"{config_type.capitalize()} config file not found at {full_path}. " + f"Proceeding with empty {config_type} configurations." + ) + return {} + else: + logging.warning( + f"No {config_type} configuration path provided. " + f"Proceeding with empty {config_type} configurations." + ) + return {} + + +def load_configurations(self: Any) -> None: + """Load agent and task YAML configurations. + + Args: + self: Crew instance with configuration paths. + """ + self.agents_config = self._load_config(self.original_agents_config_path, "agent") + self.tasks_config = self._load_config(self.original_tasks_config_path, "task") + + +def load_yaml(config_path: Path) -> Any: + """Load and parse YAML file. + + Args: + config_path: Path to YAML configuration file. + + Returns: + Parsed YAML content. + + Raises: + FileNotFoundError: If config file does not exist. + """ + try: + with open(config_path, encoding="utf-8") as file: + return yaml.safe_load(file) + except FileNotFoundError: + print(f"File not found: {config_path}") + raise + + +def _get_all_functions(self: Any) -> dict[str, Callable[..., Any]]: + """Return all non-dunder callable attributes. + + Args: + self: Instance to inspect for callable attributes. + + Returns: + Dictionary mapping attribute names to callable objects. + """ + return { + name: getattr(self, name) + for name in dir(self) + if not (name.startswith("__") and name.endswith("__")) + and callable(getattr(self, name, None)) + } + + +def _filter_functions( + functions: dict[str, Callable[..., Any]], attribute: str +) -> dict[str, Callable[..., Any]]: + """Filter functions by attribute presence. + + Args: + functions: Dictionary of functions to filter. + attribute: Attribute name to check for. + + Returns: + Dictionary containing only functions with the specified attribute. + """ + return {name: func for name, func in functions.items() if hasattr(func, attribute)} + + +def map_all_agent_variables(self: Any) -> None: + """Map agent configuration variables to callable instances. + + Args: + self: Crew instance with agent configurations to map. + """ + llms = _filter_functions(self._all_functions, "is_llm") + tool_functions = _filter_functions(self._all_functions, "is_tool") + cache_handler_functions = _filter_functions(self._all_functions, "is_cache_handler") + callbacks = _filter_functions(self._all_functions, "is_callback") + + for agent_name, agent_info in self.agents_config.items(): + self._map_agent_variables( + agent_name, + agent_info, + llms, + tool_functions, + cache_handler_functions, + callbacks, + ) + + +def _map_agent_variables( + self: Any, + agent_name: str, + agent_info: dict[str, Any], + llms: dict[str, Callable[..., Any]], + tool_functions: dict[str, Callable[..., Any]], + cache_handler_functions: dict[str, Callable[..., Any]], + callbacks: dict[str, Callable[..., Any]], +) -> None: + """Resolve and map variables for a single agent. + + Args: + self: Crew instance with agent configurations. + agent_name: Name of agent to configure. + agent_info: Agent configuration dictionary. + llms: Dictionary of available LLM providers. + tool_functions: Dictionary of available tools. + cache_handler_functions: Dictionary of available cache handlers. + callbacks: Dictionary of available callbacks. + """ + if llm := agent_info.get("llm"): + try: + self.agents_config[agent_name]["llm"] = llms[llm]() + except KeyError: + self.agents_config[agent_name]["llm"] = llm + + if tools := agent_info.get("tools"): + self.agents_config[agent_name]["tools"] = [ + tool_functions[tool]() for tool in tools + ] + + if function_calling_llm := agent_info.get("function_calling_llm"): + try: + self.agents_config[agent_name]["function_calling_llm"] = llms[ + function_calling_llm + ]() + except KeyError: + self.agents_config[agent_name]["function_calling_llm"] = ( + function_calling_llm ) - for task_name, task_info in self.tasks_config.items(): - self._map_task_variables( - task_name, - task_info, - agents, - tasks, - output_json_functions, - tool_functions, - callback_functions, - output_pydantic_functions, - ) + if step_callback := agent_info.get("step_callback"): + self.agents_config[agent_name]["step_callback"] = callbacks[step_callback]() - def _map_task_variables( - self, - task_name: str, - task_info: dict[str, Any], - agents: dict[str, Callable], - tasks: dict[str, Callable], - output_json_functions: dict[str, Callable], - tool_functions: dict[str, Callable], - callback_functions: dict[str, Callable], - output_pydantic_functions: dict[str, Callable], - ) -> None: - if context_list := task_info.get("context"): - self.tasks_config[task_name]["context"] = [ - tasks[context_task_name]() for context_task_name in context_list - ] + if cache_handler := agent_info.get("cache_handler"): + self.agents_config[agent_name]["cache_handler"] = cache_handler_functions[ + cache_handler + ]() - if tools := task_info.get("tools"): - self.tasks_config[task_name]["tools"] = [ - tool_functions[tool]() for tool in tools - ] - if agent_name := task_info.get("agent"): - self.tasks_config[task_name]["agent"] = agents[agent_name]() +def map_all_task_variables(self: Any) -> None: + """Map task configuration variables to callable instances. - if output_json := task_info.get("output_json"): - self.tasks_config[task_name]["output_json"] = output_json_functions[ - output_json - ] + Args: + self: Crew instance with task configurations to map. + """ + agents = _filter_functions(self._all_functions, "is_agent") + tasks = _filter_functions(self._all_functions, "is_task") + output_json_functions = _filter_functions(self._all_functions, "is_output_json") + tool_functions = _filter_functions(self._all_functions, "is_tool") + callback_functions = _filter_functions(self._all_functions, "is_callback") + output_pydantic_functions = _filter_functions( + self._all_functions, "is_output_pydantic" + ) - if output_pydantic := task_info.get("output_pydantic"): - self.tasks_config[task_name]["output_pydantic"] = ( - output_pydantic_functions[output_pydantic] - ) + for task_name, task_info in self.tasks_config.items(): + self._map_task_variables( + task_name, + task_info, + agents, + tasks, + output_json_functions, + tool_functions, + callback_functions, + output_pydantic_functions, + ) - if callbacks := task_info.get("callbacks"): - self.tasks_config[task_name]["callbacks"] = [ - callback_functions[callback]() for callback in callbacks - ] - if guardrail := task_info.get("guardrail"): - self.tasks_config[task_name]["guardrail"] = guardrail +def _map_task_variables( + self: Any, + task_name: str, + task_info: dict[str, Any], + agents: dict[str, Callable[..., Any]], + tasks: dict[str, Callable[..., Any]], + output_json_functions: dict[str, Callable[..., Any]], + tool_functions: dict[str, Callable[..., Any]], + callback_functions: dict[str, Callable[..., Any]], + output_pydantic_functions: dict[str, Callable[..., Any]], +) -> None: + """Resolve and map variables for a single task. - # Include base class (qual)name in the wrapper class (qual)name. - WrappedClass.__name__ = CrewBase.__name__ + "(" + cls.__name__ + ")" - WrappedClass.__qualname__ = CrewBase.__qualname__ + "(" + cls.__name__ + ")" - WrappedClass._crew_name = cls.__name__ + Args: + self: Crew instance with task configurations. + task_name: Name of task to configure. + task_info: Task configuration dictionary. + agents: Dictionary of available agents. + tasks: Dictionary of available tasks. + output_json_functions: Dictionary of available JSON output schemas. + tool_functions: Dictionary of available tools. + callback_functions: Dictionary of available callbacks. + output_pydantic_functions: Dictionary of available Pydantic output schemas. + """ + if context_list := task_info.get("context"): + self.tasks_config[task_name]["context"] = [ + tasks[context_task_name]() for context_task_name in context_list + ] - return cast(T, WrappedClass) + if tools := task_info.get("tools"): + self.tasks_config[task_name]["tools"] = [ + tool_functions[tool]() for tool in tools + ] + + if agent_name := task_info.get("agent"): + self.tasks_config[task_name]["agent"] = agents[agent_name]() + + if output_json := task_info.get("output_json"): + self.tasks_config[task_name]["output_json"] = output_json_functions[output_json] + + if output_pydantic := task_info.get("output_pydantic"): + self.tasks_config[task_name]["output_pydantic"] = output_pydantic_functions[ + output_pydantic + ] + + if callbacks := task_info.get("callbacks"): + self.tasks_config[task_name]["callbacks"] = [ + callback_functions[callback]() for callback in callbacks + ] + + if guardrail := task_info.get("guardrail"): + self.tasks_config[task_name]["guardrail"] = guardrail + + +def CrewBase(cls: type) -> type: # noqa: N802 + """Apply CrewBaseMeta metaclass to a class for decorator syntax compatibility. + + Args: + cls: Class to apply metaclass to. + + Returns: + New class with CrewBaseMeta metaclass applied. + """ + if isinstance(cls, CrewBaseMeta): + return cls + + namespace = { + key: value + for key, value in cls.__dict__.items() + if not key.startswith("__") + or key in ("__module__", "__qualname__", "__annotations__") + } + + return CrewBaseMeta(cls.__name__, cls.__bases__, namespace) diff --git a/src/crewai/project/wrappers.py b/src/crewai/project/wrappers.py index 722692e31..9bef5f4b1 100644 --- a/src/crewai/project/wrappers.py +++ b/src/crewai/project/wrappers.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Generic, ParamSpec, Protocol, Self, TypeV if TYPE_CHECKING: from crewai import Agent, Task + from crewai.crews.crew_output import CrewOutput P = ParamSpec("P") R = TypeVar("R") @@ -37,13 +38,22 @@ def _copy_function_metadata(wrapper: Any, func: Callable[..., Any]) -> None: class CrewInstance(Protocol): """Protocol for crew class instances with required attributes.""" + _mcp_server_adapter: Any + _all_functions: dict[str, Callable[..., Any]] + _original_functions: dict[str, Callable[..., Any]] _original_tasks: dict[str, Callable[[Self], Task]] _original_agents: dict[str, Callable[[Self], Agent]] _before_kickoff: dict[str, Callable[..., Any]] _after_kickoff: dict[str, Callable[..., Any]] + _kickoff: dict[str, Callable[..., Any]] agents: list[Agent] tasks: list[Task] + def load_configurations(self) -> None: ... + def map_all_agent_variables(self) -> None: ... + def map_all_task_variables(self) -> None: ... + def _close_mcp_server(self, instance: Self, outputs: CrewOutput) -> CrewOutput: ... + class DecoratedMethod(Generic[P, R]): """Base wrapper for methods with decorator metadata.