Compare commits

...

4 Commits

Author SHA1 Message Date
Greyson LaLonde
0f3a57b3b9 fix: round-trip safety for input_provider, memory scopes, embedder class
- input_provider: enforce InputProvider protocol via dedicated
  validator/serializer; reject non-class dotted paths in
  _dotted_path_to_instance
- MemoryScope/MemorySlice: allow restore without live Memory; expose
  bind() to reattach the dependency post-restore
- Knowledge.embedder: add BeforeValidator that resolves provider_class
  dotted paths back to a BaseEmbeddingsProvider subclass
2026-05-21 00:30:14 +08:00
Greyson LaLonde
b07c1439a3 fix: backfill legacy discriminators and add source validation context 2026-05-21 00:22:53 +08:00
Greyson LaLonde
97e959cb0c fix: raise on unrecognized embedder shape in serializer 2026-05-21 00:17:52 +08:00
Greyson LaLonde
752d9b45d6 fix: harden RuntimeState serialization across entity fields
Adds missing serializers, discriminators, and exclude markers on entity
fields that previously crashed model_dump_json or restored ambiguously:

- Flow.persistence: add _serialize_persistence; drop | Any escape hatch
- Flow.input_provider: SerializableInstance dotted-path round-trip
- BaseAgent.agent_executor: add _serialize_executor_ref
- BaseAgent.tools_handler / cache_handler: exclude=True
- Memory / MemoryScope / MemorySlice: memory_kind Literal discriminator
- Knowledge.storage / .embedder: exclude live client, serialize spec
- BaseKnowledgeSource subclasses: source_type Literal + dict-resolver
- BaseKnowledgeSource.storage / chunk_embeddings: exclude=True
2026-05-21 00:12:20 +08:00
16 changed files with 304 additions and 31 deletions

View File

