mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 17:18:29 +00:00
Qdrant RAG Provider Support (#3400)
* Added Qdrant provider support with factory, config, and protocols * Improved default embeddings and type definitions * Fixed ChromaDB factory embedding assignment
This commit is contained in:
@@ -1,6 +1,5 @@
|
||||
"""ChromaDB configuration model."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import field
|
||||
from typing import Literal, cast
|
||||
@@ -9,9 +8,12 @@ from chromadb.config import Settings
|
||||
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
|
||||
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
from crewai.rag.config.base import BaseRagConfig
|
||||
from crewai.rag.chromadb.constants import DEFAULT_TENANT, DEFAULT_DATABASE
|
||||
from crewai.rag.chromadb.constants import (
|
||||
DEFAULT_TENANT,
|
||||
DEFAULT_DATABASE,
|
||||
DEFAULT_STORAGE_PATH,
|
||||
)
|
||||
|
||||
|
||||
warnings.filterwarnings(
|
||||
@@ -29,7 +31,7 @@ def _default_settings() -> Settings:
|
||||
Settings with persistent storage and reset enabled.
|
||||
"""
|
||||
return Settings(
|
||||
persist_directory=os.path.join(db_storage_path(), "chromadb"),
|
||||
persist_directory=DEFAULT_STORAGE_PATH,
|
||||
allow_reset=True,
|
||||
is_persistent=True,
|
||||
)
|
||||
@@ -39,7 +41,7 @@ def _default_embedding_function() -> ChromaEmbeddingFunctionWrapper:
|
||||
"""Create default ChromaDB embedding function.
|
||||
|
||||
Returns:
|
||||
Default embedding function cast to proper type.
|
||||
Default embedding function using all-MiniLM-L6-v2 via ONNX.
|
||||
"""
|
||||
return cast(ChromaEmbeddingFunctionWrapper, DefaultEmbeddingFunction())
|
||||
|
||||
@@ -52,6 +54,6 @@ class ChromaDBConfig(BaseRagConfig):
|
||||
tenant: str = DEFAULT_TENANT
|
||||
database: str = DEFAULT_DATABASE
|
||||
settings: Settings = field(default_factory=_default_settings)
|
||||
embedding_function: ChromaEmbeddingFunctionWrapper | None = field(
|
||||
embedding_function: ChromaEmbeddingFunctionWrapper = field(
|
||||
default_factory=_default_embedding_function
|
||||
)
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
"""Constants for ChromaDB configuration."""
|
||||
|
||||
import os
|
||||
from typing import Final
|
||||
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
DEFAULT_TENANT: Final[str] = "default_tenant"
|
||||
DEFAULT_DATABASE: Final[str] = "default_database"
|
||||
DEFAULT_STORAGE_PATH: Final[str] = os.path.join(db_storage_path(), "chromadb")
|
||||
|
||||
@@ -2,7 +2,10 @@
|
||||
|
||||
from typing import cast
|
||||
|
||||
from crewai.rag.config.optional_imports.protocols import ChromaFactoryModule
|
||||
from crewai.rag.config.optional_imports.protocols import (
|
||||
ChromaFactoryModule,
|
||||
QdrantFactoryModule,
|
||||
)
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.utilities.import_utils import require
|
||||
@@ -22,11 +25,21 @@ def create_client(config: RagConfigType) -> BaseClient:
|
||||
"""
|
||||
|
||||
if config.provider == "chromadb":
|
||||
mod = cast(
|
||||
chromadb_mod = cast(
|
||||
ChromaFactoryModule,
|
||||
require(
|
||||
"crewai.rag.chromadb.factory",
|
||||
purpose="The 'chromadb' provider",
|
||||
),
|
||||
)
|
||||
return mod.create_client(config)
|
||||
return chromadb_mod.create_client(config)
|
||||
|
||||
if config.provider == "qdrant":
|
||||
qdrant_mod = cast(
|
||||
QdrantFactoryModule,
|
||||
require(
|
||||
"crewai.rag.qdrant.factory",
|
||||
purpose="The 'qdrant' provider",
|
||||
),
|
||||
)
|
||||
return qdrant_mod.create_client(config)
|
||||
|
||||
@@ -14,7 +14,9 @@ class _MissingProvider:
|
||||
Raises RuntimeError when instantiated to indicate missing dependencies.
|
||||
"""
|
||||
|
||||
provider: Literal["chromadb", "__missing__"] = field(default="__missing__")
|
||||
provider: Literal["chromadb", "qdrant", "__missing__"] = field(
|
||||
default="__missing__"
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Raises error indicating the provider is not installed."""
|
||||
|
||||
@@ -1,14 +1,27 @@
|
||||
"""Protocol definitions for RAG factory modules."""
|
||||
|
||||
from typing import Protocol
|
||||
from __future__ import annotations
|
||||
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from typing import Protocol, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.rag.chromadb.client import ChromaDBClient
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.qdrant.client import QdrantClient
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
|
||||
|
||||
class ChromaFactoryModule(Protocol):
|
||||
"""Protocol for ChromaDB factory module."""
|
||||
|
||||
def create_client(self, config: RagConfigType) -> BaseClient:
|
||||
def create_client(self, config: ChromaDBConfig) -> ChromaDBClient:
|
||||
"""Creates a ChromaDB client from configuration."""
|
||||
...
|
||||
|
||||
|
||||
class QdrantFactoryModule(Protocol):
|
||||
"""Protocol for Qdrant factory module."""
|
||||
|
||||
def create_client(self, config: QdrantConfig) -> QdrantClient:
|
||||
"""Creates a Qdrant client from configuration."""
|
||||
...
|
||||
|
||||
@@ -13,3 +13,10 @@ class MissingChromaDBConfig(_MissingProvider):
|
||||
"""Placeholder for missing ChromaDB configuration."""
|
||||
|
||||
provider: Literal["chromadb"] = field(default="chromadb")
|
||||
|
||||
|
||||
@pyd_dataclass(config=ConfigDict(extra="forbid"))
|
||||
class MissingQdrantConfig(_MissingProvider):
|
||||
"""Placeholder for missing Qdrant configuration."""
|
||||
|
||||
provider: Literal["qdrant"] = field(default="qdrant")
|
||||
|
||||
@@ -3,6 +3,6 @@
|
||||
from typing import Annotated, Literal
|
||||
|
||||
SupportedProvider = Annotated[
|
||||
Literal["chromadb"],
|
||||
Literal["chromadb", "qdrant"],
|
||||
"Supported RAG provider types, add providers here as they become available",
|
||||
]
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
"""Type definitions for RAG configuration."""
|
||||
|
||||
from typing import TYPE_CHECKING, Annotated, TypeAlias
|
||||
from typing import Annotated, TypeAlias, TYPE_CHECKING
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.config.constants import DISCRIMINATOR
|
||||
|
||||
# Linter freaks out on conditional imports, assigning in the type checking fixes it
|
||||
if TYPE_CHECKING:
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig as ChromaDBConfig_
|
||||
|
||||
ChromaDBConfig = ChromaDBConfig_
|
||||
from crewai.rag.qdrant.config import QdrantConfig as QdrantConfig_
|
||||
|
||||
QdrantConfig = QdrantConfig_
|
||||
else:
|
||||
try:
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
@@ -15,7 +21,14 @@ else:
|
||||
MissingChromaDBConfig as ChromaDBConfig,
|
||||
)
|
||||
|
||||
SupportedProviderConfig: TypeAlias = ChromaDBConfig
|
||||
try:
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
except ImportError:
|
||||
from crewai.rag.config.optional_imports.providers import (
|
||||
MissingQdrantConfig as QdrantConfig,
|
||||
)
|
||||
|
||||
SupportedProviderConfig: TypeAlias = ChromaDBConfig | QdrantConfig
|
||||
RagConfigType: TypeAlias = Annotated[
|
||||
SupportedProviderConfig, Field(discriminator=DISCRIMINATOR)
|
||||
]
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Protocol for vector database client implementations."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Protocol, runtime_checkable, TypedDict, Annotated
|
||||
from typing_extensions import Unpack, Required
|
||||
from typing import Any, Protocol, runtime_checkable, Annotated
|
||||
from typing_extensions import Unpack, Required, TypedDict
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
|
||||
54
src/crewai/rag/qdrant/config.py
Normal file
54
src/crewai/rag/qdrant/config.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Qdrant configuration model."""
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Literal, cast
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
from crewai.rag.config.base import BaseRagConfig
|
||||
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
|
||||
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
|
||||
|
||||
|
||||
def _default_options() -> QdrantClientParams:
|
||||
"""Create default Qdrant client options.
|
||||
|
||||
Returns:
|
||||
Default options with file-based storage.
|
||||
"""
|
||||
return QdrantClientParams(path=DEFAULT_STORAGE_PATH)
|
||||
|
||||
|
||||
def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper:
|
||||
"""Create default Qdrant embedding function.
|
||||
|
||||
Returns:
|
||||
Default embedding function using fastembed with all-MiniLM-L6-v2.
|
||||
"""
|
||||
from fastembed import TextEmbedding
|
||||
|
||||
model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL)
|
||||
|
||||
def embed_fn(text: str) -> list[float]:
|
||||
"""Embed a single text string.
|
||||
|
||||
Args:
|
||||
text: Text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats.
|
||||
"""
|
||||
embeddings = list(model.embed([text]))
|
||||
return embeddings[0].tolist() if embeddings else []
|
||||
|
||||
return cast(QdrantEmbeddingFunctionWrapper, embed_fn)
|
||||
|
||||
|
||||
@pyd_dataclass(frozen=True)
|
||||
class QdrantConfig(BaseRagConfig):
|
||||
"""Configuration for Qdrant client."""
|
||||
|
||||
provider: Literal["qdrant"] = field(default="qdrant", init=False)
|
||||
options: QdrantClientParams = field(default_factory=_default_options)
|
||||
embedding_function: QdrantEmbeddingFunctionWrapper = field(
|
||||
default_factory=_default_embedding_function
|
||||
)
|
||||
@@ -1,7 +1,12 @@
|
||||
"""Constants for Qdrant implementation."""
|
||||
|
||||
import os
|
||||
from typing import Final
|
||||
|
||||
from qdrant_client.models import Distance, VectorParams
|
||||
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
DEFAULT_VECTOR_PARAMS: Final = VectorParams(size=384, distance=Distance.COSINE)
|
||||
DEFAULT_EMBEDDING_MODEL: Final[str] = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
DEFAULT_STORAGE_PATH: Final[str] = os.path.join(db_storage_path(), "qdrant")
|
||||
|
||||
21
src/crewai/rag/qdrant/factory.py
Normal file
21
src/crewai/rag/qdrant/factory.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Factory functions for creating Qdrant clients from configuration."""
|
||||
|
||||
from qdrant_client import QdrantClient as SyncQdrantClientBase
|
||||
from crewai.rag.qdrant.client import QdrantClient
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
|
||||
|
||||
def create_client(config: QdrantConfig) -> QdrantClient:
|
||||
"""Create a Qdrant client from configuration.
|
||||
|
||||
Args:
|
||||
config: The Qdrant configuration.
|
||||
|
||||
Returns:
|
||||
A configured QdrantClient instance.
|
||||
"""
|
||||
|
||||
qdrant_client = SyncQdrantClientBase(**config.options)
|
||||
return QdrantClient(
|
||||
client=qdrant_client, embedding_function=config.embedding_function
|
||||
)
|
||||
@@ -1,10 +1,12 @@
|
||||
"""Type definitions specific to Qdrant implementation."""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Annotated, Any, Protocol, TypeAlias, TypedDict
|
||||
from typing_extensions import NotRequired
|
||||
from typing import Annotated, Any, Protocol, TypeAlias
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
import numpy as np
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
|
||||
from qdrant_client.models import (
|
||||
FieldCondition,
|
||||
@@ -53,6 +55,21 @@ class EmbeddingFunction(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class QdrantEmbeddingFunctionWrapper(EmbeddingFunction):
|
||||
"""Base class for Qdrant EmbeddingFunction to work with Pydantic validation."""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""Generate Pydantic core schema for Qdrant EmbeddingFunction.
|
||||
|
||||
This allows Pydantic to handle Qdrant's EmbeddingFunction type
|
||||
without requiring arbitrary_types_allowed=True.
|
||||
"""
|
||||
return core_schema.any_schema()
|
||||
|
||||
|
||||
class AsyncEmbeddingFunction(Protocol):
|
||||
"""Protocol for async embedding functions that convert text to vectors."""
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
"""Type definitions for RAG (Retrieval-Augmented Generation) systems."""
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import TypeAlias, TypedDict, Any
|
||||
from typing import TypeAlias, Any
|
||||
|
||||
from typing_extensions import Required
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class BaseRecord(TypedDict, total=False):
|
||||
|
||||
Reference in New Issue
Block a user