feat: centralize embedding types and create base client (#3246)

feat: add RAG system foundation with generic vector store support

- Add BaseClient protocol for vector stores
- Move BaseRAGStorage to rag/core
- Centralize embedding types in embeddings/types.py
- Remove unused storage models
This commit is contained in:
Greyson LaLonde
2025-08-20 09:35:27 -04:00
committed by GitHub
parent 2773996b49
commit ed187b495b
6 changed files with 724 additions and 0 deletions

View File

@@ -0,0 +1 @@
"""Core abstract base classes and protocols for RAG systems."""

View File

@@ -0,0 +1,433 @@
"""Protocol for vector database client implementations."""
from abc import abstractmethod
from typing import Any, Protocol, runtime_checkable, TypedDict, Annotated
from typing_extensions import Unpack, Required
from crewai.rag.types import (
EmbeddingFunction,
BaseRecord,
SearchResult,
)
class BaseCollectionParams(TypedDict):
"""Base parameters for collection operations.
Attributes:
collection_name: The name of the collection/index to operate on.
"""
collection_name: Required[
Annotated[
str,
"Name of the collection/index. Implementations may have specific constraints (e.g., character limits, allowed characters, case sensitivity).",
]
]
class BaseCollectionAddParams(BaseCollectionParams):
"""Parameters for adding documents to a collection.
Extends BaseCollectionParams with document-specific fields.
Attributes:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dictionaries containing document data.
"""
documents: list[BaseRecord]
class BaseCollectionSearchParams(BaseCollectionParams, total=False):
"""Parameters for searching within a collection.
Extends BaseCollectionParams with search-specific optional fields.
All fields except collection_name and query are optional.
Attributes:
query: The text query to search for (required).
limit: Maximum number of results to return.
metadata_filter: Filter results by metadata fields.
score_threshold: Minimum similarity score for results (0-1).
"""
query: Required[str]
limit: int
metadata_filter: dict[str, Any]
score_threshold: float
@runtime_checkable
class BaseClient(Protocol):
"""Protocol for vector store client implementations.
This protocol defines the interface that all vector store client implementations
must follow. It provides a consistent API for storing and retrieving
documents with their vector embeddings across different vector database
backends (e.g., Qdrant, ChromaDB, Weaviate). Implementing classes should
handle connection management, data persistence, and vector similarity
search operations specific to their backend.
Implementation Guidelines:
Implementations should accept BaseClientParams in their constructor to allow
passing pre-configured client instances:
class MyVectorClient:
def __init__(self, client: Any | None = None, **kwargs):
if client:
self.client = client
else:
self.client = self._create_default_client(**kwargs)
Notes:
This protocol replaces the former BaseRAGStorage abstraction,
providing a cleaner interface for vector store operations.
Attributes:
embedding_function: Callable that takes a list of text strings
and returns a list of embedding vectors. Implementations
should always provide a default embedding function.
client: The underlying vector database client instance. This could be
passed via BaseClientParams during initialization or created internally.
"""
client: Any
embedding_function: EmbeddingFunction
@abstractmethod
def create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Create a new collection/index in the vector database.
Keyword Args:
collection_name: The name of the collection to create. Must be unique within
the vector database instance.
Raises:
ValueError: If collection name already exists.
ConnectionError: If unable to connect to the vector database backend.
"""
...
@abstractmethod
async def acreate_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Create a new collection/index in the vector database asynchronously.
Keyword Args:
collection_name: The name of the collection to create. Must be unique within
the vector database instance.
Raises:
ValueError: If collection name already exists.
ConnectionError: If unable to connect to the vector database backend.
"""
...
@abstractmethod
def get_or_create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> Any:
"""Get an existing collection or create it if it doesn't exist.
This method provides a convenient way to ensure a collection exists
without having to check for its existence first.
Keyword Args:
collection_name: The name of the collection to get or create.
Returns:
A collection object whose type depends on the backend implementation.
This could be a collection reference, ID, or client object.
Raises:
ValueError: If unable to create the collection.
ConnectionError: If unable to connect to the vector database backend.
"""
...
@abstractmethod
async def aget_or_create_collection(
self, **kwargs: Unpack[BaseCollectionParams]
) -> Any:
"""Get an existing collection or create it if it doesn't exist asynchronously.
Keyword Args:
collection_name: The name of the collection to get or create.
Returns:
A collection object whose type depends on the backend implementation.
Raises:
ValueError: If unable to create the collection.
ConnectionError: If unable to connect to the vector database backend.
"""
...
@abstractmethod
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection.
This method performs an upsert operation - if a document with the same ID
already exists, it will be updated with the new content and metadata.
Implementations should handle embedding generation internally based on
the configured embedding function.
Keyword Args:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing:
- content: The text content (required)
- doc_id: Optional unique identifier (auto-generated from content hash if missing)
- metadata: Optional metadata dictionary
Embeddings will be generated automatically.
Raises:
ValueError: If collection doesn't exist or documents list is empty.
TypeError: If documents are not BaseRecord dict instances.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>> from crewai.rag.types import BaseRecord
>>> client = ChromaDBClient()
>>>
>>> records: list[BaseRecord] = [
... {
... "content": "Machine learning basics",
... "metadata": {"source": "file3", "topic": "ML"}
... },
... {
... "doc_id": "custom_id",
... "content": "Deep learning fundamentals",
... "metadata": {"source": "file4", "topic": "DL"}
... }
... ]
>>> client.add_documents(collection_name="my_docs", documents=records)
>>>
>>> records_with_id: list[BaseRecord] = [
... {
... "doc_id": "nlp_001",
... "content": "Advanced NLP techniques",
... "metadata": {"source": "file5", "topic": "NLP"}
... }
... ]
>>> client.add_documents(collection_name="my_docs", documents=records_with_id)
"""
...
@abstractmethod
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection asynchronously.
Implementations should handle embedding generation internally based on
the configured embedding function.
Keyword Args:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing:
- content: The text content (required)
- doc_id: Optional unique identifier (auto-generated from content hash if missing)
- metadata: Optional metadata dictionary
Embeddings will be generated automatically.
Raises:
ValueError: If collection doesn't exist or documents list is empty.
TypeError: If documents are not BaseRecord dict instances.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> import asyncio
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>> from crewai.rag.types import BaseRecord
>>>
>>> async def add_documents():
... client = ChromaDBClient()
...
... records: list[BaseRecord] = [
... {
... "doc_id": "doc2",
... "content": "Async operations in Python",
... "metadata": {"source": "file2", "topic": "async"}
... }
... ]
... await client.aadd_documents(collection_name="my_docs", documents=records)
...
>>> asyncio.run(add_documents())
"""
...
@abstractmethod
def search(
self, **kwargs: Unpack[BaseCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query.
Performs a vector similarity search to find the most similar documents
to the provided query.
Keyword Args:
collection_name: The name of the collection to search in.
query: The text query to search for. The implementation handles
embedding generation internally.
limit: Maximum number of results to return. Defaults to 10.
metadata_filter: Optional metadata filter to apply to the search. The exact
format depends on the backend, but typically supports equality
and range queries on metadata fields.
score_threshold: Optional minimum similarity score threshold. Only
results with scores >= this threshold will be returned. The
score interpretation depends on the distance metric used.
Returns:
A list of SearchResult dictionaries ordered by similarity score in
descending order. Each result contains:
- id: Document ID
- content: Document text content
- metadata: Document metadata
- score: Similarity score (0-1, higher is better)
Raises:
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>> client = ChromaDBClient()
>>>
>>> results = client.search(
... collection_name="my_docs",
... query="What is machine learning?",
... limit=5,
... metadata_filter={"source": "file1"},
... score_threshold=0.7
... )
>>> for result in results:
... print(f"{result['id']}: {result['score']:.2f}")
"""
...
@abstractmethod
async def asearch(
self, **kwargs: Unpack[BaseCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query asynchronously.
Keyword Args:
collection_name: The name of the collection to search in.
query: The text query to search for. The implementation handles
embedding generation internally.
limit: Maximum number of results to return. Defaults to 10.
metadata_filter: Optional metadata filter to apply to the search.
score_threshold: Optional minimum similarity score threshold.
Returns:
A list of SearchResult dictionaries ordered by similarity score.
Raises:
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> import asyncio
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>>
>>> async def search_documents():
... client = ChromaDBClient()
... results = await client.asearch(
... collection_name="my_docs",
... query="Python programming best practices",
... limit=5,
... metadata_filter={"source": "file1"},
... score_threshold=0.7
... )
... for result in results:
... print(f"{result['id']}: {result['score']:.2f}")
...
>>> asyncio.run(search_documents())
"""
...
@abstractmethod
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data.
This operation is irreversible and will permanently remove all documents,
embeddings, and metadata associated with the collection.
Keyword Args:
collection_name: The name of the collection to delete.
Raises:
ValueError: If the collection doesn't exist.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>> client = ChromaDBClient()
>>> client.delete_collection(collection_name="old_docs")
>>> print("Collection 'old_docs' deleted successfully")
"""
...
@abstractmethod
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data asynchronously.
Keyword Args:
collection_name: The name of the collection to delete.
Raises:
ValueError: If the collection doesn't exist.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> import asyncio
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>>
>>> async def delete_old_collection():
... client = ChromaDBClient()
... await client.adelete_collection(collection_name="old_docs")
... print("Collection 'old_docs' deleted successfully")
...
>>> asyncio.run(delete_old_collection())
"""
...
@abstractmethod
def reset(self) -> None:
"""Reset the vector database by deleting all collections and data.
This method provides a way to completely clear the vector database,
removing all collections and their contents. Use with caution as
this operation is irreversible.
Raises:
ConnectionError: If unable to connect to the vector database backend.
PermissionError: If the operation is not allowed by the backend.
Example:
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>> client = ChromaDBClient()
>>> client.reset()
>>> print("Vector database completely reset - all data deleted")
"""
...
@abstractmethod
async def areset(self) -> None:
"""Reset the vector database by deleting all collections and data asynchronously.
Raises:
ConnectionError: If unable to connect to the vector database backend.
PermissionError: If the operation is not allowed by the backend.
Example:
>>> import asyncio
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>>
>>> async def reset_database():
... client = ChromaDBClient()
... await client.areset()
... print("Vector database completely reset - all data deleted")
...
>>> asyncio.run(reset_database())
"""
...

View File

@@ -0,0 +1,30 @@
"""Base provider protocol for vector database client creation."""
from abc import ABC
from typing import Any, Protocol, runtime_checkable, Union
from pydantic import BaseModel, Field
from crewai.rag.types import EmbeddingFunction
from crewai.rag.embeddings.types import EmbeddingOptions
class BaseProviderOptions(BaseModel, ABC):
"""Base configuration for all provider options."""
client_type: str = Field(..., description="Type of client to create")
embedding_config: Union[EmbeddingOptions, EmbeddingFunction, None] = Field(
default=None,
description="Embedding configuration - either options for built-in providers or a custom function",
)
options: Any = Field(
default=None, description="Additional provider-specific options"
)
@runtime_checkable
class BaseProvider(Protocol):
"""Protocol for vector database client providers."""
def __call__(self, options: BaseProviderOptions) -> Any:
"""Create and return a configured client instance."""
...

View File

@@ -0,0 +1,148 @@
"""Minimal embedding function factory for CrewAI."""
import os
from chromadb import EmbeddingFunction
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
from chromadb.utils.embedding_functions.google_embedding_function import (
GooglePalmEmbeddingFunction,
GoogleGenerativeAiEmbeddingFunction,
GoogleVertexEmbeddingFunction,
)
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingFunction,
)
from chromadb.utils.embedding_functions.instructor_embedding_function import (
InstructorEmbeddingFunction,
)
from chromadb.utils.embedding_functions.jina_embedding_function import (
JinaEmbeddingFunction,
)
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
OpenCLIPEmbeddingFunction,
)
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
RoboflowEmbeddingFunction,
)
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
SentenceTransformerEmbeddingFunction,
)
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
Text2VecEmbeddingFunction,
)
from crewai.rag.embeddings.types import EmbeddingOptions
def get_embedding_function(
config: EmbeddingOptions | dict | None = None,
) -> EmbeddingFunction:
"""Get embedding function - delegates to ChromaDB.
Args:
config: Optional configuration - either an EmbeddingOptions object or a dict with:
- provider: The embedding provider to use (default: "openai")
- Any other provider-specific parameters
Returns:
EmbeddingFunction instance ready for use with ChromaDB
Supported providers:
- openai: OpenAI embeddings (default)
- cohere: Cohere embeddings
- ollama: Ollama local embeddings
- huggingface: HuggingFace embeddings
- sentence-transformer: Local sentence transformers
- instructor: Instructor embeddings for specialized tasks
- google-palm: Google PaLM embeddings
- google-generativeai: Google Generative AI embeddings
- google-vertex: Google Vertex AI embeddings
- amazon-bedrock: AWS Bedrock embeddings
- jina: Jina AI embeddings
- roboflow: Roboflow embeddings for vision tasks
- openclip: OpenCLIP embeddings for multimodal tasks
- text2vec: Text2Vec embeddings
- onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB)
Examples:
# Use default OpenAI with retry logic
>>> embedder = get_embedding_function()
# Use Cohere with dict
>>> embedder = get_embedding_function({
... "provider": "cohere",
... "api_key": "your-key",
... "model_name": "embed-english-v3.0"
... })
# Use with EmbeddingOptions
>>> embedder = get_embedding_function(
... EmbeddingOptions(provider="sentence-transformer", model_name="all-MiniLM-L6-v2")
... )
# Use local sentence transformers (no API key needed)
>>> embedder = get_embedding_function({
... "provider": "sentence-transformer",
... "model_name": "all-MiniLM-L6-v2"
... })
# Use Ollama for local embeddings
>>> embedder = get_embedding_function({
... "provider": "ollama",
... "model_name": "nomic-embed-text"
... })
# Use ONNX (no API key needed)
>>> embedder = get_embedding_function({
... "provider": "onnx"
... })
"""
if config is None:
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
# Handle EmbeddingOptions object
if isinstance(config, EmbeddingOptions):
config_dict = config.model_dump(exclude_none=True)
else:
config_dict = config.copy()
provider = config_dict.pop("provider", "openai")
embedding_functions = {
"openai": OpenAIEmbeddingFunction,
"cohere": CohereEmbeddingFunction,
"ollama": OllamaEmbeddingFunction,
"huggingface": HuggingFaceEmbeddingFunction,
"sentence-transformer": SentenceTransformerEmbeddingFunction,
"instructor": InstructorEmbeddingFunction,
"google-palm": GooglePalmEmbeddingFunction,
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
"google-vertex": GoogleVertexEmbeddingFunction,
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
"jina": JinaEmbeddingFunction,
"roboflow": RoboflowEmbeddingFunction,
"openclip": OpenCLIPEmbeddingFunction,
"text2vec": Text2VecEmbeddingFunction,
"onnx": ONNXMiniLM_L6_V2,
}
if provider not in embedding_functions:
raise ValueError(
f"Unsupported provider: {provider}. "
f"Available providers: {list(embedding_functions.keys())}"
)
return embedding_functions[provider](**config_dict)

