mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-27 17:18:13 +00:00
Qdrant RAG Provider Support (#3400)
* Added Qdrant provider support with factory, config, and protocols * Improved default embeddings and type definitions * Fixed ChromaDB factory embedding assignment
This commit is contained in:
@@ -1,6 +1,5 @@
|
|||||||
"""ChromaDB configuration model."""
|
"""ChromaDB configuration model."""
|
||||||
|
|
||||||
import os
|
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import field
|
from dataclasses import field
|
||||||
from typing import Literal, cast
|
from typing import Literal, cast
|
||||||
@@ -9,9 +8,12 @@ from chromadb.config import Settings
|
|||||||
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
|
from chromadb.utils.embedding_functions import DefaultEmbeddingFunction
|
||||||
|
|
||||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||||
from crewai.utilities.paths import db_storage_path
|
|
||||||
from crewai.rag.config.base import BaseRagConfig
|
from crewai.rag.config.base import BaseRagConfig
|
||||||
from crewai.rag.chromadb.constants import DEFAULT_TENANT, DEFAULT_DATABASE
|
from crewai.rag.chromadb.constants import (
|
||||||
|
DEFAULT_TENANT,
|
||||||
|
DEFAULT_DATABASE,
|
||||||
|
DEFAULT_STORAGE_PATH,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
warnings.filterwarnings(
|
warnings.filterwarnings(
|
||||||
@@ -29,7 +31,7 @@ def _default_settings() -> Settings:
|
|||||||
Settings with persistent storage and reset enabled.
|
Settings with persistent storage and reset enabled.
|
||||||
"""
|
"""
|
||||||
return Settings(
|
return Settings(
|
||||||
persist_directory=os.path.join(db_storage_path(), "chromadb"),
|
persist_directory=DEFAULT_STORAGE_PATH,
|
||||||
allow_reset=True,
|
allow_reset=True,
|
||||||
is_persistent=True,
|
is_persistent=True,
|
||||||
)
|
)
|
||||||
@@ -39,7 +41,7 @@ def _default_embedding_function() -> ChromaEmbeddingFunctionWrapper:
|
|||||||
"""Create default ChromaDB embedding function.
|
"""Create default ChromaDB embedding function.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Default embedding function cast to proper type.
|
Default embedding function using all-MiniLM-L6-v2 via ONNX.
|
||||||
"""
|
"""
|
||||||
return cast(ChromaEmbeddingFunctionWrapper, DefaultEmbeddingFunction())
|
return cast(ChromaEmbeddingFunctionWrapper, DefaultEmbeddingFunction())
|
||||||
|
|
||||||
@@ -52,6 +54,6 @@ class ChromaDBConfig(BaseRagConfig):
|
|||||||
tenant: str = DEFAULT_TENANT
|
tenant: str = DEFAULT_TENANT
|
||||||
database: str = DEFAULT_DATABASE
|
database: str = DEFAULT_DATABASE
|
||||||
settings: Settings = field(default_factory=_default_settings)
|
settings: Settings = field(default_factory=_default_settings)
|
||||||
embedding_function: ChromaEmbeddingFunctionWrapper | None = field(
|
embedding_function: ChromaEmbeddingFunctionWrapper = field(
|
||||||
default_factory=_default_embedding_function
|
default_factory=_default_embedding_function
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,6 +1,10 @@
|
|||||||
"""Constants for ChromaDB configuration."""
|
"""Constants for ChromaDB configuration."""
|
||||||
|
|
||||||
|
import os
|
||||||
from typing import Final
|
from typing import Final
|
||||||
|
|
||||||
|
from crewai.utilities.paths import db_storage_path
|
||||||
|
|
||||||
DEFAULT_TENANT: Final[str] = "default_tenant"
|
DEFAULT_TENANT: Final[str] = "default_tenant"
|
||||||
DEFAULT_DATABASE: Final[str] = "default_database"
|
DEFAULT_DATABASE: Final[str] = "default_database"
|
||||||
|
DEFAULT_STORAGE_PATH: Final[str] = os.path.join(db_storage_path(), "chromadb")
|
||||||
|
|||||||
@@ -2,7 +2,10 @@
|
|||||||
|
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
from crewai.rag.config.optional_imports.protocols import ChromaFactoryModule
|
from crewai.rag.config.optional_imports.protocols import (
|
||||||
|
ChromaFactoryModule,
|
||||||
|
QdrantFactoryModule,
|
||||||
|
)
|
||||||
from crewai.rag.core.base_client import BaseClient
|
from crewai.rag.core.base_client import BaseClient
|
||||||
from crewai.rag.config.types import RagConfigType
|
from crewai.rag.config.types import RagConfigType
|
||||||
from crewai.utilities.import_utils import require
|
from crewai.utilities.import_utils import require
|
||||||
@@ -22,11 +25,21 @@ def create_client(config: RagConfigType) -> BaseClient:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if config.provider == "chromadb":
|
if config.provider == "chromadb":
|
||||||
mod = cast(
|
chromadb_mod = cast(
|
||||||
ChromaFactoryModule,
|
ChromaFactoryModule,
|
||||||
require(
|
require(
|
||||||
"crewai.rag.chromadb.factory",
|
"crewai.rag.chromadb.factory",
|
||||||
purpose="The 'chromadb' provider",
|
purpose="The 'chromadb' provider",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
return mod.create_client(config)
|
return chromadb_mod.create_client(config)
|
||||||
|
|
||||||
|
if config.provider == "qdrant":
|
||||||
|
qdrant_mod = cast(
|
||||||
|
QdrantFactoryModule,
|
||||||
|
require(
|
||||||
|
"crewai.rag.qdrant.factory",
|
||||||
|
purpose="The 'qdrant' provider",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return qdrant_mod.create_client(config)
|
||||||
|
|||||||
@@ -14,7 +14,9 @@ class _MissingProvider:
|
|||||||
Raises RuntimeError when instantiated to indicate missing dependencies.
|
Raises RuntimeError when instantiated to indicate missing dependencies.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
provider: Literal["chromadb", "__missing__"] = field(default="__missing__")
|
provider: Literal["chromadb", "qdrant", "__missing__"] = field(
|
||||||
|
default="__missing__"
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
"""Raises error indicating the provider is not installed."""
|
"""Raises error indicating the provider is not installed."""
|
||||||
|
|||||||
@@ -1,14 +1,27 @@
|
|||||||
"""Protocol definitions for RAG factory modules."""
|
"""Protocol definitions for RAG factory modules."""
|
||||||
|
|
||||||
from typing import Protocol
|
from __future__ import annotations
|
||||||
|
|
||||||
from crewai.rag.config.types import RagConfigType
|
from typing import Protocol, TYPE_CHECKING
|
||||||
from crewai.rag.core.base_client import BaseClient
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from crewai.rag.chromadb.client import ChromaDBClient
|
||||||
|
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||||
|
from crewai.rag.qdrant.client import QdrantClient
|
||||||
|
from crewai.rag.qdrant.config import QdrantConfig
|
||||||
|
|
||||||
|
|
||||||
class ChromaFactoryModule(Protocol):
|
class ChromaFactoryModule(Protocol):
|
||||||
"""Protocol for ChromaDB factory module."""
|
"""Protocol for ChromaDB factory module."""
|
||||||
|
|
||||||
def create_client(self, config: RagConfigType) -> BaseClient:
|
def create_client(self, config: ChromaDBConfig) -> ChromaDBClient:
|
||||||
"""Creates a ChromaDB client from configuration."""
|
"""Creates a ChromaDB client from configuration."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class QdrantFactoryModule(Protocol):
|
||||||
|
"""Protocol for Qdrant factory module."""
|
||||||
|
|
||||||
|
def create_client(self, config: QdrantConfig) -> QdrantClient:
|
||||||
|
"""Creates a Qdrant client from configuration."""
|
||||||
|
...
|
||||||
|
|||||||
@@ -13,3 +13,10 @@ class MissingChromaDBConfig(_MissingProvider):
|
|||||||
"""Placeholder for missing ChromaDB configuration."""
|
"""Placeholder for missing ChromaDB configuration."""
|
||||||
|
|
||||||
provider: Literal["chromadb"] = field(default="chromadb")
|
provider: Literal["chromadb"] = field(default="chromadb")
|
||||||
|
|
||||||
|
|
||||||
|
@pyd_dataclass(config=ConfigDict(extra="forbid"))
|
||||||
|
class MissingQdrantConfig(_MissingProvider):
|
||||||
|
"""Placeholder for missing Qdrant configuration."""
|
||||||
|
|
||||||
|
provider: Literal["qdrant"] = field(default="qdrant")
|
||||||
|
|||||||
@@ -3,6 +3,6 @@
|
|||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
SupportedProvider = Annotated[
|
SupportedProvider = Annotated[
|
||||||
Literal["chromadb"],
|
Literal["chromadb", "qdrant"],
|
||||||
"Supported RAG provider types, add providers here as they become available",
|
"Supported RAG provider types, add providers here as they become available",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,12 +1,18 @@
|
|||||||
"""Type definitions for RAG configuration."""
|
"""Type definitions for RAG configuration."""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Annotated, TypeAlias
|
from typing import Annotated, TypeAlias, TYPE_CHECKING
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from crewai.rag.config.constants import DISCRIMINATOR
|
from crewai.rag.config.constants import DISCRIMINATOR
|
||||||
|
|
||||||
|
# Linter freaks out on conditional imports, assigning in the type checking fixes it
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
from crewai.rag.chromadb.config import ChromaDBConfig as ChromaDBConfig_
|
||||||
|
|
||||||
|
ChromaDBConfig = ChromaDBConfig_
|
||||||
|
from crewai.rag.qdrant.config import QdrantConfig as QdrantConfig_
|
||||||
|
|
||||||
|
QdrantConfig = QdrantConfig_
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||||
@@ -15,7 +21,14 @@ else:
|
|||||||
MissingChromaDBConfig as ChromaDBConfig,
|
MissingChromaDBConfig as ChromaDBConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
SupportedProviderConfig: TypeAlias = ChromaDBConfig
|
try:
|
||||||
|
from crewai.rag.qdrant.config import QdrantConfig
|
||||||
|
except ImportError:
|
||||||
|
from crewai.rag.config.optional_imports.providers import (
|
||||||
|
MissingQdrantConfig as QdrantConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
SupportedProviderConfig: TypeAlias = ChromaDBConfig | QdrantConfig
|
||||||
RagConfigType: TypeAlias = Annotated[
|
RagConfigType: TypeAlias = Annotated[
|
||||||
SupportedProviderConfig, Field(discriminator=DISCRIMINATOR)
|
SupportedProviderConfig, Field(discriminator=DISCRIMINATOR)
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
"""Protocol for vector database client implementations."""
|
"""Protocol for vector database client implementations."""
|
||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Any, Protocol, runtime_checkable, TypedDict, Annotated
|
from typing import Any, Protocol, runtime_checkable, Annotated
|
||||||
from typing_extensions import Unpack, Required
|
from typing_extensions import Unpack, Required, TypedDict
|
||||||
from pydantic import GetCoreSchemaHandler
|
from pydantic import GetCoreSchemaHandler
|
||||||
from pydantic_core import CoreSchema, core_schema
|
from pydantic_core import CoreSchema, core_schema
|
||||||
|
|
||||||
|
|||||||
54
src/crewai/rag/qdrant/config.py
Normal file
54
src/crewai/rag/qdrant/config.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
"""Qdrant configuration model."""
|
||||||
|
|
||||||
|
from dataclasses import field
|
||||||
|
from typing import Literal, cast
|
||||||
|
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||||
|
|
||||||
|
from crewai.rag.config.base import BaseRagConfig
|
||||||
|
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
|
||||||
|
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
|
||||||
|
|
||||||
|
|
||||||
|
def _default_options() -> QdrantClientParams:
|
||||||
|
"""Create default Qdrant client options.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Default options with file-based storage.
|
||||||
|
"""
|
||||||
|
return QdrantClientParams(path=DEFAULT_STORAGE_PATH)
|
||||||
|
|
||||||
|
|
||||||
|
def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper:
|
||||||
|
"""Create default Qdrant embedding function.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Default embedding function using fastembed with all-MiniLM-L6-v2.
|
||||||
|
"""
|
||||||
|
from fastembed import TextEmbedding
|
||||||
|
|
||||||
|
model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL)
|
||||||
|
|
||||||
|
def embed_fn(text: str) -> list[float]:
|
||||||
|
"""Embed a single text string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Text to embed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embedding vector as list of floats.
|
||||||
|
"""
|
||||||
|
embeddings = list(model.embed([text]))
|
||||||
|
return embeddings[0].tolist() if embeddings else []
|
||||||
|
|
||||||
|
return cast(QdrantEmbeddingFunctionWrapper, embed_fn)
|
||||||
|
|
||||||
|
|
||||||
|
@pyd_dataclass(frozen=True)
|
||||||
|
class QdrantConfig(BaseRagConfig):
|
||||||
|
"""Configuration for Qdrant client."""
|
||||||
|
|
||||||
|
provider: Literal["qdrant"] = field(default="qdrant", init=False)
|
||||||
|
options: QdrantClientParams = field(default_factory=_default_options)
|
||||||
|
embedding_function: QdrantEmbeddingFunctionWrapper = field(
|
||||||
|
default_factory=_default_embedding_function
|
||||||
|
)
|
||||||
@@ -1,7 +1,12 @@
|
|||||||
"""Constants for Qdrant implementation."""
|
"""Constants for Qdrant implementation."""
|
||||||
|
|
||||||
|
import os
|
||||||
from typing import Final
|
from typing import Final
|
||||||
|
|
||||||
from qdrant_client.models import Distance, VectorParams
|
from qdrant_client.models import Distance, VectorParams
|
||||||
|
|
||||||
|
from crewai.utilities.paths import db_storage_path
|
||||||
|
|
||||||
DEFAULT_VECTOR_PARAMS: Final = VectorParams(size=384, distance=Distance.COSINE)
|
DEFAULT_VECTOR_PARAMS: Final = VectorParams(size=384, distance=Distance.COSINE)
|
||||||
|
DEFAULT_EMBEDDING_MODEL: Final[str] = "sentence-transformers/all-MiniLM-L6-v2"
|
||||||
|
DEFAULT_STORAGE_PATH: Final[str] = os.path.join(db_storage_path(), "qdrant")
|
||||||
|
|||||||
21
src/crewai/rag/qdrant/factory.py
Normal file
21
src/crewai/rag/qdrant/factory.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
"""Factory functions for creating Qdrant clients from configuration."""
|
||||||
|
|
||||||
|
from qdrant_client import QdrantClient as SyncQdrantClientBase
|
||||||
|
from crewai.rag.qdrant.client import QdrantClient
|
||||||
|
from crewai.rag.qdrant.config import QdrantConfig
|
||||||
|
|
||||||
|
|
||||||
|
def create_client(config: QdrantConfig) -> QdrantClient:
|
||||||
|
"""Create a Qdrant client from configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: The Qdrant configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A configured QdrantClient instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
qdrant_client = SyncQdrantClientBase(**config.options)
|
||||||
|
return QdrantClient(
|
||||||
|
client=qdrant_client, embedding_function=config.embedding_function
|
||||||
|
)
|
||||||
@@ -1,10 +1,12 @@
|
|||||||
"""Type definitions specific to Qdrant implementation."""
|
"""Type definitions specific to Qdrant implementation."""
|
||||||
|
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import Annotated, Any, Protocol, TypeAlias, TypedDict
|
from typing import Annotated, Any, Protocol, TypeAlias
|
||||||
from typing_extensions import NotRequired
|
from typing_extensions import NotRequired, TypedDict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from pydantic import GetCoreSchemaHandler
|
||||||
|
from pydantic_core import CoreSchema, core_schema
|
||||||
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
|
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
|
||||||
from qdrant_client.models import (
|
from qdrant_client.models import (
|
||||||
FieldCondition,
|
FieldCondition,
|
||||||
@@ -53,6 +55,21 @@ class EmbeddingFunction(Protocol):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class QdrantEmbeddingFunctionWrapper(EmbeddingFunction):
|
||||||
|
"""Base class for Qdrant EmbeddingFunction to work with Pydantic validation."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(
|
||||||
|
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||||
|
) -> CoreSchema:
|
||||||
|
"""Generate Pydantic core schema for Qdrant EmbeddingFunction.
|
||||||
|
|
||||||
|
This allows Pydantic to handle Qdrant's EmbeddingFunction type
|
||||||
|
without requiring arbitrary_types_allowed=True.
|
||||||
|
"""
|
||||||
|
return core_schema.any_schema()
|
||||||
|
|
||||||
|
|
||||||
class AsyncEmbeddingFunction(Protocol):
|
class AsyncEmbeddingFunction(Protocol):
|
||||||
"""Protocol for async embedding functions that convert text to vectors."""
|
"""Protocol for async embedding functions that convert text to vectors."""
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,9 @@
|
|||||||
"""Type definitions for RAG (Retrieval-Augmented Generation) systems."""
|
"""Type definitions for RAG (Retrieval-Augmented Generation) systems."""
|
||||||
|
|
||||||
from collections.abc import Callable, Mapping
|
from collections.abc import Callable, Mapping
|
||||||
from typing import TypeAlias, TypedDict, Any
|
from typing import TypeAlias, Any
|
||||||
|
|
||||||
from typing_extensions import Required
|
from typing_extensions import Required, TypedDict
|
||||||
|
|
||||||
|
|
||||||
class BaseRecord(TypedDict, total=False):
|
class BaseRecord(TypedDict, total=False):
|
||||||
|
|||||||
Reference in New Issue
Block a user