mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
feat: rag configuration with optional dependency support (#3394)
### RAG Config System * Added ChromaDB client creation via config with sensible defaults * Introduced optional imports and shared RAG config utilities/schema * Enabled embedding function support with ChromaDB provider integration * Refactored configs for immutability and stronger type safety * Removed unused code and expanded test coverage
This commit is contained in:
@@ -1 +1,58 @@
|
|||||||
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
|
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import importlib
|
||||||
|
from types import ModuleType
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from crewai.rag.config.types import RagConfigType
|
||||||
|
from crewai.rag.config.utils import set_rag_config
|
||||||
|
|
||||||
|
|
||||||
|
_module_path = __path__
|
||||||
|
_module_file = __file__
|
||||||
|
|
||||||
|
class _RagModule(ModuleType):
|
||||||
|
"""Module wrapper to intercept attribute setting for config."""
|
||||||
|
|
||||||
|
__path__ = _module_path
|
||||||
|
__file__ = _module_file
|
||||||
|
|
||||||
|
def __init__(self, module_name: str):
|
||||||
|
"""Initialize the module wrapper.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module_name: Name of the module.
|
||||||
|
"""
|
||||||
|
super().__init__(module_name)
|
||||||
|
|
||||||
|
def __setattr__(self, name: str, value: RagConfigType) -> None:
|
||||||
|
"""Set module attributes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Attribute name.
|
||||||
|
value: Attribute value.
|
||||||
|
"""
|
||||||
|
if name == "config":
|
||||||
|
return set_rag_config(value)
|
||||||
|
raise AttributeError(f"Setting attribute '{name}' is not allowed.")
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
"""Get module attributes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Attribute name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The requested attribute.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AttributeError: If attribute doesn't exist.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return importlib.import_module(f"{self.__name__}.{name}")
|
||||||
|
except ImportError:
|
||||||
|
raise AttributeError(f"module '{self.__name__}' has no attribute '{name}'")
|
||||||
|
|
||||||
|
|
||||||
|
sys.modules[__name__] = _RagModule(__name__)
|
||||||
|
|||||||
57
src/crewai/rag/chromadb/config.py
Normal file
57
src/crewai/rag/chromadb/config.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""ChromaDB configuration model."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from dataclasses import field
|
||||||
|
from typing import Literal, cast
|
||||||
|
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||||
|
from chromadb.config import Settings
|
||||||
|
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
|
||||||
|
|
||||||
|
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||||
|
from crewai.utilities.paths import db_storage_path
|
||||||
|
from crewai.rag.config.base import BaseRagConfig
|
||||||
|
from crewai.rag.chromadb.constants import DEFAULT_TENANT, DEFAULT_DATABASE
|
||||||
|
|
||||||
|
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message=".*Mixing V1 models and V2 models.*",
|
||||||
|
category=UserWarning,
|
||||||
|
module="pydantic._internal._generate_schema",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _default_settings() -> Settings:
|
||||||
|
"""Create default ChromaDB settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Settings with persistent storage and reset enabled.
|
||||||
|
"""
|
||||||
|
return Settings(
|
||||||
|
persist_directory=os.path.join(db_storage_path(), "chromadb"),
|
||||||
|
allow_reset=True,
|
||||||
|
is_persistent=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _default_embedding_function() -> ChromaEmbeddingFunctionWrapper:
|
||||||
|
"""Create default ChromaDB embedding function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Default embedding function cast to proper type.
|
||||||
|
"""
|
||||||
|
return cast(ChromaEmbeddingFunctionWrapper, DefaultEmbeddingFunction())
|
||||||
|
|
||||||
|
|
||||||
|
@pyd_dataclass(frozen=True)
|
||||||
|
class ChromaDBConfig(BaseRagConfig):
|
||||||
|
"""Configuration for ChromaDB client."""
|
||||||
|
|
||||||
|
provider: Literal["chromadb"] = field(default="chromadb", init=False)
|
||||||
|
tenant: str = DEFAULT_TENANT
|
||||||
|
database: str = DEFAULT_DATABASE
|
||||||
|
settings: Settings = field(default_factory=_default_settings)
|
||||||
|
embedding_function: ChromaEmbeddingFunctionWrapper | None = field(
|
||||||
|
default_factory=_default_embedding_function
|
||||||
|
)
|
||||||
6
src/crewai/rag/chromadb/constants.py
Normal file
6
src/crewai/rag/chromadb/constants.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
"""Constants for ChromaDB configuration."""
|
||||||
|
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
|
DEFAULT_TENANT: Final[str] = "default_tenant"
|
||||||
|
DEFAULT_DATABASE: Final[str] = "default_database"
|
||||||
26
src/crewai/rag/chromadb/factory.py
Normal file
26
src/crewai/rag/chromadb/factory.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""Factory functions for creating ChromaDB clients."""
|
||||||
|
|
||||||
|
from chromadb import Client
|
||||||
|
|
||||||
|
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||||
|
from crewai.rag.chromadb.client import ChromaDBClient
|
||||||
|
|
||||||
|
|
||||||
|
def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||||
|
"""Create a ChromaDBClient from configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: ChromaDB configuration object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured ChromaDBClient instance.
|
||||||
|
"""
|
||||||
|
chromadb_client = Client(
|
||||||
|
settings=config.settings, tenant=config.tenant, database=config.database
|
||||||
|
)
|
||||||
|
|
||||||
|
client = ChromaDBClient()
|
||||||
|
client.client = chromadb_client
|
||||||
|
client.embedding_function = config.embedding_function
|
||||||
|
|
||||||
|
return client
|
||||||
@@ -3,6 +3,8 @@
|
|||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from typing import Any, NamedTuple
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
|
from pydantic import GetCoreSchemaHandler
|
||||||
|
from pydantic_core import CoreSchema, core_schema
|
||||||
from chromadb.api import ClientAPI, AsyncClientAPI
|
from chromadb.api import ClientAPI, AsyncClientAPI
|
||||||
from chromadb.api.configuration import CollectionConfigurationInterface
|
from chromadb.api.configuration import CollectionConfigurationInterface
|
||||||
from chromadb.api.types import (
|
from chromadb.api.types import (
|
||||||
@@ -21,6 +23,21 @@ from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSear
|
|||||||
ChromaDBClientType = ClientAPI | AsyncClientAPI
|
ChromaDBClientType = ClientAPI | AsyncClientAPI
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction[Embeddable]):
|
||||||
|
"""Base class for ChromaDB EmbeddingFunction to work with Pydantic validation."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(
|
||||||
|
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||||
|
) -> CoreSchema:
|
||||||
|
"""Generate Pydantic core schema for ChromaDB EmbeddingFunction.
|
||||||
|
|
||||||
|
This allows Pydantic to handle ChromaDB's EmbeddingFunction type
|
||||||
|
without requiring arbitrary_types_allowed=True.
|
||||||
|
"""
|
||||||
|
return core_schema.any_schema()
|
||||||
|
|
||||||
|
|
||||||
class PreparedDocuments(NamedTuple):
|
class PreparedDocuments(NamedTuple):
|
||||||
"""Prepared documents ready for ChromaDB insertion.
|
"""Prepared documents ready for ChromaDB insertion.
|
||||||
|
|
||||||
|
|||||||
@@ -10,10 +10,8 @@ from chromadb.api.types import (
|
|||||||
IncludeEnum,
|
IncludeEnum,
|
||||||
QueryResult,
|
QueryResult,
|
||||||
)
|
)
|
||||||
|
|
||||||
from chromadb.api.models.AsyncCollection import AsyncCollection
|
from chromadb.api.models.AsyncCollection import AsyncCollection
|
||||||
from chromadb.api.models.Collection import Collection
|
from chromadb.api.models.Collection import Collection
|
||||||
|
|
||||||
from crewai.rag.chromadb.types import (
|
from crewai.rag.chromadb.types import (
|
||||||
ChromaDBClientType,
|
ChromaDBClientType,
|
||||||
ChromaDBCollectionSearchParams,
|
ChromaDBCollectionSearchParams,
|
||||||
|
|||||||
1
src/crewai/rag/config/__init__.py
Normal file
1
src/crewai/rag/config/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""RAG client configuration management using ContextVars for thread-safe provider switching."""
|
||||||
16
src/crewai/rag/config/base.py
Normal file
16
src/crewai/rag/config/base.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
"""Base configuration class for RAG providers."""
|
||||||
|
|
||||||
|
from dataclasses import field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||||
|
|
||||||
|
from crewai.rag.config.optional_imports.types import SupportedProvider
|
||||||
|
|
||||||
|
|
||||||
|
@pyd_dataclass(frozen=True)
|
||||||
|
class BaseRagConfig:
|
||||||
|
"""Base class for RAG configuration with Pydantic serialization support."""
|
||||||
|
|
||||||
|
provider: SupportedProvider = field(init=False)
|
||||||
|
embedding_function: Any | None = field(default=None)
|
||||||
8
src/crewai/rag/config/constants.py
Normal file
8
src/crewai/rag/config/constants.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
"""Constants for RAG configuration."""
|
||||||
|
|
||||||
|
from typing import Final
|
||||||
|
|
||||||
|
DISCRIMINATOR: Final[str] = "provider"
|
||||||
|
|
||||||
|
DEFAULT_RAG_CONFIG_PATH: Final[str] = "crewai.rag.chromadb.config"
|
||||||
|
DEFAULT_RAG_CONFIG_CLASS: Final[str] = "ChromaDBConfig"
|
||||||
32
src/crewai/rag/config/factory.py
Normal file
32
src/crewai/rag/config/factory.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Factory functions for creating RAG clients from configuration."""
|
||||||
|
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from crewai.rag.config.optional_imports.protocols import ChromaFactoryModule
|
||||||
|
from crewai.rag.core.base_client import BaseClient
|
||||||
|
from crewai.rag.config.types import RagConfigType
|
||||||
|
from crewai.utilities.import_utils import require
|
||||||
|
|
||||||
|
|
||||||
|
def create_client(config: RagConfigType) -> BaseClient:
|
||||||
|
"""Create a client from configuration using the appropriate factory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The RAG client configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created client instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the configuration provider is not supported.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if config.provider == "chromadb":
|
||||||
|
mod = cast(
|
||||||
|
ChromaFactoryModule,
|
||||||
|
require(
|
||||||
|
"crewai.rag.chromadb.factory",
|
||||||
|
purpose="The 'chromadb' provider",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return mod.create_client(config)
|
||||||
1
src/crewai/rag/config/optional_imports/__init__.py
Normal file
1
src/crewai/rag/config/optional_imports/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Optional imports for RAG configuration providers."""
|
||||||
24
src/crewai/rag/config/optional_imports/base.py
Normal file
24
src/crewai/rag/config/optional_imports/base.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
"""Base classes for missing provider configurations."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
from dataclasses import field
|
||||||
|
|
||||||
|
from pydantic import ConfigDict
|
||||||
|
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@pyd_dataclass(config=ConfigDict(extra="forbid"))
|
||||||
|
class _MissingProvider:
|
||||||
|
"""Base class for missing provider configurations.
|
||||||
|
|
||||||
|
Raises RuntimeError when instantiated to indicate missing dependencies.
|
||||||
|
"""
|
||||||
|
|
||||||
|
provider: Literal["chromadb", "__missing__"] = field(default="__missing__")
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Raises error indicating the provider is not installed."""
|
||||||
|
raise RuntimeError(
|
||||||
|
f"provider '{self.provider}' requested but not installed. "
|
||||||
|
f"Install the extra: `uv add crewai'[{self.provider}]'`."
|
||||||
|
)
|
||||||
14
src/crewai/rag/config/optional_imports/protocols.py
Normal file
14
src/crewai/rag/config/optional_imports/protocols.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""Protocol definitions for RAG factory modules."""
|
||||||
|
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
from crewai.rag.config.types import RagConfigType
|
||||||
|
from crewai.rag.core.base_client import BaseClient
|
||||||
|
|
||||||
|
|
||||||
|
class ChromaFactoryModule(Protocol):
|
||||||
|
"""Protocol for ChromaDB factory module."""
|
||||||
|
|
||||||
|
def create_client(self, config: RagConfigType) -> BaseClient:
|
||||||
|
"""Creates a ChromaDB client from configuration."""
|
||||||
|
...
|
||||||
15
src/crewai/rag/config/optional_imports/providers.py
Normal file
15
src/crewai/rag/config/optional_imports/providers.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""Provider-specific missing configuration classes."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
from dataclasses import field
|
||||||
|
from pydantic import ConfigDict
|
||||||
|
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||||
|
|
||||||
|
from crewai.rag.config.optional_imports.base import _MissingProvider
|
||||||
|
|
||||||
|
|
||||||
|
@pyd_dataclass(config=ConfigDict(extra="forbid"))
|
||||||
|
class MissingChromaDBConfig(_MissingProvider):
|
||||||
|
"""Placeholder for missing ChromaDB configuration."""
|
||||||
|
|
||||||
|
provider: Literal["chromadb"] = field(default="chromadb")
|
||||||
8
src/crewai/rag/config/optional_imports/types.py
Normal file
8
src/crewai/rag/config/optional_imports/types.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
"""Type definitions for optional imports."""
|
||||||
|
|
||||||
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
|
SupportedProvider = Annotated[
|
||||||
|
Literal["chromadb"],
|
||||||
|
"Supported RAG provider types, add providers here as they become available",
|
||||||
|
]
|
||||||
21
src/crewai/rag/config/types.py
Normal file
21
src/crewai/rag/config/types.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""Type definitions for RAG configuration."""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Annotated, TypeAlias
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from crewai.rag.config.constants import DISCRIMINATOR
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||||
|
except ImportError:
|
||||||
|
from crewai.rag.config.optional_imports.providers import (
|
||||||
|
MissingChromaDBConfig as ChromaDBConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
SupportedProviderConfig: TypeAlias = ChromaDBConfig
|
||||||
|
RagConfigType: TypeAlias = Annotated[
|
||||||
|
SupportedProviderConfig, Field(discriminator=DISCRIMINATOR)
|
||||||
|
]
|
||||||
86
src/crewai/rag/config/utils.py
Normal file
86
src/crewai/rag/config/utils.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
"""RAG client configuration utilities."""
|
||||||
|
|
||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from crewai.utilities.import_utils import require
|
||||||
|
from crewai.rag.core.base_client import BaseClient
|
||||||
|
from crewai.rag.config.types import RagConfigType
|
||||||
|
from crewai.rag.config.constants import (
|
||||||
|
DEFAULT_RAG_CONFIG_PATH,
|
||||||
|
DEFAULT_RAG_CONFIG_CLASS,
|
||||||
|
)
|
||||||
|
from crewai.rag.config.factory import create_client
|
||||||
|
|
||||||
|
|
||||||
|
class RagContext(BaseModel):
|
||||||
|
"""Context holding RAG configuration and client instance."""
|
||||||
|
|
||||||
|
config: RagConfigType = Field(..., description="RAG provider configuration")
|
||||||
|
client: BaseClient | None = Field(
|
||||||
|
default=None, description="Instantiated RAG client"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_rag_context: ContextVar[RagContext | None] = ContextVar("_rag_context", default=None)
|
||||||
|
|
||||||
|
|
||||||
|
def set_rag_config(config: RagConfigType) -> None:
|
||||||
|
"""Set global RAG client configuration and instantiate the client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The RAG client configuration (ChromaDBConfig).
|
||||||
|
"""
|
||||||
|
client = create_client(config)
|
||||||
|
context = RagContext(config=config, client=client)
|
||||||
|
_rag_context.set(context)
|
||||||
|
|
||||||
|
|
||||||
|
def get_rag_config() -> RagConfigType:
|
||||||
|
"""Get current RAG configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The current RAG configuration object.
|
||||||
|
"""
|
||||||
|
context = _rag_context.get()
|
||||||
|
if context is None:
|
||||||
|
module = require(DEFAULT_RAG_CONFIG_PATH, purpose="RAG configuration")
|
||||||
|
config_class = getattr(module, DEFAULT_RAG_CONFIG_CLASS)
|
||||||
|
default_config = config_class()
|
||||||
|
set_rag_config(default_config)
|
||||||
|
context = _rag_context.get()
|
||||||
|
|
||||||
|
if context is None or context.config is None:
|
||||||
|
raise ValueError(
|
||||||
|
"RAG configuration is not set. Please set the RAG config first."
|
||||||
|
)
|
||||||
|
|
||||||
|
return context.config
|
||||||
|
|
||||||
|
|
||||||
|
def get_rag_client() -> BaseClient:
|
||||||
|
"""Get the current RAG client instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The current RAG client, creating one if needed.
|
||||||
|
"""
|
||||||
|
context = _rag_context.get()
|
||||||
|
if context is None:
|
||||||
|
get_rag_config()
|
||||||
|
context = _rag_context.get()
|
||||||
|
|
||||||
|
if context and context.client is None:
|
||||||
|
context.client = create_client(context.config)
|
||||||
|
|
||||||
|
if context is None or context.client is None:
|
||||||
|
raise ValueError(
|
||||||
|
"RAG client is not configured. Please set the RAG config first."
|
||||||
|
)
|
||||||
|
|
||||||
|
return context.client
|
||||||
|
|
||||||
|
|
||||||
|
def clear_rag_config() -> None:
|
||||||
|
"""Clear the current RAG configuration and client, reverting to defaults."""
|
||||||
|
_rag_context.set(None)
|
||||||
@@ -3,6 +3,8 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Any, Protocol, runtime_checkable, TypedDict, Annotated
|
from typing import Any, Protocol, runtime_checkable, TypedDict, Annotated
|
||||||
from typing_extensions import Unpack, Required
|
from typing_extensions import Unpack, Required
|
||||||
|
from pydantic import GetCoreSchemaHandler
|
||||||
|
from pydantic_core import CoreSchema, core_schema
|
||||||
|
|
||||||
|
|
||||||
from crewai.rag.types import (
|
from crewai.rag.types import (
|
||||||
@@ -96,6 +98,17 @@ class BaseClient(Protocol):
|
|||||||
client: Any
|
client: Any
|
||||||
embedding_function: EmbeddingFunction
|
embedding_function: EmbeddingFunction
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(
|
||||||
|
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||||
|
) -> CoreSchema:
|
||||||
|
"""Generate Pydantic core schema for BaseClient Protocol.
|
||||||
|
|
||||||
|
This allows the Protocol to be used in Pydantic models without
|
||||||
|
requiring arbitrary_types_allowed=True.
|
||||||
|
"""
|
||||||
|
return core_schema.any_schema()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
def create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||||
"""Create a new collection/index in the vector database.
|
"""Create a new collection/index in the vector database.
|
||||||
|
|||||||
@@ -1,30 +0,0 @@
|
|||||||
"""Base provider protocol for vector database client creation."""
|
|
||||||
|
|
||||||
from abc import ABC
|
|
||||||
from typing import Any, Protocol, runtime_checkable, Union
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from crewai.rag.types import EmbeddingFunction
|
|
||||||
from crewai.rag.embeddings.types import EmbeddingOptions
|
|
||||||
|
|
||||||
|
|
||||||
class BaseProviderOptions(BaseModel, ABC):
|
|
||||||
"""Base configuration for all provider options."""
|
|
||||||
|
|
||||||
client_type: str = Field(..., description="Type of client to create")
|
|
||||||
embedding_config: Union[EmbeddingOptions, EmbeddingFunction, None] = Field(
|
|
||||||
default=None,
|
|
||||||
description="Embedding configuration - either options for built-in providers or a custom function",
|
|
||||||
)
|
|
||||||
options: Any = Field(
|
|
||||||
default=None, description="Additional provider-specific options"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
|
||||||
class BaseProvider(Protocol):
|
|
||||||
"""Protocol for vector database client providers."""
|
|
||||||
|
|
||||||
def __call__(self, options: BaseProviderOptions) -> Any:
|
|
||||||
"""Create and return a configured client instance."""
|
|
||||||
...
|
|
||||||
34
tests/rag/config/test_factory.py
Normal file
34
tests/rag/config/test_factory.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""Tests for RAG config factory."""
|
||||||
|
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from crewai.rag.config.factory import create_client
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_client_chromadb():
|
||||||
|
"""Test ChromaDB client creation."""
|
||||||
|
mock_config = Mock()
|
||||||
|
mock_config.provider = "chromadb"
|
||||||
|
|
||||||
|
with patch("crewai.rag.config.factory.require") as mock_require:
|
||||||
|
mock_module = Mock()
|
||||||
|
mock_client = Mock()
|
||||||
|
mock_module.create_client.return_value = mock_client
|
||||||
|
mock_require.return_value = mock_module
|
||||||
|
|
||||||
|
result = create_client(mock_config)
|
||||||
|
|
||||||
|
assert result == mock_client
|
||||||
|
mock_require.assert_called_once_with(
|
||||||
|
"crewai.rag.chromadb.factory", purpose="The 'chromadb' provider"
|
||||||
|
)
|
||||||
|
mock_module.create_client.assert_called_once_with(mock_config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_client_unsupported_provider():
|
||||||
|
"""Test unsupported provider returns None for now."""
|
||||||
|
mock_config = Mock()
|
||||||
|
mock_config.provider = "unsupported"
|
||||||
|
|
||||||
|
result = create_client(mock_config)
|
||||||
|
assert result is None
|
||||||
22
tests/rag/config/test_optional_imports.py
Normal file
22
tests/rag/config/test_optional_imports.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
"""Tests for optional imports."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from crewai.rag.config.optional_imports.base import _MissingProvider
|
||||||
|
from crewai.rag.config.optional_imports.providers import MissingChromaDBConfig
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_provider_raises_runtime_error():
|
||||||
|
"""Test that _MissingProvider raises RuntimeError on instantiation."""
|
||||||
|
with pytest.raises(
|
||||||
|
RuntimeError, match="provider '__missing__' requested but not installed"
|
||||||
|
):
|
||||||
|
_MissingProvider()
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_chromadb_config_raises_runtime_error():
|
||||||
|
"""Test that MissingChromaDBConfig raises RuntimeError on instantiation."""
|
||||||
|
with pytest.raises(
|
||||||
|
RuntimeError, match="provider 'chromadb' requested but not installed"
|
||||||
|
):
|
||||||
|
MissingChromaDBConfig()
|
||||||
Reference in New Issue
Block a user