From 869bb115c8ccaa3bcced1625b345484c3c13db7f Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Tue, 26 Aug 2025 08:44:02 -0400 Subject: [PATCH] Qdrant RAG Provider Support (#3400) * Added Qdrant provider support with factory, config, and protocols * Improved default embeddings and type definitions * Fixed ChromaDB factory embedding assignment --- src/crewai/rag/chromadb/config.py | 14 ++--- src/crewai/rag/chromadb/constants.py | 4 ++ src/crewai/rag/config/factory.py | 19 +++++-- .../rag/config/optional_imports/base.py | 4 +- .../rag/config/optional_imports/protocols.py | 21 ++++++-- .../rag/config/optional_imports/providers.py | 7 +++ .../rag/config/optional_imports/types.py | 2 +- src/crewai/rag/config/types.py | 19 +++++-- src/crewai/rag/core/base_client.py | 4 +- src/crewai/rag/qdrant/config.py | 54 +++++++++++++++++++ src/crewai/rag/qdrant/constants.py | 5 ++ src/crewai/rag/qdrant/factory.py | 21 ++++++++ src/crewai/rag/qdrant/types.py | 21 +++++++- src/crewai/rag/types.py | 4 +- 14 files changed, 175 insertions(+), 24 deletions(-) create mode 100644 src/crewai/rag/qdrant/config.py create mode 100644 src/crewai/rag/qdrant/factory.py diff --git a/src/crewai/rag/chromadb/config.py b/src/crewai/rag/chromadb/config.py index 43b202c55..1f536dcf6 100644 --- a/src/crewai/rag/chromadb/config.py +++ b/src/crewai/rag/chromadb/config.py @@ -1,6 +1,5 @@ """ChromaDB configuration model.""" -import os import warnings from dataclasses import field from typing import Literal, cast @@ -9,9 +8,12 @@ from chromadb.config import Settings from chromadb.utils.embedding_functions import DefaultEmbeddingFunction from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper -from crewai.utilities.paths import db_storage_path from crewai.rag.config.base import BaseRagConfig -from crewai.rag.chromadb.constants import DEFAULT_TENANT, DEFAULT_DATABASE +from crewai.rag.chromadb.constants import ( + DEFAULT_TENANT, + DEFAULT_DATABASE, + DEFAULT_STORAGE_PATH, +) warnings.filterwarnings( @@ -29,7 +31,7 @@ def _default_settings() -> Settings: Settings with persistent storage and reset enabled. """ return Settings( - persist_directory=os.path.join(db_storage_path(), "chromadb"), + persist_directory=DEFAULT_STORAGE_PATH, allow_reset=True, is_persistent=True, ) @@ -39,7 +41,7 @@ def _default_embedding_function() -> ChromaEmbeddingFunctionWrapper: """Create default ChromaDB embedding function. Returns: - Default embedding function cast to proper type. + Default embedding function using all-MiniLM-L6-v2 via ONNX. """ return cast(ChromaEmbeddingFunctionWrapper, DefaultEmbeddingFunction()) @@ -52,6 +54,6 @@ class ChromaDBConfig(BaseRagConfig): tenant: str = DEFAULT_TENANT database: str = DEFAULT_DATABASE settings: Settings = field(default_factory=_default_settings) - embedding_function: ChromaEmbeddingFunctionWrapper | None = field( + embedding_function: ChromaEmbeddingFunctionWrapper = field( default_factory=_default_embedding_function ) diff --git a/src/crewai/rag/chromadb/constants.py b/src/crewai/rag/chromadb/constants.py index d9c585b6f..8dba23fd0 100644 --- a/src/crewai/rag/chromadb/constants.py +++ b/src/crewai/rag/chromadb/constants.py @@ -1,6 +1,10 @@ """Constants for ChromaDB configuration.""" +import os from typing import Final +from crewai.utilities.paths import db_storage_path + DEFAULT_TENANT: Final[str] = "default_tenant" DEFAULT_DATABASE: Final[str] = "default_database" +DEFAULT_STORAGE_PATH: Final[str] = os.path.join(db_storage_path(), "chromadb") diff --git a/src/crewai/rag/config/factory.py b/src/crewai/rag/config/factory.py index 1f34d6317..16e565e99 100644 --- a/src/crewai/rag/config/factory.py +++ b/src/crewai/rag/config/factory.py @@ -2,7 +2,10 @@ from typing import cast -from crewai.rag.config.optional_imports.protocols import ChromaFactoryModule +from crewai.rag.config.optional_imports.protocols import ( + ChromaFactoryModule, + QdrantFactoryModule, +) from crewai.rag.core.base_client import BaseClient from crewai.rag.config.types import RagConfigType from crewai.utilities.import_utils import require @@ -22,11 +25,21 @@ def create_client(config: RagConfigType) -> BaseClient: """ if config.provider == "chromadb": - mod = cast( + chromadb_mod = cast( ChromaFactoryModule, require( "crewai.rag.chromadb.factory", purpose="The 'chromadb' provider", ), ) - return mod.create_client(config) + return chromadb_mod.create_client(config) + + if config.provider == "qdrant": + qdrant_mod = cast( + QdrantFactoryModule, + require( + "crewai.rag.qdrant.factory", + purpose="The 'qdrant' provider", + ), + ) + return qdrant_mod.create_client(config) diff --git a/src/crewai/rag/config/optional_imports/base.py b/src/crewai/rag/config/optional_imports/base.py index abb35e0bc..ba6657ec4 100644 --- a/src/crewai/rag/config/optional_imports/base.py +++ b/src/crewai/rag/config/optional_imports/base.py @@ -14,7 +14,9 @@ class _MissingProvider: Raises RuntimeError when instantiated to indicate missing dependencies. """ - provider: Literal["chromadb", "__missing__"] = field(default="__missing__") + provider: Literal["chromadb", "qdrant", "__missing__"] = field( + default="__missing__" + ) def __post_init__(self) -> None: """Raises error indicating the provider is not installed.""" diff --git a/src/crewai/rag/config/optional_imports/protocols.py b/src/crewai/rag/config/optional_imports/protocols.py index 2e16b50c8..e7058bb66 100644 --- a/src/crewai/rag/config/optional_imports/protocols.py +++ b/src/crewai/rag/config/optional_imports/protocols.py @@ -1,14 +1,27 @@ """Protocol definitions for RAG factory modules.""" -from typing import Protocol +from __future__ import annotations -from crewai.rag.config.types import RagConfigType -from crewai.rag.core.base_client import BaseClient +from typing import Protocol, TYPE_CHECKING + +if TYPE_CHECKING: + from crewai.rag.chromadb.client import ChromaDBClient + from crewai.rag.chromadb.config import ChromaDBConfig + from crewai.rag.qdrant.client import QdrantClient + from crewai.rag.qdrant.config import QdrantConfig class ChromaFactoryModule(Protocol): """Protocol for ChromaDB factory module.""" - def create_client(self, config: RagConfigType) -> BaseClient: + def create_client(self, config: ChromaDBConfig) -> ChromaDBClient: """Creates a ChromaDB client from configuration.""" ... + + +class QdrantFactoryModule(Protocol): + """Protocol for Qdrant factory module.""" + + def create_client(self, config: QdrantConfig) -> QdrantClient: + """Creates a Qdrant client from configuration.""" + ... diff --git a/src/crewai/rag/config/optional_imports/providers.py b/src/crewai/rag/config/optional_imports/providers.py index 0d774e26b..ff4065d43 100644 --- a/src/crewai/rag/config/optional_imports/providers.py +++ b/src/crewai/rag/config/optional_imports/providers.py @@ -13,3 +13,10 @@ class MissingChromaDBConfig(_MissingProvider): """Placeholder for missing ChromaDB configuration.""" provider: Literal["chromadb"] = field(default="chromadb") + + +@pyd_dataclass(config=ConfigDict(extra="forbid")) +class MissingQdrantConfig(_MissingProvider): + """Placeholder for missing Qdrant configuration.""" + + provider: Literal["qdrant"] = field(default="qdrant") diff --git a/src/crewai/rag/config/optional_imports/types.py b/src/crewai/rag/config/optional_imports/types.py index dbd169cab..184348b1b 100644 --- a/src/crewai/rag/config/optional_imports/types.py +++ b/src/crewai/rag/config/optional_imports/types.py @@ -3,6 +3,6 @@ from typing import Annotated, Literal SupportedProvider = Annotated[ - Literal["chromadb"], + Literal["chromadb", "qdrant"], "Supported RAG provider types, add providers here as they become available", ] diff --git a/src/crewai/rag/config/types.py b/src/crewai/rag/config/types.py index ddd9f5268..d6431e98f 100644 --- a/src/crewai/rag/config/types.py +++ b/src/crewai/rag/config/types.py @@ -1,12 +1,18 @@ """Type definitions for RAG configuration.""" -from typing import TYPE_CHECKING, Annotated, TypeAlias +from typing import Annotated, TypeAlias, TYPE_CHECKING from pydantic import Field from crewai.rag.config.constants import DISCRIMINATOR +# Linter freaks out on conditional imports, assigning in the type checking fixes it if TYPE_CHECKING: - from crewai.rag.chromadb.config import ChromaDBConfig + from crewai.rag.chromadb.config import ChromaDBConfig as ChromaDBConfig_ + + ChromaDBConfig = ChromaDBConfig_ + from crewai.rag.qdrant.config import QdrantConfig as QdrantConfig_ + + QdrantConfig = QdrantConfig_ else: try: from crewai.rag.chromadb.config import ChromaDBConfig @@ -15,7 +21,14 @@ else: MissingChromaDBConfig as ChromaDBConfig, ) -SupportedProviderConfig: TypeAlias = ChromaDBConfig + try: + from crewai.rag.qdrant.config import QdrantConfig + except ImportError: + from crewai.rag.config.optional_imports.providers import ( + MissingQdrantConfig as QdrantConfig, + ) + +SupportedProviderConfig: TypeAlias = ChromaDBConfig | QdrantConfig RagConfigType: TypeAlias = Annotated[ SupportedProviderConfig, Field(discriminator=DISCRIMINATOR) ] diff --git a/src/crewai/rag/core/base_client.py b/src/crewai/rag/core/base_client.py index 6fa4346e1..d7fb48a50 100644 --- a/src/crewai/rag/core/base_client.py +++ b/src/crewai/rag/core/base_client.py @@ -1,8 +1,8 @@ """Protocol for vector database client implementations.""" from abc import abstractmethod -from typing import Any, Protocol, runtime_checkable, TypedDict, Annotated -from typing_extensions import Unpack, Required +from typing import Any, Protocol, runtime_checkable, Annotated +from typing_extensions import Unpack, Required, TypedDict from pydantic import GetCoreSchemaHandler from pydantic_core import CoreSchema, core_schema diff --git a/src/crewai/rag/qdrant/config.py b/src/crewai/rag/qdrant/config.py new file mode 100644 index 000000000..7ae04c7b7 --- /dev/null +++ b/src/crewai/rag/qdrant/config.py @@ -0,0 +1,54 @@ +"""Qdrant configuration model.""" + +from dataclasses import field +from typing import Literal, cast +from pydantic.dataclasses import dataclass as pyd_dataclass + +from crewai.rag.config.base import BaseRagConfig +from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper +from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH + + +def _default_options() -> QdrantClientParams: + """Create default Qdrant client options. + + Returns: + Default options with file-based storage. + """ + return QdrantClientParams(path=DEFAULT_STORAGE_PATH) + + +def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper: + """Create default Qdrant embedding function. + + Returns: + Default embedding function using fastembed with all-MiniLM-L6-v2. + """ + from fastembed import TextEmbedding + + model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL) + + def embed_fn(text: str) -> list[float]: + """Embed a single text string. + + Args: + text: Text to embed. + + Returns: + Embedding vector as list of floats. + """ + embeddings = list(model.embed([text])) + return embeddings[0].tolist() if embeddings else [] + + return cast(QdrantEmbeddingFunctionWrapper, embed_fn) + + +@pyd_dataclass(frozen=True) +class QdrantConfig(BaseRagConfig): + """Configuration for Qdrant client.""" + + provider: Literal["qdrant"] = field(default="qdrant", init=False) + options: QdrantClientParams = field(default_factory=_default_options) + embedding_function: QdrantEmbeddingFunctionWrapper = field( + default_factory=_default_embedding_function + ) diff --git a/src/crewai/rag/qdrant/constants.py b/src/crewai/rag/qdrant/constants.py index 027e5f1f2..9714c9de6 100644 --- a/src/crewai/rag/qdrant/constants.py +++ b/src/crewai/rag/qdrant/constants.py @@ -1,7 +1,12 @@ """Constants for Qdrant implementation.""" +import os from typing import Final from qdrant_client.models import Distance, VectorParams +from crewai.utilities.paths import db_storage_path + DEFAULT_VECTOR_PARAMS: Final = VectorParams(size=384, distance=Distance.COSINE) +DEFAULT_EMBEDDING_MODEL: Final[str] = "sentence-transformers/all-MiniLM-L6-v2" +DEFAULT_STORAGE_PATH: Final[str] = os.path.join(db_storage_path(), "qdrant") diff --git a/src/crewai/rag/qdrant/factory.py b/src/crewai/rag/qdrant/factory.py new file mode 100644 index 000000000..75529a2a1 --- /dev/null +++ b/src/crewai/rag/qdrant/factory.py @@ -0,0 +1,21 @@ +"""Factory functions for creating Qdrant clients from configuration.""" + +from qdrant_client import QdrantClient as SyncQdrantClientBase +from crewai.rag.qdrant.client import QdrantClient +from crewai.rag.qdrant.config import QdrantConfig + + +def create_client(config: QdrantConfig) -> QdrantClient: + """Create a Qdrant client from configuration. + + Args: + config: The Qdrant configuration. + + Returns: + A configured QdrantClient instance. + """ + + qdrant_client = SyncQdrantClientBase(**config.options) + return QdrantClient( + client=qdrant_client, embedding_function=config.embedding_function + ) diff --git a/src/crewai/rag/qdrant/types.py b/src/crewai/rag/qdrant/types.py index 306f66c91..706a4f155 100644 --- a/src/crewai/rag/qdrant/types.py +++ b/src/crewai/rag/qdrant/types.py @@ -1,10 +1,12 @@ """Type definitions specific to Qdrant implementation.""" from collections.abc import Awaitable, Callable -from typing import Annotated, Any, Protocol, TypeAlias, TypedDict -from typing_extensions import NotRequired +from typing import Annotated, Any, Protocol, TypeAlias +from typing_extensions import NotRequired, TypedDict import numpy as np +from pydantic import GetCoreSchemaHandler +from pydantic_core import CoreSchema, core_schema from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient from qdrant_client.models import ( FieldCondition, @@ -53,6 +55,21 @@ class EmbeddingFunction(Protocol): ... +class QdrantEmbeddingFunctionWrapper(EmbeddingFunction): + """Base class for Qdrant EmbeddingFunction to work with Pydantic validation.""" + + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: GetCoreSchemaHandler + ) -> CoreSchema: + """Generate Pydantic core schema for Qdrant EmbeddingFunction. + + This allows Pydantic to handle Qdrant's EmbeddingFunction type + without requiring arbitrary_types_allowed=True. + """ + return core_schema.any_schema() + + class AsyncEmbeddingFunction(Protocol): """Protocol for async embedding functions that convert text to vectors.""" diff --git a/src/crewai/rag/types.py b/src/crewai/rag/types.py index 0f44422a8..a1caf164a 100644 --- a/src/crewai/rag/types.py +++ b/src/crewai/rag/types.py @@ -1,9 +1,9 @@ """Type definitions for RAG (Retrieval-Augmented Generation) systems.""" from collections.abc import Callable, Mapping -from typing import TypeAlias, TypedDict, Any +from typing import TypeAlias, Any -from typing_extensions import Required +from typing_extensions import Required, TypedDict class BaseRecord(TypedDict, total=False):