@@ -127,6 +127,13 @@ def _validate_executor_ref(value: Any) -> Any:
return value
def _serialize_executor_ref(value: Any) -> dict[str, Any] | None:
if value is None:
return None
result: dict[str, Any] = value.model_dump(mode="json")
return result
def _serialize_llm_ref(value: Any) -> dict[str, Any] | None:
if value is None:
return None
@@ -251,14 +258,13 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
max_iter: int = Field(
default=25, description="Maximum iterations for an agent to execute a task"
)
agent_executor: SerializeAsAny[BaseAgentExecutor] | None = Field(
default=None, description="An instance of the CrewAgentExecutor class."
)
@field_validator("agent_executor", mode="before")
@classmethod
def _validate_agent_executor(cls, v: Any) -> Any:
return _validate_executor_ref(v)
agent_executor: Annotated[
SerializeAsAny[BaseAgentExecutor] | None,
BeforeValidator(_validate_executor_ref),
PlainSerializer(
_serialize_executor_ref, return_type=dict | None, when_used="json"
),
] = Field(default=None, description="An instance of the CrewAgentExecutor class.")
llm: Annotated[
str | BaseLLM | None,
@@ -326,7 +332,13 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
default=None,
description="List of MCP server references. Supports 'https://server.com/path' for external servers and bare slugs like 'notion' for connected MCP integrations. Use '#tool_name' suffix for specific tools.",
)
memory: bool | Memory | MemoryScope | MemorySlice | None = Field(
memory: (
bool
| Annotated[
Memory | MemoryScope | MemorySlice, Field(discriminator="memory_kind")
]
| None
) = Field(
default=None,
description=(
"Enable agent memory. Pass True for default Memory(), "

View File

@@ -223,7 +223,13 @@ class Crew(FlowTrackable, BaseModel):
] = Field(default_factory=list)
process: Process = Field(default=Process.sequential)
verbose: bool = Field(default=False)
memory: bool | Memory | MemoryScope | MemorySlice | None = Field(
memory: (
bool
| Annotated[
Memory | MemoryScope | MemorySlice, Field(discriminator="memory_kind")
]
| None
) = Field(
default=False,
description=(
"Enable crew memory. Pass True for default Memory(), "

View File

@@ -159,6 +159,36 @@ def _resolve_persistence(value: Any) -> Any:
return value
def _serialize_persistence(value: Any) -> dict[str, Any] | None:
if value is None:
return None
if isinstance(value, FlowPersistence):
return value.model_dump(mode="json")
return None
def _validate_input_provider(value: Any) -> Any:
if value is None or isinstance(value, InputProvider):
return value
from crewai.types.callback import _dotted_path_to_instance
resolved = _dotted_path_to_instance(value)
if resolved is None or isinstance(resolved, InputProvider):
return resolved
raise ValueError(
f"Resolved input_provider {resolved!r} does not implement the "
"InputProvider protocol (missing request_input)."
)
def _serialize_input_provider(value: Any) -> str | None:
if value is None:
return None
from crewai.types.callback import _instance_to_dotted_path
return _instance_to_dotted_path(value)
_INITIAL_STATE_CLASS_MARKER = "__crewai_pydantic_class_schema__"
@@ -949,15 +979,29 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
name: str | None = Field(default=None)
tracing: bool | None = Field(default=None)
stream: bool = Field(default=False)
memory: Memory | MemoryScope | MemorySlice | None = Field(default=None)
input_provider: InputProvider | None = Field(default=None)
memory: (
Annotated[
Memory | MemoryScope | MemorySlice, Field(discriminator="memory_kind")
]
| None
) = Field(default=None)
input_provider: Annotated[
InputProvider | None,
BeforeValidator(_validate_input_provider),
PlainSerializer(
_serialize_input_provider, return_type=str | None, when_used="json"
),
] = Field(default=None)
suppress_flow_events: bool = Field(default=False)
human_feedback_history: list[HumanFeedbackResult] = Field(default_factory=list)
last_human_feedback: HumanFeedbackResult | None = Field(default=None)
persistence: Annotated[
SerializeAsAny[FlowPersistence] | Any,
SerializeAsAny[FlowPersistence] | None,
BeforeValidator(lambda v, _: _resolve_persistence(v)),
PlainSerializer(
_serialize_persistence, return_type=dict | None, when_used="json"
),
] = Field(default=None)
max_method_calls: int = Field(default=100)

View File

@@ -1,16 +1,95 @@
import os
from typing import Annotated, Any
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, PlainSerializer
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.source.crew_docling_source import CrewDoclingSource
from crewai.knowledge.source.csv_knowledge_source import CSVKnowledgeSource
from crewai.knowledge.source.excel_knowledge_source import ExcelKnowledgeSource
from crewai.knowledge.source.json_knowledge_source import JSONKnowledgeSource
from crewai.knowledge.source.pdf_knowledge_source import PDFKnowledgeSource
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
from crewai.knowledge.source.text_file_knowledge_source import (
TextFileKnowledgeSource,
)
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.types import EmbedderConfig
from crewai.rag.types import SearchResult
_KNOWN_SOURCES: dict[str, type[BaseKnowledgeSource]] = {
"string": StringKnowledgeSource,
"docling": CrewDoclingSource,
"csv": CSVKnowledgeSource,
"excel": ExcelKnowledgeSource,
"json": JSONKnowledgeSource,
"pdf": PDFKnowledgeSource,
"text_file": TextFileKnowledgeSource,
}
def _resolve_knowledge_sources(value: Any) -> Any:
"""Coerce list of dicts into typed BaseKnowledgeSource subclasses via source_type.
Pass-through for anything else (existing instances, mocks).
"""
if not isinstance(value, list):
return value
resolved: list[Any] = []
for idx, item in enumerate(value):
if isinstance(item, dict):
tag = item.get("source_type")
cls = _KNOWN_SOURCES.get(tag) if isinstance(tag, str) else None
if cls is None:
resolved.append(item)
else:
try:
resolved.append(cls.model_validate(item))
except Exception as exc:
raise ValueError(
f"Failed to validate knowledge source at index {idx} "
f"with source_type={tag!r}: {exc}"
) from exc
else:
resolved.append(item)
return resolved
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
def _serialize_embedder_spec(value: Any) -> dict[str, Any] | None:
if value is None:
return None
if isinstance(value, BaseEmbeddingsProvider):
return value.model_dump(mode="json")
if isinstance(value, type) and issubclass(value, BaseEmbeddingsProvider):
return {"provider_class": f"{value.__module__}.{value.__qualname__}"}
if isinstance(value, dict):
return value
raise TypeError(
f"Cannot serialize embedder of type {type(value).__name__}: "
"expected ProviderSpec dict, BaseEmbeddingsProvider instance, or subclass."
)
def _validate_embedder_spec(value: Any) -> Any:
"""Resolve provider_class dotted-path dicts back to a class on restore."""
if isinstance(value, dict) and set(value.keys()) == {"provider_class"}:
from crewai.types.callback import _resolve_dotted_path
cls = _resolve_dotted_path(value["provider_class"])
if not isinstance(cls, type) or not issubclass(cls, BaseEmbeddingsProvider):
raise ValueError(
f"provider_class {value['provider_class']!r} did not resolve to a "
"BaseEmbeddingsProvider subclass."
)
return cls
return value
class Knowledge(BaseModel):
"""
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
@@ -20,10 +99,19 @@ class Knowledge(BaseModel):
embedder: EmbedderConfig | None = None
"""
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
sources: Annotated[
list[BaseKnowledgeSource],
BeforeValidator(_resolve_knowledge_sources),
] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: KnowledgeStorage | None = Field(default=None)
embedder: EmbedderConfig | None = None
embedder: Annotated[
EmbedderConfig | None,
BeforeValidator(_validate_embedder_spec),
PlainSerializer(
_serialize_embedder_spec, return_type=dict | None, when_used="json"
),
] = None
collection_name: str | None = None
def __init__(

View File

@@ -13,7 +13,9 @@ class BaseKnowledgeSource(BaseModel, ABC):
chunk_size: int = 4000
chunk_overlap: int = 200
chunks: list[str] = Field(default_factory=list)
chunk_embeddings: list[np.ndarray[Any, np.dtype[Any]]] = Field(default_factory=list)
chunk_embeddings: list[np.ndarray[Any, np.dtype[Any]]] = Field(
default_factory=list, exclude=True
)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: KnowledgeStorage | None = Field(default=None)

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
from collections.abc import Iterator
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal
from urllib.parse import urlparse
@@ -45,6 +45,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
_logger: Logger = Logger(verbose=True)
source_type: Literal["docling"] = "docling"
file_path: list[Path | str] | None = Field(default=None)
file_paths: list[Path | str] = Field(default_factory=list)
chunks: list[str] = Field(default_factory=list)

View File

@@ -1,5 +1,6 @@
import csv
from pathlib import Path
from typing import Literal
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
@@ -7,6 +8,8 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
class CSVKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries CSV file content using embeddings."""
source_type: Literal["csv"] = "csv"
def load_content(self) -> dict[Path, str]:
"""Load and preprocess CSV file content."""
content_dict = {}

View File

@@ -1,6 +1,6 @@
from pathlib import Path
from types import ModuleType
from typing import Any
from typing import Any, Literal
from pydantic import Field, field_validator
@@ -16,6 +16,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
_logger: Logger = Logger(verbose=True)
source_type: Literal["excel"] = "excel"
file_path: Path | list[Path] | str | list[str] | None = Field(
default=None,
description="[Deprecated] The path to the file. Use file_paths instead.",

View File

@@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import Any
from typing import Any, Literal
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
@@ -8,6 +8,8 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
class JSONKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries JSON file content using embeddings."""
source_type: Literal["json"] = "json"
def load_content(self) -> dict[Path, str]:
"""Load and preprocess JSON file content."""
content: dict[Path, str] = {}

View File

@@ -1,5 +1,6 @@
from pathlib import Path
from types import ModuleType
from typing import Literal
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
@@ -7,6 +8,8 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
class PDFKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries PDF file content using embeddings."""
source_type: Literal["pdf"] = "pdf"
def load_content(self) -> dict[Path, str]:
"""Load and preprocess PDF file content."""
pdfplumber = self._import_pdfplumber()

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Literal
from pydantic import Field
@@ -8,6 +8,7 @@ from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
class StringKnowledgeSource(BaseKnowledgeSource):
"""A knowledge source that stores and queries plain text content using embeddings."""
source_type: Literal["string"] = "string"
content: str = Field(...)
collection_name: str | None = Field(default=None)

View File

@@ -1,4 +1,5 @@
from pathlib import Path
from typing import Literal
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
@@ -6,6 +7,8 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
class TextFileKnowledgeSource(BaseFileKnowledgeSource):
"""A knowledge source that stores and queries text file content using embeddings."""
source_type: Literal["text_file"] = "text_file"
def load_content(self) -> dict[Path, str]:
"""Load and preprocess text file content."""
content = {}

View File

@@ -6,6 +6,7 @@ from datetime import datetime
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.memory.types import (
_RECALL_OVERSAMPLE_FACTOR,
@@ -21,6 +22,8 @@ class MemoryScope(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
memory_kind: Literal["scope"] = "scope"
root_path: str = Field(default="/")
_memory: Memory = PrivateAttr()
@@ -34,17 +37,25 @@ class MemoryScope(BaseModel):
return data
if not isinstance(data, dict):
raise ValueError(f"Expected dict or MemoryScope, got {type(data).__name__}")
if "memory" not in data:
raise ValueError("MemoryScope requires a 'memory' key")
memory = data.pop("memory")
memory = data.pop("memory", None)
instance: MemoryScope = handler(data)
instance._memory = memory
if memory is not None:
instance._memory = memory
root = instance.root_path.rstrip("/") or ""
if root and not root.startswith("/"):
root = "/" + root
instance._root = root
return instance
def bind(self, memory: Memory) -> Self:
"""Rebind the runtime ``Memory`` dependency after restore.
Required after deserializing from a checkpoint, since the live
``Memory`` cannot be serialized.
"""
self._memory = memory
return self
@property
def read_only(self) -> bool:
"""Whether the underlying memory is read-only."""
@@ -191,6 +202,8 @@ class MemorySlice(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
memory_kind: Literal["slice"] = "slice"
scopes: list[str] = Field(default_factory=list)
categories: list[str] | None = Field(default=None)
read_only: bool = Field(default=True)
@@ -205,14 +218,18 @@ class MemorySlice(BaseModel):
return data
if not isinstance(data, dict):
raise ValueError(f"Expected dict or MemorySlice, got {type(data).__name__}")
if "memory" not in data:
raise ValueError("MemorySlice requires a 'memory' key")
memory = data.pop("memory")
memory = data.pop("memory", None)
data["scopes"] = [s.rstrip("/") or "/" for s in data.get("scopes", [])]
instance: MemorySlice = handler(data)
instance._memory = memory
if memory is not None:
instance._memory = memory
return instance
def bind(self, memory: Memory) -> Self:
"""Rebind the runtime ``Memory`` dependency after restore."""
self._memory = memory
return self
def remember(
self,
content: str,

View File

@@ -63,6 +63,8 @@ class Memory(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
memory_kind: Literal["memory"] = "memory"
llm: Annotated[BaseLLM | str, PlainValidator(_passthrough)] = Field(
default="gpt-4o-mini",
description="LLM for analysis (model name or BaseLLM instance).",

View File

@@ -113,12 +113,48 @@ def _migrate(data: dict[str, Any]) -> dict[str, Any]:
)
# --- migrations in version order ---
# if stored < Version("X.Y.Z"):
# data.setdefault("some_field", "default")
if stored < Version("1.14.6"):
for entity in data.get("entities") or []:
_backfill_discriminators(entity)
return data
def _backfill_memory_kind(value: Any) -> None:
"""Infer ``memory_kind`` from structural fields on legacy memory dicts."""
if not isinstance(value, dict) or "memory_kind" in value:
return
if "scopes" in value:
value["memory_kind"] = "slice"
elif "root_path" in value:
value["memory_kind"] = "scope"
else:
value["memory_kind"] = "memory"
def _backfill_source_type(source: Any) -> None:
"""Infer ``source_type`` for legacy knowledge source dicts when possible."""
if not isinstance(source, dict) or "source_type" in source:
return
if "content" in source:
source["source_type"] = "string"
def _backfill_discriminators(entity: Any) -> None:
"""Walk an entity dict and backfill discriminator fields added in 1.14.6."""
if not isinstance(entity, dict):
return
_backfill_memory_kind(entity.get("memory"))
for agent in entity.get("agents") or []:
_backfill_memory_kind(agent.get("memory") if isinstance(agent, dict) else None)
for container in (entity.get("knowledge"), entity):
if isinstance(container, dict):
for src in (
container.get("sources") or container.get("knowledge_sources") or []
):
_backfill_source_type(src)
class RuntimeState(RootModel): # type: ignore[type-arg]
root: list[Entity]
_provider: BaseProvider = PrivateAttr(default_factory=JsonProvider)

View File

@@ -150,3 +150,55 @@ SerializableCallable = Annotated[
PlainSerializer(callable_to_string, return_type=str, when_used="json"),
WithJsonSchema({"type": "string"}),
]
def _instance_to_dotted_path(value: Any) -> str:
"""Serialize an instance to a dotted path naming its class."""
cls = type(value)
module = getattr(cls, "__module__", None)
qualname = getattr(cls, "__qualname__", None)
if module is None or qualname is None:
raise ValueError(
f"Cannot serialize {value!r}: class missing __module__ or __qualname__. "
"Use a module-level class for checkpointable instances."
)
if qualname.endswith("<lambda>") or "<locals>" in qualname:
raise ValueError(
f"Cannot serialize {value!r}: class defined in <locals>. "
"Use a module-level class for checkpointable instances."
)
return f"{module}.{qualname}"
def _dotted_path_to_instance(value: Any) -> Any:
"""Resolve a dotted path to a class and instantiate it with no args.
If *value* is already a non-string object it is returned as-is.
"""
if value is None or not isinstance(value, str):
return value
if "." not in value:
raise ValueError(
f"Invalid provider path {value!r}: expected 'module.name' format"
)
if not os.environ.get("CREWAI_DESERIALIZE_CALLBACKS"):
raise ValueError(
f"Refusing to resolve provider path {value!r}: "
"set CREWAI_DESERIALIZE_CALLBACKS=1 to allow. "
"Only enable this for trusted checkpoint data."
)
cls = _resolve_dotted_path(value)
if not inspect.isclass(cls):
raise ValueError(
f"Invalid provider path {value!r}: expected a class, got "
f"{type(cls).__name__}"
)
return cls()
SerializableInstance = Annotated[
Any,
BeforeValidator(_dotted_path_to_instance),
PlainSerializer(_instance_to_dotted_path, return_type=str, when_used="json"),
WithJsonSchema({"type": "string"}),
]