mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-21 00:48:10 +00:00
Compare commits
19 Commits
lorenze/fe
...
fix/runtim
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0e9167dec3 | ||
|
|
1eb2326e8a | ||
|
|
a0e5d91364 | ||
|
|
c6643d4071 | ||
|
|
0d5c2c81e0 | ||
|
|
2346f12c43 | ||
|
|
dc047743b8 | ||
|
|
163b3b592d | ||
|
|
83ba64c334 | ||
|
|
fc480409bd | ||
|
|
15a423ad3c | ||
|
|
c37afab1ff | ||
|
|
f385b91a63 | ||
|
|
0991f7994a | ||
|
|
3ceb9a287a | ||
|
|
0f3a57b3b9 | ||
|
|
b07c1439a3 | ||
|
|
97e959cb0c | ||
|
|
752d9b45d6 |
@@ -31,13 +31,13 @@ from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.events.base_events import set_emission_counter
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_context import restore_event_scope, set_last_event_id
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.knowledge import Knowledge, _resolve_knowledge_sources
|
||||
from crewai.knowledge.knowledge_config import KnowledgeConfig
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.mcp.config import MCPServerConfig
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice, _ensure_memory_kind
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.security.security_config import SecurityConfig
|
||||
@@ -127,6 +127,13 @@ def _validate_executor_ref(value: Any) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
def _serialize_executor_ref(value: Any) -> dict[str, Any] | None:
|
||||
if value is None:
|
||||
return None
|
||||
result: dict[str, Any] = value.model_dump(mode="json")
|
||||
return result
|
||||
|
||||
|
||||
def _serialize_llm_ref(value: Any) -> dict[str, Any] | None:
|
||||
if value is None:
|
||||
return None
|
||||
@@ -251,14 +258,13 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
max_iter: int = Field(
|
||||
default=25, description="Maximum iterations for an agent to execute a task"
|
||||
)
|
||||
agent_executor: SerializeAsAny[BaseAgentExecutor] | None = Field(
|
||||
default=None, description="An instance of the CrewAgentExecutor class."
|
||||
)
|
||||
|
||||
@field_validator("agent_executor", mode="before")
|
||||
@classmethod
|
||||
def _validate_agent_executor(cls, v: Any) -> Any:
|
||||
return _validate_executor_ref(v)
|
||||
agent_executor: Annotated[
|
||||
SerializeAsAny[BaseAgentExecutor] | None,
|
||||
BeforeValidator(_validate_executor_ref),
|
||||
PlainSerializer(
|
||||
_serialize_executor_ref, return_type=dict | None, when_used="json"
|
||||
),
|
||||
] = Field(default=None, description="An instance of the CrewAgentExecutor class.")
|
||||
|
||||
llm: Annotated[
|
||||
str | BaseLLM | None,
|
||||
@@ -288,7 +294,10 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
knowledge: Knowledge | None = Field(
|
||||
default=None, description="Knowledge for the agent."
|
||||
)
|
||||
knowledge_sources: list[BaseKnowledgeSource] | None = Field(
|
||||
knowledge_sources: Annotated[
|
||||
list[BaseKnowledgeSource] | None,
|
||||
BeforeValidator(_resolve_knowledge_sources),
|
||||
] = Field(
|
||||
default=None,
|
||||
description="Knowledge sources for the agent.",
|
||||
)
|
||||
@@ -326,7 +335,14 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
default=None,
|
||||
description="List of MCP server references. Supports 'https://server.com/path' for external servers and bare slugs like 'notion' for connected MCP integrations. Use '#tool_name' suffix for specific tools.",
|
||||
)
|
||||
memory: bool | Memory | MemoryScope | MemorySlice | None = Field(
|
||||
memory: Annotated[
|
||||
bool
|
||||
| Annotated[
|
||||
Memory | MemoryScope | MemorySlice, Field(discriminator="memory_kind")
|
||||
]
|
||||
| None,
|
||||
BeforeValidator(_ensure_memory_kind),
|
||||
] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Enable agent memory. Pass True for default Memory(), "
|
||||
@@ -397,8 +413,21 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
self.agent_executor._resuming = True
|
||||
if self.checkpoint_kickoff_event_id is not None:
|
||||
self._kickoff_event_id = self.checkpoint_kickoff_event_id
|
||||
self._rebind_memory_view()
|
||||
self._restore_event_scope(state)
|
||||
|
||||
def _rebind_memory_view(self) -> None:
|
||||
"""Reattach a fresh ``Memory`` to a restored ``MemoryScope``/``MemorySlice``.
|
||||
|
||||
Checkpoint JSON omits the live ``Memory`` dependency, so scoped
|
||||
memory views raise ``RuntimeError`` on first use after restore.
|
||||
"""
|
||||
if (
|
||||
isinstance(self.memory, MemoryScope | MemorySlice)
|
||||
and self.memory._memory is None
|
||||
):
|
||||
self.memory.bind(Memory())
|
||||
|
||||
def _restore_event_scope(self, state: RuntimeState) -> None:
|
||||
"""Rebuild the event scope stack from the checkpoint's event record.
|
||||
|
||||
|
||||
@@ -93,11 +93,11 @@ from crewai.events.types.crew_events import (
|
||||
CrewTrainStartedEvent,
|
||||
)
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.knowledge import Knowledge, _resolve_knowledge_sources
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice, _ensure_memory_kind
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai.process import Process
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
@@ -223,7 +223,14 @@ class Crew(FlowTrackable, BaseModel):
|
||||
] = Field(default_factory=list)
|
||||
process: Process = Field(default=Process.sequential)
|
||||
verbose: bool = Field(default=False)
|
||||
memory: bool | Memory | MemoryScope | MemorySlice | None = Field(
|
||||
memory: Annotated[
|
||||
bool
|
||||
| Annotated[
|
||||
Memory | MemoryScope | MemorySlice, Field(discriminator="memory_kind")
|
||||
]
|
||||
| None,
|
||||
BeforeValidator(_ensure_memory_kind),
|
||||
] = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Enable crew memory. Pass True for default Memory(), "
|
||||
@@ -322,7 +329,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default_factory=list,
|
||||
description="list of execution logs for tasks",
|
||||
)
|
||||
knowledge_sources: list[BaseKnowledgeSource] | None = Field(
|
||||
knowledge_sources: Annotated[
|
||||
list[BaseKnowledgeSource] | None,
|
||||
BeforeValidator(_resolve_knowledge_sources),
|
||||
] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Knowledge sources for the crew. Add knowledge sources to the "
|
||||
@@ -477,8 +487,42 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if self.checkpoint_train is not None:
|
||||
self._train = self.checkpoint_train
|
||||
|
||||
self._rebind_memory_views()
|
||||
self._restore_event_scope()
|
||||
|
||||
def _rebind_memory_views(self) -> None:
|
||||
"""Reattach a live ``Memory`` to restored ``MemoryScope``/``MemorySlice`` views.
|
||||
|
||||
Checkpoint JSON omits the live ``Memory`` dependency on scope/slice
|
||||
views, so after restore they raise ``RuntimeError`` on first use.
|
||||
Prefer the crew's restored ``Memory`` (from ``create_crew_memory``
|
||||
or a ``Crew.memory=Memory(...)`` instance) so all views share one
|
||||
backing store; fall back to a fresh ``Memory()`` only if nothing is
|
||||
available.
|
||||
"""
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
backing: Memory | None = None
|
||||
if isinstance(self._memory, Memory):
|
||||
backing = self._memory
|
||||
elif isinstance(self.memory, Memory):
|
||||
backing = self.memory
|
||||
|
||||
def _ensure(view: Any) -> None:
|
||||
nonlocal backing
|
||||
if not isinstance(view, MemoryScope | MemorySlice):
|
||||
return
|
||||
if view._memory is not None:
|
||||
return
|
||||
if backing is None:
|
||||
backing = Memory()
|
||||
view.bind(backing)
|
||||
|
||||
_ensure(self.memory)
|
||||
for agent in self.agents:
|
||||
_ensure(agent.memory)
|
||||
|
||||
def _restore_event_scope(self) -> None:
|
||||
"""Rebuild the event scope stack from the checkpoint's event record."""
|
||||
from crewai.events.base_events import set_emission_counter
|
||||
|
||||
@@ -113,7 +113,7 @@ from crewai.flow.utils import (
|
||||
is_flow_method_name,
|
||||
is_simple_flow_condition,
|
||||
)
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice, _ensure_memory_kind
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai.state.checkpoint_config import (
|
||||
CheckpointConfig,
|
||||
@@ -159,6 +159,39 @@ def _resolve_persistence(value: Any) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
def _serialize_persistence(value: Any) -> dict[str, Any] | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, FlowPersistence):
|
||||
return value.model_dump(mode="json")
|
||||
raise TypeError(
|
||||
f"Cannot serialize Flow.persistence of type {type(value).__name__}: "
|
||||
"expected FlowPersistence or None."
|
||||
)
|
||||
|
||||
|
||||
def _validate_input_provider(value: Any) -> Any:
|
||||
if value is None or isinstance(value, InputProvider):
|
||||
return value
|
||||
from crewai.types.callback import _dotted_path_to_instance
|
||||
|
||||
resolved = _dotted_path_to_instance(value)
|
||||
if resolved is None or isinstance(resolved, InputProvider):
|
||||
return resolved
|
||||
raise ValueError(
|
||||
f"Resolved input_provider {resolved!r} does not implement the "
|
||||
"InputProvider protocol (missing request_input)."
|
||||
)
|
||||
|
||||
|
||||
def _serialize_input_provider(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
from crewai.types.callback import _instance_to_dotted_path
|
||||
|
||||
return _instance_to_dotted_path(value)
|
||||
|
||||
|
||||
_INITIAL_STATE_CLASS_MARKER = "__crewai_pydantic_class_schema__"
|
||||
|
||||
|
||||
@@ -949,15 +982,30 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
name: str | None = Field(default=None)
|
||||
tracing: bool | None = Field(default=None)
|
||||
stream: bool = Field(default=False)
|
||||
memory: Memory | MemoryScope | MemorySlice | None = Field(default=None)
|
||||
input_provider: InputProvider | None = Field(default=None)
|
||||
memory: Annotated[
|
||||
Annotated[
|
||||
Memory | MemoryScope | MemorySlice, Field(discriminator="memory_kind")
|
||||
]
|
||||
| None,
|
||||
BeforeValidator(_ensure_memory_kind),
|
||||
] = Field(default=None)
|
||||
input_provider: Annotated[
|
||||
InputProvider | None,
|
||||
BeforeValidator(_validate_input_provider),
|
||||
PlainSerializer(
|
||||
_serialize_input_provider, return_type=str | None, when_used="json"
|
||||
),
|
||||
] = Field(default=None)
|
||||
suppress_flow_events: bool = Field(default=False)
|
||||
human_feedback_history: list[HumanFeedbackResult] = Field(default_factory=list)
|
||||
last_human_feedback: HumanFeedbackResult | None = Field(default=None)
|
||||
|
||||
persistence: Annotated[
|
||||
SerializeAsAny[FlowPersistence] | Any,
|
||||
SerializeAsAny[FlowPersistence] | None,
|
||||
BeforeValidator(lambda v, _: _resolve_persistence(v)),
|
||||
PlainSerializer(
|
||||
_serialize_persistence, return_type=dict | None, when_used="json"
|
||||
),
|
||||
] = Field(default=None)
|
||||
max_method_calls: int = Field(default=100)
|
||||
|
||||
@@ -1050,6 +1098,11 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
}
|
||||
if self.checkpoint_state is not None:
|
||||
self._restore_state(self.checkpoint_state)
|
||||
if (
|
||||
isinstance(self.memory, MemoryScope | MemorySlice)
|
||||
and self.memory._memory is None
|
||||
):
|
||||
self.memory.bind(Memory())
|
||||
restore_event_scope(())
|
||||
reset_last_event_id()
|
||||
|
||||
|
||||
@@ -1,16 +1,89 @@
|
||||
import os
|
||||
from typing import Annotated, Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, BeforeValidator, ConfigDict, Field, PlainSerializer
|
||||
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.source.crew_docling_source import CrewDoclingSource
|
||||
from crewai.knowledge.source.csv_knowledge_source import CSVKnowledgeSource
|
||||
from crewai.knowledge.source.excel_knowledge_source import ExcelKnowledgeSource
|
||||
from crewai.knowledge.source.json_knowledge_source import JSONKnowledgeSource
|
||||
from crewai.knowledge.source.pdf_knowledge_source import PDFKnowledgeSource
|
||||
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
|
||||
from crewai.knowledge.source.text_file_knowledge_source import (
|
||||
TextFileKnowledgeSource,
|
||||
)
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
|
||||
_KNOWN_SOURCES: dict[str, type[BaseKnowledgeSource]] = {
|
||||
"string": StringKnowledgeSource,
|
||||
"docling": CrewDoclingSource,
|
||||
"csv": CSVKnowledgeSource,
|
||||
"excel": ExcelKnowledgeSource,
|
||||
"json": JSONKnowledgeSource,
|
||||
"pdf": PDFKnowledgeSource,
|
||||
"text_file": TextFileKnowledgeSource,
|
||||
}
|
||||
|
||||
|
||||
def _resolve_knowledge_sources(value: Any) -> Any:
|
||||
"""Coerce list of dicts into typed BaseKnowledgeSource subclasses via source_type.
|
||||
|
||||
Pass-through for anything else (existing instances, mocks).
|
||||
"""
|
||||
if not isinstance(value, list):
|
||||
return value
|
||||
resolved: list[Any] = []
|
||||
for idx, item in enumerate(value):
|
||||
if isinstance(item, dict):
|
||||
tag = item.get("source_type")
|
||||
if not isinstance(tag, str):
|
||||
resolved.append(item)
|
||||
continue
|
||||
cls = _KNOWN_SOURCES.get(tag)
|
||||
if cls is None:
|
||||
raise ValueError(
|
||||
f"Unknown source_type={tag!r} at index {idx}: "
|
||||
f"expected one of {sorted(_KNOWN_SOURCES)}"
|
||||
)
|
||||
try:
|
||||
resolved.append(cls.model_validate(item))
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
f"Failed to validate knowledge source at index {idx} "
|
||||
f"with source_type={tag!r}: {exc}"
|
||||
) from exc
|
||||
else:
|
||||
resolved.append(item)
|
||||
return resolved
|
||||
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
|
||||
|
||||
|
||||
def _serialize_embedder_spec(value: Any) -> dict[str, Any] | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, BaseEmbeddingsProvider):
|
||||
return value.model_dump(mode="json")
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if isinstance(value, type) and issubclass(value, BaseEmbeddingsProvider):
|
||||
raise TypeError(
|
||||
f"Cannot checkpoint embedder class {value.__module__}.{value.__qualname__}: "
|
||||
"build_embedder requires an instance or ProviderSpec dict, not a class. "
|
||||
"Instantiate the provider before assigning it to Knowledge.embedder."
|
||||
)
|
||||
raise TypeError(
|
||||
f"Cannot serialize embedder of type {type(value).__name__}: "
|
||||
"expected ProviderSpec dict or BaseEmbeddingsProvider instance."
|
||||
)
|
||||
|
||||
|
||||
class Knowledge(BaseModel):
|
||||
"""
|
||||
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
|
||||
@@ -20,10 +93,18 @@ class Knowledge(BaseModel):
|
||||
embedder: EmbedderConfig | None = None
|
||||
"""
|
||||
|
||||
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
sources: Annotated[
|
||||
list[BaseKnowledgeSource],
|
||||
BeforeValidator(_resolve_knowledge_sources),
|
||||
] = Field(default_factory=list)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
embedder: EmbedderConfig | None = None
|
||||
embedder: Annotated[
|
||||
EmbedderConfig | None,
|
||||
PlainSerializer(
|
||||
_serialize_embedder_spec, return_type=dict | None, when_used="json"
|
||||
),
|
||||
] = None
|
||||
collection_name: str | None = None
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -13,7 +13,9 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
chunk_size: int = 4000
|
||||
chunk_overlap: int = 200
|
||||
chunks: list[str] = Field(default_factory=list)
|
||||
chunk_embeddings: list[np.ndarray[Any, np.dtype[Any]]] = Field(default_factory=list)
|
||||
chunk_embeddings: list[np.ndarray[Any, np.dtype[Any]]] = Field(
|
||||
default_factory=list, exclude=True
|
||||
)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
|
||||
source_type: Literal["docling"] = "docling"
|
||||
file_path: list[Path | str] | None = Field(default=None)
|
||||
file_paths: list[Path | str] = Field(default_factory=list)
|
||||
chunks: list[str] = Field(default_factory=list)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import csv
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -7,6 +8,8 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class CSVKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries CSV file content using embeddings."""
|
||||
|
||||
source_type: Literal["csv"] = "csv"
|
||||
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess CSV file content."""
|
||||
content_dict = {}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
@@ -16,6 +16,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
|
||||
source_type: Literal["excel"] = "excel"
|
||||
file_path: Path | list[Path] | str | list[str] | None = Field(
|
||||
default=None,
|
||||
description="[Deprecated] The path to the file. Use file_paths instead.",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -8,6 +8,8 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries JSON file content using embeddings."""
|
||||
|
||||
source_type: Literal["json"] = "json"
|
||||
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess JSON file content."""
|
||||
content: dict[Path, str] = {}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Literal
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -7,6 +8,8 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries PDF file content using embeddings."""
|
||||
|
||||
source_type: Literal["pdf"] = "pdf"
|
||||
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess PDF file content."""
|
||||
pdfplumber = self._import_pdfplumber()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -8,6 +8,7 @@ from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
class StringKnowledgeSource(BaseKnowledgeSource):
|
||||
"""A knowledge source that stores and queries plain text content using embeddings."""
|
||||
|
||||
source_type: Literal["string"] = "string"
|
||||
content: str = Field(...)
|
||||
collection_name: str | None = Field(default=None)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -6,6 +7,8 @@ from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledge
|
||||
class TextFileKnowledgeSource(BaseFileKnowledgeSource):
|
||||
"""A knowledge source that stores and queries text file content using embeddings."""
|
||||
|
||||
source_type: Literal["text_file"] = "text_file"
|
||||
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess text file content."""
|
||||
content = {}
|
||||
|
||||
@@ -6,6 +6,7 @@ from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.memory.types import (
|
||||
_RECALL_OVERSAMPLE_FACTOR,
|
||||
@@ -16,15 +17,35 @@ from crewai.memory.types import (
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
|
||||
def _ensure_memory_kind(value: Any) -> Any:
|
||||
"""Backfill ``memory_kind`` on legacy dicts that predate the discriminator.
|
||||
|
||||
Lets pre-1.14.6 configs/checkpoints flow into the discriminated
|
||||
``Memory | MemoryScope | MemorySlice`` union without crashing. Inference:
|
||||
``scopes`` key → ``slice``; ``root_path`` → ``scope``; else ``memory``.
|
||||
Pass-through for non-dict values (instances, ``bool``, ``None``).
|
||||
"""
|
||||
if isinstance(value, dict) and "memory_kind" not in value:
|
||||
if "scopes" in value:
|
||||
value["memory_kind"] = "slice"
|
||||
elif "root_path" in value:
|
||||
value["memory_kind"] = "scope"
|
||||
else:
|
||||
value["memory_kind"] = "memory"
|
||||
return value
|
||||
|
||||
|
||||
class MemoryScope(BaseModel):
|
||||
"""View of Memory restricted to a root path. All operations are scoped under that path."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
memory_kind: Literal["scope"] = "scope"
|
||||
|
||||
root_path: str = Field(default="/")
|
||||
|
||||
_memory: Memory = PrivateAttr()
|
||||
_root: str = PrivateAttr()
|
||||
_memory: Memory | None = PrivateAttr(default=None)
|
||||
_root: str = PrivateAttr(default="")
|
||||
|
||||
@model_validator(mode="wrap")
|
||||
@classmethod
|
||||
@@ -34,21 +55,38 @@ class MemoryScope(BaseModel):
|
||||
return data
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Expected dict or MemoryScope, got {type(data).__name__}")
|
||||
if "memory" not in data:
|
||||
raise ValueError("MemoryScope requires a 'memory' key")
|
||||
memory = data.pop("memory")
|
||||
memory = data.pop("memory", None)
|
||||
instance: MemoryScope = handler(data)
|
||||
instance._memory = memory
|
||||
if memory is not None:
|
||||
instance._memory = memory
|
||||
root = instance.root_path.rstrip("/") or ""
|
||||
if root and not root.startswith("/"):
|
||||
root = "/" + root
|
||||
instance._root = root
|
||||
return instance
|
||||
|
||||
def bind(self, memory: Memory) -> Self:
|
||||
"""Rebind the runtime ``Memory`` dependency after restore.
|
||||
|
||||
Required after deserializing from a checkpoint, since the live
|
||||
``Memory`` cannot be serialized.
|
||||
"""
|
||||
self._memory = memory
|
||||
return self
|
||||
|
||||
def _require_memory(self) -> Memory:
|
||||
"""Return the bound ``Memory`` or raise a clear error if missing."""
|
||||
if self._memory is None:
|
||||
raise RuntimeError(
|
||||
"MemoryScope is not bound to a Memory; call .bind(memory) "
|
||||
"after restore."
|
||||
)
|
||||
return self._memory
|
||||
|
||||
@property
|
||||
def read_only(self) -> bool:
|
||||
"""Whether the underlying memory is read-only."""
|
||||
return self._memory.read_only
|
||||
return self._require_memory().read_only
|
||||
|
||||
def _scope_path(self, scope: str | None) -> str:
|
||||
if not scope or scope == "/":
|
||||
@@ -73,7 +111,7 @@ class MemoryScope(BaseModel):
|
||||
) -> MemoryRecord | None:
|
||||
"""Remember content; scope is relative to this scope's root."""
|
||||
path = self._scope_path(scope)
|
||||
return self._memory.remember(
|
||||
return self._require_memory().remember(
|
||||
content,
|
||||
scope=path,
|
||||
categories=categories,
|
||||
@@ -96,7 +134,7 @@ class MemoryScope(BaseModel):
|
||||
) -> list[MemoryRecord]:
|
||||
"""Remember multiple items; scope is relative to this scope's root."""
|
||||
path = self._scope_path(scope)
|
||||
return self._memory.remember_many(
|
||||
return self._require_memory().remember_many(
|
||||
contents,
|
||||
scope=path,
|
||||
categories=categories,
|
||||
@@ -119,7 +157,7 @@ class MemoryScope(BaseModel):
|
||||
) -> list[MemoryMatch]:
|
||||
"""Recall within this scope (root path and below)."""
|
||||
search_scope = self._scope_path(scope) if scope else (self._root or "/")
|
||||
return self._memory.recall(
|
||||
return self._require_memory().recall(
|
||||
query,
|
||||
scope=search_scope,
|
||||
categories=categories,
|
||||
@@ -131,7 +169,7 @@ class MemoryScope(BaseModel):
|
||||
|
||||
def extract_memories(self, content: str) -> list[str]:
|
||||
"""Extract discrete memories from content; delegates to underlying Memory."""
|
||||
return self._memory.extract_memories(content)
|
||||
return self._require_memory().extract_memories(content)
|
||||
|
||||
def forget(
|
||||
self,
|
||||
@@ -143,7 +181,7 @@ class MemoryScope(BaseModel):
|
||||
) -> int:
|
||||
"""Forget within this scope."""
|
||||
prefix = self._scope_path(scope) if scope else (self._root or "/")
|
||||
return self._memory.forget(
|
||||
return self._require_memory().forget(
|
||||
scope=prefix,
|
||||
categories=categories,
|
||||
older_than=older_than,
|
||||
@@ -154,27 +192,27 @@ class MemoryScope(BaseModel):
|
||||
def list_scopes(self, path: str = "/") -> list[str]:
|
||||
"""List child scopes under path (relative to this scope's root)."""
|
||||
full = self._scope_path(path)
|
||||
return self._memory.list_scopes(full)
|
||||
return self._require_memory().list_scopes(full)
|
||||
|
||||
def info(self, path: str = "/") -> ScopeInfo:
|
||||
"""Info for path under this scope."""
|
||||
full = self._scope_path(path)
|
||||
return self._memory.info(full)
|
||||
return self._require_memory().info(full)
|
||||
|
||||
def tree(self, path: str = "/", max_depth: int = 3) -> str:
|
||||
"""Tree under path within this scope."""
|
||||
full = self._scope_path(path)
|
||||
return self._memory.tree(full, max_depth=max_depth)
|
||||
return self._require_memory().tree(full, max_depth=max_depth)
|
||||
|
||||
def list_categories(self, path: str | None = None) -> dict[str, int]:
|
||||
"""Categories in this scope; path None means this scope root."""
|
||||
full = self._scope_path(path) if path else (self._root or "/")
|
||||
return self._memory.list_categories(full)
|
||||
return self._require_memory().list_categories(full)
|
||||
|
||||
def reset(self, scope: str | None = None) -> None:
|
||||
"""Reset within this scope."""
|
||||
prefix = self._scope_path(scope) if scope else (self._root or "/")
|
||||
self._memory.reset(scope=prefix)
|
||||
self._require_memory().reset(scope=prefix)
|
||||
|
||||
def subscope(self, path: str) -> MemoryScope:
|
||||
"""Return a narrower scope under this scope."""
|
||||
@@ -191,11 +229,13 @@ class MemorySlice(BaseModel):
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
memory_kind: Literal["slice"] = "slice"
|
||||
|
||||
scopes: list[str] = Field(default_factory=list)
|
||||
categories: list[str] | None = Field(default=None)
|
||||
read_only: bool = Field(default=True)
|
||||
|
||||
_memory: Memory = PrivateAttr()
|
||||
_memory: Memory | None = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="wrap")
|
||||
@classmethod
|
||||
@@ -205,14 +245,27 @@ class MemorySlice(BaseModel):
|
||||
return data
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError(f"Expected dict or MemorySlice, got {type(data).__name__}")
|
||||
if "memory" not in data:
|
||||
raise ValueError("MemorySlice requires a 'memory' key")
|
||||
memory = data.pop("memory")
|
||||
memory = data.pop("memory", None)
|
||||
data["scopes"] = [s.rstrip("/") or "/" for s in data.get("scopes", [])]
|
||||
instance: MemorySlice = handler(data)
|
||||
instance._memory = memory
|
||||
if memory is not None:
|
||||
instance._memory = memory
|
||||
return instance
|
||||
|
||||
def bind(self, memory: Memory) -> Self:
|
||||
"""Rebind the runtime ``Memory`` dependency after restore."""
|
||||
self._memory = memory
|
||||
return self
|
||||
|
||||
def _require_memory(self) -> Memory:
|
||||
"""Return the bound ``Memory`` or raise a clear error if missing."""
|
||||
if self._memory is None:
|
||||
raise RuntimeError(
|
||||
"MemorySlice is not bound to a Memory; call .bind(memory) "
|
||||
"after restore."
|
||||
)
|
||||
return self._memory
|
||||
|
||||
def remember(
|
||||
self,
|
||||
content: str,
|
||||
@@ -226,7 +279,7 @@ class MemorySlice(BaseModel):
|
||||
"""Remember into an explicit scope. No-op when read_only=True."""
|
||||
if self.read_only:
|
||||
return None
|
||||
return self._memory.remember(
|
||||
return self._require_memory().remember(
|
||||
content,
|
||||
scope=scope,
|
||||
categories=categories,
|
||||
@@ -250,7 +303,7 @@ class MemorySlice(BaseModel):
|
||||
cats = categories or self.categories
|
||||
all_matches: list[MemoryMatch] = []
|
||||
for sc in self.scopes:
|
||||
matches = self._memory.recall(
|
||||
matches = self._require_memory().recall(
|
||||
query,
|
||||
scope=sc,
|
||||
categories=cats,
|
||||
@@ -272,14 +325,14 @@ class MemorySlice(BaseModel):
|
||||
|
||||
def extract_memories(self, content: str) -> list[str]:
|
||||
"""Extract discrete memories from content; delegates to underlying Memory."""
|
||||
return self._memory.extract_memories(content)
|
||||
return self._require_memory().extract_memories(content)
|
||||
|
||||
def list_scopes(self, path: str = "/") -> list[str]:
|
||||
"""List scopes across all slice roots."""
|
||||
out: list[str] = []
|
||||
for sc in self.scopes:
|
||||
full = f"{sc.rstrip('/')}{path}" if sc != "/" else path
|
||||
out.extend(self._memory.list_scopes(full))
|
||||
out.extend(self._require_memory().list_scopes(full))
|
||||
return sorted(set(out))
|
||||
|
||||
def info(self, path: str = "/") -> ScopeInfo:
|
||||
@@ -291,7 +344,7 @@ class MemorySlice(BaseModel):
|
||||
children: list[str] = []
|
||||
for sc in self.scopes:
|
||||
full = f"{sc.rstrip('/')}{path}" if sc != "/" else path
|
||||
inf = self._memory.info(full)
|
||||
inf = self._require_memory().info(full)
|
||||
total_records += inf.record_count
|
||||
all_categories.update(inf.categories)
|
||||
if inf.oldest_record:
|
||||
@@ -321,6 +374,6 @@ class MemorySlice(BaseModel):
|
||||
counts: dict[str, int] = {}
|
||||
for sc in self.scopes:
|
||||
full = (f"{sc.rstrip('/')}{path}" if sc != "/" else path) if path else sc
|
||||
for k, v in self._memory.list_categories(full).items():
|
||||
for k, v in self._require_memory().list_categories(full).items():
|
||||
counts[k] = counts.get(k, 0) + v
|
||||
return counts
|
||||
|
||||
@@ -63,6 +63,8 @@ class Memory(BaseModel):
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
memory_kind: Literal["memory"] = "memory"
|
||||
|
||||
llm: Annotated[BaseLLM | str, PlainValidator(_passthrough)] = Field(
|
||||
default="gpt-4o-mini",
|
||||
description="LLM for analysis (model name or BaseLLM instance).",
|
||||
|
||||
@@ -113,12 +113,68 @@ def _migrate(data: dict[str, Any]) -> dict[str, Any]:
|
||||
)
|
||||
|
||||
# --- migrations in version order ---
|
||||
# if stored < Version("X.Y.Z"):
|
||||
# data.setdefault("some_field", "default")
|
||||
if stored < Version("1.14.6"):
|
||||
for entity in data.get("entities") or []:
|
||||
_backfill_discriminators(entity)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _backfill_memory_kind(value: Any) -> None:
|
||||
"""Infer ``memory_kind`` from structural fields on legacy memory dicts."""
|
||||
if not isinstance(value, dict) or "memory_kind" in value:
|
||||
return
|
||||
if "scopes" in value:
|
||||
value["memory_kind"] = "slice"
|
||||
elif "root_path" in value:
|
||||
value["memory_kind"] = "scope"
|
||||
else:
|
||||
value["memory_kind"] = "memory"
|
||||
|
||||
|
||||
def _backfill_source_type(source: Any) -> None:
|
||||
"""Infer ``source_type`` for legacy knowledge source dicts when possible.
|
||||
|
||||
Only StringKnowledgeSource is reliably inferrable: it stores ``content``
|
||||
as a plain string. File-based sources (CSV/PDF/Excel/JSON/docling) also
|
||||
have a ``content`` field but populate it with dicts/lists, so we leave
|
||||
those untagged and let downstream validation surface a clear error.
|
||||
"""
|
||||
if not isinstance(source, dict) or "source_type" in source:
|
||||
return
|
||||
if isinstance(source.get("content"), str):
|
||||
source["source_type"] = "string"
|
||||
return
|
||||
raise ValueError(
|
||||
"Legacy knowledge source is missing 'source_type' and could not be "
|
||||
"inferred during migration. Re-checkpoint after upgrading to 1.14.6+."
|
||||
)
|
||||
|
||||
|
||||
def _backfill_sources_on(container: Any) -> None:
|
||||
"""Apply source_type backfill to ``sources`` and ``knowledge_sources`` lists."""
|
||||
if not isinstance(container, dict):
|
||||
return
|
||||
for key in ("sources", "knowledge_sources"):
|
||||
for src in container.get(key) or []:
|
||||
_backfill_source_type(src)
|
||||
|
||||
|
||||
def _backfill_discriminators(entity: Any) -> None:
|
||||
"""Walk an entity dict and backfill discriminator fields added in 1.14.6."""
|
||||
if not isinstance(entity, dict):
|
||||
return
|
||||
_backfill_memory_kind(entity.get("memory"))
|
||||
_backfill_sources_on(entity)
|
||||
_backfill_sources_on(entity.get("knowledge"))
|
||||
for agent in entity.get("agents") or []:
|
||||
if not isinstance(agent, dict):
|
||||
continue
|
||||
_backfill_memory_kind(agent.get("memory"))
|
||||
_backfill_sources_on(agent)
|
||||
_backfill_sources_on(agent.get("knowledge"))
|
||||
|
||||
|
||||
class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
root: list[Entity]
|
||||
_provider: BaseProvider = PrivateAttr(default_factory=JsonProvider)
|
||||
|
||||
@@ -19,6 +19,15 @@ from pydantic import BeforeValidator, WithJsonSchema
|
||||
from pydantic.functional_serializers import PlainSerializer
|
||||
|
||||
|
||||
_TRUSTED_DESERIALIZE_VALUES = frozenset({"1", "true", "yes"})
|
||||
|
||||
|
||||
def _trusted_deserialize() -> bool:
|
||||
"""Return True only if ``CREWAI_DESERIALIZE_CALLBACKS`` is an explicit yes."""
|
||||
raw = os.environ.get("CREWAI_DESERIALIZE_CALLBACKS", "")
|
||||
return raw.strip().lower() in _TRUSTED_DESERIALIZE_VALUES
|
||||
|
||||
|
||||
def _is_non_roundtrippable(fn: object) -> bool:
|
||||
"""Return ``True`` if *fn* cannot survive a serialize/deserialize round-trip.
|
||||
|
||||
@@ -76,7 +85,7 @@ def string_to_callable(value: Any) -> Callable[..., Any]:
|
||||
raise ValueError(
|
||||
f"Invalid callback path {value!r}: expected 'module.name' format"
|
||||
)
|
||||
if not os.environ.get("CREWAI_DESERIALIZE_CALLBACKS"):
|
||||
if not _trusted_deserialize():
|
||||
raise ValueError(
|
||||
f"Refusing to resolve callback path {value!r}: "
|
||||
"set CREWAI_DESERIALIZE_CALLBACKS=1 to allow. "
|
||||
@@ -150,3 +159,78 @@ SerializableCallable = Annotated[
|
||||
PlainSerializer(callable_to_string, return_type=str, when_used="json"),
|
||||
WithJsonSchema({"type": "string"}),
|
||||
]
|
||||
|
||||
|
||||
def _instance_to_dotted_path(value: Any) -> str:
|
||||
"""Serialize an instance to a dotted path naming its class."""
|
||||
if inspect.isclass(value):
|
||||
module = getattr(value, "__module__", "<unknown>")
|
||||
qualname = getattr(
|
||||
value, "__qualname__", getattr(value, "__name__", str(type(value)))
|
||||
)
|
||||
raise ValueError(f"Expected an instance, got class {module}.{qualname}.")
|
||||
cls = type(value)
|
||||
if cls.__module__ == "builtins":
|
||||
raise ValueError(
|
||||
f"Cannot serialize {value!r}: builtin values are not "
|
||||
"checkpointable instances."
|
||||
)
|
||||
module = getattr(cls, "__module__", None)
|
||||
qualname = getattr(cls, "__qualname__", None)
|
||||
if module is None or qualname is None:
|
||||
raise ValueError(
|
||||
f"Cannot serialize {value!r}: class missing __module__ or __qualname__. "
|
||||
"Use a module-level class for checkpointable instances."
|
||||
)
|
||||
if qualname.endswith("<lambda>") or "<locals>" in qualname:
|
||||
raise ValueError(
|
||||
f"Cannot serialize {value!r}: class defined in <locals>. "
|
||||
"Use a module-level class for checkpointable instances."
|
||||
)
|
||||
return f"{module}.{qualname}"
|
||||
|
||||
|
||||
def _dotted_path_to_instance(value: Any) -> Any:
|
||||
"""Resolve a dotted path to a class and instantiate it with no args.
|
||||
|
||||
If *value* is already a non-string object it is returned as-is.
|
||||
"""
|
||||
if value is None:
|
||||
return value
|
||||
if not isinstance(value, str):
|
||||
if inspect.isclass(value):
|
||||
raise ValueError(
|
||||
f"Expected an instance or dotted path string, got class "
|
||||
f"{getattr(value, '__module__', '<unknown>')}."
|
||||
f"{getattr(value, '__qualname__', getattr(value, '__name__', ''))}."
|
||||
)
|
||||
if type(value).__module__ == "builtins":
|
||||
raise ValueError(
|
||||
f"Expected an instance of a user-defined class or dotted "
|
||||
f"path string, got builtin value {value!r}."
|
||||
)
|
||||
return value
|
||||
if "." not in value:
|
||||
raise ValueError(
|
||||
f"Invalid provider path {value!r}: expected 'module.name' format"
|
||||
)
|
||||
if not _trusted_deserialize():
|
||||
raise ValueError(
|
||||
f"Refusing to resolve provider path {value!r}: "
|
||||
"set CREWAI_DESERIALIZE_CALLBACKS=1 to allow. "
|
||||
"Only enable this for trusted checkpoint data."
|
||||
)
|
||||
cls = _resolve_dotted_path(value)
|
||||
if not inspect.isclass(cls):
|
||||
raise ValueError(
|
||||
f"Invalid provider path {value!r}: expected a class, got "
|
||||
f"{type(cls).__name__}"
|
||||
)
|
||||
try:
|
||||
return cls()
|
||||
except TypeError as exc:
|
||||
raise ValueError(
|
||||
f"Cannot reinstantiate {value!r} with no arguments: {exc}. "
|
||||
"Only no-arg constructors are checkpointable; rebuild the "
|
||||
"instance manually and assign it after restore."
|
||||
) from exc
|
||||
|
||||
@@ -25,10 +25,16 @@ def _reset_flow_memory(flow: Flow[Any]) -> None:
|
||||
try:
|
||||
if hasattr(mem, "reset"):
|
||||
mem.reset()
|
||||
elif hasattr(mem, "_memory") and hasattr(mem._memory, "reset"):
|
||||
elif hasattr(mem, "_memory") and mem._memory is not None:
|
||||
mem._memory.reset()
|
||||
except (FileNotFoundError, OSError):
|
||||
except FileNotFoundError:
|
||||
# Storage directory was never created — nothing to reset.
|
||||
pass
|
||||
except OSError as exc:
|
||||
click.echo(f"Memory reset skipped: storage I/O error ({exc}).", err=True)
|
||||
except RuntimeError as exc:
|
||||
# Restored MemoryScope/MemorySlice without a rebound Memory.
|
||||
click.echo(f"Memory reset skipped: {exc}", err=True)
|
||||
|
||||
|
||||
def reset_memories_command(
|
||||
|
||||
@@ -7,20 +7,87 @@ durability, input history tracking, and integration with flow machinery.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.flow import Flow, flow_config, listen, start
|
||||
from crewai.flow.async_feedback.providers import ConsoleProvider
|
||||
from crewai.flow.flow import FlowState
|
||||
from crewai.flow.input_provider import InputProvider, InputResponse
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
|
||||
|
||||
# ── Test helpers ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _SaveCall:
|
||||
"""Lightweight stand-in for ``MagicMock.call_args`` entries."""
|
||||
|
||||
__slots__ = ("args", "kwargs")
|
||||
|
||||
def __init__(self, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None:
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
class _SaveStateRecorder:
|
||||
"""Callable that records each ``save_state`` invocation."""
|
||||
|
||||
def __init__(self, owner: RecordingPersistence) -> None:
|
||||
self._owner = owner
|
||||
self.call_args_list: list[_SaveCall] = []
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
flow_uuid: str,
|
||||
method_name: str,
|
||||
state_data: dict[str, Any] | BaseModel,
|
||||
) -> None:
|
||||
snapshot: dict[str, Any] | BaseModel
|
||||
if isinstance(state_data, BaseModel):
|
||||
snapshot = state_data.model_copy(deep=True)
|
||||
else:
|
||||
snapshot = copy.deepcopy(state_data)
|
||||
self.call_args_list.append(
|
||||
_SaveCall((flow_uuid, method_name, snapshot), {})
|
||||
)
|
||||
self._owner._states[flow_uuid] = snapshot
|
||||
|
||||
|
||||
class RecordingPersistence(FlowPersistence):
|
||||
"""In-memory FlowPersistence that records ``save_state`` invocations."""
|
||||
|
||||
persistence_type: str = "RecordingPersistence"
|
||||
|
||||
def model_post_init(self, _: Any) -> None:
|
||||
object.__setattr__(self, "_states", {})
|
||||
object.__setattr__(self, "save_state", _SaveStateRecorder(self))
|
||||
|
||||
def init_db(self) -> None:
|
||||
return None
|
||||
|
||||
def save_state( # type: ignore[no-redef]
|
||||
self,
|
||||
flow_uuid: str,
|
||||
method_name: str,
|
||||
state_data: dict[str, Any] | BaseModel,
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
|
||||
snapshot = self._states.get(flow_uuid)
|
||||
if snapshot is None:
|
||||
return None
|
||||
if isinstance(snapshot, BaseModel):
|
||||
return snapshot.model_copy(deep=True).model_dump()
|
||||
return copy.deepcopy(snapshot)
|
||||
|
||||
|
||||
class MockInputProvider:
|
||||
"""Mock input provider that returns pre-configured responses."""
|
||||
|
||||
@@ -436,8 +503,7 @@ class TestAskCheckpoint:
|
||||
|
||||
def test_ask_checkpoints_state_before_waiting(self) -> None:
|
||||
"""State is saved to persistence before waiting for input."""
|
||||
mock_persistence = MagicMock()
|
||||
mock_persistence.load_state.return_value = None
|
||||
mock_persistence = RecordingPersistence()
|
||||
|
||||
class TestFlow(Flow):
|
||||
input_provider = MockInputProvider(["answer"])
|
||||
@@ -480,8 +546,7 @@ class TestAskCheckpoint:
|
||||
server crashes while waiting for input, previously gathered data
|
||||
is safe.
|
||||
"""
|
||||
mock_persistence = MagicMock()
|
||||
mock_persistence.load_state.return_value = None
|
||||
mock_persistence = RecordingPersistence()
|
||||
|
||||
class GatherFlow(Flow):
|
||||
input_provider = MockInputProvider(["AI", "detailed"])
|
||||
@@ -678,8 +743,7 @@ class TestAskIntegration:
|
||||
|
||||
def test_ask_with_state_persistence_recovery(self) -> None:
|
||||
"""Ask checkpoints state so previously gathered values survive."""
|
||||
mock_persistence = MagicMock()
|
||||
mock_persistence.load_state.return_value = None
|
||||
mock_persistence = RecordingPersistence()
|
||||
|
||||
class RecoverableFlow(Flow):
|
||||
input_provider = MockInputProvider(["AI", "detailed"])
|
||||
|
||||
Reference in New Issue
Block a user