Compare commits

...

15 Commits

Author SHA1 Message Date
Greyson LaLonde
9929aea830 fix: convert remaining callback lists to SerializableCallable and add tests
Address PR review: convert before_kickoff_callbacks, after_kickoff_callbacks
(crew.py) and callbacks (base_agent.py) from plain Callable to
SerializableCallable for consistent JSON round-trip support. Add 36 unit
tests for callback.py covering serialization, deserialization, env-var
gating, round-trip detection, and full Pydantic integration.
2026-03-20 13:29:29 -04:00
Greyson LaLonde
0a3d2f596e fix: catch TypeError from relative imports in dotted path resolution 2026-03-19 21:35:57 -04:00
Greyson LaLonde
0325703901 fix: gate callback string resolution behind CREWAI_DESERIALIZE_CALLBACKS env var 2026-03-19 21:33:59 -04:00
Greyson LaLonde
c1abefbbf3 fix: catch ImportError in dotted path resolution for broken modules 2026-03-19 21:25:29 -04:00
Greyson LaLonde
870113335b fix: allow classes in roundtrip check and use type-specific warning messages 2026-03-19 21:14:01 -04:00
Greyson LaLonde
82b9b98fd2 fix: raise clear ValueError when serializing callables missing __qualname__ 2026-03-19 21:00:55 -04:00
Greyson LaLonde
f2223281a9 fix: warn for bound methods, partials, and callable instances in callbacks 2026-03-19 20:46:56 -04:00
Greyson LaLonde
47338c1efa fix: resolve multi-level qualified names in callback deserialization 2026-03-19 20:43:55 -04:00
Greyson LaLonde
0732582241 fix: guard remember_many behind Memory isinstance check 2026-03-19 20:41:06 -04:00
Greyson LaLonde
56eaa1d27b fix: raise ValueError in callback validator for clean Pydantic error messages 2026-03-19 20:38:21 -04:00
Greyson LaLonde
756e939543 fix: serialize callbacks to string only in JSON mode to preserve closures during copy 2026-03-19 20:13:54 -04:00
Greyson LaLonde
d983eca2dd fix: mock LLM must pass InstanceOf[BaseLLM] validation in MCP test 2026-03-19 19:57:13 -04:00
Greyson LaLonde
0d6b2ef8b8 fix: detect closures and nested functions as non-roundtrippable callbacks 2026-03-19 19:45:05 -04:00
Greyson LaLonde
aadb2d6694 refactor: type Flow memory and input_provider fields 2026-03-19 19:37:14 -04:00
Greyson LaLonde
395f6f1ae2 refactor: replace Any-typed callback and model fields with serializable types 2026-03-19 19:17:24 -04:00
11 changed files with 452 additions and 42 deletions

View File

