mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-04 14:39:23 +00:00
fix: harden RuntimeState serialization across entity fields
Adds missing serializers, discriminators, and exclude markers on entity fields that previously crashed model_dump_json or restored ambiguously: - Flow.persistence: add _serialize_persistence; drop | Any escape hatch - Flow.input_provider: SerializableInstance dotted-path round-trip - BaseAgent.agent_executor: add _serialize_executor_ref - BaseAgent.tools_handler / cache_handler: exclude=True - Memory / MemoryScope / MemorySlice: memory_kind Literal discriminator - Knowledge.storage / .embedder: exclude live client, serialize spec - BaseKnowledgeSource subclasses: source_type Literal + dict-resolver - BaseKnowledgeSource.storage / chunk_embeddings: exclude=True
This commit is contained in:
@@ -127,6 +127,13 @@ def _validate_executor_ref(value: Any) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
def _serialize_executor_ref(value: Any) -> dict[str, Any] | None:
|
||||
if value is None:
|
||||
return None
|
||||
result: dict[str, Any] = value.model_dump(mode="json")
|
||||
return result
|
||||
|
||||
|
||||
def _serialize_llm_ref(value: Any) -> dict[str, Any] | None:
|
||||
if value is None:
|
||||
return None
|
||||
@@ -251,14 +258,13 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
max_iter: int = Field(
|
||||
default=25, description="Maximum iterations for an agent to execute a task"
|
||||
)
|
||||
agent_executor: SerializeAsAny[BaseAgentExecutor] | None = Field(
|
||||
default=None, description="An instance of the CrewAgentExecutor class."
|
||||
)
|
||||
|
||||
@field_validator("agent_executor", mode="before")
|
||||
@classmethod
|
||||
def _validate_agent_executor(cls, v: Any) -> Any:
|
||||
return _validate_executor_ref(v)
|
||||
agent_executor: Annotated[
|
||||
SerializeAsAny[BaseAgentExecutor] | None,
|
||||
BeforeValidator(_validate_executor_ref),
|
||||
PlainSerializer(
|
||||
_serialize_executor_ref, return_type=dict | None, when_used="json"
|
||||
),
|
||||
] = Field(default=None, description="An instance of the CrewAgentExecutor class.")
|
||||
|
||||
llm: Annotated[
|
||||
str | BaseLLM | None,
|
||||
@@ -326,7 +332,13 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
default=None,
|
||||
description="List of MCP server references. Supports 'https://server.com/path' for external servers and bare slugs like 'notion' for connected MCP integrations. Use '#tool_name' suffix for specific tools.",
|
||||
)
|
||||
memory: bool | Memory | MemoryScope | MemorySlice | None = Field(
|
||||
memory: (
|
||||
bool
|
||||
| Annotated[
|
||||
Memory | MemoryScope | MemorySlice, Field(discriminator="memory_kind")
|
||||
]
|
||||
| None
|
||||
) = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Enable agent memory. Pass True for default Memory(), "
|
||||
|
||||
@@ -223,7 +223,13 @@ class Crew(FlowTrackable, BaseModel):
|
||||
] = Field(default_factory=list)
|
||||
process: Process = Field(default=Process.sequential)
|
||||
verbose: bool = Field(default=False)
|
||||
memory: bool | Memory | MemoryScope | MemorySlice | None = Field(
|
||||
memory: (
|
||||
bool
|
||||
| Annotated[
|
||||
Memory | MemoryScope | MemorySlice, Field(discriminator="memory_kind")
|
||||
]
|
||||
| None
|
||||
) = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Enable crew memory. Pass True for default Memory(), "
|
||||
|
||||
@@ -120,6 +120,7 @@ from crewai.state.checkpoint_config import (
|
||||
_coerce_checkpoint,
|
||||
apply_checkpoint,
|
||||
)
|
||||
from crewai.types.callback import SerializableInstance
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -159,6 +160,14 @@ def _resolve_persistence(value: Any) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
def _serialize_persistence(value: Any) -> dict[str, Any] | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, FlowPersistence):
|
||||
return value.model_dump(mode="json")
|
||||
return None
|
||||
|
||||
|
||||
_INITIAL_STATE_CLASS_MARKER = "__crewai_pydantic_class_schema__"
|
||||
|
||||
|
||||
@@ -949,15 +958,23 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
name: str | None = Field(default=None)
|
||||
tracing: bool | None = Field(default=None)
|
||||
stream: bool = Field(default=False)
|
||||
memory: Memory | MemoryScope | MemorySlice | None = Field(default=None)
|
||||
input_provider: InputProvider | None = Field(default=None)
|
||||
memory: (
|
||||
Annotated[
|
||||
Memory | MemoryScope | MemorySlice, Field(discriminator="memory_kind")
|
||||
]
|
||||
| None
|
||||
) = Field(default=None)
|
||||
input_provider: SerializableInstance | None = Field(default=None)
|
||||
suppress_flow_events: bool = Field(default=False)
|
||||
human_feedback_history: list[HumanFeedbackResult] = Field(default_factory=list)
|
||||
last_human_feedback: HumanFeedbackResult | None = Field(default=None)
|
||||
|
||||
persistence: Annotated[
|
||||
SerializeAsAny[FlowPersistence] | Any,
|
||||
SerializeAsAny[FlowPersistence] | None,
|
||||
BeforeValidator(lambda v, _: _resolve_persistence(v)),
|
||||
PlainSerializer(
|
||||
_serialize_persistence, return_type=dict | None, when_used="json"
|
||||
),
|
||||
] = Field(default=None)
|
||||
max_method_calls: int = Field(default=100)
|
||||
|
||||
@@ -3172,7 +3189,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
from crewai.flow.flow_config import flow_config
|
||||
|
||||
if self.input_provider is not None:
|
||||
return self.input_provider
|
||||
return cast(InputProvider, self.input_provider)
|
||||
if flow_config.input_provider is not None:
|
||||
return flow_config.input_provider
|
||||
return ConsoleProvider()
|
||||
|
||||
@@ -1,16 +1,71 @@
|
||||
import os
|
||||
from typing import Annotated, Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, PlainSerializer
|
||||
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.source.crew_docling_source import CrewDoclingSource
|
||||
from crewai.knowledge.source.csv_knowledge_source import CSVKnowledgeSource
|
||||
from crewai.knowledge.source.excel_knowledge_source import ExcelKnowledgeSource
|
||||
from crewai.knowledge.source.json_knowledge_source import JSONKnowledgeSource
|
||||
from crewai.knowledge.source.pdf_knowledge_source import PDFKnowledgeSource
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
from crewai.knowledge.source.text_file_knowledge_source import (
|
||||
TextFileKnowledgeSource,
|
||||
)
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
|
||||
_KNOWN_SOURCES: dict[str, type[BaseKnowledgeSource]] = {
|
||||
"string": StringKnowledgeSource,
|
||||
"docling": CrewDoclingSource,
|
||||
"csv": CSVKnowledgeSource,
|
||||
"excel": ExcelKnowledgeSource,
|
||||
"json": JSONKnowledgeSource,
|
||||
"pdf": PDFKnowledgeSource,
|
||||
"text_file": TextFileKnowledgeSource,
|
||||
}
|
||||
|
||||
|
||||
def _resolve_knowledge_sources(value: Any) -> Any:
|
||||
"""Coerce list of dicts into typed BaseKnowledgeSource subclasses via source_type.
|
||||
|
||||
Pass-through for anything else (existing instances, mocks).
|
||||
"""
|
||||
if not isinstance(value, list):
|
||||
return value
|
||||
resolved: list[Any] = []
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
tag = item.get("source_type")
|
||||
cls = _KNOWN_SOURCES.get(tag) if isinstance(tag, str) else None
|
||||
if cls is None:
|
||||
resolved.append(item)
|
||||
else:
|
||||
resolved.append(cls.model_validate(item))
|
||||
else:
|
||||
resolved.append(item)
|
||||
return resolved
|
||||
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
|
||||
|
||||
|
||||
def _serialize_embedder_spec(value: Any) -> dict[str, Any] | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, BaseEmbeddingsProvider):
|
||||
return value.model_dump(mode="json")
|
||||
if isinstance(value, type) and issubclass(value, BaseEmbeddingsProvider):
|
||||
return {"provider_class": f"{value.__module__}.{value.__qualname__}"}
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
class Knowledge(BaseModel):
|
||||
"""
|
||||
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
|
||||
@@ -20,10 +75,18 @@ class Knowledge(BaseModel):
|
||||
embedder: EmbedderConfig | None = None
|
||||
"""
|
||||
|
||||
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
sources: Annotated[
|
||||
list[BaseKnowledgeSource],
|
||||
BeforeValidator(_resolve_knowledge_sources),
|
||||
] = Field(default_factory=list)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
embedder: EmbedderConfig | None = None
|
||||
embedder: Annotated[
|
||||
EmbedderConfig | None,
|
||||
PlainSerializer(
|
||||
_serialize_embedder_spec, return_type=dict | None, when_used="json"
|
||||
),
|
||||
] = None
|
||||
collection_name: str | None = None
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -13,7 +13,9 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
chunk_size: int = 4000
|
||||
chunk_overlap: int = 200
|
||||
chunks: list[str] = Field(default_factory=list)
|
||||
chunk_embeddings: list[np.ndarray[Any, np.dtype[Any]]] = Field(default_factory=list)
|
||||
chunk_embeddings: list[np.ndarray[Any, np.dtype[Any]]] = Field(
|
||||
default_factory=list, exclude=True
|
||||
)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
|
||||
source_type: Literal["docling"] = "docling"
|
||||
file_path: list[Path | str] | None = Field(default=None)
|
||||
file_paths: list[Path | str] = Field(default_factory=list)
|
||||
chunks: list[str] = Field(default_factory=list)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import csv
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -7,6 +8,8 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class CSVKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries CSV file content using embeddings."""
|
||||
|
||||
source_type: Literal["csv"] = "csv"
|
||||
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess CSV file content."""
|
||||
content_dict = {}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
@@ -16,6 +16,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
|
||||
source_type: Literal["excel"] = "excel"
|
||||
file_path: Path | list[Path] | str | list[str] | None = Field(
|
||||
default=None,
|
||||
description="[Deprecated] The path to the file. Use file_paths instead.",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -8,6 +8,8 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries JSON file content using embeddings."""
|
||||
|
||||
source_type: Literal["json"] = "json"
|
||||
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess JSON file content."""
|
||||
content: dict[Path, str] = {}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Literal
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -7,6 +8,8 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries PDF file content using embeddings."""
|
||||
|
||||
source_type: Literal["pdf"] = "pdf"
|
||||
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess PDF file content."""
|
||||
pdfplumber = self._import_pdfplumber()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -8,6 +8,7 @@ from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
class StringKnowledgeSource(BaseKnowledgeSource):
|
||||
"""A knowledge source that stores and queries plain text content using embeddings."""
|
||||
|
||||
source_type: Literal["string"] = "string"
|
||||
content: str = Field(...)
|
||||
collection_name: str | None = Field(default=None)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -6,6 +7,8 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class TextFileKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries text file content using embeddings."""
|
||||
|
||||
source_type: Literal["text_file"] = "text_file"
|
||||
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess text file content."""
|
||||
content = {}
|
||||
|
||||
@@ -21,6 +21,8 @@ class MemoryScope(BaseModel):
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
memory_kind: Literal["scope"] = "scope"
|
||||
|
||||
root_path: str = Field(default="/")
|
||||
|
||||
_memory: Memory = PrivateAttr()
|
||||
@@ -191,6 +193,8 @@ class MemorySlice(BaseModel):
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
memory_kind: Literal["slice"] = "slice"
|
||||
|
||||
scopes: list[str] = Field(default_factory=list)
|
||||
categories: list[str] | None = Field(default=None)
|
||||
read_only: bool = Field(default=True)
|
||||
|
||||
@@ -63,6 +63,8 @@ class Memory(BaseModel):
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
memory_kind: Literal["memory"] = "memory"
|
||||
|
||||
llm: Annotated[BaseLLM | str, PlainValidator(_passthrough)] = Field(
|
||||
default="gpt-4o-mini",
|
||||
description="LLM for analysis (model name or BaseLLM instance).",
|
||||
|
||||
@@ -150,3 +150,50 @@ SerializableCallable = Annotated[
|
||||
PlainSerializer(callable_to_string, return_type=str, when_used="json"),
|
||||
WithJsonSchema({"type": "string"}),
|
||||
]
|
||||
|
||||
|
||||
def _instance_to_dotted_path(value: Any) -> str:
|
||||
"""Serialize an instance to a dotted path naming its class."""
|
||||
cls = type(value)
|
||||
module = getattr(cls, "__module__", None)
|
||||
qualname = getattr(cls, "__qualname__", None)
|
||||
if module is None or qualname is None:
|
||||
raise ValueError(
|
||||
f"Cannot serialize {value!r}: class missing __module__ or __qualname__. "
|
||||
"Use a module-level class for checkpointable instances."
|
||||
)
|
||||
if qualname.endswith("<lambda>") or "<locals>" in qualname:
|
||||
raise ValueError(
|
||||
f"Cannot serialize {value!r}: class defined in <locals>. "
|
||||
"Use a module-level class for checkpointable instances."
|
||||
)
|
||||
return f"{module}.{qualname}"
|
||||
|
||||
|
||||
def _dotted_path_to_instance(value: Any) -> Any:
|
||||
"""Resolve a dotted path to a class and instantiate it with no args.
|
||||
|
||||
If *value* is already a non-string object it is returned as-is.
|
||||
"""
|
||||
if value is None or not isinstance(value, str):
|
||||
return value
|
||||
if "." not in value:
|
||||
raise ValueError(
|
||||
f"Invalid provider path {value!r}: expected 'module.name' format"
|
||||
)
|
||||
if not os.environ.get("CREWAI_DESERIALIZE_CALLBACKS"):
|
||||
raise ValueError(
|
||||
f"Refusing to resolve provider path {value!r}: "
|
||||
"set CREWAI_DESERIALIZE_CALLBACKS=1 to allow. "
|
||||
"Only enable this for trusted checkpoint data."
|
||||
)
|
||||
cls = _resolve_dotted_path(value)
|
||||
return cls()
|
||||
|
||||
|
||||
SerializableInstance = Annotated[
|
||||
Any,
|
||||
BeforeValidator(_dotted_path_to_instance),
|
||||
PlainSerializer(_instance_to_dotted_path, return_type=str, when_used="json"),
|
||||
WithJsonSchema({"type": "string"}),
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user