Qdrant RAG Provider Support (#3400)

* Added Qdrant provider support with factory, config, and protocols
* Improved default embeddings and type definitions
* Fixed ChromaDB factory embedding assignment
This commit is contained in:
Greyson LaLonde
2025-08-26 08:44:02 -04:00
committed by GitHub
parent 7ac482c7c9
commit 869bb115c8
14 changed files with 175 additions and 24 deletions

View File

@@ -1,6 +1,5 @@
"""ChromaDB configuration model."""
import os
import warnings
from dataclasses import field
from typing import Literal, cast
@@ -9,9 +8,12 @@ from chromadb.config import Settings
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.utilities.paths import db_storage_path
from crewai.rag.config.base import BaseRagConfig
from crewai.rag.chromadb.constants import DEFAULT_TENANT, DEFAULT_DATABASE
from crewai.rag.chromadb.constants import (
DEFAULT_TENANT,
DEFAULT_DATABASE,
DEFAULT_STORAGE_PATH,
)
warnings.filterwarnings(
@@ -29,7 +31,7 @@ def _default_settings() -> Settings:
Settings with persistent storage and reset enabled.
"""
return Settings(
persist_directory=os.path.join(db_storage_path(), "chromadb"),
persist_directory=DEFAULT_STORAGE_PATH,
allow_reset=True,
is_persistent=True,
)
@@ -39,7 +41,7 @@ def _default_embedding_function() -> ChromaEmbeddingFunctionWrapper:
"""Create default ChromaDB embedding function.
Returns:
Default embedding function cast to proper type.
Default embedding function using all-MiniLM-L6-v2 via ONNX.
"""
return cast(ChromaEmbeddingFunctionWrapper, DefaultEmbeddingFunction())
@@ -52,6 +54,6 @@ class ChromaDBConfig(BaseRagConfig):
tenant: str = DEFAULT_TENANT
database: str = DEFAULT_DATABASE
settings: Settings = field(default_factory=_default_settings)
embedding_function: ChromaEmbeddingFunctionWrapper | None = field(
embedding_function: ChromaEmbeddingFunctionWrapper = field(
default_factory=_default_embedding_function
)

View File

@@ -1,6 +1,10 @@
"""Constants for ChromaDB configuration."""
import os
from typing import Final
from crewai.utilities.paths import db_storage_path
DEFAULT_TENANT: Final[str] = "default_tenant"
DEFAULT_DATABASE: Final[str] = "default_database"
DEFAULT_STORAGE_PATH: Final[str] = os.path.join(db_storage_path(), "chromadb")

View File

@@ -2,7 +2,10 @@
from typing import cast
from crewai.rag.config.optional_imports.protocols import ChromaFactoryModule
from crewai.rag.config.optional_imports.protocols import (
ChromaFactoryModule,
QdrantFactoryModule,
)
from crewai.rag.core.base_client import BaseClient
from crewai.rag.config.types import RagConfigType
from crewai.utilities.import_utils import require
@@ -22,11 +25,21 @@ def create_client(config: RagConfigType) -> BaseClient:
"""
if config.provider == "chromadb":
mod = cast(
chromadb_mod = cast(
ChromaFactoryModule,
require(
"crewai.rag.chromadb.factory",
purpose="The 'chromadb' provider",
),
)
return mod.create_client(config)
return chromadb_mod.create_client(config)
if config.provider == "qdrant":
qdrant_mod = cast(
QdrantFactoryModule,
require(
"crewai.rag.qdrant.factory",
purpose="The 'qdrant' provider",
),
)
return qdrant_mod.create_client(config)

View File

@@ -14,7 +14,9 @@ class _MissingProvider:
Raises RuntimeError when instantiated to indicate missing dependencies.
"""
provider: Literal["chromadb", "__missing__"] = field(default="__missing__")
provider: Literal["chromadb", "qdrant", "__missing__"] = field(
default="__missing__"
)
def __post_init__(self) -> None:
"""Raises error indicating the provider is not installed."""

View File

@@ -1,14 +1,27 @@
"""Protocol definitions for RAG factory modules."""
from typing import Protocol
from __future__ import annotations
from crewai.rag.config.types import RagConfigType
from crewai.rag.core.base_client import BaseClient
from typing import Protocol, TYPE_CHECKING
if TYPE_CHECKING:
from crewai.rag.chromadb.client import ChromaDBClient
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.qdrant.client import QdrantClient
from crewai.rag.qdrant.config import QdrantConfig
class ChromaFactoryModule(Protocol):
"""Protocol for ChromaDB factory module."""
def create_client(self, config: RagConfigType) -> BaseClient:
def create_client(self, config: ChromaDBConfig) -> ChromaDBClient:
"""Creates a ChromaDB client from configuration."""
...
class QdrantFactoryModule(Protocol):
"""Protocol for Qdrant factory module."""
def create_client(self, config: QdrantConfig) -> QdrantClient:
"""Creates a Qdrant client from configuration."""
...

View File

@@ -13,3 +13,10 @@ class MissingChromaDBConfig(_MissingProvider):
"""Placeholder for missing ChromaDB configuration."""
provider: Literal["chromadb"] = field(default="chromadb")
@pyd_dataclass(config=ConfigDict(extra="forbid"))
class MissingQdrantConfig(_MissingProvider):
"""Placeholder for missing Qdrant configuration."""
provider: Literal["qdrant"] = field(default="qdrant")

View File

@@ -3,6 +3,6 @@
from typing import Annotated, Literal
SupportedProvider = Annotated[
Literal["chromadb"],
Literal["chromadb", "qdrant"],
"Supported RAG provider types, add providers here as they become available",
]

View File

@@ -1,12 +1,18 @@
"""Type definitions for RAG configuration."""
from typing import TYPE_CHECKING, Annotated, TypeAlias
from typing import Annotated, TypeAlias, TYPE_CHECKING
from pydantic import Field
from crewai.rag.config.constants import DISCRIMINATOR
# Linter freaks out on conditional imports, assigning in the type checking fixes it
if TYPE_CHECKING:
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.config import ChromaDBConfig as ChromaDBConfig_
ChromaDBConfig = ChromaDBConfig_
from crewai.rag.qdrant.config import QdrantConfig as QdrantConfig_
QdrantConfig = QdrantConfig_
else:
try:
from crewai.rag.chromadb.config import ChromaDBConfig
@@ -15,7 +21,14 @@ else:
MissingChromaDBConfig as ChromaDBConfig,
)
SupportedProviderConfig: TypeAlias = ChromaDBConfig
try:
from crewai.rag.qdrant.config import QdrantConfig
except ImportError:
from crewai.rag.config.optional_imports.providers import (
MissingQdrantConfig as QdrantConfig,
)
SupportedProviderConfig: TypeAlias = ChromaDBConfig | QdrantConfig
RagConfigType: TypeAlias = Annotated[
SupportedProviderConfig, Field(discriminator=DISCRIMINATOR)
]

View File

@@ -1,8 +1,8 @@
"""Protocol for vector database client implementations."""
from abc import abstractmethod
from typing import Any, Protocol, runtime_checkable, TypedDict, Annotated
from typing_extensions import Unpack, Required
from typing import Any, Protocol, runtime_checkable, Annotated
from typing_extensions import Unpack, Required, TypedDict
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema

View File

@@ -0,0 +1,54 @@
"""Qdrant configuration model."""
from dataclasses import field
from typing import Literal, cast
from pydantic.dataclasses import dataclass as pyd_dataclass
from crewai.rag.config.base import BaseRagConfig
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
def _default_options() -> QdrantClientParams:
"""Create default Qdrant client options.
Returns:
Default options with file-based storage.
"""
return QdrantClientParams(path=DEFAULT_STORAGE_PATH)
def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper:
"""Create default Qdrant embedding function.
Returns:
Default embedding function using fastembed with all-MiniLM-L6-v2.
"""
from fastembed import TextEmbedding
model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL)
def embed_fn(text: str) -> list[float]:
"""Embed a single text string.
Args:
text: Text to embed.
Returns:
Embedding vector as list of floats.
"""
embeddings = list(model.embed([text]))
return embeddings[0].tolist() if embeddings else []
return cast(QdrantEmbeddingFunctionWrapper, embed_fn)
@pyd_dataclass(frozen=True)
class QdrantConfig(BaseRagConfig):
"""Configuration for Qdrant client."""
provider: Literal["qdrant"] = field(default="qdrant", init=False)
options: QdrantClientParams = field(default_factory=_default_options)
embedding_function: QdrantEmbeddingFunctionWrapper = field(
default_factory=_default_embedding_function
)

View File

@@ -1,7 +1,12 @@
"""Constants for Qdrant implementation."""
import os
from typing import Final
from qdrant_client.models import Distance, VectorParams
from crewai.utilities.paths import db_storage_path
DEFAULT_VECTOR_PARAMS: Final = VectorParams(size=384, distance=Distance.COSINE)
DEFAULT_EMBEDDING_MODEL: Final[str] = "sentence-transformers/all-MiniLM-L6-v2"
DEFAULT_STORAGE_PATH: Final[str] = os.path.join(db_storage_path(), "qdrant")

View File

@@ -0,0 +1,21 @@
"""Factory functions for creating Qdrant clients from configuration."""
from qdrant_client import QdrantClient as SyncQdrantClientBase
from crewai.rag.qdrant.client import QdrantClient
from crewai.rag.qdrant.config import QdrantConfig
def create_client(config: QdrantConfig) -> QdrantClient:
"""Create a Qdrant client from configuration.
Args:
config: The Qdrant configuration.
Returns:
A configured QdrantClient instance.
"""
qdrant_client = SyncQdrantClientBase(**config.options)
return QdrantClient(
client=qdrant_client, embedding_function=config.embedding_function
)

View File

@@ -1,10 +1,12 @@
"""Type definitions specific to Qdrant implementation."""
from collections.abc import Awaitable, Callable
from typing import Annotated, Any, Protocol, TypeAlias, TypedDict
from typing_extensions import NotRequired
from typing import Annotated, Any, Protocol, TypeAlias
from typing_extensions import NotRequired, TypedDict
import numpy as np
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
from qdrant_client.models import (
FieldCondition,
@@ -53,6 +55,21 @@ class EmbeddingFunction(Protocol):
...
class QdrantEmbeddingFunctionWrapper(EmbeddingFunction):
"""Base class for Qdrant EmbeddingFunction to work with Pydantic validation."""
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Generate Pydantic core schema for Qdrant EmbeddingFunction.
This allows Pydantic to handle Qdrant's EmbeddingFunction type
without requiring arbitrary_types_allowed=True.
"""
return core_schema.any_schema()
class AsyncEmbeddingFunction(Protocol):
"""Protocol for async embedding functions that convert text to vectors."""

View File

@@ -1,9 +1,9 @@
"""Type definitions for RAG (Retrieval-Augmented Generation) systems."""
from collections.abc import Callable, Mapping
from typing import TypeAlias, TypedDict, Any
from typing import TypeAlias, Any
from typing_extensions import Required
from typing_extensions import Required, TypedDict
class BaseRecord(TypedDict, total=False):