mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-03-20 18:58:13 +00:00
Compare commits
15 Commits
luzk/missi
...
refactor/s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9929aea830 | ||
|
|
0a3d2f596e | ||
|
|
0325703901 | ||
|
|
c1abefbbf3 | ||
|
|
870113335b | ||
|
|
82b9b98fd2 | ||
|
|
f2223281a9 | ||
|
|
47338c1efa | ||
|
|
0732582241 | ||
|
|
56eaa1d27b | ||
|
|
756e939543 | ||
|
|
d983eca2dd | ||
|
|
0d6b2ef8b8 | ||
|
|
aadb2d6694 | ||
|
|
395f6f1ae2 |
@@ -66,6 +66,7 @@ from crewai.mcp.tool_resolver import MCPToolResolver
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.security.fingerprint import Fingerprint
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.types.callback import SerializableCallable
|
||||
from crewai.utilities.agent_utils import (
|
||||
get_tool_names,
|
||||
is_inside_event_loop,
|
||||
@@ -143,7 +144,7 @@ class Agent(BaseAgent):
|
||||
default=None,
|
||||
description="Maximum execution time for an agent to execute a task",
|
||||
)
|
||||
step_callback: Any | None = Field(
|
||||
step_callback: SerializableCallable | None = Field(
|
||||
default=None,
|
||||
description="Callback to be executed after each step of the agent execution.",
|
||||
)
|
||||
@@ -151,10 +152,10 @@ class Agent(BaseAgent):
|
||||
default=True,
|
||||
description="Use system prompt for the agent.",
|
||||
)
|
||||
llm: str | InstanceOf[BaseLLM] | Any = Field(
|
||||
llm: str | InstanceOf[BaseLLM] | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
function_calling_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
function_calling_llm: str | InstanceOf[BaseLLM] | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
system_template: str | None = Field(
|
||||
@@ -340,7 +341,7 @@ class Agent(BaseAgent):
|
||||
return (
|
||||
hasattr(self.llm, "supports_function_calling")
|
||||
and callable(getattr(self.llm, "supports_function_calling", None))
|
||||
and self.llm.supports_function_calling()
|
||||
and self.llm.supports_function_calling() # type: ignore[union-attr]
|
||||
and len(tools) > 0
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from copy import copy as shallow_copy
|
||||
from hashlib import md5
|
||||
import re
|
||||
@@ -12,6 +11,7 @@ from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
field_validator,
|
||||
model_validator,
|
||||
@@ -26,10 +26,14 @@ from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.knowledge_config import KnowledgeConfig
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.mcp.config import MCPServerConfig
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.security.security_config import SecurityConfig
|
||||
from crewai.tools.base_tool import BaseTool, Tool
|
||||
from crewai.types.callback import SerializableCallable
|
||||
from crewai.utilities.config import process_config
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.logger import Logger
|
||||
@@ -179,7 +183,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
default=None,
|
||||
description="Knowledge sources for the agent.",
|
||||
)
|
||||
knowledge_storage: Any | None = Field(
|
||||
knowledge_storage: InstanceOf[BaseKnowledgeStorage] | None = Field(
|
||||
default=None,
|
||||
description="Custom knowledge storage for the agent.",
|
||||
)
|
||||
@@ -187,7 +191,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
default_factory=SecurityConfig,
|
||||
description="Security configuration for the agent, including fingerprinting.",
|
||||
)
|
||||
callbacks: list[Callable[[Any], Any]] = Field(
|
||||
callbacks: list[SerializableCallable] = Field(
|
||||
default_factory=list, description="Callbacks to be used for the agent"
|
||||
)
|
||||
adapted_agent: bool = Field(
|
||||
@@ -205,7 +209,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
default=None,
|
||||
description="List of MCP server references. Supports 'https://server.com/path' for external servers and bare slugs like 'notion' for connected MCP integrations. Use '#tool_name' suffix for specific tools.",
|
||||
)
|
||||
memory: Any = Field(
|
||||
memory: bool | Memory | MemoryScope | MemorySlice | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Enable agent memory. Pass True for default Memory(), "
|
||||
|
||||
@@ -35,6 +35,7 @@ from typing_extensions import Self
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai_files import FileInput
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
try:
|
||||
from crewai_files import get_supported_content_types
|
||||
@@ -83,6 +84,8 @@ from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai.process import Process
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.rag.types import SearchResult
|
||||
@@ -94,6 +97,7 @@ from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.agent_tools.read_file_tool import ReadFileTool
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.types.callback import SerializableCallable
|
||||
from crewai.types.streaming import CrewStreamingOutput
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
|
||||
@@ -166,12 +170,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"""
|
||||
|
||||
__hash__ = object.__hash__
|
||||
_execution_span: Any = PrivateAttr()
|
||||
_execution_span: Span | None = PrivateAttr()
|
||||
_rpm_controller: RPMController = PrivateAttr()
|
||||
_logger: Logger = PrivateAttr()
|
||||
_file_handler: FileHandler = PrivateAttr()
|
||||
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default_factory=CacheHandler)
|
||||
_memory: Any = PrivateAttr(default=None) # Unified Memory | MemoryScope
|
||||
_memory: Memory | MemoryScope | MemorySlice | None = PrivateAttr(default=None)
|
||||
_train: bool | None = PrivateAttr(default=False)
|
||||
_train_iteration: int | None = PrivateAttr()
|
||||
_inputs: dict[str, Any] | None = PrivateAttr(default=None)
|
||||
@@ -189,7 +193,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
agents: list[BaseAgent] = Field(default_factory=list)
|
||||
process: Process = Field(default=Process.sequential)
|
||||
verbose: bool = Field(default=False)
|
||||
memory: bool | Any = Field(
|
||||
memory: bool | Memory | MemoryScope | MemorySlice | None = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Enable crew memory. Pass True for default Memory(), "
|
||||
@@ -204,36 +208,34 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="Metrics for the LLM usage during all tasks execution.",
|
||||
)
|
||||
manager_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
manager_llm: str | InstanceOf[BaseLLM] | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
manager_agent: BaseAgent | None = Field(
|
||||
description="Custom agent that will be used as manager.", default=None
|
||||
)
|
||||
function_calling_llm: str | InstanceOf[LLM] | Any | None = Field(
|
||||
function_calling_llm: str | InstanceOf[LLM] | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
config: Json[dict[str, Any]] | dict[str, Any] | None = Field(default=None)
|
||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||
share_crew: bool | None = Field(default=False)
|
||||
step_callback: Any | None = Field(
|
||||
step_callback: SerializableCallable | None = Field(
|
||||
default=None,
|
||||
description="Callback to be executed after each step for all agents execution.",
|
||||
)
|
||||
task_callback: Any | None = Field(
|
||||
task_callback: SerializableCallable | None = Field(
|
||||
default=None,
|
||||
description="Callback to be executed after each task for all agents execution.",
|
||||
)
|
||||
before_kickoff_callbacks: list[
|
||||
Callable[[dict[str, Any] | None], dict[str, Any] | None]
|
||||
] = Field(
|
||||
before_kickoff_callbacks: list[SerializableCallable] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"List of callbacks to be executed before crew kickoff. "
|
||||
"It may be used to adjust inputs before the crew is executed."
|
||||
),
|
||||
)
|
||||
after_kickoff_callbacks: list[Callable[[CrewOutput], CrewOutput]] = Field(
|
||||
after_kickoff_callbacks: list[SerializableCallable] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"List of callbacks to be executed after crew kickoff. "
|
||||
@@ -349,7 +351,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._file_handler = FileHandler(self.output_log_file)
|
||||
self._rpm_controller = RPMController(max_rpm=self.max_rpm, logger=self._logger)
|
||||
if self.function_calling_llm and not isinstance(self.function_calling_llm, LLM):
|
||||
self.function_calling_llm = create_llm(self.function_calling_llm)
|
||||
self.function_calling_llm = create_llm(self.function_calling_llm) # type: ignore[assignment]
|
||||
|
||||
return self
|
||||
|
||||
@@ -363,7 +365,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if self.embedder is not None:
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
|
||||
embedder = build_embedder(self.embedder)
|
||||
embedder = build_embedder(self.embedder) # type: ignore[arg-type]
|
||||
self._memory = Memory(embedder=embedder)
|
||||
elif self.memory:
|
||||
# User passed a Memory / MemoryScope / MemorySlice instance
|
||||
|
||||
@@ -81,6 +81,7 @@ from crewai.flow.flow_wrappers import (
|
||||
SimpleFlowCondition,
|
||||
StartMethod,
|
||||
)
|
||||
from crewai.flow.input_provider import InputProvider
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.types import (
|
||||
FlowExecutionData,
|
||||
@@ -99,6 +100,8 @@ from crewai.flow.utils import (
|
||||
is_flow_method_name,
|
||||
is_simple_flow_condition,
|
||||
)
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -501,7 +504,7 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
|
||||
|
||||
def index(
|
||||
self, value: T, start: SupportsIndex = 0, stop: SupportsIndex | None = None
|
||||
) -> int: # type: ignore[override]
|
||||
) -> int:
|
||||
if stop is None:
|
||||
return self._list.index(value, start)
|
||||
return self._list.index(value, start, stop)
|
||||
@@ -520,13 +523,13 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
|
||||
def copy(self) -> list[T]:
|
||||
return self._list.copy()
|
||||
|
||||
def __add__(self, other: list[T]) -> list[T]:
|
||||
def __add__(self, other: list[T]) -> list[T]: # type: ignore[override]
|
||||
return self._list + other
|
||||
|
||||
def __radd__(self, other: list[T]) -> list[T]:
|
||||
return other + self._list
|
||||
|
||||
def __iadd__(self, other: Iterable[T]) -> LockedListProxy[T]:
|
||||
def __iadd__(self, other: Iterable[T]) -> LockedListProxy[T]: # type: ignore[override]
|
||||
with self._lock:
|
||||
self._list += list(other)
|
||||
return self
|
||||
@@ -630,13 +633,13 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
|
||||
def copy(self) -> dict[str, T]:
|
||||
return self._dict.copy()
|
||||
|
||||
def __or__(self, other: dict[str, T]) -> dict[str, T]:
|
||||
def __or__(self, other: dict[str, T]) -> dict[str, T]: # type: ignore[override]
|
||||
return self._dict | other
|
||||
|
||||
def __ror__(self, other: dict[str, T]) -> dict[str, T]:
|
||||
def __ror__(self, other: dict[str, T]) -> dict[str, T]: # type: ignore[override]
|
||||
return other | self._dict
|
||||
|
||||
def __ior__(self, other: dict[str, T]) -> LockedDictProxy[T]:
|
||||
def __ior__(self, other: dict[str, T]) -> LockedDictProxy[T]: # type: ignore[override]
|
||||
with self._lock:
|
||||
self._dict |= other
|
||||
return self
|
||||
@@ -822,10 +825,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
name: str | None = None
|
||||
tracing: bool | None = None
|
||||
stream: bool = False
|
||||
memory: Any = (
|
||||
None # Memory | MemoryScope | MemorySlice | None; auto-created if not set
|
||||
)
|
||||
input_provider: Any = None # InputProvider | None; per-flow override for self.ask()
|
||||
memory: Memory | MemoryScope | MemorySlice | None = None
|
||||
input_provider: InputProvider | None = None
|
||||
|
||||
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]:
|
||||
class _FlowGeneric(cls): # type: ignore
|
||||
@@ -904,8 +905,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# Internal flows (RecallFlow, EncodingFlow) set _skip_auto_memory
|
||||
# to avoid creating a wasteful standalone Memory instance.
|
||||
if self.memory is None and not getattr(self, "_skip_auto_memory", False):
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
self.memory = Memory()
|
||||
|
||||
# Register all flow-related methods
|
||||
@@ -951,10 +950,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
Raises:
|
||||
ValueError: If no memory is configured for this flow.
|
||||
TypeError: If batch remember is attempted on a MemoryScope or MemorySlice.
|
||||
"""
|
||||
if self.memory is None:
|
||||
raise ValueError("No memory configured for this flow")
|
||||
if isinstance(content, list):
|
||||
if not isinstance(self.memory, Memory):
|
||||
raise TypeError(
|
||||
"Batch remember requires a Memory instance, "
|
||||
f"got {type(self.memory).__name__}"
|
||||
)
|
||||
return self.memory.remember_many(content, **kwargs)
|
||||
return self.memory.remember(content, **kwargs)
|
||||
|
||||
@@ -2725,7 +2730,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
# ── User Input (self.ask) ────────────────────────────────────────
|
||||
|
||||
def _resolve_input_provider(self) -> Any:
|
||||
def _resolve_input_provider(self) -> InputProvider:
|
||||
"""Resolve the input provider using the priority chain.
|
||||
|
||||
Resolution order:
|
||||
|
||||
@@ -6,7 +6,7 @@ customize Flow behavior at runtime.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -32,17 +32,17 @@ class FlowConfig:
|
||||
self._input_provider: InputProvider | None = None
|
||||
|
||||
@property
|
||||
def hitl_provider(self) -> Any:
|
||||
def hitl_provider(self) -> HumanFeedbackProvider | None:
|
||||
"""Get the configured HITL provider."""
|
||||
return self._hitl_provider
|
||||
|
||||
@hitl_provider.setter
|
||||
def hitl_provider(self, provider: Any) -> None:
|
||||
def hitl_provider(self, provider: HumanFeedbackProvider | None) -> None:
|
||||
"""Set the HITL provider."""
|
||||
self._hitl_provider = provider
|
||||
|
||||
@property
|
||||
def input_provider(self) -> Any:
|
||||
def input_provider(self) -> InputProvider | None:
|
||||
"""Get the configured input provider for ``Flow.ask()``.
|
||||
|
||||
Returns:
|
||||
@@ -52,7 +52,7 @@ class FlowConfig:
|
||||
return self._input_provider
|
||||
|
||||
@input_provider.setter
|
||||
def input_provider(self, provider: Any) -> None:
|
||||
def input_provider(self, provider: InputProvider | None) -> None:
|
||||
"""Set the input provider for ``Flow.ask()``.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -22,7 +22,6 @@ from crewai.events.types.memory_events import (
|
||||
)
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.memory.analyze import extract_memories_from_content
|
||||
from crewai.memory.recall_flow import RecallFlow
|
||||
from crewai.memory.storage.backend import StorageBackend
|
||||
from crewai.memory.types import (
|
||||
MemoryConfig,
|
||||
@@ -620,6 +619,8 @@ class Memory(BaseModel):
|
||||
)
|
||||
results.sort(key=lambda m: m.score, reverse=True)
|
||||
else:
|
||||
from crewai.memory.recall_flow import RecallFlow
|
||||
|
||||
flow = RecallFlow(
|
||||
storage=self._storage,
|
||||
llm=self._llm,
|
||||
|
||||
@@ -67,6 +67,7 @@ except ImportError:
|
||||
return []
|
||||
|
||||
|
||||
from crewai.types.callback import SerializableCallable
|
||||
from crewai.utilities.guardrail import (
|
||||
process_guardrail,
|
||||
)
|
||||
@@ -124,7 +125,7 @@ class Task(BaseModel):
|
||||
description="Configuration for the agent",
|
||||
default=None,
|
||||
)
|
||||
callback: Any | None = Field(
|
||||
callback: SerializableCallable | None = Field(
|
||||
description="Callback to be executed after the task is completed.", default=None
|
||||
)
|
||||
agent: BaseAgent | None = Field(
|
||||
|
||||
152
lib/crewai/src/crewai/types/callback.py
Normal file
152
lib/crewai/src/crewai/types/callback.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""Serializable callback type for Pydantic models.
|
||||
|
||||
Provides a ``SerializableCallable`` type alias that enables full JSON
|
||||
round-tripping of callback fields, e.g. ``"builtins.print"`` ↔ ``print``.
|
||||
Lambdas and closures serialize to a dotted path but cannot be deserialized
|
||||
back — use module-level named functions for checkpointable callbacks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
from typing import Annotated, Any
|
||||
import warnings
|
||||
|
||||
from pydantic import BeforeValidator, WithJsonSchema
|
||||
from pydantic.functional_serializers import PlainSerializer
|
||||
|
||||
|
||||
def _is_non_roundtrippable(fn: object) -> bool:
|
||||
"""Return ``True`` if *fn* cannot survive a serialize/deserialize round-trip.
|
||||
|
||||
Built-in functions, plain module-level functions, and classes produce
|
||||
dotted paths that :func:`_resolve_dotted_path` can reliably resolve.
|
||||
Bound methods, ``functools.partial`` objects, callable class instances,
|
||||
lambdas, and closures all fail or silently change semantics during
|
||||
round-tripping.
|
||||
|
||||
Args:
|
||||
fn: The object to check.
|
||||
|
||||
Returns:
|
||||
``True`` if *fn* would not round-trip through JSON serialization.
|
||||
"""
|
||||
if inspect.isbuiltin(fn) or inspect.isclass(fn):
|
||||
return False
|
||||
if inspect.isfunction(fn):
|
||||
qualname = getattr(fn, "__qualname__", "")
|
||||
return qualname.endswith("<lambda>") or "<locals>" in qualname
|
||||
return True
|
||||
|
||||
|
||||
def string_to_callable(value: Any) -> Callable[..., Any]:
|
||||
"""Convert a dotted path string to the callable it references.
|
||||
|
||||
If *value* is already callable it is returned as-is, with a warning if
|
||||
it cannot survive JSON round-tripping. Otherwise, it is treated as
|
||||
``"module.qualname"`` and resolved via :func:`_resolve_dotted_path`.
|
||||
|
||||
Args:
|
||||
value: A callable or a dotted-path string e.g. ``"builtins.print"``.
|
||||
|
||||
Returns:
|
||||
The resolved callable.
|
||||
|
||||
Raises:
|
||||
ValueError: If *value* is not callable or a resolvable dotted-path string.
|
||||
"""
|
||||
if callable(value):
|
||||
if _is_non_roundtrippable(value):
|
||||
warnings.warn(
|
||||
f"{type(value).__name__} callbacks cannot be serialized "
|
||||
"and will prevent checkpointing. "
|
||||
"Use a module-level named function instead.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return value # type: ignore[no-any-return]
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(
|
||||
f"Expected a callable or dotted-path string, got {type(value).__name__}"
|
||||
)
|
||||
if "." not in value:
|
||||
raise ValueError(
|
||||
f"Invalid callback path {value!r}: expected 'module.name' format"
|
||||
)
|
||||
if not os.environ.get("CREWAI_DESERIALIZE_CALLBACKS"):
|
||||
raise ValueError(
|
||||
f"Refusing to resolve callback path {value!r}: "
|
||||
"set CREWAI_DESERIALIZE_CALLBACKS=1 to allow. "
|
||||
"Only enable this for trusted checkpoint data."
|
||||
)
|
||||
return _resolve_dotted_path(value)
|
||||
|
||||
|
||||
def _resolve_dotted_path(path: str) -> Callable[..., Any]:
|
||||
"""Import a module and walk attribute lookups to resolve a dotted path.
|
||||
|
||||
Handles multi-level qualified names like ``"module.ClassName.method"``
|
||||
by trying progressively shorter module paths and resolving the remainder
|
||||
as chained attribute lookups.
|
||||
|
||||
Args:
|
||||
path: A dotted string e.g. ``"builtins.print"`` or
|
||||
``"mymodule.MyClass.my_method"``.
|
||||
|
||||
Returns:
|
||||
The resolved callable.
|
||||
|
||||
Raises:
|
||||
ValueError: If no valid module can be imported from the path.
|
||||
"""
|
||||
parts = path.split(".")
|
||||
# Try importing progressively shorter prefixes as the module.
|
||||
for i in range(len(parts), 0, -1):
|
||||
module_path = ".".join(parts[:i])
|
||||
try:
|
||||
obj: Any = importlib.import_module(module_path)
|
||||
except (ImportError, TypeError, ValueError):
|
||||
continue
|
||||
# Walk the remaining attribute chain.
|
||||
try:
|
||||
for attr in parts[i:]:
|
||||
obj = getattr(obj, attr)
|
||||
except AttributeError:
|
||||
continue
|
||||
if callable(obj):
|
||||
return obj # type: ignore[no-any-return]
|
||||
raise ValueError(f"Cannot resolve callback {path!r}")
|
||||
|
||||
|
||||
def callable_to_string(fn: Callable[..., Any]) -> str:
|
||||
"""Serialize a callable to its dotted-path string representation.
|
||||
|
||||
Uses ``fn.__module__`` and ``fn.__qualname__`` to produce a string such
|
||||
as ``"builtins.print"``. Lambdas and closures produce paths that contain
|
||||
``<locals>`` and cannot be round-tripped via :func:`string_to_callable`.
|
||||
|
||||
Args:
|
||||
fn: The callable to serialize.
|
||||
|
||||
Returns:
|
||||
A dotted string of the form ``"module.qualname"``.
|
||||
"""
|
||||
module = getattr(fn, "__module__", None)
|
||||
qualname = getattr(fn, "__qualname__", None)
|
||||
if module is None or qualname is None:
|
||||
raise ValueError(
|
||||
f"Cannot serialize {fn!r}: missing __module__ or __qualname__. "
|
||||
"Use a module-level named function for checkpointable callbacks."
|
||||
)
|
||||
return f"{module}.{qualname}"
|
||||
|
||||
|
||||
SerializableCallable = Annotated[
|
||||
Callable[..., Any],
|
||||
BeforeValidator(string_to_callable),
|
||||
PlainSerializer(callable_to_string, return_type=str, when_used="json"),
|
||||
WithJsonSchema({"type": "string"}),
|
||||
]
|
||||
@@ -1690,7 +1690,10 @@ def test_agent_with_knowledge_sources_works_with_copy():
|
||||
with patch(
|
||||
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
|
||||
) as mock_knowledge_storage:
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
|
||||
mock_knowledge_storage_instance = mock_knowledge_storage.return_value
|
||||
mock_knowledge_storage_instance.__class__ = BaseKnowledgeStorage
|
||||
agent.knowledge_storage = mock_knowledge_storage_instance
|
||||
|
||||
agent_copy = agent.copy()
|
||||
|
||||
237
lib/crewai/tests/test_callback.py
Normal file
237
lib/crewai/tests/test_callback.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Tests for crewai.types.callback — SerializableCallable round-tripping."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import os
|
||||
from typing import Any
|
||||
import pytest
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from crewai.types.callback import (
|
||||
SerializableCallable,
|
||||
_is_non_roundtrippable,
|
||||
_resolve_dotted_path,
|
||||
callable_to_string,
|
||||
string_to_callable,
|
||||
)
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def module_level_function() -> str:
|
||||
"""Plain module-level function that should round-trip."""
|
||||
return "hello"
|
||||
|
||||
|
||||
class _CallableInstance:
|
||||
"""Callable class instance — non-roundtrippable."""
|
||||
|
||||
def __call__(self) -> str:
|
||||
return "instance"
|
||||
|
||||
|
||||
class _HasMethod:
|
||||
def method(self) -> str:
|
||||
return "method"
|
||||
|
||||
|
||||
class _Model(BaseModel):
|
||||
cb: SerializableCallable | None = None
|
||||
|
||||
|
||||
# ── _is_non_roundtrippable ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestIsNonRoundtrippable:
|
||||
def test_builtin_is_roundtrippable(self) -> None:
|
||||
assert _is_non_roundtrippable(print) is False
|
||||
assert _is_non_roundtrippable(len) is False
|
||||
|
||||
def test_class_is_roundtrippable(self) -> None:
|
||||
assert _is_non_roundtrippable(dict) is False
|
||||
assert _is_non_roundtrippable(_CallableInstance) is False
|
||||
|
||||
def test_module_level_function_is_roundtrippable(self) -> None:
|
||||
assert _is_non_roundtrippable(module_level_function) is False
|
||||
|
||||
def test_lambda_is_non_roundtrippable(self) -> None:
|
||||
assert _is_non_roundtrippable(lambda: None) is True
|
||||
|
||||
def test_closure_is_non_roundtrippable(self) -> None:
|
||||
x = 1
|
||||
|
||||
def closure() -> int:
|
||||
return x
|
||||
|
||||
assert _is_non_roundtrippable(closure) is True
|
||||
|
||||
def test_bound_method_is_non_roundtrippable(self) -> None:
|
||||
assert _is_non_roundtrippable(_HasMethod().method) is True
|
||||
|
||||
def test_partial_is_non_roundtrippable(self) -> None:
|
||||
assert _is_non_roundtrippable(functools.partial(print, "hi")) is True
|
||||
|
||||
def test_callable_instance_is_non_roundtrippable(self) -> None:
|
||||
assert _is_non_roundtrippable(_CallableInstance()) is True
|
||||
|
||||
|
||||
# ── callable_to_string ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCallableToString:
|
||||
def test_module_level_function(self) -> None:
|
||||
result = callable_to_string(module_level_function)
|
||||
assert result == f"{__name__}.module_level_function"
|
||||
|
||||
def test_class(self) -> None:
|
||||
result = callable_to_string(dict)
|
||||
assert result == "builtins.dict"
|
||||
|
||||
def test_builtin(self) -> None:
|
||||
result = callable_to_string(print)
|
||||
assert result == "builtins.print"
|
||||
|
||||
def test_lambda_produces_locals_path(self) -> None:
|
||||
fn = lambda: None # noqa: E731
|
||||
result = callable_to_string(fn)
|
||||
assert "<lambda>" in result
|
||||
|
||||
def test_missing_qualname_raises(self) -> None:
|
||||
obj = type("NoQual", (), {"__module__": "test"})()
|
||||
obj.__qualname__ = None # type: ignore[assignment]
|
||||
with pytest.raises(ValueError, match="missing __module__ or __qualname__"):
|
||||
callable_to_string(obj)
|
||||
|
||||
def test_missing_module_raises(self) -> None:
|
||||
# Create an object where getattr(obj, "__module__", None) returns None
|
||||
ns: dict[str, Any] = {"__qualname__": "x", "__module__": None}
|
||||
obj = type("NoMod", (), ns)()
|
||||
with pytest.raises(ValueError, match="missing __module__"):
|
||||
callable_to_string(obj)
|
||||
|
||||
|
||||
# ── string_to_callable ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStringToCallable:
|
||||
def test_callable_passthrough(self) -> None:
|
||||
assert string_to_callable(print) is print
|
||||
|
||||
def test_roundtrippable_callable_no_warning(self, recwarn: pytest.WarningsChecker) -> None:
|
||||
string_to_callable(module_level_function)
|
||||
our_warnings = [
|
||||
w for w in recwarn if "cannot be serialized" in str(w.message)
|
||||
]
|
||||
assert our_warnings == []
|
||||
|
||||
def test_non_roundtrippable_warns(self) -> None:
|
||||
with pytest.warns(UserWarning, match="cannot be serialized"):
|
||||
string_to_callable(functools.partial(print))
|
||||
|
||||
def test_non_callable_non_string_raises(self) -> None:
|
||||
with pytest.raises(ValueError, match="Expected a callable"):
|
||||
string_to_callable(42)
|
||||
|
||||
def test_string_without_dot_raises(self) -> None:
|
||||
with pytest.raises(ValueError, match="expected 'module.name' format"):
|
||||
string_to_callable("nodots")
|
||||
|
||||
def test_string_refused_without_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("CREWAI_DESERIALIZE_CALLBACKS", raising=False)
|
||||
with pytest.raises(ValueError, match="Refusing to resolve"):
|
||||
string_to_callable("builtins.print")
|
||||
|
||||
def test_string_resolves_with_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1")
|
||||
result = string_to_callable("builtins.print")
|
||||
assert result is print
|
||||
|
||||
def test_string_resolves_multi_level_path(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1")
|
||||
result = string_to_callable("os.path.join")
|
||||
assert result is os.path.join
|
||||
|
||||
def test_unresolvable_path_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1")
|
||||
with pytest.raises(ValueError, match="Cannot resolve"):
|
||||
string_to_callable("nonexistent.module.func")
|
||||
|
||||
|
||||
# ── _resolve_dotted_path ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestResolveDottedPath:
|
||||
def test_builtin(self) -> None:
|
||||
assert _resolve_dotted_path("builtins.print") is print
|
||||
|
||||
def test_nested_module_attribute(self) -> None:
|
||||
assert _resolve_dotted_path("os.path.join") is os.path.join
|
||||
|
||||
def test_class_on_module(self) -> None:
|
||||
from collections import OrderedDict
|
||||
|
||||
assert _resolve_dotted_path("collections.OrderedDict") is OrderedDict
|
||||
|
||||
def test_nonexistent_raises(self) -> None:
|
||||
with pytest.raises(ValueError, match="Cannot resolve"):
|
||||
_resolve_dotted_path("no.such.module.func")
|
||||
|
||||
def test_non_callable_attribute_skipped(self) -> None:
|
||||
# os.sep is a string, not callable — should not resolve
|
||||
with pytest.raises(ValueError, match="Cannot resolve"):
|
||||
_resolve_dotted_path("os.sep")
|
||||
|
||||
|
||||
# ── Pydantic integration round-trip ──────────────────────────────────
|
||||
|
||||
|
||||
class TestSerializableCallableRoundTrip:
|
||||
def test_json_serialize_module_function(self) -> None:
|
||||
m = _Model(cb=module_level_function)
|
||||
data = m.model_dump(mode="json")
|
||||
assert data["cb"] == f"{__name__}.module_level_function"
|
||||
|
||||
def test_json_round_trip(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1")
|
||||
m = _Model(cb=print)
|
||||
json_str = m.model_dump_json()
|
||||
restored = _Model.model_validate_json(json_str)
|
||||
assert restored.cb is print
|
||||
|
||||
def test_json_round_trip_class(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1")
|
||||
m = _Model(cb=dict)
|
||||
json_str = m.model_dump_json()
|
||||
restored = _Model.model_validate_json(json_str)
|
||||
assert restored.cb is dict
|
||||
|
||||
def test_python_mode_preserves_callable(self) -> None:
|
||||
m = _Model(cb=module_level_function)
|
||||
data = m.model_dump(mode="python")
|
||||
assert data["cb"] is module_level_function
|
||||
|
||||
def test_none_field(self) -> None:
|
||||
m = _Model(cb=None)
|
||||
assert m.cb is None
|
||||
data = m.model_dump(mode="json")
|
||||
assert data["cb"] is None
|
||||
|
||||
def test_validation_error_for_int(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
_Model(cb=42) # type: ignore[arg-type]
|
||||
|
||||
def test_deserialization_refused_without_env(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
monkeypatch.delenv("CREWAI_DESERIALIZE_CALLBACKS", raising=False)
|
||||
with pytest.raises(ValidationError, match="Refusing to resolve"):
|
||||
_Model.model_validate({"cb": "builtins.print"})
|
||||
|
||||
def test_json_schema_is_string(self) -> None:
|
||||
schema = _Model.model_json_schema()
|
||||
cb_schema = schema["properties"]["cb"]
|
||||
# anyOf for Optional: one string, one null
|
||||
types = {item.get("type") for item in cb_schema.get("anyOf", [cb_schema])}
|
||||
assert "string" in types
|
||||
@@ -6,6 +6,7 @@ from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.crew import Crew
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.project import (
|
||||
CrewBase,
|
||||
after_kickoff,
|
||||
@@ -371,9 +372,12 @@ def test_internal_crew_with_mcp():
|
||||
mock_adapter = Mock()
|
||||
mock_adapter.tools = ToolCollection([simple_tool, another_simple_tool])
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm.__class__ = BaseLLM
|
||||
|
||||
with (
|
||||
patch("crewai_tools.MCPServerAdapter", return_value=mock_adapter) as adapter_mock,
|
||||
patch("crewai.llm.LLM.__new__", return_value=Mock()),
|
||||
patch("crewai.llm.LLM.__new__", return_value=mock_llm),
|
||||
):
|
||||
crew = InternalCrewWithMCP()
|
||||
assert crew.reporting_analyst().tools == [simple_tool, another_simple_tool]
|
||||
|
||||
Reference in New Issue
Block a user