@@ -66,6 +66,7 @@ from crewai.mcp.tool_resolver import MCPToolResolver
from crewai.rag.embeddings.types import EmbedderConfig
from crewai.security.fingerprint import Fingerprint
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.types.callback import SerializableCallable
from crewai.utilities.agent_utils import (
get_tool_names,
is_inside_event_loop,
@@ -143,7 +144,7 @@ class Agent(BaseAgent):
default=None,
description="Maximum execution time for an agent to execute a task",
)
step_callback: Any | None = Field(
step_callback: SerializableCallable | None = Field(
default=None,
description="Callback to be executed after each step of the agent execution.",
)
@@ -151,10 +152,10 @@ class Agent(BaseAgent):
default=True,
description="Use system prompt for the agent.",
)
llm: str | InstanceOf[BaseLLM] | Any = Field(
llm: str | InstanceOf[BaseLLM] | None = Field(
description="Language model that will run the agent.", default=None
)
function_calling_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
function_calling_llm: str | InstanceOf[BaseLLM] | None = Field(
description="Language model that will run the agent.", default=None
)
system_template: str | None = Field(
@@ -340,7 +341,7 @@ class Agent(BaseAgent):
return (
hasattr(self.llm, "supports_function_calling")
and callable(getattr(self.llm, "supports_function_calling", None))
and self.llm.supports_function_calling()
and self.llm.supports_function_calling() # type: ignore[union-attr]
and len(tools) > 0
)

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Callable
from copy import copy as shallow_copy
from hashlib import md5
import re
@@ -12,6 +11,7 @@ from pydantic import (
UUID4,
BaseModel,
Field,
InstanceOf,
PrivateAttr,
field_validator,
model_validator,
@@ -26,10 +26,14 @@ from crewai.agents.tools_handler import ToolsHandler
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.knowledge_config import KnowledgeConfig
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.mcp.config import MCPServerConfig
from crewai.memory.memory_scope import MemoryScope, MemorySlice
from crewai.memory.unified_memory import Memory
from crewai.rag.embeddings.types import EmbedderConfig
from crewai.security.security_config import SecurityConfig
from crewai.tools.base_tool import BaseTool, Tool
from crewai.types.callback import SerializableCallable
from crewai.utilities.config import process_config
from crewai.utilities.i18n import I18N, get_i18n
from crewai.utilities.logger import Logger
@@ -179,7 +183,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
default=None,
description="Knowledge sources for the agent.",
)
knowledge_storage: Any | None = Field(
knowledge_storage: InstanceOf[BaseKnowledgeStorage] | None = Field(
default=None,
description="Custom knowledge storage for the agent.",
)
@@ -187,7 +191,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
default_factory=SecurityConfig,
description="Security configuration for the agent, including fingerprinting.",
)
callbacks: list[Callable[[Any], Any]] = Field(
callbacks: list[SerializableCallable] = Field(
default_factory=list, description="Callbacks to be used for the agent"
)
adapted_agent: bool = Field(
@@ -205,7 +209,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
default=None,
description="List of MCP server references. Supports 'https://server.com/path' for external servers and bare slugs like 'notion' for connected MCP integrations. Use '#tool_name' suffix for specific tools.",
)
memory: Any = Field(
memory: bool | Memory | MemoryScope | MemorySlice | None = Field(
default=None,
description=(
"Enable agent memory. Pass True for default Memory(), "

View File

@@ -35,6 +35,7 @@ from typing_extensions import Self
if TYPE_CHECKING:
from crewai_files import FileInput
from opentelemetry.trace import Span
try:
from crewai_files import get_supported_content_types
@@ -83,6 +84,8 @@ from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.memory.memory_scope import MemoryScope, MemorySlice
from crewai.memory.unified_memory import Memory
from crewai.process import Process
from crewai.rag.embeddings.types import EmbedderConfig
from crewai.rag.types import SearchResult
@@ -94,6 +97,7 @@ from crewai.tasks.task_output import TaskOutput
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.agent_tools.read_file_tool import ReadFileTool
from crewai.tools.base_tool import BaseTool
from crewai.types.callback import SerializableCallable
from crewai.types.streaming import CrewStreamingOutput
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
@@ -166,12 +170,12 @@ class Crew(FlowTrackable, BaseModel):
"""
__hash__ = object.__hash__
_execution_span: Any = PrivateAttr()
_execution_span: Span | None = PrivateAttr()
_rpm_controller: RPMController = PrivateAttr()
_logger: Logger = PrivateAttr()
_file_handler: FileHandler = PrivateAttr()
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default_factory=CacheHandler)
_memory: Any = PrivateAttr(default=None) # Unified Memory | MemoryScope
_memory: Memory | MemoryScope | MemorySlice | None = PrivateAttr(default=None)
_train: bool | None = PrivateAttr(default=False)
_train_iteration: int | None = PrivateAttr()
_inputs: dict[str, Any] | None = PrivateAttr(default=None)
@@ -189,7 +193,7 @@ class Crew(FlowTrackable, BaseModel):
agents: list[BaseAgent] = Field(default_factory=list)
process: Process = Field(default=Process.sequential)
verbose: bool = Field(default=False)
memory: bool | Any = Field(
memory: bool | Memory | MemoryScope | MemorySlice | None = Field(
default=False,
description=(
"Enable crew memory. Pass True for default Memory(), "
@@ -204,36 +208,34 @@ class Crew(FlowTrackable, BaseModel):
default=None,
description="Metrics for the LLM usage during all tasks execution.",
)
manager_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
manager_llm: str | InstanceOf[BaseLLM] | None = Field(
description="Language model that will run the agent.", default=None
)
manager_agent: BaseAgent | None = Field(
description="Custom agent that will be used as manager.", default=None
)
function_calling_llm: str | InstanceOf[LLM] | Any | None = Field(
function_calling_llm: str | InstanceOf[LLM] | None = Field(
description="Language model that will run the agent.", default=None
)
config: Json[dict[str, Any]] | dict[str, Any] | None = Field(default=None)
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
share_crew: bool | None = Field(default=False)
step_callback: Any | None = Field(
step_callback: SerializableCallable | None = Field(
default=None,
description="Callback to be executed after each step for all agents execution.",
)
task_callback: Any | None = Field(
task_callback: SerializableCallable | None = Field(
default=None,
description="Callback to be executed after each task for all agents execution.",
)
before_kickoff_callbacks: list[
Callable[[dict[str, Any] | None], dict[str, Any] | None]
] = Field(
before_kickoff_callbacks: list[SerializableCallable] = Field(
default_factory=list,
description=(
"List of callbacks to be executed before crew kickoff. "
"It may be used to adjust inputs before the crew is executed."
),
)
after_kickoff_callbacks: list[Callable[[CrewOutput], CrewOutput]] = Field(
after_kickoff_callbacks: list[SerializableCallable] = Field(
default_factory=list,
description=(
"List of callbacks to be executed after crew kickoff. "
@@ -349,7 +351,7 @@ class Crew(FlowTrackable, BaseModel):
self._file_handler = FileHandler(self.output_log_file)
self._rpm_controller = RPMController(max_rpm=self.max_rpm, logger=self._logger)
if self.function_calling_llm and not isinstance(self.function_calling_llm, LLM):
self.function_calling_llm = create_llm(self.function_calling_llm)
self.function_calling_llm = create_llm(self.function_calling_llm) # type: ignore[assignment]
return self
@@ -363,7 +365,7 @@ class Crew(FlowTrackable, BaseModel):
if self.embedder is not None:
from crewai.rag.embeddings.factory import build_embedder
embedder = build_embedder(self.embedder)
embedder = build_embedder(self.embedder) # type: ignore[arg-type]
self._memory = Memory(embedder=embedder)
elif self.memory:
# User passed a Memory / MemoryScope / MemorySlice instance

View File

@@ -81,6 +81,7 @@ from crewai.flow.flow_wrappers import (
SimpleFlowCondition,
StartMethod,
)
from crewai.flow.input_provider import InputProvider
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.types import (
FlowExecutionData,
@@ -99,6 +100,8 @@ from crewai.flow.utils import (
is_flow_method_name,
is_simple_flow_condition,
)
from crewai.memory.memory_scope import MemoryScope, MemorySlice
from crewai.memory.unified_memory import Memory
if TYPE_CHECKING:
@@ -501,7 +504,7 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
def index(
self, value: T, start: SupportsIndex = 0, stop: SupportsIndex | None = None
) -> int: # type: ignore[override]
) -> int:
if stop is None:
return self._list.index(value, start)
return self._list.index(value, start, stop)
@@ -520,13 +523,13 @@ class LockedListProxy(list, Generic[T]): # type: ignore[type-arg]
def copy(self) -> list[T]:
return self._list.copy()
def __add__(self, other: list[T]) -> list[T]:
def __add__(self, other: list[T]) -> list[T]: # type: ignore[override]
return self._list + other
def __radd__(self, other: list[T]) -> list[T]:
return other + self._list
def __iadd__(self, other: Iterable[T]) -> LockedListProxy[T]:
def __iadd__(self, other: Iterable[T]) -> LockedListProxy[T]: # type: ignore[override]
with self._lock:
self._list += list(other)
return self
@@ -630,13 +633,13 @@ class LockedDictProxy(dict, Generic[T]): # type: ignore[type-arg]
def copy(self) -> dict[str, T]:
return self._dict.copy()
def __or__(self, other: dict[str, T]) -> dict[str, T]:
def __or__(self, other: dict[str, T]) -> dict[str, T]: # type: ignore[override]
return self._dict | other
def __ror__(self, other: dict[str, T]) -> dict[str, T]:
def __ror__(self, other: dict[str, T]) -> dict[str, T]: # type: ignore[override]
return other | self._dict
def __ior__(self, other: dict[str, T]) -> LockedDictProxy[T]:
def __ior__(self, other: dict[str, T]) -> LockedDictProxy[T]: # type: ignore[override]
with self._lock:
self._dict |= other
return self
@@ -822,10 +825,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
name: str | None = None
tracing: bool | None = None
stream: bool = False
memory: Any = (
None # Memory | MemoryScope | MemorySlice | None; auto-created if not set
)
input_provider: Any = None # InputProvider | None; per-flow override for self.ask()
memory: Memory | MemoryScope | MemorySlice | None = None
input_provider: InputProvider | None = None
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]:
class _FlowGeneric(cls): # type: ignore
@@ -904,8 +905,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
# Internal flows (RecallFlow, EncodingFlow) set _skip_auto_memory
# to avoid creating a wasteful standalone Memory instance.
if self.memory is None and not getattr(self, "_skip_auto_memory", False):
from crewai.memory.unified_memory import Memory
self.memory = Memory()
# Register all flow-related methods
@@ -951,10 +950,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
Raises:
ValueError: If no memory is configured for this flow.
TypeError: If batch remember is attempted on a MemoryScope or MemorySlice.
"""
if self.memory is None:
raise ValueError("No memory configured for this flow")
if isinstance(content, list):
if not isinstance(self.memory, Memory):
raise TypeError(
"Batch remember requires a Memory instance, "
f"got {type(self.memory).__name__}"
)
return self.memory.remember_many(content, **kwargs)
return self.memory.remember(content, **kwargs)
@@ -2725,7 +2730,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
# ── User Input (self.ask) ────────────────────────────────────────
def _resolve_input_provider(self) -> Any:
def _resolve_input_provider(self) -> InputProvider:
"""Resolve the input provider using the priority chain.
Resolution order:

View File

@@ -6,7 +6,7 @@ customize Flow behavior at runtime.
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@@ -32,17 +32,17 @@ class FlowConfig:
self._input_provider: InputProvider | None = None
@property
def hitl_provider(self) -> Any:
def hitl_provider(self) -> HumanFeedbackProvider | None:
"""Get the configured HITL provider."""
return self._hitl_provider
@hitl_provider.setter
def hitl_provider(self, provider: Any) -> None:
def hitl_provider(self, provider: HumanFeedbackProvider | None) -> None:
"""Set the HITL provider."""
self._hitl_provider = provider
@property
def input_provider(self) -> Any:
def input_provider(self) -> InputProvider | None:
"""Get the configured input provider for ``Flow.ask()``.
Returns:
@@ -52,7 +52,7 @@ class FlowConfig:
return self._input_provider
@input_provider.setter
def input_provider(self, provider: Any) -> None:
def input_provider(self, provider: InputProvider | None) -> None:
"""Set the input provider for ``Flow.ask()``.
Args:

View File

@@ -22,7 +22,6 @@ from crewai.events.types.memory_events import (
)
from crewai.llms.base_llm import BaseLLM
from crewai.memory.analyze import extract_memories_from_content
from crewai.memory.recall_flow import RecallFlow
from crewai.memory.storage.backend import StorageBackend
from crewai.memory.types import (
MemoryConfig,
@@ -620,6 +619,8 @@ class Memory(BaseModel):
)
results.sort(key=lambda m: m.score, reverse=True)
else:
from crewai.memory.recall_flow import RecallFlow
flow = RecallFlow(
storage=self._storage,
llm=self._llm,

View File

@@ -67,6 +67,7 @@ except ImportError:
return []
from crewai.types.callback import SerializableCallable
from crewai.utilities.guardrail import (
process_guardrail,
)
@@ -124,7 +125,7 @@ class Task(BaseModel):
description="Configuration for the agent",
default=None,
)
callback: Any | None = Field(
callback: SerializableCallable | None = Field(
description="Callback to be executed after the task is completed.", default=None
)
agent: BaseAgent | None = Field(

View File

@@ -0,0 +1,152 @@
"""Serializable callback type for Pydantic models.
Provides a ``SerializableCallable`` type alias that enables full JSON
round-tripping of callback fields, e.g. ``"builtins.print"`` ↔ ``print``.
Lambdas and closures serialize to a dotted path but cannot be deserialized
back — use module-level named functions for checkpointable callbacks.
"""
from __future__ import annotations
from collections.abc import Callable
import importlib
import inspect
import os
from typing import Annotated, Any
import warnings
from pydantic import BeforeValidator, WithJsonSchema
from pydantic.functional_serializers import PlainSerializer
def _is_non_roundtrippable(fn: object) -> bool:
"""Return ``True`` if *fn* cannot survive a serialize/deserialize round-trip.
Built-in functions, plain module-level functions, and classes produce
dotted paths that :func:`_resolve_dotted_path` can reliably resolve.
Bound methods, ``functools.partial`` objects, callable class instances,
lambdas, and closures all fail or silently change semantics during
round-tripping.
Args:
fn: The object to check.
Returns:
``True`` if *fn* would not round-trip through JSON serialization.
"""
if inspect.isbuiltin(fn) or inspect.isclass(fn):
return False
if inspect.isfunction(fn):
qualname = getattr(fn, "__qualname__", "")
return qualname.endswith("<lambda>") or "<locals>" in qualname
return True
def string_to_callable(value: Any) -> Callable[..., Any]:
"""Convert a dotted path string to the callable it references.
If *value* is already callable it is returned as-is, with a warning if
it cannot survive JSON round-tripping. Otherwise, it is treated as
``"module.qualname"`` and resolved via :func:`_resolve_dotted_path`.
Args:
value: A callable or a dotted-path string e.g. ``"builtins.print"``.
Returns:
The resolved callable.
Raises:
ValueError: If *value* is not callable or a resolvable dotted-path string.
"""
if callable(value):
if _is_non_roundtrippable(value):
warnings.warn(
f"{type(value).__name__} callbacks cannot be serialized "
"and will prevent checkpointing. "
"Use a module-level named function instead.",
UserWarning,
stacklevel=2,
)
return value # type: ignore[no-any-return]
if not isinstance(value, str):
raise ValueError(
f"Expected a callable or dotted-path string, got {type(value).__name__}"
)
if "." not in value:
raise ValueError(
f"Invalid callback path {value!r}: expected 'module.name' format"
)
if not os.environ.get("CREWAI_DESERIALIZE_CALLBACKS"):
raise ValueError(
f"Refusing to resolve callback path {value!r}: "
"set CREWAI_DESERIALIZE_CALLBACKS=1 to allow. "
"Only enable this for trusted checkpoint data."
)
return _resolve_dotted_path(value)
def _resolve_dotted_path(path: str) -> Callable[..., Any]:
"""Import a module and walk attribute lookups to resolve a dotted path.
Handles multi-level qualified names like ``"module.ClassName.method"``
by trying progressively shorter module paths and resolving the remainder
as chained attribute lookups.
Args:
path: A dotted string e.g. ``"builtins.print"`` or
``"mymodule.MyClass.my_method"``.
Returns:
The resolved callable.
Raises:
ValueError: If no valid module can be imported from the path.
"""
parts = path.split(".")
# Try importing progressively shorter prefixes as the module.
for i in range(len(parts), 0, -1):
module_path = ".".join(parts[:i])
try:
obj: Any = importlib.import_module(module_path)
except (ImportError, TypeError, ValueError):
continue
# Walk the remaining attribute chain.
try:
for attr in parts[i:]:
obj = getattr(obj, attr)
except AttributeError:
continue
if callable(obj):
return obj # type: ignore[no-any-return]
raise ValueError(f"Cannot resolve callback {path!r}")
def callable_to_string(fn: Callable[..., Any]) -> str:
"""Serialize a callable to its dotted-path string representation.
Uses ``fn.__module__`` and ``fn.__qualname__`` to produce a string such
as ``"builtins.print"``. Lambdas and closures produce paths that contain
``<locals>`` and cannot be round-tripped via :func:`string_to_callable`.
Args:
fn: The callable to serialize.
Returns:
A dotted string of the form ``"module.qualname"``.
"""
module = getattr(fn, "__module__", None)
qualname = getattr(fn, "__qualname__", None)
if module is None or qualname is None:
raise ValueError(
f"Cannot serialize {fn!r}: missing __module__ or __qualname__. "
"Use a module-level named function for checkpointable callbacks."
)
return f"{module}.{qualname}"
SerializableCallable = Annotated[
Callable[..., Any],
BeforeValidator(string_to_callable),
PlainSerializer(callable_to_string, return_type=str, when_used="json"),
WithJsonSchema({"type": "string"}),
]

View File

@@ -1690,7 +1690,10 @@ def test_agent_with_knowledge_sources_works_with_copy():
with patch(
"crewai.knowledge.storage.knowledge_storage.KnowledgeStorage"
) as mock_knowledge_storage:
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
mock_knowledge_storage_instance = mock_knowledge_storage.return_value
mock_knowledge_storage_instance.__class__ = BaseKnowledgeStorage
agent.knowledge_storage = mock_knowledge_storage_instance
agent_copy = agent.copy()

View File

@@ -0,0 +1,237 @@
"""Tests for crewai.types.callback — SerializableCallable round-tripping."""
from __future__ import annotations
import functools
import os
from typing import Any
import pytest
from pydantic import BaseModel, ValidationError
from crewai.types.callback import (
SerializableCallable,
_is_non_roundtrippable,
_resolve_dotted_path,
callable_to_string,
string_to_callable,
)
# ── Helpers ──────────────────────────────────────────────────────────
def module_level_function() -> str:
"""Plain module-level function that should round-trip."""
return "hello"
class _CallableInstance:
"""Callable class instance — non-roundtrippable."""
def __call__(self) -> str:
return "instance"
class _HasMethod:
def method(self) -> str:
return "method"
class _Model(BaseModel):
cb: SerializableCallable | None = None
# ── _is_non_roundtrippable ───────────────────────────────────────────
class TestIsNonRoundtrippable:
def test_builtin_is_roundtrippable(self) -> None:
assert _is_non_roundtrippable(print) is False
assert _is_non_roundtrippable(len) is False
def test_class_is_roundtrippable(self) -> None:
assert _is_non_roundtrippable(dict) is False
assert _is_non_roundtrippable(_CallableInstance) is False
def test_module_level_function_is_roundtrippable(self) -> None:
assert _is_non_roundtrippable(module_level_function) is False
def test_lambda_is_non_roundtrippable(self) -> None:
assert _is_non_roundtrippable(lambda: None) is True
def test_closure_is_non_roundtrippable(self) -> None:
x = 1
def closure() -> int:
return x
assert _is_non_roundtrippable(closure) is True
def test_bound_method_is_non_roundtrippable(self) -> None:
assert _is_non_roundtrippable(_HasMethod().method) is True
def test_partial_is_non_roundtrippable(self) -> None:
assert _is_non_roundtrippable(functools.partial(print, "hi")) is True
def test_callable_instance_is_non_roundtrippable(self) -> None:
assert _is_non_roundtrippable(_CallableInstance()) is True
# ── callable_to_string ───────────────────────────────────────────────
class TestCallableToString:
def test_module_level_function(self) -> None:
result = callable_to_string(module_level_function)
assert result == f"{__name__}.module_level_function"
def test_class(self) -> None:
result = callable_to_string(dict)
assert result == "builtins.dict"
def test_builtin(self) -> None:
result = callable_to_string(print)
assert result == "builtins.print"
def test_lambda_produces_locals_path(self) -> None:
fn = lambda: None # noqa: E731
result = callable_to_string(fn)
assert "<lambda>" in result
def test_missing_qualname_raises(self) -> None:
obj = type("NoQual", (), {"__module__": "test"})()
obj.__qualname__ = None # type: ignore[assignment]
with pytest.raises(ValueError, match="missing __module__ or __qualname__"):
callable_to_string(obj)
def test_missing_module_raises(self) -> None:
# Create an object where getattr(obj, "__module__", None) returns None
ns: dict[str, Any] = {"__qualname__": "x", "__module__": None}
obj = type("NoMod", (), ns)()
with pytest.raises(ValueError, match="missing __module__"):
callable_to_string(obj)
# ── string_to_callable ───────────────────────────────────────────────
class TestStringToCallable:
def test_callable_passthrough(self) -> None:
assert string_to_callable(print) is print
def test_roundtrippable_callable_no_warning(self, recwarn: pytest.WarningsChecker) -> None:
string_to_callable(module_level_function)
our_warnings = [
w for w in recwarn if "cannot be serialized" in str(w.message)
]
assert our_warnings == []
def test_non_roundtrippable_warns(self) -> None:
with pytest.warns(UserWarning, match="cannot be serialized"):
string_to_callable(functools.partial(print))
def test_non_callable_non_string_raises(self) -> None:
with pytest.raises(ValueError, match="Expected a callable"):
string_to_callable(42)
def test_string_without_dot_raises(self) -> None:
with pytest.raises(ValueError, match="expected 'module.name' format"):
string_to_callable("nodots")
def test_string_refused_without_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv("CREWAI_DESERIALIZE_CALLBACKS", raising=False)
with pytest.raises(ValueError, match="Refusing to resolve"):
string_to_callable("builtins.print")
def test_string_resolves_with_env_var(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1")
result = string_to_callable("builtins.print")
assert result is print
def test_string_resolves_multi_level_path(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1")
result = string_to_callable("os.path.join")
assert result is os.path.join
def test_unresolvable_path_raises(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1")
with pytest.raises(ValueError, match="Cannot resolve"):
string_to_callable("nonexistent.module.func")
# ── _resolve_dotted_path ─────────────────────────────────────────────
class TestResolveDottedPath:
def test_builtin(self) -> None:
assert _resolve_dotted_path("builtins.print") is print
def test_nested_module_attribute(self) -> None:
assert _resolve_dotted_path("os.path.join") is os.path.join
def test_class_on_module(self) -> None:
from collections import OrderedDict
assert _resolve_dotted_path("collections.OrderedDict") is OrderedDict
def test_nonexistent_raises(self) -> None:
with pytest.raises(ValueError, match="Cannot resolve"):
_resolve_dotted_path("no.such.module.func")
def test_non_callable_attribute_skipped(self) -> None:
# os.sep is a string, not callable — should not resolve
with pytest.raises(ValueError, match="Cannot resolve"):
_resolve_dotted_path("os.sep")
# ── Pydantic integration round-trip ──────────────────────────────────
class TestSerializableCallableRoundTrip:
def test_json_serialize_module_function(self) -> None:
m = _Model(cb=module_level_function)
data = m.model_dump(mode="json")
assert data["cb"] == f"{__name__}.module_level_function"
def test_json_round_trip(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1")
m = _Model(cb=print)
json_str = m.model_dump_json()
restored = _Model.model_validate_json(json_str)
assert restored.cb is print
def test_json_round_trip_class(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("CREWAI_DESERIALIZE_CALLBACKS", "1")
m = _Model(cb=dict)
json_str = m.model_dump_json()
restored = _Model.model_validate_json(json_str)
assert restored.cb is dict
def test_python_mode_preserves_callable(self) -> None:
m = _Model(cb=module_level_function)
data = m.model_dump(mode="python")
assert data["cb"] is module_level_function
def test_none_field(self) -> None:
m = _Model(cb=None)
assert m.cb is None
data = m.model_dump(mode="json")
assert data["cb"] is None
def test_validation_error_for_int(self) -> None:
with pytest.raises(ValidationError):
_Model(cb=42) # type: ignore[arg-type]
def test_deserialization_refused_without_env(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
monkeypatch.delenv("CREWAI_DESERIALIZE_CALLBACKS", raising=False)
with pytest.raises(ValidationError, match="Refusing to resolve"):
_Model.model_validate({"cb": "builtins.print"})
def test_json_schema_is_string(self) -> None:
schema = _Model.model_json_schema()
cb_schema = schema["properties"]["cb"]
# anyOf for Optional: one string, one null
types = {item.get("type") for item in cb_schema.get("anyOf", [cb_schema])}
assert "string" in types

View File

@@ -6,6 +6,7 @@ from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.crew import Crew
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.project import (
CrewBase,
after_kickoff,
@@ -371,9 +372,12 @@ def test_internal_crew_with_mcp():
mock_adapter = Mock()
mock_adapter.tools = ToolCollection([simple_tool, another_simple_tool])
mock_llm = Mock()
mock_llm.__class__ = BaseLLM
with (
patch("crewai_tools.MCPServerAdapter", return_value=mock_adapter) as adapter_mock,
patch("crewai.llm.LLM.__new__", return_value=Mock()),
patch("crewai.llm.LLM.__new__", return_value=mock_llm),
):
crew = InternalCrewWithMCP()
assert crew.reporting_analyst().tools == [simple_tool, another_simple_tool]