View File

@@ -0,0 +1,62 @@
"""Type definitions for the embeddings module."""
from typing import Literal
from pydantic import BaseModel, Field, SecretStr
from crewai.rag.types import EmbeddingFunction
EmbeddingProvider = Literal[
"openai",
"cohere",
"ollama",
"huggingface",
"sentence-transformer",
"instructor",
"google-palm",
"google-generativeai",
"google-vertex",
"amazon-bedrock",
"jina",
"roboflow",
"openclip",
"text2vec",
"onnx",
]
"""Supported embedding providers.
These correspond to the embedding functions available in ChromaDB's
embedding_functions module. Each provider has specific requirements
and configuration options.
"""
class EmbeddingOptions(BaseModel):
"""Configuration options for embedding providers.
Generic attributes that can be passed to get_embedding_function
to configure various embedding providers.
"""
provider: EmbeddingProvider = Field(
..., description="Embedding provider name (e.g., 'openai', 'cohere', 'onnx')"
)
model_name: str | None = Field(
default=None, description="Model name for the embedding provider"
)
api_key: SecretStr | None = Field(
default=None, description="API key for the embedding provider"
)
class EmbeddingConfig(BaseModel):
"""Configuration wrapper for embedding functions.
Accepts either a pre-configured EmbeddingFunction or EmbeddingOptions
to create one. This provides flexibility in how embeddings are configured.
Attributes:
function: Either a callable EmbeddingFunction or EmbeddingOptions to create one
"""
function: EmbeddingFunction | EmbeddingOptions

