mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-01 05:08:12 +00:00
feat(storage): pluggable default backends for memory, knowledge, rag, flow (#6079)
Some checks failed
Some checks failed
Add opt-in extension seams so an application can route memory, knowledge, RAG, and flow persistence through a custom backend without subclassing or threading an explicit instance through every construction site -- mirroring the existing crewai_core.lock_store.set_lock_backend seam. - memory: crewai.memory.storage.factory.set_memory_storage_factory - knowledge: crewai.knowledge.storage.factory.set_knowledge_storage_factory - rag: crewai.rag.factory.register_rag_client_factory (provider registry) - flow: crewai.flow.persistence.factory.set_flow_persistence_factory Each construction site consults the registered factory and falls back to the built-in default when none is set; an explicit instance always wins. Widen Knowledge.storage and the knowledge source base classes to BaseKnowledgeStorage (consistent with BaseAgent.knowledge_storage) so any base-interface backend plugs in. Runtime-free tests cover each seam.
This commit is contained in:
@@ -35,7 +35,7 @@ from crewai_core.printer import PRINTER
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||
from crewai.flow.persistence.factory import default_flow_persistence
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -171,7 +171,9 @@ def persist(
|
||||
|
||||
Args:
|
||||
persistence: Optional FlowPersistence implementation to use.
|
||||
If not provided, uses SQLiteFlowPersistence.
|
||||
If not provided, uses ``default_flow_persistence()`` (the
|
||||
registered factory when present, else the built-in SQLite
|
||||
fallback).
|
||||
verbose: Whether to log persistence operations. Defaults to False.
|
||||
|
||||
Returns:
|
||||
@@ -190,7 +192,9 @@ def persist(
|
||||
"""
|
||||
|
||||
def decorator(target: type | Callable[..., T]) -> type | Callable[..., T]:
|
||||
actual_persistence = persistence or SQLiteFlowPersistence()
|
||||
actual_persistence = (
|
||||
persistence if persistence is not None else default_flow_persistence()
|
||||
)
|
||||
|
||||
if isinstance(target, type):
|
||||
_stamp_persistence_metadata(target, actual_persistence, verbose)
|
||||
|
||||
60
lib/crewai/src/crewai/flow/persistence/factory.py
Normal file
60
lib/crewai/src/crewai/flow/persistence/factory.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Pluggable default persistence backend for flows.
|
||||
|
||||
By default, ``@persist`` and the flow runtime persist state with
|
||||
:class:`~crewai.flow.persistence.sqlite.SQLiteFlowPersistence` when no explicit
|
||||
``persistence=`` is given. Registering a factory via
|
||||
:func:`set_flow_persistence_factory` lets an application back flow state with a
|
||||
custom :class:`~crewai.flow.persistence.base.FlowPersistence` -- a database, a
|
||||
remote service, an in-memory fake for tests -- without passing a
|
||||
``persistence=`` instance at every ``@persist`` / kickoff site.
|
||||
|
||||
This mirrors :func:`crewai_core.lock_store.set_lock_backend`: a one-time,
|
||||
process-wide setter intended for application startup. Pass ``None`` to restore
|
||||
the built-in SQLite default. Call :func:`default_flow_persistence` to build the
|
||||
default backend (the registered factory if any, else SQLite).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
|
||||
FlowPersistenceFactory = Callable[[], "FlowPersistence"]
|
||||
|
||||
_factory: FlowPersistenceFactory | None = None
|
||||
|
||||
|
||||
def set_flow_persistence_factory(factory: FlowPersistenceFactory | None) -> None:
|
||||
"""Replace the process-wide default flow persistence factory.
|
||||
|
||||
Intended for one-time setup at startup. Pass ``None`` to restore the
|
||||
built-in ``SQLiteFlowPersistence``. Only affects flows that fall back to
|
||||
the default; an explicit ``persistence=`` instance always wins.
|
||||
|
||||
The default is resolved at each fall-back site (``@persist`` and the
|
||||
runtime's pause/resume paths), so the factory may be called more than once
|
||||
for a single flow. Return instances backed by shared durable state (or a
|
||||
singleton) so state saved on one call is visible to the next -- the
|
||||
built-in SQLite default satisfies this by sharing one on-disk file.
|
||||
"""
|
||||
global _factory
|
||||
_factory = factory
|
||||
|
||||
|
||||
def default_flow_persistence() -> FlowPersistence:
|
||||
"""Build the default flow persistence backend.
|
||||
|
||||
Returns the result of the registered factory if one is set, otherwise a
|
||||
built-in :class:`~crewai.flow.persistence.sqlite.SQLiteFlowPersistence`.
|
||||
"""
|
||||
factory = _factory
|
||||
if factory is not None:
|
||||
return factory()
|
||||
|
||||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||
|
||||
return SQLiteFlowPersistence()
|
||||
@@ -1252,7 +1252,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
Args:
|
||||
flow_id: The unique identifier of the paused flow (from state.id)
|
||||
persistence: The persistence backend where the state was saved.
|
||||
If not provided, defaults to SQLiteFlowPersistence().
|
||||
If not provided, uses ``default_flow_persistence()`` (the
|
||||
registered factory when present, else the built-in SQLite
|
||||
fallback).
|
||||
**kwargs: Additional keyword arguments passed to the Flow constructor
|
||||
|
||||
Returns:
|
||||
@@ -1274,9 +1276,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
```
|
||||
"""
|
||||
if persistence is None:
|
||||
from crewai.flow.persistence import SQLiteFlowPersistence
|
||||
from crewai.flow.persistence.factory import default_flow_persistence
|
||||
|
||||
persistence = SQLiteFlowPersistence()
|
||||
persistence = default_flow_persistence()
|
||||
|
||||
loaded = persistence.load_pending_feedback(flow_id)
|
||||
if loaded is None:
|
||||
@@ -1463,7 +1465,7 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
|
||||
self._pending_feedback_context = None
|
||||
|
||||
if self.persistence:
|
||||
if self.persistence is not None:
|
||||
self.persistence.clear_pending_feedback(context.flow_id)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
@@ -1505,9 +1507,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
self._pending_feedback_context = e.context
|
||||
|
||||
if self.persistence is None:
|
||||
from crewai.flow.persistence import SQLiteFlowPersistence
|
||||
from crewai.flow.persistence.factory import default_flow_persistence
|
||||
|
||||
self.persistence = SQLiteFlowPersistence()
|
||||
self.persistence = default_flow_persistence()
|
||||
|
||||
state_data = (
|
||||
self._state
|
||||
@@ -2244,9 +2246,11 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
if isinstance(e, HumanFeedbackPending):
|
||||
# Auto-save pending feedback (create default persistence if needed)
|
||||
if self.persistence is None:
|
||||
from crewai.flow.persistence import SQLiteFlowPersistence
|
||||
from crewai.flow.persistence.factory import (
|
||||
default_flow_persistence,
|
||||
)
|
||||
|
||||
self.persistence = SQLiteFlowPersistence()
|
||||
self.persistence = default_flow_persistence()
|
||||
|
||||
state_data = (
|
||||
self._state
|
||||
@@ -2597,9 +2601,9 @@ class Flow(_ConversationalMixin, BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
e.context.method_name = method_name
|
||||
|
||||
if self.persistence is None:
|
||||
from crewai.flow.persistence import SQLiteFlowPersistence
|
||||
from crewai.flow.persistence.factory import default_flow_persistence
|
||||
|
||||
self.persistence = SQLiteFlowPersistence()
|
||||
self.persistence = default_flow_persistence()
|
||||
|
||||
# Emit paused event (not failed)
|
||||
if not self.suppress_flow_events:
|
||||
|
||||
@@ -13,6 +13,7 @@ from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSourc
|
||||
from crewai.knowledge.source.text_file_knowledge_source import (
|
||||
TextFileKnowledgeSource,
|
||||
)
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
@@ -89,7 +90,7 @@ class Knowledge(BaseModel):
|
||||
Knowledge is a collection of sources and setup for the vector store to save and query relevant context.
|
||||
Args:
|
||||
sources: list[BaseKnowledgeSource] = Field(default_factory=list)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
storage: BaseKnowledgeStorage | None = Field(default=None)
|
||||
embedder: EmbedderConfig | None = None
|
||||
"""
|
||||
|
||||
@@ -98,7 +99,7 @@ class Knowledge(BaseModel):
|
||||
BeforeValidator(_resolve_knowledge_sources),
|
||||
] = Field(default_factory=list)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
storage: BaseKnowledgeStorage | None = Field(default=None)
|
||||
embedder: Annotated[
|
||||
EmbedderConfig | None,
|
||||
PlainSerializer(
|
||||
@@ -112,15 +113,22 @@ class Knowledge(BaseModel):
|
||||
collection_name: str,
|
||||
sources: list[BaseKnowledgeSource],
|
||||
embedder: EmbedderConfig | None = None,
|
||||
storage: KnowledgeStorage | None = None,
|
||||
storage: BaseKnowledgeStorage | None = None,
|
||||
**data: object,
|
||||
) -> None:
|
||||
super().__init__(**data)
|
||||
if storage:
|
||||
if storage is not None:
|
||||
self.storage = storage
|
||||
else:
|
||||
self.storage = KnowledgeStorage(
|
||||
embedder=embedder, collection_name=collection_name
|
||||
from crewai.knowledge.storage.factory import resolve_knowledge_storage
|
||||
|
||||
custom = resolve_knowledge_storage(embedder, collection_name)
|
||||
self.storage = (
|
||||
custom
|
||||
if custom is not None
|
||||
else KnowledgeStorage(
|
||||
embedder=embedder, collection_name=collection_name
|
||||
)
|
||||
)
|
||||
self.sources = sources
|
||||
|
||||
@@ -152,10 +160,9 @@ class Knowledge(BaseModel):
|
||||
raise e
|
||||
|
||||
def reset(self) -> None:
|
||||
if self.storage:
|
||||
self.storage.reset()
|
||||
else:
|
||||
if self.storage is None:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
self.storage.reset()
|
||||
|
||||
async def aquery(
|
||||
self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6
|
||||
@@ -193,7 +200,6 @@ class Knowledge(BaseModel):
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the knowledge base asynchronously."""
|
||||
if self.storage:
|
||||
await self.storage.areset()
|
||||
else:
|
||||
if self.storage is None:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
await self.storage.areset()
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
@@ -22,7 +22,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
default_factory=list, description="The path to the file"
|
||||
)
|
||||
content: dict[Path, str] = Field(init=False, default_factory=dict)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
storage: BaseKnowledgeStorage | None = Field(default=None)
|
||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||
|
||||
@field_validator("file_path", "file_paths", mode="before")
|
||||
@@ -70,14 +70,14 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
|
||||
def _save_documents(self) -> None:
|
||||
"""Save the documents to the storage."""
|
||||
if self.storage:
|
||||
if self.storage is not None:
|
||||
self.storage.save(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
async def _asave_documents(self) -> None:
|
||||
"""Save the documents to the storage asynchronously."""
|
||||
if self.storage:
|
||||
if self.storage is not None:
|
||||
await self.storage.asave(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
@@ -4,9 +4,15 @@ from typing import Any
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
|
||||
|
||||
|
||||
# ``KnowledgeStorage`` is re-exported for backwards compatibility; the ``storage``
|
||||
# field below is typed to the base interface so any backend plugs in.
|
||||
__all__ = ["BaseKnowledgeSource", "KnowledgeStorage"]
|
||||
|
||||
|
||||
class BaseKnowledgeSource(BaseModel, ABC):
|
||||
"""Abstract base class for knowledge sources."""
|
||||
|
||||
@@ -18,7 +24,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
)
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
storage: KnowledgeStorage | None = Field(default=None)
|
||||
storage: BaseKnowledgeStorage | None = Field(default=None)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict) # Currently unused
|
||||
collection_name: str | None = Field(default=None)
|
||||
|
||||
@@ -49,7 +55,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
Raises:
|
||||
ValueError: If no storage is configured.
|
||||
"""
|
||||
if self.storage:
|
||||
if self.storage is not None:
|
||||
self.storage.save(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
@@ -66,7 +72,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
Raises:
|
||||
ValueError: If no storage is configured.
|
||||
"""
|
||||
if self.storage:
|
||||
if self.storage is not None:
|
||||
await self.storage.asave(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
56
lib/crewai/src/crewai/knowledge/storage/factory.py
Normal file
56
lib/crewai/src/crewai/knowledge/storage/factory.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Pluggable default storage backend for knowledge collections.
|
||||
|
||||
By default, :class:`~crewai.knowledge.knowledge.Knowledge` builds a
|
||||
:class:`~crewai.knowledge.storage.knowledge_storage.KnowledgeStorage` when no
|
||||
explicit ``storage=`` is given. Registering a factory via
|
||||
:func:`set_knowledge_storage_factory` lets an application back knowledge with a
|
||||
custom :class:`~crewai.knowledge.storage.base_knowledge_storage.BaseKnowledgeStorage`
|
||||
without subclassing ``Knowledge`` or passing a ``storage=`` instance at every
|
||||
call site.
|
||||
|
||||
This mirrors :func:`crewai_core.lock_store.set_lock_backend`: a one-time,
|
||||
process-wide setter intended for application startup. Pass ``None`` to restore
|
||||
the built-in default.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
|
||||
# Receives the same inputs as the built-in default -- the embedder config and
|
||||
# collection name -- and returns a storage backend, or ``None`` to defer to the
|
||||
# built-in ``KnowledgeStorage``.
|
||||
KnowledgeStorageFactory = Callable[
|
||||
["EmbedderConfig | None", "str | None"], "BaseKnowledgeStorage | None"
|
||||
]
|
||||
|
||||
_factory: KnowledgeStorageFactory | None = None
|
||||
|
||||
|
||||
def set_knowledge_storage_factory(factory: KnowledgeStorageFactory | None) -> None:
|
||||
"""Replace the process-wide default knowledge storage factory.
|
||||
|
||||
Intended for one-time setup at startup. Pass ``None`` to restore the
|
||||
built-in ``KnowledgeStorage``. Only affects ``Knowledge`` instances
|
||||
constructed afterwards; an explicit ``storage=`` instance always wins.
|
||||
"""
|
||||
global _factory
|
||||
_factory = factory
|
||||
|
||||
|
||||
def resolve_knowledge_storage(
|
||||
embedder: EmbedderConfig | None, collection_name: str | None
|
||||
) -> BaseKnowledgeStorage | None:
|
||||
"""Return the registered factory's backend, or ``None`` for the built-in.
|
||||
|
||||
``None`` means no factory is registered or it declined; the caller then
|
||||
falls back to the built-in ``KnowledgeStorage``.
|
||||
"""
|
||||
factory = _factory
|
||||
return factory(embedder, collection_name) if factory is not None else None
|
||||
55
lib/crewai/src/crewai/memory/storage/factory.py
Normal file
55
lib/crewai/src/crewai/memory/storage/factory.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Pluggable default storage backend for the unified memory system.
|
||||
|
||||
By default, :class:`~crewai.memory.unified_memory.Memory` builds a built-in
|
||||
vector store from its ``storage`` spec string (LanceDB, or Qdrant for the
|
||||
``"qdrant-edge"`` spec). Registering a factory via
|
||||
:func:`set_memory_storage_factory` lets an application route memory through a
|
||||
custom :class:`~crewai.memory.storage.backend.StorageBackend` -- a different
|
||||
vector store, a remote service, an in-memory fake for tests -- without
|
||||
subclassing ``Memory`` or threading an explicit ``storage=`` instance through
|
||||
every construction site.
|
||||
|
||||
This mirrors :func:`crewai_core.lock_store.set_lock_backend`: a one-time,
|
||||
process-wide setter intended for application startup. Pass ``None`` to restore
|
||||
the built-in default.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.memory.storage.backend import StorageBackend
|
||||
|
||||
# Receives the raw ``storage`` spec string and returns a backend to use, or
|
||||
# ``None`` to defer to the built-in selection for that spec.
|
||||
MemoryStorageFactory = Callable[[str], "StorageBackend | None"]
|
||||
|
||||
_factory: MemoryStorageFactory | None = None
|
||||
|
||||
|
||||
def set_memory_storage_factory(factory: MemoryStorageFactory | None) -> None:
|
||||
"""Replace the process-wide default memory storage factory.
|
||||
|
||||
Intended for one-time setup at startup. Pass ``None`` to restore the
|
||||
built-in LanceDB/Qdrant selection. Only affects ``Memory`` instances
|
||||
constructed afterwards; an explicit ``storage=`` instance always wins.
|
||||
|
||||
The factory is consulted for every string ``storage`` spec, so it must
|
||||
return ``None`` for specs it does not handle to let the built-in
|
||||
LanceDB/Qdrant/path selection take over.
|
||||
"""
|
||||
global _factory
|
||||
_factory = factory
|
||||
|
||||
|
||||
def resolve_memory_storage(spec: str) -> StorageBackend | None:
|
||||
"""Return the registered factory's backend for ``spec``, or ``None``.
|
||||
|
||||
``None`` means no factory is registered or it declined this spec; the
|
||||
caller then falls back to the built-in selection.
|
||||
"""
|
||||
factory = _factory
|
||||
return factory(spec) if factory is not None else None
|
||||
@@ -204,7 +204,12 @@ class Memory(BaseModel):
|
||||
)
|
||||
|
||||
if isinstance(self.storage, str):
|
||||
if self.storage == "qdrant-edge":
|
||||
from crewai.memory.storage.factory import resolve_memory_storage
|
||||
|
||||
custom = resolve_memory_storage(self.storage)
|
||||
if custom is not None:
|
||||
self._storage = custom
|
||||
elif self.storage == "qdrant-edge":
|
||||
from crewai.memory.storage.qdrant_edge_storage import QdrantEdgeStorage
|
||||
|
||||
self._storage = QdrantEdgeStorage()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Factory functions for creating RAG clients from configuration."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from crewai.rag.config.optional_imports.protocols import (
|
||||
@@ -11,6 +12,32 @@ from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.utilities.import_utils import require
|
||||
|
||||
|
||||
# RAG uses a provider-keyed registry (rather than the single-default setter
|
||||
# used by the memory/knowledge/flow seams) because ``create_client`` already
|
||||
# dispatches on ``config.provider`` -- the natural seam here is per-provider.
|
||||
# A factory receives the RAG config and returns a client; one registered for a
|
||||
# built-in provider name overrides the built-in for that provider.
|
||||
RagClientFactory = Callable[[RagConfigType], BaseClient]
|
||||
|
||||
_factories: dict[str, RagClientFactory] = {}
|
||||
|
||||
|
||||
def register_rag_client_factory(provider: str, factory: RagClientFactory) -> None:
|
||||
"""Register a client factory for a RAG ``provider`` name.
|
||||
|
||||
Lets an application plug in a client for a new provider, or override a
|
||||
built-in provider (``"chromadb"`` / ``"qdrant"``), without modifying
|
||||
:func:`create_client`. Registered factories take precedence over the
|
||||
built-ins. Intended for one-time setup at startup.
|
||||
"""
|
||||
_factories[provider] = factory
|
||||
|
||||
|
||||
def unregister_rag_client_factory(provider: str) -> None:
|
||||
"""Remove a previously registered factory; a no-op if none is registered."""
|
||||
_factories.pop(provider, None)
|
||||
|
||||
|
||||
def create_client(config: RagConfigType) -> BaseClient:
|
||||
"""Create a client from configuration using the appropriate factory.
|
||||
|
||||
@@ -24,6 +51,10 @@ def create_client(config: RagConfigType) -> BaseClient:
|
||||
ValueError: If the configuration provider is not supported.
|
||||
"""
|
||||
|
||||
factory = _factories.get(config.provider)
|
||||
if factory is not None:
|
||||
return factory(config)
|
||||
|
||||
if config.provider == "chromadb":
|
||||
chromadb_mod = cast(
|
||||
ChromaFactoryModule,
|
||||
|
||||
130
lib/crewai/tests/knowledge/test_storage_factory.py
Normal file
130
lib/crewai/tests/knowledge/test_storage_factory.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Tests for the pluggable knowledge storage factory seam.
|
||||
|
||||
We verify our own logic: the set/get round-trip, that a registered factory is
|
||||
consulted when no explicit ``storage=`` is given (and receives the embedder and
|
||||
collection name), and that an explicit ``storage=`` instance bypasses it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
import crewai.knowledge.storage.factory as factory
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
|
||||
class _FakeKnowledgeStorage(BaseKnowledgeStorage):
|
||||
"""Minimal stand-in implementing the abstract interface."""
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: list[str],
|
||||
limit: int = 5,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[SearchResult]:
|
||||
return []
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: list[str],
|
||||
limit: int = 5,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[SearchResult]:
|
||||
return []
|
||||
|
||||
def save(self, documents: list[str]) -> None:
|
||||
return None
|
||||
|
||||
async def asave(self, documents: list[str]) -> None:
|
||||
return None
|
||||
|
||||
def reset(self) -> None:
|
||||
return None
|
||||
|
||||
async def areset(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_factory():
|
||||
"""Reset the factory around each test without clobbering preexisting state."""
|
||||
original = factory._factory
|
||||
factory.set_knowledge_storage_factory(None)
|
||||
yield
|
||||
factory.set_knowledge_storage_factory(original)
|
||||
|
||||
|
||||
def test_resolve_reflects_registered_factory():
|
||||
fake = _FakeKnowledgeStorage()
|
||||
assert factory.resolve_knowledge_storage(None, "docs") is None
|
||||
|
||||
factory.set_knowledge_storage_factory(lambda embedder, name: fake)
|
||||
assert factory.resolve_knowledge_storage(None, "docs") is fake
|
||||
|
||||
|
||||
def test_factory_used_when_no_explicit_storage():
|
||||
fake = _FakeKnowledgeStorage()
|
||||
factory.set_knowledge_storage_factory(lambda embedder, name: fake)
|
||||
|
||||
knowledge = Knowledge(collection_name="docs", sources=[])
|
||||
|
||||
assert knowledge.storage is fake
|
||||
|
||||
|
||||
def test_factory_receives_embedder_and_collection_name():
|
||||
seen: list[tuple[object, object]] = []
|
||||
|
||||
def make(embedder, collection_name):
|
||||
seen.append((embedder, collection_name))
|
||||
return _FakeKnowledgeStorage()
|
||||
|
||||
factory.set_knowledge_storage_factory(make)
|
||||
Knowledge(collection_name="docs", sources=[])
|
||||
|
||||
assert seen == [(None, "docs")]
|
||||
|
||||
|
||||
def test_explicit_storage_bypasses_factory():
|
||||
factory_called = False
|
||||
|
||||
def make(embedder, name):
|
||||
nonlocal factory_called
|
||||
factory_called = True
|
||||
return _FakeKnowledgeStorage()
|
||||
|
||||
factory.set_knowledge_storage_factory(make)
|
||||
|
||||
explicit = _FakeKnowledgeStorage()
|
||||
knowledge = Knowledge(collection_name="docs", sources=[], storage=explicit)
|
||||
|
||||
assert knowledge.storage is explicit
|
||||
assert factory_called is False
|
||||
|
||||
|
||||
def test_falsy_explicit_storage_is_honored():
|
||||
# A custom backend that is falsy (defines __bool__/__len__) must still be
|
||||
# used and operated on, not silently treated as "not initialized" by a
|
||||
# truthiness check in __init__, reset, or the source save path.
|
||||
reset_calls: list[bool] = []
|
||||
|
||||
class _FalsyStorage(_FakeKnowledgeStorage):
|
||||
def __bool__(self) -> bool:
|
||||
return False
|
||||
|
||||
def reset(self) -> None:
|
||||
reset_calls.append(True)
|
||||
|
||||
explicit = _FalsyStorage()
|
||||
knowledge = Knowledge(collection_name="docs", sources=[], storage=explicit)
|
||||
|
||||
assert knowledge.storage is explicit
|
||||
|
||||
# reset must call the backend, not raise "Storage is not initialized."
|
||||
knowledge.reset()
|
||||
assert reset_calls == [True]
|
||||
72
lib/crewai/tests/memory/test_storage_factory.py
Normal file
72
lib/crewai/tests/memory/test_storage_factory.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Tests for the pluggable memory storage factory seam.
|
||||
|
||||
We verify our own logic: the set/get round-trip, that a registered factory is
|
||||
consulted for string ``storage`` specs (and receives the spec), and that an
|
||||
explicit ``storage=`` instance bypasses the factory entirely.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
import crewai.memory.storage.factory as factory
|
||||
from crewai.memory.unified_memory import Memory
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_factory():
|
||||
"""Reset the factory around each test without clobbering preexisting state."""
|
||||
original = factory._factory
|
||||
factory.set_memory_storage_factory(None)
|
||||
yield
|
||||
factory.set_memory_storage_factory(original)
|
||||
|
||||
|
||||
def test_resolve_reflects_registered_factory():
|
||||
sentinel = object()
|
||||
assert factory.resolve_memory_storage("lancedb") is None
|
||||
|
||||
factory.set_memory_storage_factory(lambda spec: sentinel)
|
||||
assert factory.resolve_memory_storage("lancedb") is sentinel
|
||||
|
||||
factory.set_memory_storage_factory(None)
|
||||
assert factory.resolve_memory_storage("lancedb") is None
|
||||
|
||||
|
||||
def test_factory_backend_used_for_string_spec():
|
||||
sentinel = object()
|
||||
factory.set_memory_storage_factory(lambda spec: sentinel)
|
||||
|
||||
mem = Memory(storage="lancedb")
|
||||
|
||||
assert mem._storage is sentinel
|
||||
|
||||
|
||||
def test_factory_receives_the_raw_spec():
|
||||
seen: list[str] = []
|
||||
|
||||
def make(spec):
|
||||
seen.append(spec)
|
||||
return object()
|
||||
|
||||
factory.set_memory_storage_factory(make)
|
||||
Memory(storage="some/custom/path")
|
||||
|
||||
assert seen == ["some/custom/path"]
|
||||
|
||||
|
||||
def test_explicit_storage_instance_bypasses_factory():
|
||||
factory_called = False
|
||||
|
||||
def make(spec):
|
||||
nonlocal factory_called
|
||||
factory_called = True
|
||||
return object()
|
||||
|
||||
factory.set_memory_storage_factory(make)
|
||||
|
||||
explicit = object()
|
||||
mem = Memory(storage=explicit) # type: ignore[arg-type]
|
||||
|
||||
assert mem._storage is explicit
|
||||
assert factory_called is False
|
||||
66
lib/crewai/tests/rag/test_client_factory_registry.py
Normal file
66
lib/crewai/tests/rag/test_client_factory_registry.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Tests for the RAG client factory registry seam.
|
||||
|
||||
We verify our own logic: a registered factory is used for its provider,
|
||||
factories override the built-in providers, unregister removes them, and an
|
||||
unknown provider still raises.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
import crewai.rag.factory as factory
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_registry():
|
||||
"""Reset the registry around each test without clobbering preexisting state."""
|
||||
original = dict(factory._factories)
|
||||
factory._factories.clear()
|
||||
yield
|
||||
factory._factories.clear()
|
||||
factory._factories.update(original)
|
||||
|
||||
|
||||
def test_registered_factory_is_used_for_its_provider():
|
||||
sentinel = object()
|
||||
factory.register_rag_client_factory("custom", lambda config: sentinel)
|
||||
|
||||
assert factory.create_client(SimpleNamespace(provider="custom")) is sentinel
|
||||
|
||||
|
||||
def test_factory_receives_the_config():
|
||||
seen: list[object] = []
|
||||
config = SimpleNamespace(provider="custom")
|
||||
factory.register_rag_client_factory("custom", lambda cfg: seen.append(cfg) or object())
|
||||
|
||||
factory.create_client(config)
|
||||
|
||||
assert seen == [config]
|
||||
|
||||
|
||||
def test_factory_overrides_builtin_provider():
|
||||
sentinel = object()
|
||||
factory.register_rag_client_factory("chromadb", lambda config: sentinel)
|
||||
|
||||
# Resolves via the registry without importing the built-in chromadb factory.
|
||||
assert factory.create_client(SimpleNamespace(provider="chromadb")) is sentinel
|
||||
|
||||
|
||||
def test_unregister_removes_factory():
|
||||
factory.register_rag_client_factory("custom", lambda config: object())
|
||||
factory.unregister_rag_client_factory("custom")
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported provider: custom"):
|
||||
factory.create_client(SimpleNamespace(provider="custom"))
|
||||
|
||||
|
||||
def test_unregister_unknown_provider_is_noop():
|
||||
factory.unregister_rag_client_factory("never-registered")
|
||||
|
||||
|
||||
def test_unknown_provider_still_raises():
|
||||
with pytest.raises(ValueError, match="Unsupported provider: nope"):
|
||||
factory.create_client(SimpleNamespace(provider="nope"))
|
||||
68
lib/crewai/tests/test_flow_persistence_factory.py
Normal file
68
lib/crewai/tests/test_flow_persistence_factory.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""Tests for the pluggable flow persistence factory seam.
|
||||
|
||||
We verify our own logic: that ``default_flow_persistence`` returns the
|
||||
registered factory's result, and that it falls back to the built-in SQLite
|
||||
persistence when no factory is registered.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
import crewai.flow.persistence.factory as factory
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.persistence.decorators import persist
|
||||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_factory():
|
||||
"""Reset the factory around each test without clobbering preexisting state."""
|
||||
original = factory._factory
|
||||
factory.set_flow_persistence_factory(None)
|
||||
yield
|
||||
factory.set_flow_persistence_factory(original)
|
||||
|
||||
|
||||
def test_default_uses_registered_factory():
|
||||
sentinel = SQLiteFlowPersistence()
|
||||
factory.set_flow_persistence_factory(lambda: sentinel)
|
||||
|
||||
assert factory.default_flow_persistence() is sentinel
|
||||
|
||||
|
||||
def test_default_falls_back_to_sqlite():
|
||||
assert isinstance(factory.default_flow_persistence(), SQLiteFlowPersistence)
|
||||
|
||||
|
||||
def test_persist_decorator_honors_falsy_persistence():
|
||||
# @persist with an explicit but falsy FlowPersistence must keep it, not
|
||||
# replace it with the default via a truthiness check.
|
||||
class _FalsyPersistence(FlowPersistence):
|
||||
def __bool__(self) -> bool:
|
||||
return False
|
||||
|
||||
def init_db(self) -> None:
|
||||
pass
|
||||
|
||||
def save_state(
|
||||
self,
|
||||
flow_uuid: str,
|
||||
method_name: str,
|
||||
state_data: dict[str, Any] | BaseModel,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
|
||||
return None
|
||||
|
||||
falsy = _FalsyPersistence()
|
||||
|
||||
@persist(persistence=falsy)
|
||||
class _DummyFlow:
|
||||
pass
|
||||
|
||||
assert _DummyFlow.__flow_persistence_config__.persistence is falsy
|
||||
Reference in New Issue
Block a user