mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
* Added Qdrant provider support with factory, config, and protocols * Improved default embeddings and type definitions * Fixed ChromaDB factory embedding assignment
447 lines
17 KiB
Python
447 lines
17 KiB
Python
"""Protocol for vector database client implementations."""
|
|
|
|
from abc import abstractmethod
|
|
from typing import Any, Protocol, runtime_checkable, Annotated
|
|
from typing_extensions import Unpack, Required, TypedDict
|
|
from pydantic import GetCoreSchemaHandler
|
|
from pydantic_core import CoreSchema, core_schema
|
|
|
|
|
|
from crewai.rag.types import (
|
|
EmbeddingFunction,
|
|
BaseRecord,
|
|
SearchResult,
|
|
)
|
|
|
|
|
|
class BaseCollectionParams(TypedDict):
|
|
"""Base parameters for collection operations.
|
|
|
|
Attributes:
|
|
collection_name: The name of the collection/index to operate on.
|
|
"""
|
|
|
|
collection_name: Required[
|
|
Annotated[
|
|
str,
|
|
"Name of the collection/index. Implementations may have specific constraints (e.g., character limits, allowed characters, case sensitivity).",
|
|
]
|
|
]
|
|
|
|
|
|
class BaseCollectionAddParams(BaseCollectionParams):
|
|
"""Parameters for adding documents to a collection.
|
|
|
|
Extends BaseCollectionParams with document-specific fields.
|
|
|
|
Attributes:
|
|
collection_name: The name of the collection to add documents to.
|
|
documents: List of BaseRecord dictionaries containing document data.
|
|
"""
|
|
|
|
documents: list[BaseRecord]
|
|
|
|
|
|
class BaseCollectionSearchParams(BaseCollectionParams, total=False):
|
|
"""Parameters for searching within a collection.
|
|
|
|
Extends BaseCollectionParams with search-specific optional fields.
|
|
All fields except collection_name and query are optional.
|
|
|
|
Attributes:
|
|
query: The text query to search for (required).
|
|
limit: Maximum number of results to return.
|
|
metadata_filter: Filter results by metadata fields.
|
|
score_threshold: Minimum similarity score for results (0-1).
|
|
"""
|
|
|
|
query: Required[str]
|
|
limit: int
|
|
metadata_filter: dict[str, Any]
|
|
score_threshold: float
|
|
|
|
|
|
@runtime_checkable
|
|
class BaseClient(Protocol):
|
|
"""Protocol for vector store client implementations.
|
|
|
|
This protocol defines the interface that all vector store client implementations
|
|
must follow. It provides a consistent API for storing and retrieving
|
|
documents with their vector embeddings across different vector database
|
|
backends (e.g., Qdrant, ChromaDB, Weaviate). Implementing classes should
|
|
handle connection management, data persistence, and vector similarity
|
|
search operations specific to their backend.
|
|
|
|
Implementation Guidelines:
|
|
Implementations should accept BaseClientParams in their constructor to allow
|
|
passing pre-configured client instances:
|
|
|
|
class MyVectorClient:
|
|
def __init__(self, client: Any | None = None, **kwargs):
|
|
if client:
|
|
self.client = client
|
|
else:
|
|
self.client = self._create_default_client(**kwargs)
|
|
|
|
Notes:
|
|
This protocol replaces the former BaseRAGStorage abstraction,
|
|
providing a cleaner interface for vector store operations.
|
|
|
|
Attributes:
|
|
embedding_function: Callable that takes a list of text strings
|
|
and returns a list of embedding vectors. Implementations
|
|
should always provide a default embedding function.
|
|
client: The underlying vector database client instance. This could be
|
|
passed via BaseClientParams during initialization or created internally.
|
|
"""
|
|
|
|
client: Any
|
|
embedding_function: EmbeddingFunction
|
|
|
|
@classmethod
|
|
def __get_pydantic_core_schema__(
|
|
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
|
) -> CoreSchema:
|
|
"""Generate Pydantic core schema for BaseClient Protocol.
|
|
|
|
This allows the Protocol to be used in Pydantic models without
|
|
requiring arbitrary_types_allowed=True.
|
|
"""
|
|
return core_schema.any_schema()
|
|
|
|
@abstractmethod
|
|
def create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
|
"""Create a new collection/index in the vector database.
|
|
|
|
Keyword Args:
|
|
collection_name: The name of the collection to create. Must be unique within
|
|
the vector database instance.
|
|
|
|
Raises:
|
|
ValueError: If collection name already exists.
|
|
ConnectionError: If unable to connect to the vector database backend.
|
|
"""
|
|
...
|
|
|
|
@abstractmethod
|
|
async def acreate_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
|
"""Create a new collection/index in the vector database asynchronously.
|
|
|
|
Keyword Args:
|
|
collection_name: The name of the collection to create. Must be unique within
|
|
the vector database instance.
|
|
|
|
Raises:
|
|
ValueError: If collection name already exists.
|
|
ConnectionError: If unable to connect to the vector database backend.
|
|
"""
|
|
...
|
|
|
|
@abstractmethod
|
|
def get_or_create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> Any:
|
|
"""Get an existing collection or create it if it doesn't exist.
|
|
|
|
This method provides a convenient way to ensure a collection exists
|
|
without having to check for its existence first.
|
|
|
|
Keyword Args:
|
|
collection_name: The name of the collection to get or create.
|
|
|
|
Returns:
|
|
A collection object whose type depends on the backend implementation.
|
|
This could be a collection reference, ID, or client object.
|
|
|
|
Raises:
|
|
ValueError: If unable to create the collection.
|
|
ConnectionError: If unable to connect to the vector database backend.
|
|
"""
|
|
...
|
|
|
|
@abstractmethod
|
|
async def aget_or_create_collection(
|
|
self, **kwargs: Unpack[BaseCollectionParams]
|
|
) -> Any:
|
|
"""Get an existing collection or create it if it doesn't exist asynchronously.
|
|
|
|
Keyword Args:
|
|
collection_name: The name of the collection to get or create.
|
|
|
|
Returns:
|
|
A collection object whose type depends on the backend implementation.
|
|
|
|
Raises:
|
|
ValueError: If unable to create the collection.
|
|
ConnectionError: If unable to connect to the vector database backend.
|
|
"""
|
|
...
|
|
|
|
@abstractmethod
|
|
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
|
"""Add documents with their embeddings to a collection.
|
|
|
|
This method performs an upsert operation - if a document with the same ID
|
|
already exists, it will be updated with the new content and metadata.
|
|
|
|
Implementations should handle embedding generation internally based on
|
|
the configured embedding function.
|
|
|
|
Keyword Args:
|
|
collection_name: The name of the collection to add documents to.
|
|
documents: List of BaseRecord dicts containing:
|
|
- content: The text content (required)
|
|
- doc_id: Optional unique identifier (auto-generated from content hash if missing)
|
|
- metadata: Optional metadata dictionary
|
|
Embeddings will be generated automatically.
|
|
|
|
Raises:
|
|
ValueError: If collection doesn't exist or documents list is empty.
|
|
TypeError: If documents are not BaseRecord dict instances.
|
|
ConnectionError: If unable to connect to the vector database backend.
|
|
|
|
Example:
|
|
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
|
>>> from crewai.rag.types import BaseRecord
|
|
>>> client = ChromaDBClient()
|
|
>>>
|
|
>>> records: list[BaseRecord] = [
|
|
... {
|
|
... "content": "Machine learning basics",
|
|
... "metadata": {"source": "file3", "topic": "ML"}
|
|
... },
|
|
... {
|
|
... "doc_id": "custom_id",
|
|
... "content": "Deep learning fundamentals",
|
|
... "metadata": {"source": "file4", "topic": "DL"}
|
|
... }
|
|
... ]
|
|
>>> client.add_documents(collection_name="my_docs", documents=records)
|
|
>>>
|
|
>>> records_with_id: list[BaseRecord] = [
|
|
... {
|
|
... "doc_id": "nlp_001",
|
|
... "content": "Advanced NLP techniques",
|
|
... "metadata": {"source": "file5", "topic": "NLP"}
|
|
... }
|
|
... ]
|
|
>>> client.add_documents(collection_name="my_docs", documents=records_with_id)
|
|
"""
|
|
...
|
|
|
|
@abstractmethod
|
|
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
|
"""Add documents with their embeddings to a collection asynchronously.
|
|
|
|
Implementations should handle embedding generation internally based on
|
|
the configured embedding function.
|
|
|
|
Keyword Args:
|
|
collection_name: The name of the collection to add documents to.
|
|
documents: List of BaseRecord dicts containing:
|
|
- content: The text content (required)
|
|
- doc_id: Optional unique identifier (auto-generated from content hash if missing)
|
|
- metadata: Optional metadata dictionary
|
|
Embeddings will be generated automatically.
|
|
|
|
Raises:
|
|
ValueError: If collection doesn't exist or documents list is empty.
|
|
TypeError: If documents are not BaseRecord dict instances.
|
|
ConnectionError: If unable to connect to the vector database backend.
|
|
|
|
Example:
|
|
>>> import asyncio
|
|
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
|
>>> from crewai.rag.types import BaseRecord
|
|
>>>
|
|
>>> async def add_documents():
|
|
... client = ChromaDBClient()
|
|
...
|
|
... records: list[BaseRecord] = [
|
|
... {
|
|
... "doc_id": "doc2",
|
|
... "content": "Async operations in Python",
|
|
... "metadata": {"source": "file2", "topic": "async"}
|
|
... }
|
|
... ]
|
|
... await client.aadd_documents(collection_name="my_docs", documents=records)
|
|
...
|
|
>>> asyncio.run(add_documents())
|
|
"""
|
|
...
|
|
|
|
@abstractmethod
|
|
def search(
|
|
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
|
) -> list[SearchResult]:
|
|
"""Search for similar documents using a query.
|
|
|
|
Performs a vector similarity search to find the most similar documents
|
|
to the provided query.
|
|
|
|
Keyword Args:
|
|
collection_name: The name of the collection to search in.
|
|
query: The text query to search for. The implementation handles
|
|
embedding generation internally.
|
|
limit: Maximum number of results to return. Defaults to 10.
|
|
metadata_filter: Optional metadata filter to apply to the search. The exact
|
|
format depends on the backend, but typically supports equality
|
|
and range queries on metadata fields.
|
|
score_threshold: Optional minimum similarity score threshold. Only
|
|
results with scores >= this threshold will be returned. The
|
|
score interpretation depends on the distance metric used.
|
|
|
|
Returns:
|
|
A list of SearchResult dictionaries ordered by similarity score in
|
|
descending order. Each result contains:
|
|
- id: Document ID
|
|
- content: Document text content
|
|
- metadata: Document metadata
|
|
- score: Similarity score (0-1, higher is better)
|
|
|
|
Raises:
|
|
ValueError: If collection doesn't exist.
|
|
ConnectionError: If unable to connect to the vector database backend.
|
|
|
|
Example:
|
|
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
|
>>> client = ChromaDBClient()
|
|
>>>
|
|
>>> results = client.search(
|
|
... collection_name="my_docs",
|
|
... query="What is machine learning?",
|
|
... limit=5,
|
|
... metadata_filter={"source": "file1"},
|
|
... score_threshold=0.7
|
|
... )
|
|
>>> for result in results:
|
|
... print(f"{result['id']}: {result['score']:.2f}")
|
|
"""
|
|
...
|
|
|
|
@abstractmethod
|
|
async def asearch(
|
|
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
|
) -> list[SearchResult]:
|
|
"""Search for similar documents using a query asynchronously.
|
|
|
|
Keyword Args:
|
|
collection_name: The name of the collection to search in.
|
|
query: The text query to search for. The implementation handles
|
|
embedding generation internally.
|
|
limit: Maximum number of results to return. Defaults to 10.
|
|
metadata_filter: Optional metadata filter to apply to the search.
|
|
score_threshold: Optional minimum similarity score threshold.
|
|
|
|
Returns:
|
|
A list of SearchResult dictionaries ordered by similarity score.
|
|
|
|
Raises:
|
|
ValueError: If collection doesn't exist.
|
|
ConnectionError: If unable to connect to the vector database backend.
|
|
|
|
Example:
|
|
>>> import asyncio
|
|
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
|
>>>
|
|
>>> async def search_documents():
|
|
... client = ChromaDBClient()
|
|
... results = await client.asearch(
|
|
... collection_name="my_docs",
|
|
... query="Python programming best practices",
|
|
... limit=5,
|
|
... metadata_filter={"source": "file1"},
|
|
... score_threshold=0.7
|
|
... )
|
|
... for result in results:
|
|
... print(f"{result['id']}: {result['score']:.2f}")
|
|
...
|
|
>>> asyncio.run(search_documents())
|
|
"""
|
|
...
|
|
|
|
@abstractmethod
|
|
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
|
"""Delete a collection and all its data.
|
|
|
|
This operation is irreversible and will permanently remove all documents,
|
|
embeddings, and metadata associated with the collection.
|
|
|
|
Keyword Args:
|
|
collection_name: The name of the collection to delete.
|
|
|
|
Raises:
|
|
ValueError: If the collection doesn't exist.
|
|
ConnectionError: If unable to connect to the vector database backend.
|
|
|
|
Example:
|
|
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
|
>>> client = ChromaDBClient()
|
|
>>> client.delete_collection(collection_name="old_docs")
|
|
>>> print("Collection 'old_docs' deleted successfully")
|
|
"""
|
|
...
|
|
|
|
@abstractmethod
|
|
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
|
"""Delete a collection and all its data asynchronously.
|
|
|
|
Keyword Args:
|
|
collection_name: The name of the collection to delete.
|
|
|
|
Raises:
|
|
ValueError: If the collection doesn't exist.
|
|
ConnectionError: If unable to connect to the vector database backend.
|
|
|
|
Example:
|
|
>>> import asyncio
|
|
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
|
>>>
|
|
>>> async def delete_old_collection():
|
|
... client = ChromaDBClient()
|
|
... await client.adelete_collection(collection_name="old_docs")
|
|
... print("Collection 'old_docs' deleted successfully")
|
|
...
|
|
>>> asyncio.run(delete_old_collection())
|
|
"""
|
|
...
|
|
|
|
@abstractmethod
|
|
def reset(self) -> None:
|
|
"""Reset the vector database by deleting all collections and data.
|
|
|
|
This method provides a way to completely clear the vector database,
|
|
removing all collections and their contents. Use with caution as
|
|
this operation is irreversible.
|
|
|
|
Raises:
|
|
ConnectionError: If unable to connect to the vector database backend.
|
|
PermissionError: If the operation is not allowed by the backend.
|
|
|
|
Example:
|
|
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
|
>>> client = ChromaDBClient()
|
|
>>> client.reset()
|
|
>>> print("Vector database completely reset - all data deleted")
|
|
"""
|
|
...
|
|
|
|
@abstractmethod
|
|
async def areset(self) -> None:
|
|
"""Reset the vector database by deleting all collections and data asynchronously.
|
|
|
|
Raises:
|
|
ConnectionError: If unable to connect to the vector database backend.
|
|
PermissionError: If the operation is not allowed by the backend.
|
|
|
|
Example:
|
|
>>> import asyncio
|
|
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
|
>>>
|
|
>>> async def reset_database():
|
|
... client = ChromaDBClient()
|
|
... await client.areset()
|
|
... print("Vector database completely reset - all data deleted")
|
|
...
|
|
>>> asyncio.run(reset_database())
|
|
"""
|
|
...
|