Qdrant RAG Provider Support (#3400)

* Added Qdrant provider support with factory, config, and protocols
* Improved default embeddings and type definitions
* Fixed ChromaDB factory embedding assignment
This commit is contained in:
Greyson LaLonde
2025-08-26 08:44:02 -04:00
committed by GitHub
parent 7ac482c7c9
commit 869bb115c8
14 changed files with 175 additions and 24 deletions

View File

@@ -1,6 +1,5 @@
"""ChromaDB configuration model.""" """ChromaDB configuration model."""
import os
import warnings import warnings
from dataclasses import field from dataclasses import field
from typing import Literal, cast from typing import Literal, cast
@@ -9,9 +8,12 @@ from chromadb.config import Settings
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.utilities.paths import db_storage_path
from crewai.rag.config.base import BaseRagConfig from crewai.rag.config.base import BaseRagConfig
from crewai.rag.chromadb.constants import DEFAULT_TENANT, DEFAULT_DATABASE from crewai.rag.chromadb.constants import (
DEFAULT_TENANT,
DEFAULT_DATABASE,
DEFAULT_STORAGE_PATH,
)
warnings.filterwarnings( warnings.filterwarnings(
@@ -29,7 +31,7 @@ def _default_settings() -> Settings:
Settings with persistent storage and reset enabled. Settings with persistent storage and reset enabled.
""" """
return Settings( return Settings(
persist_directory=os.path.join(db_storage_path(), "chromadb"), persist_directory=DEFAULT_STORAGE_PATH,
allow_reset=True, allow_reset=True,
is_persistent=True, is_persistent=True,
) )
@@ -39,7 +41,7 @@ def _default_embedding_function() -> ChromaEmbeddingFunctionWrapper:
"""Create default ChromaDB embedding function. """Create default ChromaDB embedding function.
Returns: Returns:
Default embedding function cast to proper type. Default embedding function using all-MiniLM-L6-v2 via ONNX.
""" """
return cast(ChromaEmbeddingFunctionWrapper, DefaultEmbeddingFunction()) return cast(ChromaEmbeddingFunctionWrapper, DefaultEmbeddingFunction())
@@ -52,6 +54,6 @@ class ChromaDBConfig(BaseRagConfig):
tenant: str = DEFAULT_TENANT tenant: str = DEFAULT_TENANT
database: str = DEFAULT_DATABASE database: str = DEFAULT_DATABASE
settings: Settings = field(default_factory=_default_settings) settings: Settings = field(default_factory=_default_settings)
embedding_function: ChromaEmbeddingFunctionWrapper | None = field( embedding_function: ChromaEmbeddingFunctionWrapper = field(
default_factory=_default_embedding_function default_factory=_default_embedding_function
) )

View File

@@ -1,6 +1,10 @@
"""Constants for ChromaDB configuration.""" """Constants for ChromaDB configuration."""
import os
from typing import Final from typing import Final
from crewai.utilities.paths import db_storage_path
DEFAULT_TENANT: Final[str] = "default_tenant" DEFAULT_TENANT: Final[str] = "default_tenant"
DEFAULT_DATABASE: Final[str] = "default_database" DEFAULT_DATABASE: Final[str] = "default_database"
DEFAULT_STORAGE_PATH: Final[str] = os.path.join(db_storage_path(), "chromadb")

View File

@@ -2,7 +2,10 @@
from typing import cast from typing import cast
from crewai.rag.config.optional_imports.protocols import ChromaFactoryModule from crewai.rag.config.optional_imports.protocols import (
ChromaFactoryModule,
QdrantFactoryModule,
)
from crewai.rag.core.base_client import BaseClient from crewai.rag.core.base_client import BaseClient
from crewai.rag.config.types import RagConfigType from crewai.rag.config.types import RagConfigType
from crewai.utilities.import_utils import require from crewai.utilities.import_utils import require
@@ -22,11 +25,21 @@ def create_client(config: RagConfigType) -> BaseClient:
""" """
if config.provider == "chromadb": if config.provider == "chromadb":
mod = cast( chromadb_mod = cast(
ChromaFactoryModule, ChromaFactoryModule,
require( require(
"crewai.rag.chromadb.factory", "crewai.rag.chromadb.factory",
purpose="The 'chromadb' provider", purpose="The 'chromadb' provider",
), ),
) )
return mod.create_client(config) return chromadb_mod.create_client(config)
if config.provider == "qdrant":
qdrant_mod = cast(
QdrantFactoryModule,
require(
"crewai.rag.qdrant.factory",
purpose="The 'qdrant' provider",
),
)
return qdrant_mod.create_client(config)

View File

@@ -14,7 +14,9 @@ class _MissingProvider:
Raises RuntimeError when instantiated to indicate missing dependencies. Raises RuntimeError when instantiated to indicate missing dependencies.
""" """
provider: Literal["chromadb", "__missing__"] = field(default="__missing__") provider: Literal["chromadb", "qdrant", "__missing__"] = field(
default="__missing__"
)
def __post_init__(self) -> None: def __post_init__(self) -> None:
"""Raises error indicating the provider is not installed.""" """Raises error indicating the provider is not installed."""

View File

@@ -1,14 +1,27 @@
"""Protocol definitions for RAG factory modules.""" """Protocol definitions for RAG factory modules."""
from typing import Protocol from __future__ import annotations
from crewai.rag.config.types import RagConfigType from typing import Protocol, TYPE_CHECKING
from crewai.rag.core.base_client import BaseClient
if TYPE_CHECKING:
from crewai.rag.chromadb.client import ChromaDBClient
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.qdrant.client import QdrantClient
from crewai.rag.qdrant.config import QdrantConfig
class ChromaFactoryModule(Protocol): class ChromaFactoryModule(Protocol):
"""Protocol for ChromaDB factory module.""" """Protocol for ChromaDB factory module."""
def create_client(self, config: RagConfigType) -> BaseClient: def create_client(self, config: ChromaDBConfig) -> ChromaDBClient:
"""Creates a ChromaDB client from configuration.""" """Creates a ChromaDB client from configuration."""
... ...
class QdrantFactoryModule(Protocol):
"""Protocol for Qdrant factory module."""
def create_client(self, config: QdrantConfig) -> QdrantClient:
"""Creates a Qdrant client from configuration."""
...

View File

@@ -13,3 +13,10 @@ class MissingChromaDBConfig(_MissingProvider):
"""Placeholder for missing ChromaDB configuration.""" """Placeholder for missing ChromaDB configuration."""
provider: Literal["chromadb"] = field(default="chromadb") provider: Literal["chromadb"] = field(default="chromadb")
@pyd_dataclass(config=ConfigDict(extra="forbid"))
class MissingQdrantConfig(_MissingProvider):
"""Placeholder for missing Qdrant configuration."""
provider: Literal["qdrant"] = field(default="qdrant")

View File

@@ -3,6 +3,6 @@
from typing import Annotated, Literal from typing import Annotated, Literal
SupportedProvider = Annotated[ SupportedProvider = Annotated[
Literal["chromadb"], Literal["chromadb", "qdrant"],
"Supported RAG provider types, add providers here as they become available", "Supported RAG provider types, add providers here as they become available",
] ]

View File

@@ -1,12 +1,18 @@
"""Type definitions for RAG configuration.""" """Type definitions for RAG configuration."""
from typing import TYPE_CHECKING, Annotated, TypeAlias from typing import Annotated, TypeAlias, TYPE_CHECKING
from pydantic import Field from pydantic import Field
from crewai.rag.config.constants import DISCRIMINATOR from crewai.rag.config.constants import DISCRIMINATOR
# Linter freaks out on conditional imports, assigning in the type checking fixes it
if TYPE_CHECKING: if TYPE_CHECKING:
from crewai.rag.chromadb.config import ChromaDBConfig from crewai.rag.chromadb.config import ChromaDBConfig as ChromaDBConfig_
ChromaDBConfig = ChromaDBConfig_
from crewai.rag.qdrant.config import QdrantConfig as QdrantConfig_
QdrantConfig = QdrantConfig_
else: else:
try: try:
from crewai.rag.chromadb.config import ChromaDBConfig from crewai.rag.chromadb.config import ChromaDBConfig
@@ -15,7 +21,14 @@ else:
MissingChromaDBConfig as ChromaDBConfig, MissingChromaDBConfig as ChromaDBConfig,
) )
SupportedProviderConfig: TypeAlias = ChromaDBConfig try:
from crewai.rag.qdrant.config import QdrantConfig
except ImportError:
from crewai.rag.config.optional_imports.providers import (
MissingQdrantConfig as QdrantConfig,
)
SupportedProviderConfig: TypeAlias = ChromaDBConfig | QdrantConfig
RagConfigType: TypeAlias = Annotated[ RagConfigType: TypeAlias = Annotated[
SupportedProviderConfig, Field(discriminator=DISCRIMINATOR) SupportedProviderConfig, Field(discriminator=DISCRIMINATOR)
] ]

View File

@@ -1,8 +1,8 @@
"""Protocol for vector database client implementations.""" """Protocol for vector database client implementations."""
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Protocol, runtime_checkable, TypedDict, Annotated from typing import Any, Protocol, runtime_checkable, Annotated
from typing_extensions import Unpack, Required from typing_extensions import Unpack, Required, TypedDict
from pydantic import GetCoreSchemaHandler from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema from pydantic_core import CoreSchema, core_schema

View File

@@ -0,0 +1,54 @@
"""Qdrant configuration model."""
from dataclasses import field
from typing import Literal, cast
from pydantic.dataclasses import dataclass as pyd_dataclass
from crewai.rag.config.base import BaseRagConfig
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
def _default_options() -> QdrantClientParams:
"""Create default Qdrant client options.
Returns:
Default options with file-based storage.
"""
return QdrantClientParams(path=DEFAULT_STORAGE_PATH)
def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper:
"""Create default Qdrant embedding function.
Returns:
Default embedding function using fastembed with all-MiniLM-L6-v2.
"""
from fastembed import TextEmbedding
model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL)
def embed_fn(text: str) -> list[float]:
"""Embed a single text string.
Args:
text: Text to embed.
Returns:
Embedding vector as list of floats.
"""
embeddings = list(model.embed([text]))
return embeddings[0].tolist() if embeddings else []
return cast(QdrantEmbeddingFunctionWrapper, embed_fn)
@pyd_dataclass(frozen=True)
class QdrantConfig(BaseRagConfig):
"""Configuration for Qdrant client."""
provider: Literal["qdrant"] = field(default="qdrant", init=False)
options: QdrantClientParams = field(default_factory=_default_options)
embedding_function: QdrantEmbeddingFunctionWrapper = field(
default_factory=_default_embedding_function
)

