diff --git a/lib/crewai/src/crewai/project/__init__.py b/lib/crewai/src/crewai/project/__init__.py index 7aabbebe1..b712138cc 100644 --- a/lib/crewai/src/crewai/project/__init__.py +++ b/lib/crewai/src/crewai/project/__init__.py @@ -1,4 +1,6 @@ -from .annotations import ( +"""Project package for CrewAI.""" + +from crewai.project.annotations import ( after_kickoff, agent, before_kickoff, @@ -11,7 +13,8 @@ from .annotations import ( task, tool, ) -from .crew_base import CrewBase +from crewai.project.crew_base import CrewBase + __all__ = [ "CrewBase", diff --git a/lib/crewai/src/crewai/project/annotations.py b/lib/crewai/src/crewai/project/annotations.py index b5f560ad1..17a07ddad 100644 --- a/lib/crewai/src/crewai/project/annotations.py +++ b/lib/crewai/src/crewai/project/annotations.py @@ -1,95 +1,194 @@ -from functools import wraps -from typing import Callable - -from crewai import Crew -from crewai.project.utils import memoize - """Decorators for defining crew components and their behaviors.""" +from __future__ import annotations -def before_kickoff(func): - """Marks a method to execute before crew kickoff.""" - func.is_before_kickoff = True - return func +from collections.abc import Callable +from functools import wraps +from typing import TYPE_CHECKING, Concatenate, ParamSpec, TypeVar + +from crewai.project.utils import memoize -def after_kickoff(func): - """Marks a method to execute after crew kickoff.""" - func.is_after_kickoff = True - return func +if TYPE_CHECKING: + from crewai import Agent, Crew, Task + +from crewai.project.wrappers import ( + AfterKickoffMethod, + AgentMethod, + BeforeKickoffMethod, + CacheHandlerMethod, + CallbackMethod, + CrewInstance, + LLMMethod, + OutputJsonClass, + OutputPydanticClass, + TaskMethod, + TaskResultT, + ToolMethod, +) -def task(func): - """Marks a method as a crew task.""" - func.is_task = True - - @wraps(func) - def wrapper(*args, **kwargs): - result = func(*args, **kwargs) - if not result.name: - result.name = func.__name__ - return result - - return memoize(wrapper) +P = ParamSpec("P") +P2 = ParamSpec("P2") +R = TypeVar("R") +R2 = TypeVar("R2") +T = TypeVar("T") -def agent(func): - """Marks a method as a crew agent.""" - func.is_agent = True - return memoize(func) +def before_kickoff(meth: Callable[P, R]) -> BeforeKickoffMethod[P, R]: + """Marks a method to execute before crew kickoff. + + Args: + meth: The method to mark. + + Returns: + A wrapped method marked for before kickoff execution. + """ + return BeforeKickoffMethod(meth) -def llm(func): - """Marks a method as an LLM provider.""" - func.is_llm = True - return memoize(func) +def after_kickoff(meth: Callable[P, R]) -> AfterKickoffMethod[P, R]: + """Marks a method to execute after crew kickoff. + + Args: + meth: The method to mark. + + Returns: + A wrapped method marked for after kickoff execution. + """ + return AfterKickoffMethod(meth) -def output_json(cls): - """Marks a class as JSON output format.""" - cls.is_output_json = True - return cls +def task(meth: Callable[P, TaskResultT]) -> TaskMethod[P, TaskResultT]: + """Marks a method as a crew task. + + Args: + meth: The method to mark. + + Returns: + A wrapped method marked as a task with memoization. + """ + return TaskMethod(memoize(meth)) -def output_pydantic(cls): - """Marks a class as Pydantic output format.""" - cls.is_output_pydantic = True - return cls +def agent(meth: Callable[P, R]) -> AgentMethod[P, R]: + """Marks a method as a crew agent. + + Args: + meth: The method to mark. + + Returns: + A wrapped method marked as an agent with memoization. + """ + return AgentMethod(memoize(meth)) -def tool(func): - """Marks a method as a crew tool.""" - func.is_tool = True - return memoize(func) +def llm(meth: Callable[P, R]) -> LLMMethod[P, R]: + """Marks a method as an LLM provider. + + Args: + meth: The method to mark. + + Returns: + A wrapped method marked as an LLM provider with memoization. + """ + return LLMMethod(memoize(meth)) -def callback(func): - """Marks a method as a crew callback.""" - func.is_callback = True - return memoize(func) +def output_json(cls: type[T]) -> OutputJsonClass[T]: + """Marks a class as JSON output format. + + Args: + cls: The class to mark. + + Returns: + A wrapped class marked as JSON output format. + """ + return OutputJsonClass(cls) -def cache_handler(func): - """Marks a method as a cache handler.""" - func.is_cache_handler = True - return memoize(func) +def output_pydantic(cls: type[T]) -> OutputPydanticClass[T]: + """Marks a class as Pydantic output format. + + Args: + cls: The class to mark. + + Returns: + A wrapped class marked as Pydantic output format. + """ + return OutputPydanticClass(cls) -def crew(func) -> Callable[..., Crew]: - """Marks a method as the main crew execution point.""" +def tool(meth: Callable[P, R]) -> ToolMethod[P, R]: + """Marks a method as a crew tool. - @wraps(func) - def wrapper(self, *args, **kwargs) -> Crew: - instantiated_tasks = [] - instantiated_agents = [] - agent_roles = set() + Args: + meth: The method to mark. + + Returns: + A wrapped method marked as a tool with memoization. + """ + return ToolMethod(memoize(meth)) + + +def callback(meth: Callable[P, R]) -> CallbackMethod[P, R]: + """Marks a method as a crew callback. + + Args: + meth: The method to mark. + + Returns: + A wrapped method marked as a callback with memoization. + """ + return CallbackMethod(memoize(meth)) + + +def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]: + """Marks a method as a cache handler. + + Args: + meth: The method to mark. + + Returns: + A wrapped method marked as a cache handler with memoization. + """ + return CacheHandlerMethod(memoize(meth)) + + +def crew( + meth: Callable[Concatenate[CrewInstance, P], Crew], +) -> Callable[Concatenate[CrewInstance, P], Crew]: + """Marks a method as the main crew execution point. + + Args: + meth: The method to mark as crew execution point. + + Returns: + A wrapped method that instantiates tasks and agents before execution. + """ + + @wraps(meth) + def wrapper(self: CrewInstance, *args: P.args, **kwargs: P.kwargs) -> Crew: + """Wrapper that sets up crew before calling the decorated method. + + Args: + self: The crew class instance. + *args: Additional positional arguments. + **kwargs: Keyword arguments to pass to the method. + + Returns: + The configured Crew instance with callbacks attached. + """ + instantiated_tasks: list[Task] = [] + instantiated_agents: list[Agent] = [] + agent_roles: set[str] = set() # Use the preserved task and agent information - tasks = self._original_tasks.items() - agents = self._original_agents.items() + tasks = self.__crew_metadata__["original_tasks"].items() + agents = self.__crew_metadata__["original_agents"].items() # Instantiate tasks in order - for _task_name, task_method in tasks: + for _, task_method in tasks: task_instance = task_method(self) instantiated_tasks.append(task_instance) agent_instance = getattr(task_instance, "agent", None) @@ -98,7 +197,7 @@ def crew(func) -> Callable[..., Crew]: agent_roles.add(agent_instance.role) # Instantiate agents not included by tasks - for _agent_name, agent_method in agents: + for _, agent_method in agents: agent_instance = agent_method(self) if agent_instance.role not in agent_roles: instantiated_agents.append(agent_instance) @@ -107,19 +206,44 @@ def crew(func) -> Callable[..., Crew]: self.agents = instantiated_agents self.tasks = instantiated_tasks - crew = func(self, *args, **kwargs) + crew_instance = meth(self, *args, **kwargs) - def callback_wrapper(callback, instance): - def wrapper(*args, **kwargs): - return callback(instance, *args, **kwargs) + def callback_wrapper( + hook: Callable[Concatenate[CrewInstance, P2], R2], instance: CrewInstance + ) -> Callable[P2, R2]: + """Bind a hook callback to an instance. - return wrapper + Args: + hook: The callback hook to bind. + instance: The instance to bind to. - for callback in self._before_kickoff.values(): - crew.before_kickoff_callbacks.append(callback_wrapper(callback, self)) - for callback in self._after_kickoff.values(): - crew.after_kickoff_callbacks.append(callback_wrapper(callback, self)) + Returns: + A bound callback function. + """ - return crew + def bound_callback(*cb_args: P2.args, **cb_kwargs: P2.kwargs) -> R2: + """Execute the bound callback. + + Args: + *cb_args: Positional arguments for the callback. + **cb_kwargs: Keyword arguments for the callback. + + Returns: + The result of the callback execution. + """ + return hook(instance, *cb_args, **cb_kwargs) + + return bound_callback + + for hook_callback in self.__crew_metadata__["before_kickoff"].values(): + crew_instance.before_kickoff_callbacks.append( + callback_wrapper(hook_callback, self) + ) + for hook_callback in self.__crew_metadata__["after_kickoff"].values(): + crew_instance.after_kickoff_callbacks.append( + callback_wrapper(hook_callback, self) + ) + + return crew_instance return memoize(wrapper) diff --git a/lib/crewai/src/crewai/project/crew_base.py b/lib/crewai/src/crewai/project/crew_base.py index 1065012c9..d2ba2d794 100644 --- a/lib/crewai/src/crewai/project/crew_base.py +++ b/lib/crewai/src/crewai/project/crew_base.py @@ -1,303 +1,632 @@ +"""Base metaclass for creating crew classes with configuration and method management.""" + +from __future__ import annotations + +from collections.abc import Callable import inspect import logging -from collections.abc import Callable from pathlib import Path -from typing import Any, TypeVar, cast +from typing import TYPE_CHECKING, Any, Literal, TypeGuard, TypeVar, TypedDict, cast -import yaml from dotenv import load_dotenv +import yaml +from crewai.project.wrappers import CrewClass, CrewMetadata from crewai.tools import BaseTool -from crewai.utilities.printer import Printer + + +if TYPE_CHECKING: + from crewai import Agent, Task + from crewai.agents.cache.cache_handler import CacheHandler + from crewai.crews.crew_output import CrewOutput + from crewai.project.wrappers import ( + CrewInstance, + OutputJsonClass, + OutputPydanticClass, + ) + from crewai.tasks.task_output import TaskOutput + + +class AgentConfig(TypedDict, total=False): + """Type definition for agent configuration dictionary. + + All fields are optional as they come from YAML configuration files. + Fields can be either string references (from YAML) or actual instances (after processing). + """ + + # Core agent attributes (from BaseAgent) + role: str + goal: str + backstory: str + cache: bool + verbose: bool + max_rpm: int + allow_delegation: bool + max_iter: int + max_tokens: int + callbacks: list[str] + + # LLM configuration + llm: str + function_calling_llm: str + use_system_prompt: bool + + # Template configuration + system_template: str + prompt_template: str + response_template: str + + # Tools and handlers (can be string references or instances) + tools: list[str] | list[BaseTool] + step_callback: str + cache_handler: str | CacheHandler + + # Code execution + allow_code_execution: bool + code_execution_mode: Literal["safe", "unsafe"] + + # Context and performance + respect_context_window: bool + max_retry_limit: int + + # Multimodal and reasoning + multimodal: bool + reasoning: bool + max_reasoning_attempts: int + + # Knowledge configuration + knowledge_sources: list[str] | list[Any] + knowledge_storage: str | Any + knowledge_config: dict[str, Any] + embedder: dict[str, Any] + agent_knowledge_context: str + crew_knowledge_context: str + knowledge_search_query: str + + # Misc configuration + inject_date: bool + date_format: str + from_repository: str + guardrail: Callable[[Any], tuple[bool, Any]] | str + guardrail_max_retries: int + + +class TaskConfig(TypedDict, total=False): + """Type definition for task configuration dictionary. + + All fields are optional as they come from YAML configuration files. + Fields can be either string references (from YAML) or actual instances (after processing). + """ + + # Core task attributes + name: str + description: str + expected_output: str + + # Agent and context + agent: str + context: list[str] + + # Tools and callbacks (can be string references or instances) + tools: list[str] | list[BaseTool] + callback: str + callbacks: list[str] + + # Output configuration + output_json: str + output_pydantic: str + output_file: str + create_directory: bool + + # Execution configuration + async_execution: bool + human_input: bool + markdown: bool + + # Guardrail configuration + guardrail: Callable[[TaskOutput], tuple[bool, Any]] | str + guardrail_max_retries: int + + # Misc configuration + allow_crewai_trigger_context: bool + load_dotenv() -_printer = Printer() -T = TypeVar("T", bound=type) - -"""Base decorator for creating crew classes with configuration and function management.""" +CallableT = TypeVar("CallableT", bound=Callable[..., Any]) -def CrewBase(cls: T) -> T: # noqa: N802 - """Wraps a class with crew functionality and configuration management.""" +def _set_base_directory(cls: type[CrewClass]) -> None: + """Set the base directory for the crew class. - class WrappedClass(cls): # type: ignore - is_crew_class: bool = True # type: ignore + Args: + cls: Crew class to configure. + """ + try: + cls.base_directory = Path(inspect.getfile(cls)).parent + except (TypeError, OSError): + cls.base_directory = Path.cwd() - # Get the directory of the class being decorated - base_directory = Path(inspect.getfile(cls)).parent - original_agents_config_path = getattr( - cls, "agents_config", "config/agents.yaml" +def _set_config_paths(cls: type[CrewClass]) -> None: + """Set the configuration file paths for the crew class. + + Args: + cls: Crew class to configure. + """ + cls.original_agents_config_path = getattr( + cls, "agents_config", "config/agents.yaml" + ) + cls.original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml") + + +def _set_mcp_params(cls: type[CrewClass]) -> None: + """Set the MCP server parameters for the crew class. + + Args: + cls: Crew class to configure. + """ + cls.mcp_server_params = getattr(cls, "mcp_server_params", None) + cls.mcp_connect_timeout = getattr(cls, "mcp_connect_timeout", 30) + + +def _is_string_list(value: list[str] | list[BaseTool]) -> TypeGuard[list[str]]: + """Type guard to check if list contains strings rather than BaseTool instances. + + Args: + value: List that may contain strings or BaseTool instances. + + Returns: + True if all elements are strings, False otherwise. + """ + return all(isinstance(item, str) for item in value) + + +def _is_string_value(value: str | CacheHandler) -> TypeGuard[str]: + """Type guard to check if value is a string rather than a CacheHandler instance. + + Args: + value: Value that may be a string or CacheHandler instance. + + Returns: + True if value is a string, False otherwise. + """ + return isinstance(value, str) + + +class CrewBaseMeta(type): + """Metaclass that adds crew functionality to classes.""" + + def __new__( + mcs, + name: str, + bases: tuple[type, ...], + namespace: dict[str, Any], + **kwargs: Any, + ) -> type[CrewClass]: + """Create crew class with configuration and method injection. + + Args: + name: Class name. + bases: Base classes. + namespace: Class namespace dictionary. + **kwargs: Additional keyword arguments. + + Returns: + New crew class with injected methods and attributes. + """ + cls = cast( + type[CrewClass], cast(object, super().__new__(mcs, name, bases, namespace)) ) - original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml") - mcp_server_params: Any = getattr(cls, "mcp_server_params", None) - mcp_connect_timeout: int = getattr(cls, "mcp_connect_timeout", 30) + cls.is_crew_class = True + cls._crew_name = name - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.load_configurations() - self.map_all_agent_variables() - self.map_all_task_variables() - # Preserve all decorated functions - self._original_functions = { - name: method - for name, method in cls.__dict__.items() - if any( - hasattr(method, attr) - for attr in [ - "is_task", - "is_agent", - "is_before_kickoff", - "is_after_kickoff", - "is_kickoff", - ] - ) - } - # Store specific function types - self._original_tasks = self._filter_functions( - self._original_functions, "is_task" - ) - self._original_agents = self._filter_functions( - self._original_functions, "is_agent" - ) - self._before_kickoff = self._filter_functions( - self._original_functions, "is_before_kickoff" - ) - self._after_kickoff = self._filter_functions( - self._original_functions, "is_after_kickoff" - ) - self._kickoff = self._filter_functions( - self._original_functions, "is_kickoff" - ) + for setup_fn in _CLASS_SETUP_FUNCTIONS: + setup_fn(cls) - # Add close mcp server method to after kickoff - bound_method = self._create_close_mcp_server_method() - self._after_kickoff["_close_mcp_server"] = bound_method + for method in _METHODS_TO_INJECT: + setattr(cls, method.__name__, method) - def _create_close_mcp_server_method(self): - def _close_mcp_server(self, instance, outputs): - adapter = getattr(self, "_mcp_server_adapter", None) - if adapter is not None: - try: - adapter.stop() - except Exception as e: - logging.warning(f"Error stopping MCP server: {e}") - return outputs + return cls - _close_mcp_server.is_after_kickoff = True + def __call__(cls, *args: Any, **kwargs: Any) -> CrewInstance: + """Intercept instance creation to initialize crew functionality. - import types + Args: + *args: Positional arguments for instance creation. + **kwargs: Keyword arguments for instance creation. - return types.MethodType(_close_mcp_server, self) + Returns: + Initialized crew instance. + """ + instance: CrewInstance = super().__call__(*args, **kwargs) + CrewBaseMeta._initialize_crew_instance(instance, cls) + return instance - def get_mcp_tools(self, *tool_names: list[str]) -> list[BaseTool]: - if not self.mcp_server_params: - return [] + @staticmethod + def _initialize_crew_instance(instance: CrewInstance, cls: type) -> None: + """Initialize crew instance attributes and load configurations. - from crewai_tools import MCPServerAdapter # type: ignore[import-untyped] + Args: + instance: Crew instance to initialize. + cls: Crew class type. + """ + instance._mcp_server_adapter = None + instance.load_configurations() + instance._all_methods = _get_all_methods(instance) + instance.map_all_agent_variables() + instance.map_all_task_variables() - adapter = getattr(self, "_mcp_server_adapter", None) - if not adapter: - self._mcp_server_adapter = MCPServerAdapter( - self.mcp_server_params, connect_timeout=self.mcp_connect_timeout - ) - - return self._mcp_server_adapter.tools.filter_by_names(tool_names or None) - - def load_configurations(self): - """Load agent and task configurations from YAML files.""" - if isinstance(self.original_agents_config_path, str): - agents_config_path = ( - self.base_directory / self.original_agents_config_path - ) - try: - self.agents_config = self.load_yaml(agents_config_path) - except FileNotFoundError: - logging.warning( - f"Agent config file not found at {agents_config_path}. " - "Proceeding with empty agent configurations." - ) - self.agents_config = {} - else: - logging.warning( - "No agent configuration path provided. Proceeding with empty agent configurations." - ) - self.agents_config = {} - - if isinstance(self.original_tasks_config_path, str): - tasks_config_path = ( - self.base_directory / self.original_tasks_config_path - ) - try: - self.tasks_config = self.load_yaml(tasks_config_path) - except FileNotFoundError: - logging.warning( - f"Task config file not found at {tasks_config_path}. " - "Proceeding with empty task configurations." - ) - self.tasks_config = {} - else: - logging.warning( - "No task configuration path provided. Proceeding with empty task configurations." - ) - self.tasks_config = {} - - @staticmethod - def load_yaml(config_path: Path): - try: - with open(config_path, "r", encoding="utf-8") as file: - return yaml.safe_load(file) - except FileNotFoundError: - _printer.print(f"File not found: {config_path}", color="red") - raise - - def _get_all_functions(self): - return { - name: getattr(self, name) - for name in dir(self) - if callable(getattr(self, name)) - } - - def _filter_functions( - self, functions: dict[str, Callable], attribute: str - ) -> dict[str, Callable]: - return { - name: func - for name, func in functions.items() - if hasattr(func, attribute) - } - - def map_all_agent_variables(self) -> None: - all_functions = self._get_all_functions() - llms = self._filter_functions(all_functions, "is_llm") - tool_functions = self._filter_functions(all_functions, "is_tool") - cache_handler_functions = self._filter_functions( - all_functions, "is_cache_handler" - ) - callbacks = self._filter_functions(all_functions, "is_callback") - - for agent_name, agent_info in self.agents_config.items(): - self._map_agent_variables( - agent_name, - agent_info, - llms, - tool_functions, - cache_handler_functions, - callbacks, - ) - - def _map_agent_variables( - self, - agent_name: str, - agent_info: dict[str, Any], - llms: dict[str, Callable], - tool_functions: dict[str, Callable], - cache_handler_functions: dict[str, Callable], - callbacks: dict[str, Callable], - ) -> None: - if llm := agent_info.get("llm"): - try: - self.agents_config[agent_name]["llm"] = llms[llm]() - except KeyError: - self.agents_config[agent_name]["llm"] = llm - - if tools := agent_info.get("tools"): - self.agents_config[agent_name]["tools"] = [ - tool_functions[tool]() for tool in tools + original_methods = { + name: method + for name, method in cls.__dict__.items() + if any( + hasattr(method, attr) + for attr in [ + "is_task", + "is_agent", + "is_before_kickoff", + "is_after_kickoff", + "is_kickoff", ] - - if function_calling_llm := agent_info.get("function_calling_llm"): - try: - self.agents_config[agent_name]["function_calling_llm"] = llms[ - function_calling_llm - ]() - except KeyError: - self.agents_config[agent_name]["function_calling_llm"] = ( - function_calling_llm - ) - - if step_callback := agent_info.get("step_callback"): - self.agents_config[agent_name]["step_callback"] = callbacks[ - step_callback - ]() - - if cache_handler := agent_info.get("cache_handler"): - self.agents_config[agent_name]["cache_handler"] = ( - cache_handler_functions[cache_handler]() - ) - - def map_all_task_variables(self) -> None: - all_functions = self._get_all_functions() - agents = self._filter_functions(all_functions, "is_agent") - tasks = self._filter_functions(all_functions, "is_task") - output_json_functions = self._filter_functions( - all_functions, "is_output_json" ) - tool_functions = self._filter_functions(all_functions, "is_tool") - callback_functions = self._filter_functions(all_functions, "is_callback") - output_pydantic_functions = self._filter_functions( - all_functions, "is_output_pydantic" + } + + after_kickoff_callbacks = _filter_methods(original_methods, "is_after_kickoff") + after_kickoff_callbacks["close_mcp_server"] = instance.close_mcp_server + + instance.__crew_metadata__ = CrewMetadata( + original_methods=original_methods, + original_tasks=_filter_methods(original_methods, "is_task"), + original_agents=_filter_methods(original_methods, "is_agent"), + before_kickoff=_filter_methods(original_methods, "is_before_kickoff"), + after_kickoff=after_kickoff_callbacks, + kickoff=_filter_methods(original_methods, "is_kickoff"), + ) + + +def close_mcp_server( + self: CrewInstance, _instance: CrewInstance, outputs: CrewOutput +) -> CrewOutput: + """Stop MCP server adapter and return outputs. + + Args: + self: Crew instance with MCP server adapter. + _instance: Crew instance (unused, required by callback signature). + outputs: Crew execution outputs. + + Returns: + Unmodified crew outputs. + """ + if self._mcp_server_adapter is not None: + try: + self._mcp_server_adapter.stop() + except Exception as e: + logging.warning(f"Error stopping MCP server: {e}") + return outputs + + +def get_mcp_tools(self: CrewInstance, *tool_names: str) -> list[BaseTool]: + """Get MCP tools filtered by name. + + Args: + self: Crew instance with MCP server configuration. + *tool_names: Optional tool names to filter by. + + Returns: + List of filtered MCP tools, or empty list if no MCP server configured. + """ + if not self.mcp_server_params: + return [] + + from crewai_tools import MCPServerAdapter # type: ignore[import-untyped] + + if self._mcp_server_adapter is None: + self._mcp_server_adapter = MCPServerAdapter( + self.mcp_server_params, connect_timeout=self.mcp_connect_timeout + ) + + return self._mcp_server_adapter.tools.filter_by_names(tool_names or None) + + +def _load_config( + self: CrewInstance, config_path: str | None, config_type: Literal["agent", "task"] +) -> dict[str, Any]: + """Load YAML config file or return empty dict if not found. + + Args: + self: Crew instance with base directory and load_yaml method. + config_path: Relative path to config file. + config_type: Config type for logging, either "agent" or "task". + + Returns: + Config dictionary or empty dict. + """ + if isinstance(config_path, str): + full_path = self.base_directory / config_path + try: + return self.load_yaml(full_path) + except FileNotFoundError: + logging.warning( + f"{config_type.capitalize()} config file not found at {full_path}. " + f"Proceeding with empty {config_type} configurations." ) + return {} + else: + logging.warning( + f"No {config_type} configuration path provided. " + f"Proceeding with empty {config_type} configurations." + ) + return {} - for task_name, task_info in self.tasks_config.items(): - self._map_task_variables( - task_name, - task_info, - agents, - tasks, - output_json_functions, - tool_functions, - callback_functions, - output_pydantic_functions, - ) - def _map_task_variables( - self, - task_name: str, - task_info: dict[str, Any], - agents: dict[str, Callable], - tasks: dict[str, Callable], - output_json_functions: dict[str, Callable], - tool_functions: dict[str, Callable], - callback_functions: dict[str, Callable], - output_pydantic_functions: dict[str, Callable], - ) -> None: - if context_list := task_info.get("context"): - self.tasks_config[task_name]["context"] = [ - tasks[context_task_name]() for context_task_name in context_list - ] +def load_configurations(self: CrewInstance) -> None: + """Load agent and task YAML configurations. - if tools := task_info.get("tools"): - self.tasks_config[task_name]["tools"] = [ - tool_functions[tool]() for tool in tools - ] + Args: + self: Crew instance with configuration paths. + """ + self.agents_config = self._load_config(self.original_agents_config_path, "agent") + self.tasks_config = self._load_config(self.original_tasks_config_path, "task") - if agent_name := task_info.get("agent"): - self.tasks_config[task_name]["agent"] = agents[agent_name]() - if output_json := task_info.get("output_json"): - self.tasks_config[task_name]["output_json"] = output_json_functions[ - output_json - ] +def load_yaml(config_path: Path) -> dict[str, Any]: + """Load and parse YAML configuration file. - if output_pydantic := task_info.get("output_pydantic"): - self.tasks_config[task_name]["output_pydantic"] = ( - output_pydantic_functions[output_pydantic] - ) + Args: + config_path: Path to YAML configuration file. - if callbacks := task_info.get("callbacks"): - self.tasks_config[task_name]["callbacks"] = [ - callback_functions[callback]() for callback in callbacks - ] + Returns: + Parsed YAML content as a dictionary. Returns empty dict if file is empty. - if guardrail := task_info.get("guardrail"): - self.tasks_config[task_name]["guardrail"] = guardrail + Raises: + FileNotFoundError: If config file does not exist. + """ + try: + with open(config_path, encoding="utf-8") as file: + content = yaml.safe_load(file) + return content if isinstance(content, dict) else {} + except FileNotFoundError: + logging.warning(f"File not found: {config_path}") + raise - # Include base class (qual)name in the wrapper class (qual)name. - WrappedClass.__name__ = CrewBase.__name__ + "(" + cls.__name__ + ")" - WrappedClass.__qualname__ = CrewBase.__qualname__ + "(" + cls.__name__ + ")" - WrappedClass._crew_name = cls.__name__ - return cast(T, WrappedClass) +def _get_all_methods(self: CrewInstance) -> dict[str, Callable[..., Any]]: + """Return all non-dunder callable attributes (methods). + + Args: + self: Instance to inspect for callable attributes. + + Returns: + Dictionary mapping method names to bound method objects. + """ + return { + name: getattr(self, name) + for name in dir(self) + if not (name.startswith("__") and name.endswith("__")) + and callable(getattr(self, name, None)) + } + + +def _filter_methods( + methods: dict[str, CallableT], attribute: str +) -> dict[str, CallableT]: + """Filter methods by attribute presence, preserving exact callable types. + + Args: + methods: Dictionary of methods to filter. + attribute: Attribute name to check for. + + Returns: + Dictionary containing only methods with the specified attribute. + The return type matches the input callable type exactly. + """ + return { + name: method for name, method in methods.items() if hasattr(method, attribute) + } + + +def map_all_agent_variables(self: CrewInstance) -> None: + """Map agent configuration variables to callable instances. + + Args: + self: Crew instance with agent configurations to map. + """ + llms = _filter_methods(self._all_methods, "is_llm") + tool_functions = _filter_methods(self._all_methods, "is_tool") + cache_handler_functions = _filter_methods(self._all_methods, "is_cache_handler") + callbacks = _filter_methods(self._all_methods, "is_callback") + + for agent_name, agent_info in self.agents_config.items(): + self._map_agent_variables( + agent_name=agent_name, + agent_info=agent_info, + llms=llms, + tool_functions=tool_functions, + cache_handler_functions=cache_handler_functions, + callbacks=callbacks, + ) + + +def _map_agent_variables( + self: CrewInstance, + agent_name: str, + agent_info: AgentConfig, + llms: dict[str, Callable[[], Any]], + tool_functions: dict[str, Callable[[], BaseTool]], + cache_handler_functions: dict[str, Callable[[], Any]], + callbacks: dict[str, Callable[..., Any]], +) -> None: + """Resolve and map variables for a single agent. + + Args: + self: Crew instance with agent configurations. + agent_name: Name of agent to configure. + agent_info: Agent configuration dictionary with optional fields. + llms: Dictionary mapping names to LLM factory functions. + tool_functions: Dictionary mapping names to tool factory functions. + cache_handler_functions: Dictionary mapping names to cache handler factory functions. + callbacks: Dictionary of available callbacks. + """ + if llm := agent_info.get("llm"): + factory = llms.get(llm) + self.agents_config[agent_name]["llm"] = factory() if factory else llm + + if tools := agent_info.get("tools"): + if _is_string_list(tools): + self.agents_config[agent_name]["tools"] = [ + tool_functions[tool]() for tool in tools + ] + + if function_calling_llm := agent_info.get("function_calling_llm"): + factory = llms.get(function_calling_llm) + self.agents_config[agent_name]["function_calling_llm"] = ( + factory() if factory else function_calling_llm + ) + + if step_callback := agent_info.get("step_callback"): + self.agents_config[agent_name]["step_callback"] = callbacks[step_callback]() + + if cache_handler := agent_info.get("cache_handler"): + if _is_string_value(cache_handler): + self.agents_config[agent_name]["cache_handler"] = cache_handler_functions[ + cache_handler + ]() + + +def map_all_task_variables(self: CrewInstance) -> None: + """Map task configuration variables to callable instances. + + Args: + self: Crew instance with task configurations to map. + """ + agents = _filter_methods(self._all_methods, "is_agent") + tasks = _filter_methods(self._all_methods, "is_task") + output_json_functions = _filter_methods(self._all_methods, "is_output_json") + tool_functions = _filter_methods(self._all_methods, "is_tool") + callback_functions = _filter_methods(self._all_methods, "is_callback") + output_pydantic_functions = _filter_methods(self._all_methods, "is_output_pydantic") + + for task_name, task_info in self.tasks_config.items(): + self._map_task_variables( + task_name=task_name, + task_info=task_info, + agents=agents, + tasks=tasks, + output_json_functions=output_json_functions, + tool_functions=tool_functions, + callback_functions=callback_functions, + output_pydantic_functions=output_pydantic_functions, + ) + + +def _map_task_variables( + self: CrewInstance, + task_name: str, + task_info: TaskConfig, + agents: dict[str, Callable[[], Agent]], + tasks: dict[str, Callable[[], Task]], + output_json_functions: dict[str, OutputJsonClass[Any]], + tool_functions: dict[str, Callable[[], BaseTool]], + callback_functions: dict[str, Callable[..., Any]], + output_pydantic_functions: dict[str, OutputPydanticClass[Any]], +) -> None: + """Resolve and map variables for a single task. + + Args: + self: Crew instance with task configurations. + task_name: Name of task to configure. + task_info: Task configuration dictionary with optional fields. + agents: Dictionary mapping names to agent factory functions. + tasks: Dictionary mapping names to task factory functions. + output_json_functions: Dictionary of JSON output class wrappers. + tool_functions: Dictionary mapping names to tool factory functions. + callback_functions: Dictionary of available callbacks. + output_pydantic_functions: Dictionary of Pydantic output class wrappers. + """ + if context_list := task_info.get("context"): + self.tasks_config[task_name]["context"] = [ + tasks[context_task_name]() for context_task_name in context_list + ] + + if tools := task_info.get("tools"): + if _is_string_list(tools): + self.tasks_config[task_name]["tools"] = [ + tool_functions[tool]() for tool in tools + ] + + if agent_name := task_info.get("agent"): + self.tasks_config[task_name]["agent"] = agents[agent_name]() + + if output_json := task_info.get("output_json"): + self.tasks_config[task_name]["output_json"] = output_json_functions[output_json] + + if output_pydantic := task_info.get("output_pydantic"): + self.tasks_config[task_name]["output_pydantic"] = output_pydantic_functions[ + output_pydantic + ] + + if callbacks := task_info.get("callbacks"): + self.tasks_config[task_name]["callbacks"] = [ + callback_functions[callback]() for callback in callbacks + ] + + if guardrail := task_info.get("guardrail"): + self.tasks_config[task_name]["guardrail"] = guardrail + + +_CLASS_SETUP_FUNCTIONS: tuple[Callable[[type[CrewClass]], None], ...] = ( + _set_base_directory, + _set_config_paths, + _set_mcp_params, +) + +_METHODS_TO_INJECT = ( + close_mcp_server, + get_mcp_tools, + _load_config, + load_configurations, + staticmethod(load_yaml), + map_all_agent_variables, + _map_agent_variables, + map_all_task_variables, + _map_task_variables, +) + + +class _CrewBaseType(type): + """Metaclass for CrewBase that makes it callable as a decorator.""" + + def __call__(cls, decorated_cls: type) -> type[CrewClass]: + """Apply CrewBaseMeta to the decorated class. + + Args: + decorated_cls: Class to transform with CrewBaseMeta metaclass. + + Returns: + New class with CrewBaseMeta metaclass applied. + """ + __name = str(decorated_cls.__name__) + __bases = tuple(decorated_cls.__bases__) + __dict = { + key: value + for key, value in decorated_cls.__dict__.items() + if key not in ("__dict__", "__weakref__") + } + for slot in __dict.get("__slots__", tuple()): + __dict.pop(slot, None) + __dict["__metaclass__"] = CrewBaseMeta + return cast(type[CrewClass], CrewBaseMeta(__name, __bases, __dict)) + + +class CrewBase(metaclass=_CrewBaseType): + """Class decorator that applies CrewBaseMeta metaclass. + + Applies CrewBaseMeta metaclass to a class via decorator syntax rather than + explicit metaclass declaration. Use as @CrewBase instead of + class Foo(metaclass=CrewBaseMeta). + + Note: + Reference: https://stackoverflow.com/questions/11091609/setting-a-class-metaclass-using-a-decorator + """ diff --git a/lib/crewai/src/crewai/project/utils.py b/lib/crewai/src/crewai/project/utils.py index e8876d941..4d73145c2 100644 --- a/lib/crewai/src/crewai/project/utils.py +++ b/lib/crewai/src/crewai/project/utils.py @@ -1,14 +1,42 @@ +"""Utility functions for the crewai project module.""" + +from collections.abc import Callable from functools import wraps +from typing import Any, ParamSpec, TypeVar -def memoize(func): - cache = {} +P = ParamSpec("P") +R = TypeVar("R") - @wraps(func) - def memoized_func(*args, **kwargs): + +def memoize(meth: Callable[P, R]) -> Callable[P, R]: + """Memoize a method by caching its results based on arguments. + + Args: + meth: The method to memoize. + + Returns: + A memoized version of the method that caches results. + + Notes: + - TODO: Need to make this thread-safe for concurrent access, prevent memory leaks. + """ + cache: dict[Any, R] = {} + + @wraps(meth) + def memoized_func(*args: P.args, **kwargs: P.kwargs) -> R: + """Memoized wrapper method. + + Args: + *args: Positional arguments to pass to the method. + **kwargs: Keyword arguments to pass to the method. + + Returns: + The cached or computed result of the method. + """ key = (args, tuple(kwargs.items())) if key not in cache: - cache[key] = func(*args, **kwargs) + cache[key] = meth(*args, **kwargs) return cache[key] return memoized_func diff --git a/lib/crewai/src/crewai/project/wrappers.py b/lib/crewai/src/crewai/project/wrappers.py new file mode 100644 index 000000000..566dd2268 --- /dev/null +++ b/lib/crewai/src/crewai/project/wrappers.py @@ -0,0 +1,389 @@ +"""Wrapper classes for decorated methods with type-safe metadata.""" + +from __future__ import annotations + +from collections.abc import Callable +from functools import partial +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Literal, + ParamSpec, + Protocol, + TypeVar, + TypedDict, +) + +from typing_extensions import Self + + +if TYPE_CHECKING: + from crewai import Agent, Task + from crewai.crews.crew_output import CrewOutput + from crewai.tools import BaseTool + + +class CrewMetadata(TypedDict): + """Type definition for crew metadata dictionary. + + Stores framework-injected metadata about decorated methods and callbacks. + """ + + original_methods: dict[str, Callable[..., Any]] + original_tasks: dict[str, Callable[..., Task]] + original_agents: dict[str, Callable[..., Agent]] + before_kickoff: dict[str, Callable[..., Any]] + after_kickoff: dict[str, Callable[..., Any]] + kickoff: dict[str, Callable[..., Any]] + + +P = ParamSpec("P") +R = TypeVar("R") +T = TypeVar("T") + + +class TaskResult(Protocol): + """Protocol for task objects that have a name attribute.""" + + name: str | None + + +TaskResultT = TypeVar("TaskResultT", bound=TaskResult) + + +def _copy_method_metadata(wrapper: Any, meth: Callable[..., Any]) -> None: + """Copy method metadata to a wrapper object. + + Args: + wrapper: The wrapper object to update. + meth: The method to copy metadata from. + """ + wrapper.__name__ = meth.__name__ + wrapper.__doc__ = meth.__doc__ + + +class CrewInstance(Protocol): + """Protocol for crew class instances with required attributes.""" + + __crew_metadata__: CrewMetadata + _mcp_server_adapter: Any + _all_methods: dict[str, Callable[..., Any]] + agents: list[Agent] + tasks: list[Task] + base_directory: Path + original_agents_config_path: str + original_tasks_config_path: str + agents_config: dict[str, Any] + tasks_config: dict[str, Any] + mcp_server_params: Any + mcp_connect_timeout: int + + def load_configurations(self) -> None: ... + def map_all_agent_variables(self) -> None: ... + def map_all_task_variables(self) -> None: ... + def close_mcp_server(self, instance: Self, outputs: CrewOutput) -> CrewOutput: ... + def _load_config( + self, config_path: str | None, config_type: Literal["agent", "task"] + ) -> dict[str, Any]: ... + def _map_agent_variables( + self, + agent_name: str, + agent_info: dict[str, Any], + llms: dict[str, Callable[..., Any]], + tool_functions: dict[str, Callable[..., Any]], + cache_handler_functions: dict[str, Callable[..., Any]], + callbacks: dict[str, Callable[..., Any]], + ) -> None: ... + def _map_task_variables( + self, + task_name: str, + task_info: dict[str, Any], + agents: dict[str, Callable[..., Any]], + tasks: dict[str, Callable[..., Any]], + output_json_functions: dict[str, Callable[..., Any]], + tool_functions: dict[str, Callable[..., Any]], + callback_functions: dict[str, Callable[..., Any]], + output_pydantic_functions: dict[str, Callable[..., Any]], + ) -> None: ... + def load_yaml(self, config_path: Path) -> dict[str, Any]: ... + + +class CrewClass(Protocol): + """Protocol describing class attributes injected by CrewBaseMeta.""" + + is_crew_class: bool + _crew_name: str + base_directory: Path + original_agents_config_path: str + original_tasks_config_path: str + mcp_server_params: Any + mcp_connect_timeout: int + close_mcp_server: Callable[..., Any] + get_mcp_tools: Callable[..., list[BaseTool]] + _load_config: Callable[..., dict[str, Any]] + load_configurations: Callable[..., None] + load_yaml: staticmethod + map_all_agent_variables: Callable[..., None] + _map_agent_variables: Callable[..., None] + map_all_task_variables: Callable[..., None] + _map_task_variables: Callable[..., None] + + +class DecoratedMethod(Generic[P, R]): + """Base wrapper for methods with decorator metadata. + + This class provides a type-safe way to add metadata to methods + while preserving their callable signature and attributes. + """ + + def __init__(self, meth: Callable[P, R]) -> None: + """Initialize the decorated method wrapper. + + Args: + meth: The method to wrap. + """ + self._meth = meth + _copy_method_metadata(self, meth) + + def __get__( + self, obj: Any, objtype: type[Any] | None = None + ) -> Self | Callable[..., R]: + """Support instance methods by implementing the descriptor protocol. + + Args: + obj: The instance that the method is accessed through. + objtype: The type of the instance. + + Returns: + Self when accessed through class, bound method when accessed through instance. + """ + if obj is None: + return self + bound = partial(self._meth, obj) + for attr in ( + "is_agent", + "is_llm", + "is_tool", + "is_callback", + "is_cache_handler", + "is_before_kickoff", + "is_after_kickoff", + "is_crew", + ): + if hasattr(self, attr): + setattr(bound, attr, getattr(self, attr)) + return bound + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: + """Call the wrapped method. + + Args: + *args: Positional arguments. + **kwargs: Keyword arguments. + + Returns: + The result of calling the wrapped method. + """ + return self._meth(*args, **kwargs) + + def unwrap(self) -> Callable[P, R]: + """Get the original unwrapped method. + + Returns: + The original method before decoration. + """ + return self._meth + + +class BeforeKickoffMethod(DecoratedMethod[P, R]): + """Wrapper for methods marked to execute before crew kickoff.""" + + is_before_kickoff: bool = True + + +class AfterKickoffMethod(DecoratedMethod[P, R]): + """Wrapper for methods marked to execute after crew kickoff.""" + + is_after_kickoff: bool = True + + +class BoundTaskMethod(Generic[TaskResultT]): + """Bound task method with task marker attribute.""" + + is_task: bool = True + + def __init__(self, task_method: TaskMethod[Any, TaskResultT], obj: Any) -> None: + """Initialize the bound task method. + + Args: + task_method: The TaskMethod descriptor instance. + obj: The instance to bind to. + """ + self._task_method = task_method + self._obj = obj + + def __call__(self, *args: Any, **kwargs: Any) -> TaskResultT: + """Execute the bound task method. + + Args: + *args: Positional arguments. + **kwargs: Keyword arguments. + + Returns: + The task result with name ensured. + """ + result = self._task_method.unwrap()(self._obj, *args, **kwargs) + return self._task_method.ensure_task_name(result) + + +class TaskMethod(Generic[P, TaskResultT]): + """Wrapper for methods marked as crew tasks.""" + + is_task: bool = True + + def __init__(self, meth: Callable[P, TaskResultT]) -> None: + """Initialize the task method wrapper. + + Args: + meth: The method to wrap. + """ + self._meth = meth + _copy_method_metadata(self, meth) + + def ensure_task_name(self, result: TaskResultT) -> TaskResultT: + """Ensure task result has a name set. + + Args: + result: The task result to check. + + Returns: + The task result with name ensured. + """ + if not result.name: + result.name = self._meth.__name__ + return result + + def __get__( + self, obj: Any, objtype: type[Any] | None = None + ) -> Self | BoundTaskMethod[TaskResultT]: + """Support instance methods by implementing the descriptor protocol. + + Args: + obj: The instance that the method is accessed through. + objtype: The type of the instance. + + Returns: + Self when accessed through class, bound method when accessed through instance. + """ + if obj is None: + return self + return BoundTaskMethod(self, obj) + + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> TaskResultT: + """Call the wrapped method and set task name if not provided. + + Args: + *args: Positional arguments. + **kwargs: Keyword arguments. + + Returns: + The task instance with name set if not already provided. + """ + return self.ensure_task_name(self._meth(*args, **kwargs)) + + def unwrap(self) -> Callable[P, TaskResultT]: + """Get the original unwrapped method. + + Returns: + The original method before decoration. + """ + return self._meth + + +class AgentMethod(DecoratedMethod[P, R]): + """Wrapper for methods marked as crew agents.""" + + is_agent: bool = True + + +class LLMMethod(DecoratedMethod[P, R]): + """Wrapper for methods marked as LLM providers.""" + + is_llm: bool = True + + +class ToolMethod(DecoratedMethod[P, R]): + """Wrapper for methods marked as crew tools.""" + + is_tool: bool = True + + +class CallbackMethod(DecoratedMethod[P, R]): + """Wrapper for methods marked as crew callbacks.""" + + is_callback: bool = True + + +class CacheHandlerMethod(DecoratedMethod[P, R]): + """Wrapper for methods marked as cache handlers.""" + + is_cache_handler: bool = True + + +class CrewMethod(DecoratedMethod[P, R]): + """Wrapper for methods marked as the main crew execution point.""" + + is_crew: bool = True + + +class OutputClass(Generic[T]): + """Base wrapper for classes marked as output format.""" + + def __init__(self, cls: type[T]) -> None: + """Initialize the output class wrapper. + + Args: + cls: The class to wrap. + """ + self._cls = cls + self.__name__ = cls.__name__ + self.__qualname__ = cls.__qualname__ + self.__module__ = cls.__module__ + self.__doc__ = cls.__doc__ + + def __call__(self, *args: Any, **kwargs: Any) -> T: + """Create an instance of the wrapped class. + + Args: + *args: Positional arguments for the class constructor. + **kwargs: Keyword arguments for the class constructor. + + Returns: + An instance of the wrapped class. + """ + return self._cls(*args, **kwargs) + + def __getattr__(self, name: str) -> Any: + """Delegate attribute access to the wrapped class. + + Args: + name: The attribute name. + + Returns: + The attribute from the wrapped class. + """ + return getattr(self._cls, name) + + +class OutputJsonClass(OutputClass[T]): + """Wrapper for classes marked as JSON output format.""" + + is_output_json: bool = True + + +class OutputPydanticClass(OutputClass[T]): + """Wrapper for classes marked as Pydantic output format.""" + + is_output_pydantic: bool = True