mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 09:08:31 +00:00
feat: rag configuration with optional dependency support (#3394)
### RAG Config System * Added ChromaDB client creation via config with sensible defaults * Introduced optional imports and shared RAG config utilities/schema * Enabled embedding function support with ChromaDB provider integration * Refactored configs for immutability and stronger type safety * Removed unused code and expanded test coverage
This commit is contained in:
@@ -1 +1,58 @@
|
||||
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
|
||||
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
|
||||
|
||||
import sys
|
||||
import importlib
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.config.utils import set_rag_config
|
||||
|
||||
|
||||
_module_path = __path__
|
||||
_module_file = __file__
|
||||
|
||||
class _RagModule(ModuleType):
|
||||
"""Module wrapper to intercept attribute setting for config."""
|
||||
|
||||
__path__ = _module_path
|
||||
__file__ = _module_file
|
||||
|
||||
def __init__(self, module_name: str):
|
||||
"""Initialize the module wrapper.
|
||||
|
||||
Args:
|
||||
module_name: Name of the module.
|
||||
"""
|
||||
super().__init__(module_name)
|
||||
|
||||
def __setattr__(self, name: str, value: RagConfigType) -> None:
|
||||
"""Set module attributes.
|
||||
|
||||
Args:
|
||||
name: Attribute name.
|
||||
value: Attribute value.
|
||||
"""
|
||||
if name == "config":
|
||||
return set_rag_config(value)
|
||||
raise AttributeError(f"Setting attribute '{name}' is not allowed.")
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Get module attributes.
|
||||
|
||||
Args:
|
||||
name: Attribute name.
|
||||
|
||||
Returns:
|
||||
The requested attribute.
|
||||
|
||||
Raises:
|
||||
AttributeError: If attribute doesn't exist.
|
||||
"""
|
||||
try:
|
||||
return importlib.import_module(f"{self.__name__}.{name}")
|
||||
except ImportError:
|
||||
raise AttributeError(f"module '{self.__name__}' has no attribute '{name}'")
|
||||
|
||||
|
||||
sys.modules[__name__] = _RagModule(__name__)
|
||||
|
||||
57
src/crewai/rag/chromadb/config.py
Normal file
57
src/crewai/rag/chromadb/config.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""ChromaDB configuration model."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import field
|
||||
from typing import Literal, cast
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
from chromadb.config import Settings
|
||||
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
|
||||
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from crewai.rag.config.base import BaseRagConfig
|
||||
from crewai.rag.chromadb.constants import DEFAULT_TENANT, DEFAULT_DATABASE
|
||||
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=".*Mixing V1 models and V2 models.*",
|
||||
category=UserWarning,
|
||||
module="pydantic._internal._generate_schema",
|
||||
)
|
||||
|
||||
|
||||
def _default_settings() -> Settings:
|
||||
"""Create default ChromaDB settings.
|
||||
|
||||
Returns:
|
||||
Settings with persistent storage and reset enabled.
|
||||
"""
|
||||
return Settings(
|
||||
persist_directory=os.path.join(db_storage_path(), "chromadb"),
|
||||
allow_reset=True,
|
||||
is_persistent=True,
|
||||
)
|
||||
|
||||
|
||||
def _default_embedding_function() -> ChromaEmbeddingFunctionWrapper:
|
||||
"""Create default ChromaDB embedding function.
|
||||
|
||||
Returns:
|
||||
Default embedding function cast to proper type.
|
||||
"""
|
||||
return cast(ChromaEmbeddingFunctionWrapper, DefaultEmbeddingFunction())
|
||||
|
||||
|
||||
@pyd_dataclass(frozen=True)
|
||||
class ChromaDBConfig(BaseRagConfig):
|
||||
"""Configuration for ChromaDB client."""
|
||||
|
||||
provider: Literal["chromadb"] = field(default="chromadb", init=False)
|
||||
tenant: str = DEFAULT_TENANT
|
||||
database: str = DEFAULT_DATABASE
|
||||
settings: Settings = field(default_factory=_default_settings)
|
||||
embedding_function: ChromaEmbeddingFunctionWrapper | None = field(
|
||||
default_factory=_default_embedding_function
|
||||
)
|
||||
6
src/crewai/rag/chromadb/constants.py
Normal file
6
src/crewai/rag/chromadb/constants.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Constants for ChromaDB configuration."""
|
||||
|
||||
from typing import Final
|
||||
|
||||
DEFAULT_TENANT: Final[str] = "default_tenant"
|
||||
DEFAULT_DATABASE: Final[str] = "default_database"
|
||||
26
src/crewai/rag/chromadb/factory.py
Normal file
26
src/crewai/rag/chromadb/factory.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Factory functions for creating ChromaDB clients."""
|
||||
|
||||
from chromadb import Client
|
||||
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.client import ChromaDBClient
|
||||
|
||||
|
||||
def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
"""Create a ChromaDBClient from configuration.
|
||||
|
||||
Args:
|
||||
config: ChromaDB configuration object.
|
||||
|
||||
Returns:
|
||||
Configured ChromaDBClient instance.
|
||||
"""
|
||||
chromadb_client = Client(
|
||||
settings=config.settings, tenant=config.tenant, database=config.database
|
||||
)
|
||||
|
||||
client = ChromaDBClient()
|
||||
client.client = chromadb_client
|
||||
client.embedding_function = config.embedding_function
|
||||
|
||||
return client
|
||||
@@ -3,6 +3,8 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from chromadb.api import ClientAPI, AsyncClientAPI
|
||||
from chromadb.api.configuration import CollectionConfigurationInterface
|
||||
from chromadb.api.types import (
|
||||
@@ -21,6 +23,21 @@ from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSear
|
||||
ChromaDBClientType = ClientAPI | AsyncClientAPI
|
||||
|
||||
|
||||
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction[Embeddable]):
|
||||
"""Base class for ChromaDB EmbeddingFunction to work with Pydantic validation."""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""Generate Pydantic core schema for ChromaDB EmbeddingFunction.
|
||||
|
||||
This allows Pydantic to handle ChromaDB's EmbeddingFunction type
|
||||
without requiring arbitrary_types_allowed=True.
|
||||
"""
|
||||
return core_schema.any_schema()
|
||||
|
||||
|
||||
class PreparedDocuments(NamedTuple):
|
||||
"""Prepared documents ready for ChromaDB insertion.
|
||||
|
||||
|
||||
@@ -10,10 +10,8 @@ from chromadb.api.types import (
|
||||
IncludeEnum,
|
||||
QueryResult,
|
||||
)
|
||||
|
||||
from chromadb.api.models.AsyncCollection import AsyncCollection
|
||||
from chromadb.api.models.Collection import Collection
|
||||
|
||||
from crewai.rag.chromadb.types import (
|
||||
ChromaDBClientType,
|
||||
ChromaDBCollectionSearchParams,
|
||||
|
||||
1
src/crewai/rag/config/__init__.py
Normal file
1
src/crewai/rag/config/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""RAG client configuration management using ContextVars for thread-safe provider switching."""
|
||||
16
src/crewai/rag/config/base.py
Normal file
16
src/crewai/rag/config/base.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Base configuration class for RAG providers."""
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Any
|
||||
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
from crewai.rag.config.optional_imports.types import SupportedProvider
|
||||
|
||||
|
||||
@pyd_dataclass(frozen=True)
|
||||
class BaseRagConfig:
|
||||
"""Base class for RAG configuration with Pydantic serialization support."""
|
||||
|
||||
provider: SupportedProvider = field(init=False)
|
||||
embedding_function: Any | None = field(default=None)
|
||||
8
src/crewai/rag/config/constants.py
Normal file
8
src/crewai/rag/config/constants.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Constants for RAG configuration."""
|
||||
|
||||
from typing import Final
|
||||
|
||||
DISCRIMINATOR: Final[str] = "provider"
|
||||
|
||||
DEFAULT_RAG_CONFIG_PATH: Final[str] = "crewai.rag.chromadb.config"
|
||||
DEFAULT_RAG_CONFIG_CLASS: Final[str] = "ChromaDBConfig"
|
||||
32
src/crewai/rag/config/factory.py
Normal file
32
src/crewai/rag/config/factory.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Factory functions for creating RAG clients from configuration."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
from crewai.rag.config.optional_imports.protocols import ChromaFactoryModule
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.utilities.import_utils import require
|
||||
|
||||
|
||||
def create_client(config: RagConfigType) -> BaseClient:
|
||||
"""Create a client from configuration using the appropriate factory.
|
||||
|
||||
Args:
|
||||
config: The RAG client configuration.
|
||||
|
||||
Returns:
|
||||
The created client instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If the configuration provider is not supported.
|
||||
"""
|
||||
|
||||
if config.provider == "chromadb":
|
||||
mod = cast(
|
||||
ChromaFactoryModule,
|
||||
require(
|
||||
"crewai.rag.chromadb.factory",
|
||||
purpose="The 'chromadb' provider",
|
||||
),
|
||||
)
|
||||
return mod.create_client(config)
|
||||
1
src/crewai/rag/config/optional_imports/__init__.py
Normal file
1
src/crewai/rag/config/optional_imports/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Optional imports for RAG configuration providers."""
|
||||
24
src/crewai/rag/config/optional_imports/base.py
Normal file
24
src/crewai/rag/config/optional_imports/base.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Base classes for missing provider configurations."""
|
||||
|
||||
from typing import Literal
|
||||
from dataclasses import field
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
|
||||
@pyd_dataclass(config=ConfigDict(extra="forbid"))
|
||||
class _MissingProvider:
|
||||
"""Base class for missing provider configurations.
|
||||
|
||||
Raises RuntimeError when instantiated to indicate missing dependencies.
|
||||
"""
|
||||
|
||||
provider: Literal["chromadb", "__missing__"] = field(default="__missing__")
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Raises error indicating the provider is not installed."""
|
||||
raise RuntimeError(
|
||||
f"provider '{self.provider}' requested but not installed. "
|
||||
f"Install the extra: `uv add crewai'[{self.provider}]'`."
|
||||
)
|
||||
14
src/crewai/rag/config/optional_imports/protocols.py
Normal file
14
src/crewai/rag/config/optional_imports/protocols.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Protocol definitions for RAG factory modules."""
|
||||
|
||||
from typing import Protocol
|
||||
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
|
||||
|
||||
class ChromaFactoryModule(Protocol):
|
||||
"""Protocol for ChromaDB factory module."""
|
||||
|
||||
def create_client(self, config: RagConfigType) -> BaseClient:
|
||||
"""Creates a ChromaDB client from configuration."""
|
||||
...
|
||||
15
src/crewai/rag/config/optional_imports/providers.py
Normal file
15
src/crewai/rag/config/optional_imports/providers.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Provider-specific missing configuration classes."""
|
||||
|
||||
from typing import Literal
|
||||
from dataclasses import field
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
from crewai.rag.config.optional_imports.base import _MissingProvider
|
||||
|
||||
|
||||
@pyd_dataclass(config=ConfigDict(extra="forbid"))
|
||||
class MissingChromaDBConfig(_MissingProvider):
|
||||
"""Placeholder for missing ChromaDB configuration."""
|
||||
|
||||
provider: Literal["chromadb"] = field(default="chromadb")
|
||||
8
src/crewai/rag/config/optional_imports/types.py
Normal file
8
src/crewai/rag/config/optional_imports/types.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Type definitions for optional imports."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
SupportedProvider = Annotated[
|
||||
Literal["chromadb"],
|
||||
"Supported RAG provider types, add providers here as they become available",
|
||||
]
|
||||
21
src/crewai/rag/config/types.py
Normal file
21
src/crewai/rag/config/types.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Type definitions for RAG configuration."""
|
||||
|
||||
from typing import TYPE_CHECKING, Annotated, TypeAlias
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.config.constants import DISCRIMINATOR
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
else:
|
||||
try:
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
except ImportError:
|
||||
from crewai.rag.config.optional_imports.providers import (
|
||||
MissingChromaDBConfig as ChromaDBConfig,
|
||||
)
|
||||
|
||||
SupportedProviderConfig: TypeAlias = ChromaDBConfig
|
||||
RagConfigType: TypeAlias = Annotated[
|
||||
SupportedProviderConfig, Field(discriminator=DISCRIMINATOR)
|
||||
]
|
||||
86
src/crewai/rag/config/utils.py
Normal file
86
src/crewai/rag/config/utils.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""RAG client configuration utilities."""
|
||||
|
||||
from contextvars import ContextVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.utilities.import_utils import require
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.config.constants import (
|
||||
DEFAULT_RAG_CONFIG_PATH,
|
||||
DEFAULT_RAG_CONFIG_CLASS,
|
||||
)
|
||||
from crewai.rag.config.factory import create_client
|
||||
|
||||
|
||||
class RagContext(BaseModel):
|
||||
"""Context holding RAG configuration and client instance."""
|
||||
|
||||
config: RagConfigType = Field(..., description="RAG provider configuration")
|
||||
client: BaseClient | None = Field(
|
||||
default=None, description="Instantiated RAG client"
|
||||
)
|
||||
|
||||
|
||||
_rag_context: ContextVar[RagContext | None] = ContextVar("_rag_context", default=None)
|
||||
|
||||
|
||||
def set_rag_config(config: RagConfigType) -> None:
|
||||
"""Set global RAG client configuration and instantiate the client.
|
||||
|
||||
Args:
|
||||
config: The RAG client configuration (ChromaDBConfig).
|
||||
"""
|
||||
client = create_client(config)
|
||||
context = RagContext(config=config, client=client)
|
||||
_rag_context.set(context)
|
||||
|
||||
|
||||
def get_rag_config() -> RagConfigType:
|
||||
"""Get current RAG configuration.
|
||||
|
||||
Returns:
|
||||
The current RAG configuration object.
|
||||
"""
|
||||
context = _rag_context.get()
|
||||
if context is None:
|
||||
module = require(DEFAULT_RAG_CONFIG_PATH, purpose="RAG configuration")
|
||||
config_class = getattr(module, DEFAULT_RAG_CONFIG_CLASS)
|
||||
default_config = config_class()
|
||||
set_rag_config(default_config)
|
||||
context = _rag_context.get()
|
||||
|
||||
if context is None or context.config is None:
|
||||
raise ValueError(
|
||||
"RAG configuration is not set. Please set the RAG config first."
|
||||
)
|
||||
|
||||
return context.config
|
||||
|
||||
|
||||
def get_rag_client() -> BaseClient:
|
||||
"""Get the current RAG client instance.
|
||||
|
||||
Returns:
|
||||
The current RAG client, creating one if needed.
|
||||
"""
|
||||
context = _rag_context.get()
|
||||
if context is None:
|
||||
get_rag_config()
|
||||
context = _rag_context.get()
|
||||
|
||||
if context and context.client is None:
|
||||
context.client = create_client(context.config)
|
||||
|
||||
if context is None or context.client is None:
|
||||
raise ValueError(
|
||||
"RAG client is not configured. Please set the RAG config first."
|
||||
)
|
||||
|
||||
return context.client
|
||||
|
||||
|
||||
def clear_rag_config() -> None:
|
||||
"""Clear the current RAG configuration and client, reverting to defaults."""
|
||||
_rag_context.set(None)
|
||||
@@ -3,6 +3,8 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Protocol, runtime_checkable, TypedDict, Annotated
|
||||
from typing_extensions import Unpack, Required
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
|
||||
from crewai.rag.types import (
|
||||
@@ -96,6 +98,17 @@ class BaseClient(Protocol):
|
||||
client: Any
|
||||
embedding_function: EmbeddingFunction
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""Generate Pydantic core schema for BaseClient Protocol.
|
||||
|
||||
This allows the Protocol to be used in Pydantic models without
|
||||
requiring arbitrary_types_allowed=True.
|
||||
"""
|
||||
return core_schema.any_schema()
|
||||
|
||||
@abstractmethod
|
||||
def create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Create a new collection/index in the vector database.
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
"""Base provider protocol for vector database client creation."""
|
||||
|
||||
from abc import ABC
|
||||
from typing import Any, Protocol, runtime_checkable, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.rag.types import EmbeddingFunction
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||
|
||||
|
||||
class BaseProviderOptions(BaseModel, ABC):
|
||||
"""Base configuration for all provider options."""
|
||||
|
||||
client_type: str = Field(..., description="Type of client to create")
|
||||
embedding_config: Union[EmbeddingOptions, EmbeddingFunction, None] = Field(
|
||||
default=None,
|
||||
description="Embedding configuration - either options for built-in providers or a custom function",
|
||||
)
|
||||
options: Any = Field(
|
||||
default=None, description="Additional provider-specific options"
|
||||
)
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class BaseProvider(Protocol):
|
||||
"""Protocol for vector database client providers."""
|
||||
|
||||
def __call__(self, options: BaseProviderOptions) -> Any:
|
||||
"""Create and return a configured client instance."""
|
||||
...
|
||||
Reference in New Issue
Block a user