diff --git a/lib/crewai/src/crewai/flow/persistence/decorators.py b/lib/crewai/src/crewai/flow/persistence/decorators.py index 5776e6867..5b0a594e8 100644 --- a/lib/crewai/src/crewai/flow/persistence/decorators.py +++ b/lib/crewai/src/crewai/flow/persistence/decorators.py @@ -35,7 +35,7 @@ from crewai_core.printer import PRINTER from pydantic import BaseModel from crewai.flow.persistence.base import FlowPersistence -from crewai.flow.persistence.sqlite import SQLiteFlowPersistence +from crewai.flow.persistence.factory import default_flow_persistence if TYPE_CHECKING: @@ -171,7 +171,9 @@ def persist( Args: persistence: Optional FlowPersistence implementation to use. - If not provided, uses SQLiteFlowPersistence. + If not provided, uses ``default_flow_persistence()`` (the + registered factory when present, else the built-in SQLite + fallback). verbose: Whether to log persistence operations. Defaults to False. Returns: @@ -190,7 +192,9 @@ def persist( """ def decorator(target: type | Callable[..., T]) -> type | Callable[..., T]: - actual_persistence = persistence or SQLiteFlowPersistence() + actual_persistence = ( + persistence if persistence is not None else default_flow_persistence() + ) if isinstance(target, type): _stamp_persistence_metadata(target, actual_persistence, verbose) diff --git a/lib/crewai/src/crewai/flow/persistence/factory.py b/lib/crewai/src/crewai/flow/persistence/factory.py new file mode 100644 index 000000000..399e6e9ca --- /dev/null +++ b/lib/crewai/src/crewai/flow/persistence/factory.py @@ -0,0 +1,60 @@ +"""Pluggable default persistence backend for flows. + +By default, ``@persist`` and the flow runtime persist state with +:class:`~crewai.flow.persistence.sqlite.SQLiteFlowPersistence` when no explicit +``persistence=`` is given. Registering a factory via +:func:`set_flow_persistence_factory` lets an application back flow state with a +custom :class:`~crewai.flow.persistence.base.FlowPersistence` -- a database, a +remote service, an in-memory fake for tests -- without passing a +``persistence=`` instance at every ``@persist`` / kickoff site. + +This mirrors :func:`crewai_core.lock_store.set_lock_backend`: a one-time, +process-wide setter intended for application startup. Pass ``None`` to restore +the built-in SQLite default. Call :func:`default_flow_persistence` to build the +default backend (the registered factory if any, else SQLite). +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from crewai.flow.persistence.base import FlowPersistence + +FlowPersistenceFactory = Callable[[], "FlowPersistence"] + +_factory: FlowPersistenceFactory | None = None + + +def set_flow_persistence_factory(factory: FlowPersistenceFactory | None) -> None: + """Replace the process-wide default flow persistence factory. + + Intended for one-time setup at startup. Pass ``None`` to restore the + built-in ``SQLiteFlowPersistence``. Only affects flows that fall back to + the default; an explicit ``persistence=`` instance always wins. + + The default is resolved at each fall-back site (``@persist`` and the + runtime's pause/resume paths), so the factory may be called more than once + for a single flow. Return instances backed by shared durable state (or a + singleton) so state saved on one call is visible to the next -- the + built-in SQLite default satisfies this by sharing one on-disk file. + """ + global _factory + _factory = factory + + +def default_flow_persistence() -> FlowPersistence: + """Build the default flow persistence backend. + + Returns the result of the registered factory if one is set, otherwise a + built-in :class:`~crewai.flow.persistence.sqlite.SQLiteFlowPersistence`. + """ + factory = _factory + if factory is not None: + return factory() + + from crewai.flow.persistence.sqlite import SQLiteFlowPersistence + + return SQLiteFlowPersistence() diff --git a/lib/crewai/src/crewai/flow/runtime.py b/lib/crewai/src/crewai/flow/runtime.py index 7f71c8930..34b796f07 100644 --- a/lib/crewai/src/crewai/flow/runtime.py +++ b/lib/crewai/src/crewai/flow/runtime.py @@ -1252,7 +1252,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta): Args: flow_id: The unique identifier of the paused flow (from state.id) persistence: The persistence backend where the state was saved. - If not provided, defaults to SQLiteFlowPersistence(). + If not provided, uses ``default_flow_persistence()`` (the + registered factory when present, else the built-in SQLite + fallback). **kwargs: Additional keyword arguments passed to the Flow constructor Returns: @@ -1274,9 +1276,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta): ``` """ if persistence is None: - from crewai.flow.persistence import SQLiteFlowPersistence + from crewai.flow.persistence.factory import default_flow_persistence - persistence = SQLiteFlowPersistence() + persistence = default_flow_persistence() loaded = persistence.load_pending_feedback(flow_id) if loaded is None: @@ -1463,7 +1465,7 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta): self._pending_feedback_context = None - if self.persistence: + if self.persistence is not None: self.persistence.clear_pending_feedback(context.flow_id) crewai_event_bus.emit( @@ -1505,9 +1507,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta): self._pending_feedback_context = e.context if self.persistence is None: - from crewai.flow.persistence import SQLiteFlowPersistence + from crewai.flow.persistence.factory import default_flow_persistence - self.persistence = SQLiteFlowPersistence() + self.persistence = default_flow_persistence() state_data = ( self._state @@ -2244,9 +2246,11 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta): if isinstance(e, HumanFeedbackPending): # Auto-save pending feedback (create default persistence if needed) if self.persistence is None: - from crewai.flow.persistence import SQLiteFlowPersistence + from crewai.flow.persistence.factory import ( + default_flow_persistence, + ) - self.persistence = SQLiteFlowPersistence() + self.persistence = default_flow_persistence() state_data = ( self._state @@ -2597,9 +2601,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta): e.context.method_name = method_name if self.persistence is None: - from crewai.flow.persistence import SQLiteFlowPersistence + from crewai.flow.persistence.factory import default_flow_persistence - self.persistence = SQLiteFlowPersistence() + self.persistence = default_flow_persistence() # Emit paused event (not failed) if not self.suppress_flow_events: diff --git a/lib/crewai/src/crewai/knowledge/knowledge.py b/lib/crewai/src/crewai/knowledge/knowledge.py index fd391635e..76198fec9 100644 --- a/lib/crewai/src/crewai/knowledge/knowledge.py +++ b/lib/crewai/src/crewai/knowledge/knowledge.py @@ -13,6 +13,7 @@ from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSourc from crewai.knowledge.source.text_file_knowledge_source import ( TextFileKnowledgeSource, ) +from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider from crewai.rag.embeddings.types import EmbedderConfig @@ -89,7 +90,7 @@ class Knowledge(BaseModel): Knowledge is a collection of sources and setup for the vector store to save and query relevant context. Args: sources: list[BaseKnowledgeSource] = Field(default_factory=list) - storage: KnowledgeStorage | None = Field(default=None) + storage: BaseKnowledgeStorage | None = Field(default=None) embedder: EmbedderConfig | None = None """ @@ -98,7 +99,7 @@ class Knowledge(BaseModel): BeforeValidator(_resolve_knowledge_sources), ] = Field(default_factory=list) model_config = ConfigDict(arbitrary_types_allowed=True) - storage: KnowledgeStorage | None = Field(default=None) + storage: BaseKnowledgeStorage | None = Field(default=None) embedder: Annotated[ EmbedderConfig | None, PlainSerializer( @@ -112,15 +113,22 @@ class Knowledge(BaseModel): collection_name: str, sources: list[BaseKnowledgeSource], embedder: EmbedderConfig | None = None, - storage: KnowledgeStorage | None = None, + storage: BaseKnowledgeStorage | None = None, **data: object, ) -> None: super().__init__(**data) - if storage: + if storage is not None: self.storage = storage else: - self.storage = KnowledgeStorage( - embedder=embedder, collection_name=collection_name + from crewai.knowledge.storage.factory import resolve_knowledge_storage + + custom = resolve_knowledge_storage(embedder, collection_name) + self.storage = ( + custom + if custom is not None + else KnowledgeStorage( + embedder=embedder, collection_name=collection_name + ) ) self.sources = sources @@ -152,10 +160,9 @@ class Knowledge(BaseModel): raise e def reset(self) -> None: - if self.storage: - self.storage.reset() - else: + if self.storage is None: raise ValueError("Storage is not initialized.") + self.storage.reset() async def aquery( self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6 @@ -193,7 +200,6 @@ class Knowledge(BaseModel): async def areset(self) -> None: """Reset the knowledge base asynchronously.""" - if self.storage: - await self.storage.areset() - else: + if self.storage is None: raise ValueError("Storage is not initialized.") + await self.storage.areset() diff --git a/lib/crewai/src/crewai/knowledge/source/base_file_knowledge_source.py b/lib/crewai/src/crewai/knowledge/source/base_file_knowledge_source.py index 1ceeff5b4..1a668c0c6 100644 --- a/lib/crewai/src/crewai/knowledge/source/base_file_knowledge_source.py +++ b/lib/crewai/src/crewai/knowledge/source/base_file_knowledge_source.py @@ -5,7 +5,7 @@ from typing import Any from pydantic import Field, field_validator from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource -from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage +from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage from crewai.utilities.constants import KNOWLEDGE_DIRECTORY from crewai.utilities.logger import Logger @@ -22,7 +22,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC): default_factory=list, description="The path to the file" ) content: dict[Path, str] = Field(init=False, default_factory=dict) - storage: KnowledgeStorage | None = Field(default=None) + storage: BaseKnowledgeStorage | None = Field(default=None) safe_file_paths: list[Path] = Field(default_factory=list) @field_validator("file_path", "file_paths", mode="before") @@ -70,14 +70,14 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC): def _save_documents(self) -> None: """Save the documents to the storage.""" - if self.storage: + if self.storage is not None: self.storage.save(self.chunks) else: raise ValueError("No storage found to save documents.") async def _asave_documents(self) -> None: """Save the documents to the storage asynchronously.""" - if self.storage: + if self.storage is not None: await self.storage.asave(self.chunks) else: raise ValueError("No storage found to save documents.") diff --git a/lib/crewai/src/crewai/knowledge/source/base_knowledge_source.py b/lib/crewai/src/crewai/knowledge/source/base_knowledge_source.py index 8c99b47b0..a5c557cb5 100644 --- a/lib/crewai/src/crewai/knowledge/source/base_knowledge_source.py +++ b/lib/crewai/src/crewai/knowledge/source/base_knowledge_source.py @@ -4,9 +4,15 @@ from typing import Any import numpy as np from pydantic import BaseModel, ConfigDict, Field +from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage +# ``KnowledgeStorage`` is re-exported for backwards compatibility; the ``storage`` +# field below is typed to the base interface so any backend plugs in. +__all__ = ["BaseKnowledgeSource", "KnowledgeStorage"] + + class BaseKnowledgeSource(BaseModel, ABC): """Abstract base class for knowledge sources.""" @@ -18,7 +24,7 @@ class BaseKnowledgeSource(BaseModel, ABC): ) model_config = ConfigDict(arbitrary_types_allowed=True) - storage: KnowledgeStorage | None = Field(default=None) + storage: BaseKnowledgeStorage | None = Field(default=None) metadata: dict[str, Any] = Field(default_factory=dict) # Currently unused collection_name: str | None = Field(default=None) @@ -49,7 +55,7 @@ class BaseKnowledgeSource(BaseModel, ABC): Raises: ValueError: If no storage is configured. """ - if self.storage: + if self.storage is not None: self.storage.save(self.chunks) else: raise ValueError("No storage found to save documents.") @@ -66,7 +72,7 @@ class BaseKnowledgeSource(BaseModel, ABC): Raises: ValueError: If no storage is configured. """ - if self.storage: + if self.storage is not None: await self.storage.asave(self.chunks) else: raise ValueError("No storage found to save documents.") diff --git a/lib/crewai/src/crewai/knowledge/storage/factory.py b/lib/crewai/src/crewai/knowledge/storage/factory.py new file mode 100644 index 000000000..4a401b80a --- /dev/null +++ b/lib/crewai/src/crewai/knowledge/storage/factory.py @@ -0,0 +1,56 @@ +"""Pluggable default storage backend for knowledge collections. + +By default, :class:`~crewai.knowledge.knowledge.Knowledge` builds a +:class:`~crewai.knowledge.storage.knowledge_storage.KnowledgeStorage` when no +explicit ``storage=`` is given. Registering a factory via +:func:`set_knowledge_storage_factory` lets an application back knowledge with a +custom :class:`~crewai.knowledge.storage.base_knowledge_storage.BaseKnowledgeStorage` +without subclassing ``Knowledge`` or passing a ``storage=`` instance at every +call site. + +This mirrors :func:`crewai_core.lock_store.set_lock_backend`: a one-time, +process-wide setter intended for application startup. Pass ``None`` to restore +the built-in default. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage + from crewai.rag.embeddings.types import EmbedderConfig + +# Receives the same inputs as the built-in default -- the embedder config and +# collection name -- and returns a storage backend, or ``None`` to defer to the +# built-in ``KnowledgeStorage``. +KnowledgeStorageFactory = Callable[ + ["EmbedderConfig | None", "str | None"], "BaseKnowledgeStorage | None" +] + +_factory: KnowledgeStorageFactory | None = None + + +def set_knowledge_storage_factory(factory: KnowledgeStorageFactory | None) -> None: + """Replace the process-wide default knowledge storage factory. + + Intended for one-time setup at startup. Pass ``None`` to restore the + built-in ``KnowledgeStorage``. Only affects ``Knowledge`` instances + constructed afterwards; an explicit ``storage=`` instance always wins. + """ + global _factory + _factory = factory + + +def resolve_knowledge_storage( + embedder: EmbedderConfig | None, collection_name: str | None +) -> BaseKnowledgeStorage | None: + """Return the registered factory's backend, or ``None`` for the built-in. + + ``None`` means no factory is registered or it declined; the caller then + falls back to the built-in ``KnowledgeStorage``. + """ + factory = _factory + return factory(embedder, collection_name) if factory is not None else None diff --git a/lib/crewai/src/crewai/memory/storage/factory.py b/lib/crewai/src/crewai/memory/storage/factory.py new file mode 100644 index 000000000..3dac6dcd4 --- /dev/null +++ b/lib/crewai/src/crewai/memory/storage/factory.py @@ -0,0 +1,55 @@ +"""Pluggable default storage backend for the unified memory system. + +By default, :class:`~crewai.memory.unified_memory.Memory` builds a built-in +vector store from its ``storage`` spec string (LanceDB, or Qdrant for the +``"qdrant-edge"`` spec). Registering a factory via +:func:`set_memory_storage_factory` lets an application route memory through a +custom :class:`~crewai.memory.storage.backend.StorageBackend` -- a different +vector store, a remote service, an in-memory fake for tests -- without +subclassing ``Memory`` or threading an explicit ``storage=`` instance through +every construction site. + +This mirrors :func:`crewai_core.lock_store.set_lock_backend`: a one-time, +process-wide setter intended for application startup. Pass ``None`` to restore +the built-in default. +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from crewai.memory.storage.backend import StorageBackend + +# Receives the raw ``storage`` spec string and returns a backend to use, or +# ``None`` to defer to the built-in selection for that spec. +MemoryStorageFactory = Callable[[str], "StorageBackend | None"] + +_factory: MemoryStorageFactory | None = None + + +def set_memory_storage_factory(factory: MemoryStorageFactory | None) -> None: + """Replace the process-wide default memory storage factory. + + Intended for one-time setup at startup. Pass ``None`` to restore the + built-in LanceDB/Qdrant selection. Only affects ``Memory`` instances + constructed afterwards; an explicit ``storage=`` instance always wins. + + The factory is consulted for every string ``storage`` spec, so it must + return ``None`` for specs it does not handle to let the built-in + LanceDB/Qdrant/path selection take over. + """ + global _factory + _factory = factory + + +def resolve_memory_storage(spec: str) -> StorageBackend | None: + """Return the registered factory's backend for ``spec``, or ``None``. + + ``None`` means no factory is registered or it declined this spec; the + caller then falls back to the built-in selection. + """ + factory = _factory + return factory(spec) if factory is not None else None diff --git a/lib/crewai/src/crewai/memory/unified_memory.py b/lib/crewai/src/crewai/memory/unified_memory.py index 02c181822..75191b203 100644 --- a/lib/crewai/src/crewai/memory/unified_memory.py +++ b/lib/crewai/src/crewai/memory/unified_memory.py @@ -204,7 +204,12 @@ class Memory(BaseModel): ) if isinstance(self.storage, str): - if self.storage == "qdrant-edge": + from crewai.memory.storage.factory import resolve_memory_storage + + custom = resolve_memory_storage(self.storage) + if custom is not None: + self._storage = custom + elif self.storage == "qdrant-edge": from crewai.memory.storage.qdrant_edge_storage import QdrantEdgeStorage self._storage = QdrantEdgeStorage() diff --git a/lib/crewai/src/crewai/rag/factory.py b/lib/crewai/src/crewai/rag/factory.py index 47fc6cb62..0993c445e 100644 --- a/lib/crewai/src/crewai/rag/factory.py +++ b/lib/crewai/src/crewai/rag/factory.py @@ -1,5 +1,6 @@ """Factory functions for creating RAG clients from configuration.""" +from collections.abc import Callable from typing import cast from crewai.rag.config.optional_imports.protocols import ( @@ -11,6 +12,32 @@ from crewai.rag.core.base_client import BaseClient from crewai.utilities.import_utils import require +# RAG uses a provider-keyed registry (rather than the single-default setter +# used by the memory/knowledge/flow seams) because ``create_client`` already +# dispatches on ``config.provider`` -- the natural seam here is per-provider. +# A factory receives the RAG config and returns a client; one registered for a +# built-in provider name overrides the built-in for that provider. +RagClientFactory = Callable[[RagConfigType], BaseClient] + +_factories: dict[str, RagClientFactory] = {} + + +def register_rag_client_factory(provider: str, factory: RagClientFactory) -> None: + """Register a client factory for a RAG ``provider`` name. + + Lets an application plug in a client for a new provider, or override a + built-in provider (``"chromadb"`` / ``"qdrant"``), without modifying + :func:`create_client`. Registered factories take precedence over the + built-ins. Intended for one-time setup at startup. + """ + _factories[provider] = factory + + +def unregister_rag_client_factory(provider: str) -> None: + """Remove a previously registered factory; a no-op if none is registered.""" + _factories.pop(provider, None) + + def create_client(config: RagConfigType) -> BaseClient: """Create a client from configuration using the appropriate factory. @@ -24,6 +51,10 @@ def create_client(config: RagConfigType) -> BaseClient: ValueError: If the configuration provider is not supported. """ + factory = _factories.get(config.provider) + if factory is not None: + return factory(config) + if config.provider == "chromadb": chromadb_mod = cast( ChromaFactoryModule, diff --git a/lib/crewai/tests/knowledge/test_storage_factory.py b/lib/crewai/tests/knowledge/test_storage_factory.py new file mode 100644 index 000000000..5d8512f7c --- /dev/null +++ b/lib/crewai/tests/knowledge/test_storage_factory.py @@ -0,0 +1,130 @@ +"""Tests for the pluggable knowledge storage factory seam. + +We verify our own logic: the set/get round-trip, that a registered factory is +consulted when no explicit ``storage=`` is given (and receives the embedder and +collection name), and that an explicit ``storage=`` instance bypasses it. +""" + +from __future__ import annotations + +from typing import Any + +import pytest + +import crewai.knowledge.storage.factory as factory +from crewai.knowledge.knowledge import Knowledge +from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage +from crewai.rag.types import SearchResult + + +class _FakeKnowledgeStorage(BaseKnowledgeStorage): + """Minimal stand-in implementing the abstract interface.""" + + def search( + self, + query: list[str], + limit: int = 5, + metadata_filter: dict[str, Any] | None = None, + score_threshold: float = 0.6, + ) -> list[SearchResult]: + return [] + + async def asearch( + self, + query: list[str], + limit: int = 5, + metadata_filter: dict[str, Any] | None = None, + score_threshold: float = 0.6, + ) -> list[SearchResult]: + return [] + + def save(self, documents: list[str]) -> None: + return None + + async def asave(self, documents: list[str]) -> None: + return None + + def reset(self) -> None: + return None + + async def areset(self) -> None: + return None + + +@pytest.fixture(autouse=True) +def reset_factory(): + """Reset the factory around each test without clobbering preexisting state.""" + original = factory._factory + factory.set_knowledge_storage_factory(None) + yield + factory.set_knowledge_storage_factory(original) + + +def test_resolve_reflects_registered_factory(): + fake = _FakeKnowledgeStorage() + assert factory.resolve_knowledge_storage(None, "docs") is None + + factory.set_knowledge_storage_factory(lambda embedder, name: fake) + assert factory.resolve_knowledge_storage(None, "docs") is fake + + +def test_factory_used_when_no_explicit_storage(): + fake = _FakeKnowledgeStorage() + factory.set_knowledge_storage_factory(lambda embedder, name: fake) + + knowledge = Knowledge(collection_name="docs", sources=[]) + + assert knowledge.storage is fake + + +def test_factory_receives_embedder_and_collection_name(): + seen: list[tuple[object, object]] = [] + + def make(embedder, collection_name): + seen.append((embedder, collection_name)) + return _FakeKnowledgeStorage() + + factory.set_knowledge_storage_factory(make) + Knowledge(collection_name="docs", sources=[]) + + assert seen == [(None, "docs")] + + +def test_explicit_storage_bypasses_factory(): + factory_called = False + + def make(embedder, name): + nonlocal factory_called + factory_called = True + return _FakeKnowledgeStorage() + + factory.set_knowledge_storage_factory(make) + + explicit = _FakeKnowledgeStorage() + knowledge = Knowledge(collection_name="docs", sources=[], storage=explicit) + + assert knowledge.storage is explicit + assert factory_called is False + + +def test_falsy_explicit_storage_is_honored(): + # A custom backend that is falsy (defines __bool__/__len__) must still be + # used and operated on, not silently treated as "not initialized" by a + # truthiness check in __init__, reset, or the source save path. + reset_calls: list[bool] = [] + + class _FalsyStorage(_FakeKnowledgeStorage): + def __bool__(self) -> bool: + return False + + def reset(self) -> None: + reset_calls.append(True) + + explicit = _FalsyStorage() + knowledge = Knowledge(collection_name="docs", sources=[], storage=explicit) + + assert knowledge.storage is explicit + + # reset must call the backend, not raise "Storage is not initialized." + knowledge.reset() + assert reset_calls == [True] diff --git a/lib/crewai/tests/memory/test_storage_factory.py b/lib/crewai/tests/memory/test_storage_factory.py new file mode 100644 index 000000000..45774108b --- /dev/null +++ b/lib/crewai/tests/memory/test_storage_factory.py @@ -0,0 +1,72 @@ +"""Tests for the pluggable memory storage factory seam. + +We verify our own logic: the set/get round-trip, that a registered factory is +consulted for string ``storage`` specs (and receives the spec), and that an +explicit ``storage=`` instance bypasses the factory entirely. +""" + +from __future__ import annotations + +import pytest + +import crewai.memory.storage.factory as factory +from crewai.memory.unified_memory import Memory + + +@pytest.fixture(autouse=True) +def reset_factory(): + """Reset the factory around each test without clobbering preexisting state.""" + original = factory._factory + factory.set_memory_storage_factory(None) + yield + factory.set_memory_storage_factory(original) + + +def test_resolve_reflects_registered_factory(): + sentinel = object() + assert factory.resolve_memory_storage("lancedb") is None + + factory.set_memory_storage_factory(lambda spec: sentinel) + assert factory.resolve_memory_storage("lancedb") is sentinel + + factory.set_memory_storage_factory(None) + assert factory.resolve_memory_storage("lancedb") is None + + +def test_factory_backend_used_for_string_spec(): + sentinel = object() + factory.set_memory_storage_factory(lambda spec: sentinel) + + mem = Memory(storage="lancedb") + + assert mem._storage is sentinel + + +def test_factory_receives_the_raw_spec(): + seen: list[str] = [] + + def make(spec): + seen.append(spec) + return object() + + factory.set_memory_storage_factory(make) + Memory(storage="some/custom/path") + + assert seen == ["some/custom/path"] + + +def test_explicit_storage_instance_bypasses_factory(): + factory_called = False + + def make(spec): + nonlocal factory_called + factory_called = True + return object() + + factory.set_memory_storage_factory(make) + + explicit = object() + mem = Memory(storage=explicit) # type: ignore[arg-type] + + assert mem._storage is explicit + assert factory_called is False diff --git a/lib/crewai/tests/rag/test_client_factory_registry.py b/lib/crewai/tests/rag/test_client_factory_registry.py new file mode 100644 index 000000000..f97f830e7 --- /dev/null +++ b/lib/crewai/tests/rag/test_client_factory_registry.py @@ -0,0 +1,66 @@ +"""Tests for the RAG client factory registry seam. + +We verify our own logic: a registered factory is used for its provider, +factories override the built-in providers, unregister removes them, and an +unknown provider still raises. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +import crewai.rag.factory as factory + + +@pytest.fixture(autouse=True) +def reset_registry(): + """Reset the registry around each test without clobbering preexisting state.""" + original = dict(factory._factories) + factory._factories.clear() + yield + factory._factories.clear() + factory._factories.update(original) + + +def test_registered_factory_is_used_for_its_provider(): + sentinel = object() + factory.register_rag_client_factory("custom", lambda config: sentinel) + + assert factory.create_client(SimpleNamespace(provider="custom")) is sentinel + + +def test_factory_receives_the_config(): + seen: list[object] = [] + config = SimpleNamespace(provider="custom") + factory.register_rag_client_factory("custom", lambda cfg: seen.append(cfg) or object()) + + factory.create_client(config) + + assert seen == [config] + + +def test_factory_overrides_builtin_provider(): + sentinel = object() + factory.register_rag_client_factory("chromadb", lambda config: sentinel) + + # Resolves via the registry without importing the built-in chromadb factory. + assert factory.create_client(SimpleNamespace(provider="chromadb")) is sentinel + + +def test_unregister_removes_factory(): + factory.register_rag_client_factory("custom", lambda config: object()) + factory.unregister_rag_client_factory("custom") + + with pytest.raises(ValueError, match="Unsupported provider: custom"): + factory.create_client(SimpleNamespace(provider="custom")) + + +def test_unregister_unknown_provider_is_noop(): + factory.unregister_rag_client_factory("never-registered") + + +def test_unknown_provider_still_raises(): + with pytest.raises(ValueError, match="Unsupported provider: nope"): + factory.create_client(SimpleNamespace(provider="nope")) diff --git a/lib/crewai/tests/test_flow_persistence_factory.py b/lib/crewai/tests/test_flow_persistence_factory.py new file mode 100644 index 000000000..b90c00aa8 --- /dev/null +++ b/lib/crewai/tests/test_flow_persistence_factory.py @@ -0,0 +1,68 @@ +"""Tests for the pluggable flow persistence factory seam. + +We verify our own logic: that ``default_flow_persistence`` returns the +registered factory's result, and that it falls back to the built-in SQLite +persistence when no factory is registered. +""" + +from __future__ import annotations + +from typing import Any + +import pytest +from pydantic import BaseModel + +import crewai.flow.persistence.factory as factory +from crewai.flow.persistence.base import FlowPersistence +from crewai.flow.persistence.decorators import persist +from crewai.flow.persistence.sqlite import SQLiteFlowPersistence + + +@pytest.fixture(autouse=True) +def reset_factory(): + """Reset the factory around each test without clobbering preexisting state.""" + original = factory._factory + factory.set_flow_persistence_factory(None) + yield + factory.set_flow_persistence_factory(original) + + +def test_default_uses_registered_factory(): + sentinel = SQLiteFlowPersistence() + factory.set_flow_persistence_factory(lambda: sentinel) + + assert factory.default_flow_persistence() is sentinel + + +def test_default_falls_back_to_sqlite(): + assert isinstance(factory.default_flow_persistence(), SQLiteFlowPersistence) + + +def test_persist_decorator_honors_falsy_persistence(): + # @persist with an explicit but falsy FlowPersistence must keep it, not + # replace it with the default via a truthiness check. + class _FalsyPersistence(FlowPersistence): + def __bool__(self) -> bool: + return False + + def init_db(self) -> None: + pass + + def save_state( + self, + flow_uuid: str, + method_name: str, + state_data: dict[str, Any] | BaseModel, + ) -> None: + pass + + def load_state(self, flow_uuid: str) -> dict[str, Any] | None: + return None + + falsy = _FalsyPersistence() + + @persist(persistence=falsy) + class _DummyFlow: + pass + + assert _DummyFlow.__flow_persistence_config__.persistence is falsy