diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py index 9844bee03..8b5e36ff4 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py @@ -31,13 +31,13 @@ from crewai.agents.tools_handler import ToolsHandler from crewai.events.base_events import set_emission_counter from crewai.events.event_bus import crewai_event_bus from crewai.events.event_context import restore_event_scope, set_last_event_id -from crewai.knowledge.knowledge import Knowledge +from crewai.knowledge.knowledge import Knowledge, _resolve_knowledge_sources from crewai.knowledge.knowledge_config import KnowledgeConfig from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage from crewai.llms.base_llm import BaseLLM from crewai.mcp.config import MCPServerConfig -from crewai.memory.memory_scope import MemoryScope, MemorySlice +from crewai.memory.memory_scope import MemoryScope, MemorySlice, _ensure_memory_kind from crewai.memory.unified_memory import Memory from crewai.rag.embeddings.types import EmbedderConfig from crewai.security.security_config import SecurityConfig @@ -127,6 +127,13 @@ def _validate_executor_ref(value: Any) -> Any: return value +def _serialize_executor_ref(value: Any) -> dict[str, Any] | None: + if value is None: + return None + result: dict[str, Any] = value.model_dump(mode="json") + return result + + def _serialize_llm_ref(value: Any) -> dict[str, Any] | None: if value is None: return None @@ -251,14 +258,13 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): max_iter: int = Field( default=25, description="Maximum iterations for an agent to execute a task" ) - agent_executor: SerializeAsAny[BaseAgentExecutor] | None = Field( - default=None, description="An instance of the CrewAgentExecutor class." - ) - - @field_validator("agent_executor", mode="before") - @classmethod - def _validate_agent_executor(cls, v: Any) -> Any: - return _validate_executor_ref(v) + agent_executor: Annotated[ + SerializeAsAny[BaseAgentExecutor] | None, + BeforeValidator(_validate_executor_ref), + PlainSerializer( + _serialize_executor_ref, return_type=dict | None, when_used="json" + ), + ] = Field(default=None, description="An instance of the CrewAgentExecutor class.") llm: Annotated[ str | BaseLLM | None, @@ -288,7 +294,10 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): knowledge: Knowledge | None = Field( default=None, description="Knowledge for the agent." ) - knowledge_sources: list[BaseKnowledgeSource] | None = Field( + knowledge_sources: Annotated[ + list[BaseKnowledgeSource] | None, + BeforeValidator(_resolve_knowledge_sources), + ] = Field( default=None, description="Knowledge sources for the agent.", ) @@ -326,7 +335,14 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): default=None, description="List of MCP server references. Supports 'https://server.com/path' for external servers and bare slugs like 'notion' for connected MCP integrations. Use '#tool_name' suffix for specific tools.", ) - memory: bool | Memory | MemoryScope | MemorySlice | None = Field( + memory: Annotated[ + bool + | Annotated[ + Memory | MemoryScope | MemorySlice, Field(discriminator="memory_kind") + ] + | None, + BeforeValidator(_ensure_memory_kind), + ] = Field( default=None, description=( "Enable agent memory. Pass True for default Memory(), " @@ -397,8 +413,21 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): self.agent_executor._resuming = True if self.checkpoint_kickoff_event_id is not None: self._kickoff_event_id = self.checkpoint_kickoff_event_id + self._rebind_memory_view() self._restore_event_scope(state) + def _rebind_memory_view(self) -> None: + """Reattach a fresh ``Memory`` to a restored ``MemoryScope``/``MemorySlice``. + + Checkpoint JSON omits the live ``Memory`` dependency, so scoped + memory views raise ``RuntimeError`` on first use after restore. + """ + if ( + isinstance(self.memory, MemoryScope | MemorySlice) + and self.memory._memory is None + ): + self.memory.bind(Memory()) + def _restore_event_scope(self, state: RuntimeState) -> None: """Rebuild the event scope stack from the checkpoint's event record. diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index acc90e965..9f69129f1 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -93,11 +93,11 @@ from crewai.events.types.crew_events import ( CrewTrainStartedEvent, ) from crewai.flow.flow_trackable import FlowTrackable -from crewai.knowledge.knowledge import Knowledge +from crewai.knowledge.knowledge import Knowledge, _resolve_knowledge_sources from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.llm import LLM from crewai.llms.base_llm import BaseLLM -from crewai.memory.memory_scope import MemoryScope, MemorySlice +from crewai.memory.memory_scope import MemoryScope, MemorySlice, _ensure_memory_kind from crewai.memory.unified_memory import Memory from crewai.process import Process from crewai.rag.embeddings.types import EmbedderConfig @@ -223,7 +223,14 @@ class Crew(FlowTrackable, BaseModel): ] = Field(default_factory=list) process: Process = Field(default=Process.sequential) verbose: bool = Field(default=False) - memory: bool | Memory | MemoryScope | MemorySlice | None = Field( + memory: Annotated[ + bool + | Annotated[ + Memory | MemoryScope | MemorySlice, Field(discriminator="memory_kind") + ] + | None, + BeforeValidator(_ensure_memory_kind), + ] = Field( default=False, description=( "Enable crew memory. Pass True for default Memory(), " @@ -322,7 +329,10 @@ class Crew(FlowTrackable, BaseModel): default_factory=list, description="list of execution logs for tasks", ) - knowledge_sources: list[BaseKnowledgeSource] | None = Field( + knowledge_sources: Annotated[ + list[BaseKnowledgeSource] | None, + BeforeValidator(_resolve_knowledge_sources), + ] = Field( default=None, description=( "Knowledge sources for the crew. Add knowledge sources to the " @@ -477,8 +487,42 @@ class Crew(FlowTrackable, BaseModel): if self.checkpoint_train is not None: self._train = self.checkpoint_train + self._rebind_memory_views() self._restore_event_scope() + def _rebind_memory_views(self) -> None: + """Reattach a live ``Memory`` to restored ``MemoryScope``/``MemorySlice`` views. + + Checkpoint JSON omits the live ``Memory`` dependency on scope/slice + views, so after restore they raise ``RuntimeError`` on first use. + Prefer the crew's restored ``Memory`` (from ``create_crew_memory`` + or a ``Crew.memory=Memory(...)`` instance) so all views share one + backing store; fall back to a fresh ``Memory()`` only if nothing is + available. + """ + from crewai.memory.memory_scope import MemoryScope, MemorySlice + from crewai.memory.unified_memory import Memory + + backing: Memory | None = None + if isinstance(self._memory, Memory): + backing = self._memory + elif isinstance(self.memory, Memory): + backing = self.memory + + def _ensure(view: Any) -> None: + nonlocal backing + if not isinstance(view, MemoryScope | MemorySlice): + return + if view._memory is not None: + return + if backing is None: + backing = Memory() + view.bind(backing) + + _ensure(self.memory) + for agent in self.agents: + _ensure(agent.memory) + def _restore_event_scope(self) -> None: """Rebuild the event scope stack from the checkpoint's event record.""" from crewai.events.base_events import set_emission_counter diff --git a/lib/crewai/src/crewai/flow/flow.py b/lib/crewai/src/crewai/flow/flow.py index d22794873..ef9658128 100644 --- a/lib/crewai/src/crewai/flow/flow.py +++ b/lib/crewai/src/crewai/flow/flow.py @@ -113,7 +113,7 @@ from crewai.flow.utils import ( is_flow_method_name, is_simple_flow_condition, ) -from crewai.memory.memory_scope import MemoryScope, MemorySlice +from crewai.memory.memory_scope import MemoryScope, MemorySlice, _ensure_memory_kind from crewai.memory.unified_memory import Memory from crewai.state.checkpoint_config import ( CheckpointConfig, @@ -159,6 +159,39 @@ def _resolve_persistence(value: Any) -> Any: return value +def _serialize_persistence(value: Any) -> dict[str, Any] | None: + if value is None: + return None + if isinstance(value, FlowPersistence): + return value.model_dump(mode="json") + raise TypeError( + f"Cannot serialize Flow.persistence of type {type(value).__name__}: " + "expected FlowPersistence or None." + ) + + +def _validate_input_provider(value: Any) -> Any: + if value is None or isinstance(value, InputProvider): + return value + from crewai.types.callback import _dotted_path_to_instance + + resolved = _dotted_path_to_instance(value) + if resolved is None or isinstance(resolved, InputProvider): + return resolved + raise ValueError( + f"Resolved input_provider {resolved!r} does not implement the " + "InputProvider protocol (missing request_input)." + ) + + +def _serialize_input_provider(value: Any) -> str | None: + if value is None: + return None + from crewai.types.callback import _instance_to_dotted_path + + return _instance_to_dotted_path(value) + + _INITIAL_STATE_CLASS_MARKER = "__crewai_pydantic_class_schema__" @@ -949,15 +982,30 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): name: str | None = Field(default=None) tracing: bool | None = Field(default=None) stream: bool = Field(default=False) - memory: Memory | MemoryScope | MemorySlice | None = Field(default=None) - input_provider: InputProvider | None = Field(default=None) + memory: Annotated[ + Annotated[ + Memory | MemoryScope | MemorySlice, Field(discriminator="memory_kind") + ] + | None, + BeforeValidator(_ensure_memory_kind), + ] = Field(default=None) + input_provider: Annotated[ + InputProvider | None, + BeforeValidator(_validate_input_provider), + PlainSerializer( + _serialize_input_provider, return_type=str | None, when_used="json" + ), + ] = Field(default=None) suppress_flow_events: bool = Field(default=False) human_feedback_history: list[HumanFeedbackResult] = Field(default_factory=list) last_human_feedback: HumanFeedbackResult | None = Field(default=None) persistence: Annotated[ - SerializeAsAny[FlowPersistence] | Any, + SerializeAsAny[FlowPersistence] | None, BeforeValidator(lambda v, _: _resolve_persistence(v)), + PlainSerializer( + _serialize_persistence, return_type=dict | None, when_used="json" + ), ] = Field(default=None) max_method_calls: int = Field(default=100) @@ -1050,6 +1098,11 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): } if self.checkpoint_state is not None: self._restore_state(self.checkpoint_state) + if ( + isinstance(self.memory, MemoryScope | MemorySlice) + and self.memory._memory is None + ): + self.memory.bind(Memory()) restore_event_scope(()) reset_last_event_id() diff --git a/lib/crewai/src/crewai/knowledge/knowledge.py b/lib/crewai/src/crewai/knowledge/knowledge.py index eceef8b99..8dcf38f4e 100644 --- a/lib/crewai/src/crewai/knowledge/knowledge.py +++ b/lib/crewai/src/crewai/knowledge/knowledge.py @@ -1,16 +1,89 @@ import os +from typing import Annotated, Any -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, PlainSerializer from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource +from crewai.knowledge.source.crew_docling_source import CrewDoclingSource +from crewai.knowledge.source.csv_knowledge_source import CSVKnowledgeSource +from crewai.knowledge.source.excel_knowledge_source import ExcelKnowledgeSource +from crewai.knowledge.source.json_knowledge_source import JSONKnowledgeSource +from crewai.knowledge.source.pdf_knowledge_source import PDFKnowledgeSource +from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource +from crewai.knowledge.source.text_file_knowledge_source import ( + TextFileKnowledgeSource, +) from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage +from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider from crewai.rag.embeddings.types import EmbedderConfig from crewai.rag.types import SearchResult +_KNOWN_SOURCES: dict[str, type[BaseKnowledgeSource]] = { + "string": StringKnowledgeSource, + "docling": CrewDoclingSource, + "csv": CSVKnowledgeSource, + "excel": ExcelKnowledgeSource, + "json": JSONKnowledgeSource, + "pdf": PDFKnowledgeSource, + "text_file": TextFileKnowledgeSource, +} + + +def _resolve_knowledge_sources(value: Any) -> Any: + """Coerce list of dicts into typed BaseKnowledgeSource subclasses via source_type. + + Pass-through for anything else (existing instances, mocks). + """ + if not isinstance(value, list): + return value + resolved: list[Any] = [] + for idx, item in enumerate(value): + if isinstance(item, dict): + tag = item.get("source_type") + if not isinstance(tag, str): + resolved.append(item) + continue + cls = _KNOWN_SOURCES.get(tag) + if cls is None: + raise ValueError( + f"Unknown source_type={tag!r} at index {idx}: " + f"expected one of {sorted(_KNOWN_SOURCES)}" + ) + try: + resolved.append(cls.model_validate(item)) + except Exception as exc: + raise ValueError( + f"Failed to validate knowledge source at index {idx} " + f"with source_type={tag!r}: {exc}" + ) from exc + else: + resolved.append(item) + return resolved + + os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed +def _serialize_embedder_spec(value: Any) -> dict[str, Any] | None: + if value is None: + return None + if isinstance(value, BaseEmbeddingsProvider): + return value.model_dump(mode="json") + if isinstance(value, dict): + return value + if isinstance(value, type) and issubclass(value, BaseEmbeddingsProvider): + raise TypeError( + f"Cannot checkpoint embedder class {value.__module__}.{value.__qualname__}: " + "build_embedder requires an instance or ProviderSpec dict, not a class. " + "Instantiate the provider before assigning it to Knowledge.embedder." + ) + raise TypeError( + f"Cannot serialize embedder of type {type(value).__name__}: " + "expected ProviderSpec dict or BaseEmbeddingsProvider instance." + ) + + class Knowledge(BaseModel): """ Knowledge is a collection of sources and setup for the vector store to save and query relevant context. @@ -20,10 +93,18 @@ class Knowledge(BaseModel): embedder: EmbedderConfig | None = None """ - sources: list[BaseKnowledgeSource] = Field(default_factory=list) + sources: Annotated[ + list[BaseKnowledgeSource], + BeforeValidator(_resolve_knowledge_sources), + ] = Field(default_factory=list) model_config = ConfigDict(arbitrary_types_allowed=True) storage: KnowledgeStorage | None = Field(default=None) - embedder: EmbedderConfig | None = None + embedder: Annotated[ + EmbedderConfig | None, + PlainSerializer( + _serialize_embedder_spec, return_type=dict | None, when_used="json" + ), + ] = None collection_name: str | None = None def __init__( diff --git a/lib/crewai/src/crewai/knowledge/source/base_knowledge_source.py b/lib/crewai/src/crewai/knowledge/source/base_knowledge_source.py index 4f4a53fb0..8c99b47b0 100644 --- a/lib/crewai/src/crewai/knowledge/source/base_knowledge_source.py +++ b/lib/crewai/src/crewai/knowledge/source/base_knowledge_source.py @@ -13,7 +13,9 @@ class BaseKnowledgeSource(BaseModel, ABC): chunk_size: int = 4000 chunk_overlap: int = 200 chunks: list[str] = Field(default_factory=list) - chunk_embeddings: list[np.ndarray[Any, np.dtype[Any]]] = Field(default_factory=list) + chunk_embeddings: list[np.ndarray[Any, np.dtype[Any]]] = Field( + default_factory=list, exclude=True + ) model_config = ConfigDict(arbitrary_types_allowed=True) storage: KnowledgeStorage | None = Field(default=None) diff --git a/lib/crewai/src/crewai/knowledge/source/crew_docling_source.py b/lib/crewai/src/crewai/knowledge/source/crew_docling_source.py index 3dddacfac..42d69049b 100644 --- a/lib/crewai/src/crewai/knowledge/source/crew_docling_source.py +++ b/lib/crewai/src/crewai/knowledge/source/crew_docling_source.py @@ -2,7 +2,7 @@ from __future__ import annotations from collections.abc import Iterator from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal from urllib.parse import urlparse @@ -45,6 +45,7 @@ class CrewDoclingSource(BaseKnowledgeSource): _logger: Logger = Logger(verbose=True) + source_type: Literal["docling"] = "docling" file_path: list[Path | str] | None = Field(default=None) file_paths: list[Path | str] = Field(default_factory=list) chunks: list[str] = Field(default_factory=list) diff --git a/lib/crewai/src/crewai/knowledge/source/csv_knowledge_source.py b/lib/crewai/src/crewai/knowledge/source/csv_knowledge_source.py index 7da82c3e3..8a87c6fb3 100644 --- a/lib/crewai/src/crewai/knowledge/source/csv_knowledge_source.py +++ b/lib/crewai/src/crewai/knowledge/source/csv_knowledge_source.py @@ -1,5 +1,6 @@ import csv from pathlib import Path +from typing import Literal from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource @@ -7,6 +8,8 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge class CSVKnowledgeSource(BaseFileKnowledgeSource): """A knowledge source that stores and queries CSV file content using embeddings.""" + source_type: Literal["csv"] = "csv" + def load_content(self) -> dict[Path, str]: """Load and preprocess CSV file content.""" content_dict = {} diff --git a/lib/crewai/src/crewai/knowledge/source/excel_knowledge_source.py b/lib/crewai/src/crewai/knowledge/source/excel_knowledge_source.py index ece582053..2e492019f 100644 --- a/lib/crewai/src/crewai/knowledge/source/excel_knowledge_source.py +++ b/lib/crewai/src/crewai/knowledge/source/excel_knowledge_source.py @@ -1,6 +1,6 @@ from pathlib import Path from types import ModuleType -from typing import Any +from typing import Any, Literal from pydantic import Field, field_validator @@ -16,6 +16,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource): _logger: Logger = Logger(verbose=True) + source_type: Literal["excel"] = "excel" file_path: Path | list[Path] | str | list[str] | None = Field( default=None, description="[Deprecated] The path to the file. Use file_paths instead.", diff --git a/lib/crewai/src/crewai/knowledge/source/json_knowledge_source.py b/lib/crewai/src/crewai/knowledge/source/json_knowledge_source.py index ac527af2d..e547f318b 100644 --- a/lib/crewai/src/crewai/knowledge/source/json_knowledge_source.py +++ b/lib/crewai/src/crewai/knowledge/source/json_knowledge_source.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Any +from typing import Any, Literal from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource @@ -8,6 +8,8 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge class JSONKnowledgeSource(BaseFileKnowledgeSource): """A knowledge source that stores and queries JSON file content using embeddings.""" + source_type: Literal["json"] = "json" + def load_content(self) -> dict[Path, str]: """Load and preprocess JSON file content.""" content: dict[Path, str] = {} diff --git a/lib/crewai/src/crewai/knowledge/source/pdf_knowledge_source.py b/lib/crewai/src/crewai/knowledge/source/pdf_knowledge_source.py index 8af860875..733513aea 100644 --- a/lib/crewai/src/crewai/knowledge/source/pdf_knowledge_source.py +++ b/lib/crewai/src/crewai/knowledge/source/pdf_knowledge_source.py @@ -1,5 +1,6 @@ from pathlib import Path from types import ModuleType +from typing import Literal from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource @@ -7,6 +8,8 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge class PDFKnowledgeSource(BaseFileKnowledgeSource): """A knowledge source that stores and queries PDF file content using embeddings.""" + source_type: Literal["pdf"] = "pdf" + def load_content(self) -> dict[Path, str]: """Load and preprocess PDF file content.""" pdfplumber = self._import_pdfplumber() diff --git a/lib/crewai/src/crewai/knowledge/source/string_knowledge_source.py b/lib/crewai/src/crewai/knowledge/source/string_knowledge_source.py index b1165c2d1..639ae98bd 100644 --- a/lib/crewai/src/crewai/knowledge/source/string_knowledge_source.py +++ b/lib/crewai/src/crewai/knowledge/source/string_knowledge_source.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, Literal from pydantic import Field @@ -8,6 +8,7 @@ from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource class StringKnowledgeSource(BaseKnowledgeSource): """A knowledge source that stores and queries plain text content using embeddings.""" + source_type: Literal["string"] = "string" content: str = Field(...) collection_name: str | None = Field(default=None) diff --git a/lib/crewai/src/crewai/knowledge/source/text_file_knowledge_source.py b/lib/crewai/src/crewai/knowledge/source/text_file_knowledge_source.py index 00265743d..5e88da46f 100644 --- a/lib/crewai/src/crewai/knowledge/source/text_file_knowledge_source.py +++ b/lib/crewai/src/crewai/knowledge/source/text_file_knowledge_source.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Literal from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource @@ -6,6 +7,8 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge class TextFileKnowledgeSource(BaseFileKnowledgeSource): """A knowledge source that stores and queries text file content using embeddings.""" + source_type: Literal["text_file"] = "text_file" + def load_content(self) -> dict[Path, str]: """Load and preprocess text file content.""" content = {} diff --git a/lib/crewai/src/crewai/memory/memory_scope.py b/lib/crewai/src/crewai/memory/memory_scope.py index b5418e03f..1cd09d476 100644 --- a/lib/crewai/src/crewai/memory/memory_scope.py +++ b/lib/crewai/src/crewai/memory/memory_scope.py @@ -6,6 +6,7 @@ from datetime import datetime from typing import Any, Literal from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from typing_extensions import Self from crewai.memory.types import ( _RECALL_OVERSAMPLE_FACTOR, @@ -16,15 +17,35 @@ from crewai.memory.types import ( from crewai.memory.unified_memory import Memory +def _ensure_memory_kind(value: Any) -> Any: + """Backfill ``memory_kind`` on legacy dicts that predate the discriminator. + + Lets pre-1.14.6 configs/checkpoints flow into the discriminated + ``Memory | MemoryScope | MemorySlice`` union without crashing. Inference: + ``scopes`` key → ``slice``; ``root_path`` → ``scope``; else ``memory``. + Pass-through for non-dict values (instances, ``bool``, ``None``). + """ + if isinstance(value, dict) and "memory_kind" not in value: + if "scopes" in value: + value["memory_kind"] = "slice" + elif "root_path" in value: + value["memory_kind"] = "scope" + else: + value["memory_kind"] = "memory" + return value + + class MemoryScope(BaseModel): """View of Memory restricted to a root path. All operations are scoped under that path.""" model_config = ConfigDict(arbitrary_types_allowed=True) + memory_kind: Literal["scope"] = "scope" + root_path: str = Field(default="/") - _memory: Memory = PrivateAttr() - _root: str = PrivateAttr() + _memory: Memory | None = PrivateAttr(default=None) + _root: str = PrivateAttr(default="") @model_validator(mode="wrap") @classmethod @@ -34,21 +55,38 @@ class MemoryScope(BaseModel): return data if not isinstance(data, dict): raise ValueError(f"Expected dict or MemoryScope, got {type(data).__name__}") - if "memory" not in data: - raise ValueError("MemoryScope requires a 'memory' key") - memory = data.pop("memory") + memory = data.pop("memory", None) instance: MemoryScope = handler(data) - instance._memory = memory + if memory is not None: + instance._memory = memory root = instance.root_path.rstrip("/") or "" if root and not root.startswith("/"): root = "/" + root instance._root = root return instance + def bind(self, memory: Memory) -> Self: + """Rebind the runtime ``Memory`` dependency after restore. + + Required after deserializing from a checkpoint, since the live + ``Memory`` cannot be serialized. + """ + self._memory = memory + return self + + def _require_memory(self) -> Memory: + """Return the bound ``Memory`` or raise a clear error if missing.""" + if self._memory is None: + raise RuntimeError( + "MemoryScope is not bound to a Memory; call .bind(memory) " + "after restore." + ) + return self._memory + @property def read_only(self) -> bool: """Whether the underlying memory is read-only.""" - return self._memory.read_only + return self._require_memory().read_only def _scope_path(self, scope: str | None) -> str: if not scope or scope == "/": @@ -73,7 +111,7 @@ class MemoryScope(BaseModel): ) -> MemoryRecord | None: """Remember content; scope is relative to this scope's root.""" path = self._scope_path(scope) - return self._memory.remember( + return self._require_memory().remember( content, scope=path, categories=categories, @@ -96,7 +134,7 @@ class MemoryScope(BaseModel): ) -> list[MemoryRecord]: """Remember multiple items; scope is relative to this scope's root.""" path = self._scope_path(scope) - return self._memory.remember_many( + return self._require_memory().remember_many( contents, scope=path, categories=categories, @@ -119,7 +157,7 @@ class MemoryScope(BaseModel): ) -> list[MemoryMatch]: """Recall within this scope (root path and below).""" search_scope = self._scope_path(scope) if scope else (self._root or "/") - return self._memory.recall( + return self._require_memory().recall( query, scope=search_scope, categories=categories, @@ -131,7 +169,7 @@ class MemoryScope(BaseModel): def extract_memories(self, content: str) -> list[str]: """Extract discrete memories from content; delegates to underlying Memory.""" - return self._memory.extract_memories(content) + return self._require_memory().extract_memories(content) def forget( self, @@ -143,7 +181,7 @@ class MemoryScope(BaseModel): ) -> int: """Forget within this scope.""" prefix = self._scope_path(scope) if scope else (self._root or "/") - return self._memory.forget( + return self._require_memory().forget( scope=prefix, categories=categories, older_than=older_than, @@ -154,27 +192,27 @@ class MemoryScope(BaseModel): def list_scopes(self, path: str = "/") -> list[str]: """List child scopes under path (relative to this scope's root).""" full = self._scope_path(path) - return self._memory.list_scopes(full) + return self._require_memory().list_scopes(full) def info(self, path: str = "/") -> ScopeInfo: """Info for path under this scope.""" full = self._scope_path(path) - return self._memory.info(full) + return self._require_memory().info(full) def tree(self, path: str = "/", max_depth: int = 3) -> str: """Tree under path within this scope.""" full = self._scope_path(path) - return self._memory.tree(full, max_depth=max_depth) + return self._require_memory().tree(full, max_depth=max_depth) def list_categories(self, path: str | None = None) -> dict[str, int]: """Categories in this scope; path None means this scope root.""" full = self._scope_path(path) if path else (self._root or "/") - return self._memory.list_categories(full) + return self._require_memory().list_categories(full) def reset(self, scope: str | None = None) -> None: """Reset within this scope.""" prefix = self._scope_path(scope) if scope else (self._root or "/") - self._memory.reset(scope=prefix) + self._require_memory().reset(scope=prefix) def subscope(self, path: str) -> MemoryScope: """Return a narrower scope under this scope.""" @@ -191,11 +229,13 @@ class MemorySlice(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) + memory_kind: Literal["slice"] = "slice" + scopes: list[str] = Field(default_factory=list) categories: list[str] | None = Field(default=None) read_only: bool = Field(default=True) - _memory: Memory = PrivateAttr() + _memory: Memory | None = PrivateAttr(default=None) @model_validator(mode="wrap") @classmethod @@ -205,14 +245,27 @@ class MemorySlice(BaseModel): return data if not isinstance(data, dict): raise ValueError(f"Expected dict or MemorySlice, got {type(data).__name__}") - if "memory" not in data: - raise ValueError("MemorySlice requires a 'memory' key") - memory = data.pop("memory") + memory = data.pop("memory", None) data["scopes"] = [s.rstrip("/") or "/" for s in data.get("scopes", [])] instance: MemorySlice = handler(data) - instance._memory = memory + if memory is not None: + instance._memory = memory return instance + def bind(self, memory: Memory) -> Self: + """Rebind the runtime ``Memory`` dependency after restore.""" + self._memory = memory + return self + + def _require_memory(self) -> Memory: + """Return the bound ``Memory`` or raise a clear error if missing.""" + if self._memory is None: + raise RuntimeError( + "MemorySlice is not bound to a Memory; call .bind(memory) " + "after restore." + ) + return self._memory + def remember( self, content: str, @@ -226,7 +279,7 @@ class MemorySlice(BaseModel): """Remember into an explicit scope. No-op when read_only=True.""" if self.read_only: return None - return self._memory.remember( + return self._require_memory().remember( content, scope=scope, categories=categories, @@ -250,7 +303,7 @@ class MemorySlice(BaseModel): cats = categories or self.categories all_matches: list[MemoryMatch] = [] for sc in self.scopes: - matches = self._memory.recall( + matches = self._require_memory().recall( query, scope=sc, categories=cats, @@ -272,14 +325,14 @@ class MemorySlice(BaseModel): def extract_memories(self, content: str) -> list[str]: """Extract discrete memories from content; delegates to underlying Memory.""" - return self._memory.extract_memories(content) + return self._require_memory().extract_memories(content) def list_scopes(self, path: str = "/") -> list[str]: """List scopes across all slice roots.""" out: list[str] = [] for sc in self.scopes: full = f"{sc.rstrip('/')}{path}" if sc != "/" else path - out.extend(self._memory.list_scopes(full)) + out.extend(self._require_memory().list_scopes(full)) return sorted(set(out)) def info(self, path: str = "/") -> ScopeInfo: @@ -291,7 +344,7 @@ class MemorySlice(BaseModel): children: list[str] = [] for sc in self.scopes: full = f"{sc.rstrip('/')}{path}" if sc != "/" else path - inf = self._memory.info(full) + inf = self._require_memory().info(full) total_records += inf.record_count all_categories.update(inf.categories) if inf.oldest_record: @@ -321,6 +374,6 @@ class MemorySlice(BaseModel): counts: dict[str, int] = {} for sc in self.scopes: full = (f"{sc.rstrip('/')}{path}" if sc != "/" else path) if path else sc - for k, v in self._memory.list_categories(full).items(): + for k, v in self._require_memory().list_categories(full).items(): counts[k] = counts.get(k, 0) + v return counts diff --git a/lib/crewai/src/crewai/memory/unified_memory.py b/lib/crewai/src/crewai/memory/unified_memory.py index d879bace0..27a2c109d 100644 --- a/lib/crewai/src/crewai/memory/unified_memory.py +++ b/lib/crewai/src/crewai/memory/unified_memory.py @@ -63,6 +63,8 @@ class Memory(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) + memory_kind: Literal["memory"] = "memory" + llm: Annotated[BaseLLM | str, PlainValidator(_passthrough)] = Field( default="gpt-4o-mini", description="LLM for analysis (model name or BaseLLM instance).", diff --git a/lib/crewai/src/crewai/state/runtime.py b/lib/crewai/src/crewai/state/runtime.py index 2662266d2..59c3171d9 100644 --- a/lib/crewai/src/crewai/state/runtime.py +++ b/lib/crewai/src/crewai/state/runtime.py @@ -113,12 +113,68 @@ def _migrate(data: dict[str, Any]) -> dict[str, Any]: ) # --- migrations in version order --- - # if stored < Version("X.Y.Z"): - # data.setdefault("some_field", "default") + if stored < Version("1.14.6"): + for entity in data.get("entities") or []: + _backfill_discriminators(entity) return data +def _backfill_memory_kind(value: Any) -> None: + """Infer ``memory_kind`` from structural fields on legacy memory dicts.""" + if not isinstance(value, dict) or "memory_kind" in value: + return + if "scopes" in value: + value["memory_kind"] = "slice" + elif "root_path" in value: + value["memory_kind"] = "scope" + else: + value["memory_kind"] = "memory" + + +def _backfill_source_type(source: Any) -> None: + """Infer ``source_type`` for legacy knowledge source dicts when possible. + + Only StringKnowledgeSource is reliably inferrable: it stores ``content`` + as a plain string. File-based sources (CSV/PDF/Excel/JSON/docling) also + have a ``content`` field but populate it with dicts/lists, so we leave + those untagged and let downstream validation surface a clear error. + """ + if not isinstance(source, dict) or "source_type" in source: + return + if isinstance(source.get("content"), str): + source["source_type"] = "string" + return + raise ValueError( + "Legacy knowledge source is missing 'source_type' and could not be " + "inferred during migration. Re-checkpoint after upgrading to 1.14.6+." + ) + + +def _backfill_sources_on(container: Any) -> None: + """Apply source_type backfill to ``sources`` and ``knowledge_sources`` lists.""" + if not isinstance(container, dict): + return + for key in ("sources", "knowledge_sources"): + for src in container.get(key) or []: + _backfill_source_type(src) + + +def _backfill_discriminators(entity: Any) -> None: + """Walk an entity dict and backfill discriminator fields added in 1.14.6.""" + if not isinstance(entity, dict): + return + _backfill_memory_kind(entity.get("memory")) + _backfill_sources_on(entity) + _backfill_sources_on(entity.get("knowledge")) + for agent in entity.get("agents") or []: + if not isinstance(agent, dict): + continue + _backfill_memory_kind(agent.get("memory")) + _backfill_sources_on(agent) + _backfill_sources_on(agent.get("knowledge")) + + class RuntimeState(RootModel): # type: ignore[type-arg] root: list[Entity] _provider: BaseProvider = PrivateAttr(default_factory=JsonProvider) diff --git a/lib/crewai/src/crewai/types/callback.py b/lib/crewai/src/crewai/types/callback.py index 2a8be235e..ea89effdb 100644 --- a/lib/crewai/src/crewai/types/callback.py +++ b/lib/crewai/src/crewai/types/callback.py @@ -19,6 +19,15 @@ from pydantic import BeforeValidator, WithJsonSchema from pydantic.functional_serializers import PlainSerializer +_TRUSTED_DESERIALIZE_VALUES = frozenset({"1", "true", "yes"}) + + +def _trusted_deserialize() -> bool: + """Return True only if ``CREWAI_DESERIALIZE_CALLBACKS`` is an explicit yes.""" + raw = os.environ.get("CREWAI_DESERIALIZE_CALLBACKS", "") + return raw.strip().lower() in _TRUSTED_DESERIALIZE_VALUES + + def _is_non_roundtrippable(fn: object) -> bool: """Return ``True`` if *fn* cannot survive a serialize/deserialize round-trip. @@ -76,7 +85,7 @@ def string_to_callable(value: Any) -> Callable[..., Any]: raise ValueError( f"Invalid callback path {value!r}: expected 'module.name' format" ) - if not os.environ.get("CREWAI_DESERIALIZE_CALLBACKS"): + if not _trusted_deserialize(): raise ValueError( f"Refusing to resolve callback path {value!r}: " "set CREWAI_DESERIALIZE_CALLBACKS=1 to allow. " @@ -150,3 +159,78 @@ SerializableCallable = Annotated[ PlainSerializer(callable_to_string, return_type=str, when_used="json"), WithJsonSchema({"type": "string"}), ] + + +def _instance_to_dotted_path(value: Any) -> str: + """Serialize an instance to a dotted path naming its class.""" + if inspect.isclass(value): + module = getattr(value, "__module__", "") + qualname = getattr( + value, "__qualname__", getattr(value, "__name__", str(type(value))) + ) + raise ValueError(f"Expected an instance, got class {module}.{qualname}.") + cls = type(value) + if cls.__module__ == "builtins": + raise ValueError( + f"Cannot serialize {value!r}: builtin values are not " + "checkpointable instances." + ) + module = getattr(cls, "__module__", None) + qualname = getattr(cls, "__qualname__", None) + if module is None or qualname is None: + raise ValueError( + f"Cannot serialize {value!r}: class missing __module__ or __qualname__. " + "Use a module-level class for checkpointable instances." + ) + if qualname.endswith("") or "" in qualname: + raise ValueError( + f"Cannot serialize {value!r}: class defined in . " + "Use a module-level class for checkpointable instances." + ) + return f"{module}.{qualname}" + + +def _dotted_path_to_instance(value: Any) -> Any: + """Resolve a dotted path to a class and instantiate it with no args. + + If *value* is already a non-string object it is returned as-is. + """ + if value is None: + return value + if not isinstance(value, str): + if inspect.isclass(value): + raise ValueError( + f"Expected an instance or dotted path string, got class " + f"{getattr(value, '__module__', '')}." + f"{getattr(value, '__qualname__', getattr(value, '__name__', ''))}." + ) + if type(value).__module__ == "builtins": + raise ValueError( + f"Expected an instance of a user-defined class or dotted " + f"path string, got builtin value {value!r}." + ) + return value + if "." not in value: + raise ValueError( + f"Invalid provider path {value!r}: expected 'module.name' format" + ) + if not _trusted_deserialize(): + raise ValueError( + f"Refusing to resolve provider path {value!r}: " + "set CREWAI_DESERIALIZE_CALLBACKS=1 to allow. " + "Only enable this for trusted checkpoint data." + ) + cls = _resolve_dotted_path(value) + if not inspect.isclass(cls): + raise ValueError( + f"Invalid provider path {value!r}: expected a class, got " + f"{type(cls).__name__}" + ) + try: + return cls() + except TypeError as exc: + raise ValueError( + f"Cannot reinstantiate {value!r} with no arguments: {exc}. " + "Only no-arg constructors are checkpointable; rebuild the " + "instance manually and assign it after restore." + ) from exc diff --git a/lib/crewai/src/crewai/utilities/reset_memories.py b/lib/crewai/src/crewai/utilities/reset_memories.py index 50d4a633e..e8239b83d 100644 --- a/lib/crewai/src/crewai/utilities/reset_memories.py +++ b/lib/crewai/src/crewai/utilities/reset_memories.py @@ -25,10 +25,16 @@ def _reset_flow_memory(flow: Flow[Any]) -> None: try: if hasattr(mem, "reset"): mem.reset() - elif hasattr(mem, "_memory") and hasattr(mem._memory, "reset"): + elif hasattr(mem, "_memory") and mem._memory is not None: mem._memory.reset() - except (FileNotFoundError, OSError): + except FileNotFoundError: + # Storage directory was never created — nothing to reset. pass + except OSError as exc: + click.echo(f"Memory reset skipped: storage I/O error ({exc}).", err=True) + except RuntimeError as exc: + # Restored MemoryScope/MemorySlice without a rebound Memory. + click.echo(f"Memory reset skipped: {exc}", err=True) def reset_memories_command( diff --git a/lib/crewai/tests/test_flow_ask.py b/lib/crewai/tests/test_flow_ask.py index d198e261c..5ba3729df 100644 --- a/lib/crewai/tests/test_flow_ask.py +++ b/lib/crewai/tests/test_flow_ask.py @@ -7,20 +7,87 @@ durability, input history tracking, and integration with flow machinery. from __future__ import annotations +import copy import time from datetime import datetime from typing import Any from unittest.mock import MagicMock, patch +from pydantic import BaseModel + from crewai.flow import Flow, flow_config, listen, start from crewai.flow.async_feedback.providers import ConsoleProvider from crewai.flow.flow import FlowState from crewai.flow.input_provider import InputProvider, InputResponse +from crewai.flow.persistence.base import FlowPersistence # ── Test helpers ───────────────────────────────────────────────── +class _SaveCall: + """Lightweight stand-in for ``MagicMock.call_args`` entries.""" + + __slots__ = ("args", "kwargs") + + def __init__(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None: + self.args = args + self.kwargs = kwargs + + +class _SaveStateRecorder: + """Callable that records each ``save_state`` invocation.""" + + def __init__(self, owner: RecordingPersistence) -> None: + self._owner = owner + self.call_args_list: list[_SaveCall] = [] + + def __call__( + self, + flow_uuid: str, + method_name: str, + state_data: dict[str, Any] | BaseModel, + ) -> None: + snapshot: dict[str, Any] | BaseModel + if isinstance(state_data, BaseModel): + snapshot = state_data.model_copy(deep=True) + else: + snapshot = copy.deepcopy(state_data) + self.call_args_list.append( + _SaveCall((flow_uuid, method_name, snapshot), {}) + ) + self._owner._states[flow_uuid] = snapshot + + +class RecordingPersistence(FlowPersistence): + """In-memory FlowPersistence that records ``save_state`` invocations.""" + + persistence_type: str = "RecordingPersistence" + + def model_post_init(self, _: Any) -> None: + object.__setattr__(self, "_states", {}) + object.__setattr__(self, "save_state", _SaveStateRecorder(self)) + + def init_db(self) -> None: + return None + + def save_state( # type: ignore[no-redef] + self, + flow_uuid: str, + method_name: str, + state_data: dict[str, Any] | BaseModel, + ) -> None: + return None + + def load_state(self, flow_uuid: str) -> dict[str, Any] | None: + snapshot = self._states.get(flow_uuid) + if snapshot is None: + return None + if isinstance(snapshot, BaseModel): + return snapshot.model_copy(deep=True).model_dump() + return copy.deepcopy(snapshot) + + class MockInputProvider: """Mock input provider that returns pre-configured responses.""" @@ -436,8 +503,7 @@ class TestAskCheckpoint: def test_ask_checkpoints_state_before_waiting(self) -> None: """State is saved to persistence before waiting for input.""" - mock_persistence = MagicMock() - mock_persistence.load_state.return_value = None + mock_persistence = RecordingPersistence() class TestFlow(Flow): input_provider = MockInputProvider(["answer"]) @@ -480,8 +546,7 @@ class TestAskCheckpoint: server crashes while waiting for input, previously gathered data is safe. """ - mock_persistence = MagicMock() - mock_persistence.load_state.return_value = None + mock_persistence = RecordingPersistence() class GatherFlow(Flow): input_provider = MockInputProvider(["AI", "detailed"]) @@ -678,8 +743,7 @@ class TestAskIntegration: def test_ask_with_state_persistence_recovery(self) -> None: """Ask checkpoints state so previously gathered values survive.""" - mock_persistence = MagicMock() - mock_persistence.load_state.return_value = None + mock_persistence = RecordingPersistence() class RecoverableFlow(Flow): input_provider = MockInputProvider(["AI", "detailed"])