View File

@@ -1,7 +1,12 @@
"""Constants for Qdrant implementation.""" """Constants for Qdrant implementation."""
import os
from typing import Final from typing import Final
from qdrant_client.models import Distance, VectorParams from qdrant_client.models import Distance, VectorParams
from crewai.utilities.paths import db_storage_path
DEFAULT_VECTOR_PARAMS: Final = VectorParams(size=384, distance=Distance.COSINE) DEFAULT_VECTOR_PARAMS: Final = VectorParams(size=384, distance=Distance.COSINE)
DEFAULT_EMBEDDING_MODEL: Final[str] = "sentence-transformers/all-MiniLM-L6-v2"
DEFAULT_STORAGE_PATH: Final[str] = os.path.join(db_storage_path(), "qdrant")

View File

@@ -0,0 +1,21 @@
"""Factory functions for creating Qdrant clients from configuration."""
from qdrant_client import QdrantClient as SyncQdrantClientBase
from crewai.rag.qdrant.client import QdrantClient
from crewai.rag.qdrant.config import QdrantConfig
def create_client(config: QdrantConfig) -> QdrantClient:
"""Create a Qdrant client from configuration.
Args:
config: The Qdrant configuration.
Returns:
A configured QdrantClient instance.
"""
qdrant_client = SyncQdrantClientBase(**config.options)
return QdrantClient(
client=qdrant_client, embedding_function=config.embedding_function
)

