fix: bot review follow-ups for serialization hardening

- Reject classes and builtin values in _instance_to_dotted_path
- Require classes in _dotted_path_to_instance
- Drop unused SerializableInstance alias
- Raise on unknown FlowPersistence types in _serialize_persistence
- Gate Knowledge.embedder provider_class restore behind
  CREWAI_DESERIALIZE_CALLBACKS
- Raise on unknown source_type tags in _resolve_knowledge_sources
- Tighten _backfill_source_type: only infer 'string' when content is
  str; raise otherwise so legacy file-based sources fail loudly
- Add BeforeValidator(_ensure_memory_kind) to Crew/Agent/Flow memory
  fields so legacy dict configs get the discriminator at construction
- Default MemoryScope/MemorySlice._memory to None; add _require_memory()
  helper and route all internal accesses through it
- Convert test_flow_ask persistence mocks to RecordingPersistence
This commit is contained in:
Greyson LaLonde
2026-05-21 00:47:40 +08:00
parent 0f3a57b3b9
commit 3ceb9a287a
8 changed files with 178 additions and 59 deletions

View File

@@ -37,7 +37,7 @@ from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.llms.base_llm import BaseLLM
from crewai.mcp.config import MCPServerConfig
from crewai.memory.memory_scope import MemoryScope, MemorySlice
from crewai.memory.memory_scope import MemoryScope, MemorySlice, _ensure_memory_kind
from crewai.memory.unified_memory import Memory
from crewai.rag.embeddings.types import EmbedderConfig
from crewai.security.security_config import SecurityConfig
@@ -332,13 +332,14 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
default=None,
description="List of MCP server references. Supports 'https://server.com/path' for external servers and bare slugs like 'notion' for connected MCP integrations. Use '#tool_name' suffix for specific tools.",
)
memory: (
memory: Annotated[
bool
| Annotated[
Memory | MemoryScope | MemorySlice, Field(discriminator="memory_kind")
]
| None
) = Field(
| None,
BeforeValidator(_ensure_memory_kind),
] = Field(
default=None,
description=(
"Enable agent memory. Pass True for default Memory(), "

View File

@@ -97,7 +97,7 @@ from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.memory.memory_scope import MemoryScope, MemorySlice
from crewai.memory.memory_scope import MemoryScope, MemorySlice, _ensure_memory_kind
from crewai.memory.unified_memory import Memory
from crewai.process import Process
from crewai.rag.embeddings.types import EmbedderConfig
@@ -223,13 +223,14 @@ class Crew(FlowTrackable, BaseModel):
] = Field(default_factory=list)
process: Process = Field(default=Process.sequential)
verbose: bool = Field(default=False)
memory: (
memory: Annotated[
bool
| Annotated[
Memory | MemoryScope | MemorySlice, Field(discriminator="memory_kind")
]
| None
) = Field(
| None,
BeforeValidator(_ensure_memory_kind),
] = Field(
default=False,
description=(
"Enable crew memory. Pass True for default Memory(), "

View File

@@ -113,7 +113,7 @@ from crewai.flow.utils import (
is_flow_method_name,
is_simple_flow_condition,
)
from crewai.memory.memory_scope import MemoryScope, MemorySlice
from crewai.memory.memory_scope import MemoryScope, MemorySlice, _ensure_memory_kind
from crewai.memory.unified_memory import Memory
from crewai.state.checkpoint_config import (
CheckpointConfig,
@@ -164,7 +164,10 @@ def _serialize_persistence(value: Any) -> dict[str, Any] | None:
return None
if isinstance(value, FlowPersistence):
return value.model_dump(mode="json")
return None
raise TypeError(
f"Cannot serialize Flow.persistence of type {type(value).__name__}: "
"expected FlowPersistence or None."
)
def _validate_input_provider(value: Any) -> Any:
@@ -979,12 +982,13 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
name: str | None = Field(default=None)
tracing: bool | None = Field(default=None)
stream: bool = Field(default=False)
memory: (
memory: Annotated[
Annotated[
Memory | MemoryScope | MemorySlice, Field(discriminator="memory_kind")
]
| None
) = Field(default=None)
| None,
BeforeValidator(_ensure_memory_kind),
] = Field(default=None)
input_provider: Annotated[
InputProvider | None,
BeforeValidator(_validate_input_provider),

View File

@@ -41,17 +41,22 @@ def _resolve_knowledge_sources(value: Any) -> Any:
for idx, item in enumerate(value):
if isinstance(item, dict):
tag = item.get("source_type")
cls = _KNOWN_SOURCES.get(tag) if isinstance(tag, str) else None
if cls is None:
if not isinstance(tag, str):
resolved.append(item)
else:
try:
resolved.append(cls.model_validate(item))
except Exception as exc:
raise ValueError(
f"Failed to validate knowledge source at index {idx} "
f"with source_type={tag!r}: {exc}"
) from exc
continue
cls = _KNOWN_SOURCES.get(tag)
if cls is None:
raise ValueError(
f"Unknown source_type={tag!r} at index {idx}: "
f"expected one of {sorted(_KNOWN_SOURCES)}"
)
try:
resolved.append(cls.model_validate(item))
except Exception as exc:
raise ValueError(
f"Failed to validate knowledge source at index {idx} "
f"with source_type={tag!r}: {exc}"
) from exc
else:
resolved.append(item)
return resolved
@@ -78,6 +83,13 @@ def _serialize_embedder_spec(value: Any) -> dict[str, Any] | None:
def _validate_embedder_spec(value: Any) -> Any:
"""Resolve provider_class dotted-path dicts back to a class on restore."""
if isinstance(value, dict) and set(value.keys()) == {"provider_class"}:
if not os.environ.get("CREWAI_DESERIALIZE_CALLBACKS"):
raise ValueError(
f"Refusing to resolve embedder provider_class "
f"{value['provider_class']!r}: set "
"CREWAI_DESERIALIZE_CALLBACKS=1 to allow. Only enable this "
"for trusted checkpoint data."
)
from crewai.types.callback import _resolve_dotted_path
cls = _resolve_dotted_path(value["provider_class"])

View File

@@ -17,6 +17,24 @@ from crewai.memory.types import (
from crewai.memory.unified_memory import Memory
def _ensure_memory_kind(value: Any) -> Any:
"""Backfill ``memory_kind`` on legacy dicts that predate the discriminator.
Lets pre-1.14.6 configs/checkpoints flow into the discriminated
``Memory | MemoryScope | MemorySlice`` union without crashing. Inference:
``scopes`` key → ``slice``; ``root_path`` → ``scope``; else ``memory``.
Pass-through for non-dict values (instances, ``bool``, ``None``).
"""
if isinstance(value, dict) and "memory_kind" not in value:
if "scopes" in value:
value["memory_kind"] = "slice"
elif "root_path" in value:
value["memory_kind"] = "scope"
else:
value["memory_kind"] = "memory"
return value
class MemoryScope(BaseModel):
"""View of Memory restricted to a root path. All operations are scoped under that path."""
@@ -26,8 +44,8 @@ class MemoryScope(BaseModel):
root_path: str = Field(default="/")
_memory: Memory = PrivateAttr()
_root: str = PrivateAttr()
_memory: Memory | None = PrivateAttr(default=None)
_root: str = PrivateAttr(default="")
@model_validator(mode="wrap")
@classmethod
@@ -56,10 +74,19 @@ class MemoryScope(BaseModel):
self._memory = memory
return self
def _require_memory(self) -> Memory:
"""Return the bound ``Memory`` or raise a clear error if missing."""
if self._memory is None:
raise RuntimeError(
"MemoryScope is not bound to a Memory; call .bind(memory) "
"after restore."
)
return self._memory
@property
def read_only(self) -> bool:
"""Whether the underlying memory is read-only."""
return self._memory.read_only
return self._require_memory().read_only
def _scope_path(self, scope: str | None) -> str:
if not scope or scope == "/":
@@ -84,7 +111,7 @@ class MemoryScope(BaseModel):
) -> MemoryRecord | None:
"""Remember content; scope is relative to this scope's root."""
path = self._scope_path(scope)
return self._memory.remember(
return self._require_memory().remember(
content,
scope=path,
categories=categories,
@@ -107,7 +134,7 @@ class MemoryScope(BaseModel):
) -> list[MemoryRecord]:
"""Remember multiple items; scope is relative to this scope's root."""
path = self._scope_path(scope)
return self._memory.remember_many(
return self._require_memory().remember_many(
contents,
scope=path,
categories=categories,
@@ -130,7 +157,7 @@ class MemoryScope(BaseModel):
) -> list[MemoryMatch]:
"""Recall within this scope (root path and below)."""
search_scope = self._scope_path(scope) if scope else (self._root or "/")
return self._memory.recall(
return self._require_memory().recall(
query,
scope=search_scope,
categories=categories,
@@ -142,7 +169,7 @@ class MemoryScope(BaseModel):
def extract_memories(self, content: str) -> list[str]:
"""Extract discrete memories from content; delegates to underlying Memory."""
return self._memory.extract_memories(content)
return self._require_memory().extract_memories(content)
def forget(
self,
@@ -154,7 +181,7 @@ class MemoryScope(BaseModel):
) -> int:
"""Forget within this scope."""
prefix = self._scope_path(scope) if scope else (self._root or "/")
return self._memory.forget(
return self._require_memory().forget(
scope=prefix,
categories=categories,
older_than=older_than,
@@ -165,27 +192,27 @@ class MemoryScope(BaseModel):
def list_scopes(self, path: str = "/") -> list[str]:
"""List child scopes under path (relative to this scope's root)."""
full = self._scope_path(path)
return self._memory.list_scopes(full)
return self._require_memory().list_scopes(full)
def info(self, path: str = "/") -> ScopeInfo:
"""Info for path under this scope."""
full = self._scope_path(path)
return self._memory.info(full)
return self._require_memory().info(full)
def tree(self, path: str = "/", max_depth: int = 3) -> str:
"""Tree under path within this scope."""
full = self._scope_path(path)
return self._memory.tree(full, max_depth=max_depth)
return self._require_memory().tree(full, max_depth=max_depth)
def list_categories(self, path: str | None = None) -> dict[str, int]:
"""Categories in this scope; path None means this scope root."""
full = self._scope_path(path) if path else (self._root or "/")
return self._memory.list_categories(full)
return self._require_memory().list_categories(full)
def reset(self, scope: str | None = None) -> None:
"""Reset within this scope."""
prefix = self._scope_path(scope) if scope else (self._root or "/")
self._memory.reset(scope=prefix)
self._require_memory().reset(scope=prefix)
def subscope(self, path: str) -> MemoryScope:
"""Return a narrower scope under this scope."""
@@ -208,7 +235,7 @@ class MemorySlice(BaseModel):
categories: list[str] | None = Field(default=None)
read_only: bool = Field(default=True)
_memory: Memory = PrivateAttr()
_memory: Memory | None = PrivateAttr(default=None)
@model_validator(mode="wrap")
@classmethod
@@ -230,6 +257,15 @@ class MemorySlice(BaseModel):
self._memory = memory
return self
def _require_memory(self) -> Memory:
"""Return the bound ``Memory`` or raise a clear error if missing."""
if self._memory is None:
raise RuntimeError(
"MemorySlice is not bound to a Memory; call .bind(memory) "
"after restore."
)
return self._memory
def remember(
self,
content: str,
@@ -243,7 +279,7 @@ class MemorySlice(BaseModel):
"""Remember into an explicit scope. No-op when read_only=True."""
if self.read_only:
return None
return self._memory.remember(
return self._require_memory().remember(
content,
scope=scope,
categories=categories,
@@ -267,7 +303,7 @@ class MemorySlice(BaseModel):
cats = categories or self.categories
all_matches: list[MemoryMatch] = []
for sc in self.scopes:
matches = self._memory.recall(
matches = self._require_memory().recall(
query,
scope=sc,
categories=cats,
@@ -289,14 +325,14 @@ class MemorySlice(BaseModel):
def extract_memories(self, content: str) -> list[str]:
"""Extract discrete memories from content; delegates to underlying Memory."""
return self._memory.extract_memories(content)
return self._require_memory().extract_memories(content)
def list_scopes(self, path: str = "/") -> list[str]:
"""List scopes across all slice roots."""
out: list[str] = []
for sc in self.scopes:
full = f"{sc.rstrip('/')}{path}" if sc != "/" else path
out.extend(self._memory.list_scopes(full))
out.extend(self._require_memory().list_scopes(full))
return sorted(set(out))
def info(self, path: str = "/") -> ScopeInfo:
@@ -308,7 +344,7 @@ class MemorySlice(BaseModel):
children: list[str] = []
for sc in self.scopes:
full = f"{sc.rstrip('/')}{path}" if sc != "/" else path
inf = self._memory.info(full)
inf = self._require_memory().info(full)
total_records += inf.record_count
all_categories.update(inf.categories)
if inf.oldest_record:
@@ -338,6 +374,6 @@ class MemorySlice(BaseModel):
counts: dict[str, int] = {}
for sc in self.scopes:
full = (f"{sc.rstrip('/')}{path}" if sc != "/" else path) if path else sc
for k, v in self._memory.list_categories(full).items():
for k, v in self._require_memory().list_categories(full).items():
counts[k] = counts.get(k, 0) + v
return counts

View File

@@ -133,11 +133,22 @@ def _backfill_memory_kind(value: Any) -> None:
def _backfill_source_type(source: Any) -> None:
"""Infer ``source_type`` for legacy knowledge source dicts when possible."""
"""Infer ``source_type`` for legacy knowledge source dicts when possible.
Only StringKnowledgeSource is reliably inferrable: it stores ``content``
as a plain string. File-based sources (CSV/PDF/Excel/JSON/docling) also
have a ``content`` field but populate it with dicts/lists, so we leave
those untagged and let downstream validation surface a clear error.
"""
if not isinstance(source, dict) or "source_type" in source:
return
if "content" in source:
if isinstance(source.get("content"), str):
source["source_type"] = "string"
return
raise ValueError(
"Legacy knowledge source is missing 'source_type' and could not be "
"inferred during migration. Re-checkpoint after upgrading to 1.14.6+."
)
def _backfill_discriminators(entity: Any) -> None:

View File

@@ -154,7 +154,16 @@ SerializableCallable = Annotated[
def _instance_to_dotted_path(value: Any) -> str:
"""Serialize an instance to a dotted path naming its class."""
if inspect.isclass(value):
raise ValueError(
f"Expected an instance, got class {value.__module__}.{value.__qualname__}."
)
cls = type(value)
if cls.__module__ == "builtins":
raise ValueError(
f"Cannot serialize {value!r}: builtin values are not "
"checkpointable instances."
)
module = getattr(cls, "__module__", None)
qualname = getattr(cls, "__qualname__", None)
if module is None or qualname is None:
@@ -194,11 +203,3 @@ def _dotted_path_to_instance(value: Any) -> Any:
f"{type(cls).__name__}"
)
return cls()
SerializableInstance = Annotated[
Any,
BeforeValidator(_dotted_path_to_instance),
PlainSerializer(_instance_to_dotted_path, return_type=str, when_used="json"),
WithJsonSchema({"type": "string"}),
]

View File

@@ -12,15 +12,71 @@ from datetime import datetime
from typing import Any
from unittest.mock import MagicMock, patch
from pydantic import BaseModel
from crewai.flow import Flow, flow_config, listen, start
from crewai.flow.async_feedback.providers import ConsoleProvider
from crewai.flow.flow import FlowState
from crewai.flow.input_provider import InputProvider, InputResponse
from crewai.flow.persistence.base import FlowPersistence
# ── Test helpers ─────────────────────────────────────────────────
class _SaveCall:
"""Lightweight stand-in for ``MagicMock.call_args`` entries."""
__slots__ = ("args", "kwargs")
def __init__(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None:
self.args = args
self.kwargs = kwargs
class _SaveStateRecorder:
"""Callable that records each ``save_state`` invocation."""
def __init__(self, owner: RecordingPersistence) -> None:
self._owner = owner
self.call_args_list: list[_SaveCall] = []
def __call__(
self,
flow_uuid: str,
method_name: str,
state_data: dict[str, Any] | BaseModel,
) -> None:
self.call_args_list.append(
_SaveCall((flow_uuid, method_name, state_data), {})
)
self._owner._states[flow_uuid] = state_data
class RecordingPersistence(FlowPersistence):
"""In-memory FlowPersistence that records ``save_state`` invocations."""
persistence_type: str = "RecordingPersistence"
def model_post_init(self, _: Any) -> None:
object.__setattr__(self, "_states", {})
object.__setattr__(self, "save_state", _SaveStateRecorder(self))
def init_db(self) -> None:
return None
def save_state( # type: ignore[no-redef]
self,
flow_uuid: str,
method_name: str,
state_data: dict[str, Any] | BaseModel,
) -> None:
return None
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
return None
class MockInputProvider:
"""Mock input provider that returns pre-configured responses."""
@@ -436,8 +492,7 @@ class TestAskCheckpoint:
def test_ask_checkpoints_state_before_waiting(self) -> None:
"""State is saved to persistence before waiting for input."""
mock_persistence = MagicMock()
mock_persistence.load_state.return_value = None
mock_persistence = RecordingPersistence()
class TestFlow(Flow):
input_provider = MockInputProvider(["answer"])
@@ -480,8 +535,7 @@ class TestAskCheckpoint:
server crashes while waiting for input, previously gathered data
is safe.
"""
mock_persistence = MagicMock()
mock_persistence.load_state.return_value = None
mock_persistence = RecordingPersistence()
class GatherFlow(Flow):
input_provider = MockInputProvider(["AI", "detailed"])
@@ -678,8 +732,7 @@ class TestAskIntegration:
def test_ask_with_state_persistence_recovery(self) -> None:
"""Ask checkpoints state so previously gathered values survive."""
mock_persistence = MagicMock()
mock_persistence.load_state.return_value = None
mock_persistence = RecordingPersistence()
class RecoverableFlow(Flow):
input_provider = MockInputProvider(["AI", "detailed"])