From 7ac482c7c95b8746d69a9ab03aace9964f2bdfe9 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Tue, 26 Aug 2025 00:00:22 -0400 Subject: [PATCH] feat: rag configuration with optional dependency support (#3394) ### RAG Config System * Added ChromaDB client creation via config with sensible defaults * Introduced optional imports and shared RAG config utilities/schema * Enabled embedding function support with ChromaDB provider integration * Refactored configs for immutability and stronger type safety * Removed unused code and expanded test coverage --- src/crewai/rag/__init__.py | 59 ++++++++++++- src/crewai/rag/chromadb/config.py | 57 ++++++++++++ src/crewai/rag/chromadb/constants.py | 6 ++ src/crewai/rag/chromadb/factory.py | 26 ++++++ src/crewai/rag/chromadb/types.py | 17 ++++ src/crewai/rag/chromadb/utils.py | 2 - src/crewai/rag/config/__init__.py | 1 + src/crewai/rag/config/base.py | 16 ++++ src/crewai/rag/config/constants.py | 8 ++ src/crewai/rag/config/factory.py | 32 +++++++ .../rag/config/optional_imports/__init__.py | 1 + .../rag/config/optional_imports/base.py | 24 ++++++ .../rag/config/optional_imports/protocols.py | 14 +++ .../rag/config/optional_imports/providers.py | 15 ++++ .../rag/config/optional_imports/types.py | 8 ++ src/crewai/rag/config/types.py | 21 +++++ src/crewai/rag/config/utils.py | 86 +++++++++++++++++++ src/crewai/rag/core/base_client.py | 13 +++ src/crewai/rag/core/base_provider.py | 30 ------- tests/rag/config/test_factory.py | 34 ++++++++ tests/rag/config/test_optional_imports.py | 22 +++++ 21 files changed, 459 insertions(+), 33 deletions(-) create mode 100644 src/crewai/rag/chromadb/config.py create mode 100644 src/crewai/rag/chromadb/constants.py create mode 100644 src/crewai/rag/chromadb/factory.py create mode 100644 src/crewai/rag/config/__init__.py create mode 100644 src/crewai/rag/config/base.py create mode 100644 src/crewai/rag/config/constants.py create mode 100644 src/crewai/rag/config/factory.py create mode 100644 src/crewai/rag/config/optional_imports/__init__.py create mode 100644 src/crewai/rag/config/optional_imports/base.py create mode 100644 src/crewai/rag/config/optional_imports/protocols.py create mode 100644 src/crewai/rag/config/optional_imports/providers.py create mode 100644 src/crewai/rag/config/optional_imports/types.py create mode 100644 src/crewai/rag/config/types.py create mode 100644 src/crewai/rag/config/utils.py delete mode 100644 src/crewai/rag/core/base_provider.py create mode 100644 tests/rag/config/test_factory.py create mode 100644 tests/rag/config/test_optional_imports.py diff --git a/src/crewai/rag/__init__.py b/src/crewai/rag/__init__.py index 3aaee2cef..3b39accd5 100644 --- a/src/crewai/rag/__init__.py +++ b/src/crewai/rag/__init__.py @@ -1 +1,58 @@ -"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI.""" \ No newline at end of file +"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI.""" + +import sys +import importlib +from types import ModuleType +from typing import Any + +from crewai.rag.config.types import RagConfigType +from crewai.rag.config.utils import set_rag_config + + +_module_path = __path__ +_module_file = __file__ + +class _RagModule(ModuleType): + """Module wrapper to intercept attribute setting for config.""" + + __path__ = _module_path + __file__ = _module_file + + def __init__(self, module_name: str): + """Initialize the module wrapper. + + Args: + module_name: Name of the module. + """ + super().__init__(module_name) + + def __setattr__(self, name: str, value: RagConfigType) -> None: + """Set module attributes. + + Args: + name: Attribute name. + value: Attribute value. + """ + if name == "config": + return set_rag_config(value) + raise AttributeError(f"Setting attribute '{name}' is not allowed.") + + def __getattr__(self, name: str) -> Any: + """Get module attributes. + + Args: + name: Attribute name. + + Returns: + The requested attribute. + + Raises: + AttributeError: If attribute doesn't exist. + """ + try: + return importlib.import_module(f"{self.__name__}.{name}") + except ImportError: + raise AttributeError(f"module '{self.__name__}' has no attribute '{name}'") + + +sys.modules[__name__] = _RagModule(__name__) diff --git a/src/crewai/rag/chromadb/config.py b/src/crewai/rag/chromadb/config.py new file mode 100644 index 000000000..43b202c55 --- /dev/null +++ b/src/crewai/rag/chromadb/config.py @@ -0,0 +1,57 @@ +"""ChromaDB configuration model.""" + +import os +import warnings +from dataclasses import field +from typing import Literal, cast +from pydantic.dataclasses import dataclass as pyd_dataclass +from chromadb.config import Settings +from chromadb.utils.embedding_functions import DefaultEmbeddingFunction + +from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper +from crewai.utilities.paths import db_storage_path +from crewai.rag.config.base import BaseRagConfig +from crewai.rag.chromadb.constants import DEFAULT_TENANT, DEFAULT_DATABASE + + +warnings.filterwarnings( + "ignore", + message=".*Mixing V1 models and V2 models.*", + category=UserWarning, + module="pydantic._internal._generate_schema", +) + + +def _default_settings() -> Settings: + """Create default ChromaDB settings. + + Returns: + Settings with persistent storage and reset enabled. + """ + return Settings( + persist_directory=os.path.join(db_storage_path(), "chromadb"), + allow_reset=True, + is_persistent=True, + ) + + +def _default_embedding_function() -> ChromaEmbeddingFunctionWrapper: + """Create default ChromaDB embedding function. + + Returns: + Default embedding function cast to proper type. + """ + return cast(ChromaEmbeddingFunctionWrapper, DefaultEmbeddingFunction()) + + +@pyd_dataclass(frozen=True) +class ChromaDBConfig(BaseRagConfig): + """Configuration for ChromaDB client.""" + + provider: Literal["chromadb"] = field(default="chromadb", init=False) + tenant: str = DEFAULT_TENANT + database: str = DEFAULT_DATABASE + settings: Settings = field(default_factory=_default_settings) + embedding_function: ChromaEmbeddingFunctionWrapper | None = field( + default_factory=_default_embedding_function + ) diff --git a/src/crewai/rag/chromadb/constants.py b/src/crewai/rag/chromadb/constants.py new file mode 100644 index 000000000..d9c585b6f --- /dev/null +++ b/src/crewai/rag/chromadb/constants.py @@ -0,0 +1,6 @@ +"""Constants for ChromaDB configuration.""" + +from typing import Final + +DEFAULT_TENANT: Final[str] = "default_tenant" +DEFAULT_DATABASE: Final[str] = "default_database" diff --git a/src/crewai/rag/chromadb/factory.py b/src/crewai/rag/chromadb/factory.py new file mode 100644 index 000000000..fff9f2dc1 --- /dev/null +++ b/src/crewai/rag/chromadb/factory.py @@ -0,0 +1,26 @@ +"""Factory functions for creating ChromaDB clients.""" + +from chromadb import Client + +from crewai.rag.chromadb.config import ChromaDBConfig +from crewai.rag.chromadb.client import ChromaDBClient + + +def create_client(config: ChromaDBConfig) -> ChromaDBClient: + """Create a ChromaDBClient from configuration. + + Args: + config: ChromaDB configuration object. + + Returns: + Configured ChromaDBClient instance. + """ + chromadb_client = Client( + settings=config.settings, tenant=config.tenant, database=config.database + ) + + client = ChromaDBClient() + client.client = chromadb_client + client.embedding_function = config.embedding_function + + return client diff --git a/src/crewai/rag/chromadb/types.py b/src/crewai/rag/chromadb/types.py index 54a03df39..11c480ea3 100644 --- a/src/crewai/rag/chromadb/types.py +++ b/src/crewai/rag/chromadb/types.py @@ -3,6 +3,8 @@ from collections.abc import Mapping from typing import Any, NamedTuple +from pydantic import GetCoreSchemaHandler +from pydantic_core import CoreSchema, core_schema from chromadb.api import ClientAPI, AsyncClientAPI from chromadb.api.configuration import CollectionConfigurationInterface from chromadb.api.types import ( @@ -21,6 +23,21 @@ from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSear ChromaDBClientType = ClientAPI | AsyncClientAPI +class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction[Embeddable]): + """Base class for ChromaDB EmbeddingFunction to work with Pydantic validation.""" + + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: GetCoreSchemaHandler + ) -> CoreSchema: + """Generate Pydantic core schema for ChromaDB EmbeddingFunction. + + This allows Pydantic to handle ChromaDB's EmbeddingFunction type + without requiring arbitrary_types_allowed=True. + """ + return core_schema.any_schema() + + class PreparedDocuments(NamedTuple): """Prepared documents ready for ChromaDB insertion. diff --git a/src/crewai/rag/chromadb/utils.py b/src/crewai/rag/chromadb/utils.py index f7e5c4ebd..1226be80c 100644 --- a/src/crewai/rag/chromadb/utils.py +++ b/src/crewai/rag/chromadb/utils.py @@ -10,10 +10,8 @@ from chromadb.api.types import ( IncludeEnum, QueryResult, ) - from chromadb.api.models.AsyncCollection import AsyncCollection from chromadb.api.models.Collection import Collection - from crewai.rag.chromadb.types import ( ChromaDBClientType, ChromaDBCollectionSearchParams, diff --git a/src/crewai/rag/config/__init__.py b/src/crewai/rag/config/__init__.py new file mode 100644 index 000000000..7a43cb46b --- /dev/null +++ b/src/crewai/rag/config/__init__.py @@ -0,0 +1 @@ +"""RAG client configuration management using ContextVars for thread-safe provider switching.""" diff --git a/src/crewai/rag/config/base.py b/src/crewai/rag/config/base.py new file mode 100644 index 000000000..b287b6ea6 --- /dev/null +++ b/src/crewai/rag/config/base.py @@ -0,0 +1,16 @@ +"""Base configuration class for RAG providers.""" + +from dataclasses import field +from typing import Any + +from pydantic.dataclasses import dataclass as pyd_dataclass + +from crewai.rag.config.optional_imports.types import SupportedProvider + + +@pyd_dataclass(frozen=True) +class BaseRagConfig: + """Base class for RAG configuration with Pydantic serialization support.""" + + provider: SupportedProvider = field(init=False) + embedding_function: Any | None = field(default=None) diff --git a/src/crewai/rag/config/constants.py b/src/crewai/rag/config/constants.py new file mode 100644 index 000000000..d0d360db1 --- /dev/null +++ b/src/crewai/rag/config/constants.py @@ -0,0 +1,8 @@ +"""Constants for RAG configuration.""" + +from typing import Final + +DISCRIMINATOR: Final[str] = "provider" + +DEFAULT_RAG_CONFIG_PATH: Final[str] = "crewai.rag.chromadb.config" +DEFAULT_RAG_CONFIG_CLASS: Final[str] = "ChromaDBConfig" diff --git a/src/crewai/rag/config/factory.py b/src/crewai/rag/config/factory.py new file mode 100644 index 000000000..1f34d6317 --- /dev/null +++ b/src/crewai/rag/config/factory.py @@ -0,0 +1,32 @@ +"""Factory functions for creating RAG clients from configuration.""" + +from typing import cast + +from crewai.rag.config.optional_imports.protocols import ChromaFactoryModule +from crewai.rag.core.base_client import BaseClient +from crewai.rag.config.types import RagConfigType +from crewai.utilities.import_utils import require + + +def create_client(config: RagConfigType) -> BaseClient: + """Create a client from configuration using the appropriate factory. + + Args: + config: The RAG client configuration. + + Returns: + The created client instance. + + Raises: + ValueError: If the configuration provider is not supported. + """ + + if config.provider == "chromadb": + mod = cast( + ChromaFactoryModule, + require( + "crewai.rag.chromadb.factory", + purpose="The 'chromadb' provider", + ), + ) + return mod.create_client(config) diff --git a/src/crewai/rag/config/optional_imports/__init__.py b/src/crewai/rag/config/optional_imports/__init__.py new file mode 100644 index 000000000..ad6a61f92 --- /dev/null +++ b/src/crewai/rag/config/optional_imports/__init__.py @@ -0,0 +1 @@ +"""Optional imports for RAG configuration providers.""" \ No newline at end of file diff --git a/src/crewai/rag/config/optional_imports/base.py b/src/crewai/rag/config/optional_imports/base.py new file mode 100644 index 000000000..abb35e0bc --- /dev/null +++ b/src/crewai/rag/config/optional_imports/base.py @@ -0,0 +1,24 @@ +"""Base classes for missing provider configurations.""" + +from typing import Literal +from dataclasses import field + +from pydantic import ConfigDict +from pydantic.dataclasses import dataclass as pyd_dataclass + + +@pyd_dataclass(config=ConfigDict(extra="forbid")) +class _MissingProvider: + """Base class for missing provider configurations. + + Raises RuntimeError when instantiated to indicate missing dependencies. + """ + + provider: Literal["chromadb", "__missing__"] = field(default="__missing__") + + def __post_init__(self) -> None: + """Raises error indicating the provider is not installed.""" + raise RuntimeError( + f"provider '{self.provider}' requested but not installed. " + f"Install the extra: `uv add crewai'[{self.provider}]'`." + ) diff --git a/src/crewai/rag/config/optional_imports/protocols.py b/src/crewai/rag/config/optional_imports/protocols.py new file mode 100644 index 000000000..2e16b50c8 --- /dev/null +++ b/src/crewai/rag/config/optional_imports/protocols.py @@ -0,0 +1,14 @@ +"""Protocol definitions for RAG factory modules.""" + +from typing import Protocol + +from crewai.rag.config.types import RagConfigType +from crewai.rag.core.base_client import BaseClient + + +class ChromaFactoryModule(Protocol): + """Protocol for ChromaDB factory module.""" + + def create_client(self, config: RagConfigType) -> BaseClient: + """Creates a ChromaDB client from configuration.""" + ... diff --git a/src/crewai/rag/config/optional_imports/providers.py b/src/crewai/rag/config/optional_imports/providers.py new file mode 100644 index 000000000..0d774e26b --- /dev/null +++ b/src/crewai/rag/config/optional_imports/providers.py @@ -0,0 +1,15 @@ +"""Provider-specific missing configuration classes.""" + +from typing import Literal +from dataclasses import field +from pydantic import ConfigDict +from pydantic.dataclasses import dataclass as pyd_dataclass + +from crewai.rag.config.optional_imports.base import _MissingProvider + + +@pyd_dataclass(config=ConfigDict(extra="forbid")) +class MissingChromaDBConfig(_MissingProvider): + """Placeholder for missing ChromaDB configuration.""" + + provider: Literal["chromadb"] = field(default="chromadb") diff --git a/src/crewai/rag/config/optional_imports/types.py b/src/crewai/rag/config/optional_imports/types.py new file mode 100644 index 000000000..dbd169cab --- /dev/null +++ b/src/crewai/rag/config/optional_imports/types.py @@ -0,0 +1,8 @@ +"""Type definitions for optional imports.""" + +from typing import Annotated, Literal + +SupportedProvider = Annotated[ + Literal["chromadb"], + "Supported RAG provider types, add providers here as they become available", +] diff --git a/src/crewai/rag/config/types.py b/src/crewai/rag/config/types.py new file mode 100644 index 000000000..ddd9f5268 --- /dev/null +++ b/src/crewai/rag/config/types.py @@ -0,0 +1,21 @@ +"""Type definitions for RAG configuration.""" + +from typing import TYPE_CHECKING, Annotated, TypeAlias +from pydantic import Field + +from crewai.rag.config.constants import DISCRIMINATOR + +if TYPE_CHECKING: + from crewai.rag.chromadb.config import ChromaDBConfig +else: + try: + from crewai.rag.chromadb.config import ChromaDBConfig + except ImportError: + from crewai.rag.config.optional_imports.providers import ( + MissingChromaDBConfig as ChromaDBConfig, + ) + +SupportedProviderConfig: TypeAlias = ChromaDBConfig +RagConfigType: TypeAlias = Annotated[ + SupportedProviderConfig, Field(discriminator=DISCRIMINATOR) +] diff --git a/src/crewai/rag/config/utils.py b/src/crewai/rag/config/utils.py new file mode 100644 index 000000000..0eaef87f1 --- /dev/null +++ b/src/crewai/rag/config/utils.py @@ -0,0 +1,86 @@ +"""RAG client configuration utilities.""" + +from contextvars import ContextVar + +from pydantic import BaseModel, Field + +from crewai.utilities.import_utils import require +from crewai.rag.core.base_client import BaseClient +from crewai.rag.config.types import RagConfigType +from crewai.rag.config.constants import ( + DEFAULT_RAG_CONFIG_PATH, + DEFAULT_RAG_CONFIG_CLASS, +) +from crewai.rag.config.factory import create_client + + +class RagContext(BaseModel): + """Context holding RAG configuration and client instance.""" + + config: RagConfigType = Field(..., description="RAG provider configuration") + client: BaseClient | None = Field( + default=None, description="Instantiated RAG client" + ) + + +_rag_context: ContextVar[RagContext | None] = ContextVar("_rag_context", default=None) + + +def set_rag_config(config: RagConfigType) -> None: + """Set global RAG client configuration and instantiate the client. + + Args: + config: The RAG client configuration (ChromaDBConfig). + """ + client = create_client(config) + context = RagContext(config=config, client=client) + _rag_context.set(context) + + +def get_rag_config() -> RagConfigType: + """Get current RAG configuration. + + Returns: + The current RAG configuration object. + """ + context = _rag_context.get() + if context is None: + module = require(DEFAULT_RAG_CONFIG_PATH, purpose="RAG configuration") + config_class = getattr(module, DEFAULT_RAG_CONFIG_CLASS) + default_config = config_class() + set_rag_config(default_config) + context = _rag_context.get() + + if context is None or context.config is None: + raise ValueError( + "RAG configuration is not set. Please set the RAG config first." + ) + + return context.config + + +def get_rag_client() -> BaseClient: + """Get the current RAG client instance. + + Returns: + The current RAG client, creating one if needed. + """ + context = _rag_context.get() + if context is None: + get_rag_config() + context = _rag_context.get() + + if context and context.client is None: + context.client = create_client(context.config) + + if context is None or context.client is None: + raise ValueError( + "RAG client is not configured. Please set the RAG config first." + ) + + return context.client + + +def clear_rag_config() -> None: + """Clear the current RAG configuration and client, reverting to defaults.""" + _rag_context.set(None) diff --git a/src/crewai/rag/core/base_client.py b/src/crewai/rag/core/base_client.py index c3bdcd3b0..6fa4346e1 100644 --- a/src/crewai/rag/core/base_client.py +++ b/src/crewai/rag/core/base_client.py @@ -3,6 +3,8 @@ from abc import abstractmethod from typing import Any, Protocol, runtime_checkable, TypedDict, Annotated from typing_extensions import Unpack, Required +from pydantic import GetCoreSchemaHandler +from pydantic_core import CoreSchema, core_schema from crewai.rag.types import ( @@ -96,6 +98,17 @@ class BaseClient(Protocol): client: Any embedding_function: EmbeddingFunction + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: GetCoreSchemaHandler + ) -> CoreSchema: + """Generate Pydantic core schema for BaseClient Protocol. + + This allows the Protocol to be used in Pydantic models without + requiring arbitrary_types_allowed=True. + """ + return core_schema.any_schema() + @abstractmethod def create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None: """Create a new collection/index in the vector database. diff --git a/src/crewai/rag/core/base_provider.py b/src/crewai/rag/core/base_provider.py deleted file mode 100644 index 0651ce540..000000000 --- a/src/crewai/rag/core/base_provider.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Base provider protocol for vector database client creation.""" - -from abc import ABC -from typing import Any, Protocol, runtime_checkable, Union -from pydantic import BaseModel, Field - -from crewai.rag.types import EmbeddingFunction -from crewai.rag.embeddings.types import EmbeddingOptions - - -class BaseProviderOptions(BaseModel, ABC): - """Base configuration for all provider options.""" - - client_type: str = Field(..., description="Type of client to create") - embedding_config: Union[EmbeddingOptions, EmbeddingFunction, None] = Field( - default=None, - description="Embedding configuration - either options for built-in providers or a custom function", - ) - options: Any = Field( - default=None, description="Additional provider-specific options" - ) - - -@runtime_checkable -class BaseProvider(Protocol): - """Protocol for vector database client providers.""" - - def __call__(self, options: BaseProviderOptions) -> Any: - """Create and return a configured client instance.""" - ... diff --git a/tests/rag/config/test_factory.py b/tests/rag/config/test_factory.py new file mode 100644 index 000000000..91f30329e --- /dev/null +++ b/tests/rag/config/test_factory.py @@ -0,0 +1,34 @@ +"""Tests for RAG config factory.""" + +from unittest.mock import Mock, patch + +from crewai.rag.config.factory import create_client + + +def test_create_client_chromadb(): + """Test ChromaDB client creation.""" + mock_config = Mock() + mock_config.provider = "chromadb" + + with patch("crewai.rag.config.factory.require") as mock_require: + mock_module = Mock() + mock_client = Mock() + mock_module.create_client.return_value = mock_client + mock_require.return_value = mock_module + + result = create_client(mock_config) + + assert result == mock_client + mock_require.assert_called_once_with( + "crewai.rag.chromadb.factory", purpose="The 'chromadb' provider" + ) + mock_module.create_client.assert_called_once_with(mock_config) + + +def test_create_client_unsupported_provider(): + """Test unsupported provider returns None for now.""" + mock_config = Mock() + mock_config.provider = "unsupported" + + result = create_client(mock_config) + assert result is None diff --git a/tests/rag/config/test_optional_imports.py b/tests/rag/config/test_optional_imports.py new file mode 100644 index 000000000..11dad9855 --- /dev/null +++ b/tests/rag/config/test_optional_imports.py @@ -0,0 +1,22 @@ +"""Tests for optional imports.""" + +import pytest + +from crewai.rag.config.optional_imports.base import _MissingProvider +from crewai.rag.config.optional_imports.providers import MissingChromaDBConfig + + +def test_missing_provider_raises_runtime_error(): + """Test that _MissingProvider raises RuntimeError on instantiation.""" + with pytest.raises( + RuntimeError, match="provider '__missing__' requested but not installed" + ): + _MissingProvider() + + +def test_missing_chromadb_config_raises_runtime_error(): + """Test that MissingChromaDBConfig raises RuntimeError on instantiation.""" + with pytest.raises( + RuntimeError, match="provider 'chromadb' requested but not installed" + ): + MissingChromaDBConfig()