feat: rag configuration with optional dependency support (#3394)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

### RAG Config System

* Added ChromaDB client creation via config with sensible defaults
* Introduced optional imports and shared RAG config utilities/schema
* Enabled embedding function support with ChromaDB provider integration
* Refactored configs for immutability and stronger type safety
* Removed unused code and expanded test coverage
This commit is contained in:
Greyson LaLonde
2025-08-26 00:00:22 -04:00
committed by GitHub
parent 2e4bd3f49d
commit 7ac482c7c9
21 changed files with 459 additions and 33 deletions

View File

@@ -1 +1,58 @@
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI.""" """RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
import sys
import importlib
from types import ModuleType
from typing import Any
from crewai.rag.config.types import RagConfigType
from crewai.rag.config.utils import set_rag_config
_module_path = __path__
_module_file = __file__
class _RagModule(ModuleType):
"""Module wrapper to intercept attribute setting for config."""
__path__ = _module_path
__file__ = _module_file
def __init__(self, module_name: str):
"""Initialize the module wrapper.
Args:
module_name: Name of the module.
"""
super().__init__(module_name)
def __setattr__(self, name: str, value: RagConfigType) -> None:
"""Set module attributes.
Args:
name: Attribute name.
value: Attribute value.
"""
if name == "config":
return set_rag_config(value)
raise AttributeError(f"Setting attribute '{name}' is not allowed.")
def __getattr__(self, name: str) -> Any:
"""Get module attributes.
Args:
name: Attribute name.
Returns:
The requested attribute.
Raises:
AttributeError: If attribute doesn't exist.
"""
try:
return importlib.import_module(f"{self.__name__}.{name}")
except ImportError:
raise AttributeError(f"module '{self.__name__}' has no attribute '{name}'")
sys.modules[__name__] = _RagModule(__name__)

View File

@@ -0,0 +1,57 @@
"""ChromaDB configuration model."""
import os
import warnings
from dataclasses import field
from typing import Literal, cast
from pydantic.dataclasses import dataclass as pyd_dataclass
from chromadb.config import Settings
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.utilities.paths import db_storage_path
from crewai.rag.config.base import BaseRagConfig
from crewai.rag.chromadb.constants import DEFAULT_TENANT, DEFAULT_DATABASE
warnings.filterwarnings(
"ignore",
message=".*Mixing V1 models and V2 models.*",
category=UserWarning,
module="pydantic._internal._generate_schema",
)
def _default_settings() -> Settings:
"""Create default ChromaDB settings.
Returns:
Settings with persistent storage and reset enabled.
"""
return Settings(
persist_directory=os.path.join(db_storage_path(), "chromadb"),
allow_reset=True,
is_persistent=True,
)
def _default_embedding_function() -> ChromaEmbeddingFunctionWrapper:
"""Create default ChromaDB embedding function.
Returns:
Default embedding function cast to proper type.
"""
return cast(ChromaEmbeddingFunctionWrapper, DefaultEmbeddingFunction())
@pyd_dataclass(frozen=True)
class ChromaDBConfig(BaseRagConfig):
"""Configuration for ChromaDB client."""
provider: Literal["chromadb"] = field(default="chromadb", init=False)
tenant: str = DEFAULT_TENANT
database: str = DEFAULT_DATABASE
settings: Settings = field(default_factory=_default_settings)
embedding_function: ChromaEmbeddingFunctionWrapper | None = field(
default_factory=_default_embedding_function
)

View File

@@ -0,0 +1,6 @@
"""Constants for ChromaDB configuration."""
from typing import Final
DEFAULT_TENANT: Final[str] = "default_tenant"
DEFAULT_DATABASE: Final[str] = "default_database"

View File

@@ -0,0 +1,26 @@
"""Factory functions for creating ChromaDB clients."""
from chromadb import Client
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.client import ChromaDBClient
def create_client(config: ChromaDBConfig) -> ChromaDBClient:
"""Create a ChromaDBClient from configuration.
Args:
config: ChromaDB configuration object.
Returns:
Configured ChromaDBClient instance.
"""
chromadb_client = Client(
settings=config.settings, tenant=config.tenant, database=config.database
)
client = ChromaDBClient()
client.client = chromadb_client
client.embedding_function = config.embedding_function
return client

View File

@@ -3,6 +3,8 @@
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, NamedTuple from typing import Any, NamedTuple
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from chromadb.api import ClientAPI, AsyncClientAPI from chromadb.api import ClientAPI, AsyncClientAPI
from chromadb.api.configuration import CollectionConfigurationInterface from chromadb.api.configuration import CollectionConfigurationInterface
from chromadb.api.types import ( from chromadb.api.types import (
@@ -21,6 +23,21 @@ from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSear
ChromaDBClientType = ClientAPI | AsyncClientAPI ChromaDBClientType = ClientAPI | AsyncClientAPI
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction[Embeddable]):
"""Base class for ChromaDB EmbeddingFunction to work with Pydantic validation."""
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Generate Pydantic core schema for ChromaDB EmbeddingFunction.
This allows Pydantic to handle ChromaDB's EmbeddingFunction type
without requiring arbitrary_types_allowed=True.
"""
return core_schema.any_schema()
class PreparedDocuments(NamedTuple): class PreparedDocuments(NamedTuple):
"""Prepared documents ready for ChromaDB insertion. """Prepared documents ready for ChromaDB insertion.

View File

@@ -10,10 +10,8 @@ from chromadb.api.types import (
IncludeEnum, IncludeEnum,
QueryResult, QueryResult,
) )
from chromadb.api.models.AsyncCollection import AsyncCollection from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.models.Collection import Collection from chromadb.api.models.Collection import Collection
from crewai.rag.chromadb.types import ( from crewai.rag.chromadb.types import (
ChromaDBClientType, ChromaDBClientType,
ChromaDBCollectionSearchParams, ChromaDBCollectionSearchParams,

View File

@@ -0,0 +1 @@
"""RAG client configuration management using ContextVars for thread-safe provider switching."""

View File

@@ -0,0 +1,16 @@
"""Base configuration class for RAG providers."""
from dataclasses import field
from typing import Any
from pydantic.dataclasses import dataclass as pyd_dataclass
from crewai.rag.config.optional_imports.types import SupportedProvider
@pyd_dataclass(frozen=True)
class BaseRagConfig:
"""Base class for RAG configuration with Pydantic serialization support."""
provider: SupportedProvider = field(init=False)
embedding_function: Any | None = field(default=None)

View File

@@ -0,0 +1,8 @@
"""Constants for RAG configuration."""
from typing import Final
DISCRIMINATOR: Final[str] = "provider"
DEFAULT_RAG_CONFIG_PATH: Final[str] = "crewai.rag.chromadb.config"
DEFAULT_RAG_CONFIG_CLASS: Final[str] = "ChromaDBConfig"

View File

@@ -0,0 +1,32 @@
"""Factory functions for creating RAG clients from configuration."""
from typing import cast
from crewai.rag.config.optional_imports.protocols import ChromaFactoryModule
from crewai.rag.core.base_client import BaseClient
from crewai.rag.config.types import RagConfigType
from crewai.utilities.import_utils import require
def create_client(config: RagConfigType) -> BaseClient:
"""Create a client from configuration using the appropriate factory.
Args:
config: The RAG client configuration.
Returns:
The created client instance.
Raises:
ValueError: If the configuration provider is not supported.
"""
if config.provider == "chromadb":
mod = cast(
ChromaFactoryModule,
require(
"crewai.rag.chromadb.factory",
purpose="The 'chromadb' provider",
),
)
return mod.create_client(config)

View File

@@ -0,0 +1 @@
"""Optional imports for RAG configuration providers."""

View File

@@ -0,0 +1,24 @@
"""Base classes for missing provider configurations."""
from typing import Literal
from dataclasses import field
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass as pyd_dataclass
@pyd_dataclass(config=ConfigDict(extra="forbid"))
class _MissingProvider:
"""Base class for missing provider configurations.
Raises RuntimeError when instantiated to indicate missing dependencies.
"""
provider: Literal["chromadb", "__missing__"] = field(default="__missing__")
def __post_init__(self) -> None:
"""Raises error indicating the provider is not installed."""
raise RuntimeError(
f"provider '{self.provider}' requested but not installed. "
f"Install the extra: `uv add crewai'[{self.provider}]'`."
)

View File

@@ -0,0 +1,14 @@
"""Protocol definitions for RAG factory modules."""
from typing import Protocol
from crewai.rag.config.types import RagConfigType
from crewai.rag.core.base_client import BaseClient
class ChromaFactoryModule(Protocol):
"""Protocol for ChromaDB factory module."""
def create_client(self, config: RagConfigType) -> BaseClient:
"""Creates a ChromaDB client from configuration."""
...

View File

@@ -0,0 +1,15 @@
"""Provider-specific missing configuration classes."""
from typing import Literal
from dataclasses import field
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass as pyd_dataclass
from crewai.rag.config.optional_imports.base import _MissingProvider
@pyd_dataclass(config=ConfigDict(extra="forbid"))
class MissingChromaDBConfig(_MissingProvider):
"""Placeholder for missing ChromaDB configuration."""
provider: Literal["chromadb"] = field(default="chromadb")

View File

@@ -0,0 +1,8 @@
"""Type definitions for optional imports."""
from typing import Annotated, Literal
SupportedProvider = Annotated[
Literal["chromadb"],
"Supported RAG provider types, add providers here as they become available",
]

View File

@@ -0,0 +1,21 @@
"""Type definitions for RAG configuration."""
from typing import TYPE_CHECKING, Annotated, TypeAlias
from pydantic import Field
from crewai.rag.config.constants import DISCRIMINATOR
if TYPE_CHECKING:
from crewai.rag.chromadb.config import ChromaDBConfig
else:
try:
from crewai.rag.chromadb.config import ChromaDBConfig
except ImportError:
from crewai.rag.config.optional_imports.providers import (
MissingChromaDBConfig as ChromaDBConfig,
)
SupportedProviderConfig: TypeAlias = ChromaDBConfig
RagConfigType: TypeAlias = Annotated[
SupportedProviderConfig, Field(discriminator=DISCRIMINATOR)
]

View File

@@ -0,0 +1,86 @@
"""RAG client configuration utilities."""
from contextvars import ContextVar
from pydantic import BaseModel, Field
from crewai.utilities.import_utils import require
from crewai.rag.core.base_client import BaseClient
from crewai.rag.config.types import RagConfigType
from crewai.rag.config.constants import (
DEFAULT_RAG_CONFIG_PATH,
DEFAULT_RAG_CONFIG_CLASS,
)
from crewai.rag.config.factory import create_client
class RagContext(BaseModel):
"""Context holding RAG configuration and client instance."""
config: RagConfigType = Field(..., description="RAG provider configuration")
client: BaseClient | None = Field(
default=None, description="Instantiated RAG client"
)
_rag_context: ContextVar[RagContext | None] = ContextVar("_rag_context", default=None)
def set_rag_config(config: RagConfigType) -> None:
"""Set global RAG client configuration and instantiate the client.
Args:
config: The RAG client configuration (ChromaDBConfig).
"""
client = create_client(config)
context = RagContext(config=config, client=client)
_rag_context.set(context)
def get_rag_config() -> RagConfigType:
"""Get current RAG configuration.
Returns:
The current RAG configuration object.
"""
context = _rag_context.get()
if context is None:
module = require(DEFAULT_RAG_CONFIG_PATH, purpose="RAG configuration")
config_class = getattr(module, DEFAULT_RAG_CONFIG_CLASS)
default_config = config_class()
set_rag_config(default_config)
context = _rag_context.get()
if context is None or context.config is None:
raise ValueError(
"RAG configuration is not set. Please set the RAG config first."
)
return context.config
def get_rag_client() -> BaseClient:
"""Get the current RAG client instance.
Returns:
The current RAG client, creating one if needed.
"""
context = _rag_context.get()
if context is None:
get_rag_config()
context = _rag_context.get()
if context and context.client is None:
context.client = create_client(context.config)
if context is None or context.client is None:
raise ValueError(
"RAG client is not configured. Please set the RAG config first."
)
return context.client
def clear_rag_config() -> None:
"""Clear the current RAG configuration and client, reverting to defaults."""
_rag_context.set(None)

View File

@@ -3,6 +3,8 @@
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Protocol, runtime_checkable, TypedDict, Annotated from typing import Any, Protocol, runtime_checkable, TypedDict, Annotated
from typing_extensions import Unpack, Required from typing_extensions import Unpack, Required
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from crewai.rag.types import ( from crewai.rag.types import (
@@ -96,6 +98,17 @@ class BaseClient(Protocol):
client: Any client: Any
embedding_function: EmbeddingFunction embedding_function: EmbeddingFunction
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Generate Pydantic core schema for BaseClient Protocol.
This allows the Protocol to be used in Pydantic models without
requiring arbitrary_types_allowed=True.
"""
return core_schema.any_schema()
@abstractmethod @abstractmethod
def create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None: def create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Create a new collection/index in the vector database. """Create a new collection/index in the vector database.

View File

@@ -1,30 +0,0 @@
"""Base provider protocol for vector database client creation."""
from abc import ABC
from typing import Any, Protocol, runtime_checkable, Union
from pydantic import BaseModel, Field
from crewai.rag.types import EmbeddingFunction
from crewai.rag.embeddings.types import EmbeddingOptions
class BaseProviderOptions(BaseModel, ABC):
"""Base configuration for all provider options."""
client_type: str = Field(..., description="Type of client to create")
embedding_config: Union[EmbeddingOptions, EmbeddingFunction, None] = Field(
default=None,
description="Embedding configuration - either options for built-in providers or a custom function",
)
options: Any = Field(
default=None, description="Additional provider-specific options"
)
@runtime_checkable
class BaseProvider(Protocol):
"""Protocol for vector database client providers."""
def __call__(self, options: BaseProviderOptions) -> Any:
"""Create and return a configured client instance."""
...

View File

@@ -0,0 +1,34 @@
"""Tests for RAG config factory."""
from unittest.mock import Mock, patch
from crewai.rag.config.factory import create_client
def test_create_client_chromadb():
"""Test ChromaDB client creation."""
mock_config = Mock()
mock_config.provider = "chromadb"
with patch("crewai.rag.config.factory.require") as mock_require:
mock_module = Mock()
mock_client = Mock()
mock_module.create_client.return_value = mock_client
mock_require.return_value = mock_module
result = create_client(mock_config)
assert result == mock_client
mock_require.assert_called_once_with(
"crewai.rag.chromadb.factory", purpose="The 'chromadb' provider"
)
mock_module.create_client.assert_called_once_with(mock_config)
def test_create_client_unsupported_provider():
"""Test unsupported provider returns None for now."""
mock_config = Mock()
mock_config.provider = "unsupported"
result = create_client(mock_config)
assert result is None

View File

@@ -0,0 +1,22 @@
"""Tests for optional imports."""
import pytest
from crewai.rag.config.optional_imports.base import _MissingProvider
from crewai.rag.config.optional_imports.providers import MissingChromaDBConfig
def test_missing_provider_raises_runtime_error():
"""Test that _MissingProvider raises RuntimeError on instantiation."""
with pytest.raises(
RuntimeError, match="provider '__missing__' requested but not installed"
):
_MissingProvider()
def test_missing_chromadb_config_raises_runtime_error():
"""Test that MissingChromaDBConfig raises RuntimeError on instantiation."""
with pytest.raises(
RuntimeError, match="provider 'chromadb' requested but not installed"
):
MissingChromaDBConfig()