mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
refactor(project): improve type safety and consolidate crew metadata
This commit is contained in:
@@ -10,6 +10,7 @@ from crewai.project.utils import memoize
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai import Agent, Crew, Task
|
||||
|
||||
from crewai.project.wrappers import (
|
||||
AfterKickoffMethod,
|
||||
AgentMethod,
|
||||
@@ -181,8 +182,8 @@ def crew(
|
||||
agent_roles: set[str] = set()
|
||||
|
||||
# Use the preserved task and agent information
|
||||
tasks = self._original_tasks.items()
|
||||
agents = self._original_agents.items()
|
||||
tasks = self.__crew_metadata__["original_tasks"].items()
|
||||
agents = self.__crew_metadata__["original_agents"].items()
|
||||
|
||||
# Instantiate tasks in order
|
||||
for _, task_method in tasks:
|
||||
@@ -232,11 +233,11 @@ def crew(
|
||||
|
||||
return bound_callback
|
||||
|
||||
for hook_callback in self._before_kickoff.values():
|
||||
for hook_callback in self.__crew_metadata__["before_kickoff"].values():
|
||||
crew_instance.before_kickoff_callbacks.append(
|
||||
callback_wrapper(hook_callback, self)
|
||||
)
|
||||
for hook_callback in self._after_kickoff.values():
|
||||
for hook_callback in self.__crew_metadata__["after_kickoff"].values():
|
||||
crew_instance.after_kickoff_callbacks.append(
|
||||
callback_wrapper(hook_callback, self)
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Base metaclass for creating crew classes with configuration and function management."""
|
||||
"""Base metaclass for creating crew classes with configuration and method management."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -6,19 +6,118 @@ import inspect
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeGuard, TypeVar, cast
|
||||
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from crewai.project.wrappers import CrewClass, CrewMetadata
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai import Agent, Task
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.project.wrappers import CrewInstance
|
||||
from crewai.project.wrappers import (
|
||||
CrewInstance,
|
||||
OutputJsonClass,
|
||||
OutputPydanticClass,
|
||||
)
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
class AgentConfig(TypedDict, total=False):
|
||||
"""Type definition for agent configuration dictionary.
|
||||
|
||||
All fields are optional as they come from YAML configuration files.
|
||||
Fields can be either string references (from YAML) or actual instances (after processing).
|
||||
"""
|
||||
|
||||
llm: str
|
||||
tools: list[str] | list[BaseTool]
|
||||
function_calling_llm: str
|
||||
step_callback: str
|
||||
cache_handler: str | CacheHandler
|
||||
|
||||
|
||||
class TaskConfig(TypedDict, total=False):
|
||||
"""Type definition for task configuration dictionary.
|
||||
|
||||
All fields are optional as they come from YAML configuration files.
|
||||
Fields can be either string references (from YAML) or actual instances (after processing).
|
||||
"""
|
||||
|
||||
context: list[str]
|
||||
tools: list[str] | list[BaseTool]
|
||||
agent: str
|
||||
output_json: str
|
||||
output_pydantic: str
|
||||
callbacks: list[str]
|
||||
guardrail: Callable[[TaskOutput], tuple[bool, Any]] | str
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def _set_base_directory(cls: type[CrewClass]) -> None:
|
||||
"""Set the base directory for the crew class.
|
||||
|
||||
Args:
|
||||
cls: Crew class to configure.
|
||||
"""
|
||||
try:
|
||||
cls.base_directory = Path(inspect.getfile(cls)).parent
|
||||
except (TypeError, OSError):
|
||||
cls.base_directory = Path.cwd()
|
||||
|
||||
|
||||
def _set_config_paths(cls: type[CrewClass]) -> None:
|
||||
"""Set the configuration file paths for the crew class.
|
||||
|
||||
Args:
|
||||
cls: Crew class to configure.
|
||||
"""
|
||||
cls.original_agents_config_path = getattr(
|
||||
cls, "agents_config", "config/agents.yaml"
|
||||
)
|
||||
cls.original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml")
|
||||
|
||||
|
||||
def _set_mcp_params(cls: type[CrewClass]) -> None:
|
||||
"""Set the MCP server parameters for the crew class.
|
||||
|
||||
Args:
|
||||
cls: Crew class to configure.
|
||||
"""
|
||||
cls.mcp_server_params = getattr(cls, "mcp_server_params", None)
|
||||
cls.mcp_connect_timeout = getattr(cls, "mcp_connect_timeout", 30)
|
||||
|
||||
|
||||
def _is_string_list(value: list[str] | list[BaseTool]) -> TypeGuard[list[str]]:
|
||||
"""Type guard to check if list contains strings rather than BaseTool instances.
|
||||
|
||||
Args:
|
||||
value: List that may contain strings or BaseTool instances.
|
||||
|
||||
Returns:
|
||||
True if all elements are strings, False otherwise.
|
||||
"""
|
||||
return all(isinstance(item, str) for item in value)
|
||||
|
||||
|
||||
def _is_string_value(value: str | CacheHandler) -> TypeGuard[str]:
|
||||
"""Type guard to check if value is a string rather than a CacheHandler instance.
|
||||
|
||||
Args:
|
||||
value: Value that may be a string or CacheHandler instance.
|
||||
|
||||
Returns:
|
||||
True if value is a string, False otherwise.
|
||||
"""
|
||||
return isinstance(value, str)
|
||||
|
||||
|
||||
class CrewBaseMeta(type):
|
||||
"""Metaclass that adds crew functionality to classes."""
|
||||
@@ -29,7 +128,7 @@ class CrewBaseMeta(type):
|
||||
bases: tuple[type, ...],
|
||||
namespace: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> type:
|
||||
) -> type[CrewClass]:
|
||||
"""Create crew class with configuration and method injection.
|
||||
|
||||
Args:
|
||||
@@ -41,35 +140,18 @@ class CrewBaseMeta(type):
|
||||
Returns:
|
||||
New crew class with injected methods and attributes.
|
||||
"""
|
||||
cls = super().__new__(mcs, name, bases, namespace)
|
||||
|
||||
cls.is_crew_class = True # type: ignore[attr-defined]
|
||||
cls._crew_name = name # type: ignore[attr-defined]
|
||||
|
||||
try:
|
||||
cls.base_directory = Path(inspect.getfile(cls)).parent # type: ignore[attr-defined]
|
||||
except (TypeError, OSError):
|
||||
cls.base_directory = Path.cwd() # type: ignore[attr-defined]
|
||||
|
||||
cls.original_agents_config_path = getattr( # type: ignore[attr-defined]
|
||||
cls, "agents_config", "config/agents.yaml"
|
||||
)
|
||||
cls.original_tasks_config_path = getattr( # type: ignore[attr-defined]
|
||||
cls, "tasks_config", "config/tasks.yaml"
|
||||
cls = cast(
|
||||
type[CrewClass], cast(object, super().__new__(mcs, name, bases, namespace))
|
||||
)
|
||||
|
||||
cls.mcp_server_params = getattr(cls, "mcp_server_params", None) # type: ignore[attr-defined]
|
||||
cls.mcp_connect_timeout = getattr(cls, "mcp_connect_timeout", 30) # type: ignore[attr-defined]
|
||||
cls.is_crew_class = True
|
||||
cls._crew_name = name
|
||||
|
||||
cls._close_mcp_server = _close_mcp_server # type: ignore[attr-defined]
|
||||
cls.get_mcp_tools = get_mcp_tools # type: ignore[attr-defined]
|
||||
cls._load_config = _load_config # type: ignore[attr-defined]
|
||||
cls.load_configurations = load_configurations # type: ignore[attr-defined]
|
||||
cls.load_yaml = staticmethod(load_yaml) # type: ignore[attr-defined]
|
||||
cls.map_all_agent_variables = map_all_agent_variables # type: ignore[attr-defined]
|
||||
cls._map_agent_variables = _map_agent_variables # type: ignore[attr-defined]
|
||||
cls.map_all_task_variables = map_all_task_variables # type: ignore[attr-defined]
|
||||
cls._map_task_variables = _map_task_variables # type: ignore[attr-defined]
|
||||
for setup_fn in _CLASS_SETUP_FUNCTIONS:
|
||||
setup_fn(cls)
|
||||
|
||||
for method in _METHODS_TO_INJECT:
|
||||
setattr(cls, method.__name__, method)
|
||||
|
||||
return cls
|
||||
|
||||
@@ -97,11 +179,11 @@ class CrewBaseMeta(type):
|
||||
"""
|
||||
instance._mcp_server_adapter = None
|
||||
instance.load_configurations()
|
||||
instance._all_functions = _get_all_functions(instance)
|
||||
instance._all_methods = _get_all_methods(instance)
|
||||
instance.map_all_agent_variables()
|
||||
instance.map_all_task_variables()
|
||||
|
||||
instance._original_functions = {
|
||||
original_methods = {
|
||||
name: method
|
||||
for name, method in cls.__dict__.items()
|
||||
if any(
|
||||
@@ -116,23 +198,17 @@ class CrewBaseMeta(type):
|
||||
)
|
||||
}
|
||||
|
||||
instance._original_tasks = _filter_functions(
|
||||
instance._original_functions, "is_task"
|
||||
)
|
||||
instance._original_agents = _filter_functions(
|
||||
instance._original_functions, "is_agent"
|
||||
)
|
||||
instance._before_kickoff = _filter_functions(
|
||||
instance._original_functions, "is_before_kickoff"
|
||||
)
|
||||
instance._after_kickoff = _filter_functions(
|
||||
instance._original_functions, "is_after_kickoff"
|
||||
)
|
||||
instance._kickoff = _filter_functions(
|
||||
instance._original_functions, "is_kickoff"
|
||||
)
|
||||
after_kickoff_callbacks = _filter_methods(original_methods, "is_after_kickoff")
|
||||
after_kickoff_callbacks["_close_mcp_server"] = instance._close_mcp_server
|
||||
|
||||
instance._after_kickoff["_close_mcp_server"] = instance._close_mcp_server
|
||||
instance.__crew_metadata__ = CrewMetadata(
|
||||
original_methods=original_methods,
|
||||
original_tasks=_filter_methods(original_methods, "is_task"),
|
||||
original_agents=_filter_methods(original_methods, "is_agent"),
|
||||
before_kickoff=_filter_methods(original_methods, "is_before_kickoff"),
|
||||
after_kickoff=after_kickoff_callbacks,
|
||||
kickoff=_filter_methods(original_methods, "is_kickoff"),
|
||||
)
|
||||
|
||||
|
||||
def _close_mcp_server(
|
||||
@@ -156,7 +232,7 @@ def _close_mcp_server(
|
||||
return outputs
|
||||
|
||||
|
||||
def get_mcp_tools(self: Any, *tool_names: str) -> list[BaseTool]:
|
||||
def get_mcp_tools(self: CrewInstance, *tool_names: str) -> list[BaseTool]:
|
||||
"""Get MCP tools filtered by name.
|
||||
|
||||
Args:
|
||||
@@ -180,13 +256,14 @@ def get_mcp_tools(self: Any, *tool_names: str) -> list[BaseTool]:
|
||||
|
||||
|
||||
def _load_config(
|
||||
self: Any, config_path: str | None, config_type: str
|
||||
self: CrewInstance, config_path: str | None, config_type: Literal["agent", "task"]
|
||||
) -> dict[str, Any]:
|
||||
"""Load YAML config file or return empty dict if not found.
|
||||
|
||||
Args:
|
||||
self: Crew instance with base directory and load_yaml method.
|
||||
config_path: Relative path to config file.
|
||||
config_type: Config type for logging (e.g., "agent", "task").
|
||||
config_type: Config type for logging, either "agent" or "task".
|
||||
|
||||
Returns:
|
||||
Config dictionary or empty dict.
|
||||
@@ -209,7 +286,7 @@ def _load_config(
|
||||
return {}
|
||||
|
||||
|
||||
def load_configurations(self: Any) -> None:
|
||||
def load_configurations(self: CrewInstance) -> None:
|
||||
"""Load agent and task YAML configurations.
|
||||
|
||||
Args:
|
||||
@@ -219,34 +296,35 @@ def load_configurations(self: Any) -> None:
|
||||
self.tasks_config = self._load_config(self.original_tasks_config_path, "task")
|
||||
|
||||
|
||||
def load_yaml(config_path: Path) -> Any:
|
||||
"""Load and parse YAML file.
|
||||
def load_yaml(config_path: Path) -> dict[str, Any]:
|
||||
"""Load and parse YAML configuration file.
|
||||
|
||||
Args:
|
||||
config_path: Path to YAML configuration file.
|
||||
|
||||
Returns:
|
||||
Parsed YAML content.
|
||||
Parsed YAML content as a dictionary. Returns empty dict if file is empty.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If config file does not exist.
|
||||
"""
|
||||
try:
|
||||
with open(config_path, encoding="utf-8") as file:
|
||||
return yaml.safe_load(file)
|
||||
content = yaml.safe_load(file)
|
||||
return content if isinstance(content, dict) else {}
|
||||
except FileNotFoundError:
|
||||
print(f"File not found: {config_path}")
|
||||
logging.warning(f"File not found: {config_path}")
|
||||
raise
|
||||
|
||||
|
||||
def _get_all_functions(self: Any) -> dict[str, Callable[..., Any]]:
|
||||
"""Return all non-dunder callable attributes.
|
||||
def _get_all_methods(self: CrewInstance) -> dict[str, Callable[..., Any]]:
|
||||
"""Return all non-dunder callable attributes (methods).
|
||||
|
||||
Args:
|
||||
self: Instance to inspect for callable attributes.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping attribute names to callable objects.
|
||||
Dictionary mapping method names to bound method objects.
|
||||
"""
|
||||
return {
|
||||
name: getattr(self, name)
|
||||
@@ -256,50 +334,53 @@ def _get_all_functions(self: Any) -> dict[str, Callable[..., Any]]:
|
||||
}
|
||||
|
||||
|
||||
def _filter_functions(
|
||||
functions: dict[str, Callable[..., Any]], attribute: str
|
||||
) -> dict[str, Callable[..., Any]]:
|
||||
"""Filter functions by attribute presence.
|
||||
def _filter_methods(
|
||||
methods: dict[str, CallableT], attribute: str
|
||||
) -> dict[str, CallableT]:
|
||||
"""Filter methods by attribute presence, preserving exact callable types.
|
||||
|
||||
Args:
|
||||
functions: Dictionary of functions to filter.
|
||||
methods: Dictionary of methods to filter.
|
||||
attribute: Attribute name to check for.
|
||||
|
||||
Returns:
|
||||
Dictionary containing only functions with the specified attribute.
|
||||
Dictionary containing only methods with the specified attribute.
|
||||
The return type matches the input callable type exactly.
|
||||
"""
|
||||
return {name: func for name, func in functions.items() if hasattr(func, attribute)}
|
||||
return {
|
||||
name: method for name, method in methods.items() if hasattr(method, attribute)
|
||||
}
|
||||
|
||||
|
||||
def map_all_agent_variables(self: Any) -> None:
|
||||
def map_all_agent_variables(self: CrewInstance) -> None:
|
||||
"""Map agent configuration variables to callable instances.
|
||||
|
||||
Args:
|
||||
self: Crew instance with agent configurations to map.
|
||||
"""
|
||||
llms = _filter_functions(self._all_functions, "is_llm")
|
||||
tool_functions = _filter_functions(self._all_functions, "is_tool")
|
||||
cache_handler_functions = _filter_functions(self._all_functions, "is_cache_handler")
|
||||
callbacks = _filter_functions(self._all_functions, "is_callback")
|
||||
llms = _filter_methods(self._all_methods, "is_llm")
|
||||
tool_functions = _filter_methods(self._all_methods, "is_tool")
|
||||
cache_handler_functions = _filter_methods(self._all_methods, "is_cache_handler")
|
||||
callbacks = _filter_methods(self._all_methods, "is_callback")
|
||||
|
||||
for agent_name, agent_info in self.agents_config.items():
|
||||
self._map_agent_variables(
|
||||
agent_name,
|
||||
agent_info,
|
||||
llms,
|
||||
tool_functions,
|
||||
cache_handler_functions,
|
||||
callbacks,
|
||||
agent_name=agent_name,
|
||||
agent_info=agent_info,
|
||||
llms=llms,
|
||||
tool_functions=tool_functions,
|
||||
cache_handler_functions=cache_handler_functions,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
|
||||
def _map_agent_variables(
|
||||
self: Any,
|
||||
self: CrewInstance,
|
||||
agent_name: str,
|
||||
agent_info: dict[str, Any],
|
||||
llms: dict[str, Callable[..., Any]],
|
||||
tool_functions: dict[str, Callable[..., Any]],
|
||||
cache_handler_functions: dict[str, Callable[..., Any]],
|
||||
agent_info: AgentConfig,
|
||||
llms: dict[str, Callable[[], Any]],
|
||||
tool_functions: dict[str, Callable[[], BaseTool]],
|
||||
cache_handler_functions: dict[str, Callable[[], Any]],
|
||||
callbacks: dict[str, Callable[..., Any]],
|
||||
) -> None:
|
||||
"""Resolve and map variables for a single agent.
|
||||
@@ -307,93 +388,87 @@ def _map_agent_variables(
|
||||
Args:
|
||||
self: Crew instance with agent configurations.
|
||||
agent_name: Name of agent to configure.
|
||||
agent_info: Agent configuration dictionary.
|
||||
llms: Dictionary of available LLM providers.
|
||||
tool_functions: Dictionary of available tools.
|
||||
cache_handler_functions: Dictionary of available cache handlers.
|
||||
agent_info: Agent configuration dictionary with optional fields.
|
||||
llms: Dictionary mapping names to LLM factory functions.
|
||||
tool_functions: Dictionary mapping names to tool factory functions.
|
||||
cache_handler_functions: Dictionary mapping names to cache handler factory functions.
|
||||
callbacks: Dictionary of available callbacks.
|
||||
"""
|
||||
if llm := agent_info.get("llm"):
|
||||
try:
|
||||
self.agents_config[agent_name]["llm"] = llms[llm]()
|
||||
except KeyError:
|
||||
self.agents_config[agent_name]["llm"] = llm
|
||||
factory = llms.get(llm)
|
||||
self.agents_config[agent_name]["llm"] = factory() if factory else llm
|
||||
|
||||
if tools := agent_info.get("tools"):
|
||||
self.agents_config[agent_name]["tools"] = [
|
||||
tool_functions[tool]() for tool in tools
|
||||
]
|
||||
if _is_string_list(tools):
|
||||
self.agents_config[agent_name]["tools"] = [
|
||||
tool_functions[tool]() for tool in tools
|
||||
]
|
||||
|
||||
if function_calling_llm := agent_info.get("function_calling_llm"):
|
||||
try:
|
||||
self.agents_config[agent_name]["function_calling_llm"] = llms[
|
||||
function_calling_llm
|
||||
]()
|
||||
except KeyError:
|
||||
self.agents_config[agent_name]["function_calling_llm"] = (
|
||||
function_calling_llm
|
||||
)
|
||||
factory = llms.get(function_calling_llm)
|
||||
self.agents_config[agent_name]["function_calling_llm"] = (
|
||||
factory() if factory else function_calling_llm
|
||||
)
|
||||
|
||||
if step_callback := agent_info.get("step_callback"):
|
||||
self.agents_config[agent_name]["step_callback"] = callbacks[step_callback]()
|
||||
|
||||
if cache_handler := agent_info.get("cache_handler"):
|
||||
self.agents_config[agent_name]["cache_handler"] = cache_handler_functions[
|
||||
cache_handler
|
||||
]()
|
||||
if _is_string_value(cache_handler):
|
||||
self.agents_config[agent_name]["cache_handler"] = cache_handler_functions[
|
||||
cache_handler
|
||||
]()
|
||||
|
||||
|
||||
def map_all_task_variables(self: Any) -> None:
|
||||
def map_all_task_variables(self: CrewInstance) -> None:
|
||||
"""Map task configuration variables to callable instances.
|
||||
|
||||
Args:
|
||||
self: Crew instance with task configurations to map.
|
||||
"""
|
||||
agents = _filter_functions(self._all_functions, "is_agent")
|
||||
tasks = _filter_functions(self._all_functions, "is_task")
|
||||
output_json_functions = _filter_functions(self._all_functions, "is_output_json")
|
||||
tool_functions = _filter_functions(self._all_functions, "is_tool")
|
||||
callback_functions = _filter_functions(self._all_functions, "is_callback")
|
||||
output_pydantic_functions = _filter_functions(
|
||||
self._all_functions, "is_output_pydantic"
|
||||
)
|
||||
agents = _filter_methods(self._all_methods, "is_agent")
|
||||
tasks = _filter_methods(self._all_methods, "is_task")
|
||||
output_json_functions = _filter_methods(self._all_methods, "is_output_json")
|
||||
tool_functions = _filter_methods(self._all_methods, "is_tool")
|
||||
callback_functions = _filter_methods(self._all_methods, "is_callback")
|
||||
output_pydantic_functions = _filter_methods(self._all_methods, "is_output_pydantic")
|
||||
|
||||
for task_name, task_info in self.tasks_config.items():
|
||||
self._map_task_variables(
|
||||
task_name,
|
||||
task_info,
|
||||
agents,
|
||||
tasks,
|
||||
output_json_functions,
|
||||
tool_functions,
|
||||
callback_functions,
|
||||
output_pydantic_functions,
|
||||
task_name=task_name,
|
||||
task_info=task_info,
|
||||
agents=agents,
|
||||
tasks=tasks,
|
||||
output_json_functions=output_json_functions,
|
||||
tool_functions=tool_functions,
|
||||
callback_functions=callback_functions,
|
||||
output_pydantic_functions=output_pydantic_functions,
|
||||
)
|
||||
|
||||
|
||||
def _map_task_variables(
|
||||
self: Any,
|
||||
self: CrewInstance,
|
||||
task_name: str,
|
||||
task_info: dict[str, Any],
|
||||
agents: dict[str, Callable[..., Any]],
|
||||
tasks: dict[str, Callable[..., Any]],
|
||||
output_json_functions: dict[str, Callable[..., Any]],
|
||||
tool_functions: dict[str, Callable[..., Any]],
|
||||
task_info: TaskConfig,
|
||||
agents: dict[str, Callable[[], Agent]],
|
||||
tasks: dict[str, Callable[[], Task]],
|
||||
output_json_functions: dict[str, OutputJsonClass[Any]],
|
||||
tool_functions: dict[str, Callable[[], BaseTool]],
|
||||
callback_functions: dict[str, Callable[..., Any]],
|
||||
output_pydantic_functions: dict[str, Callable[..., Any]],
|
||||
output_pydantic_functions: dict[str, OutputPydanticClass[Any]],
|
||||
) -> None:
|
||||
"""Resolve and map variables for a single task.
|
||||
|
||||
Args:
|
||||
self: Crew instance with task configurations.
|
||||
task_name: Name of task to configure.
|
||||
task_info: Task configuration dictionary.
|
||||
agents: Dictionary of available agents.
|
||||
tasks: Dictionary of available tasks.
|
||||
output_json_functions: Dictionary of available JSON output schemas.
|
||||
tool_functions: Dictionary of available tools.
|
||||
task_info: Task configuration dictionary with optional fields.
|
||||
agents: Dictionary mapping names to agent factory functions.
|
||||
tasks: Dictionary mapping names to task factory functions.
|
||||
output_json_functions: Dictionary of JSON output class wrappers.
|
||||
tool_functions: Dictionary mapping names to tool factory functions.
|
||||
callback_functions: Dictionary of available callbacks.
|
||||
output_pydantic_functions: Dictionary of available Pydantic output schemas.
|
||||
output_pydantic_functions: Dictionary of Pydantic output class wrappers.
|
||||
"""
|
||||
if context_list := task_info.get("context"):
|
||||
self.tasks_config[task_name]["context"] = [
|
||||
@@ -401,9 +476,10 @@ def _map_task_variables(
|
||||
]
|
||||
|
||||
if tools := task_info.get("tools"):
|
||||
self.tasks_config[task_name]["tools"] = [
|
||||
tool_functions[tool]() for tool in tools
|
||||
]
|
||||
if _is_string_list(tools):
|
||||
self.tasks_config[task_name]["tools"] = [
|
||||
tool_functions[tool]() for tool in tools
|
||||
]
|
||||
|
||||
if agent_name := task_info.get("agent"):
|
||||
self.tasks_config[task_name]["agent"] = agents[agent_name]()
|
||||
@@ -425,23 +501,57 @@ def _map_task_variables(
|
||||
self.tasks_config[task_name]["guardrail"] = guardrail
|
||||
|
||||
|
||||
def CrewBase(cls: type) -> type: # noqa: N802
|
||||
"""Apply CrewBaseMeta metaclass to a class for decorator syntax compatibility.
|
||||
_CLASS_SETUP_FUNCTIONS: tuple[Callable[[type[CrewClass]], None], ...] = (
|
||||
_set_base_directory,
|
||||
_set_config_paths,
|
||||
_set_mcp_params,
|
||||
)
|
||||
|
||||
Args:
|
||||
cls: Class to apply metaclass to.
|
||||
_METHODS_TO_INJECT = (
|
||||
_close_mcp_server,
|
||||
get_mcp_tools,
|
||||
_load_config,
|
||||
load_configurations,
|
||||
staticmethod(load_yaml),
|
||||
map_all_agent_variables,
|
||||
_map_agent_variables,
|
||||
map_all_task_variables,
|
||||
_map_task_variables,
|
||||
)
|
||||
|
||||
Returns:
|
||||
New class with CrewBaseMeta metaclass applied.
|
||||
|
||||
class _CrewBaseType(type):
|
||||
"""Metaclass for CrewBase that makes it callable as a decorator."""
|
||||
|
||||
def __call__(cls, decorated_cls: type) -> type[CrewClass]:
|
||||
"""Apply CrewBaseMeta to the decorated class.
|
||||
|
||||
Args:
|
||||
decorated_cls: Class to transform with CrewBaseMeta metaclass.
|
||||
|
||||
Returns:
|
||||
New class with CrewBaseMeta metaclass applied.
|
||||
"""
|
||||
__name = str(decorated_cls.__name__)
|
||||
__bases = tuple(decorated_cls.__bases__)
|
||||
__dict = {
|
||||
key: value
|
||||
for key, value in decorated_cls.__dict__.items()
|
||||
if key not in ("__dict__", "__weakref__")
|
||||
}
|
||||
for slot in __dict.get("__slots__", tuple()):
|
||||
__dict.pop(slot, None)
|
||||
__dict["__metaclass__"] = CrewBaseMeta
|
||||
return cast(type[CrewClass], CrewBaseMeta(__name, __bases, __dict))
|
||||
|
||||
|
||||
class CrewBase(metaclass=_CrewBaseType):
|
||||
"""Class decorator that applies CrewBaseMeta metaclass.
|
||||
|
||||
Applies CrewBaseMeta metaclass to a class via decorator syntax rather than
|
||||
explicit metaclass declaration. Use as @CrewBase instead of
|
||||
class Foo(metaclass=CrewBaseMeta).
|
||||
|
||||
Note:
|
||||
Reference: https://stackoverflow.com/questions/11091609/setting-a-class-metaclass-using-a-decorator
|
||||
"""
|
||||
if isinstance(cls, CrewBaseMeta):
|
||||
return cls
|
||||
|
||||
namespace = {
|
||||
key: value
|
||||
for key, value in cls.__dict__.items()
|
||||
if not key.startswith("__")
|
||||
or key in ("__module__", "__qualname__", "__annotations__")
|
||||
}
|
||||
|
||||
return CrewBaseMeta(cls.__name__, cls.__bases__, namespace)
|
||||
|
||||
@@ -4,11 +4,38 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Generic, ParamSpec, Protocol, Self, TypeVar
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
Self,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai import Agent, Task
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
|
||||
class CrewMetadata(TypedDict):
|
||||
"""Type definition for crew metadata dictionary.
|
||||
|
||||
Stores framework-injected metadata about decorated methods and callbacks.
|
||||
"""
|
||||
|
||||
original_methods: dict[str, Callable[..., Any]]
|
||||
original_tasks: dict[str, Callable[..., Task]]
|
||||
original_agents: dict[str, Callable[..., Agent]]
|
||||
before_kickoff: dict[str, Callable[..., Any]]
|
||||
after_kickoff: dict[str, Callable[..., Any]]
|
||||
kickoff: dict[str, Callable[..., Any]]
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
@@ -24,35 +51,82 @@ class TaskResult(Protocol):
|
||||
TaskResultT = TypeVar("TaskResultT", bound=TaskResult)
|
||||
|
||||
|
||||
def _copy_function_metadata(wrapper: Any, func: Callable[..., Any]) -> None:
|
||||
"""Copy function metadata to a wrapper object.
|
||||
def _copy_method_metadata(wrapper: Any, meth: Callable[..., Any]) -> None:
|
||||
"""Copy method metadata to a wrapper object.
|
||||
|
||||
Args:
|
||||
wrapper: The wrapper object to update.
|
||||
func: The function to copy metadata from.
|
||||
meth: The method to copy metadata from.
|
||||
"""
|
||||
wrapper.__name__ = func.__name__
|
||||
wrapper.__doc__ = func.__doc__
|
||||
wrapper.__name__ = meth.__name__
|
||||
wrapper.__doc__ = meth.__doc__
|
||||
|
||||
|
||||
class CrewInstance(Protocol):
|
||||
"""Protocol for crew class instances with required attributes."""
|
||||
|
||||
__crew_metadata__: CrewMetadata
|
||||
_mcp_server_adapter: Any
|
||||
_all_functions: dict[str, Callable[..., Any]]
|
||||
_original_functions: dict[str, Callable[..., Any]]
|
||||
_original_tasks: dict[str, Callable[[Self], Task]]
|
||||
_original_agents: dict[str, Callable[[Self], Agent]]
|
||||
_before_kickoff: dict[str, Callable[..., Any]]
|
||||
_after_kickoff: dict[str, Callable[..., Any]]
|
||||
_kickoff: dict[str, Callable[..., Any]]
|
||||
_all_methods: dict[str, Callable[..., Any]]
|
||||
agents: list[Agent]
|
||||
tasks: list[Task]
|
||||
base_directory: Path
|
||||
original_agents_config_path: str
|
||||
original_tasks_config_path: str
|
||||
agents_config: dict[str, Any]
|
||||
tasks_config: dict[str, Any]
|
||||
mcp_server_params: Any
|
||||
mcp_connect_timeout: int
|
||||
|
||||
def load_configurations(self) -> None: ...
|
||||
def map_all_agent_variables(self) -> None: ...
|
||||
def map_all_task_variables(self) -> None: ...
|
||||
def _close_mcp_server(self, instance: Self, outputs: CrewOutput) -> CrewOutput: ...
|
||||
def _load_config(
|
||||
self, config_path: str | None, config_type: Literal["agent", "task"]
|
||||
) -> dict[str, Any]: ...
|
||||
def _map_agent_variables(
|
||||
self,
|
||||
agent_name: str,
|
||||
agent_info: dict[str, Any],
|
||||
llms: dict[str, Callable[..., Any]],
|
||||
tool_functions: dict[str, Callable[..., Any]],
|
||||
cache_handler_functions: dict[str, Callable[..., Any]],
|
||||
callbacks: dict[str, Callable[..., Any]],
|
||||
) -> None: ...
|
||||
def _map_task_variables(
|
||||
self,
|
||||
task_name: str,
|
||||
task_info: dict[str, Any],
|
||||
agents: dict[str, Callable[..., Any]],
|
||||
tasks: dict[str, Callable[..., Any]],
|
||||
output_json_functions: dict[str, Callable[..., Any]],
|
||||
tool_functions: dict[str, Callable[..., Any]],
|
||||
callback_functions: dict[str, Callable[..., Any]],
|
||||
output_pydantic_functions: dict[str, Callable[..., Any]],
|
||||
) -> None: ...
|
||||
def load_yaml(self, config_path: Path) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
class CrewClass(Protocol):
|
||||
"""Protocol describing class attributes injected by CrewBaseMeta."""
|
||||
|
||||
is_crew_class: bool
|
||||
_crew_name: str
|
||||
base_directory: Path
|
||||
original_agents_config_path: str
|
||||
original_tasks_config_path: str
|
||||
mcp_server_params: Any
|
||||
mcp_connect_timeout: int
|
||||
_close_mcp_server: Callable[..., Any]
|
||||
get_mcp_tools: Callable[..., list[BaseTool]]
|
||||
_load_config: Callable[..., dict[str, Any]]
|
||||
load_configurations: Callable[..., None]
|
||||
load_yaml: staticmethod
|
||||
map_all_agent_variables: Callable[..., None]
|
||||
_map_agent_variables: Callable[..., None]
|
||||
map_all_task_variables: Callable[..., None]
|
||||
_map_task_variables: Callable[..., None]
|
||||
|
||||
|
||||
class DecoratedMethod(Generic[P, R]):
|
||||
@@ -69,7 +143,7 @@ class DecoratedMethod(Generic[P, R]):
|
||||
meth: The method to wrap.
|
||||
"""
|
||||
self._meth = meth
|
||||
_copy_function_metadata(self, meth)
|
||||
_copy_method_metadata(self, meth)
|
||||
|
||||
def __get__(
|
||||
self, obj: Any, objtype: type[Any] | None = None
|
||||
@@ -158,8 +232,8 @@ class BoundTaskMethod(Generic[TaskResultT]):
|
||||
Returns:
|
||||
The task result with name ensured.
|
||||
"""
|
||||
result = self._task_method._meth(self._obj, *args, **kwargs)
|
||||
return self._task_method._ensure_task_name(result)
|
||||
result = self._task_method.unwrap()(self._obj, *args, **kwargs)
|
||||
return self._task_method.ensure_task_name(result)
|
||||
|
||||
|
||||
class TaskMethod(Generic[P, TaskResultT]):
|
||||
@@ -174,9 +248,9 @@ class TaskMethod(Generic[P, TaskResultT]):
|
||||
meth: The method to wrap.
|
||||
"""
|
||||
self._meth = meth
|
||||
_copy_function_metadata(self, meth)
|
||||
_copy_method_metadata(self, meth)
|
||||
|
||||
def _ensure_task_name(self, result: TaskResultT) -> TaskResultT:
|
||||
def ensure_task_name(self, result: TaskResultT) -> TaskResultT:
|
||||
"""Ensure task result has a name set.
|
||||
|
||||
Args:
|
||||
@@ -215,7 +289,7 @@ class TaskMethod(Generic[P, TaskResultT]):
|
||||
Returns:
|
||||
The task instance with name set if not already provided.
|
||||
"""
|
||||
return self._ensure_task_name(self._meth(*args, **kwargs))
|
||||
return self.ensure_task_name(self._meth(*args, **kwargs))
|
||||
|
||||
def unwrap(self) -> Callable[P, TaskResultT]:
|
||||
"""Get the original unwrapped method.
|
||||
|
||||
Reference in New Issue
Block a user