feat(storage): pluggable default backends for memory, knowledge, rag, flow (#6079)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Vulnerability Scan / pip-audit (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

Add opt-in extension seams so an application can route memory, knowledge,
RAG, and flow persistence through a custom backend without subclassing or
threading an explicit instance through every construction site -- mirroring
the existing crewai_core.lock_store.set_lock_backend seam.

- memory:    crewai.memory.storage.factory.set_memory_storage_factory
- knowledge: crewai.knowledge.storage.factory.set_knowledge_storage_factory
- rag:       crewai.rag.factory.register_rag_client_factory (provider registry)
- flow:      crewai.flow.persistence.factory.set_flow_persistence_factory

Each construction site consults the registered factory and falls back to the
built-in default when none is set; an explicit instance always wins. Widen
Knowledge.storage and the knowledge source base classes to BaseKnowledgeStorage
(consistent with BaseAgent.knowledge_storage) so any base-interface backend
plugs in. Runtime-free tests cover each seam.
This commit is contained in:
Matt Aitchison
2026-06-08 21:14:13 -05:00
committed by GitHub
parent 988927006c
commit 8919026326
14 changed files with 596 additions and 33 deletions

View File

@@ -35,7 +35,7 @@ from crewai_core.printer import PRINTER
from pydantic import BaseModel
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
from crewai.flow.persistence.factory import default_flow_persistence
if TYPE_CHECKING:
@@ -171,7 +171,9 @@ def persist(
Args:
persistence: Optional FlowPersistence implementation to use.
If not provided, uses SQLiteFlowPersistence.
If not provided, uses ``default_flow_persistence()`` (the
registered factory when present, else the built-in SQLite
fallback).
verbose: Whether to log persistence operations. Defaults to False.
Returns:
@@ -190,7 +192,9 @@ def persist(
"""
def decorator(target: type | Callable[..., T]) -> type | Callable[..., T]:
actual_persistence = persistence or SQLiteFlowPersistence()
actual_persistence = (
persistence if persistence is not None else default_flow_persistence()
)
if isinstance(target, type):
_stamp_persistence_metadata(target, actual_persistence, verbose)

View File

@@ -0,0 +1,60 @@
"""Pluggable default persistence backend for flows.
By default, ``@persist`` and the flow runtime persist state with
:class:`~crewai.flow.persistence.sqlite.SQLiteFlowPersistence` when no explicit
``persistence=`` is given. Registering a factory via
:func:`set_flow_persistence_factory` lets an application back flow state with a
custom :class:`~crewai.flow.persistence.base.FlowPersistence` -- a database, a
remote service, an in-memory fake for tests -- without passing a
``persistence=`` instance at every ``@persist`` / kickoff site.
This mirrors :func:`crewai_core.lock_store.set_lock_backend`: a one-time,
process-wide setter intended for application startup. Pass ``None`` to restore
the built-in SQLite default. Call :func:`default_flow_persistence` to build the
default backend (the registered factory if any, else SQLite).
"""
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from crewai.flow.persistence.base import FlowPersistence
FlowPersistenceFactory = Callable[[], "FlowPersistence"]
_factory: FlowPersistenceFactory | None = None
def set_flow_persistence_factory(factory: FlowPersistenceFactory | None) -> None:
"""Replace the process-wide default flow persistence factory.
Intended for one-time setup at startup. Pass ``None`` to restore the
built-in ``SQLiteFlowPersistence``. Only affects flows that fall back to
the default; an explicit ``persistence=`` instance always wins.
The default is resolved at each fall-back site (``@persist`` and the
runtime's pause/resume paths), so the factory may be called more than once
for a single flow. Return instances backed by shared durable state (or a
singleton) so state saved on one call is visible to the next -- the
built-in SQLite default satisfies this by sharing one on-disk file.
"""
global _factory
_factory = factory
def default_flow_persistence() -> FlowPersistence:
"""Build the default flow persistence backend.
Returns the result of the registered factory if one is set, otherwise a
built-in :class:`~crewai.flow.persistence.sqlite.SQLiteFlowPersistence`.
"""
factory = _factory
if factory is not None:
return factory()
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
return SQLiteFlowPersistence()

View File

@@ -1252,7 +1252,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
Args:
flow_id: The unique identifier of the paused flow (from state.id)
persistence: The persistence backend where the state was saved.
If not provided, defaults to SQLiteFlowPersistence().
If not provided, uses ``default_flow_persistence()`` (the
registered factory when present, else the built-in SQLite
fallback).
**kwargs: Additional keyword arguments passed to the Flow constructor
Returns:
@@ -1274,9 +1276,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
```
"""
if persistence is None:
from crewai.flow.persistence import SQLiteFlowPersistence
from crewai.flow.persistence.factory import default_flow_persistence
persistence = SQLiteFlowPersistence()
persistence = default_flow_persistence()
loaded = persistence.load_pending_feedback(flow_id)
if loaded is None:
@@ -1463,7 +1465,7 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
self._pending_feedback_context = None
if self.persistence:
if self.persistence is not None:
self.persistence.clear_pending_feedback(context.flow_id)
crewai_event_bus.emit(
@@ -1505,9 +1507,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
self._pending_feedback_context = e.context
if self.persistence is None:
from crewai.flow.persistence import SQLiteFlowPersistence
from crewai.flow.persistence.factory import default_flow_persistence
self.persistence = SQLiteFlowPersistence()
self.persistence = default_flow_persistence()
state_data = (
self._state
@@ -2244,9 +2246,11 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
if isinstance(e, HumanFeedbackPending):
# Auto-save pending feedback (create default persistence if needed)
if self.persistence is None:
from crewai.flow.persistence import SQLiteFlowPersistence
from crewai.flow.persistence.factory import (
default_flow_persistence,
)
self.persistence = SQLiteFlowPersistence()
self.persistence = default_flow_persistence()
state_data = (
self._state
@@ -2597,9 +2601,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
e.context.method_name = method_name
if self.persistence is None:
from crewai.flow.persistence import SQLiteFlowPersistence
from crewai.flow.persistence.factory import default_flow_persistence
self.persistence = SQLiteFlowPersistence()
self.persistence = default_flow_persistence()
# Emit paused event (not failed)
if not self.suppress_flow_events:

View File

@@ -13,6 +13,7 @@ from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSourc
from crewai.knowledge.source.text_file_knowledge_source import (
TextFileKnowledgeSource,
)
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.types import EmbedderConfig
@@ -89,7 +90,7 @@ class Knowledge(BaseModel):
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
Args:
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
storage: KnowledgeStorage | None = Field(default=None)
storage: BaseKnowledgeStorage | None = Field(default=None)
embedder: EmbedderConfig | None = None
"""
@@ -98,7 +99,7 @@ class Knowledge(BaseModel):
BeforeValidator(_resolve_knowledge_sources),
] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: KnowledgeStorage | None = Field(default=None)
storage: BaseKnowledgeStorage | None = Field(default=None)
embedder: Annotated[
EmbedderConfig | None,
PlainSerializer(
@@ -112,15 +113,22 @@ class Knowledge(BaseModel):
collection_name: str,
sources: list[BaseKnowledgeSource],
embedder: EmbedderConfig | None = None,
storage: KnowledgeStorage | None = None,
storage: BaseKnowledgeStorage | None = None,
**data: object,
) -> None:
super().__init__(**data)
if storage:
if storage is not None:
self.storage = storage
else:
self.storage = KnowledgeStorage(
embedder=embedder, collection_name=collection_name
from crewai.knowledge.storage.factory import resolve_knowledge_storage
custom = resolve_knowledge_storage(embedder, collection_name)
self.storage = (
custom
if custom is not None
else KnowledgeStorage(
embedder=embedder, collection_name=collection_name
)
)
self.sources = sources
@@ -152,10 +160,9 @@ class Knowledge(BaseModel):
raise e
def reset(self) -> None:
if self.storage:
self.storage.reset()
else:
if self.storage is None:
raise ValueError("Storage is not initialized.")
self.storage.reset()
async def aquery(
self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6
@@ -193,7 +200,6 @@ class Knowledge(BaseModel):
async def areset(self) -> None:
"""Reset the knowledge base asynchronously."""
if self.storage:
await self.storage.areset()
else:
if self.storage is None:
raise ValueError("Storage is not initialized.")
await self.storage.areset()

View File

@@ -5,7 +5,7 @@ from typing import Any
from pydantic import Field, field_validator
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
from crewai.utilities.logger import Logger
@@ -22,7 +22,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
default_factory=list, description="The path to the file"
)
content: dict[Path, str] = Field(init=False, default_factory=dict)
storage: KnowledgeStorage | None = Field(default=None)
storage: BaseKnowledgeStorage | None = Field(default=None)
safe_file_paths: list[Path] = Field(default_factory=list)
@field_validator("file_path", "file_paths", mode="before")
@@ -70,14 +70,14 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
def _save_documents(self) -> None:
"""Save the documents to the storage."""
if self.storage:
if self.storage is not None:
self.storage.save(self.chunks)
else:
raise ValueError("No storage found to save documents.")
async def _asave_documents(self) -> None:
"""Save the documents to the storage asynchronously."""
if self.storage:
if self.storage is not None:
await self.storage.asave(self.chunks)
else:
raise ValueError("No storage found to save documents.")

View File

@@ -4,9 +4,15 @@ from typing import Any
import numpy as np
from pydantic import BaseModel, ConfigDict, Field
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
# ``KnowledgeStorage`` is re-exported for backwards compatibility; the ``storage``
# field below is typed to the base interface so any backend plugs in.
__all__ = ["BaseKnowledgeSource", "KnowledgeStorage"]
class BaseKnowledgeSource(BaseModel, ABC):
"""Abstract base class for knowledge sources."""
@@ -18,7 +24,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: KnowledgeStorage | None = Field(default=None)
storage: BaseKnowledgeStorage | None = Field(default=None)
metadata: dict[str, Any] = Field(default_factory=dict) # Currently unused
collection_name: str | None = Field(default=None)
@@ -49,7 +55,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
Raises:
ValueError: If no storage is configured.
"""
if self.storage:
if self.storage is not None:
self.storage.save(self.chunks)
else:
raise ValueError("No storage found to save documents.")
@@ -66,7 +72,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
Raises:
ValueError: If no storage is configured.
"""
if self.storage:
if self.storage is not None:
await self.storage.asave(self.chunks)
else:
raise ValueError("No storage found to save documents.")

View File

@@ -0,0 +1,56 @@
"""Pluggable default storage backend for knowledge collections.
By default, :class:`~crewai.knowledge.knowledge.Knowledge` builds a
:class:`~crewai.knowledge.storage.knowledge_storage.KnowledgeStorage` when no
explicit ``storage=`` is given. Registering a factory via
:func:`set_knowledge_storage_factory` lets an application back knowledge with a
custom :class:`~crewai.knowledge.storage.base_knowledge_storage.BaseKnowledgeStorage`
without subclassing ``Knowledge`` or passing a ``storage=`` instance at every
call site.
This mirrors :func:`crewai_core.lock_store.set_lock_backend`: a one-time,
process-wide setter intended for application startup. Pass ``None`` to restore
the built-in default.
"""
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.rag.embeddings.types import EmbedderConfig
# Receives the same inputs as the built-in default -- the embedder config and
# collection name -- and returns a storage backend, or ``None`` to defer to the
# built-in ``KnowledgeStorage``.
KnowledgeStorageFactory = Callable[
["EmbedderConfig | None", "str | None"], "BaseKnowledgeStorage | None"
]
_factory: KnowledgeStorageFactory | None = None
def set_knowledge_storage_factory(factory: KnowledgeStorageFactory | None) -> None:
"""Replace the process-wide default knowledge storage factory.
Intended for one-time setup at startup. Pass ``None`` to restore the
built-in ``KnowledgeStorage``. Only affects ``Knowledge`` instances
constructed afterwards; an explicit ``storage=`` instance always wins.
"""
global _factory
_factory = factory
def resolve_knowledge_storage(
embedder: EmbedderConfig | None, collection_name: str | None
) -> BaseKnowledgeStorage | None:
"""Return the registered factory's backend, or ``None`` for the built-in.
``None`` means no factory is registered or it declined; the caller then
falls back to the built-in ``KnowledgeStorage``.
"""
factory = _factory
return factory(embedder, collection_name) if factory is not None else None

View File

@@ -0,0 +1,55 @@
"""Pluggable default storage backend for the unified memory system.
By default, :class:`~crewai.memory.unified_memory.Memory` builds a built-in
vector store from its ``storage`` spec string (LanceDB, or Qdrant for the
``"qdrant-edge"`` spec). Registering a factory via
:func:`set_memory_storage_factory` lets an application route memory through a
custom :class:`~crewai.memory.storage.backend.StorageBackend` -- a different
vector store, a remote service, an in-memory fake for tests -- without
subclassing ``Memory`` or threading an explicit ``storage=`` instance through
every construction site.
This mirrors :func:`crewai_core.lock_store.set_lock_backend`: a one-time,
process-wide setter intended for application startup. Pass ``None`` to restore
the built-in default.
"""
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from crewai.memory.storage.backend import StorageBackend
# Receives the raw ``storage`` spec string and returns a backend to use, or
# ``None`` to defer to the built-in selection for that spec.
MemoryStorageFactory = Callable[[str], "StorageBackend | None"]
_factory: MemoryStorageFactory | None = None
def set_memory_storage_factory(factory: MemoryStorageFactory | None) -> None:
"""Replace the process-wide default memory storage factory.
Intended for one-time setup at startup. Pass ``None`` to restore the
built-in LanceDB/Qdrant selection. Only affects ``Memory`` instances
constructed afterwards; an explicit ``storage=`` instance always wins.
The factory is consulted for every string ``storage`` spec, so it must
return ``None`` for specs it does not handle to let the built-in
LanceDB/Qdrant/path selection take over.
"""
global _factory
_factory = factory
def resolve_memory_storage(spec: str) -> StorageBackend | None:
"""Return the registered factory's backend for ``spec``, or ``None``.
``None`` means no factory is registered or it declined this spec; the
caller then falls back to the built-in selection.
"""
factory = _factory
return factory(spec) if factory is not None else None

View File

@@ -204,7 +204,12 @@ class Memory(BaseModel):
)
if isinstance(self.storage, str):
if self.storage == "qdrant-edge":
from crewai.memory.storage.factory import resolve_memory_storage
custom = resolve_memory_storage(self.storage)
if custom is not None:
self._storage = custom
elif self.storage == "qdrant-edge":
from crewai.memory.storage.qdrant_edge_storage import QdrantEdgeStorage
self._storage = QdrantEdgeStorage()

View File

@@ -1,5 +1,6 @@
"""Factory functions for creating RAG clients from configuration."""
from collections.abc import Callable
from typing import cast
from crewai.rag.config.optional_imports.protocols import (
@@ -11,6 +12,32 @@ from crewai.rag.core.base_client import BaseClient
from crewai.utilities.import_utils import require
# RAG uses a provider-keyed registry (rather than the single-default setter
# used by the memory/knowledge/flow seams) because ``create_client`` already
# dispatches on ``config.provider`` -- the natural seam here is per-provider.
# A factory receives the RAG config and returns a client; one registered for a
# built-in provider name overrides the built-in for that provider.
RagClientFactory = Callable[[RagConfigType], BaseClient]
_factories: dict[str, RagClientFactory] = {}
def register_rag_client_factory(provider: str, factory: RagClientFactory) -> None:
"""Register a client factory for a RAG ``provider`` name.
Lets an application plug in a client for a new provider, or override a
built-in provider (``"chromadb"`` / ``"qdrant"``), without modifying
:func:`create_client`. Registered factories take precedence over the
built-ins. Intended for one-time setup at startup.
"""
_factories[provider] = factory
def unregister_rag_client_factory(provider: str) -> None:
"""Remove a previously registered factory; a no-op if none is registered."""
_factories.pop(provider, None)
def create_client(config: RagConfigType) -> BaseClient:
"""Create a client from configuration using the appropriate factory.
@@ -24,6 +51,10 @@ def create_client(config: RagConfigType) -> BaseClient:
ValueError: If the configuration provider is not supported.
"""
factory = _factories.get(config.provider)
if factory is not None:
return factory(config)
if config.provider == "chromadb":
chromadb_mod = cast(
ChromaFactoryModule,

View File

@@ -0,0 +1,130 @@
"""Tests for the pluggable knowledge storage factory seam.
We verify our own logic: the set/get round-trip, that a registered factory is
consulted when no explicit ``storage=`` is given (and receives the embedder and
collection name), and that an explicit ``storage=`` instance bypasses it.
"""
from __future__ import annotations
from typing import Any
import pytest
import crewai.knowledge.storage.factory as factory
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.rag.types import SearchResult
class _FakeKnowledgeStorage(BaseKnowledgeStorage):
"""Minimal stand-in implementing the abstract interface."""
def search(
self,
query: list[str],
limit: int = 5,
metadata_filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[SearchResult]:
return []
async def asearch(
self,
query: list[str],
limit: int = 5,
metadata_filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[SearchResult]:
return []
def save(self, documents: list[str]) -> None:
return None
async def asave(self, documents: list[str]) -> None:
return None
def reset(self) -> None:
return None
async def areset(self) -> None:
return None
@pytest.fixture(autouse=True)
def reset_factory():
"""Reset the factory around each test without clobbering preexisting state."""
original = factory._factory
factory.set_knowledge_storage_factory(None)
yield
factory.set_knowledge_storage_factory(original)
def test_resolve_reflects_registered_factory():
fake = _FakeKnowledgeStorage()
assert factory.resolve_knowledge_storage(None, "docs") is None
factory.set_knowledge_storage_factory(lambda embedder, name: fake)
assert factory.resolve_knowledge_storage(None, "docs") is fake
def test_factory_used_when_no_explicit_storage():
fake = _FakeKnowledgeStorage()
factory.set_knowledge_storage_factory(lambda embedder, name: fake)
knowledge = Knowledge(collection_name="docs", sources=[])
assert knowledge.storage is fake
def test_factory_receives_embedder_and_collection_name():
seen: list[tuple[object, object]] = []
def make(embedder, collection_name):
seen.append((embedder, collection_name))
return _FakeKnowledgeStorage()
factory.set_knowledge_storage_factory(make)
Knowledge(collection_name="docs", sources=[])
assert seen == [(None, "docs")]
def test_explicit_storage_bypasses_factory():
factory_called = False
def make(embedder, name):
nonlocal factory_called
factory_called = True
return _FakeKnowledgeStorage()
factory.set_knowledge_storage_factory(make)
explicit = _FakeKnowledgeStorage()
knowledge = Knowledge(collection_name="docs", sources=[], storage=explicit)
assert knowledge.storage is explicit
assert factory_called is False
def test_falsy_explicit_storage_is_honored():
# A custom backend that is falsy (defines __bool__/__len__) must still be
# used and operated on, not silently treated as "not initialized" by a
# truthiness check in __init__, reset, or the source save path.
reset_calls: list[bool] = []
class _FalsyStorage(_FakeKnowledgeStorage):
def __bool__(self) -> bool:
return False
def reset(self) -> None:
reset_calls.append(True)
explicit = _FalsyStorage()
knowledge = Knowledge(collection_name="docs", sources=[], storage=explicit)
assert knowledge.storage is explicit
# reset must call the backend, not raise "Storage is not initialized."
knowledge.reset()
assert reset_calls == [True]

View File

@@ -0,0 +1,72 @@
"""Tests for the pluggable memory storage factory seam.
We verify our own logic: the set/get round-trip, that a registered factory is
consulted for string ``storage`` specs (and receives the spec), and that an
explicit ``storage=`` instance bypasses the factory entirely.
"""
from __future__ import annotations
import pytest
import crewai.memory.storage.factory as factory
from crewai.memory.unified_memory import Memory
@pytest.fixture(autouse=True)
def reset_factory():
"""Reset the factory around each test without clobbering preexisting state."""
original = factory._factory
factory.set_memory_storage_factory(None)
yield
factory.set_memory_storage_factory(original)
def test_resolve_reflects_registered_factory():
sentinel = object()
assert factory.resolve_memory_storage("lancedb") is None
factory.set_memory_storage_factory(lambda spec: sentinel)
assert factory.resolve_memory_storage("lancedb") is sentinel
factory.set_memory_storage_factory(None)
assert factory.resolve_memory_storage("lancedb") is None
def test_factory_backend_used_for_string_spec():
sentinel = object()
factory.set_memory_storage_factory(lambda spec: sentinel)
mem = Memory(storage="lancedb")
assert mem._storage is sentinel
def test_factory_receives_the_raw_spec():
seen: list[str] = []
def make(spec):
seen.append(spec)
return object()
factory.set_memory_storage_factory(make)
Memory(storage="some/custom/path")
assert seen == ["some/custom/path"]
def test_explicit_storage_instance_bypasses_factory():
factory_called = False
def make(spec):
nonlocal factory_called
factory_called = True
return object()
factory.set_memory_storage_factory(make)
explicit = object()
mem = Memory(storage=explicit) # type: ignore[arg-type]
assert mem._storage is explicit
assert factory_called is False

View File

@@ -0,0 +1,66 @@
"""Tests for the RAG client factory registry seam.
We verify our own logic: a registered factory is used for its provider,
factories override the built-in providers, unregister removes them, and an
unknown provider still raises.
"""
from __future__ import annotations
from types import SimpleNamespace
import pytest
import crewai.rag.factory as factory
@pytest.fixture(autouse=True)
def reset_registry():
"""Reset the registry around each test without clobbering preexisting state."""
original = dict(factory._factories)
factory._factories.clear()
yield
factory._factories.clear()
factory._factories.update(original)
def test_registered_factory_is_used_for_its_provider():
sentinel = object()
factory.register_rag_client_factory("custom", lambda config: sentinel)
assert factory.create_client(SimpleNamespace(provider="custom")) is sentinel
def test_factory_receives_the_config():
seen: list[object] = []
config = SimpleNamespace(provider="custom")
factory.register_rag_client_factory("custom", lambda cfg: seen.append(cfg) or object())
factory.create_client(config)
assert seen == [config]
def test_factory_overrides_builtin_provider():
sentinel = object()
factory.register_rag_client_factory("chromadb", lambda config: sentinel)
# Resolves via the registry without importing the built-in chromadb factory.
assert factory.create_client(SimpleNamespace(provider="chromadb")) is sentinel
def test_unregister_removes_factory():
factory.register_rag_client_factory("custom", lambda config: object())
factory.unregister_rag_client_factory("custom")
with pytest.raises(ValueError, match="Unsupported provider: custom"):
factory.create_client(SimpleNamespace(provider="custom"))
def test_unregister_unknown_provider_is_noop():
factory.unregister_rag_client_factory("never-registered")
def test_unknown_provider_still_raises():
with pytest.raises(ValueError, match="Unsupported provider: nope"):
factory.create_client(SimpleNamespace(provider="nope"))

View File

@@ -0,0 +1,68 @@
"""Tests for the pluggable flow persistence factory seam.
We verify our own logic: that ``default_flow_persistence`` returns the
registered factory's result, and that it falls back to the built-in SQLite
persistence when no factory is registered.
"""
from __future__ import annotations
from typing import Any
import pytest
from pydantic import BaseModel
import crewai.flow.persistence.factory as factory
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.persistence.decorators import persist
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
@pytest.fixture(autouse=True)
def reset_factory():
"""Reset the factory around each test without clobbering preexisting state."""
original = factory._factory
factory.set_flow_persistence_factory(None)
yield
factory.set_flow_persistence_factory(original)
def test_default_uses_registered_factory():
sentinel = SQLiteFlowPersistence()
factory.set_flow_persistence_factory(lambda: sentinel)
assert factory.default_flow_persistence() is sentinel
def test_default_falls_back_to_sqlite():
assert isinstance(factory.default_flow_persistence(), SQLiteFlowPersistence)
def test_persist_decorator_honors_falsy_persistence():
# @persist with an explicit but falsy FlowPersistence must keep it, not
# replace it with the default via a truthiness check.
class _FalsyPersistence(FlowPersistence):
def __bool__(self) -> bool:
return False
def init_db(self) -> None:
pass
def save_state(
self,
flow_uuid: str,
method_name: str,
state_data: dict[str, Any] | BaseModel,
) -> None:
pass
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
return None
falsy = _FalsyPersistence()
@persist(persistence=falsy)
class _DummyFlow:
pass
assert _DummyFlow.__flow_persistence_config__.persistence is falsy