diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py index 365e3403b..4e1912752 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py @@ -37,7 +37,7 @@ from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage from crewai.llms.base_llm import BaseLLM from crewai.mcp.config import MCPServerConfig -from crewai.memory.memory_scope import MemoryScope, MemorySlice +from crewai.memory.memory_scope import MemoryScope, MemorySlice, _ensure_memory_kind from crewai.memory.unified_memory import Memory from crewai.rag.embeddings.types import EmbedderConfig from crewai.security.security_config import SecurityConfig @@ -332,13 +332,14 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): default=None, description="List of MCP server references. Supports 'https://server.com/path' for external servers and bare slugs like 'notion' for connected MCP integrations. Use '#tool_name' suffix for specific tools.", ) - memory: ( + memory: Annotated[ bool | Annotated[ Memory | MemoryScope | MemorySlice, Field(discriminator="memory_kind") ] - | None - ) = Field( + | None, + BeforeValidator(_ensure_memory_kind), + ] = Field( default=None, description=( "Enable agent memory. Pass True for default Memory(), " diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index 972426459..e2e67368e 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -97,7 +97,7 @@ from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.llm import LLM from crewai.llms.base_llm import BaseLLM -from crewai.memory.memory_scope import MemoryScope, MemorySlice +from crewai.memory.memory_scope import MemoryScope, MemorySlice, _ensure_memory_kind from crewai.memory.unified_memory import Memory from crewai.process import Process from crewai.rag.embeddings.types import EmbedderConfig @@ -223,13 +223,14 @@ class Crew(FlowTrackable, BaseModel): ] = Field(default_factory=list) process: Process = Field(default=Process.sequential) verbose: bool = Field(default=False) - memory: ( + memory: Annotated[ bool | Annotated[ Memory | MemoryScope | MemorySlice, Field(discriminator="memory_kind") ] - | None - ) = Field( + | None, + BeforeValidator(_ensure_memory_kind), + ] = Field( default=False, description=( "Enable crew memory. Pass True for default Memory(), " diff --git a/lib/crewai/src/crewai/flow/flow.py b/lib/crewai/src/crewai/flow/flow.py index d5895d0bf..5b71a2fbd 100644 --- a/lib/crewai/src/crewai/flow/flow.py +++ b/lib/crewai/src/crewai/flow/flow.py @@ -113,7 +113,7 @@ from crewai.flow.utils import ( is_flow_method_name, is_simple_flow_condition, ) -from crewai.memory.memory_scope import MemoryScope, MemorySlice +from crewai.memory.memory_scope import MemoryScope, MemorySlice, _ensure_memory_kind from crewai.memory.unified_memory import Memory from crewai.state.checkpoint_config import ( CheckpointConfig, @@ -164,7 +164,10 @@ def _serialize_persistence(value: Any) -> dict[str, Any] | None: return None if isinstance(value, FlowPersistence): return value.model_dump(mode="json") - return None + raise TypeError( + f"Cannot serialize Flow.persistence of type {type(value).__name__}: " + "expected FlowPersistence or None." + ) def _validate_input_provider(value: Any) -> Any: @@ -979,12 +982,13 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): name: str | None = Field(default=None) tracing: bool | None = Field(default=None) stream: bool = Field(default=False) - memory: ( + memory: Annotated[ Annotated[ Memory | MemoryScope | MemorySlice, Field(discriminator="memory_kind") ] - | None - ) = Field(default=None) + | None, + BeforeValidator(_ensure_memory_kind), + ] = Field(default=None) input_provider: Annotated[ InputProvider | None, BeforeValidator(_validate_input_provider), diff --git a/lib/crewai/src/crewai/knowledge/knowledge.py b/lib/crewai/src/crewai/knowledge/knowledge.py index a14125c22..b8d57bc05 100644 --- a/lib/crewai/src/crewai/knowledge/knowledge.py +++ b/lib/crewai/src/crewai/knowledge/knowledge.py @@ -41,17 +41,22 @@ def _resolve_knowledge_sources(value: Any) -> Any: for idx, item in enumerate(value): if isinstance(item, dict): tag = item.get("source_type") - cls = _KNOWN_SOURCES.get(tag) if isinstance(tag, str) else None - if cls is None: + if not isinstance(tag, str): resolved.append(item) - else: - try: - resolved.append(cls.model_validate(item)) - except Exception as exc: - raise ValueError( - f"Failed to validate knowledge source at index {idx} " - f"with source_type={tag!r}: {exc}" - ) from exc + continue + cls = _KNOWN_SOURCES.get(tag) + if cls is None: + raise ValueError( + f"Unknown source_type={tag!r} at index {idx}: " + f"expected one of {sorted(_KNOWN_SOURCES)}" + ) + try: + resolved.append(cls.model_validate(item)) + except Exception as exc: + raise ValueError( + f"Failed to validate knowledge source at index {idx} " + f"with source_type={tag!r}: {exc}" + ) from exc else: resolved.append(item) return resolved @@ -78,6 +83,13 @@ def _serialize_embedder_spec(value: Any) -> dict[str, Any] | None: def _validate_embedder_spec(value: Any) -> Any: """Resolve provider_class dotted-path dicts back to a class on restore.""" if isinstance(value, dict) and set(value.keys()) == {"provider_class"}: + if not os.environ.get("CREWAI_DESERIALIZE_CALLBACKS"): + raise ValueError( + f"Refusing to resolve embedder provider_class " + f"{value['provider_class']!r}: set " + "CREWAI_DESERIALIZE_CALLBACKS=1 to allow. Only enable this " + "for trusted checkpoint data." + ) from crewai.types.callback import _resolve_dotted_path cls = _resolve_dotted_path(value["provider_class"]) diff --git a/lib/crewai/src/crewai/memory/memory_scope.py b/lib/crewai/src/crewai/memory/memory_scope.py index 0becc14e3..1cd09d476 100644 --- a/lib/crewai/src/crewai/memory/memory_scope.py +++ b/lib/crewai/src/crewai/memory/memory_scope.py @@ -17,6 +17,24 @@ from crewai.memory.types import ( from crewai.memory.unified_memory import Memory +def _ensure_memory_kind(value: Any) -> Any: + """Backfill ``memory_kind`` on legacy dicts that predate the discriminator. + + Lets pre-1.14.6 configs/checkpoints flow into the discriminated + ``Memory | MemoryScope | MemorySlice`` union without crashing. Inference: + ``scopes`` key → ``slice``; ``root_path`` → ``scope``; else ``memory``. + Pass-through for non-dict values (instances, ``bool``, ``None``). + """ + if isinstance(value, dict) and "memory_kind" not in value: + if "scopes" in value: + value["memory_kind"] = "slice" + elif "root_path" in value: + value["memory_kind"] = "scope" + else: + value["memory_kind"] = "memory" + return value + + class MemoryScope(BaseModel): """View of Memory restricted to a root path. All operations are scoped under that path.""" @@ -26,8 +44,8 @@ class MemoryScope(BaseModel): root_path: str = Field(default="/") - _memory: Memory = PrivateAttr() - _root: str = PrivateAttr() + _memory: Memory | None = PrivateAttr(default=None) + _root: str = PrivateAttr(default="") @model_validator(mode="wrap") @classmethod @@ -56,10 +74,19 @@ class MemoryScope(BaseModel): self._memory = memory return self + def _require_memory(self) -> Memory: + """Return the bound ``Memory`` or raise a clear error if missing.""" + if self._memory is None: + raise RuntimeError( + "MemoryScope is not bound to a Memory; call .bind(memory) " + "after restore." + ) + return self._memory + @property def read_only(self) -> bool: """Whether the underlying memory is read-only.""" - return self._memory.read_only + return self._require_memory().read_only def _scope_path(self, scope: str | None) -> str: if not scope or scope == "/": @@ -84,7 +111,7 @@ class MemoryScope(BaseModel): ) -> MemoryRecord | None: """Remember content; scope is relative to this scope's root.""" path = self._scope_path(scope) - return self._memory.remember( + return self._require_memory().remember( content, scope=path, categories=categories, @@ -107,7 +134,7 @@ class MemoryScope(BaseModel): ) -> list[MemoryRecord]: """Remember multiple items; scope is relative to this scope's root.""" path = self._scope_path(scope) - return self._memory.remember_many( + return self._require_memory().remember_many( contents, scope=path, categories=categories, @@ -130,7 +157,7 @@ class MemoryScope(BaseModel): ) -> list[MemoryMatch]: """Recall within this scope (root path and below).""" search_scope = self._scope_path(scope) if scope else (self._root or "/") - return self._memory.recall( + return self._require_memory().recall( query, scope=search_scope, categories=categories, @@ -142,7 +169,7 @@ class MemoryScope(BaseModel): def extract_memories(self, content: str) -> list[str]: """Extract discrete memories from content; delegates to underlying Memory.""" - return self._memory.extract_memories(content) + return self._require_memory().extract_memories(content) def forget( self, @@ -154,7 +181,7 @@ class MemoryScope(BaseModel): ) -> int: """Forget within this scope.""" prefix = self._scope_path(scope) if scope else (self._root or "/") - return self._memory.forget( + return self._require_memory().forget( scope=prefix, categories=categories, older_than=older_than, @@ -165,27 +192,27 @@ class MemoryScope(BaseModel): def list_scopes(self, path: str = "/") -> list[str]: """List child scopes under path (relative to this scope's root).""" full = self._scope_path(path) - return self._memory.list_scopes(full) + return self._require_memory().list_scopes(full) def info(self, path: str = "/") -> ScopeInfo: """Info for path under this scope.""" full = self._scope_path(path) - return self._memory.info(full) + return self._require_memory().info(full) def tree(self, path: str = "/", max_depth: int = 3) -> str: """Tree under path within this scope.""" full = self._scope_path(path) - return self._memory.tree(full, max_depth=max_depth) + return self._require_memory().tree(full, max_depth=max_depth) def list_categories(self, path: str | None = None) -> dict[str, int]: """Categories in this scope; path None means this scope root.""" full = self._scope_path(path) if path else (self._root or "/") - return self._memory.list_categories(full) + return self._require_memory().list_categories(full) def reset(self, scope: str | None = None) -> None: """Reset within this scope.""" prefix = self._scope_path(scope) if scope else (self._root or "/") - self._memory.reset(scope=prefix) + self._require_memory().reset(scope=prefix) def subscope(self, path: str) -> MemoryScope: """Return a narrower scope under this scope.""" @@ -208,7 +235,7 @@ class MemorySlice(BaseModel): categories: list[str] | None = Field(default=None) read_only: bool = Field(default=True) - _memory: Memory = PrivateAttr() + _memory: Memory | None = PrivateAttr(default=None) @model_validator(mode="wrap") @classmethod @@ -230,6 +257,15 @@ class MemorySlice(BaseModel): self._memory = memory return self + def _require_memory(self) -> Memory: + """Return the bound ``Memory`` or raise a clear error if missing.""" + if self._memory is None: + raise RuntimeError( + "MemorySlice is not bound to a Memory; call .bind(memory) " + "after restore." + ) + return self._memory + def remember( self, content: str, @@ -243,7 +279,7 @@ class MemorySlice(BaseModel): """Remember into an explicit scope. No-op when read_only=True.""" if self.read_only: return None - return self._memory.remember( + return self._require_memory().remember( content, scope=scope, categories=categories, @@ -267,7 +303,7 @@ class MemorySlice(BaseModel): cats = categories or self.categories all_matches: list[MemoryMatch] = [] for sc in self.scopes: - matches = self._memory.recall( + matches = self._require_memory().recall( query, scope=sc, categories=cats, @@ -289,14 +325,14 @@ class MemorySlice(BaseModel): def extract_memories(self, content: str) -> list[str]: """Extract discrete memories from content; delegates to underlying Memory.""" - return self._memory.extract_memories(content) + return self._require_memory().extract_memories(content) def list_scopes(self, path: str = "/") -> list[str]: """List scopes across all slice roots.""" out: list[str] = [] for sc in self.scopes: full = f"{sc.rstrip('/')}{path}" if sc != "/" else path - out.extend(self._memory.list_scopes(full)) + out.extend(self._require_memory().list_scopes(full)) return sorted(set(out)) def info(self, path: str = "/") -> ScopeInfo: @@ -308,7 +344,7 @@ class MemorySlice(BaseModel): children: list[str] = [] for sc in self.scopes: full = f"{sc.rstrip('/')}{path}" if sc != "/" else path - inf = self._memory.info(full) + inf = self._require_memory().info(full) total_records += inf.record_count all_categories.update(inf.categories) if inf.oldest_record: @@ -338,6 +374,6 @@ class MemorySlice(BaseModel): counts: dict[str, int] = {} for sc in self.scopes: full = (f"{sc.rstrip('/')}{path}" if sc != "/" else path) if path else sc - for k, v in self._memory.list_categories(full).items(): + for k, v in self._require_memory().list_categories(full).items(): counts[k] = counts.get(k, 0) + v return counts diff --git a/lib/crewai/src/crewai/state/runtime.py b/lib/crewai/src/crewai/state/runtime.py index 8b08632d9..392897048 100644 --- a/lib/crewai/src/crewai/state/runtime.py +++ b/lib/crewai/src/crewai/state/runtime.py @@ -133,11 +133,22 @@ def _backfill_memory_kind(value: Any) -> None: def _backfill_source_type(source: Any) -> None: - """Infer ``source_type`` for legacy knowledge source dicts when possible.""" + """Infer ``source_type`` for legacy knowledge source dicts when possible. + + Only StringKnowledgeSource is reliably inferrable: it stores ``content`` + as a plain string. File-based sources (CSV/PDF/Excel/JSON/docling) also + have a ``content`` field but populate it with dicts/lists, so we leave + those untagged and let downstream validation surface a clear error. + """ if not isinstance(source, dict) or "source_type" in source: return - if "content" in source: + if isinstance(source.get("content"), str): source["source_type"] = "string" + return + raise ValueError( + "Legacy knowledge source is missing 'source_type' and could not be " + "inferred during migration. Re-checkpoint after upgrading to 1.14.6+." + ) def _backfill_discriminators(entity: Any) -> None: diff --git a/lib/crewai/src/crewai/types/callback.py b/lib/crewai/src/crewai/types/callback.py index f08b03b2e..d4ec14b5a 100644 --- a/lib/crewai/src/crewai/types/callback.py +++ b/lib/crewai/src/crewai/types/callback.py @@ -154,7 +154,16 @@ SerializableCallable = Annotated[ def _instance_to_dotted_path(value: Any) -> str: """Serialize an instance to a dotted path naming its class.""" + if inspect.isclass(value): + raise ValueError( + f"Expected an instance, got class {value.__module__}.{value.__qualname__}." + ) cls = type(value) + if cls.__module__ == "builtins": + raise ValueError( + f"Cannot serialize {value!r}: builtin values are not " + "checkpointable instances." + ) module = getattr(cls, "__module__", None) qualname = getattr(cls, "__qualname__", None) if module is None or qualname is None: @@ -194,11 +203,3 @@ def _dotted_path_to_instance(value: Any) -> Any: f"{type(cls).__name__}" ) return cls() - - -SerializableInstance = Annotated[ - Any, - BeforeValidator(_dotted_path_to_instance), - PlainSerializer(_instance_to_dotted_path, return_type=str, when_used="json"), - WithJsonSchema({"type": "string"}), -] diff --git a/lib/crewai/tests/test_flow_ask.py b/lib/crewai/tests/test_flow_ask.py index d198e261c..ed616d1be 100644 --- a/lib/crewai/tests/test_flow_ask.py +++ b/lib/crewai/tests/test_flow_ask.py @@ -12,15 +12,71 @@ from datetime import datetime from typing import Any from unittest.mock import MagicMock, patch +from pydantic import BaseModel + from crewai.flow import Flow, flow_config, listen, start from crewai.flow.async_feedback.providers import ConsoleProvider from crewai.flow.flow import FlowState from crewai.flow.input_provider import InputProvider, InputResponse +from crewai.flow.persistence.base import FlowPersistence # ── Test helpers ───────────────────────────────────────────────── +class _SaveCall: + """Lightweight stand-in for ``MagicMock.call_args`` entries.""" + + __slots__ = ("args", "kwargs") + + def __init__(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None: + self.args = args + self.kwargs = kwargs + + +class _SaveStateRecorder: + """Callable that records each ``save_state`` invocation.""" + + def __init__(self, owner: RecordingPersistence) -> None: + self._owner = owner + self.call_args_list: list[_SaveCall] = [] + + def __call__( + self, + flow_uuid: str, + method_name: str, + state_data: dict[str, Any] | BaseModel, + ) -> None: + self.call_args_list.append( + _SaveCall((flow_uuid, method_name, state_data), {}) + ) + self._owner._states[flow_uuid] = state_data + + +class RecordingPersistence(FlowPersistence): + """In-memory FlowPersistence that records ``save_state`` invocations.""" + + persistence_type: str = "RecordingPersistence" + + def model_post_init(self, _: Any) -> None: + object.__setattr__(self, "_states", {}) + object.__setattr__(self, "save_state", _SaveStateRecorder(self)) + + def init_db(self) -> None: + return None + + def save_state( # type: ignore[no-redef] + self, + flow_uuid: str, + method_name: str, + state_data: dict[str, Any] | BaseModel, + ) -> None: + return None + + def load_state(self, flow_uuid: str) -> dict[str, Any] | None: + return None + + class MockInputProvider: """Mock input provider that returns pre-configured responses.""" @@ -436,8 +492,7 @@ class TestAskCheckpoint: def test_ask_checkpoints_state_before_waiting(self) -> None: """State is saved to persistence before waiting for input.""" - mock_persistence = MagicMock() - mock_persistence.load_state.return_value = None + mock_persistence = RecordingPersistence() class TestFlow(Flow): input_provider = MockInputProvider(["answer"]) @@ -480,8 +535,7 @@ class TestAskCheckpoint: server crashes while waiting for input, previously gathered data is safe. """ - mock_persistence = MagicMock() - mock_persistence.load_state.return_value = None + mock_persistence = RecordingPersistence() class GatherFlow(Flow): input_provider = MockInputProvider(["AI", "detailed"]) @@ -678,8 +732,7 @@ class TestAskIntegration: def test_ask_with_state_persistence_recovery(self) -> None: """Ask checkpoints state so previously gathered values survive.""" - mock_persistence = MagicMock() - mock_persistence.load_state.return_value = None + mock_persistence = RecordingPersistence() class RecoverableFlow(Flow): input_provider = MockInputProvider(["AI", "detailed"])