mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-02 13:48:09 +00:00
fix: round-trip safety for input_provider, memory scopes, embedder class
- input_provider: enforce InputProvider protocol via dedicated validator/serializer; reject non-class dotted paths in _dotted_path_to_instance - MemoryScope/MemorySlice: allow restore without live Memory; expose bind() to reattach the dependency post-restore - Knowledge.embedder: add BeforeValidator that resolves provider_class dotted paths back to a BaseEmbeddingsProvider subclass
This commit is contained in:
@@ -120,7 +120,6 @@ from crewai.state.checkpoint_config import (
|
||||
_coerce_checkpoint,
|
||||
apply_checkpoint,
|
||||
)
|
||||
from crewai.types.callback import SerializableInstance
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -168,6 +167,28 @@ def _serialize_persistence(value: Any) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
|
||||
def _validate_input_provider(value: Any) -> Any:
|
||||
if value is None or isinstance(value, InputProvider):
|
||||
return value
|
||||
from crewai.types.callback import _dotted_path_to_instance
|
||||
|
||||
resolved = _dotted_path_to_instance(value)
|
||||
if resolved is None or isinstance(resolved, InputProvider):
|
||||
return resolved
|
||||
raise ValueError(
|
||||
f"Resolved input_provider {resolved!r} does not implement the "
|
||||
"InputProvider protocol (missing request_input)."
|
||||
)
|
||||
|
||||
|
||||
def _serialize_input_provider(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
from crewai.types.callback import _instance_to_dotted_path
|
||||
|
||||
return _instance_to_dotted_path(value)
|
||||
|
||||
|
||||
_INITIAL_STATE_CLASS_MARKER = "__crewai_pydantic_class_schema__"
|
||||
|
||||
|
||||
@@ -964,7 +985,13 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
]
|
||||
| None
|
||||
) = Field(default=None)
|
||||
input_provider: SerializableInstance | None = Field(default=None)
|
||||
input_provider: Annotated[
|
||||
InputProvider | None,
|
||||
BeforeValidator(_validate_input_provider),
|
||||
PlainSerializer(
|
||||
_serialize_input_provider, return_type=str | None, when_used="json"
|
||||
),
|
||||
] = Field(default=None)
|
||||
suppress_flow_events: bool = Field(default=False)
|
||||
human_feedback_history: list[HumanFeedbackResult] = Field(default_factory=list)
|
||||
last_human_feedback: HumanFeedbackResult | None = Field(default=None)
|
||||
@@ -3189,7 +3216,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
from crewai.flow.flow_config import flow_config
|
||||
|
||||
if self.input_provider is not None:
|
||||
return cast(InputProvider, self.input_provider)
|
||||
return self.input_provider
|
||||
if flow_config.input_provider is not None:
|
||||
return flow_config.input_provider
|
||||
return ConsoleProvider()
|
||||
|
||||
@@ -75,6 +75,21 @@ def _serialize_embedder_spec(value: Any) -> dict[str, Any] | None:
|
||||
)
|
||||
|
||||
|
||||
def _validate_embedder_spec(value: Any) -> Any:
|
||||
"""Resolve provider_class dotted-path dicts back to a class on restore."""
|
||||
if isinstance(value, dict) and set(value.keys()) == {"provider_class"}:
|
||||
from crewai.types.callback import _resolve_dotted_path
|
||||
|
||||
cls = _resolve_dotted_path(value["provider_class"])
|
||||
if not isinstance(cls, type) or not issubclass(cls, BaseEmbeddingsProvider):
|
||||
raise ValueError(
|
||||
f"provider_class {value['provider_class']!r} did not resolve to a "
|
||||
"BaseEmbeddingsProvider subclass."
|
||||
)
|
||||
return cls
|
||||
return value
|
||||
|
||||
|
||||
class Knowledge(BaseModel):
|
||||
"""
|
||||
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
|
||||
@@ -92,6 +107,7 @@ class Knowledge(BaseModel):
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
embedder: Annotated[
|
||||
EmbedderConfig | None,
|
||||
BeforeValidator(_validate_embedder_spec),
|
||||
PlainSerializer(
|
||||
_serialize_embedder_spec, return_type=dict | None, when_used="json"
|
||||
),
|
||||
|
||||
@@ -6,6 +6,7 @@ from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.memory.types import (
|
||||
_RECALL_OVERSAMPLE_FACTOR,
|
||||
@@ -36,17 +37,25 @@ class MemoryScope(BaseModel):
|
||||
return data
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Expected dict or MemoryScope, got {type(data).__name__}")
|
||||
if "memory" not in data:
|
||||
raise ValueError("MemoryScope requires a 'memory' key")
|
||||
memory = data.pop("memory")
|
||||
memory = data.pop("memory", None)
|
||||
instance: MemoryScope = handler(data)
|
||||
instance._memory = memory
|
||||
if memory is not None:
|
||||
instance._memory = memory
|
||||
root = instance.root_path.rstrip("/") or ""
|
||||
if root and not root.startswith("/"):
|
||||
root = "/" + root
|
||||
instance._root = root
|
||||
return instance
|
||||
|
||||
def bind(self, memory: Memory) -> Self:
|
||||
"""Rebind the runtime ``Memory`` dependency after restore.
|
||||
|
||||
Required after deserializing from a checkpoint, since the live
|
||||
``Memory`` cannot be serialized.
|
||||
"""
|
||||
self._memory = memory
|
||||
return self
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
"""Whether the underlying memory is read-only."""
|
||||
@@ -209,14 +218,18 @@ class MemorySlice(BaseModel):
|
||||
return data
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Expected dict or MemorySlice, got {type(data).__name__}")
|
||||
if "memory" not in data:
|
||||
raise ValueError("MemorySlice requires a 'memory' key")
|
||||
memory = data.pop("memory")
|
||||
memory = data.pop("memory", None)
|
||||
data["scopes"] = [s.rstrip("/") or "/" for s in data.get("scopes", [])]
|
||||
instance: MemorySlice = handler(data)
|
||||
instance._memory = memory
|
||||
if memory is not None:
|
||||
instance._memory = memory
|
||||
return instance
|
||||
|
||||
def bind(self, memory: Memory) -> Self:
|
||||
"""Rebind the runtime ``Memory`` dependency after restore."""
|
||||
self._memory = memory
|
||||
return self
|
||||
|
||||
def remember(
|
||||
self,
|
||||
content: str,
|
||||
|
||||
@@ -188,6 +188,11 @@ def _dotted_path_to_instance(value: Any) -> Any:
|
||||
"Only enable this for trusted checkpoint data."
|
||||
)
|
||||
cls = _resolve_dotted_path(value)
|
||||
if not inspect.isclass(cls):
|
||||
raise ValueError(
|
||||
f"Invalid provider path {value!r}: expected a class, got "
|
||||
f"{type(cls).__name__}"
|
||||
)
|
||||
return cls()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user