mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-04 22:49:23 +00:00
fix: bot review follow-ups for serialization hardening
- Reject classes and builtin values in _instance_to_dotted_path - Require classes in _dotted_path_to_instance - Drop unused SerializableInstance alias - Raise on unknown FlowPersistence types in _serialize_persistence - Gate Knowledge.embedder provider_class restore behind CREWAI_DESERIALIZE_CALLBACKS - Raise on unknown source_type tags in _resolve_knowledge_sources - Tighten _backfill_source_type: only infer 'string' when content is str; raise otherwise so legacy file-based sources fail loudly - Add BeforeValidator(_ensure_memory_kind) to Crew/Agent/Flow memory fields so legacy dict configs get the discriminator at construction - Default MemoryScope/MemorySlice._memory to None; add _require_memory() helper and route all internal accesses through it - Convert test_flow_ask persistence mocks to RecordingPersistence
This commit is contained in:
@@ -37,7 +37,7 @@ from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.mcp.config import MCPServerConfig
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice, _ensure_memory_kind
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.security.security_config import SecurityConfig
|
||||
@@ -332,13 +332,14 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
default=None,
|
||||
description="List of MCP server references. Supports 'https://server.com/path' for external servers and bare slugs like 'notion' for connected MCP integrations. Use '#tool_name' suffix for specific tools.",
|
||||
)
|
||||
memory: (
|
||||
memory: Annotated[
|
||||
bool
|
||||
| Annotated[
|
||||
Memory | MemoryScope | MemorySlice, Field(discriminator="memory_kind")
|
||||
]
|
||||
| None
|
||||
) = Field(
|
||||
| None,
|
||||
BeforeValidator(_ensure_memory_kind),
|
||||
] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Enable agent memory. Pass True for default Memory(), "
|
||||
|
||||
@@ -97,7 +97,7 @@ from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice, _ensure_memory_kind
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai.process import Process
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
@@ -223,13 +223,14 @@ class Crew(FlowTrackable, BaseModel):
|
||||
] = Field(default_factory=list)
|
||||
process: Process = Field(default=Process.sequential)
|
||||
verbose: bool = Field(default=False)
|
||||
memory: (
|
||||
memory: Annotated[
|
||||
bool
|
||||
| Annotated[
|
||||
Memory | MemoryScope | MemorySlice, Field(discriminator="memory_kind")
|
||||
]
|
||||
| None
|
||||
) = Field(
|
||||
| None,
|
||||
BeforeValidator(_ensure_memory_kind),
|
||||
] = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Enable crew memory. Pass True for default Memory(), "
|
||||
|
||||
@@ -113,7 +113,7 @@ from crewai.flow.utils import (
|
||||
is_flow_method_name,
|
||||
is_simple_flow_condition,
|
||||
)
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice, _ensure_memory_kind
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai.state.checkpoint_config import (
|
||||
CheckpointConfig,
|
||||
@@ -164,7 +164,10 @@ def _serialize_persistence(value: Any) -> dict[str, Any] | None:
|
||||
return None
|
||||
if isinstance(value, FlowPersistence):
|
||||
return value.model_dump(mode="json")
|
||||
return None
|
||||
raise TypeError(
|
||||
f"Cannot serialize Flow.persistence of type {type(value).__name__}: "
|
||||
"expected FlowPersistence or None."
|
||||
)
|
||||
|
||||
|
||||
def _validate_input_provider(value: Any) -> Any:
|
||||
@@ -979,12 +982,13 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
name: str | None = Field(default=None)
|
||||
tracing: bool | None = Field(default=None)
|
||||
stream: bool = Field(default=False)
|
||||
memory: (
|
||||
memory: Annotated[
|
||||
Annotated[
|
||||
Memory | MemoryScope | MemorySlice, Field(discriminator="memory_kind")
|
||||
]
|
||||
| None
|
||||
) = Field(default=None)
|
||||
| None,
|
||||
BeforeValidator(_ensure_memory_kind),
|
||||
] = Field(default=None)
|
||||
input_provider: Annotated[
|
||||
InputProvider | None,
|
||||
BeforeValidator(_validate_input_provider),
|
||||
|
||||
@@ -41,17 +41,22 @@ def _resolve_knowledge_sources(value: Any) -> Any:
|
||||
for idx, item in enumerate(value):
|
||||
if isinstance(item, dict):
|
||||
tag = item.get("source_type")
|
||||
cls = _KNOWN_SOURCES.get(tag) if isinstance(tag, str) else None
|
||||
if cls is None:
|
||||
if not isinstance(tag, str):
|
||||
resolved.append(item)
|
||||
else:
|
||||
try:
|
||||
resolved.append(cls.model_validate(item))
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
f"Failed to validate knowledge source at index {idx} "
|
||||
f"with source_type={tag!r}: {exc}"
|
||||
) from exc
|
||||
continue
|
||||
cls = _KNOWN_SOURCES.get(tag)
|
||||
if cls is None:
|
||||
raise ValueError(
|
||||
f"Unknown source_type={tag!r} at index {idx}: "
|
||||
f"expected one of {sorted(_KNOWN_SOURCES)}"
|
||||
)
|
||||
try:
|
||||
resolved.append(cls.model_validate(item))
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
f"Failed to validate knowledge source at index {idx} "
|
||||
f"with source_type={tag!r}: {exc}"
|
||||
) from exc
|
||||
else:
|
||||
resolved.append(item)
|
||||
return resolved
|
||||
@@ -78,6 +83,13 @@ def _serialize_embedder_spec(value: Any) -> dict[str, Any] | None:
|
||||
def _validate_embedder_spec(value: Any) -> Any:
|
||||
"""Resolve provider_class dotted-path dicts back to a class on restore."""
|
||||
if isinstance(value, dict) and set(value.keys()) == {"provider_class"}:
|
||||
if not os.environ.get("CREWAI_DESERIALIZE_CALLBACKS"):
|
||||
raise ValueError(
|
||||
f"Refusing to resolve embedder provider_class "
|
||||
f"{value['provider_class']!r}: set "
|
||||
"CREWAI_DESERIALIZE_CALLBACKS=1 to allow. Only enable this "
|
||||
"for trusted checkpoint data."
|
||||
)
|
||||
from crewai.types.callback import _resolve_dotted_path
|
||||
|
||||
cls = _resolve_dotted_path(value["provider_class"])
|
||||
|
||||
@@ -17,6 +17,24 @@ from crewai.memory.types import (
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
|
||||
def _ensure_memory_kind(value: Any) -> Any:
|
||||
"""Backfill ``memory_kind`` on legacy dicts that predate the discriminator.
|
||||
|
||||
Lets pre-1.14.6 configs/checkpoints flow into the discriminated
|
||||
``Memory | MemoryScope | MemorySlice`` union without crashing. Inference:
|
||||
``scopes`` key → ``slice``; ``root_path`` → ``scope``; else ``memory``.
|
||||
Pass-through for non-dict values (instances, ``bool``, ``None``).
|
||||
"""
|
||||
if isinstance(value, dict) and "memory_kind" not in value:
|
||||
if "scopes" in value:
|
||||
value["memory_kind"] = "slice"
|
||||
elif "root_path" in value:
|
||||
value["memory_kind"] = "scope"
|
||||
else:
|
||||
value["memory_kind"] = "memory"
|
||||
return value
|
||||
|
||||
|
||||
class MemoryScope(BaseModel):
|
||||
"""View of Memory restricted to a root path. All operations are scoped under that path."""
|
||||
|
||||
@@ -26,8 +44,8 @@ class MemoryScope(BaseModel):
|
||||
|
||||
root_path: str = Field(default="/")
|
||||
|
||||
_memory: Memory = PrivateAttr()
|
||||
_root: str = PrivateAttr()
|
||||
_memory: Memory | None = PrivateAttr(default=None)
|
||||
_root: str = PrivateAttr(default="")
|
||||
|
||||
@model_validator(mode="wrap")
|
||||
@classmethod
|
||||
@@ -56,10 +74,19 @@ class MemoryScope(BaseModel):
|
||||
self._memory = memory
|
||||
return self
|
||||
|
||||
def _require_memory(self) -> Memory:
|
||||
"""Return the bound ``Memory`` or raise a clear error if missing."""
|
||||
if self._memory is None:
|
||||
raise RuntimeError(
|
||||
"MemoryScope is not bound to a Memory; call .bind(memory) "
|
||||
"after restore."
|
||||
)
|
||||
return self._memory
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
"""Whether the underlying memory is read-only."""
|
||||
return self._memory.read_only
|
||||
return self._require_memory().read_only
|
||||
|
||||
def _scope_path(self, scope: str | None) -> str:
|
||||
if not scope or scope == "/":
|
||||
@@ -84,7 +111,7 @@ class MemoryScope(BaseModel):
|
||||
) -> MemoryRecord | None:
|
||||
"""Remember content; scope is relative to this scope's root."""
|
||||
path = self._scope_path(scope)
|
||||
return self._memory.remember(
|
||||
return self._require_memory().remember(
|
||||
content,
|
||||
scope=path,
|
||||
categories=categories,
|
||||
@@ -107,7 +134,7 @@ class MemoryScope(BaseModel):
|
||||
) -> list[MemoryRecord]:
|
||||
"""Remember multiple items; scope is relative to this scope's root."""
|
||||
path = self._scope_path(scope)
|
||||
return self._memory.remember_many(
|
||||
return self._require_memory().remember_many(
|
||||
contents,
|
||||
scope=path,
|
||||
categories=categories,
|
||||
@@ -130,7 +157,7 @@ class MemoryScope(BaseModel):
|
||||
) -> list[MemoryMatch]:
|
||||
"""Recall within this scope (root path and below)."""
|
||||
search_scope = self._scope_path(scope) if scope else (self._root or "/")
|
||||
return self._memory.recall(
|
||||
return self._require_memory().recall(
|
||||
query,
|
||||
scope=search_scope,
|
||||
categories=categories,
|
||||
@@ -142,7 +169,7 @@ class MemoryScope(BaseModel):
|
||||
|
||||
def extract_memories(self, content: str) -> list[str]:
|
||||
"""Extract discrete memories from content; delegates to underlying Memory."""
|
||||
return self._memory.extract_memories(content)
|
||||
return self._require_memory().extract_memories(content)
|
||||
|
||||
def forget(
|
||||
self,
|
||||
@@ -154,7 +181,7 @@ class MemoryScope(BaseModel):
|
||||
) -> int:
|
||||
"""Forget within this scope."""
|
||||
prefix = self._scope_path(scope) if scope else (self._root or "/")
|
||||
return self._memory.forget(
|
||||
return self._require_memory().forget(
|
||||
scope=prefix,
|
||||
categories=categories,
|
||||
older_than=older_than,
|
||||
@@ -165,27 +192,27 @@ class MemoryScope(BaseModel):
|
||||
def list_scopes(self, path: str = "/") -> list[str]:
|
||||
"""List child scopes under path (relative to this scope's root)."""
|
||||
full = self._scope_path(path)
|
||||
return self._memory.list_scopes(full)
|
||||
return self._require_memory().list_scopes(full)
|
||||
|
||||
def info(self, path: str = "/") -> ScopeInfo:
|
||||
"""Info for path under this scope."""
|
||||
full = self._scope_path(path)
|
||||
return self._memory.info(full)
|
||||
return self._require_memory().info(full)
|
||||
|
||||
def tree(self, path: str = "/", max_depth: int = 3) -> str:
|
||||
"""Tree under path within this scope."""
|
||||
full = self._scope_path(path)
|
||||
return self._memory.tree(full, max_depth=max_depth)
|
||||
return self._require_memory().tree(full, max_depth=max_depth)
|
||||
|
||||
def list_categories(self, path: str | None = None) -> dict[str, int]:
|
||||
"""Categories in this scope; path None means this scope root."""
|
||||
full = self._scope_path(path) if path else (self._root or "/")
|
||||
return self._memory.list_categories(full)
|
||||
return self._require_memory().list_categories(full)
|
||||
|
||||
def reset(self, scope: str | None = None) -> None:
|
||||
"""Reset within this scope."""
|
||||
prefix = self._scope_path(scope) if scope else (self._root or "/")
|
||||
self._memory.reset(scope=prefix)
|
||||
self._require_memory().reset(scope=prefix)
|
||||
|
||||
def subscope(self, path: str) -> MemoryScope:
|
||||
"""Return a narrower scope under this scope."""
|
||||
@@ -208,7 +235,7 @@ class MemorySlice(BaseModel):
|
||||
categories: list[str] | None = Field(default=None)
|
||||
read_only: bool = Field(default=True)
|
||||
|
||||
_memory: Memory = PrivateAttr()
|
||||
_memory: Memory | None = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="wrap")
|
||||
@classmethod
|
||||
@@ -230,6 +257,15 @@ class MemorySlice(BaseModel):
|
||||
self._memory = memory
|
||||
return self
|
||||
|
||||
def _require_memory(self) -> Memory:
|
||||
"""Return the bound ``Memory`` or raise a clear error if missing."""
|
||||
if self._memory is None:
|
||||
raise RuntimeError(
|
||||
"MemorySlice is not bound to a Memory; call .bind(memory) "
|
||||
"after restore."
|
||||
)
|
||||
return self._memory
|
||||
|
||||
def remember(
|
||||
self,
|
||||
content: str,
|
||||
@@ -243,7 +279,7 @@ class MemorySlice(BaseModel):
|
||||
"""Remember into an explicit scope. No-op when read_only=True."""
|
||||
if self.read_only:
|
||||
return None
|
||||
return self._memory.remember(
|
||||
return self._require_memory().remember(
|
||||
content,
|
||||
scope=scope,
|
||||
categories=categories,
|
||||
@@ -267,7 +303,7 @@ class MemorySlice(BaseModel):
|
||||
cats = categories or self.categories
|
||||
all_matches: list[MemoryMatch] = []
|
||||
for sc in self.scopes:
|
||||
matches = self._memory.recall(
|
||||
matches = self._require_memory().recall(
|
||||
query,
|
||||
scope=sc,
|
||||
categories=cats,
|
||||
@@ -289,14 +325,14 @@ class MemorySlice(BaseModel):
|
||||
|
||||
def extract_memories(self, content: str) -> list[str]:
|
||||
"""Extract discrete memories from content; delegates to underlying Memory."""
|
||||
return self._memory.extract_memories(content)
|
||||
return self._require_memory().extract_memories(content)
|
||||
|
||||
def list_scopes(self, path: str = "/") -> list[str]:
|
||||
"""List scopes across all slice roots."""
|
||||
out: list[str] = []
|
||||
for sc in self.scopes:
|
||||
full = f"{sc.rstrip('/')}{path}" if sc != "/" else path
|
||||
out.extend(self._memory.list_scopes(full))
|
||||
out.extend(self._require_memory().list_scopes(full))
|
||||
return sorted(set(out))
|
||||
|
||||
def info(self, path: str = "/") -> ScopeInfo:
|
||||
@@ -308,7 +344,7 @@ class MemorySlice(BaseModel):
|
||||
children: list[str] = []
|
||||
for sc in self.scopes:
|
||||
full = f"{sc.rstrip('/')}{path}" if sc != "/" else path
|
||||
inf = self._memory.info(full)
|
||||
inf = self._require_memory().info(full)
|
||||
total_records += inf.record_count
|
||||
all_categories.update(inf.categories)
|
||||
if inf.oldest_record:
|
||||
@@ -338,6 +374,6 @@ class MemorySlice(BaseModel):
|
||||
counts: dict[str, int] = {}
|
||||
for sc in self.scopes:
|
||||
full = (f"{sc.rstrip('/')}{path}" if sc != "/" else path) if path else sc
|
||||
for k, v in self._memory.list_categories(full).items():
|
||||
for k, v in self._require_memory().list_categories(full).items():
|
||||
counts[k] = counts.get(k, 0) + v
|
||||
return counts
|
||||
|
||||
@@ -133,11 +133,22 @@ def _backfill_memory_kind(value: Any) -> None:
|
||||
|
||||
|
||||
def _backfill_source_type(source: Any) -> None:
|
||||
"""Infer ``source_type`` for legacy knowledge source dicts when possible."""
|
||||
"""Infer ``source_type`` for legacy knowledge source dicts when possible.
|
||||
|
||||
Only StringKnowledgeSource is reliably inferrable: it stores ``content``
|
||||
as a plain string. File-based sources (CSV/PDF/Excel/JSON/docling) also
|
||||
have a ``content`` field but populate it with dicts/lists, so we leave
|
||||
those untagged and let downstream validation surface a clear error.
|
||||
"""
|
||||
if not isinstance(source, dict) or "source_type" in source:
|
||||
return
|
||||
if "content" in source:
|
||||
if isinstance(source.get("content"), str):
|
||||
source["source_type"] = "string"
|
||||
return
|
||||
raise ValueError(
|
||||
"Legacy knowledge source is missing 'source_type' and could not be "
|
||||
"inferred during migration. Re-checkpoint after upgrading to 1.14.6+."
|
||||
)
|
||||
|
||||
|
||||
def _backfill_discriminators(entity: Any) -> None:
|
||||
|
||||
@@ -154,7 +154,16 @@ SerializableCallable = Annotated[
|
||||
|
||||
def _instance_to_dotted_path(value: Any) -> str:
|
||||
"""Serialize an instance to a dotted path naming its class."""
|
||||
if inspect.isclass(value):
|
||||
raise ValueError(
|
||||
f"Expected an instance, got class {value.__module__}.{value.__qualname__}."
|
||||
)
|
||||
cls = type(value)
|
||||
if cls.__module__ == "builtins":
|
||||
raise ValueError(
|
||||
f"Cannot serialize {value!r}: builtin values are not "
|
||||
"checkpointable instances."
|
||||
)
|
||||
module = getattr(cls, "__module__", None)
|
||||
qualname = getattr(cls, "__qualname__", None)
|
||||
if module is None or qualname is None:
|
||||
@@ -194,11 +203,3 @@ def _dotted_path_to_instance(value: Any) -> Any:
|
||||
f"{type(cls).__name__}"
|
||||
)
|
||||
return cls()
|
||||
|
||||
|
||||
SerializableInstance = Annotated[
|
||||
Any,
|
||||
BeforeValidator(_dotted_path_to_instance),
|
||||
PlainSerializer(_instance_to_dotted_path, return_type=str, when_used="json"),
|
||||
WithJsonSchema({"type": "string"}),
|
||||
]
|
||||
|
||||
@@ -12,15 +12,71 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.flow import Flow, flow_config, listen, start
|
||||
from crewai.flow.async_feedback.providers import ConsoleProvider
|
||||
from crewai.flow.flow import FlowState
|
||||
from crewai.flow.input_provider import InputProvider, InputResponse
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
|
||||
|
||||
# ── Test helpers ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _SaveCall:
|
||||
"""Lightweight stand-in for ``MagicMock.call_args`` entries."""
|
||||
|
||||
__slots__ = ("args", "kwargs")
|
||||
|
||||
def __init__(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None:
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
class _SaveStateRecorder:
|
||||
"""Callable that records each ``save_state`` invocation."""
|
||||
|
||||
def __init__(self, owner: RecordingPersistence) -> None:
|
||||
self._owner = owner
|
||||
self.call_args_list: list[_SaveCall] = []
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
flow_uuid: str,
|
||||
method_name: str,
|
||||
state_data: dict[str, Any] | BaseModel,
|
||||
) -> None:
|
||||
self.call_args_list.append(
|
||||
_SaveCall((flow_uuid, method_name, state_data), {})
|
||||
)
|
||||
self._owner._states[flow_uuid] = state_data
|
||||
|
||||
|
||||
class RecordingPersistence(FlowPersistence):
|
||||
"""In-memory FlowPersistence that records ``save_state`` invocations."""
|
||||
|
||||
persistence_type: str = "RecordingPersistence"
|
||||
|
||||
def model_post_init(self, _: Any) -> None:
|
||||
object.__setattr__(self, "_states", {})
|
||||
object.__setattr__(self, "save_state", _SaveStateRecorder(self))
|
||||
|
||||
def init_db(self) -> None:
|
||||
return None
|
||||
|
||||
def save_state( # type: ignore[no-redef]
|
||||
self,
|
||||
flow_uuid: str,
|
||||
method_name: str,
|
||||
state_data: dict[str, Any] | BaseModel,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
|
||||
class MockInputProvider:
|
||||
"""Mock input provider that returns pre-configured responses."""
|
||||
|
||||
@@ -436,8 +492,7 @@ class TestAskCheckpoint:
|
||||
|
||||
def test_ask_checkpoints_state_before_waiting(self) -> None:
|
||||
"""State is saved to persistence before waiting for input."""
|
||||
mock_persistence = MagicMock()
|
||||
mock_persistence.load_state.return_value = None
|
||||
mock_persistence = RecordingPersistence()
|
||||
|
||||
class TestFlow(Flow):
|
||||
input_provider = MockInputProvider(["answer"])
|
||||
@@ -480,8 +535,7 @@ class TestAskCheckpoint:
|
||||
server crashes while waiting for input, previously gathered data
|
||||
is safe.
|
||||
"""
|
||||
mock_persistence = MagicMock()
|
||||
mock_persistence.load_state.return_value = None
|
||||
mock_persistence = RecordingPersistence()
|
||||
|
||||
class GatherFlow(Flow):
|
||||
input_provider = MockInputProvider(["AI", "detailed"])
|
||||
@@ -678,8 +732,7 @@ class TestAskIntegration:
|
||||
|
||||
def test_ask_with_state_persistence_recovery(self) -> None:
|
||||
"""Ask checkpoints state so previously gathered values survive."""
|
||||
mock_persistence = MagicMock()
|
||||
mock_persistence.load_state.return_value = None
|
||||
mock_persistence = RecordingPersistence()
|
||||
|
||||
class RecoverableFlow(Flow):
|
||||
input_provider = MockInputProvider(["AI", "detailed"])
|
||||
|
||||
Reference in New Issue
Block a user