View File

@@ -1,10 +1,12 @@
"""Type definitions specific to Qdrant implementation.""" """Type definitions specific to Qdrant implementation."""
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import Annotated, Any, Protocol, TypeAlias, TypedDict from typing import Annotated, Any, Protocol, TypeAlias
from typing_extensions import NotRequired from typing_extensions import NotRequired, TypedDict
import numpy as np import numpy as np
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
from qdrant_client.models import ( from qdrant_client.models import (
FieldCondition, FieldCondition,
@@ -53,6 +55,21 @@ class EmbeddingFunction(Protocol):
... ...
class QdrantEmbeddingFunctionWrapper(EmbeddingFunction):
"""Base class for Qdrant EmbeddingFunction to work with Pydantic validation."""
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Generate Pydantic core schema for Qdrant EmbeddingFunction.
This allows Pydantic to handle Qdrant's EmbeddingFunction type
without requiring arbitrary_types_allowed=True.
"""
return core_schema.any_schema()
class AsyncEmbeddingFunction(Protocol): class AsyncEmbeddingFunction(Protocol):
"""Protocol for async embedding functions that convert text to vectors.""" """Protocol for async embedding functions that convert text to vectors."""

View File

@@ -1,9 +1,9 @@
"""Type definitions for RAG (Retrieval-Augmented Generation) systems.""" """Type definitions for RAG (Retrieval-Augmented Generation) systems."""
from collections.abc import Callable, Mapping from collections.abc import Callable, Mapping
from typing import TypeAlias, TypedDict, Any from typing import TypeAlias, Any
from typing_extensions import Required from typing_extensions import Required, TypedDict
class BaseRecord(TypedDict, total=False): class BaseRecord(TypedDict, total=False):