50
src/crewai/rag/types.py Normal file
View File

@@ -0,0 +1,50 @@
"""Type definitions for RAG (Retrieval-Augmented Generation) systems."""
from collections.abc import Callable, Mapping
from typing import TypeAlias, TypedDict, Any
from typing_extensions import Required
class BaseRecord(TypedDict, total=False):
"""A typed dictionary representing a document record.
Attributes:
doc_id: Optional unique identifier for the document. If not provided,
a content-based ID will be generated using SHA256 hash.
content: The text content of the document (required)
metadata: Optional metadata associated with the document
"""
doc_id: str
content: Required[str]
metadata: (
Mapping[str, str | int | float | bool]
| list[Mapping[str, str | int | float | bool]]
)
DenseVector: TypeAlias = list[float]
IntVector: TypeAlias = list[int]
EmbeddingFunction: TypeAlias = Callable[..., Any]
class SearchResult(TypedDict):
"""Standard search result format for vector store queries.
This provides a consistent interface for search results across different
vector store implementations. Each implementation should convert their
native result format to this standard format.
Attributes:
id: Unique identifier of the document
content: The text content of the document
metadata: Optional metadata associated with the document
score: Similarity score (higher is better, typically between 0 and 1)
"""
id: str
content: str
metadata: dict[str, Any]
score: float