fix: round-trip safety for input_provider, memory scopes, embedder class

- input_provider: enforce InputProvider protocol via dedicated
  validator/serializer; reject non-class dotted paths in
  _dotted_path_to_instance
- MemoryScope/MemorySlice: allow restore without live Memory; expose
  bind() to reattach the dependency post-restore
- Knowledge.embedder: add BeforeValidator that resolves provider_class
  dotted paths back to a BaseEmbeddingsProvider subclass
This commit is contained in:
Greyson LaLonde
2026-05-21 00:30:14 +08:00
parent b07c1439a3
commit 0f3a57b3b9
4 changed files with 72 additions and 11 deletions

View File

@@ -120,7 +120,6 @@ from crewai.state.checkpoint_config import (
_coerce_checkpoint,
apply_checkpoint,
)
from crewai.types.callback import SerializableInstance
if TYPE_CHECKING:
@@ -168,6 +167,28 @@ def _serialize_persistence(value: Any) -> dict[str, Any] | None:
return None
def _validate_input_provider(value: Any) -> Any:
if value is None or isinstance(value, InputProvider):
return value
from crewai.types.callback import _dotted_path_to_instance
resolved = _dotted_path_to_instance(value)
if resolved is None or isinstance(resolved, InputProvider):
return resolved
raise ValueError(
f"Resolved input_provider {resolved!r} does not implement the "
"InputProvider protocol (missing request_input)."
)
def _serialize_input_provider(value: Any) -> str | None:
if value is None:
return None
from crewai.types.callback import _instance_to_dotted_path
return _instance_to_dotted_path(value)
_INITIAL_STATE_CLASS_MARKER = "__crewai_pydantic_class_schema__"
@@ -964,7 +985,13 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
]
| None
) = Field(default=None)
input_provider: SerializableInstance | None = Field(default=None)
input_provider: Annotated[
InputProvider | None,
BeforeValidator(_validate_input_provider),
PlainSerializer(
_serialize_input_provider, return_type=str | None, when_used="json"
),
] = Field(default=None)
suppress_flow_events: bool = Field(default=False)
human_feedback_history: list[HumanFeedbackResult] = Field(default_factory=list)
last_human_feedback: HumanFeedbackResult | None = Field(default=None)
@@ -3189,7 +3216,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
from crewai.flow.flow_config import flow_config
if self.input_provider is not None:
return cast(InputProvider, self.input_provider)
return self.input_provider
if flow_config.input_provider is not None:
return flow_config.input_provider
return ConsoleProvider()

View File

@@ -75,6 +75,21 @@ def _serialize_embedder_spec(value: Any) -> dict[str, Any] | None:
)
def _validate_embedder_spec(value: Any) -> Any:
"""Resolve provider_class dotted-path dicts back to a class on restore."""
if isinstance(value, dict) and set(value.keys()) == {"provider_class"}:
from crewai.types.callback import _resolve_dotted_path
cls = _resolve_dotted_path(value["provider_class"])
if not isinstance(cls, type) or not issubclass(cls, BaseEmbeddingsProvider):
raise ValueError(
f"provider_class {value['provider_class']!r} did not resolve to a "
"BaseEmbeddingsProvider subclass."
)
return cls
return value
class Knowledge(BaseModel):
"""
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
@@ -92,6 +107,7 @@ class Knowledge(BaseModel):
storage: KnowledgeStorage | None = Field(default=None)
embedder: Annotated[
EmbedderConfig | None,
BeforeValidator(_validate_embedder_spec),
PlainSerializer(
_serialize_embedder_spec, return_type=dict | None, when_used="json"
),

View File

@@ -6,6 +6,7 @@ from datetime import datetime
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.memory.types import (
_RECALL_OVERSAMPLE_FACTOR,
@@ -36,17 +37,25 @@ class MemoryScope(BaseModel):
return data
if not isinstance(data, dict):
raise ValueError(f"Expected dict or MemoryScope, got {type(data).__name__}")
if "memory" not in data:
raise ValueError("MemoryScope requires a 'memory' key")
memory = data.pop("memory")
memory = data.pop("memory", None)
instance: MemoryScope = handler(data)
instance._memory = memory
if memory is not None:
instance._memory = memory
root = instance.root_path.rstrip("/") or ""
if root and not root.startswith("/"):
root = "/" + root
instance._root = root
return instance
def bind(self, memory: Memory) -> Self:
"""Rebind the runtime ``Memory`` dependency after restore.
Required after deserializing from a checkpoint, since the live
``Memory`` cannot be serialized.
"""
self._memory = memory
return self
@property
def read_only(self) -> bool:
"""Whether the underlying memory is read-only."""
@@ -209,14 +218,18 @@ class MemorySlice(BaseModel):
return data
if not isinstance(data, dict):
raise ValueError(f"Expected dict or MemorySlice, got {type(data).__name__}")
if "memory" not in data:
raise ValueError("MemorySlice requires a 'memory' key")
memory = data.pop("memory")
memory = data.pop("memory", None)
data["scopes"] = [s.rstrip("/") or "/" for s in data.get("scopes", [])]
instance: MemorySlice = handler(data)
instance._memory = memory
if memory is not None:
instance._memory = memory
return instance
def bind(self, memory: Memory) -> Self:
"""Rebind the runtime ``Memory`` dependency after restore."""
self._memory = memory
return self
def remember(
self,
content: str,

View File

@@ -188,6 +188,11 @@ def _dotted_path_to_instance(value: Any) -> Any:
"Only enable this for trusted checkpoint data."
)
cls = _resolve_dotted_path(value)
if not inspect.isclass(cls):
raise ValueError(
f"Invalid provider path {value!r}: expected a class, got "
f"{type(cls).__name__}"
)
return cls()