Files
crewAI/src/crewai/rag/core/base_client.py
Greyson LaLonde 869bb115c8 Qdrant RAG Provider Support (#3400)
* Added Qdrant provider support with factory, config, and protocols
* Improved default embeddings and type definitions
* Fixed ChromaDB factory embedding assignment
2025-08-26 08:44:02 -04:00

447 lines
17 KiB
Python

"""Protocol for vector database client implementations."""
from abc import abstractmethod
from typing import Any, Protocol, runtime_checkable, Annotated
from typing_extensions import Unpack, Required, TypedDict
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from crewai.rag.types import (
EmbeddingFunction,
BaseRecord,
SearchResult,
)
class BaseCollectionParams(TypedDict):
"""Base parameters for collection operations.
Attributes:
collection_name: The name of the collection/index to operate on.
"""
collection_name: Required[
Annotated[
str,
"Name of the collection/index. Implementations may have specific constraints (e.g., character limits, allowed characters, case sensitivity).",
]
]
class BaseCollectionAddParams(BaseCollectionParams):
"""Parameters for adding documents to a collection.
Extends BaseCollectionParams with document-specific fields.
Attributes:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dictionaries containing document data.
"""
documents: list[BaseRecord]
class BaseCollectionSearchParams(BaseCollectionParams, total=False):
"""Parameters for searching within a collection.
Extends BaseCollectionParams with search-specific optional fields.
All fields except collection_name and query are optional.
Attributes:
query: The text query to search for (required).
limit: Maximum number of results to return.
metadata_filter: Filter results by metadata fields.
score_threshold: Minimum similarity score for results (0-1).
"""
query: Required[str]
limit: int
metadata_filter: dict[str, Any]
score_threshold: float
@runtime_checkable
class BaseClient(Protocol):
"""Protocol for vector store client implementations.
This protocol defines the interface that all vector store client implementations
must follow. It provides a consistent API for storing and retrieving
documents with their vector embeddings across different vector database
backends (e.g., Qdrant, ChromaDB, Weaviate). Implementing classes should
handle connection management, data persistence, and vector similarity
search operations specific to their backend.
Implementation Guidelines:
Implementations should accept BaseClientParams in their constructor to allow
passing pre-configured client instances:
class MyVectorClient:
def __init__(self, client: Any | None = None, **kwargs):
if client:
self.client = client
else:
self.client = self._create_default_client(**kwargs)
Notes:
This protocol replaces the former BaseRAGStorage abstraction,
providing a cleaner interface for vector store operations.
Attributes:
embedding_function: Callable that takes a list of text strings
and returns a list of embedding vectors. Implementations
should always provide a default embedding function.
client: The underlying vector database client instance. This could be
passed via BaseClientParams during initialization or created internally.
"""
client: Any
embedding_function: EmbeddingFunction
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Generate Pydantic core schema for BaseClient Protocol.
This allows the Protocol to be used in Pydantic models without
requiring arbitrary_types_allowed=True.
"""
return core_schema.any_schema()
@abstractmethod
def create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Create a new collection/index in the vector database.
Keyword Args:
collection_name: The name of the collection to create. Must be unique within
the vector database instance.
Raises:
ValueError: If collection name already exists.
ConnectionError: If unable to connect to the vector database backend.
"""
...
@abstractmethod
async def acreate_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Create a new collection/index in the vector database asynchronously.
Keyword Args:
collection_name: The name of the collection to create. Must be unique within
the vector database instance.
Raises:
ValueError: If collection name already exists.
ConnectionError: If unable to connect to the vector database backend.
"""
...
@abstractmethod
def get_or_create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> Any:
"""Get an existing collection or create it if it doesn't exist.
This method provides a convenient way to ensure a collection exists
without having to check for its existence first.
Keyword Args:
collection_name: The name of the collection to get or create.
Returns:
A collection object whose type depends on the backend implementation.
This could be a collection reference, ID, or client object.
Raises:
ValueError: If unable to create the collection.
ConnectionError: If unable to connect to the vector database backend.
"""
...
@abstractmethod
async def aget_or_create_collection(
self, **kwargs: Unpack[BaseCollectionParams]
) -> Any:
"""Get an existing collection or create it if it doesn't exist asynchronously.
Keyword Args:
collection_name: The name of the collection to get or create.
Returns:
A collection object whose type depends on the backend implementation.
Raises:
ValueError: If unable to create the collection.
ConnectionError: If unable to connect to the vector database backend.
"""
...
@abstractmethod
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection.
This method performs an upsert operation - if a document with the same ID
already exists, it will be updated with the new content and metadata.
Implementations should handle embedding generation internally based on
the configured embedding function.
Keyword Args:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing:
- content: The text content (required)
- doc_id: Optional unique identifier (auto-generated from content hash if missing)
- metadata: Optional metadata dictionary
Embeddings will be generated automatically.
Raises:
ValueError: If collection doesn't exist or documents list is empty.
TypeError: If documents are not BaseRecord dict instances.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>> from crewai.rag.types import BaseRecord
>>> client = ChromaDBClient()
>>>
>>> records: list[BaseRecord] = [
... {
... "content": "Machine learning basics",
... "metadata": {"source": "file3", "topic": "ML"}
... },
... {
... "doc_id": "custom_id",
... "content": "Deep learning fundamentals",
... "metadata": {"source": "file4", "topic": "DL"}
... }
... ]
>>> client.add_documents(collection_name="my_docs", documents=records)
>>>
>>> records_with_id: list[BaseRecord] = [
... {
... "doc_id": "nlp_001",
... "content": "Advanced NLP techniques",
... "metadata": {"source": "file5", "topic": "NLP"}
... }
... ]
>>> client.add_documents(collection_name="my_docs", documents=records_with_id)
"""
...
@abstractmethod
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection asynchronously.
Implementations should handle embedding generation internally based on
the configured embedding function.
Keyword Args:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing:
- content: The text content (required)
- doc_id: Optional unique identifier (auto-generated from content hash if missing)
- metadata: Optional metadata dictionary
Embeddings will be generated automatically.
Raises:
ValueError: If collection doesn't exist or documents list is empty.
TypeError: If documents are not BaseRecord dict instances.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> import asyncio
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>> from crewai.rag.types import BaseRecord
>>>
>>> async def add_documents():
... client = ChromaDBClient()
...
... records: list[BaseRecord] = [
... {
... "doc_id": "doc2",
... "content": "Async operations in Python",
... "metadata": {"source": "file2", "topic": "async"}
... }
... ]
... await client.aadd_documents(collection_name="my_docs", documents=records)
...
>>> asyncio.run(add_documents())
"""
...
@abstractmethod
def search(
self, **kwargs: Unpack[BaseCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query.
Performs a vector similarity search to find the most similar documents
to the provided query.
Keyword Args:
collection_name: The name of the collection to search in.
query: The text query to search for. The implementation handles
embedding generation internally.
limit: Maximum number of results to return. Defaults to 10.
metadata_filter: Optional metadata filter to apply to the search. The exact
format depends on the backend, but typically supports equality
and range queries on metadata fields.
score_threshold: Optional minimum similarity score threshold. Only
results with scores >= this threshold will be returned. The
score interpretation depends on the distance metric used.
Returns:
A list of SearchResult dictionaries ordered by similarity score in
descending order. Each result contains:
- id: Document ID
- content: Document text content
- metadata: Document metadata
- score: Similarity score (0-1, higher is better)
Raises:
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>> client = ChromaDBClient()
>>>
>>> results = client.search(
... collection_name="my_docs",
... query="What is machine learning?",
... limit=5,
... metadata_filter={"source": "file1"},
... score_threshold=0.7
... )
>>> for result in results:
... print(f"{result['id']}: {result['score']:.2f}")
"""
...
@abstractmethod
async def asearch(
self, **kwargs: Unpack[BaseCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query asynchronously.
Keyword Args:
collection_name: The name of the collection to search in.
query: The text query to search for. The implementation handles
embedding generation internally.
limit: Maximum number of results to return. Defaults to 10.
metadata_filter: Optional metadata filter to apply to the search.
score_threshold: Optional minimum similarity score threshold.
Returns:
A list of SearchResult dictionaries ordered by similarity score.
Raises:
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> import asyncio
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>>
>>> async def search_documents():
... client = ChromaDBClient()
... results = await client.asearch(
... collection_name="my_docs",
... query="Python programming best practices",
... limit=5,
... metadata_filter={"source": "file1"},
... score_threshold=0.7
... )
... for result in results:
... print(f"{result['id']}: {result['score']:.2f}")
...
>>> asyncio.run(search_documents())
"""
...
@abstractmethod
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data.
This operation is irreversible and will permanently remove all documents,
embeddings, and metadata associated with the collection.
Keyword Args:
collection_name: The name of the collection to delete.
Raises:
ValueError: If the collection doesn't exist.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>> client = ChromaDBClient()
>>> client.delete_collection(collection_name="old_docs")
>>> print("Collection 'old_docs' deleted successfully")
"""
...
@abstractmethod
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data asynchronously.
Keyword Args:
collection_name: The name of the collection to delete.
Raises:
ValueError: If the collection doesn't exist.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> import asyncio
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>>
>>> async def delete_old_collection():
... client = ChromaDBClient()
... await client.adelete_collection(collection_name="old_docs")
... print("Collection 'old_docs' deleted successfully")
...
>>> asyncio.run(delete_old_collection())
"""
...
@abstractmethod
def reset(self) -> None:
"""Reset the vector database by deleting all collections and data.
This method provides a way to completely clear the vector database,
removing all collections and their contents. Use with caution as
this operation is irreversible.
Raises:
ConnectionError: If unable to connect to the vector database backend.
PermissionError: If the operation is not allowed by the backend.
Example:
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>> client = ChromaDBClient()
>>> client.reset()
>>> print("Vector database completely reset - all data deleted")
"""
...
@abstractmethod
async def areset(self) -> None:
"""Reset the vector database by deleting all collections and data asynchronously.
Raises:
ConnectionError: If unable to connect to the vector database backend.
PermissionError: If the operation is not allowed by the backend.
Example:
>>> import asyncio
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>>
>>> async def reset_database():
... client = ChromaDBClient()
... await client.areset()
... print("Vector database completely reset - all data deleted")
...
>>> asyncio.run(reset_database())
"""
...