mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 09:08:31 +00:00
feat: qdrant generic client (#3377)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
### Qdrant Client * Add core client with collection, search, and document APIs (sync + async) * Refactor utilities, types, and vector params (default 384-dim) * Improve error handling with `ClientMethodMismatchError` * Add score normalization, async embeddings, and optional `qdrant-client` dep * Expand tests and type safety throughout
This commit is contained in:
26
src/crewai/rag/core/exceptions.py
Normal file
26
src/crewai/rag/core/exceptions.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Core exceptions for RAG module."""
|
||||
|
||||
|
||||
class ClientMethodMismatchError(TypeError):
|
||||
"""Raised when a method is called with the wrong client type.
|
||||
|
||||
Typically used when a sync method is called with an async client,
|
||||
or vice versa.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, method_name: str, expected_client: str, alt_method: str, alt_client: str
|
||||
) -> None:
|
||||
"""Create a ClientMethodMismatchError.
|
||||
|
||||
Args:
|
||||
method_name: Method that was called incorrectly.
|
||||
expected_client: Required client type.
|
||||
alt_method: Suggested alternative method.
|
||||
alt_client: Client type for the alternative method.
|
||||
"""
|
||||
message = (
|
||||
f"Method {method_name}() requires a {expected_client}. "
|
||||
f"Use {alt_method}() for {alt_client}."
|
||||
)
|
||||
super().__init__(message)
|
||||
1
src/crewai/rag/qdrant/__init__.py
Normal file
1
src/crewai/rag/qdrant/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Qdrant vector database client implementation."""
|
||||
527
src/crewai/rag/qdrant/client.py
Normal file
527
src/crewai/rag/qdrant/client.py
Normal file
@@ -0,0 +1,527 @@
|
||||
"""Qdrant client implementation."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from fastembed import TextEmbedding
|
||||
from qdrant_client import QdrantClient as SyncQdrantClientBase
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_client import (
|
||||
BaseClient,
|
||||
BaseCollectionParams,
|
||||
BaseCollectionAddParams,
|
||||
BaseCollectionSearchParams,
|
||||
)
|
||||
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
||||
from crewai.rag.qdrant.types import (
|
||||
AsyncEmbeddingFunction,
|
||||
EmbeddingFunction,
|
||||
QdrantClientParams,
|
||||
QdrantClientType,
|
||||
QdrantCollectionCreateParams,
|
||||
)
|
||||
from crewai.rag.qdrant.utils import (
|
||||
_is_async_client,
|
||||
_is_async_embedding_function,
|
||||
_is_sync_client,
|
||||
_create_point_from_document,
|
||||
_get_collection_params,
|
||||
_prepare_search_params,
|
||||
_process_search_results,
|
||||
)
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
|
||||
class QdrantClient(BaseClient):
|
||||
"""Qdrant implementation of the BaseClient protocol.
|
||||
|
||||
Provides vector database operations for Qdrant, supporting both
|
||||
synchronous and asynchronous clients.
|
||||
|
||||
Attributes:
|
||||
client: Qdrant client instance (QdrantClient or AsyncQdrantClient).
|
||||
embedding_function: Function to generate embeddings for documents.
|
||||
"""
|
||||
|
||||
client: QdrantClientType
|
||||
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: QdrantClientType | None = None,
|
||||
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction | None = None,
|
||||
**kwargs: Unpack[QdrantClientParams],
|
||||
) -> None:
|
||||
"""Initialize QdrantClient with optional client and embedding function.
|
||||
|
||||
Args:
|
||||
client: Optional pre-configured Qdrant client instance.
|
||||
embedding_function: Optional embedding function. If not provided,
|
||||
uses FastEmbed's BAAI/bge-small-en-v1.5 model.
|
||||
**kwargs: Additional arguments for QdrantClient creation.
|
||||
"""
|
||||
if client is not None:
|
||||
self.client = client
|
||||
else:
|
||||
location = kwargs.get("location", ":memory:")
|
||||
client_kwargs = {k: v for k, v in kwargs.items() if k != "location"}
|
||||
self.client = SyncQdrantClientBase(location, **cast(Any, client_kwargs))
|
||||
|
||||
if embedding_function is not None:
|
||||
self.embedding_function = embedding_function
|
||||
else:
|
||||
_embedder = TextEmbedding("BAAI/bge-small-en-v1.5")
|
||||
|
||||
def _embed_fn(text: str) -> list[float]:
|
||||
embeddings = list(_embedder.embed([text]))
|
||||
return [float(x) for x in embeddings[0]] if embeddings else []
|
||||
|
||||
self.embedding_function = _embed_fn
|
||||
|
||||
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
|
||||
"""Create a new collection in Qdrant.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to create. Must be unique.
|
||||
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
|
||||
sparse_vectors_config: Optional sparse vector configuration.
|
||||
shard_number: Optional number of shards.
|
||||
replication_factor: Optional replication factor.
|
||||
write_consistency_factor: Optional write consistency factor.
|
||||
on_disk_payload: Optional flag to store payload on disk.
|
||||
hnsw_config: Optional HNSW index configuration.
|
||||
optimizers_config: Optional optimizer configuration.
|
||||
wal_config: Optional write-ahead log configuration.
|
||||
quantization_config: Optional quantization configuration.
|
||||
init_from: Optional collection to initialize from.
|
||||
timeout: Optional timeout for the operation.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection with the same name already exists.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="create_collection",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="acreate_collection",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' already exists")
|
||||
|
||||
params = _get_collection_params(kwargs)
|
||||
self.client.create_collection(**params)
|
||||
|
||||
async def acreate_collection(
|
||||
self, **kwargs: Unpack[QdrantCollectionCreateParams]
|
||||
) -> None:
|
||||
"""Create a new collection in Qdrant asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to create. Must be unique.
|
||||
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
|
||||
sparse_vectors_config: Optional sparse vector configuration.
|
||||
shard_number: Optional number of shards.
|
||||
replication_factor: Optional replication factor.
|
||||
write_consistency_factor: Optional write consistency factor.
|
||||
on_disk_payload: Optional flag to store payload on disk.
|
||||
hnsw_config: Optional HNSW index configuration.
|
||||
optimizers_config: Optional optimizer configuration.
|
||||
wal_config: Optional write-ahead log configuration.
|
||||
quantization_config: Optional quantization configuration.
|
||||
init_from: Optional collection to initialize from.
|
||||
timeout: Optional timeout for the operation.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection with the same name already exists.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="acreate_collection",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="create_collection",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if await self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' already exists")
|
||||
|
||||
params = _get_collection_params(kwargs)
|
||||
await self.client.create_collection(**params)
|
||||
|
||||
def get_or_create_collection(
|
||||
self, **kwargs: Unpack[QdrantCollectionCreateParams]
|
||||
) -> Any:
|
||||
"""Get an existing collection or create it if it doesn't exist.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to get or create.
|
||||
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
|
||||
sparse_vectors_config: Optional sparse vector configuration.
|
||||
shard_number: Optional number of shards.
|
||||
replication_factor: Optional replication factor.
|
||||
write_consistency_factor: Optional write consistency factor.
|
||||
on_disk_payload: Optional flag to store payload on disk.
|
||||
hnsw_config: Optional HNSW index configuration.
|
||||
optimizers_config: Optional optimizer configuration.
|
||||
wal_config: Optional write-ahead log configuration.
|
||||
quantization_config: Optional quantization configuration.
|
||||
init_from: Optional collection to initialize from.
|
||||
timeout: Optional timeout for the operation.
|
||||
|
||||
Returns:
|
||||
Collection info dict with name and other metadata.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="get_or_create_collection",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="aget_or_create_collection",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if self.client.collection_exists(collection_name):
|
||||
return self.client.get_collection(collection_name)
|
||||
|
||||
params = _get_collection_params(kwargs)
|
||||
self.client.create_collection(**params)
|
||||
|
||||
return self.client.get_collection(collection_name)
|
||||
|
||||
async def aget_or_create_collection(
|
||||
self, **kwargs: Unpack[QdrantCollectionCreateParams]
|
||||
) -> Any:
|
||||
"""Get an existing collection or create it if it doesn't exist asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to get or create.
|
||||
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
|
||||
sparse_vectors_config: Optional sparse vector configuration.
|
||||
shard_number: Optional number of shards.
|
||||
replication_factor: Optional replication factor.
|
||||
write_consistency_factor: Optional write consistency factor.
|
||||
on_disk_payload: Optional flag to store payload on disk.
|
||||
hnsw_config: Optional HNSW index configuration.
|
||||
optimizers_config: Optional optimizer configuration.
|
||||
wal_config: Optional write-ahead log configuration.
|
||||
quantization_config: Optional quantization configuration.
|
||||
init_from: Optional collection to initialize from.
|
||||
timeout: Optional timeout for the operation.
|
||||
|
||||
Returns:
|
||||
Collection info dict with name and other metadata.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="aget_or_create_collection",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="get_or_create_collection",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if await self.client.collection_exists(collection_name):
|
||||
return await self.client.get_collection(collection_name)
|
||||
|
||||
params = _get_collection_params(kwargs)
|
||||
await self.client.create_collection(**params)
|
||||
|
||||
return await self.client.get_collection(collection_name)
|
||||
|
||||
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing document data.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="add_documents",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="aadd_documents",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
if not self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
points = []
|
||||
for doc in documents:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
raise TypeError(
|
||||
"Async embedding function cannot be used with sync add_documents. "
|
||||
"Use aadd_documents instead."
|
||||
)
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
point = _create_point_from_document(doc, embedding)
|
||||
points.append(point)
|
||||
|
||||
self.client.upsert(collection_name=collection_name, points=points, wait=True)
|
||||
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing document data.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="aadd_documents",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="add_documents",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
if not await self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
points = []
|
||||
for doc in documents:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
||||
embedding = await async_fn(doc["content"])
|
||||
else:
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
point = _create_point_from_document(doc, embedding)
|
||||
points.append(point)
|
||||
|
||||
await self.client.upsert(
|
||||
collection_name=collection_name, points=points, wait=True
|
||||
)
|
||||
|
||||
def search(
|
||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to search in.
|
||||
query: The text query to search for.
|
||||
limit: Maximum number of results to return (default: 10).
|
||||
metadata_filter: Optional filter for metadata fields.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="search",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="asearch",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
query = kwargs["query"]
|
||||
limit = kwargs.get("limit", 10)
|
||||
metadata_filter = kwargs.get("metadata_filter")
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
|
||||
if not self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
raise TypeError(
|
||||
"Async embedding function cannot be used with sync search. "
|
||||
"Use asearch instead."
|
||||
)
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
query_embedding = sync_fn(query)
|
||||
|
||||
search_kwargs = _prepare_search_params(
|
||||
collection_name=collection_name,
|
||||
query_embedding=query_embedding,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
metadata_filter=metadata_filter,
|
||||
)
|
||||
|
||||
response = self.client.query_points(**search_kwargs)
|
||||
return _process_search_results(response)
|
||||
|
||||
async def asearch(
|
||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to search in.
|
||||
query: The text query to search for.
|
||||
limit: Maximum number of results to return (default: 10).
|
||||
metadata_filter: Optional filter for metadata fields.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="asearch",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="search",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
query = kwargs["query"]
|
||||
limit = kwargs.get("limit", 10)
|
||||
metadata_filter = kwargs.get("metadata_filter")
|
||||
score_threshold = kwargs.get("score_threshold")
|
||||
|
||||
if not await self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
||||
query_embedding = await async_fn(query)
|
||||
else:
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
query_embedding = sync_fn(query)
|
||||
|
||||
search_kwargs = _prepare_search_params(
|
||||
collection_name=collection_name,
|
||||
query_embedding=query_embedding,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
metadata_filter=metadata_filter,
|
||||
)
|
||||
|
||||
response = await self.client.query_points(**search_kwargs)
|
||||
return _process_search_results(response)
|
||||
|
||||
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="delete_collection",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="adelete_collection",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if not self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
self.client.delete_collection(collection_name=collection_name)
|
||||
|
||||
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="adelete_collection",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="delete_collection",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if not await self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
await self.client.delete_collection(collection_name=collection_name)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="reset",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="areset",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collections_response = self.client.get_collections()
|
||||
|
||||
for collection in collections_response.collections:
|
||||
self.client.delete_collection(collection_name=collection.name)
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data asynchronously.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="areset",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="reset",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collections_response = await self.client.get_collections()
|
||||
|
||||
for collection in collections_response.collections:
|
||||
await self.client.delete_collection(collection_name=collection.name)
|
||||
7
src/crewai/rag/qdrant/constants.py
Normal file
7
src/crewai/rag/qdrant/constants.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Constants for Qdrant implementation."""
|
||||
|
||||
from typing import Final
|
||||
|
||||
from qdrant_client.models import Distance, VectorParams
|
||||
|
||||
DEFAULT_VECTOR_PARAMS: Final = VectorParams(size=384, distance=Distance.COSINE)
|
||||
134
src/crewai/rag/qdrant/types.py
Normal file
134
src/crewai/rag/qdrant/types.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""Type definitions specific to Qdrant implementation."""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Annotated, Any, Protocol, TypeAlias, TypedDict
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
import numpy as np
|
||||
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
|
||||
from qdrant_client.models import (
|
||||
FieldCondition,
|
||||
Filter,
|
||||
HasIdCondition,
|
||||
HasVectorCondition,
|
||||
HnswConfigDiff,
|
||||
InitFrom,
|
||||
IsEmptyCondition,
|
||||
IsNullCondition,
|
||||
NestedCondition,
|
||||
OptimizersConfigDiff,
|
||||
QuantizationConfig,
|
||||
ShardingMethod,
|
||||
SparseVectorsConfig,
|
||||
VectorsConfig,
|
||||
WalConfigDiff,
|
||||
)
|
||||
|
||||
from crewai.rag.core.base_client import BaseCollectionParams
|
||||
|
||||
QdrantClientType = SyncQdrantClient | AsyncQdrantClient
|
||||
|
||||
QueryEmbedding: TypeAlias = list[float] | np.ndarray[Any, np.dtype[np.floating[Any]]]
|
||||
|
||||
BasicConditions = FieldCondition | IsEmptyCondition | IsNullCondition
|
||||
StructuralConditions = HasIdCondition | HasVectorCondition | NestedCondition
|
||||
FilterCondition = BasicConditions | StructuralConditions | Filter
|
||||
|
||||
MetadataFilterValue = bool | int | str
|
||||
MetadataFilter = dict[str, MetadataFilterValue]
|
||||
|
||||
|
||||
class EmbeddingFunction(Protocol):
|
||||
"""Protocol for embedding functions that convert text to vectors."""
|
||||
|
||||
def __call__(self, text: str) -> QueryEmbedding:
|
||||
"""Convert text to embedding vector.
|
||||
|
||||
Args:
|
||||
text: Input text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats or numpy array.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class AsyncEmbeddingFunction(Protocol):
|
||||
"""Protocol for async embedding functions that convert text to vectors."""
|
||||
|
||||
async def __call__(self, text: str) -> QueryEmbedding:
|
||||
"""Convert text to embedding vector asynchronously.
|
||||
|
||||
Args:
|
||||
text: Input text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats or numpy array.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class QdrantClientParams(TypedDict, total=False):
|
||||
"""Parameters for QdrantClient initialization."""
|
||||
|
||||
location: str | None
|
||||
url: str | None
|
||||
port: int
|
||||
grpc_port: int
|
||||
prefer_grpc: bool
|
||||
https: bool | None
|
||||
api_key: str | None
|
||||
prefix: str | None
|
||||
timeout: int | None
|
||||
host: str | None
|
||||
path: str | None
|
||||
force_disable_check_same_thread: bool
|
||||
grpc_options: dict[str, Any] | None
|
||||
auth_token_provider: Callable[[], str] | Callable[[], Awaitable[str]] | None
|
||||
cloud_inference: bool
|
||||
local_inference_batch_size: int | None
|
||||
check_compatibility: bool
|
||||
|
||||
|
||||
class CommonCreateFields(TypedDict, total=False):
|
||||
"""Fields shared between high-level and direct create_collection params."""
|
||||
|
||||
vectors_config: VectorsConfig
|
||||
sparse_vectors_config: SparseVectorsConfig
|
||||
shard_number: Annotated[int, "Number of shards (default: 1)"]
|
||||
sharding_method: ShardingMethod
|
||||
replication_factor: Annotated[int, "Number of replicas per shard (default: 1)"]
|
||||
write_consistency_factor: Annotated[int, "Await N replicas on write (default: 1)"]
|
||||
on_disk_payload: Annotated[bool, "Store payload on disk instead of RAM"]
|
||||
hnsw_config: HnswConfigDiff
|
||||
optimizers_config: OptimizersConfigDiff
|
||||
wal_config: WalConfigDiff
|
||||
quantization_config: QuantizationConfig
|
||||
init_from: InitFrom | str
|
||||
timeout: Annotated[int, "Operation timeout in seconds"]
|
||||
|
||||
|
||||
class QdrantCollectionCreateParams(
|
||||
BaseCollectionParams, CommonCreateFields, total=False
|
||||
):
|
||||
"""High-level parameters for creating a Qdrant collection."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class CreateCollectionParams(CommonCreateFields, total=False):
|
||||
"""Parameters for qdrant_client.create_collection."""
|
||||
|
||||
collection_name: str
|
||||
|
||||
|
||||
class PreparedSearchParams(TypedDict):
|
||||
"""Type definition for prepared Qdrant search parameters."""
|
||||
|
||||
collection_name: str
|
||||
query: list[float]
|
||||
limit: Annotated[int, "Max results to return"]
|
||||
with_payload: Annotated[bool, "Include payload in results"]
|
||||
with_vectors: Annotated[bool, "Include vectors in results"]
|
||||
score_threshold: NotRequired[Annotated[float, "Min similarity score (0-1)"]]
|
||||
query_filter: NotRequired[Filter]
|
||||
228
src/crewai/rag/qdrant/utils.py
Normal file
228
src/crewai/rag/qdrant/utils.py
Normal file
@@ -0,0 +1,228 @@
|
||||
"""Utility functions for Qdrant operations."""
|
||||
|
||||
import asyncio
|
||||
from typing import TypeGuard
|
||||
from uuid import uuid4
|
||||
|
||||
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
|
||||
from qdrant_client.models import (
|
||||
FieldCondition,
|
||||
Filter,
|
||||
MatchValue,
|
||||
PointStruct,
|
||||
QueryResponse,
|
||||
)
|
||||
|
||||
from crewai.rag.qdrant.constants import DEFAULT_VECTOR_PARAMS
|
||||
from crewai.rag.qdrant.types import (
|
||||
AsyncEmbeddingFunction,
|
||||
CreateCollectionParams,
|
||||
EmbeddingFunction,
|
||||
FilterCondition,
|
||||
MetadataFilter,
|
||||
PreparedSearchParams,
|
||||
QdrantClientType,
|
||||
QdrantCollectionCreateParams,
|
||||
QueryEmbedding,
|
||||
)
|
||||
from crewai.rag.types import SearchResult, BaseRecord
|
||||
|
||||
|
||||
def _ensure_list_embedding(embedding: QueryEmbedding) -> list[float]:
|
||||
"""Convert embedding to list[float] format if needed.
|
||||
|
||||
Args:
|
||||
embedding: Embedding vector as list or numpy array.
|
||||
|
||||
Returns:
|
||||
Embedding as list[float].
|
||||
"""
|
||||
if not isinstance(embedding, list):
|
||||
return embedding.tolist()
|
||||
return embedding
|
||||
|
||||
|
||||
def _is_sync_client(client: QdrantClientType) -> TypeGuard[SyncQdrantClient]:
|
||||
"""Type guard to check if the client is a synchronous QdrantClient.
|
||||
|
||||
Args:
|
||||
client: The client to check.
|
||||
|
||||
Returns:
|
||||
True if the client is a QdrantClient, False otherwise.
|
||||
"""
|
||||
return isinstance(client, SyncQdrantClient)
|
||||
|
||||
|
||||
def _is_async_client(client: QdrantClientType) -> TypeGuard[AsyncQdrantClient]:
|
||||
"""Type guard to check if the client is an asynchronous AsyncQdrantClient.
|
||||
|
||||
Args:
|
||||
client: The client to check.
|
||||
|
||||
Returns:
|
||||
True if the client is an AsyncQdrantClient, False otherwise.
|
||||
"""
|
||||
return isinstance(client, AsyncQdrantClient)
|
||||
|
||||
|
||||
def _is_async_embedding_function(
|
||||
func: EmbeddingFunction | AsyncEmbeddingFunction,
|
||||
) -> TypeGuard[AsyncEmbeddingFunction]:
|
||||
"""Type guard to check if the embedding function is async.
|
||||
|
||||
Args:
|
||||
func: The embedding function to check.
|
||||
|
||||
Returns:
|
||||
True if the function is async, False otherwise.
|
||||
"""
|
||||
return asyncio.iscoroutinefunction(func)
|
||||
|
||||
|
||||
def _get_collection_params(
|
||||
kwargs: QdrantCollectionCreateParams,
|
||||
) -> CreateCollectionParams:
|
||||
"""Extract collection creation parameters from kwargs."""
|
||||
params: CreateCollectionParams = {
|
||||
"collection_name": kwargs["collection_name"],
|
||||
"vectors_config": kwargs.get("vectors_config", DEFAULT_VECTOR_PARAMS),
|
||||
}
|
||||
|
||||
if "sparse_vectors_config" in kwargs:
|
||||
params["sparse_vectors_config"] = kwargs["sparse_vectors_config"]
|
||||
if "shard_number" in kwargs:
|
||||
params["shard_number"] = kwargs["shard_number"]
|
||||
if "sharding_method" in kwargs:
|
||||
params["sharding_method"] = kwargs["sharding_method"]
|
||||
if "replication_factor" in kwargs:
|
||||
params["replication_factor"] = kwargs["replication_factor"]
|
||||
if "write_consistency_factor" in kwargs:
|
||||
params["write_consistency_factor"] = kwargs["write_consistency_factor"]
|
||||
if "on_disk_payload" in kwargs:
|
||||
params["on_disk_payload"] = kwargs["on_disk_payload"]
|
||||
if "hnsw_config" in kwargs:
|
||||
params["hnsw_config"] = kwargs["hnsw_config"]
|
||||
if "optimizers_config" in kwargs:
|
||||
params["optimizers_config"] = kwargs["optimizers_config"]
|
||||
if "wal_config" in kwargs:
|
||||
params["wal_config"] = kwargs["wal_config"]
|
||||
if "quantization_config" in kwargs:
|
||||
params["quantization_config"] = kwargs["quantization_config"]
|
||||
if "init_from" in kwargs:
|
||||
params["init_from"] = kwargs["init_from"]
|
||||
if "timeout" in kwargs:
|
||||
params["timeout"] = kwargs["timeout"]
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def _prepare_search_params(
|
||||
collection_name: str,
|
||||
query_embedding: QueryEmbedding,
|
||||
limit: int,
|
||||
score_threshold: float | None,
|
||||
metadata_filter: MetadataFilter | None,
|
||||
) -> PreparedSearchParams:
|
||||
"""Prepare search parameters for Qdrant query_points.
|
||||
|
||||
Args:
|
||||
collection_name: Name of the collection to search.
|
||||
query_embedding: Embedding vector for the query.
|
||||
limit: Maximum number of results.
|
||||
score_threshold: Optional minimum similarity score.
|
||||
metadata_filter: Optional metadata filters.
|
||||
|
||||
Returns:
|
||||
Dictionary of parameters for query_points method.
|
||||
"""
|
||||
query_vector = _ensure_list_embedding(query_embedding)
|
||||
|
||||
search_kwargs: PreparedSearchParams = {
|
||||
"collection_name": collection_name,
|
||||
"query": query_vector,
|
||||
"limit": limit,
|
||||
"with_payload": True,
|
||||
"with_vectors": False,
|
||||
}
|
||||
|
||||
if score_threshold is not None:
|
||||
search_kwargs["score_threshold"] = score_threshold
|
||||
|
||||
if metadata_filter:
|
||||
filter_conditions: list[FilterCondition] = []
|
||||
for key, value in metadata_filter.items():
|
||||
filter_conditions.append(
|
||||
FieldCondition(key=key, match=MatchValue(value=value))
|
||||
)
|
||||
|
||||
search_kwargs["query_filter"] = Filter(must=filter_conditions)
|
||||
|
||||
return search_kwargs
|
||||
|
||||
|
||||
def _normalize_qdrant_score(score: float) -> float:
|
||||
"""Normalize Qdrant cosine similarity score to [0, 1] range.
|
||||
|
||||
Converts from Qdrant's [-1, 1] cosine similarity range to [0, 1] range for standardization across clients.
|
||||
|
||||
Args:
|
||||
score: Raw cosine similarity score from Qdrant [-1, 1].
|
||||
|
||||
Returns:
|
||||
Normalized score in [0, 1] range where 1 is most similar.
|
||||
"""
|
||||
normalized = (score + 1.0) / 2.0
|
||||
return max(0.0, min(1.0, normalized))
|
||||
|
||||
|
||||
def _process_search_results(response: QueryResponse) -> list[SearchResult]:
|
||||
"""Process Qdrant search response into SearchResult format.
|
||||
|
||||
Args:
|
||||
response: Response from Qdrant query_points method.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dictionaries.
|
||||
"""
|
||||
results: list[SearchResult] = []
|
||||
for point in response.points:
|
||||
payload = point.payload or {}
|
||||
score = _normalize_qdrant_score(score=point.score)
|
||||
result: SearchResult = {
|
||||
"id": str(point.id),
|
||||
"content": payload.get("content", ""),
|
||||
"metadata": {k: v for k, v in payload.items() if k != "content"},
|
||||
"score": score,
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _create_point_from_document(
|
||||
doc: BaseRecord, embedding: QueryEmbedding
|
||||
) -> PointStruct:
|
||||
"""Create a PointStruct from a document and its embedding.
|
||||
|
||||
Args:
|
||||
doc: Document dictionary containing content, metadata, and optional doc_id.
|
||||
embedding: The embedding vector for the document content.
|
||||
|
||||
Returns:
|
||||
PointStruct ready to be upserted to Qdrant.
|
||||
"""
|
||||
doc_id = doc.get("doc_id", str(uuid4()))
|
||||
vector = _ensure_list_embedding(embedding)
|
||||
|
||||
metadata = doc.get("metadata", {})
|
||||
if isinstance(metadata, list):
|
||||
metadata = metadata[0] if metadata else {}
|
||||
elif not isinstance(metadata, dict):
|
||||
metadata = dict(metadata) if metadata else {}
|
||||
|
||||
return PointStruct(
|
||||
id=doc_id,
|
||||
vector=vector,
|
||||
payload={"content": doc["content"], **metadata},
|
||||
)
|
||||
Reference in New Issue
Block a user