mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-05 06:08:29 +00:00
Compare commits
16 Commits
fix/unsafe
...
gl/chore/p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6081809c76 | ||
|
|
30a4b712a3 | ||
|
|
8465350f1d | ||
|
|
2a23bc604c | ||
|
|
714f8a8940 | ||
|
|
f0fb349ddf | ||
|
|
bf2e2a42da | ||
|
|
e029de2863 | ||
|
|
814c962196 | ||
|
|
6492852a0c | ||
|
|
fecf7e9a83 | ||
|
|
6bc8818ae9 | ||
|
|
620df71763 | ||
|
|
7d6324dfa3 | ||
|
|
541eec0639 | ||
|
|
2ebb2e845f |
@@ -775,4 +775,3 @@ A: Yes, CrewAI provides extensive beginner-friendly tutorials, courses, and docu
|
||||
### Q: Can CrewAI automate human-in-the-loop workflows?
|
||||
|
||||
A: Yes, CrewAI fully supports human-in-the-loop workflows, allowing seamless collaboration between human experts and AI agents for enhanced decision-making.
|
||||
# test
|
||||
|
||||
@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "0.203.0"
|
||||
__version__ = "0.203.1"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ def validate_jwt_token(
|
||||
algorithms=["RS256"],
|
||||
audience=audience,
|
||||
issuer=issuer,
|
||||
leeway=10.0,
|
||||
options={
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.203.0,<1.0.0"
|
||||
"crewai[tools]>=0.203.1,<1.0.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.203.0,<1.0.0",
|
||||
"crewai[tools]>=0.203.1,<1.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.203.0"
|
||||
"crewai[tools]>=0.203.1"
|
||||
]
|
||||
|
||||
[tool.crewai]
|
||||
|
||||
@@ -358,7 +358,8 @@ def prompt_user_for_trace_viewing(timeout_seconds: int = 20) -> bool:
|
||||
try:
|
||||
response = input().strip().lower()
|
||||
result[0] = response in ["y", "yes"]
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
except (EOFError, KeyboardInterrupt, OSError, LookupError):
|
||||
# Handle all input-related errors silently
|
||||
result[0] = False
|
||||
|
||||
input_thread = threading.Thread(target=get_input, daemon=True)
|
||||
@@ -371,6 +372,7 @@ def prompt_user_for_trace_viewing(timeout_seconds: int = 20) -> bool:
|
||||
return result[0]
|
||||
|
||||
except Exception:
|
||||
# Suppress any warnings or errors and assume "no"
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from .annotations import (
|
||||
"""Project package for CrewAI."""
|
||||
|
||||
from crewai.project.annotations import (
|
||||
after_kickoff,
|
||||
agent,
|
||||
before_kickoff,
|
||||
@@ -11,19 +13,19 @@ from .annotations import (
|
||||
task,
|
||||
tool,
|
||||
)
|
||||
from .crew_base import CrewBase
|
||||
from crewai.project.crew_base import CrewBase
|
||||
|
||||
__all__ = [
|
||||
"CrewBase",
|
||||
"after_kickoff",
|
||||
"agent",
|
||||
"before_kickoff",
|
||||
"cache_handler",
|
||||
"callback",
|
||||
"crew",
|
||||
"task",
|
||||
"llm",
|
||||
"output_json",
|
||||
"output_pydantic",
|
||||
"task",
|
||||
"tool",
|
||||
"callback",
|
||||
"CrewBase",
|
||||
"llm",
|
||||
"cache_handler",
|
||||
"before_kickoff",
|
||||
"after_kickoff",
|
||||
]
|
||||
|
||||
@@ -1,97 +1,192 @@
|
||||
from functools import wraps
|
||||
from typing import Callable
|
||||
|
||||
from crewai import Crew
|
||||
from crewai.project.utils import memoize
|
||||
|
||||
"""Decorators for defining crew components and their behaviors."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
def before_kickoff(func):
|
||||
"""Marks a method to execute before crew kickoff."""
|
||||
func.is_before_kickoff = True
|
||||
return func
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from crewai.project.utils import memoize
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai import Agent, Crew, Task
|
||||
|
||||
from crewai.project.wrappers import (
|
||||
AfterKickoffMethod,
|
||||
AgentMethod,
|
||||
BeforeKickoffMethod,
|
||||
CacheHandlerMethod,
|
||||
CallbackMethod,
|
||||
CrewInstance,
|
||||
LLMMethod,
|
||||
OutputJsonClass,
|
||||
OutputPydanticClass,
|
||||
TaskMethod,
|
||||
TaskResultT,
|
||||
ToolMethod,
|
||||
)
|
||||
|
||||
P = ParamSpec("P")
|
||||
P2 = ParamSpec("P2")
|
||||
R = TypeVar("R")
|
||||
R2 = TypeVar("R2")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def after_kickoff(func):
|
||||
"""Marks a method to execute after crew kickoff."""
|
||||
func.is_after_kickoff = True
|
||||
return func
|
||||
def before_kickoff(meth: Callable[P, R]) -> BeforeKickoffMethod[P, R]:
|
||||
"""Marks a method to execute before crew kickoff.
|
||||
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped method marked for before kickoff execution.
|
||||
"""
|
||||
return BeforeKickoffMethod(meth)
|
||||
|
||||
|
||||
def task(func):
|
||||
"""Marks a method as a crew task."""
|
||||
func.is_task = True
|
||||
def after_kickoff(meth: Callable[P, R]) -> AfterKickoffMethod[P, R]:
|
||||
"""Marks a method to execute after crew kickoff.
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
result = func(*args, **kwargs)
|
||||
if not result.name:
|
||||
result.name = func.__name__
|
||||
return result
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
return memoize(wrapper)
|
||||
Returns:
|
||||
A wrapped method marked for after kickoff execution.
|
||||
"""
|
||||
return AfterKickoffMethod(meth)
|
||||
|
||||
|
||||
def agent(func):
|
||||
"""Marks a method as a crew agent."""
|
||||
func.is_agent = True
|
||||
func = memoize(func)
|
||||
return func
|
||||
def task(meth: Callable[P, TaskResultT]) -> TaskMethod[P, TaskResultT]:
|
||||
"""Marks a method as a crew task.
|
||||
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped method marked as a task with memoization.
|
||||
"""
|
||||
return TaskMethod(memoize(meth))
|
||||
|
||||
|
||||
def llm(func):
|
||||
"""Marks a method as an LLM provider."""
|
||||
func.is_llm = True
|
||||
func = memoize(func)
|
||||
return func
|
||||
def agent(meth: Callable[P, R]) -> AgentMethod[P, R]:
|
||||
"""Marks a method as a crew agent.
|
||||
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped method marked as an agent with memoization.
|
||||
"""
|
||||
return AgentMethod(memoize(meth))
|
||||
|
||||
|
||||
def output_json(cls):
|
||||
"""Marks a class as JSON output format."""
|
||||
cls.is_output_json = True
|
||||
return cls
|
||||
def llm(meth: Callable[P, R]) -> LLMMethod[P, R]:
|
||||
"""Marks a method as an LLM provider.
|
||||
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped method marked as an LLM provider with memoization.
|
||||
"""
|
||||
return LLMMethod(memoize(meth))
|
||||
|
||||
|
||||
def output_pydantic(cls):
|
||||
"""Marks a class as Pydantic output format."""
|
||||
cls.is_output_pydantic = True
|
||||
return cls
|
||||
def output_json(cls: type[T]) -> OutputJsonClass[T]:
|
||||
"""Marks a class as JSON output format.
|
||||
|
||||
Args:
|
||||
cls: The class to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped class marked as JSON output format.
|
||||
"""
|
||||
return OutputJsonClass(cls)
|
||||
|
||||
|
||||
def tool(func):
|
||||
"""Marks a method as a crew tool."""
|
||||
func.is_tool = True
|
||||
return memoize(func)
|
||||
def output_pydantic(cls: type[T]) -> OutputPydanticClass[T]:
|
||||
"""Marks a class as Pydantic output format.
|
||||
|
||||
Args:
|
||||
cls: The class to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped class marked as Pydantic output format.
|
||||
"""
|
||||
return OutputPydanticClass(cls)
|
||||
|
||||
|
||||
def callback(func):
|
||||
"""Marks a method as a crew callback."""
|
||||
func.is_callback = True
|
||||
return memoize(func)
|
||||
def tool(meth: Callable[P, R]) -> ToolMethod[P, R]:
|
||||
"""Marks a method as a crew tool.
|
||||
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped method marked as a tool with memoization.
|
||||
"""
|
||||
return ToolMethod(memoize(meth))
|
||||
|
||||
|
||||
def cache_handler(func):
|
||||
"""Marks a method as a cache handler."""
|
||||
func.is_cache_handler = True
|
||||
return memoize(func)
|
||||
def callback(meth: Callable[P, R]) -> CallbackMethod[P, R]:
|
||||
"""Marks a method as a crew callback.
|
||||
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped method marked as a callback with memoization.
|
||||
"""
|
||||
return CallbackMethod(memoize(meth))
|
||||
|
||||
|
||||
def crew(func) -> Callable[..., Crew]:
|
||||
"""Marks a method as the main crew execution point."""
|
||||
def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
|
||||
"""Marks a method as a cache handler.
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs) -> Crew:
|
||||
instantiated_tasks = []
|
||||
instantiated_agents = []
|
||||
agent_roles = set()
|
||||
Args:
|
||||
meth: The method to mark.
|
||||
|
||||
Returns:
|
||||
A wrapped method marked as a cache handler with memoization.
|
||||
"""
|
||||
return CacheHandlerMethod(memoize(meth))
|
||||
|
||||
|
||||
def crew(
|
||||
meth: Callable[Concatenate[CrewInstance, P], Crew],
|
||||
) -> Callable[Concatenate[CrewInstance, P], Crew]:
|
||||
"""Marks a method as the main crew execution point.
|
||||
|
||||
Args:
|
||||
meth: The method to mark as crew execution point.
|
||||
|
||||
Returns:
|
||||
A wrapped method that instantiates tasks and agents before execution.
|
||||
"""
|
||||
|
||||
@wraps(meth)
|
||||
def wrapper(self: CrewInstance, *args: P.args, **kwargs: P.kwargs) -> Crew:
|
||||
"""Wrapper that sets up crew before calling the decorated method.
|
||||
|
||||
Args:
|
||||
self: The crew class instance.
|
||||
*args: Additional positional arguments.
|
||||
**kwargs: Keyword arguments to pass to the method.
|
||||
|
||||
Returns:
|
||||
The configured Crew instance with callbacks attached.
|
||||
"""
|
||||
instantiated_tasks: list[Task] = []
|
||||
instantiated_agents: list[Agent] = []
|
||||
agent_roles: set[str] = set()
|
||||
|
||||
# Use the preserved task and agent information
|
||||
tasks = self._original_tasks.items()
|
||||
agents = self._original_agents.items()
|
||||
tasks = self.__crew_metadata__["original_tasks"].items()
|
||||
agents = self.__crew_metadata__["original_agents"].items()
|
||||
|
||||
# Instantiate tasks in order
|
||||
for task_name, task_method in tasks:
|
||||
for _, task_method in tasks:
|
||||
task_instance = task_method(self)
|
||||
instantiated_tasks.append(task_instance)
|
||||
agent_instance = getattr(task_instance, "agent", None)
|
||||
@@ -100,7 +195,7 @@ def crew(func) -> Callable[..., Crew]:
|
||||
agent_roles.add(agent_instance.role)
|
||||
|
||||
# Instantiate agents not included by tasks
|
||||
for agent_name, agent_method in agents:
|
||||
for _, agent_method in agents:
|
||||
agent_instance = agent_method(self)
|
||||
if agent_instance.role not in agent_roles:
|
||||
instantiated_agents.append(agent_instance)
|
||||
@@ -109,19 +204,44 @@ def crew(func) -> Callable[..., Crew]:
|
||||
self.agents = instantiated_agents
|
||||
self.tasks = instantiated_tasks
|
||||
|
||||
crew = func(self, *args, **kwargs)
|
||||
crew_instance = meth(self, *args, **kwargs)
|
||||
|
||||
def callback_wrapper(callback, instance):
|
||||
def wrapper(*args, **kwargs):
|
||||
return callback(instance, *args, **kwargs)
|
||||
def callback_wrapper(
|
||||
hook: Callable[Concatenate[CrewInstance, P2], R2], instance: CrewInstance
|
||||
) -> Callable[P2, R2]:
|
||||
"""Bind a hook callback to an instance.
|
||||
|
||||
return wrapper
|
||||
Args:
|
||||
hook: The callback hook to bind.
|
||||
instance: The instance to bind to.
|
||||
|
||||
for _, callback in self._before_kickoff.items():
|
||||
crew.before_kickoff_callbacks.append(callback_wrapper(callback, self))
|
||||
for _, callback in self._after_kickoff.items():
|
||||
crew.after_kickoff_callbacks.append(callback_wrapper(callback, self))
|
||||
Returns:
|
||||
A bound callback function.
|
||||
"""
|
||||
|
||||
return crew
|
||||
def bound_callback(*cb_args: P2.args, **cb_kwargs: P2.kwargs) -> R2:
|
||||
"""Execute the bound callback.
|
||||
|
||||
Args:
|
||||
*cb_args: Positional arguments for the callback.
|
||||
**cb_kwargs: Keyword arguments for the callback.
|
||||
|
||||
Returns:
|
||||
The result of the callback execution.
|
||||
"""
|
||||
return hook(instance, *cb_args, **cb_kwargs)
|
||||
|
||||
return bound_callback
|
||||
|
||||
for hook_callback in self.__crew_metadata__["before_kickoff"].values():
|
||||
crew_instance.before_kickoff_callbacks.append(
|
||||
callback_wrapper(hook_callback, self)
|
||||
)
|
||||
for hook_callback in self.__crew_metadata__["after_kickoff"].values():
|
||||
crew_instance.after_kickoff_callbacks.append(
|
||||
callback_wrapper(hook_callback, self)
|
||||
)
|
||||
|
||||
return crew_instance
|
||||
|
||||
return memoize(wrapper)
|
||||
|
||||
@@ -1,298 +1,631 @@
|
||||
"""Base metaclass for creating crew classes with configuration and method management."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypedDict, TypeGuard, TypeVar, cast
|
||||
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from crewai.project.wrappers import CrewClass, CrewMetadata
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai import Agent, Task
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.project.wrappers import (
|
||||
CrewInstance,
|
||||
OutputJsonClass,
|
||||
OutputPydanticClass,
|
||||
)
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
class AgentConfig(TypedDict, total=False):
|
||||
"""Type definition for agent configuration dictionary.
|
||||
|
||||
All fields are optional as they come from YAML configuration files.
|
||||
Fields can be either string references (from YAML) or actual instances (after processing).
|
||||
"""
|
||||
|
||||
# Core agent attributes (from BaseAgent)
|
||||
role: str
|
||||
goal: str
|
||||
backstory: str
|
||||
cache: bool
|
||||
verbose: bool
|
||||
max_rpm: int
|
||||
allow_delegation: bool
|
||||
max_iter: int
|
||||
max_tokens: int
|
||||
callbacks: list[str]
|
||||
|
||||
# LLM configuration
|
||||
llm: str
|
||||
function_calling_llm: str
|
||||
use_system_prompt: bool
|
||||
|
||||
# Template configuration
|
||||
system_template: str
|
||||
prompt_template: str
|
||||
response_template: str
|
||||
|
||||
# Tools and handlers (can be string references or instances)
|
||||
tools: list[str] | list[BaseTool]
|
||||
step_callback: str
|
||||
cache_handler: str | CacheHandler
|
||||
|
||||
# Code execution
|
||||
allow_code_execution: bool
|
||||
code_execution_mode: Literal["safe", "unsafe"]
|
||||
|
||||
# Context and performance
|
||||
respect_context_window: bool
|
||||
max_retry_limit: int
|
||||
|
||||
# Multimodal and reasoning
|
||||
multimodal: bool
|
||||
reasoning: bool
|
||||
max_reasoning_attempts: int
|
||||
|
||||
# Knowledge configuration
|
||||
knowledge_sources: list[str] | list[Any]
|
||||
knowledge_storage: str | Any
|
||||
knowledge_config: dict[str, Any]
|
||||
embedder: dict[str, Any]
|
||||
agent_knowledge_context: str
|
||||
crew_knowledge_context: str
|
||||
knowledge_search_query: str
|
||||
|
||||
# Misc configuration
|
||||
inject_date: bool
|
||||
date_format: str
|
||||
from_repository: str
|
||||
guardrail: Callable[[Any], tuple[bool, Any]] | str
|
||||
guardrail_max_retries: int
|
||||
|
||||
|
||||
class TaskConfig(TypedDict, total=False):
|
||||
"""Type definition for task configuration dictionary.
|
||||
|
||||
All fields are optional as they come from YAML configuration files.
|
||||
Fields can be either string references (from YAML) or actual instances (after processing).
|
||||
"""
|
||||
|
||||
# Core task attributes
|
||||
name: str
|
||||
description: str
|
||||
expected_output: str
|
||||
|
||||
# Agent and context
|
||||
agent: str
|
||||
context: list[str]
|
||||
|
||||
# Tools and callbacks (can be string references or instances)
|
||||
tools: list[str] | list[BaseTool]
|
||||
callback: str
|
||||
callbacks: list[str]
|
||||
|
||||
# Output configuration
|
||||
output_json: str
|
||||
output_pydantic: str
|
||||
output_file: str
|
||||
create_directory: bool
|
||||
|
||||
# Execution configuration
|
||||
async_execution: bool
|
||||
human_input: bool
|
||||
markdown: bool
|
||||
|
||||
# Guardrail configuration
|
||||
guardrail: Callable[[TaskOutput], tuple[bool, Any]] | str
|
||||
guardrail_max_retries: int
|
||||
|
||||
# Misc configuration
|
||||
allow_crewai_trigger_context: bool
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
T = TypeVar("T", bound=type)
|
||||
|
||||
"""Base decorator for creating crew classes with configuration and function management."""
|
||||
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def CrewBase(cls: T) -> T: # noqa: N802
|
||||
"""Wraps a class with crew functionality and configuration management."""
|
||||
def _set_base_directory(cls: type[CrewClass]) -> None:
|
||||
"""Set the base directory for the crew class.
|
||||
|
||||
class WrappedClass(cls): # type: ignore
|
||||
is_crew_class: bool = True # type: ignore
|
||||
Args:
|
||||
cls: Crew class to configure.
|
||||
"""
|
||||
try:
|
||||
cls.base_directory = Path(inspect.getfile(cls)).parent
|
||||
except (TypeError, OSError):
|
||||
cls.base_directory = Path.cwd()
|
||||
|
||||
# Get the directory of the class being decorated
|
||||
base_directory = Path(inspect.getfile(cls)).parent
|
||||
|
||||
original_agents_config_path = getattr(
|
||||
cls, "agents_config", "config/agents.yaml"
|
||||
def _set_config_paths(cls: type[CrewClass]) -> None:
|
||||
"""Set the configuration file paths for the crew class.
|
||||
|
||||
Args:
|
||||
cls: Crew class to configure.
|
||||
"""
|
||||
cls.original_agents_config_path = getattr(
|
||||
cls, "agents_config", "config/agents.yaml"
|
||||
)
|
||||
cls.original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml")
|
||||
|
||||
|
||||
def _set_mcp_params(cls: type[CrewClass]) -> None:
|
||||
"""Set the MCP server parameters for the crew class.
|
||||
|
||||
Args:
|
||||
cls: Crew class to configure.
|
||||
"""
|
||||
cls.mcp_server_params = getattr(cls, "mcp_server_params", None)
|
||||
cls.mcp_connect_timeout = getattr(cls, "mcp_connect_timeout", 30)
|
||||
|
||||
|
||||
def _is_string_list(value: list[str] | list[BaseTool]) -> TypeGuard[list[str]]:
|
||||
"""Type guard to check if list contains strings rather than BaseTool instances.
|
||||
|
||||
Args:
|
||||
value: List that may contain strings or BaseTool instances.
|
||||
|
||||
Returns:
|
||||
True if all elements are strings, False otherwise.
|
||||
"""
|
||||
return all(isinstance(item, str) for item in value)
|
||||
|
||||
|
||||
def _is_string_value(value: str | CacheHandler) -> TypeGuard[str]:
|
||||
"""Type guard to check if value is a string rather than a CacheHandler instance.
|
||||
|
||||
Args:
|
||||
value: Value that may be a string or CacheHandler instance.
|
||||
|
||||
Returns:
|
||||
True if value is a string, False otherwise.
|
||||
"""
|
||||
return isinstance(value, str)
|
||||
|
||||
|
||||
class CrewBaseMeta(type):
|
||||
"""Metaclass that adds crew functionality to classes."""
|
||||
|
||||
def __new__(
|
||||
mcs,
|
||||
name: str,
|
||||
bases: tuple[type, ...],
|
||||
namespace: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> type[CrewClass]:
|
||||
"""Create crew class with configuration and method injection.
|
||||
|
||||
Args:
|
||||
name: Class name.
|
||||
bases: Base classes.
|
||||
namespace: Class namespace dictionary.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
New crew class with injected methods and attributes.
|
||||
"""
|
||||
cls = cast(
|
||||
type[CrewClass], cast(object, super().__new__(mcs, name, bases, namespace))
|
||||
)
|
||||
original_tasks_config_path = getattr(cls, "tasks_config", "config/tasks.yaml")
|
||||
|
||||
mcp_server_params: Any = getattr(cls, "mcp_server_params", None)
|
||||
mcp_connect_timeout: int = getattr(cls, "mcp_connect_timeout", 30)
|
||||
cls.is_crew_class = True
|
||||
cls._crew_name = name
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.load_configurations()
|
||||
self.map_all_agent_variables()
|
||||
self.map_all_task_variables()
|
||||
# Preserve all decorated functions
|
||||
self._original_functions = {
|
||||
name: method
|
||||
for name, method in cls.__dict__.items()
|
||||
if any(
|
||||
hasattr(method, attr)
|
||||
for attr in [
|
||||
"is_task",
|
||||
"is_agent",
|
||||
"is_before_kickoff",
|
||||
"is_after_kickoff",
|
||||
"is_kickoff",
|
||||
]
|
||||
)
|
||||
}
|
||||
# Store specific function types
|
||||
self._original_tasks = self._filter_functions(
|
||||
self._original_functions, "is_task"
|
||||
)
|
||||
self._original_agents = self._filter_functions(
|
||||
self._original_functions, "is_agent"
|
||||
)
|
||||
self._before_kickoff = self._filter_functions(
|
||||
self._original_functions, "is_before_kickoff"
|
||||
)
|
||||
self._after_kickoff = self._filter_functions(
|
||||
self._original_functions, "is_after_kickoff"
|
||||
)
|
||||
self._kickoff = self._filter_functions(
|
||||
self._original_functions, "is_kickoff"
|
||||
)
|
||||
for setup_fn in _CLASS_SETUP_FUNCTIONS:
|
||||
setup_fn(cls)
|
||||
|
||||
# Add close mcp server method to after kickoff
|
||||
bound_method = self._create_close_mcp_server_method()
|
||||
self._after_kickoff['_close_mcp_server'] = bound_method
|
||||
for method in _METHODS_TO_INJECT:
|
||||
setattr(cls, method.__name__, method)
|
||||
|
||||
def _create_close_mcp_server_method(self):
|
||||
def _close_mcp_server(self, instance, outputs):
|
||||
adapter = getattr(self, '_mcp_server_adapter', None)
|
||||
if adapter is not None:
|
||||
try:
|
||||
adapter.stop()
|
||||
except Exception as e:
|
||||
logging.warning(f"Error stopping MCP server: {e}")
|
||||
return outputs
|
||||
return cls
|
||||
|
||||
_close_mcp_server.is_after_kickoff = True
|
||||
def __call__(cls, *args: Any, **kwargs: Any) -> CrewInstance:
|
||||
"""Intercept instance creation to initialize crew functionality.
|
||||
|
||||
import types
|
||||
return types.MethodType(_close_mcp_server, self)
|
||||
Args:
|
||||
*args: Positional arguments for instance creation.
|
||||
**kwargs: Keyword arguments for instance creation.
|
||||
|
||||
def get_mcp_tools(self, *tool_names: list[str]) -> list[BaseTool]:
|
||||
if not self.mcp_server_params:
|
||||
return []
|
||||
Returns:
|
||||
Initialized crew instance.
|
||||
"""
|
||||
instance: CrewInstance = super().__call__(*args, **kwargs)
|
||||
CrewBaseMeta._initialize_crew_instance(instance, cls)
|
||||
return instance
|
||||
|
||||
from crewai_tools import MCPServerAdapter # type: ignore[import-untyped]
|
||||
@staticmethod
|
||||
def _initialize_crew_instance(instance: CrewInstance, cls: type) -> None:
|
||||
"""Initialize crew instance attributes and load configurations.
|
||||
|
||||
adapter = getattr(self, '_mcp_server_adapter', None)
|
||||
if not adapter:
|
||||
self._mcp_server_adapter = MCPServerAdapter(
|
||||
self.mcp_server_params,
|
||||
connect_timeout=self.mcp_connect_timeout
|
||||
)
|
||||
Args:
|
||||
instance: Crew instance to initialize.
|
||||
cls: Crew class type.
|
||||
"""
|
||||
instance._mcp_server_adapter = None
|
||||
instance.load_configurations()
|
||||
instance._all_methods = _get_all_methods(instance)
|
||||
instance.map_all_agent_variables()
|
||||
instance.map_all_task_variables()
|
||||
|
||||
return self._mcp_server_adapter.tools.filter_by_names(tool_names or None)
|
||||
|
||||
|
||||
def load_configurations(self):
|
||||
"""Load agent and task configurations from YAML files."""
|
||||
if isinstance(self.original_agents_config_path, str):
|
||||
agents_config_path = (
|
||||
self.base_directory / self.original_agents_config_path
|
||||
)
|
||||
try:
|
||||
self.agents_config = self.load_yaml(agents_config_path)
|
||||
except FileNotFoundError:
|
||||
logging.warning(
|
||||
f"Agent config file not found at {agents_config_path}. "
|
||||
"Proceeding with empty agent configurations."
|
||||
)
|
||||
self.agents_config = {}
|
||||
else:
|
||||
logging.warning(
|
||||
"No agent configuration path provided. Proceeding with empty agent configurations."
|
||||
)
|
||||
self.agents_config = {}
|
||||
|
||||
if isinstance(self.original_tasks_config_path, str):
|
||||
tasks_config_path = (
|
||||
self.base_directory / self.original_tasks_config_path
|
||||
)
|
||||
try:
|
||||
self.tasks_config = self.load_yaml(tasks_config_path)
|
||||
except FileNotFoundError:
|
||||
logging.warning(
|
||||
f"Task config file not found at {tasks_config_path}. "
|
||||
"Proceeding with empty task configurations."
|
||||
)
|
||||
self.tasks_config = {}
|
||||
else:
|
||||
logging.warning(
|
||||
"No task configuration path provided. Proceeding with empty task configurations."
|
||||
)
|
||||
self.tasks_config = {}
|
||||
|
||||
@staticmethod
|
||||
def load_yaml(config_path: Path):
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as file:
|
||||
return yaml.safe_load(file)
|
||||
except FileNotFoundError:
|
||||
print(f"File not found: {config_path}")
|
||||
raise
|
||||
|
||||
def _get_all_functions(self):
|
||||
return {
|
||||
name: getattr(self, name)
|
||||
for name in dir(self)
|
||||
if callable(getattr(self, name))
|
||||
}
|
||||
|
||||
def _filter_functions(
|
||||
self, functions: dict[str, Callable], attribute: str
|
||||
) -> dict[str, Callable]:
|
||||
return {
|
||||
name: func
|
||||
for name, func in functions.items()
|
||||
if hasattr(func, attribute)
|
||||
}
|
||||
|
||||
def map_all_agent_variables(self) -> None:
|
||||
all_functions = self._get_all_functions()
|
||||
llms = self._filter_functions(all_functions, "is_llm")
|
||||
tool_functions = self._filter_functions(all_functions, "is_tool")
|
||||
cache_handler_functions = self._filter_functions(
|
||||
all_functions, "is_cache_handler"
|
||||
)
|
||||
callbacks = self._filter_functions(all_functions, "is_callback")
|
||||
|
||||
for agent_name, agent_info in self.agents_config.items():
|
||||
self._map_agent_variables(
|
||||
agent_name,
|
||||
agent_info,
|
||||
llms,
|
||||
tool_functions,
|
||||
cache_handler_functions,
|
||||
callbacks,
|
||||
)
|
||||
|
||||
def _map_agent_variables(
|
||||
self,
|
||||
agent_name: str,
|
||||
agent_info: dict[str, Any],
|
||||
llms: dict[str, Callable],
|
||||
tool_functions: dict[str, Callable],
|
||||
cache_handler_functions: dict[str, Callable],
|
||||
callbacks: dict[str, Callable],
|
||||
) -> None:
|
||||
if llm := agent_info.get("llm"):
|
||||
try:
|
||||
self.agents_config[agent_name]["llm"] = llms[llm]()
|
||||
except KeyError:
|
||||
self.agents_config[agent_name]["llm"] = llm
|
||||
|
||||
if tools := agent_info.get("tools"):
|
||||
self.agents_config[agent_name]["tools"] = [
|
||||
tool_functions[tool]() for tool in tools
|
||||
original_methods = {
|
||||
name: method
|
||||
for name, method in cls.__dict__.items()
|
||||
if any(
|
||||
hasattr(method, attr)
|
||||
for attr in [
|
||||
"is_task",
|
||||
"is_agent",
|
||||
"is_before_kickoff",
|
||||
"is_after_kickoff",
|
||||
"is_kickoff",
|
||||
]
|
||||
|
||||
if function_calling_llm := agent_info.get("function_calling_llm"):
|
||||
try:
|
||||
self.agents_config[agent_name]["function_calling_llm"] = llms[function_calling_llm]()
|
||||
except KeyError:
|
||||
self.agents_config[agent_name]["function_calling_llm"] = function_calling_llm
|
||||
|
||||
if step_callback := agent_info.get("step_callback"):
|
||||
self.agents_config[agent_name]["step_callback"] = callbacks[
|
||||
step_callback
|
||||
]()
|
||||
|
||||
if cache_handler := agent_info.get("cache_handler"):
|
||||
self.agents_config[agent_name]["cache_handler"] = (
|
||||
cache_handler_functions[cache_handler]()
|
||||
)
|
||||
|
||||
def map_all_task_variables(self) -> None:
|
||||
all_functions = self._get_all_functions()
|
||||
agents = self._filter_functions(all_functions, "is_agent")
|
||||
tasks = self._filter_functions(all_functions, "is_task")
|
||||
output_json_functions = self._filter_functions(
|
||||
all_functions, "is_output_json"
|
||||
)
|
||||
tool_functions = self._filter_functions(all_functions, "is_tool")
|
||||
callback_functions = self._filter_functions(all_functions, "is_callback")
|
||||
output_pydantic_functions = self._filter_functions(
|
||||
all_functions, "is_output_pydantic"
|
||||
}
|
||||
|
||||
after_kickoff_callbacks = _filter_methods(original_methods, "is_after_kickoff")
|
||||
after_kickoff_callbacks["close_mcp_server"] = instance.close_mcp_server
|
||||
|
||||
instance.__crew_metadata__ = CrewMetadata(
|
||||
original_methods=original_methods,
|
||||
original_tasks=_filter_methods(original_methods, "is_task"),
|
||||
original_agents=_filter_methods(original_methods, "is_agent"),
|
||||
before_kickoff=_filter_methods(original_methods, "is_before_kickoff"),
|
||||
after_kickoff=after_kickoff_callbacks,
|
||||
kickoff=_filter_methods(original_methods, "is_kickoff"),
|
||||
)
|
||||
|
||||
|
||||
def close_mcp_server(
|
||||
self: CrewInstance, _instance: CrewInstance, outputs: CrewOutput
|
||||
) -> CrewOutput:
|
||||
"""Stop MCP server adapter and return outputs.
|
||||
|
||||
Args:
|
||||
self: Crew instance with MCP server adapter.
|
||||
_instance: Crew instance (unused, required by callback signature).
|
||||
outputs: Crew execution outputs.
|
||||
|
||||
Returns:
|
||||
Unmodified crew outputs.
|
||||
"""
|
||||
if self._mcp_server_adapter is not None:
|
||||
try:
|
||||
self._mcp_server_adapter.stop()
|
||||
except Exception as e:
|
||||
logging.warning(f"Error stopping MCP server: {e}")
|
||||
return outputs
|
||||
|
||||
|
||||
def get_mcp_tools(self: CrewInstance, *tool_names: str) -> list[BaseTool]:
|
||||
"""Get MCP tools filtered by name.
|
||||
|
||||
Args:
|
||||
self: Crew instance with MCP server configuration.
|
||||
*tool_names: Optional tool names to filter by.
|
||||
|
||||
Returns:
|
||||
List of filtered MCP tools, or empty list if no MCP server configured.
|
||||
"""
|
||||
if not self.mcp_server_params:
|
||||
return []
|
||||
|
||||
from crewai_tools import MCPServerAdapter # type: ignore[import-untyped]
|
||||
|
||||
if self._mcp_server_adapter is None:
|
||||
self._mcp_server_adapter = MCPServerAdapter(
|
||||
self.mcp_server_params, connect_timeout=self.mcp_connect_timeout
|
||||
)
|
||||
|
||||
return self._mcp_server_adapter.tools.filter_by_names(tool_names or None)
|
||||
|
||||
|
||||
def _load_config(
|
||||
self: CrewInstance, config_path: str | None, config_type: Literal["agent", "task"]
|
||||
) -> dict[str, Any]:
|
||||
"""Load YAML config file or return empty dict if not found.
|
||||
|
||||
Args:
|
||||
self: Crew instance with base directory and load_yaml method.
|
||||
config_path: Relative path to config file.
|
||||
config_type: Config type for logging, either "agent" or "task".
|
||||
|
||||
Returns:
|
||||
Config dictionary or empty dict.
|
||||
"""
|
||||
if isinstance(config_path, str):
|
||||
full_path = self.base_directory / config_path
|
||||
try:
|
||||
return self.load_yaml(full_path)
|
||||
except FileNotFoundError:
|
||||
logging.warning(
|
||||
f"{config_type.capitalize()} config file not found at {full_path}. "
|
||||
f"Proceeding with empty {config_type} configurations."
|
||||
)
|
||||
return {}
|
||||
else:
|
||||
logging.warning(
|
||||
f"No {config_type} configuration path provided. "
|
||||
f"Proceeding with empty {config_type} configurations."
|
||||
)
|
||||
return {}
|
||||
|
||||
for task_name, task_info in self.tasks_config.items():
|
||||
self._map_task_variables(
|
||||
task_name,
|
||||
task_info,
|
||||
agents,
|
||||
tasks,
|
||||
output_json_functions,
|
||||
tool_functions,
|
||||
callback_functions,
|
||||
output_pydantic_functions,
|
||||
)
|
||||
|
||||
def _map_task_variables(
|
||||
self,
|
||||
task_name: str,
|
||||
task_info: dict[str, Any],
|
||||
agents: dict[str, Callable],
|
||||
tasks: dict[str, Callable],
|
||||
output_json_functions: dict[str, Callable],
|
||||
tool_functions: dict[str, Callable],
|
||||
callback_functions: dict[str, Callable],
|
||||
output_pydantic_functions: dict[str, Callable],
|
||||
) -> None:
|
||||
if context_list := task_info.get("context"):
|
||||
self.tasks_config[task_name]["context"] = [
|
||||
tasks[context_task_name]() for context_task_name in context_list
|
||||
]
|
||||
def load_configurations(self: CrewInstance) -> None:
|
||||
"""Load agent and task YAML configurations.
|
||||
|
||||
if tools := task_info.get("tools"):
|
||||
self.tasks_config[task_name]["tools"] = [
|
||||
tool_functions[tool]() for tool in tools
|
||||
]
|
||||
Args:
|
||||
self: Crew instance with configuration paths.
|
||||
"""
|
||||
self.agents_config = self._load_config(self.original_agents_config_path, "agent")
|
||||
self.tasks_config = self._load_config(self.original_tasks_config_path, "task")
|
||||
|
||||
if agent_name := task_info.get("agent"):
|
||||
self.tasks_config[task_name]["agent"] = agents[agent_name]()
|
||||
|
||||
if output_json := task_info.get("output_json"):
|
||||
self.tasks_config[task_name]["output_json"] = output_json_functions[
|
||||
output_json
|
||||
]
|
||||
def load_yaml(config_path: Path) -> dict[str, Any]:
|
||||
"""Load and parse YAML configuration file.
|
||||
|
||||
if output_pydantic := task_info.get("output_pydantic"):
|
||||
self.tasks_config[task_name]["output_pydantic"] = (
|
||||
output_pydantic_functions[output_pydantic]
|
||||
)
|
||||
Args:
|
||||
config_path: Path to YAML configuration file.
|
||||
|
||||
if callbacks := task_info.get("callbacks"):
|
||||
self.tasks_config[task_name]["callbacks"] = [
|
||||
callback_functions[callback]() for callback in callbacks
|
||||
]
|
||||
Returns:
|
||||
Parsed YAML content as a dictionary. Returns empty dict if file is empty.
|
||||
|
||||
if guardrail := task_info.get("guardrail"):
|
||||
self.tasks_config[task_name]["guardrail"] = guardrail
|
||||
Raises:
|
||||
FileNotFoundError: If config file does not exist.
|
||||
"""
|
||||
try:
|
||||
with open(config_path, encoding="utf-8") as file:
|
||||
content = yaml.safe_load(file)
|
||||
return content if isinstance(content, dict) else {}
|
||||
except FileNotFoundError:
|
||||
logging.warning(f"File not found: {config_path}")
|
||||
raise
|
||||
|
||||
# Include base class (qual)name in the wrapper class (qual)name.
|
||||
WrappedClass.__name__ = CrewBase.__name__ + "(" + cls.__name__ + ")"
|
||||
WrappedClass.__qualname__ = CrewBase.__qualname__ + "(" + cls.__name__ + ")"
|
||||
WrappedClass._crew_name = cls.__name__
|
||||
|
||||
return cast(T, WrappedClass)
|
||||
def _get_all_methods(self: CrewInstance) -> dict[str, Callable[..., Any]]:
|
||||
"""Return all non-dunder callable attributes (methods).
|
||||
|
||||
Args:
|
||||
self: Instance to inspect for callable attributes.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping method names to bound method objects.
|
||||
"""
|
||||
return {
|
||||
name: getattr(self, name)
|
||||
for name in dir(self)
|
||||
if not (name.startswith("__") and name.endswith("__"))
|
||||
and callable(getattr(self, name, None))
|
||||
}
|
||||
|
||||
|
||||
def _filter_methods(
|
||||
methods: dict[str, CallableT], attribute: str
|
||||
) -> dict[str, CallableT]:
|
||||
"""Filter methods by attribute presence, preserving exact callable types.
|
||||
|
||||
Args:
|
||||
methods: Dictionary of methods to filter.
|
||||
attribute: Attribute name to check for.
|
||||
|
||||
Returns:
|
||||
Dictionary containing only methods with the specified attribute.
|
||||
The return type matches the input callable type exactly.
|
||||
"""
|
||||
return {
|
||||
name: method for name, method in methods.items() if hasattr(method, attribute)
|
||||
}
|
||||
|
||||
|
||||
def map_all_agent_variables(self: CrewInstance) -> None:
|
||||
"""Map agent configuration variables to callable instances.
|
||||
|
||||
Args:
|
||||
self: Crew instance with agent configurations to map.
|
||||
"""
|
||||
llms = _filter_methods(self._all_methods, "is_llm")
|
||||
tool_functions = _filter_methods(self._all_methods, "is_tool")
|
||||
cache_handler_functions = _filter_methods(self._all_methods, "is_cache_handler")
|
||||
callbacks = _filter_methods(self._all_methods, "is_callback")
|
||||
|
||||
for agent_name, agent_info in self.agents_config.items():
|
||||
self._map_agent_variables(
|
||||
agent_name=agent_name,
|
||||
agent_info=agent_info,
|
||||
llms=llms,
|
||||
tool_functions=tool_functions,
|
||||
cache_handler_functions=cache_handler_functions,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
|
||||
def _map_agent_variables(
|
||||
self: CrewInstance,
|
||||
agent_name: str,
|
||||
agent_info: AgentConfig,
|
||||
llms: dict[str, Callable[[], Any]],
|
||||
tool_functions: dict[str, Callable[[], BaseTool]],
|
||||
cache_handler_functions: dict[str, Callable[[], Any]],
|
||||
callbacks: dict[str, Callable[..., Any]],
|
||||
) -> None:
|
||||
"""Resolve and map variables for a single agent.
|
||||
|
||||
Args:
|
||||
self: Crew instance with agent configurations.
|
||||
agent_name: Name of agent to configure.
|
||||
agent_info: Agent configuration dictionary with optional fields.
|
||||
llms: Dictionary mapping names to LLM factory functions.
|
||||
tool_functions: Dictionary mapping names to tool factory functions.
|
||||
cache_handler_functions: Dictionary mapping names to cache handler factory functions.
|
||||
callbacks: Dictionary of available callbacks.
|
||||
"""
|
||||
if llm := agent_info.get("llm"):
|
||||
factory = llms.get(llm)
|
||||
self.agents_config[agent_name]["llm"] = factory() if factory else llm
|
||||
|
||||
if tools := agent_info.get("tools"):
|
||||
if _is_string_list(tools):
|
||||
self.agents_config[agent_name]["tools"] = [
|
||||
tool_functions[tool]() for tool in tools
|
||||
]
|
||||
|
||||
if function_calling_llm := agent_info.get("function_calling_llm"):
|
||||
factory = llms.get(function_calling_llm)
|
||||
self.agents_config[agent_name]["function_calling_llm"] = (
|
||||
factory() if factory else function_calling_llm
|
||||
)
|
||||
|
||||
if step_callback := agent_info.get("step_callback"):
|
||||
self.agents_config[agent_name]["step_callback"] = callbacks[step_callback]()
|
||||
|
||||
if cache_handler := agent_info.get("cache_handler"):
|
||||
if _is_string_value(cache_handler):
|
||||
self.agents_config[agent_name]["cache_handler"] = cache_handler_functions[
|
||||
cache_handler
|
||||
]()
|
||||
|
||||
|
||||
def map_all_task_variables(self: CrewInstance) -> None:
|
||||
"""Map task configuration variables to callable instances.
|
||||
|
||||
Args:
|
||||
self: Crew instance with task configurations to map.
|
||||
"""
|
||||
agents = _filter_methods(self._all_methods, "is_agent")
|
||||
tasks = _filter_methods(self._all_methods, "is_task")
|
||||
output_json_functions = _filter_methods(self._all_methods, "is_output_json")
|
||||
tool_functions = _filter_methods(self._all_methods, "is_tool")
|
||||
callback_functions = _filter_methods(self._all_methods, "is_callback")
|
||||
output_pydantic_functions = _filter_methods(self._all_methods, "is_output_pydantic")
|
||||
|
||||
for task_name, task_info in self.tasks_config.items():
|
||||
self._map_task_variables(
|
||||
task_name=task_name,
|
||||
task_info=task_info,
|
||||
agents=agents,
|
||||
tasks=tasks,
|
||||
output_json_functions=output_json_functions,
|
||||
tool_functions=tool_functions,
|
||||
callback_functions=callback_functions,
|
||||
output_pydantic_functions=output_pydantic_functions,
|
||||
)
|
||||
|
||||
|
||||
def _map_task_variables(
|
||||
self: CrewInstance,
|
||||
task_name: str,
|
||||
task_info: TaskConfig,
|
||||
agents: dict[str, Callable[[], Agent]],
|
||||
tasks: dict[str, Callable[[], Task]],
|
||||
output_json_functions: dict[str, OutputJsonClass[Any]],
|
||||
tool_functions: dict[str, Callable[[], BaseTool]],
|
||||
callback_functions: dict[str, Callable[..., Any]],
|
||||
output_pydantic_functions: dict[str, OutputPydanticClass[Any]],
|
||||
) -> None:
|
||||
"""Resolve and map variables for a single task.
|
||||
|
||||
Args:
|
||||
self: Crew instance with task configurations.
|
||||
task_name: Name of task to configure.
|
||||
task_info: Task configuration dictionary with optional fields.
|
||||
agents: Dictionary mapping names to agent factory functions.
|
||||
tasks: Dictionary mapping names to task factory functions.
|
||||
output_json_functions: Dictionary of JSON output class wrappers.
|
||||
tool_functions: Dictionary mapping names to tool factory functions.
|
||||
callback_functions: Dictionary of available callbacks.
|
||||
output_pydantic_functions: Dictionary of Pydantic output class wrappers.
|
||||
"""
|
||||
if context_list := task_info.get("context"):
|
||||
self.tasks_config[task_name]["context"] = [
|
||||
tasks[context_task_name]() for context_task_name in context_list
|
||||
]
|
||||
|
||||
if tools := task_info.get("tools"):
|
||||
if _is_string_list(tools):
|
||||
self.tasks_config[task_name]["tools"] = [
|
||||
tool_functions[tool]() for tool in tools
|
||||
]
|
||||
|
||||
if agent_name := task_info.get("agent"):
|
||||
self.tasks_config[task_name]["agent"] = agents[agent_name]()
|
||||
|
||||
if output_json := task_info.get("output_json"):
|
||||
self.tasks_config[task_name]["output_json"] = output_json_functions[output_json]
|
||||
|
||||
if output_pydantic := task_info.get("output_pydantic"):
|
||||
self.tasks_config[task_name]["output_pydantic"] = output_pydantic_functions[
|
||||
output_pydantic
|
||||
]
|
||||
|
||||
if callbacks := task_info.get("callbacks"):
|
||||
self.tasks_config[task_name]["callbacks"] = [
|
||||
callback_functions[callback]() for callback in callbacks
|
||||
]
|
||||
|
||||
if guardrail := task_info.get("guardrail"):
|
||||
self.tasks_config[task_name]["guardrail"] = guardrail
|
||||
|
||||
|
||||
_CLASS_SETUP_FUNCTIONS: tuple[Callable[[type[CrewClass]], None], ...] = (
|
||||
_set_base_directory,
|
||||
_set_config_paths,
|
||||
_set_mcp_params,
|
||||
)
|
||||
|
||||
_METHODS_TO_INJECT = (
|
||||
close_mcp_server,
|
||||
get_mcp_tools,
|
||||
_load_config,
|
||||
load_configurations,
|
||||
staticmethod(load_yaml),
|
||||
map_all_agent_variables,
|
||||
_map_agent_variables,
|
||||
map_all_task_variables,
|
||||
_map_task_variables,
|
||||
)
|
||||
|
||||
|
||||
class _CrewBaseType(type):
|
||||
"""Metaclass for CrewBase that makes it callable as a decorator."""
|
||||
|
||||
def __call__(cls, decorated_cls: type) -> type[CrewClass]:
|
||||
"""Apply CrewBaseMeta to the decorated class.
|
||||
|
||||
Args:
|
||||
decorated_cls: Class to transform with CrewBaseMeta metaclass.
|
||||
|
||||
Returns:
|
||||
New class with CrewBaseMeta metaclass applied.
|
||||
"""
|
||||
__name = str(decorated_cls.__name__)
|
||||
__bases = tuple(decorated_cls.__bases__)
|
||||
__dict = {
|
||||
key: value
|
||||
for key, value in decorated_cls.__dict__.items()
|
||||
if key not in ("__dict__", "__weakref__")
|
||||
}
|
||||
for slot in __dict.get("__slots__", tuple()):
|
||||
__dict.pop(slot, None)
|
||||
__dict["__metaclass__"] = CrewBaseMeta
|
||||
return cast(type[CrewClass], CrewBaseMeta(__name, __bases, __dict))
|
||||
|
||||
|
||||
class CrewBase(metaclass=_CrewBaseType):
|
||||
"""Class decorator that applies CrewBaseMeta metaclass.
|
||||
|
||||
Applies CrewBaseMeta metaclass to a class via decorator syntax rather than
|
||||
explicit metaclass declaration. Use as @CrewBase instead of
|
||||
class Foo(metaclass=CrewBaseMeta).
|
||||
|
||||
Note:
|
||||
Reference: https://stackoverflow.com/questions/11091609/setting-a-class-metaclass-using-a-decorator
|
||||
"""
|
||||
|
||||
@@ -1,14 +1,38 @@
|
||||
"""Utility functions for the crewai project module."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def memoize(func):
|
||||
cache = {}
|
||||
def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Memoize a method by caching its results based on arguments.
|
||||
|
||||
@wraps(func)
|
||||
def memoized_func(*args, **kwargs):
|
||||
Args:
|
||||
meth: The method to memoize.
|
||||
|
||||
Returns:
|
||||
A memoized version of the method that caches results.
|
||||
"""
|
||||
cache: dict[Any, R] = {}
|
||||
|
||||
@wraps(meth)
|
||||
def memoized_func(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Memoized wrapper method.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments to pass to the method.
|
||||
**kwargs: Keyword arguments to pass to the method.
|
||||
|
||||
Returns:
|
||||
The cached or computed result of the method.
|
||||
"""
|
||||
key = (args, tuple(kwargs.items()))
|
||||
if key not in cache:
|
||||
cache[key] = func(*args, **kwargs)
|
||||
cache[key] = meth(*args, **kwargs)
|
||||
return cache[key]
|
||||
|
||||
return memoized_func
|
||||
|
||||
388
src/crewai/project/wrappers.py
Normal file
388
src/crewai/project/wrappers.py
Normal file
@@ -0,0 +1,388 @@
|
||||
"""Wrapper classes for decorated methods with type-safe metadata."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Generic,
|
||||
Literal,
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai import Agent, Task
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
|
||||
class CrewMetadata(TypedDict):
|
||||
"""Type definition for crew metadata dictionary.
|
||||
|
||||
Stores framework-injected metadata about decorated methods and callbacks.
|
||||
"""
|
||||
|
||||
original_methods: dict[str, Callable[..., Any]]
|
||||
original_tasks: dict[str, Callable[..., Task]]
|
||||
original_agents: dict[str, Callable[..., Agent]]
|
||||
before_kickoff: dict[str, Callable[..., Any]]
|
||||
after_kickoff: dict[str, Callable[..., Any]]
|
||||
kickoff: dict[str, Callable[..., Any]]
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class TaskResult(Protocol):
|
||||
"""Protocol for task objects that have a name attribute."""
|
||||
|
||||
name: str | None
|
||||
|
||||
|
||||
TaskResultT = TypeVar("TaskResultT", bound=TaskResult)
|
||||
|
||||
|
||||
def _copy_method_metadata(wrapper: Any, meth: Callable[..., Any]) -> None:
|
||||
"""Copy method metadata to a wrapper object.
|
||||
|
||||
Args:
|
||||
wrapper: The wrapper object to update.
|
||||
meth: The method to copy metadata from.
|
||||
"""
|
||||
wrapper.__name__ = meth.__name__
|
||||
wrapper.__doc__ = meth.__doc__
|
||||
|
||||
|
||||
class CrewInstance(Protocol):
|
||||
"""Protocol for crew class instances with required attributes."""
|
||||
|
||||
__crew_metadata__: CrewMetadata
|
||||
_mcp_server_adapter: Any
|
||||
_all_methods: dict[str, Callable[..., Any]]
|
||||
agents: list[Agent]
|
||||
tasks: list[Task]
|
||||
base_directory: Path
|
||||
original_agents_config_path: str
|
||||
original_tasks_config_path: str
|
||||
agents_config: dict[str, Any]
|
||||
tasks_config: dict[str, Any]
|
||||
mcp_server_params: Any
|
||||
mcp_connect_timeout: int
|
||||
|
||||
def load_configurations(self) -> None: ...
|
||||
def map_all_agent_variables(self) -> None: ...
|
||||
def map_all_task_variables(self) -> None: ...
|
||||
def close_mcp_server(self, instance: Self, outputs: CrewOutput) -> CrewOutput: ...
|
||||
def _load_config(
|
||||
self, config_path: str | None, config_type: Literal["agent", "task"]
|
||||
) -> dict[str, Any]: ...
|
||||
def _map_agent_variables(
|
||||
self,
|
||||
agent_name: str,
|
||||
agent_info: dict[str, Any],
|
||||
llms: dict[str, Callable[..., Any]],
|
||||
tool_functions: dict[str, Callable[..., Any]],
|
||||
cache_handler_functions: dict[str, Callable[..., Any]],
|
||||
callbacks: dict[str, Callable[..., Any]],
|
||||
) -> None: ...
|
||||
def _map_task_variables(
|
||||
self,
|
||||
task_name: str,
|
||||
task_info: dict[str, Any],
|
||||
agents: dict[str, Callable[..., Any]],
|
||||
tasks: dict[str, Callable[..., Any]],
|
||||
output_json_functions: dict[str, Callable[..., Any]],
|
||||
tool_functions: dict[str, Callable[..., Any]],
|
||||
callback_functions: dict[str, Callable[..., Any]],
|
||||
output_pydantic_functions: dict[str, Callable[..., Any]],
|
||||
) -> None: ...
|
||||
def load_yaml(self, config_path: Path) -> dict[str, Any]: ...
|
||||
|
||||
|
||||
class CrewClass(Protocol):
|
||||
"""Protocol describing class attributes injected by CrewBaseMeta."""
|
||||
|
||||
is_crew_class: bool
|
||||
_crew_name: str
|
||||
base_directory: Path
|
||||
original_agents_config_path: str
|
||||
original_tasks_config_path: str
|
||||
mcp_server_params: Any
|
||||
mcp_connect_timeout: int
|
||||
close_mcp_server: Callable[..., Any]
|
||||
get_mcp_tools: Callable[..., list[BaseTool]]
|
||||
_load_config: Callable[..., dict[str, Any]]
|
||||
load_configurations: Callable[..., None]
|
||||
load_yaml: staticmethod
|
||||
map_all_agent_variables: Callable[..., None]
|
||||
_map_agent_variables: Callable[..., None]
|
||||
map_all_task_variables: Callable[..., None]
|
||||
_map_task_variables: Callable[..., None]
|
||||
|
||||
|
||||
class DecoratedMethod(Generic[P, R]):
|
||||
"""Base wrapper for methods with decorator metadata.
|
||||
|
||||
This class provides a type-safe way to add metadata to methods
|
||||
while preserving their callable signature and attributes.
|
||||
"""
|
||||
|
||||
def __init__(self, meth: Callable[P, R]) -> None:
|
||||
"""Initialize the decorated method wrapper.
|
||||
|
||||
Args:
|
||||
meth: The method to wrap.
|
||||
"""
|
||||
self._meth = meth
|
||||
_copy_method_metadata(self, meth)
|
||||
|
||||
def __get__(
|
||||
self, obj: Any, objtype: type[Any] | None = None
|
||||
) -> Self | Callable[..., R]:
|
||||
"""Support instance methods by implementing the descriptor protocol.
|
||||
|
||||
Args:
|
||||
obj: The instance that the method is accessed through.
|
||||
objtype: The type of the instance.
|
||||
|
||||
Returns:
|
||||
Self when accessed through class, bound method when accessed through instance.
|
||||
"""
|
||||
if obj is None:
|
||||
return self
|
||||
bound = partial(self._meth, obj)
|
||||
for attr in (
|
||||
"is_agent",
|
||||
"is_llm",
|
||||
"is_tool",
|
||||
"is_callback",
|
||||
"is_cache_handler",
|
||||
"is_before_kickoff",
|
||||
"is_after_kickoff",
|
||||
"is_crew",
|
||||
):
|
||||
if hasattr(self, attr):
|
||||
setattr(bound, attr, getattr(self, attr))
|
||||
return bound
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Call the wrapped method.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments.
|
||||
**kwargs: Keyword arguments.
|
||||
|
||||
Returns:
|
||||
The result of calling the wrapped method.
|
||||
"""
|
||||
return self._meth(*args, **kwargs)
|
||||
|
||||
def unwrap(self) -> Callable[P, R]:
|
||||
"""Get the original unwrapped method.
|
||||
|
||||
Returns:
|
||||
The original method before decoration.
|
||||
"""
|
||||
return self._meth
|
||||
|
||||
|
||||
class BeforeKickoffMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked to execute before crew kickoff."""
|
||||
|
||||
is_before_kickoff: bool = True
|
||||
|
||||
|
||||
class AfterKickoffMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked to execute after crew kickoff."""
|
||||
|
||||
is_after_kickoff: bool = True
|
||||
|
||||
|
||||
class BoundTaskMethod(Generic[TaskResultT]):
|
||||
"""Bound task method with task marker attribute."""
|
||||
|
||||
is_task: bool = True
|
||||
|
||||
def __init__(self, task_method: TaskMethod[Any, TaskResultT], obj: Any) -> None:
|
||||
"""Initialize the bound task method.
|
||||
|
||||
Args:
|
||||
task_method: The TaskMethod descriptor instance.
|
||||
obj: The instance to bind to.
|
||||
"""
|
||||
self._task_method = task_method
|
||||
self._obj = obj
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> TaskResultT:
|
||||
"""Execute the bound task method.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments.
|
||||
**kwargs: Keyword arguments.
|
||||
|
||||
Returns:
|
||||
The task result with name ensured.
|
||||
"""
|
||||
result = self._task_method.unwrap()(self._obj, *args, **kwargs)
|
||||
return self._task_method.ensure_task_name(result)
|
||||
|
||||
|
||||
class TaskMethod(Generic[P, TaskResultT]):
|
||||
"""Wrapper for methods marked as crew tasks."""
|
||||
|
||||
is_task: bool = True
|
||||
|
||||
def __init__(self, meth: Callable[P, TaskResultT]) -> None:
|
||||
"""Initialize the task method wrapper.
|
||||
|
||||
Args:
|
||||
meth: The method to wrap.
|
||||
"""
|
||||
self._meth = meth
|
||||
_copy_method_metadata(self, meth)
|
||||
|
||||
def ensure_task_name(self, result: TaskResultT) -> TaskResultT:
|
||||
"""Ensure task result has a name set.
|
||||
|
||||
Args:
|
||||
result: The task result to check.
|
||||
|
||||
Returns:
|
||||
The task result with name ensured.
|
||||
"""
|
||||
if not result.name:
|
||||
result.name = self._meth.__name__
|
||||
return result
|
||||
|
||||
def __get__(
|
||||
self, obj: Any, objtype: type[Any] | None = None
|
||||
) -> Self | BoundTaskMethod[TaskResultT]:
|
||||
"""Support instance methods by implementing the descriptor protocol.
|
||||
|
||||
Args:
|
||||
obj: The instance that the method is accessed through.
|
||||
objtype: The type of the instance.
|
||||
|
||||
Returns:
|
||||
Self when accessed through class, bound method when accessed through instance.
|
||||
"""
|
||||
if obj is None:
|
||||
return self
|
||||
return BoundTaskMethod(self, obj)
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> TaskResultT:
|
||||
"""Call the wrapped method and set task name if not provided.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments.
|
||||
**kwargs: Keyword arguments.
|
||||
|
||||
Returns:
|
||||
The task instance with name set if not already provided.
|
||||
"""
|
||||
return self.ensure_task_name(self._meth(*args, **kwargs))
|
||||
|
||||
def unwrap(self) -> Callable[P, TaskResultT]:
|
||||
"""Get the original unwrapped method.
|
||||
|
||||
Returns:
|
||||
The original method before decoration.
|
||||
"""
|
||||
return self._meth
|
||||
|
||||
|
||||
class AgentMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked as crew agents."""
|
||||
|
||||
is_agent: bool = True
|
||||
|
||||
|
||||
class LLMMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked as LLM providers."""
|
||||
|
||||
is_llm: bool = True
|
||||
|
||||
|
||||
class ToolMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked as crew tools."""
|
||||
|
||||
is_tool: bool = True
|
||||
|
||||
|
||||
class CallbackMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked as crew callbacks."""
|
||||
|
||||
is_callback: bool = True
|
||||
|
||||
|
||||
class CacheHandlerMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked as cache handlers."""
|
||||
|
||||
is_cache_handler: bool = True
|
||||
|
||||
|
||||
class CrewMethod(DecoratedMethod[P, R]):
|
||||
"""Wrapper for methods marked as the main crew execution point."""
|
||||
|
||||
is_crew: bool = True
|
||||
|
||||
|
||||
class OutputClass(Generic[T]):
|
||||
"""Base wrapper for classes marked as output format."""
|
||||
|
||||
def __init__(self, cls: type[T]) -> None:
|
||||
"""Initialize the output class wrapper.
|
||||
|
||||
Args:
|
||||
cls: The class to wrap.
|
||||
"""
|
||||
self._cls = cls
|
||||
self.__name__ = cls.__name__
|
||||
self.__qualname__ = cls.__qualname__
|
||||
self.__module__ = cls.__module__
|
||||
self.__doc__ = cls.__doc__
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> T:
|
||||
"""Create an instance of the wrapped class.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments for the class constructor.
|
||||
**kwargs: Keyword arguments for the class constructor.
|
||||
|
||||
Returns:
|
||||
An instance of the wrapped class.
|
||||
"""
|
||||
return self._cls(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Delegate attribute access to the wrapped class.
|
||||
|
||||
Args:
|
||||
name: The attribute name.
|
||||
|
||||
Returns:
|
||||
The attribute from the wrapped class.
|
||||
"""
|
||||
return getattr(self._cls, name)
|
||||
|
||||
|
||||
class OutputJsonClass(OutputClass[T]):
|
||||
"""Wrapper for classes marked as JSON output format."""
|
||||
|
||||
is_output_json: bool = True
|
||||
|
||||
|
||||
class OutputPydanticClass(OutputClass[T]):
|
||||
"""Wrapper for classes marked as Pydantic output format."""
|
||||
|
||||
is_output_pydantic: bool = True
|
||||
@@ -7,7 +7,7 @@ import uuid
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import Future
|
||||
from copy import copy
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
@@ -672,7 +672,9 @@ Follow these guidelines:
|
||||
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
||||
|
||||
cloned_context = (
|
||||
[task_mapping[context_task.key] for context_task in self.context]
|
||||
self.context
|
||||
if self.context is NOT_SPECIFIED
|
||||
else [task_mapping[context_task.key] for context_task in self.context]
|
||||
if isinstance(self.context, list)
|
||||
else None
|
||||
)
|
||||
@@ -681,7 +683,7 @@ Follow these guidelines:
|
||||
return next((agent for agent in agents if agent.role == role), None)
|
||||
|
||||
cloned_agent = get_agent_by_role(self.agent.role) if self.agent else None
|
||||
cloned_tools = copy(self.tools) if self.tools else []
|
||||
cloned_tools = shallow_copy(self.tools) if self.tools else []
|
||||
|
||||
return self.__class__(
|
||||
**copied_data,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import ast
|
||||
import datetime
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from difflib import SequenceMatcher
|
||||
from json import JSONDecodeError
|
||||
@@ -44,183 +44,6 @@ OPENAI_BIGGER_MODELS = [
|
||||
]
|
||||
|
||||
|
||||
def _safe_literal_parse(input_str: str) -> Any:
|
||||
"""
|
||||
Safely parse a limited subset of Python literal syntax without using ast.literal_eval.
|
||||
Only supports: strings (single/double quotes), numbers, booleans, None, lists, dicts.
|
||||
Rejects any input that could lead to code execution.
|
||||
|
||||
Args:
|
||||
input_str: String to parse
|
||||
|
||||
Returns:
|
||||
Parsed Python object
|
||||
|
||||
Raises:
|
||||
ValueError: If input contains unsafe or unsupported syntax
|
||||
"""
|
||||
if not isinstance(input_str, str):
|
||||
raise ValueError("Input must be a string")
|
||||
|
||||
stripped = input_str.strip()
|
||||
if not stripped:
|
||||
raise ValueError("Input cannot be empty")
|
||||
|
||||
# Check for potentially dangerous patterns
|
||||
dangerous_patterns = [
|
||||
r'__.*__', # dunder methods
|
||||
r'import\b', # import statements
|
||||
r'exec\b', # exec function
|
||||
r'eval\b', # eval function
|
||||
r'lambda\b', # lambda functions
|
||||
r'def\b', # function definitions
|
||||
r'class\b', # class definitions
|
||||
r'@\w+', # decorators
|
||||
r'\.\.\.', # ellipsis (could be used in slicing)
|
||||
r'->[^\]]*\]', # type hints in lists
|
||||
]
|
||||
|
||||
for pattern in dangerous_patterns:
|
||||
if re.search(pattern, stripped, re.IGNORECASE):
|
||||
raise ValueError(f"Potentially dangerous pattern detected: {pattern}")
|
||||
|
||||
# Only allow specific characters
|
||||
allowed_chars = r'[\s\w\.\-\+\*/\(\)\[\]\{\}:\'"<>!=,!=\?%&|~^`]'
|
||||
if not re.fullmatch(f'{allowed_chars}*', stripped):
|
||||
raise ValueError("Input contains unsupported characters")
|
||||
|
||||
# Try JSON parsing first (safest)
|
||||
try:
|
||||
return json.loads(stripped)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Manual parsing for simple Python literals (JSON with single quotes, etc.)
|
||||
try:
|
||||
return _parse_python_literal_safe(stripped)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse input safely: {e}")
|
||||
|
||||
|
||||
def _parse_python_literal_safe(input_str: str) -> Any:
|
||||
"""
|
||||
Parse a limited subset of Python literals safely.
|
||||
|
||||
Args:
|
||||
input_str: String to parse
|
||||
|
||||
Returns:
|
||||
Parsed Python object
|
||||
"""
|
||||
# Handle None
|
||||
if input_str == 'None':
|
||||
return None
|
||||
|
||||
# Handle booleans
|
||||
if input_str == 'True':
|
||||
return True
|
||||
if input_str == 'False':
|
||||
return False
|
||||
|
||||
# Handle numbers
|
||||
if re.fullmatch(r'-?\d+$', input_str):
|
||||
return int(input_str)
|
||||
if re.fullmatch(r'-?\d+\.\d+$', input_str):
|
||||
return float(input_str)
|
||||
|
||||
# Handle strings with single quotes (convert to JSON format)
|
||||
if (input_str.startswith("'") and input_str.endswith("'")) or \
|
||||
(input_str.startswith('"') and input_str.endswith('"')):
|
||||
# Simple string - just remove quotes and escape common sequences
|
||||
inner = input_str[1:-1]
|
||||
# Handle common escape sequences safely
|
||||
inner = inner.replace("\\'", "'").replace('\\"', '"').replace("\\\\", "\\")
|
||||
return inner
|
||||
|
||||
# Handle lists
|
||||
if input_str.startswith('[') and input_str.endswith(']'):
|
||||
inner = input_str[1:-1].strip()
|
||||
if not inner:
|
||||
return []
|
||||
|
||||
items = _split_items_safe(inner)
|
||||
return [_parse_python_literal_safe(item.strip()) for item in items]
|
||||
|
||||
# Handle dictionaries
|
||||
if input_str.startswith('{') and input_str.endswith('}'):
|
||||
inner = input_str[1:-1].strip()
|
||||
if not inner:
|
||||
return {}
|
||||
|
||||
pairs = _split_items_safe(inner)
|
||||
result = {}
|
||||
for pair in pairs:
|
||||
if ':' not in pair:
|
||||
raise ValueError(f"Invalid dict pair: {pair}")
|
||||
|
||||
key_str, value_str = pair.split(':', 1)
|
||||
key = _parse_python_literal_safe(key_str.strip())
|
||||
value = _parse_python_literal_safe(value_str.strip())
|
||||
if not isinstance(key, str):
|
||||
raise ValueError(f"Dict keys must be strings, got {type(key)}")
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
|
||||
raise ValueError(f"Unsupported literal format: {input_str}")
|
||||
|
||||
|
||||
def _split_items_safe(input_str: str, delimiter: str = ',') -> list[str]:
|
||||
"""
|
||||
Split a list or dict string into items, respecting nested structures.
|
||||
|
||||
Args:
|
||||
input_str: String to split
|
||||
delimiter: Delimiter to split on
|
||||
|
||||
Returns:
|
||||
List of item strings
|
||||
"""
|
||||
items = []
|
||||
current = []
|
||||
depth = 0
|
||||
in_string = False
|
||||
string_char = None
|
||||
i = 0
|
||||
|
||||
while i < len(input_str):
|
||||
char = input_str[i]
|
||||
|
||||
# Handle string literals
|
||||
if char in ('"', "'") and (i == 0 or input_str[i-1] != '\\'):
|
||||
if not in_string:
|
||||
in_string = True
|
||||
string_char = char
|
||||
elif char == string_char:
|
||||
in_string = False
|
||||
string_char = None
|
||||
|
||||
# Track nesting depth when not in strings
|
||||
elif not in_string:
|
||||
if char in ('[', '(', '{'):
|
||||
depth += 1
|
||||
elif char in (']', ')', '}'):
|
||||
depth -= 1
|
||||
elif char == delimiter and depth == 0:
|
||||
items.append(''.join(current).strip())
|
||||
current = []
|
||||
i += 1
|
||||
continue
|
||||
|
||||
current.append(char)
|
||||
i += 1
|
||||
|
||||
if current:
|
||||
items.append(''.join(current).strip())
|
||||
|
||||
return items
|
||||
|
||||
|
||||
class ToolUsageError(Exception):
|
||||
"""Exception raised for errors in the tool usage."""
|
||||
|
||||
@@ -701,14 +524,14 @@ class ToolUsage:
|
||||
except (JSONDecodeError, TypeError):
|
||||
pass # Continue to the next parsing attempt
|
||||
|
||||
# Attempt 2: Parse as Python literal (safe alternative to ast.literal_eval)
|
||||
# Attempt 2: Parse as Python literal
|
||||
try:
|
||||
arguments = _safe_literal_parse(tool_input)
|
||||
arguments = ast.literal_eval(tool_input)
|
||||
if isinstance(arguments, dict):
|
||||
return arguments
|
||||
except ValueError:
|
||||
except (ValueError, SyntaxError):
|
||||
repaired_input = repair_json(tool_input)
|
||||
# Continue to the next parsing attempt
|
||||
pass
|
||||
|
||||
# Attempt 3: Parse as JSON5
|
||||
try:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import jwt
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import jwt
|
||||
|
||||
from crewai.cli.authentication.utils import validate_jwt_token
|
||||
|
||||
@@ -17,19 +17,22 @@ class TestUtils(unittest.TestCase):
|
||||
key="mock_signing_key"
|
||||
)
|
||||
|
||||
jwt_token = "aaaaa.bbbbbb.cccccc" # noqa: S105
|
||||
|
||||
decoded_token = validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwt_token=jwt_token,
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
|
||||
mock_jwt.decode.assert_called_with(
|
||||
"aaaaa.bbbbbb.cccccc",
|
||||
jwt_token,
|
||||
"mock_signing_key",
|
||||
algorithms=["RS256"],
|
||||
audience="app_id_xxxx",
|
||||
issuer="https://mock_issuer",
|
||||
leeway=10.0,
|
||||
options={
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
@@ -43,9 +46,9 @@ class TestUtils(unittest.TestCase):
|
||||
|
||||
def test_validate_jwt_token_expired(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.ExpiredSignatureError
|
||||
with self.assertRaises(Exception):
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
@@ -53,9 +56,9 @@ class TestUtils(unittest.TestCase):
|
||||
|
||||
def test_validate_jwt_token_invalid_audience(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.InvalidAudienceError
|
||||
with self.assertRaises(Exception):
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
@@ -63,9 +66,9 @@ class TestUtils(unittest.TestCase):
|
||||
|
||||
def test_validate_jwt_token_invalid_issuer(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.InvalidIssuerError
|
||||
with self.assertRaises(Exception):
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
@@ -75,9 +78,9 @@ class TestUtils(unittest.TestCase):
|
||||
self, mock_jwt, mock_pyjwkclient
|
||||
):
|
||||
mock_jwt.decode.side_effect = jwt.MissingRequiredClaimError
|
||||
with self.assertRaises(Exception):
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
@@ -85,9 +88,9 @@ class TestUtils(unittest.TestCase):
|
||||
|
||||
def test_validate_jwt_token_jwks_error(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.exceptions.PyJWKClientError
|
||||
with self.assertRaises(Exception):
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
@@ -95,9 +98,9 @@ class TestUtils(unittest.TestCase):
|
||||
|
||||
def test_validate_jwt_token_invalid_token(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.InvalidTokenError
|
||||
with self.assertRaises(Exception):
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
|
||||
@@ -1218,7 +1218,7 @@ def test_create_directory_false():
|
||||
assert not resolved_dir.exists()
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError, match="Directory .* does not exist and create_directory is False"
|
||||
RuntimeError, match=r"Directory .* does not exist and create_directory is False"
|
||||
):
|
||||
task._save_file("test content")
|
||||
|
||||
@@ -1635,3 +1635,48 @@ def test_task_interpolation_with_hyphens():
|
||||
assert "say hello world" in task.prompt()
|
||||
|
||||
assert result.raw == "Hello, World!"
|
||||
|
||||
|
||||
def test_task_copy_with_none_context():
|
||||
original_task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
context=None
|
||||
)
|
||||
|
||||
new_task = original_task.copy(agents=[], task_mapping={})
|
||||
assert original_task.context is None
|
||||
assert new_task.context is None
|
||||
|
||||
|
||||
def test_task_copy_with_not_specified_context():
|
||||
from crewai.utilities.constants import NOT_SPECIFIED
|
||||
original_task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
)
|
||||
|
||||
new_task = original_task.copy(agents=[], task_mapping={})
|
||||
assert original_task.context is NOT_SPECIFIED
|
||||
assert new_task.context is NOT_SPECIFIED
|
||||
|
||||
|
||||
def test_task_copy_with_list_context():
|
||||
"""Test that copying a task with list context works correctly."""
|
||||
task1 = Task(
|
||||
description="Task 1",
|
||||
expected_output="Output 1"
|
||||
)
|
||||
task2 = Task(
|
||||
description="Task 2",
|
||||
expected_output="Output 2",
|
||||
context=[task1]
|
||||
)
|
||||
|
||||
task_mapping = {task1.key: task1}
|
||||
|
||||
copied_task2 = task2.copy(agents=[], task_mapping=task_mapping)
|
||||
|
||||
assert isinstance(copied_task2.context, list)
|
||||
assert len(copied_task2.context) == 1
|
||||
assert copied_task2.context[0] is task1
|
||||
|
||||
Reference in New Issue
Block a user