fix: reject class-typed embedders at serialize time, drop unused validator

This commit is contained in:
Greyson LaLonde
2026-05-21 03:20:32 +08:00
parent 1eb2326e8a
commit 0e9167dec3

View File

@@ -70,40 +70,20 @@ def _serialize_embedder_spec(value: Any) -> dict[str, Any] | None:
return None
if isinstance(value, BaseEmbeddingsProvider):
return value.model_dump(mode="json")
if isinstance(value, type) and issubclass(value, BaseEmbeddingsProvider):
return {"provider_class": f"{value.__module__}.{value.__qualname__}"}
if isinstance(value, dict):
return value
if isinstance(value, type) and issubclass(value, BaseEmbeddingsProvider):
raise TypeError(
f"Cannot checkpoint embedder class {value.__module__}.{value.__qualname__}: "
"build_embedder requires an instance or ProviderSpec dict, not a class. "
"Instantiate the provider before assigning it to Knowledge.embedder."
)
raise TypeError(
f"Cannot serialize embedder of type {type(value).__name__}: "
"expected ProviderSpec dict, BaseEmbeddingsProvider instance, or subclass."
"expected ProviderSpec dict or BaseEmbeddingsProvider instance."
)
def _validate_embedder_spec(value: Any) -> Any:
"""Resolve provider_class dotted-path dicts back to a class on restore."""
if isinstance(value, dict) and set(value.keys()) == {"provider_class"}:
from crewai.types.callback import _trusted_deserialize
if not _trusted_deserialize():
raise ValueError(
f"Refusing to resolve embedder provider_class "
f"{value['provider_class']!r}: set "
"CREWAI_DESERIALIZE_CALLBACKS=1 to allow. Only enable this "
"for trusted checkpoint data."
)
from crewai.types.callback import _resolve_dotted_path
cls = _resolve_dotted_path(value["provider_class"])
if not isinstance(cls, type) or not issubclass(cls, BaseEmbeddingsProvider):
raise ValueError(
f"provider_class {value['provider_class']!r} did not resolve to a "
"BaseEmbeddingsProvider subclass."
)
return cls
return value
class Knowledge(BaseModel):
"""
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
@@ -121,7 +101,6 @@ class Knowledge(BaseModel):
storage: KnowledgeStorage | None = Field(default=None)
embedder: Annotated[
EmbedderConfig | None,
BeforeValidator(_validate_embedder_spec),
PlainSerializer(
_serialize_embedder_spec, return_type=dict | None, when_used="json"
),