refactor: move rag package to lib/core and extract standalone utilities

This commit is contained in:
Greyson Lalonde
2025-09-27 16:31:54 -04:00
parent 1d1f5f455c
commit aca826c553
120 changed files with 360 additions and 87 deletions

1
lib/core/.python-version Normal file
View File

@@ -0,0 +1 @@
3.13

0
lib/core/README.md Normal file
View File

56
lib/core/pyproject.toml Normal file
View File

@@ -0,0 +1,56 @@
[project]
name = "crewai-core"
dynamic = ["version"]
description = ""
readme = "README.md"
authors = [
{ name = "Greyson Lalonde", email = "greyson.r.lalonde@gmail.com" }
]
keywords = [
"crewai",
"ai",
"agents",
"framework",
"orchestration",
"llm",
"core",
"typed",
]
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Developers",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Software Development :: Libraries :: Python Modules",
"Typing :: Typed",
]
requires-python = ">=3.10, <3.14"
dependencies = [
"chromadb~=1.1.0",
"pydantic-settings>=2.10.1",
"uv>=0.4.25",
]
[project.urls]
Homepage = "https://crewai.com"
Documentation = "https://docs.crewai.com"
Repository = "https://github.com/crewAIInc/crewAI"
[tool.pytest.ini_options]
testpaths = ["tests"]
asyncio_mode = "strict"
asyncio_default_fixture_loop_scope = "function"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.version]
path = "src/crewai/core/__init__.py"
[tool.hatch.build.targets.wheel]
packages = ["src/crewai"]

View File

@@ -0,0 +1 @@
__version__ = "1.0.0a0"

View File

View File

@@ -0,0 +1,60 @@
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
import importlib
import sys
from types import ModuleType
from typing import Any
from crewai.rag.config.types import RagConfigType
from crewai.rag.config.utils import set_rag_config
_module_path = __path__
_module_file = __file__
class _RagModule(ModuleType):
"""Module wrapper to intercept attribute setting for config."""
__path__ = _module_path
__file__ = _module_file
def __init__(self, module_name: str):
"""Initialize the module wrapper.
Args:
module_name: Name of the module.
"""
super().__init__(module_name)
def __setattr__(self, name: str, value: RagConfigType) -> None:
"""Set module attributes.
Args:
name: Attribute name.
value: Attribute value.
"""
if name == "config":
return set_rag_config(value)
raise AttributeError(f"Setting attribute '{name}' is not allowed.")
def __getattr__(self, name: str) -> Any:
"""Get module attributes.
Args:
name: Attribute name.
Returns:
The requested attribute.
Raises:
AttributeError: If attribute doesn't exist.
"""
try:
return importlib.import_module(f"{self.__name__}.{name}")
except ImportError as e:
raise AttributeError(
f"module '{self.__name__}' has no attribute '{name}'"
) from e
sys.modules[__name__] = _RagModule(__name__)

View File

@@ -0,0 +1,617 @@
"""ChromaDB client implementation."""
import logging
from typing import Any
from chromadb.api.types import (
EmbeddingFunction as ChromaEmbeddingFunction,
)
from chromadb.api.types import (
QueryResult,
)
from typing_extensions import Unpack
from crewai.rag.chromadb.types import (
ChromaDBClientType,
ChromaDBCollectionCreateParams,
ChromaDBCollectionSearchParams,
)
from crewai.rag.chromadb.utils import (
_create_batch_slice,
_extract_search_params,
_is_async_client,
_is_sync_client,
_prepare_documents_for_chromadb,
_process_query_results,
_sanitize_collection_name,
)
from crewai.rag.core.base_client import (
BaseClient,
BaseCollectionAddParams,
BaseCollectionParams,
)
from crewai.rag.types import SearchResult
from crewai.core.utilities.logger_utils import suppress_logging
class ChromaDBClient(BaseClient):
"""ChromaDB implementation of the BaseClient protocol.
Provides vector database operations for ChromaDB, supporting both
synchronous and asynchronous clients.
Attributes:
client: ChromaDB client instance (ClientAPI or AsyncClientAPI).
embedding_function: Function to generate embeddings for documents.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
"""
def __init__(
self,
client: ChromaDBClientType,
embedding_function: ChromaEmbeddingFunction,
default_limit: int = 5,
default_score_threshold: float = 0.6,
default_batch_size: int = 100,
) -> None:
"""Initialize ChromaDBClient with client and embedding function.
Args:
client: Pre-configured ChromaDB client instance.
embedding_function: Embedding function for text to vector conversion.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
default_batch_size: Default batch size for adding documents.
"""
self.client = client
self.embedding_function = embedding_function
self.default_limit = default_limit
self.default_score_threshold = default_score_threshold
self.default_batch_size = default_batch_size
def create_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
) -> None:
"""Create a new collection in ChromaDB.
Uses the client's default embedding function if none provided.
Keyword Args:
collection_name: Name of the collection to create. Must be unique.
configuration: Optional collection configuration specifying distance metrics,
HNSW parameters, or other backend-specific settings.
metadata: Optional metadata dictionary to attach to the collection.
embedding_function: Optional custom embedding function. If not provided,
uses the client's default embedding function.
data_loader: Optional data loader for batch loading data into the collection.
get_or_create: If True, returns existing collection if it already exists
instead of raising an error. Defaults to False.
Raises:
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
ValueError: If collection with the same name already exists and get_or_create
is False.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> client = ChromaDBClient()
>>> client.create_collection(
... collection_name="documents",
... metadata={"description": "Product documentation"},
... get_or_create=True
... )
"""
if not _is_sync_client(self.client):
raise TypeError(
"Synchronous method create_collection() requires a ClientAPI. "
"Use acreate_collection() for AsyncClientAPI."
)
metadata = kwargs.get("metadata", {})
if "hnsw:space" not in metadata:
metadata["hnsw:space"] = "cosine"
self.client.create_collection(
name=_sanitize_collection_name(kwargs["collection_name"]),
configuration=kwargs.get("configuration"),
metadata=metadata,
embedding_function=kwargs.get(
"embedding_function", self.embedding_function
),
data_loader=kwargs.get("data_loader"),
get_or_create=kwargs.get("get_or_create", False),
)
async def acreate_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
) -> None:
"""Create a new collection in ChromaDB asynchronously.
Creates a new collection with the specified name and optional configuration.
If an embedding function is not provided, uses the client's default embedding function.
Keyword Args:
collection_name: Name of the collection to create. Must be unique.
configuration: Optional collection configuration specifying distance metrics,
HNSW parameters, or other backend-specific settings.
metadata: Optional metadata dictionary to attach to the collection.
embedding_function: Optional custom embedding function. If not provided,
uses the client's default embedding function.
data_loader: Optional data loader for batch loading data into the collection.
get_or_create: If True, returns existing collection if it already exists
instead of raising an error. Defaults to False.
Raises:
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
ValueError: If collection with the same name already exists and get_or_create
is False.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> import asyncio
>>> async def main():
... client = ChromaDBClient()
... await client.acreate_collection(
... collection_name="documents",
... metadata={"description": "Product documentation"},
... get_or_create=True
... )
>>> asyncio.run(main())
"""
if not _is_async_client(self.client):
raise TypeError(
"Asynchronous method acreate_collection() requires an AsyncClientAPI. "
"Use create_collection() for ClientAPI."
)
metadata = kwargs.get("metadata", {})
if "hnsw:space" not in metadata:
metadata["hnsw:space"] = "cosine"
await self.client.create_collection(
name=_sanitize_collection_name(kwargs["collection_name"]),
configuration=kwargs.get("configuration"),
metadata=metadata,
embedding_function=kwargs.get(
"embedding_function", self.embedding_function
),
data_loader=kwargs.get("data_loader"),
get_or_create=kwargs.get("get_or_create", False),
)
def get_or_create_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
) -> Any:
"""Get an existing collection or create it if it doesn't exist.
Returns existing collection if found, otherwise creates a new one.
Keyword Args:
collection_name: Name of the collection to get or create.
configuration: Optional collection configuration specifying distance metrics,
HNSW parameters, or other backend-specific settings.
metadata: Optional metadata dictionary to attach to the collection.
embedding_function: Optional custom embedding function. If not provided,
uses the client's default embedding function.
data_loader: Optional data loader for batch loading data into the collection.
Returns:
A ChromaDB Collection object.
Raises:
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> client = ChromaDBClient()
>>> collection = client.get_or_create_collection(
... collection_name="documents",
... metadata={"description": "Product documentation"}
... )
"""
if not _is_sync_client(self.client):
raise TypeError(
"Synchronous method get_or_create_collection() requires a ClientAPI. "
"Use aget_or_create_collection() for AsyncClientAPI."
)
metadata = kwargs.get("metadata", {})
if "hnsw:space" not in metadata:
metadata["hnsw:space"] = "cosine"
return self.client.get_or_create_collection(
name=_sanitize_collection_name(kwargs["collection_name"]),
configuration=kwargs.get("configuration"),
metadata=metadata,
embedding_function=kwargs.get(
"embedding_function", self.embedding_function
),
data_loader=kwargs.get("data_loader"),
)
async def aget_or_create_collection(
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
) -> Any:
"""Get an existing collection or create it if it doesn't exist asynchronously.
Returns existing collection if found, otherwise creates a new one.
Keyword Args:
collection_name: Name of the collection to get or create.
configuration: Optional collection configuration specifying distance metrics,
HNSW parameters, or other backend-specific settings.
metadata: Optional metadata dictionary to attach to the collection.
embedding_function: Optional custom embedding function. If not provided,
uses the client's default embedding function.
data_loader: Optional data loader for batch loading data into the collection.
Returns:
A ChromaDB AsyncCollection object.
Raises:
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> import asyncio
>>> async def main():
... client = ChromaDBClient()
... collection = await client.aget_or_create_collection(
... collection_name="documents",
... metadata={"description": "Product documentation"}
... )
>>> asyncio.run(main())
"""
if not _is_async_client(self.client):
raise TypeError(
"Asynchronous method aget_or_create_collection() requires an AsyncClientAPI. "
"Use get_or_create_collection() for ClientAPI."
)
metadata = kwargs.get("metadata", {})
if "hnsw:space" not in metadata:
metadata["hnsw:space"] = "cosine"
return await self.client.get_or_create_collection(
name=_sanitize_collection_name(kwargs["collection_name"]),
configuration=kwargs.get("configuration"),
metadata=metadata,
embedding_function=kwargs.get(
"embedding_function", self.embedding_function
),
data_loader=kwargs.get("data_loader"),
)
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection.
Performs an upsert operation - documents with existing IDs are updated.
Generates embeddings automatically using the configured embedding function.
Keyword Args:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing:
- content: The text content (required)
- doc_id: Optional unique identifier (auto-generated if missing)
- metadata: Optional metadata dictionary
batch_size: Optional batch size for processing documents (default: 100)
Raises:
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
ValueError: If collection doesn't exist or documents list is empty.
ConnectionError: If unable to connect to ChromaDB server.
"""
if not _is_sync_client(self.client):
raise TypeError(
"Synchronous method add_documents() requires a ClientAPI. "
"Use aadd_documents() for AsyncClientAPI."
)
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
batch_size = kwargs.get("batch_size", self.default_batch_size)
if not documents:
raise ValueError("Documents list cannot be empty")
collection = self.client.get_or_create_collection(
name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function,
)
prepared = _prepare_documents_for_chromadb(documents)
for i in range(0, len(prepared.ids), batch_size):
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared=prepared, start_index=i, batch_size=batch_size
)
collection.upsert(
ids=batch_ids,
documents=batch_texts,
metadatas=batch_metadatas,
)
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection asynchronously.
Performs an upsert operation - documents with existing IDs are updated.
Generates embeddings automatically using the configured embedding function.
Keyword Args:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing:
- content: The text content (required)
- doc_id: Optional unique identifier (auto-generated if missing)
- metadata: Optional metadata dictionary
batch_size: Optional batch size for processing documents (default: 100)
Raises:
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
ValueError: If collection doesn't exist or documents list is empty.
ConnectionError: If unable to connect to ChromaDB server.
"""
if not _is_async_client(self.client):
raise TypeError(
"Asynchronous method aadd_documents() requires an AsyncClientAPI. "
"Use add_documents() for ClientAPI."
)
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
batch_size = kwargs.get("batch_size", self.default_batch_size)
if not documents:
raise ValueError("Documents list cannot be empty")
collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(collection_name),
embedding_function=self.embedding_function,
)
prepared = _prepare_documents_for_chromadb(documents)
for i in range(0, len(prepared.ids), batch_size):
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
prepared=prepared, start_index=i, batch_size=batch_size
)
await collection.upsert(
ids=batch_ids,
documents=batch_texts,
metadatas=batch_metadatas,
)
def search(
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query.
Performs semantic search to find documents similar to the query text.
Uses the configured embedding function to generate query embeddings.
Keyword Args:
collection_name: Name of the collection to search in.
query: The text query to search for.
limit: Maximum number of results to return (default: 10).
metadata_filter: Optional filter for metadata fields.
score_threshold: Optional minimum similarity score (0-1) for results.
where: Optional ChromaDB where clause for metadata filtering.
where_document: Optional ChromaDB where clause for document content filtering.
include: Optional list of fields to include in results.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
Raises:
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to ChromaDB server.
"""
if not _is_sync_client(self.client):
raise TypeError(
"Synchronous method search() requires a ClientAPI. "
"Use asearch() for AsyncClientAPI."
)
if "limit" not in kwargs:
kwargs["limit"] = self.default_limit
if "score_threshold" not in kwargs:
kwargs["score_threshold"] = self.default_score_threshold
params = _extract_search_params(kwargs)
collection = self.client.get_or_create_collection(
name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function,
)
where = params.where if params.where is not None else params.metadata_filter
with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
results: QueryResult = collection.query(
query_texts=[params.query],
n_results=params.limit,
where=where,
where_document=params.where_document,
include=params.include,
)
return _process_query_results(
collection=collection,
results=results,
params=params,
)
async def asearch(
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query asynchronously.
Performs semantic search to find documents similar to the query text.
Uses the configured embedding function to generate query embeddings.
Keyword Args:
collection_name: Name of the collection to search in.
query: The text query to search for.
limit: Maximum number of results to return (default: 10).
metadata_filter: Optional filter for metadata fields.
score_threshold: Optional minimum similarity score (0-1) for results.
where: Optional ChromaDB where clause for metadata filtering.
where_document: Optional ChromaDB where clause for document content filtering.
include: Optional list of fields to include in results.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
Raises:
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to ChromaDB server.
"""
if not _is_async_client(self.client):
raise TypeError(
"Asynchronous method asearch() requires an AsyncClientAPI. "
"Use search() for ClientAPI."
)
if "limit" not in kwargs:
kwargs["limit"] = self.default_limit
if "score_threshold" not in kwargs:
kwargs["score_threshold"] = self.default_score_threshold
params = _extract_search_params(kwargs)
collection = await self.client.get_or_create_collection(
name=_sanitize_collection_name(params.collection_name),
embedding_function=self.embedding_function,
)
where = params.where if params.where is not None else params.metadata_filter
with suppress_logging(
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
):
results: QueryResult = await collection.query(
query_texts=[params.query],
n_results=params.limit,
where=where,
where_document=params.where_document,
include=params.include,
)
return _process_query_results(
collection=collection,
results=results,
params=params,
)
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data.
Permanently removes a collection and all documents, embeddings, and metadata it contains.
This operation cannot be undone.
Keyword Args:
collection_name: Name of the collection to delete.
Raises:
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> client = ChromaDBClient()
>>> client.delete_collection(collection_name="old_documents")
"""
if not _is_sync_client(self.client):
raise TypeError(
"Synchronous method delete_collection() requires a ClientAPI. "
"Use adelete_collection() for AsyncClientAPI."
)
collection_name = kwargs["collection_name"]
self.client.delete_collection(name=_sanitize_collection_name(collection_name))
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data asynchronously.
Permanently removes a collection and all documents, embeddings, and metadata it contains.
This operation cannot be undone.
Keyword Args:
collection_name: Name of the collection to delete.
Raises:
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> import asyncio
>>> async def main():
... client = ChromaDBClient()
... await client.adelete_collection(collection_name="old_documents")
>>> asyncio.run(main())
"""
if not _is_async_client(self.client):
raise TypeError(
"Asynchronous method adelete_collection() requires an AsyncClientAPI. "
"Use delete_collection() for ClientAPI."
)
collection_name = kwargs["collection_name"]
await self.client.delete_collection(
name=_sanitize_collection_name(collection_name)
)
def reset(self) -> None:
"""Reset the vector database by deleting all collections and data.
Completely clears the ChromaDB instance, removing all collections,
documents, embeddings, and metadata. This operation cannot be undone.
Use with extreme caution in production environments.
Raises:
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> client = ChromaDBClient()
>>> client.reset() # Removes ALL data from ChromaDB
"""
if not _is_sync_client(self.client):
raise TypeError(
"Synchronous method reset() requires a ClientAPI. "
"Use areset() for AsyncClientAPI."
)
self.client.reset()
async def areset(self) -> None:
"""Reset the vector database by deleting all collections and data asynchronously.
Completely clears the ChromaDB instance, removing all collections,
documents, embeddings, and metadata. This operation cannot be undone.
Use with extreme caution in production environments.
Raises:
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
ConnectionError: If unable to connect to ChromaDB server.
Example:
>>> import asyncio
>>> async def main():
... client = ChromaDBClient()
... await client.areset() # Removes ALL data from ChromaDB
>>> asyncio.run(main())
"""
if not _is_async_client(self.client):
raise TypeError(
"Asynchronous method areset() requires an AsyncClientAPI. "
"Use reset() for ClientAPI."
)
await self.client.reset()

View File

@@ -0,0 +1,75 @@
"""ChromaDB configuration model."""
import os
import warnings
from dataclasses import field
from typing import Literal, cast
from chromadb.config import Settings
from pydantic.dataclasses import dataclass as pyd_dataclass
from crewai.rag.chromadb.constants import (
DEFAULT_DATABASE,
DEFAULT_STORAGE_PATH,
DEFAULT_TENANT,
)
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.base import BaseRagConfig
warnings.filterwarnings(
"ignore",
message=".*Mixing V1 models and V2 models.*",
category=UserWarning,
module="pydantic._internal._generate_schema",
)
warnings.filterwarnings(
"ignore",
message=r".*'model_fields'.*is deprecated.*",
module=r"^chromadb(\.|$)",
)
def _default_settings() -> Settings:
"""Create default ChromaDB settings.
Returns:
Settings with persistent storage and reset enabled.
"""
return Settings(
persist_directory=DEFAULT_STORAGE_PATH,
allow_reset=True,
is_persistent=True,
)
def _default_embedding_function() -> ChromaEmbeddingFunctionWrapper:
"""Create default ChromaDB embedding function.
Returns:
Default embedding function using all-MiniLM-L6-v2 via ONNX.
"""
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return cast(
ChromaEmbeddingFunctionWrapper,
OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"),
model_name="text-embedding-3-small",
),
)
@pyd_dataclass(frozen=True)
class ChromaDBConfig(BaseRagConfig):
"""Configuration for ChromaDB client."""
provider: Literal["chromadb"] = field(default="chromadb", init=False)
tenant: str = DEFAULT_TENANT
database: str = DEFAULT_DATABASE
settings: Settings = field(default_factory=_default_settings)
embedding_function: ChromaEmbeddingFunctionWrapper = field(
default_factory=_default_embedding_function
)

View File

@@ -0,0 +1,17 @@
"""Constants for ChromaDB configuration."""
import re
from typing import Final
from crewai.core.utilities.paths import db_storage_path
DEFAULT_TENANT: Final[str] = "default_tenant"
DEFAULT_DATABASE: Final[str] = "default_database"
DEFAULT_STORAGE_PATH: Final[str] = db_storage_path()
MIN_COLLECTION_LENGTH: Final[int] = 3
MAX_COLLECTION_LENGTH: Final[int] = 63
DEFAULT_COLLECTION: Final[str] = "default_collection"
INVALID_CHARS_PATTERN: Final[re.Pattern[str]] = re.compile(r"[^a-zA-Z0-9_-]")
IPV4_PATTERN: Final[re.Pattern[str]] = re.compile(r"^(\d{1,3}\.){3}\d{1,3}$")

View File

@@ -0,0 +1,45 @@
"""Factory functions for creating ChromaDB clients."""
import os
from hashlib import md5
import portalocker
from chromadb import PersistentClient
from crewai.rag.chromadb.client import ChromaDBClient
from crewai.rag.chromadb.config import ChromaDBConfig
def create_client(config: ChromaDBConfig) -> ChromaDBClient:
"""Create a ChromaDBClient from configuration.
Args:
config: ChromaDB configuration object.
Returns:
Configured ChromaDBClient instance.
Notes:
Need to update to use chromadb.Client to support more client types in the near future.
"""
persist_dir = config.settings.persist_directory
os.makedirs(persist_dir, exist_ok=True)
lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest()
lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock")
with portalocker.Lock(lockfile):
client = PersistentClient(
path=persist_dir,
settings=config.settings,
tenant=config.tenant,
database=config.database,
)
return ChromaDBClient(
client=client,
embedding_function=config.embedding_function,
default_limit=config.limit,
default_score_threshold=config.score_threshold,
default_batch_size=config.batch_size,
)

View File

@@ -0,0 +1,103 @@
"""Type definitions specific to ChromaDB implementation."""
from collections.abc import Mapping
from typing import Any, NamedTuple
from chromadb.api import AsyncClientAPI, ClientAPI
from chromadb.api.configuration import CollectionConfigurationInterface
from chromadb.api.types import (
CollectionMetadata,
DataLoader,
Include,
Loadable,
Where,
WhereDocument,
)
from chromadb.api.types import (
EmbeddingFunction as ChromaEmbeddingFunction,
)
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSearchParams
ChromaDBClientType = ClientAPI | AsyncClientAPI
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction):
"""Base class for ChromaDB EmbeddingFunction to work with Pydantic validation."""
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Generate Pydantic core schema for ChromaDB EmbeddingFunction.
This allows Pydantic to handle ChromaDB's EmbeddingFunction type
without requiring arbitrary_types_allowed=True.
"""
return core_schema.any_schema()
class PreparedDocuments(NamedTuple):
"""Prepared documents ready for ChromaDB insertion.
Attributes:
ids: List of document IDs
texts: List of document texts
metadatas: List of document metadata mappings (empty dict for no metadata)
"""
ids: list[str]
texts: list[str]
metadatas: list[Mapping[str, str | int | float | bool]]
class ExtractedSearchParams(NamedTuple):
"""Extracted search parameters for ChromaDB queries.
Attributes:
collection_name: Name of the collection to search
query: Search query text
limit: Maximum number of results
metadata_filter: Optional metadata filter
score_threshold: Optional minimum similarity score
where: Optional ChromaDB where clause
where_document: Optional ChromaDB document filter
include: Fields to include in results
"""
collection_name: str
query: str
limit: int
metadata_filter: dict[str, Any] | None
score_threshold: float | None
where: Where | None
where_document: WhereDocument | None
include: Include
class ChromaDBCollectionCreateParams(BaseCollectionParams, total=False):
"""Parameters for creating a ChromaDB collection.
This class extends BaseCollectionParams to include any additional
parameters specific to ChromaDB collection creation.
"""
configuration: CollectionConfigurationInterface
metadata: CollectionMetadata
embedding_function: ChromaEmbeddingFunction
data_loader: DataLoader[Loadable]
get_or_create: bool
class ChromaDBCollectionSearchParams(BaseCollectionSearchParams, total=False):
"""Parameters for searching a ChromaDB collection.
This class extends BaseCollectionSearchParams to include ChromaDB-specific
search parameters like where clauses and include options.
"""
where: Where
where_document: WhereDocument
include: Include

View File

@@ -0,0 +1,323 @@
"""Utility functions for ChromaDB client implementation."""
import hashlib
import json
from collections.abc import Mapping
from typing import Literal, TypeGuard, cast
from chromadb.api import AsyncClientAPI, ClientAPI
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.models.Collection import Collection
from chromadb.api.types import (
Include,
QueryResult,
)
from crewai.rag.chromadb.constants import (
DEFAULT_COLLECTION,
INVALID_CHARS_PATTERN,
IPV4_PATTERN,
MAX_COLLECTION_LENGTH,
MIN_COLLECTION_LENGTH,
)
from crewai.rag.chromadb.types import (
ChromaDBClientType,
ChromaDBCollectionSearchParams,
ExtractedSearchParams,
PreparedDocuments,
)
from crewai.rag.types import BaseRecord, SearchResult
def _is_sync_client(client: ChromaDBClientType) -> TypeGuard[ClientAPI]:
"""Type guard to check if the client is a synchronous ClientAPI.
Args:
client: The client to check.
Returns:
True if the client is a ClientAPI, False otherwise.
"""
return isinstance(client, ClientAPI)
def _is_async_client(client: ChromaDBClientType) -> TypeGuard[AsyncClientAPI]:
"""Type guard to check if the client is an asynchronous AsyncClientAPI.
Args:
client: The client to check.
Returns:
True if the client is an AsyncClientAPI, False otherwise.
"""
return isinstance(client, AsyncClientAPI)
def _prepare_documents_for_chromadb(
documents: list[BaseRecord],
) -> PreparedDocuments:
"""Prepare documents for ChromaDB by extracting IDs, texts, and metadata.
Args:
documents: List of BaseRecord documents to prepare.
Returns:
PreparedDocuments with ids, texts, and metadatas ready for ChromaDB.
"""
ids: list[str] = []
texts: list[str] = []
metadatas: list[Mapping[str, str | int | float | bool]] = []
for doc in documents:
if "doc_id" in doc:
ids.append(doc["doc_id"])
else:
content_for_hash = doc["content"]
metadata = doc.get("metadata")
if metadata:
metadata_str = json.dumps(metadata, sort_keys=True)
content_for_hash = f"{content_for_hash}|{metadata_str}"
content_hash = hashlib.blake2b(
content_for_hash.encode(), digest_size=32
).hexdigest()
ids.append(content_hash)
texts.append(doc["content"])
metadata = doc.get("metadata")
if metadata:
if isinstance(metadata, list):
metadatas.append(metadata[0] if metadata and metadata[0] else {})
else:
metadatas.append(metadata)
else:
metadatas.append({})
return PreparedDocuments(ids, texts, metadatas)
def _create_batch_slice(
prepared: PreparedDocuments, start_index: int, batch_size: int
) -> tuple[list[str], list[str], list[Mapping[str, str | int | float | bool]] | None]:
"""Create a batch slice from prepared documents.
Args:
prepared: PreparedDocuments containing ids, texts, and metadatas.
start_index: Starting index for the batch.
batch_size: Size of the batch.
Returns:
Tuple of (batch_ids, batch_texts, batch_metadatas).
"""
batch_end = min(start_index + batch_size, len(prepared.ids))
batch_ids = prepared.ids[start_index:batch_end]
batch_texts = prepared.texts[start_index:batch_end]
batch_metadatas = (
prepared.metadatas[start_index:batch_end] if prepared.metadatas else None
)
if batch_metadatas and not any(m for m in batch_metadatas):
batch_metadatas = None
return batch_ids, batch_texts, batch_metadatas
def _extract_search_params(
kwargs: ChromaDBCollectionSearchParams,
) -> ExtractedSearchParams:
"""Extract search parameters from kwargs.
Args:
kwargs: Keyword arguments containing search parameters.
Returns:
ExtractedSearchParams with all extracted parameters.
"""
return ExtractedSearchParams(
collection_name=kwargs["collection_name"],
query=kwargs["query"],
limit=kwargs.get("limit", 10),
metadata_filter=kwargs.get("metadata_filter"),
score_threshold=kwargs.get("score_threshold"),
where=kwargs.get("where"),
where_document=kwargs.get("where_document"),
include=cast(
Include,
kwargs.get(
"include",
["metadatas", "documents", "distances"],
),
),
)
def _convert_distance_to_score(
distance: float,
distance_metric: Literal["l2", "cosine", "ip"],
) -> float:
"""Convert ChromaDB distance to similarity score.
Notes:
Assuming all embedding are unit-normalized for now, including custom embeddings.
Args:
distance: The distance value from ChromaDB.
distance_metric: The distance metric used ("l2", "cosine", or "ip").
Returns:
Similarity score in range [0, 1] where 1 is most similar.
"""
if distance_metric == "cosine":
score = 1.0 - 0.5 * distance
return max(0.0, min(1.0, score))
if distance_metric == "l2":
score = 1.0 / (1.0 + distance)
return max(0.0, min(1.0, score))
raise ValueError(f"Unsupported distance metric: {distance_metric}")
def _convert_chromadb_results_to_search_results(
results: QueryResult,
include: Include,
distance_metric: Literal["l2", "cosine", "ip"],
score_threshold: float | None = None,
) -> list[SearchResult]:
"""Convert ChromaDB query results to SearchResult format.
Args:
results: ChromaDB query results.
include: List of fields that were included in the query.
distance_metric: The distance metric used by the collection.
score_threshold: Optional minimum similarity score (0-1) for results.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
"""
search_results: list[SearchResult] = []
include_strings = list(include) if include else []
ids = results["ids"][0] if results.get("ids") else []
documents_list = results.get("documents")
documents = (
documents_list[0] if documents_list and "documents" in include_strings else []
)
metadatas_list = results.get("metadatas")
metadatas = (
metadatas_list[0] if metadatas_list and "metadatas" in include_strings else []
)
distances_list = results.get("distances")
distances = (
distances_list[0] if distances_list and "distances" in include_strings else []
)
for i, doc_id in enumerate(ids):
if not distances or i >= len(distances):
continue
distance = distances[i]
score = _convert_distance_to_score(
distance=distance, distance_metric=distance_metric
)
if score_threshold and score < score_threshold:
continue
result: SearchResult = {
"id": doc_id,
"content": documents[i] if documents and i < len(documents) else "",
"metadata": dict(metadatas[i])
if metadatas and i < len(metadatas) and metadatas[i] is not None
else {},
"score": score,
}
search_results.append(result)
return search_results
def _process_query_results(
collection: Collection | AsyncCollection,
results: QueryResult,
params: ExtractedSearchParams,
) -> list[SearchResult]:
"""Process ChromaDB query results and convert to SearchResult format.
Args:
collection: The ChromaDB collection (sync or async) that was queried.
results: Raw query results from ChromaDB.
params: The search parameters used for the query.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
"""
distance_metric = cast(
Literal["l2", "cosine", "ip"],
collection.metadata.get("hnsw:space", "l2") if collection.metadata else "l2",
)
return _convert_chromadb_results_to_search_results(
results=results,
include=params.include,
distance_metric=distance_metric,
score_threshold=params.score_threshold,
)
def _is_ipv4_pattern(name: str) -> bool:
"""Check if a string matches an IPv4 address pattern.
Args:
name: The string to check
Returns:
True if the string matches an IPv4 pattern, False otherwise
"""
return bool(IPV4_PATTERN.match(name))
def _sanitize_collection_name(
name: str | None, max_collection_length: int = MAX_COLLECTION_LENGTH
) -> str:
"""Sanitize a collection name to meet ChromaDB requirements.
Requirements:
1. 3-63 characters long
2. Starts and ends with alphanumeric character
3. Contains only alphanumeric characters, underscores, or hyphens
4. No consecutive periods
5. Not a valid IPv4 address
Args:
name: The original collection name to sanitize
max_collection_length: Maximum allowed length for the collection name
Returns:
A sanitized collection name that meets ChromaDB requirements
"""
if not name:
return DEFAULT_COLLECTION
if _is_ipv4_pattern(name):
name = f"ip_{name}"
sanitized = INVALID_CHARS_PATTERN.sub("_", name)
if not sanitized[0].isalnum():
sanitized = "a" + sanitized
if not sanitized[-1].isalnum():
sanitized = sanitized[:-1] + "z"
if len(sanitized) < MIN_COLLECTION_LENGTH:
sanitized += "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
if len(sanitized) > max_collection_length:
sanitized = sanitized[:max_collection_length]
if not sanitized[-1].isalnum():
sanitized = sanitized[:-1] + "z"
return sanitized

View File

@@ -0,0 +1 @@
"""RAG client configuration management using ContextVars for thread-safe provider switching."""

View File

@@ -0,0 +1,19 @@
"""Base configuration class for RAG providers."""
from dataclasses import field
from typing import Any
from pydantic.dataclasses import dataclass as pyd_dataclass
from crewai.rag.config.optional_imports.types import SupportedProvider
@pyd_dataclass(frozen=True)
class BaseRagConfig:
"""Base class for RAG configuration with Pydantic serialization support."""
provider: SupportedProvider = field(init=False)
embedding_function: Any | None = field(default=None)
limit: int = field(default=5)
score_threshold: float = field(default=0.6)
batch_size: int = field(default=100)

View File

@@ -0,0 +1,8 @@
"""Constants for RAG configuration."""
from typing import Final
DISCRIMINATOR: Final[str] = "provider"
DEFAULT_RAG_CONFIG_PATH: Final[str] = "crewai.rag.chromadb.config"
DEFAULT_RAG_CONFIG_CLASS: Final[str] = "ChromaDBConfig"

View File

@@ -0,0 +1 @@
"""Optional imports for RAG configuration providers."""

View File

@@ -0,0 +1,26 @@
"""Base classes for missing provider configurations."""
from dataclasses import field
from typing import Literal
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass as pyd_dataclass
@pyd_dataclass(config=ConfigDict(extra="forbid"))
class _MissingProvider:
"""Base class for missing provider configurations.
Raises RuntimeError when instantiated to indicate missing dependencies.
"""
provider: Literal["chromadb", "qdrant", "__missing__"] = field(
default="__missing__"
)
def __post_init__(self) -> None:
"""Raises error indicating the provider is not installed."""
raise RuntimeError(
f"provider '{self.provider}' requested but not installed. "
f"Install the extra: `uv add crewai'[{self.provider}]'`."
)

View File

@@ -0,0 +1,27 @@
"""Protocol definitions for RAG factory modules."""
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING:
from crewai.rag.chromadb.client import ChromaDBClient
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.qdrant.client import QdrantClient
from crewai.rag.qdrant.config import QdrantConfig
class ChromaFactoryModule(Protocol):
"""Protocol for ChromaDB factory module."""
def create_client(self, config: ChromaDBConfig) -> ChromaDBClient:
"""Creates a ChromaDB client from configuration."""
...
class QdrantFactoryModule(Protocol):
"""Protocol for Qdrant factory module."""
def create_client(self, config: QdrantConfig) -> QdrantClient:
"""Creates a Qdrant client from configuration."""
...

View File

@@ -0,0 +1,23 @@
"""Provider-specific missing configuration classes."""
from dataclasses import field
from typing import Literal
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass as pyd_dataclass
from crewai.rag.config.optional_imports.base import _MissingProvider
@pyd_dataclass(config=ConfigDict(extra="forbid"))
class MissingChromaDBConfig(_MissingProvider):
"""Placeholder for missing ChromaDB configuration."""
provider: Literal["chromadb"] = field(default="chromadb")
@pyd_dataclass(config=ConfigDict(extra="forbid"))
class MissingQdrantConfig(_MissingProvider):
"""Placeholder for missing Qdrant configuration."""
provider: Literal["qdrant"] = field(default="qdrant")

View File

@@ -0,0 +1,8 @@
"""Type definitions for optional imports."""
from typing import Annotated, Literal
SupportedProvider = Annotated[
Literal["chromadb", "qdrant"],
"Supported RAG provider types, add providers here as they become available",
]

View File

@@ -0,0 +1,35 @@
"""Type definitions for RAG configuration."""
from typing import TYPE_CHECKING, Annotated, TypeAlias
from pydantic import Field
from crewai.rag.config.constants import DISCRIMINATOR
# Linter freaks out on conditional imports, assigning in the type checking fixes it
if TYPE_CHECKING:
from crewai.rag.chromadb.config import ChromaDBConfig as ChromaDBConfig_
ChromaDBConfig = ChromaDBConfig_
from crewai.rag.qdrant.config import QdrantConfig as QdrantConfig_
QdrantConfig = QdrantConfig_
else:
try:
from crewai.rag.chromadb.config import ChromaDBConfig
except ImportError:
from crewai.rag.config.optional_imports.providers import (
MissingChromaDBConfig as ChromaDBConfig,
)
try:
from crewai.rag.qdrant.config import QdrantConfig
except ImportError:
from crewai.rag.config.optional_imports.providers import (
MissingQdrantConfig as QdrantConfig,
)
SupportedProviderConfig: TypeAlias = ChromaDBConfig | QdrantConfig
RagConfigType: TypeAlias = Annotated[
SupportedProviderConfig, Field(discriminator=DISCRIMINATOR)
]

View File

@@ -0,0 +1,86 @@
"""RAG client configuration utilities."""
from contextvars import ContextVar
from pydantic import BaseModel, Field
from crewai.rag.config.constants import (
DEFAULT_RAG_CONFIG_CLASS,
DEFAULT_RAG_CONFIG_PATH,
)
from crewai.rag.config.types import RagConfigType
from crewai.rag.core.base_client import BaseClient
from crewai.rag.factory import create_client
from crewai.core.utilities.import_utils import require
class RagContext(BaseModel):
"""Context holding RAG configuration and client instance."""
config: RagConfigType = Field(..., description="RAG provider configuration")
client: BaseClient | None = Field(
default=None, description="Instantiated RAG client"
)
_rag_context: ContextVar[RagContext | None] = ContextVar("_rag_context", default=None)
def set_rag_config(config: RagConfigType) -> None:
"""Set global RAG client configuration and instantiate the client.
Args:
config: The RAG client configuration (ChromaDBConfig).
"""
client = create_client(config)
context = RagContext(config=config, client=client)
_rag_context.set(context)
def get_rag_config() -> RagConfigType:
"""Get current RAG configuration.
Returns:
The current RAG configuration object.
"""
context = _rag_context.get()
if context is None:
module = require(DEFAULT_RAG_CONFIG_PATH, purpose="RAG configuration")
config_class = getattr(module, DEFAULT_RAG_CONFIG_CLASS)
default_config = config_class()
set_rag_config(default_config)
context = _rag_context.get()
if context is None or context.config is None:
raise ValueError(
"RAG configuration is not set. Please set the RAG config first."
)
return context.config
def get_rag_client() -> BaseClient:
"""Get the current RAG client instance.
Returns:
The current RAG client, creating one if needed.
"""
context = _rag_context.get()
if context is None:
get_rag_config()
context = _rag_context.get()
if context and context.client is None:
context.client = create_client(context.config)
if context is None or context.client is None:
raise ValueError(
"RAG client is not configured. Please set the RAG config first."
)
return context.client
def clear_rag_config() -> None:
"""Clear the current RAG configuration and client, reverting to defaults."""
_rag_context.set(None)

View File

@@ -0,0 +1 @@
"""Core abstract base classes and protocols for RAG systems."""

View File

@@ -0,0 +1,448 @@
"""Protocol for vector database client implementations."""
from abc import abstractmethod
from typing import Annotated, Any, Protocol, runtime_checkable
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from typing_extensions import Required, TypedDict, Unpack
from crewai.rag.types import (
BaseRecord,
EmbeddingFunction,
SearchResult,
)
class BaseCollectionParams(TypedDict):
"""Base parameters for collection operations.
Attributes:
collection_name: The name of the collection/index to operate on.
"""
collection_name: Required[
Annotated[
str,
"Name of the collection/index. Implementations may have specific constraints (e.g., character limits, allowed characters, case sensitivity).",
]
]
class BaseCollectionAddParams(BaseCollectionParams, total=False):
"""Parameters for adding documents to a collection.
Extends BaseCollectionParams with document-specific fields.
Attributes:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dictionaries containing document data.
batch_size: Optional batch size for processing documents to avoid token limits.
"""
documents: Required[list[BaseRecord]]
batch_size: int
class BaseCollectionSearchParams(BaseCollectionParams, total=False):
"""Parameters for searching within a collection.
Extends BaseCollectionParams with search-specific optional fields.
All fields except collection_name and query are optional.
Attributes:
query: The text query to search for (required).
limit: Maximum number of results to return.
metadata_filter: Filter results by metadata fields.
score_threshold: Minimum similarity score for results (0-1).
"""
query: Required[str]
limit: int
metadata_filter: dict[str, Any] | None
score_threshold: float
@runtime_checkable
class BaseClient(Protocol):
"""Protocol for vector store client implementations.
This protocol defines the interface that all vector store client implementations
must follow. It provides a consistent API for storing and retrieving
documents with their vector embeddings across different vector database
backends (e.g., Qdrant, ChromaDB, Weaviate). Implementing classes should
handle connection management, data persistence, and vector similarity
search operations specific to their backend.
Implementation Guidelines:
Implementations should accept BaseClientParams in their constructor to allow
passing pre-configured client instances:
class MyVectorClient:
def __init__(self, client: Any | None = None, **kwargs):
if client:
self.client = client
else:
self.client = self._create_default_client(**kwargs)
Notes:
This protocol replaces the former BaseRAGStorage abstraction,
providing a cleaner interface for vector store operations.
Attributes:
embedding_function: Callable that takes a list of text strings
and returns a list of embedding vectors. Implementations
should always provide a default embedding function.
client: The underlying vector database client instance. This could be
passed via BaseClientParams during initialization or created internally.
"""
client: Any
embedding_function: EmbeddingFunction
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Generate Pydantic core schema for BaseClient Protocol.
This allows the Protocol to be used in Pydantic models without
requiring arbitrary_types_allowed=True.
"""
return core_schema.any_schema()
@abstractmethod
def create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Create a new collection/index in the vector database.
Keyword Args:
collection_name: The name of the collection to create. Must be unique within
the vector database instance.
Raises:
ValueError: If collection name already exists.
ConnectionError: If unable to connect to the vector database backend.
"""
...
@abstractmethod
async def acreate_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Create a new collection/index in the vector database asynchronously.
Keyword Args:
collection_name: The name of the collection to create. Must be unique within
the vector database instance.
Raises:
ValueError: If collection name already exists.
ConnectionError: If unable to connect to the vector database backend.
"""
...
@abstractmethod
def get_or_create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> Any:
"""Get an existing collection or create it if it doesn't exist.
This method provides a convenient way to ensure a collection exists
without having to check for its existence first.
Keyword Args:
collection_name: The name of the collection to get or create.
Returns:
A collection object whose type depends on the backend implementation.
This could be a collection reference, ID, or client object.
Raises:
ValueError: If unable to create the collection.
ConnectionError: If unable to connect to the vector database backend.
"""
...
@abstractmethod
async def aget_or_create_collection(
self, **kwargs: Unpack[BaseCollectionParams]
) -> Any:
"""Get an existing collection or create it if it doesn't exist asynchronously.
Keyword Args:
collection_name: The name of the collection to get or create.
Returns:
A collection object whose type depends on the backend implementation.
Raises:
ValueError: If unable to create the collection.
ConnectionError: If unable to connect to the vector database backend.
"""
...
@abstractmethod
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection.
This method performs an upsert operation - if a document with the same ID
already exists, it will be updated with the new content and metadata.
Implementations should handle embedding generation internally based on
the configured embedding function.
Keyword Args:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing:
- content: The text content (required)
- doc_id: Optional unique identifier (auto-generated from content hash if missing)
- metadata: Optional metadata dictionary
Embeddings will be generated automatically.
Raises:
ValueError: If collection doesn't exist or documents list is empty.
TypeError: If documents are not BaseRecord dict instances.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>> from crewai.rag.types import BaseRecord
>>> client = ChromaDBClient()
>>>
>>> records: list[BaseRecord] = [
... {
... "content": "Machine learning basics",
... "metadata": {"source": "file3", "topic": "ML"}
... },
... {
... "doc_id": "custom_id",
... "content": "Deep learning fundamentals",
... "metadata": {"source": "file4", "topic": "DL"}
... }
... ]
>>> client.add_documents(collection_name="my_docs", documents=records)
>>>
>>> records_with_id: list[BaseRecord] = [
... {
... "doc_id": "nlp_001",
... "content": "Advanced NLP techniques",
... "metadata": {"source": "file5", "topic": "NLP"}
... }
... ]
>>> client.add_documents(collection_name="my_docs", documents=records_with_id)
"""
...
@abstractmethod
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection asynchronously.
Implementations should handle embedding generation internally based on
the configured embedding function.
Keyword Args:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing:
- content: The text content (required)
- doc_id: Optional unique identifier (auto-generated from content hash if missing)
- metadata: Optional metadata dictionary
Embeddings will be generated automatically.
Raises:
ValueError: If collection doesn't exist or documents list is empty.
TypeError: If documents are not BaseRecord dict instances.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> import asyncio
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>> from crewai.rag.types import BaseRecord
>>>
>>> async def add_documents():
... client = ChromaDBClient()
...
... records: list[BaseRecord] = [
... {
... "doc_id": "doc2",
... "content": "Async operations in Python",
... "metadata": {"source": "file2", "topic": "async"}
... }
... ]
... await client.aadd_documents(collection_name="my_docs", documents=records)
...
>>> asyncio.run(add_documents())
"""
...
@abstractmethod
def search(
self, **kwargs: Unpack[BaseCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query.
Performs a vector similarity search to find the most similar documents
to the provided query.
Keyword Args:
collection_name: The name of the collection to search in.
query: The text query to search for. The implementation handles
embedding generation internally.
limit: Maximum number of results to return. Defaults to 10.
metadata_filter: Optional metadata filter to apply to the search. The exact
format depends on the backend, but typically supports equality
and range queries on metadata fields.
score_threshold: Optional minimum similarity score threshold. Only
results with scores >= this threshold will be returned. The
score interpretation depends on the distance metric used.
Returns:
A list of SearchResult dictionaries ordered by similarity score in
descending order. Each result contains:
- id: Document ID
- content: Document text content
- metadata: Document metadata
- score: Similarity score (0-1, higher is better)
Raises:
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>> client = ChromaDBClient()
>>>
>>> results = client.search(
... collection_name="my_docs",
... query="What is machine learning?",
... limit=5,
... metadata_filter={"source": "file1"},
... score_threshold=0.7
... )
>>> for result in results:
... print(f"{result['id']}: {result['score']:.2f}")
"""
...
@abstractmethod
async def asearch(
self, **kwargs: Unpack[BaseCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query asynchronously.
Keyword Args:
collection_name: The name of the collection to search in.
query: The text query to search for. The implementation handles
embedding generation internally.
limit: Maximum number of results to return. Defaults to 10.
metadata_filter: Optional metadata filter to apply to the search.
score_threshold: Optional minimum similarity score threshold.
Returns:
A list of SearchResult dictionaries ordered by similarity score.
Raises:
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> import asyncio
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>>
>>> async def search_documents():
... client = ChromaDBClient()
... results = await client.asearch(
... collection_name="my_docs",
... query="Python programming best practices",
... limit=5,
... metadata_filter={"source": "file1"},
... score_threshold=0.7
... )
... for result in results:
... print(f"{result['id']}: {result['score']:.2f}")
...
>>> asyncio.run(search_documents())
"""
...
@abstractmethod
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data.
This operation is irreversible and will permanently remove all documents,
embeddings, and metadata associated with the collection.
Keyword Args:
collection_name: The name of the collection to delete.
Raises:
ValueError: If the collection doesn't exist.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>> client = ChromaDBClient()
>>> client.delete_collection(collection_name="old_docs")
>>> print("Collection 'old_docs' deleted successfully")
"""
...
@abstractmethod
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data asynchronously.
Keyword Args:
collection_name: The name of the collection to delete.
Raises:
ValueError: If the collection doesn't exist.
ConnectionError: If unable to connect to the vector database backend.
Example:
>>> import asyncio
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>>
>>> async def delete_old_collection():
... client = ChromaDBClient()
... await client.adelete_collection(collection_name="old_docs")
... print("Collection 'old_docs' deleted successfully")
...
>>> asyncio.run(delete_old_collection())
"""
...
@abstractmethod
def reset(self) -> None:
"""Reset the vector database by deleting all collections and data.
This method provides a way to completely clear the vector database,
removing all collections and their contents. Use with caution as
this operation is irreversible.
Raises:
ConnectionError: If unable to connect to the vector database backend.
PermissionError: If the operation is not allowed by the backend.
Example:
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>> client = ChromaDBClient()
>>> client.reset()
>>> print("Vector database completely reset - all data deleted")
"""
...
@abstractmethod
async def areset(self) -> None:
"""Reset the vector database by deleting all collections and data asynchronously.
Raises:
ConnectionError: If unable to connect to the vector database backend.
PermissionError: If the operation is not allowed by the backend.
Example:
>>> import asyncio
>>> from crewai.rag.chromadb.client import ChromaDBClient
>>>
>>> async def reset_database():
... client = ChromaDBClient()
... await client.areset()
... print("Vector database completely reset - all data deleted")
...
>>> asyncio.run(reset_database())
"""
...

View File

@@ -0,0 +1,149 @@
"""Base embeddings callable utilities for RAG systems."""
from typing import Protocol, TypeVar, runtime_checkable
import numpy as np
from crewai.rag.core.types import (
Embeddable,
Embedding,
Embeddings,
PyEmbedding,
)
T = TypeVar("T")
D = TypeVar("D", bound=Embeddable, contravariant=True)
def normalize_embeddings(
target: Embedding | list[Embedding] | PyEmbedding | list[PyEmbedding],
) -> Embeddings | None:
"""Normalize various embedding formats to a standard list of numpy arrays.
Args:
target: Input embeddings in various formats (list of floats, list of lists,
numpy array, or list of numpy arrays).
Returns:
Normalized embeddings as a list of numpy arrays, or None if input is None.
Raises:
ValueError: If embeddings are empty or in an unsupported format.
"""
if isinstance(target, np.ndarray):
if target.ndim == 1:
return [target.astype(np.float32)]
if target.ndim == 2:
return [row.astype(np.float32) for row in target]
raise ValueError(f"Unsupported numpy array shape: {target.shape}")
first = target[0]
if isinstance(first, (int, float)) and not isinstance(first, bool):
return [np.array(target, dtype=np.float32)]
if isinstance(first, list):
return [np.array(emb, dtype=np.float32) for emb in target]
if isinstance(first, np.ndarray):
return [emb.astype(np.float32) for emb in target] # type: ignore[union-attr]
raise ValueError(f"Unsupported embeddings format: {type(first)}")
def maybe_cast_one_to_many(target: T | list[T] | None) -> list[T] | None:
"""Cast a single item to a list if needed.
Args:
target: A single item or list of items.
Returns:
A list of items or None if input is None.
"""
if target is None:
return None
return target if isinstance(target, list) else [target]
def validate_embeddings(embeddings: Embeddings) -> Embeddings:
"""Validate embeddings format and content.
Args:
embeddings: List of numpy arrays to validate.
Returns:
Validated embeddings.
Raises:
ValueError: If embeddings format or content is invalid.
"""
if not isinstance(embeddings, list):
raise ValueError(
f"Expected embeddings to be a list, got {type(embeddings).__name__}"
)
if len(embeddings) == 0:
raise ValueError(
f"Expected embeddings to be a list with at least one item, got {len(embeddings)} embeddings"
)
if not all(isinstance(e, np.ndarray) for e in embeddings):
raise ValueError(
"Expected each embedding in the embeddings to be a numpy array"
)
for i, embedding in enumerate(embeddings):
if embedding.ndim == 0:
raise ValueError(
f"Expected a 1-dimensional array, got a 0-dimensional array {embedding}"
)
if embedding.size == 0:
raise ValueError(
f"Expected each embedding to be a 1-dimensional numpy array with at least 1 value. "
f"Got an array with no values at position {i}"
)
if not all(
isinstance(value, (np.integer, float, np.floating))
and not isinstance(value, bool)
for value in embedding
):
raise ValueError(
f"Expected embedding to contain numeric values, got non-numeric values at position {i}"
)
return embeddings
@runtime_checkable
class EmbeddingFunction(Protocol[D]):
"""Protocol for embedding functions.
Embedding functions convert input data (documents or images) into vector embeddings.
"""
def __call__(self, input: D) -> Embeddings:
"""Convert input data to embeddings.
Args:
input: Input data to embed (documents or images).
Returns:
List of numpy arrays representing the embeddings.
"""
...
def __init_subclass__(cls) -> None:
"""Wrap __call__ method to normalize and validate embeddings."""
super().__init_subclass__()
original_call = cls.__call__
def wrapped_call(self: EmbeddingFunction[D], input: D) -> Embeddings:
result = original_call(self, input)
if result is None:
raise ValueError("Embedding function returned None")
normalized = normalize_embeddings(result)
if normalized is None:
raise ValueError("Normalization returned None for non-None input")
return validate_embeddings(normalized)
cls.__call__ = wrapped_call # type: ignore[method-assign]
def embed_query(self, input: D) -> Embeddings:
"""
Get the embeddings for a query input.
This method is optional, and if not implemented, the default behavior is to call __call__.
"""
return self.__call__(input=input)

View File

@@ -0,0 +1,23 @@
"""Base class for embedding providers."""
from typing import Generic, TypeVar
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
T = TypeVar("T", bound=EmbeddingFunction)
class BaseEmbeddingsProvider(BaseSettings, Generic[T]):
"""Abstract base class for embedding providers.
This class provides a common interface for dynamically loading and building
embedding functions from various providers.
"""
model_config = SettingsConfigDict(extra="allow", populate_by_name=True)
embedding_callable: type[T] = Field(
..., description="The embedding function class to use"
)

View File

@@ -0,0 +1,26 @@
"""Core exceptions for RAG module."""
class ClientMethodMismatchError(TypeError):
"""Raised when a method is called with the wrong client type.
Typically used when a sync method is called with an async client,
or vice versa.
"""
def __init__(
self, method_name: str, expected_client: str, alt_method: str, alt_client: str
) -> None:
"""Create a ClientMethodMismatchError.
Args:
method_name: Method that was called incorrectly.
expected_client: Required client type.
alt_method: Suggested alternative method.
alt_client: Client type for the alternative method.
"""
message = (
f"Method {method_name}() requires a {expected_client}. "
f"Use {alt_method}() for {alt_client}."
)
super().__init__(message)

View File

@@ -0,0 +1,28 @@
"""Core type definitions for RAG systems."""
from collections.abc import Sequence
from typing import TypeVar
import numpy as np
from numpy import floating, integer, number
from numpy.typing import NDArray
T = TypeVar("T")
PyEmbedding = Sequence[float] | Sequence[int]
PyEmbeddings = list[PyEmbedding]
Embedding = NDArray[np.int32 | np.float32]
Embeddings = list[Embedding]
Documents = list[str]
Images = list[np.ndarray]
Embeddable = Documents | Images
ScalarType = TypeVar("ScalarType", bound=np.generic)
IntegerType = TypeVar("IntegerType", bound=integer)
FloatingType = TypeVar("FloatingType", bound=floating)
NumberType = TypeVar("NumberType", bound=number)
DType32 = TypeVar("DType32", np.int32, np.float32)
DType64 = TypeVar("DType64", np.int64, np.float64)
DTypeCommon = TypeVar("DTypeCommon", np.int32, np.int64, np.float32, np.float64)

View File

@@ -0,0 +1 @@
"""Embedding components for RAG infrastructure."""

View File

@@ -0,0 +1,392 @@
"""Factory functions for creating embedding providers and functions."""
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, TypeVar, overload
from typing_extensions import deprecated
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.core.utilities.import_utils import import_and_validate_definition
if TYPE_CHECKING:
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
GoogleVertexEmbeddingFunction,
)
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingFunction,
)
from chromadb.utils.embedding_functions.instructor_embedding_function import (
InstructorEmbeddingFunction,
)
from chromadb.utils.embedding_functions.jina_embedding_function import (
JinaEmbeddingFunction,
)
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
OpenCLIPEmbeddingFunction,
)
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
RoboflowEmbeddingFunction,
)
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
SentenceTransformerEmbeddingFunction,
)
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
Text2VecEmbeddingFunction,
)
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
from crewai.rag.embeddings.providers.google.types import (
GenerativeAiProviderSpec,
VertexAIProviderSpec,
)
from crewai.rag.embeddings.providers.huggingface.types import (
HuggingFaceProviderSpec,
)
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
WatsonXEmbeddingFunction,
)
from crewai.rag.embeddings.providers.ibm.types import (
WatsonProviderSpec,
WatsonXProviderSpec,
)
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec
from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
from crewai.rag.embeddings.providers.openclip.types import OpenCLIPProviderSpec
from crewai.rag.embeddings.providers.roboflow.types import RoboflowProviderSpec
from crewai.rag.embeddings.providers.sentence_transformer.types import (
SentenceTransformerProviderSpec,
)
from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
VoyageAIEmbeddingFunction,
)
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
T = TypeVar("T", bound=EmbeddingFunction)
PROVIDER_PATHS = {
"azure": "crewai.rag.embeddings.providers.microsoft.azure.AzureProvider",
"amazon-bedrock": "crewai.rag.embeddings.providers.aws.bedrock.BedrockProvider",
"cohere": "crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider",
"custom": "crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider",
"google-generativeai": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
"google-vertex": "crewai.rag.embeddings.providers.google.vertex.VertexAIProvider",
"huggingface": "crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider",
"instructor": "crewai.rag.embeddings.providers.instructor.instructor_provider.InstructorProvider",
"jina": "crewai.rag.embeddings.providers.jina.jina_provider.JinaProvider",
"ollama": "crewai.rag.embeddings.providers.ollama.ollama_provider.OllamaProvider",
"onnx": "crewai.rag.embeddings.providers.onnx.onnx_provider.ONNXProvider",
"openai": "crewai.rag.embeddings.providers.openai.openai_provider.OpenAIProvider",
"openclip": "crewai.rag.embeddings.providers.openclip.openclip_provider.OpenCLIPProvider",
"roboflow": "crewai.rag.embeddings.providers.roboflow.roboflow_provider.RoboflowProvider",
"sentence-transformer": "crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider.SentenceTransformerProvider",
"text2vec": "crewai.rag.embeddings.providers.text2vec.text2vec_provider.Text2VecProvider",
"voyageai": "crewai.rag.embeddings.providers.voyageai.voyageai_provider.VoyageAIProvider",
"watson": "crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider", # Deprecated alias
"watsonx": "crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider",
}
def build_embedder_from_provider(provider: BaseEmbeddingsProvider[T]) -> T:
"""Build an embedding function instance from a provider.
Args:
provider: The embedding provider configuration.
Returns:
An instance of the specified embedding function type.
"""
return provider.embedding_callable(
**provider.model_dump(exclude={"embedding_callable"})
)
@overload
def build_embedder_from_dict(spec: AzureProviderSpec) -> OpenAIEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: BedrockProviderSpec,
) -> AmazonBedrockEmbeddingFunction: ...
@overload
def build_embedder_from_dict(spec: CohereProviderSpec) -> CohereEmbeddingFunction: ...
@overload
def build_embedder_from_dict(spec: CustomProviderSpec) -> EmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: GenerativeAiProviderSpec,
) -> GoogleGenerativeAiEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: HuggingFaceProviderSpec,
) -> HuggingFaceEmbeddingFunction: ...
@overload
def build_embedder_from_dict(spec: OllamaProviderSpec) -> OllamaEmbeddingFunction: ...
@overload
def build_embedder_from_dict(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: VertexAIProviderSpec,
) -> GoogleVertexEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: VoyageAIProviderSpec,
) -> VoyageAIEmbeddingFunction: ...
@overload
def build_embedder_from_dict(spec: WatsonXProviderSpec) -> WatsonXEmbeddingFunction: ...
@overload
@deprecated(
'The "WatsonProviderSpec" provider spec is deprecated and will be removed in v1.0.0. Use "WatsonXProviderSpec" instead.'
)
def build_embedder_from_dict(spec: WatsonProviderSpec) -> WatsonXEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: SentenceTransformerProviderSpec,
) -> SentenceTransformerEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: InstructorProviderSpec,
) -> InstructorEmbeddingFunction: ...
@overload
def build_embedder_from_dict(spec: JinaProviderSpec) -> JinaEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: RoboflowProviderSpec,
) -> RoboflowEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: OpenCLIPProviderSpec,
) -> OpenCLIPEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: Text2VecProviderSpec,
) -> Text2VecEmbeddingFunction: ...
@overload
def build_embedder_from_dict(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
def build_embedder_from_dict(spec):
"""Build an embedding function instance from a dictionary specification.
Args:
spec: A dictionary with 'provider' and 'config' keys.
Example: {
"provider": "openai",
"config": {
"api_key": "sk-...",
"model_name": "text-embedding-3-small"
}
}
Returns:
An instance of the appropriate embedding function.
Raises:
ValueError: If the provider is not recognized.
"""
provider_name = spec["provider"]
if not provider_name:
raise ValueError("Missing 'provider' key in specification")
if provider_name == "watson":
warnings.warn(
'The "watson" provider key is deprecated and will be removed in v1.0.0. '
'Use "watsonx" instead.',
DeprecationWarning,
stacklevel=2,
)
if provider_name not in PROVIDER_PATHS:
raise ValueError(
f"Unknown provider: {provider_name}. Available providers: {list(PROVIDER_PATHS.keys())}"
)
provider_path = PROVIDER_PATHS[provider_name]
try:
provider_class = import_and_validate_definition(provider_path)
except (ImportError, AttributeError, ValueError) as e:
raise ImportError(f"Failed to import provider {provider_name}: {e}") from e
provider_config = spec.get("config", {})
if provider_name == "custom" and "embedding_callable" not in provider_config:
raise ValueError("Custom provider requires 'embedding_callable' in config")
provider = provider_class(**provider_config)
return build_embedder_from_provider(provider)
@overload
def build_embedder(spec: BaseEmbeddingsProvider[T]) -> T: ...
@overload
def build_embedder(spec: AzureProviderSpec) -> OpenAIEmbeddingFunction: ...
@overload
def build_embedder(spec: BedrockProviderSpec) -> AmazonBedrockEmbeddingFunction: ...
@overload
def build_embedder(spec: CohereProviderSpec) -> CohereEmbeddingFunction: ...
@overload
def build_embedder(spec: CustomProviderSpec) -> EmbeddingFunction: ...
@overload
def build_embedder(
spec: GenerativeAiProviderSpec,
) -> GoogleGenerativeAiEmbeddingFunction: ...
@overload
def build_embedder(spec: HuggingFaceProviderSpec) -> HuggingFaceEmbeddingFunction: ...
@overload
def build_embedder(spec: OllamaProviderSpec) -> OllamaEmbeddingFunction: ...
@overload
def build_embedder(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
@overload
def build_embedder(spec: VertexAIProviderSpec) -> GoogleVertexEmbeddingFunction: ...
@overload
def build_embedder(spec: VoyageAIProviderSpec) -> VoyageAIEmbeddingFunction: ...
@overload
def build_embedder(spec: WatsonXProviderSpec) -> WatsonXEmbeddingFunction: ...
@overload
@deprecated(
'The "WatsonProviderSpec" provider spec is deprecated and will be removed in v1.0.0. Use "WatsonXProviderSpec" instead.'
)
def build_embedder(spec: WatsonProviderSpec) -> WatsonXEmbeddingFunction: ...
@overload
def build_embedder(
spec: SentenceTransformerProviderSpec,
) -> SentenceTransformerEmbeddingFunction: ...
@overload
def build_embedder(spec: InstructorProviderSpec) -> InstructorEmbeddingFunction: ...
@overload
def build_embedder(spec: JinaProviderSpec) -> JinaEmbeddingFunction: ...
@overload
def build_embedder(spec: RoboflowProviderSpec) -> RoboflowEmbeddingFunction: ...
@overload
def build_embedder(spec: OpenCLIPProviderSpec) -> OpenCLIPEmbeddingFunction: ...
@overload
def build_embedder(spec: Text2VecProviderSpec) -> Text2VecEmbeddingFunction: ...
@overload
def build_embedder(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
def build_embedder(spec):
"""Build an embedding function from either a provider spec or a provider instance.
Args:
spec: Either a provider specification dictionary or a provider instance.
Returns:
An embedding function instance. If a typed provider is passed, returns
the specific embedding function type.
Examples:
# From dictionary specification
embedder = build_embedder({
"provider": "openai",
"config": {"api_key": "sk-..."}
})
# From provider instance
provider = OpenAIProvider(api_key="sk-...")
embedder = build_embedder(provider)
"""
if isinstance(spec, BaseEmbeddingsProvider):
return build_embedder_from_provider(spec)
return build_embedder_from_dict(spec)
# Backward compatibility alias
get_embedding_function = build_embedder

View File

@@ -0,0 +1 @@
"""Embedding provider implementations."""

View File

@@ -0,0 +1,13 @@
"""AWS embedding providers."""
from crewai.rag.embeddings.providers.aws.bedrock import BedrockProvider
from crewai.rag.embeddings.providers.aws.types import (
BedrockProviderConfig,
BedrockProviderSpec,
)
__all__ = [
"BedrockProvider",
"BedrockProviderConfig",
"BedrockProviderSpec",
]

View File

@@ -0,0 +1,53 @@
"""Amazon Bedrock embeddings provider."""
from typing import Any
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
def create_aws_session() -> Any:
"""Create an AWS session for Bedrock.
Returns:
boto3.Session: AWS session object
Raises:
ImportError: If boto3 is not installed
ValueError: If AWS session creation fails
"""
try:
import boto3 # type: ignore[import]
return boto3.Session()
except ImportError as e:
raise ImportError(
"boto3 is required for amazon-bedrock embeddings. "
"Install it with: uv add boto3"
) from e
except Exception as e:
raise ValueError(
f"Failed to create AWS session for amazon-bedrock. "
f"Ensure AWS credentials are configured. Error: {e}"
) from e
class BedrockProvider(BaseEmbeddingsProvider[AmazonBedrockEmbeddingFunction]):
"""Amazon Bedrock embeddings provider."""
embedding_callable: type[AmazonBedrockEmbeddingFunction] = Field(
default=AmazonBedrockEmbeddingFunction,
description="Amazon Bedrock embedding function class",
)
model_name: str = Field(
default="amazon.titan-embed-text-v1",
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_BEDROCK_MODEL_NAME",
)
session: Any = Field(
default_factory=create_aws_session, description="AWS session object"
)

View File

@@ -0,0 +1,19 @@
"""Type definitions for AWS embedding providers."""
from typing import Annotated, Any, Literal
from typing_extensions import Required, TypedDict
class BedrockProviderConfig(TypedDict, total=False):
"""Configuration for Bedrock provider."""
model_name: Annotated[str, "amazon.titan-embed-text-v1"]
session: Any
class BedrockProviderSpec(TypedDict, total=False):
"""Bedrock provider specification."""
provider: Required[Literal["amazon-bedrock"]]
config: BedrockProviderConfig

View File

@@ -0,0 +1,13 @@
"""Cohere embedding providers."""
from crewai.rag.embeddings.providers.cohere.cohere_provider import CohereProvider
from crewai.rag.embeddings.providers.cohere.types import (
CohereProviderConfig,
CohereProviderSpec,
)
__all__ = [
"CohereProvider",
"CohereProviderConfig",
"CohereProviderSpec",
]

View File

@@ -0,0 +1,24 @@
"""Cohere embeddings provider."""
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class CohereProvider(BaseEmbeddingsProvider[CohereEmbeddingFunction]):
"""Cohere embeddings provider."""
embedding_callable: type[CohereEmbeddingFunction] = Field(
default=CohereEmbeddingFunction, description="Cohere embedding function class"
)
api_key: str = Field(
description="Cohere API key", validation_alias="EMBEDDINGS_COHERE_API_KEY"
)
model_name: str = Field(
default="large",
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_COHERE_MODEL_NAME",
)

View File

@@ -0,0 +1,19 @@
"""Type definitions for Cohere embedding providers."""
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class CohereProviderConfig(TypedDict, total=False):
"""Configuration for Cohere provider."""
api_key: str
model_name: Annotated[str, "large"]
class CohereProviderSpec(TypedDict, total=False):
"""Cohere provider specification."""
provider: Required[Literal["cohere"]]
config: CohereProviderConfig

View File

@@ -0,0 +1,13 @@
"""Custom embedding providers."""
from crewai.rag.embeddings.providers.custom.custom_provider import CustomProvider
from crewai.rag.embeddings.providers.custom.types import (
CustomProviderConfig,
CustomProviderSpec,
)
__all__ = [
"CustomProvider",
"CustomProviderConfig",
"CustomProviderSpec",
]

View File

@@ -0,0 +1,19 @@
"""Custom embeddings provider for user-defined embedding functions."""
from pydantic import Field
from pydantic_settings import SettingsConfigDict
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.custom.embedding_callable import (
CustomEmbeddingFunction,
)
class CustomProvider(BaseEmbeddingsProvider[CustomEmbeddingFunction]):
"""Custom embeddings provider for user-defined embedding functions."""
embedding_callable: type[CustomEmbeddingFunction] = Field(
..., description="Custom embedding function class"
)
model_config = SettingsConfigDict(extra="allow")

View File

@@ -0,0 +1,22 @@
"""Custom embedding function base implementation."""
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
from crewai.rag.core.types import Documents, Embeddings
class CustomEmbeddingFunction(EmbeddingFunction[Documents]):
"""Base class for custom embedding functions.
This provides a concrete implementation that can be subclassed for custom embeddings.
"""
def __call__(self, input: Documents) -> Embeddings:
"""Convert input documents to embeddings.
Args:
input: List of documents to embed.
Returns:
List of numpy arrays representing the embeddings.
"""
raise NotImplementedError("Subclasses must implement __call__ method")

View File

@@ -0,0 +1,19 @@
"""Type definitions for custom embedding providers."""
from typing import Literal
from chromadb.api.types import EmbeddingFunction
from typing_extensions import Required, TypedDict
class CustomProviderConfig(TypedDict, total=False):
"""Configuration for Custom provider."""
embedding_callable: type[EmbeddingFunction]
class CustomProviderSpec(TypedDict, total=False):
"""Custom provider specification."""
provider: Required[Literal["custom"]]
config: CustomProviderConfig

View File

@@ -0,0 +1,23 @@
"""Google embedding providers."""
from crewai.rag.embeddings.providers.google.generative_ai import (
GenerativeAiProvider,
)
from crewai.rag.embeddings.providers.google.types import (
GenerativeAiProviderConfig,
GenerativeAiProviderSpec,
VertexAIProviderConfig,
VertexAIProviderSpec,
)
from crewai.rag.embeddings.providers.google.vertex import (
VertexAIProvider,
)
__all__ = [
"GenerativeAiProvider",
"GenerativeAiProviderConfig",
"GenerativeAiProviderSpec",
"VertexAIProvider",
"VertexAIProviderConfig",
"VertexAIProviderSpec",
]

View File

@@ -0,0 +1,30 @@
"""Google Generative AI embeddings provider."""
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class GenerativeAiProvider(BaseEmbeddingsProvider[GoogleGenerativeAiEmbeddingFunction]):
"""Google Generative AI embeddings provider."""
embedding_callable: type[GoogleGenerativeAiEmbeddingFunction] = Field(
default=GoogleGenerativeAiEmbeddingFunction,
description="Google Generative AI embedding function class",
)
model_name: str = Field(
default="models/embedding-001",
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME",
)
api_key: str = Field(
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_API_KEY"
)
task_type: str = Field(
default="RETRIEVAL_DOCUMENT",
description="Task type for embeddings",
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE",
)

View File

@@ -0,0 +1,36 @@
"""Type definitions for Google embedding providers."""
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class GenerativeAiProviderConfig(TypedDict, total=False):
"""Configuration for Google Generative AI provider."""
api_key: str
model_name: Annotated[str, "models/embedding-001"]
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
class GenerativeAiProviderSpec(TypedDict):
"""Google Generative AI provider specification."""
provider: Literal["google-generativeai"]
config: GenerativeAiProviderConfig
class VertexAIProviderConfig(TypedDict, total=False):
"""Configuration for Vertex AI provider."""
api_key: str
model_name: Annotated[str, "textembedding-gecko"]
project_id: Annotated[str, "cloud-large-language-models"]
region: Annotated[str, "us-central1"]
class VertexAIProviderSpec(TypedDict, total=False):
"""Vertex AI provider specification."""
provider: Required[Literal["google-vertex"]]
config: VertexAIProviderConfig

View File

@@ -0,0 +1,35 @@
"""Google Vertex AI embeddings provider."""
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
"""Google Vertex AI embeddings provider."""
embedding_callable: type[GoogleVertexEmbeddingFunction] = Field(
default=GoogleVertexEmbeddingFunction,
description="Vertex AI embedding function class",
)
model_name: str = Field(
default="textembedding-gecko",
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
)
api_key: str = Field(
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_CLOUD_API_KEY"
)
project_id: str = Field(
default="cloud-large-language-models",
description="GCP project ID",
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_PROJECT",
)
region: str = Field(
default="us-central1",
description="GCP region",
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_REGION",
)

View File

@@ -0,0 +1,15 @@
"""HuggingFace embedding providers."""
from crewai.rag.embeddings.providers.huggingface.huggingface_provider import (
HuggingFaceProvider,
)
from crewai.rag.embeddings.providers.huggingface.types import (
HuggingFaceProviderConfig,
HuggingFaceProviderSpec,
)
__all__ = [
"HuggingFaceProvider",
"HuggingFaceProviderConfig",
"HuggingFaceProviderSpec",
]

View File

@@ -0,0 +1,20 @@
"""HuggingFace embeddings provider."""
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingServer,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
"""HuggingFace embeddings provider."""
embedding_callable: type[HuggingFaceEmbeddingServer] = Field(
default=HuggingFaceEmbeddingServer,
description="HuggingFace embedding function class",
)
url: str = Field(
description="HuggingFace API URL", validation_alias="EMBEDDINGS_HUGGINGFACE_URL"
)

View File

@@ -0,0 +1,18 @@
"""Type definitions for HuggingFace embedding providers."""
from typing import Literal
from typing_extensions import Required, TypedDict
class HuggingFaceProviderConfig(TypedDict, total=False):
"""Configuration for HuggingFace provider."""
url: str
class HuggingFaceProviderSpec(TypedDict, total=False):
"""HuggingFace provider specification."""
provider: Required[Literal["huggingface"]]
config: HuggingFaceProviderConfig

View File

@@ -0,0 +1,17 @@
"""IBM embedding providers."""
from crewai.rag.embeddings.providers.ibm.types import (
WatsonProviderSpec,
WatsonXProviderConfig,
WatsonXProviderSpec,
)
from crewai.rag.embeddings.providers.ibm.watsonx import (
WatsonXProvider,
)
__all__ = [
"WatsonProviderSpec",
"WatsonXProvider",
"WatsonXProviderConfig",
"WatsonXProviderSpec",
]

View File

@@ -0,0 +1,159 @@
"""IBM WatsonX embedding function implementation."""
from typing import cast
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
from typing_extensions import Unpack
from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderConfig
class WatsonXEmbeddingFunction(EmbeddingFunction[Documents]):
"""Embedding function for IBM WatsonX models."""
def __init__(self, **kwargs: Unpack[WatsonXProviderConfig]) -> None:
"""Initialize WatsonX embedding function.
Args:
**kwargs: Configuration parameters for WatsonX Embeddings and Credentials.
"""
super().__init__(**kwargs)
self._config = kwargs
@staticmethod
def name() -> str:
"""Return the name of the embedding function for ChromaDB compatibility."""
return "watsonx"
def __call__(self, input: Documents) -> Embeddings:
"""Generate embeddings for input documents.
Args:
input: List of documents to embed.
Returns:
List of embedding vectors.
"""
try:
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found, import-untyped]
from ibm_watsonx_ai import (
Credentials, # type: ignore[import-not-found, import-untyped]
)
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found, import-untyped]
EmbedTextParamsMetaNames as EmbedParams,
)
except ImportError as e:
raise ImportError(
"ibm-watsonx-ai is required for watsonx embeddings. "
"Install it with: uv add ibm-watsonx-ai"
) from e
if isinstance(input, str):
input = [input]
embeddings_config: dict = {
"model_id": self._config["model_id"],
}
if "params" in self._config and self._config["params"] is not None:
embeddings_config["params"] = self._config["params"]
if "project_id" in self._config and self._config["project_id"] is not None:
embeddings_config["project_id"] = self._config["project_id"]
if "space_id" in self._config and self._config["space_id"] is not None:
embeddings_config["space_id"] = self._config["space_id"]
if "api_client" in self._config and self._config["api_client"] is not None:
embeddings_config["api_client"] = self._config["api_client"]
if "verify" in self._config and self._config["verify"] is not None:
embeddings_config["verify"] = self._config["verify"]
if "persistent_connection" in self._config:
embeddings_config["persistent_connection"] = self._config[
"persistent_connection"
]
if "batch_size" in self._config:
embeddings_config["batch_size"] = self._config["batch_size"]
if "concurrency_limit" in self._config:
embeddings_config["concurrency_limit"] = self._config["concurrency_limit"]
if "max_retries" in self._config and self._config["max_retries"] is not None:
embeddings_config["max_retries"] = self._config["max_retries"]
if "delay_time" in self._config and self._config["delay_time"] is not None:
embeddings_config["delay_time"] = self._config["delay_time"]
if (
"retry_status_codes" in self._config
and self._config["retry_status_codes"] is not None
):
embeddings_config["retry_status_codes"] = self._config["retry_status_codes"]
if "credentials" in self._config and self._config["credentials"] is not None:
embeddings_config["credentials"] = self._config["credentials"]
else:
cred_config: dict = {}
if "url" in self._config and self._config["url"] is not None:
cred_config["url"] = self._config["url"]
if "api_key" in self._config and self._config["api_key"] is not None:
cred_config["api_key"] = self._config["api_key"]
if "name" in self._config and self._config["name"] is not None:
cred_config["name"] = self._config["name"]
if (
"iam_serviceid_crn" in self._config
and self._config["iam_serviceid_crn"] is not None
):
cred_config["iam_serviceid_crn"] = self._config["iam_serviceid_crn"]
if (
"trusted_profile_id" in self._config
and self._config["trusted_profile_id"] is not None
):
cred_config["trusted_profile_id"] = self._config["trusted_profile_id"]
if "token" in self._config and self._config["token"] is not None:
cred_config["token"] = self._config["token"]
if (
"projects_token" in self._config
and self._config["projects_token"] is not None
):
cred_config["projects_token"] = self._config["projects_token"]
if "username" in self._config and self._config["username"] is not None:
cred_config["username"] = self._config["username"]
if "password" in self._config and self._config["password"] is not None:
cred_config["password"] = self._config["password"]
if (
"instance_id" in self._config
and self._config["instance_id"] is not None
):
cred_config["instance_id"] = self._config["instance_id"]
if "version" in self._config and self._config["version"] is not None:
cred_config["version"] = self._config["version"]
if (
"bedrock_url" in self._config
and self._config["bedrock_url"] is not None
):
cred_config["bedrock_url"] = self._config["bedrock_url"]
if (
"platform_url" in self._config
and self._config["platform_url"] is not None
):
cred_config["platform_url"] = self._config["platform_url"]
if "proxies" in self._config and self._config["proxies"] is not None:
cred_config["proxies"] = self._config["proxies"]
if (
"verify" not in embeddings_config
and "verify" in self._config
and self._config["verify"] is not None
):
cred_config["verify"] = self._config["verify"]
if cred_config:
embeddings_config["credentials"] = Credentials(**cred_config)
if "params" not in embeddings_config:
embeddings_config["params"] = {
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
EmbedParams.RETURN_OPTIONS: {"input_text": True},
}
embedding = watson_models.Embeddings(**embeddings_config)
try:
embeddings = embedding.embed_documents(input)
return cast(Embeddings, embeddings)
except Exception as e:
print(f"Error during WatsonX embedding: {e}")
raise

View File

@@ -0,0 +1,58 @@
"""Type definitions for IBM WatsonX embedding providers."""
from typing import Annotated, Any, Literal
from typing_extensions import Required, TypedDict, deprecated
class WatsonXProviderConfig(TypedDict, total=False):
"""Configuration for WatsonX provider."""
model_id: str
url: str
params: dict[str, str | dict[str, str]]
credentials: Any
project_id: str
space_id: str
api_client: Any
verify: bool | str
persistent_connection: Annotated[bool, True]
batch_size: Annotated[int, 100]
concurrency_limit: Annotated[int, 10]
max_retries: int
delay_time: float
retry_status_codes: list[int]
api_key: str
name: str
iam_serviceid_crn: str
trusted_profile_id: str
token: str
projects_token: str
username: str
password: str
instance_id: str
version: str
bedrock_url: str
platform_url: str
proxies: dict
class WatsonXProviderSpec(TypedDict, total=False):
"""WatsonX provider specification."""
provider: Required[Literal["watsonx"]]
config: WatsonXProviderConfig
@deprecated(
'The "WatsonProviderSpec" provider spec is deprecated and will be removed in v1.0.0. Use "WatsonXProviderSpec" instead.'
)
class WatsonProviderSpec(TypedDict, total=False):
"""Watson provider specification (deprecated).
Notes:
- This is deprecated. Use WatsonXProviderSpec with provider="watsonx" instead.
"""
provider: Required[Literal["watson"]]
config: WatsonXProviderConfig

View File

@@ -0,0 +1,142 @@
"""IBM WatsonX embeddings provider."""
from typing import Any
from pydantic import Field, model_validator
from typing_extensions import Self
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
WatsonXEmbeddingFunction,
)
class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]):
"""IBM WatsonX embeddings provider.
Note: Requires custom implementation as WatsonX uses a different interface.
"""
embedding_callable: type[WatsonXEmbeddingFunction] = Field(
default=WatsonXEmbeddingFunction, description="WatsonX embedding function class"
)
model_id: str = Field(
description="WatsonX model ID", validation_alias="EMBEDDINGS_WATSONX_MODEL_ID"
)
params: dict[str, str | dict[str, str]] | None = Field(
default=None, description="Additional parameters"
)
credentials: Any | None = Field(default=None, description="WatsonX credentials")
project_id: str | None = Field(
default=None,
description="WatsonX project ID",
validation_alias="EMBEDDINGS_WATSONX_PROJECT_ID",
)
space_id: str | None = Field(
default=None,
description="WatsonX space ID",
validation_alias="EMBEDDINGS_WATSONX_SPACE_ID",
)
api_client: Any | None = Field(default=None, description="WatsonX API client")
verify: bool | str | None = Field(
default=None,
description="SSL verification",
validation_alias="EMBEDDINGS_WATSONX_VERIFY",
)
persistent_connection: bool = Field(
default=True,
description="Use persistent connection",
validation_alias="EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION",
)
batch_size: int = Field(
default=100,
description="Batch size for processing",
validation_alias="EMBEDDINGS_WATSONX_BATCH_SIZE",
)
concurrency_limit: int = Field(
default=10,
description="Concurrency limit",
validation_alias="EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT",
)
max_retries: int | None = Field(
default=None,
description="Maximum retries",
validation_alias="EMBEDDINGS_WATSONX_MAX_RETRIES",
)
delay_time: float | None = Field(
default=None,
description="Delay time between retries",
validation_alias="EMBEDDINGS_WATSONX_DELAY_TIME",
)
retry_status_codes: list[int] | None = Field(
default=None, description="HTTP status codes to retry on"
)
url: str = Field(
description="WatsonX API URL", validation_alias="EMBEDDINGS_WATSONX_URL"
)
api_key: str = Field(
description="WatsonX API key", validation_alias="EMBEDDINGS_WATSONX_API_KEY"
)
name: str | None = Field(
default=None,
description="Service name",
validation_alias="EMBEDDINGS_WATSONX_NAME",
)
iam_serviceid_crn: str | None = Field(
default=None,
description="IAM service ID CRN",
validation_alias="EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN",
)
trusted_profile_id: str | None = Field(
default=None,
description="Trusted profile ID",
validation_alias="EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID",
)
token: str | None = Field(
default=None,
description="Bearer token",
validation_alias="EMBEDDINGS_WATSONX_TOKEN",
)
projects_token: str | None = Field(
default=None,
description="Projects token",
validation_alias="EMBEDDINGS_WATSONX_PROJECTS_TOKEN",
)
username: str | None = Field(
default=None,
description="Username",
validation_alias="EMBEDDINGS_WATSONX_USERNAME",
)
password: str | None = Field(
default=None,
description="Password",
validation_alias="EMBEDDINGS_WATSONX_PASSWORD",
)
instance_id: str | None = Field(
default=None,
description="Service instance ID",
validation_alias="EMBEDDINGS_WATSONX_INSTANCE_ID",
)
version: str | None = Field(
default=None,
description="API version",
validation_alias="EMBEDDINGS_WATSONX_VERSION",
)
bedrock_url: str | None = Field(
default=None,
description="Bedrock URL",
validation_alias="EMBEDDINGS_WATSONX_BEDROCK_URL",
)
platform_url: str | None = Field(
default=None,
description="Platform URL",
validation_alias="EMBEDDINGS_WATSONX_PLATFORM_URL",
)
proxies: dict | None = Field(default=None, description="Proxy configuration")
@model_validator(mode="after")
def validate_space_or_project(self) -> Self:
"""Validate that either space_id or project_id is provided."""
if not self.space_id and not self.project_id:
raise ValueError("One of 'space_id' or 'project_id' must be provided")
return self

View File

@@ -0,0 +1,15 @@
"""Instructor embedding providers."""
from crewai.rag.embeddings.providers.instructor.instructor_provider import (
InstructorProvider,
)
from crewai.rag.embeddings.providers.instructor.types import (
InstructorProviderConfig,
InstructorProviderSpec,
)
__all__ = [
"InstructorProvider",
"InstructorProviderConfig",
"InstructorProviderSpec",
]

View File

@@ -0,0 +1,32 @@
"""Instructor embeddings provider."""
from chromadb.utils.embedding_functions.instructor_embedding_function import (
InstructorEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class InstructorProvider(BaseEmbeddingsProvider[InstructorEmbeddingFunction]):
"""Instructor embeddings provider."""
embedding_callable: type[InstructorEmbeddingFunction] = Field(
default=InstructorEmbeddingFunction,
description="Instructor embedding function class",
)
model_name: str = Field(
default="hkunlp/instructor-base",
description="Model name to use",
validation_alias="EMBEDDINGS_INSTRUCTOR_MODEL_NAME",
)
device: str = Field(
default="cpu",
description="Device to run model on (cpu or cuda)",
validation_alias="EMBEDDINGS_INSTRUCTOR_DEVICE",
)
instruction: str | None = Field(
default=None,
description="Instruction for embeddings",
validation_alias="EMBEDDINGS_INSTRUCTOR_INSTRUCTION",
)

View File

@@ -0,0 +1,20 @@
"""Type definitions for Instructor embedding providers."""
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class InstructorProviderConfig(TypedDict, total=False):
"""Configuration for Instructor provider."""
model_name: Annotated[str, "hkunlp/instructor-base"]
device: Annotated[str, "cpu"]
instruction: str
class InstructorProviderSpec(TypedDict, total=False):
"""Instructor provider specification."""
provider: Required[Literal["instructor"]]
config: InstructorProviderConfig

View File

@@ -0,0 +1,13 @@
"""Jina embedding providers."""
from crewai.rag.embeddings.providers.jina.jina_provider import JinaProvider
from crewai.rag.embeddings.providers.jina.types import (
JinaProviderConfig,
JinaProviderSpec,
)
__all__ = [
"JinaProvider",
"JinaProviderConfig",
"JinaProviderSpec",
]

View File

@@ -0,0 +1,24 @@
"""Jina embeddings provider."""
from chromadb.utils.embedding_functions.jina_embedding_function import (
JinaEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class JinaProvider(BaseEmbeddingsProvider[JinaEmbeddingFunction]):
"""Jina embeddings provider."""
embedding_callable: type[JinaEmbeddingFunction] = Field(
default=JinaEmbeddingFunction, description="Jina embedding function class"
)
api_key: str = Field(
description="Jina API key", validation_alias="EMBEDDINGS_JINA_API_KEY"
)
model_name: str = Field(
default="jina-embeddings-v2-base-en",
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_JINA_MODEL_NAME",
)

View File

@@ -0,0 +1,19 @@
"""Type definitions for Jina embedding providers."""
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class JinaProviderConfig(TypedDict, total=False):
"""Configuration for Jina provider."""
api_key: str
model_name: Annotated[str, "jina-embeddings-v2-base-en"]
class JinaProviderSpec(TypedDict, total=False):
"""Jina provider specification."""
provider: Required[Literal["jina"]]
config: JinaProviderConfig

View File

@@ -0,0 +1,15 @@
"""Microsoft embedding providers."""
from crewai.rag.embeddings.providers.microsoft.azure import (
AzureProvider,
)
from crewai.rag.embeddings.providers.microsoft.types import (
AzureProviderConfig,
AzureProviderSpec,
)
__all__ = [
"AzureProvider",
"AzureProviderConfig",
"AzureProviderSpec",
]

View File

@@ -0,0 +1,60 @@
"""Azure OpenAI embeddings provider."""
from typing import Any
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
"""Azure OpenAI embeddings provider."""
embedding_callable: type[OpenAIEmbeddingFunction] = Field(
default=OpenAIEmbeddingFunction,
description="Azure OpenAI embedding function class",
)
api_key: str = Field(
description="Azure API key", validation_alias="EMBEDDINGS_OPENAI_API_KEY"
)
api_base: str | None = Field(
default=None,
description="Azure endpoint URL",
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
)
api_type: str = Field(
default="azure",
description="API type for Azure",
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
)
api_version: str | None = Field(
default=None,
description="Azure API version",
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
)
model_name: str = Field(
default="text-embedding-ada-002",
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
)
default_headers: dict[str, Any] | None = Field(
default=None, description="Default headers for API requests"
)
dimensions: int | None = Field(
default=None,
description="Embedding dimensions",
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
)
deployment_id: str | None = Field(
default=None,
description="Azure deployment ID",
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
)
organization_id: str | None = Field(
default=None,
description="Organization ID",
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
)

View File

@@ -0,0 +1,26 @@
"""Type definitions for Microsoft Azure embedding providers."""
from typing import Annotated, Any, Literal
from typing_extensions import Required, TypedDict
class AzureProviderConfig(TypedDict, total=False):
"""Configuration for Azure provider."""
api_key: str
api_base: str
api_type: Annotated[str, "azure"]
api_version: str
model_name: Annotated[str, "text-embedding-ada-002"]
default_headers: dict[str, Any]
dimensions: int
deployment_id: str
organization_id: str
class AzureProviderSpec(TypedDict, total=False):
"""Azure provider specification."""
provider: Required[Literal["azure"]]
config: AzureProviderConfig

View File

@@ -0,0 +1,15 @@
"""Ollama embedding providers."""
from crewai.rag.embeddings.providers.ollama.ollama_provider import (
OllamaProvider,
)
from crewai.rag.embeddings.providers.ollama.types import (
OllamaProviderConfig,
OllamaProviderSpec,
)
__all__ = [
"OllamaProvider",
"OllamaProviderConfig",
"OllamaProviderSpec",
]

View File

@@ -0,0 +1,25 @@
"""Ollama embeddings provider."""
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class OllamaProvider(BaseEmbeddingsProvider[OllamaEmbeddingFunction]):
"""Ollama embeddings provider."""
embedding_callable: type[OllamaEmbeddingFunction] = Field(
default=OllamaEmbeddingFunction, description="Ollama embedding function class"
)
url: str = Field(
default="http://localhost:11434/api/embeddings",
description="Ollama API endpoint URL",
validation_alias="EMBEDDINGS_OLLAMA_URL",
)
model_name: str = Field(
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_OLLAMA_MODEL_NAME",
)

View File

@@ -0,0 +1,19 @@
"""Type definitions for Ollama embedding providers."""
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class OllamaProviderConfig(TypedDict, total=False):
"""Configuration for Ollama provider."""
url: Annotated[str, "http://localhost:11434/api/embeddings"]
model_name: str
class OllamaProviderSpec(TypedDict, total=False):
"""Ollama provider specification."""
provider: Required[Literal["ollama"]]
config: OllamaProviderConfig

View File

@@ -0,0 +1,13 @@
"""ONNX embedding providers."""
from crewai.rag.embeddings.providers.onnx.onnx_provider import ONNXProvider
from crewai.rag.embeddings.providers.onnx.types import (
ONNXProviderConfig,
ONNXProviderSpec,
)
__all__ = [
"ONNXProvider",
"ONNXProviderConfig",
"ONNXProviderSpec",
]

View File

@@ -0,0 +1,19 @@
"""ONNX embeddings provider."""
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class ONNXProvider(BaseEmbeddingsProvider[ONNXMiniLM_L6_V2]):
"""ONNX embeddings provider."""
embedding_callable: type[ONNXMiniLM_L6_V2] = Field(
default=ONNXMiniLM_L6_V2, description="ONNX MiniLM embedding function class"
)
preferred_providers: list[str] | None = Field(
default=None,
description="Preferred ONNX execution providers",
validation_alias="EMBEDDINGS_ONNX_PREFERRED_PROVIDERS",
)

View File

@@ -0,0 +1,18 @@
"""Type definitions for ONNX embedding providers."""
from typing import Literal
from typing_extensions import Required, TypedDict
class ONNXProviderConfig(TypedDict, total=False):
"""Configuration for ONNX provider."""
preferred_providers: list[str]
class ONNXProviderSpec(TypedDict, total=False):
"""ONNX provider specification."""
provider: Required[Literal["onnx"]]
config: ONNXProviderConfig

View File

@@ -0,0 +1,15 @@
"""OpenAI embedding providers."""
from crewai.rag.embeddings.providers.openai.openai_provider import (
OpenAIProvider,
)
from crewai.rag.embeddings.providers.openai.types import (
OpenAIProviderConfig,
OpenAIProviderSpec,
)
__all__ = [
"OpenAIProvider",
"OpenAIProviderConfig",
"OpenAIProviderSpec",
]

View File

@@ -0,0 +1,62 @@
"""OpenAI embeddings provider."""
from typing import Any
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
"""OpenAI embeddings provider."""
embedding_callable: type[OpenAIEmbeddingFunction] = Field(
default=OpenAIEmbeddingFunction,
description="OpenAI embedding function class",
)
api_key: str | None = Field(
default=None,
description="OpenAI API key",
validation_alias="EMBEDDINGS_OPENAI_API_KEY",
)
model_name: str = Field(
default="text-embedding-ada-002",
description="Model name to use for embeddings",
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
)
api_base: str | None = Field(
default=None,
description="Base URL for API requests",
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
)
api_type: str | None = Field(
default=None,
description="API type (e.g., 'azure')",
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
)
api_version: str | None = Field(
default=None,
description="API version",
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
)
default_headers: dict[str, Any] | None = Field(
default=None, description="Default headers for API requests"
)
dimensions: int | None = Field(
default=None,
description="Embedding dimensions",
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
)
deployment_id: str | None = Field(
default=None,
description="Azure deployment ID",
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
)
organization_id: str | None = Field(
default=None,
description="OpenAI organization ID",
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
)

View File

@@ -0,0 +1,26 @@
"""Type definitions for OpenAI embedding providers."""
from typing import Annotated, Any, Literal
from typing_extensions import Required, TypedDict
class OpenAIProviderConfig(TypedDict, total=False):
"""Configuration for OpenAI provider."""
api_key: str
model_name: Annotated[str, "text-embedding-ada-002"]
api_base: str
api_type: str
api_version: str
default_headers: dict[str, Any]
dimensions: int
deployment_id: str
organization_id: str
class OpenAIProviderSpec(TypedDict, total=False):
"""OpenAI provider specification."""
provider: Required[Literal["openai"]]
config: OpenAIProviderConfig

View File

@@ -0,0 +1,15 @@
"""OpenCLIP embedding providers."""
from crewai.rag.embeddings.providers.openclip.openclip_provider import (
OpenCLIPProvider,
)
from crewai.rag.embeddings.providers.openclip.types import (
OpenCLIPProviderConfig,
OpenCLIPProviderSpec,
)
__all__ = [
"OpenCLIPProvider",
"OpenCLIPProviderConfig",
"OpenCLIPProviderSpec",
]

View File

@@ -0,0 +1,32 @@
"""OpenCLIP embeddings provider."""
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
OpenCLIPEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class OpenCLIPProvider(BaseEmbeddingsProvider[OpenCLIPEmbeddingFunction]):
"""OpenCLIP embeddings provider."""
embedding_callable: type[OpenCLIPEmbeddingFunction] = Field(
default=OpenCLIPEmbeddingFunction,
description="OpenCLIP embedding function class",
)
model_name: str = Field(
default="ViT-B-32",
description="Model name to use",
validation_alias="EMBEDDINGS_OPENCLIP_MODEL_NAME",
)
checkpoint: str = Field(
default="laion2b_s34b_b79k",
description="Model checkpoint",
validation_alias="EMBEDDINGS_OPENCLIP_CHECKPOINT",
)
device: str | None = Field(
default="cpu",
description="Device to run model on",
validation_alias="EMBEDDINGS_OPENCLIP_DEVICE",
)

View File

@@ -0,0 +1,20 @@
"""Type definitions for OpenCLIP embedding providers."""
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class OpenCLIPProviderConfig(TypedDict, total=False):
"""Configuration for OpenCLIP provider."""
model_name: Annotated[str, "ViT-B-32"]
checkpoint: Annotated[str, "laion2b_s34b_b79k"]
device: Annotated[str, "cpu"]
class OpenCLIPProviderSpec(TypedDict):
"""OpenCLIP provider specification."""
provider: Required[Literal["openclip"]]
config: OpenCLIPProviderConfig

View File

@@ -0,0 +1,15 @@
"""Roboflow embedding providers."""
from crewai.rag.embeddings.providers.roboflow.roboflow_provider import (
RoboflowProvider,
)
from crewai.rag.embeddings.providers.roboflow.types import (
RoboflowProviderConfig,
RoboflowProviderSpec,
)
__all__ = [
"RoboflowProvider",
"RoboflowProviderConfig",
"RoboflowProviderSpec",
]

View File

@@ -0,0 +1,27 @@
"""Roboflow embeddings provider."""
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
RoboflowEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class RoboflowProvider(BaseEmbeddingsProvider[RoboflowEmbeddingFunction]):
"""Roboflow embeddings provider."""
embedding_callable: type[RoboflowEmbeddingFunction] = Field(
default=RoboflowEmbeddingFunction,
description="Roboflow embedding function class",
)
api_key: str = Field(
default="",
description="Roboflow API key",
validation_alias="EMBEDDINGS_ROBOFLOW_API_KEY",
)
api_url: str = Field(
default="https://infer.roboflow.com",
description="Roboflow API URL",
validation_alias="EMBEDDINGS_ROBOFLOW_API_URL",
)

View File

@@ -0,0 +1,19 @@
"""Type definitions for Roboflow embedding providers."""
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class RoboflowProviderConfig(TypedDict, total=False):
"""Configuration for Roboflow provider."""
api_key: Annotated[str, ""]
api_url: Annotated[str, "https://infer.roboflow.com"]
class RoboflowProviderSpec(TypedDict):
"""Roboflow provider specification."""
provider: Required[Literal["roboflow"]]
config: RoboflowProviderConfig

View File

@@ -0,0 +1,15 @@
"""SentenceTransformer embedding providers."""
from crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider import (
SentenceTransformerProvider,
)
from crewai.rag.embeddings.providers.sentence_transformer.types import (
SentenceTransformerProviderConfig,
SentenceTransformerProviderSpec,
)
__all__ = [
"SentenceTransformerProvider",
"SentenceTransformerProviderConfig",
"SentenceTransformerProviderSpec",
]

View File

@@ -0,0 +1,34 @@
"""SentenceTransformer embeddings provider."""
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
SentenceTransformerEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class SentenceTransformerProvider(
BaseEmbeddingsProvider[SentenceTransformerEmbeddingFunction]
):
"""SentenceTransformer embeddings provider."""
embedding_callable: type[SentenceTransformerEmbeddingFunction] = Field(
default=SentenceTransformerEmbeddingFunction,
description="SentenceTransformer embedding function class",
)
model_name: str = Field(
default="all-MiniLM-L6-v2",
description="Model name to use",
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME",
)
device: str = Field(
default="cpu",
description="Device to run model on (cpu or cuda)",
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE",
)
normalize_embeddings: bool = Field(
default=False,
description="Whether to normalize embeddings",
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
)

View File

@@ -0,0 +1,20 @@
"""Type definitions for SentenceTransformer embedding providers."""
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class SentenceTransformerProviderConfig(TypedDict, total=False):
"""Configuration for SentenceTransformer provider."""
model_name: Annotated[str, "all-MiniLM-L6-v2"]
device: Annotated[str, "cpu"]
normalize_embeddings: Annotated[bool, False]
class SentenceTransformerProviderSpec(TypedDict):
"""SentenceTransformer provider specification."""
provider: Required[Literal["sentence-transformer"]]
config: SentenceTransformerProviderConfig

View File

@@ -0,0 +1,15 @@
"""Text2Vec embedding providers."""
from crewai.rag.embeddings.providers.text2vec.text2vec_provider import (
Text2VecProvider,
)
from crewai.rag.embeddings.providers.text2vec.types import (
Text2VecProviderConfig,
Text2VecProviderSpec,
)
__all__ = [
"Text2VecProvider",
"Text2VecProviderConfig",
"Text2VecProviderSpec",
]

View File

@@ -0,0 +1,22 @@
"""Text2Vec embeddings provider."""
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
Text2VecEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class Text2VecProvider(BaseEmbeddingsProvider[Text2VecEmbeddingFunction]):
"""Text2Vec embeddings provider."""
embedding_callable: type[Text2VecEmbeddingFunction] = Field(
default=Text2VecEmbeddingFunction,
description="Text2Vec embedding function class",
)
model_name: str = Field(
default="shibing624/text2vec-base-chinese",
description="Model name to use",
validation_alias="EMBEDDINGS_TEXT2VEC_MODEL_NAME",
)

View File

@@ -0,0 +1,18 @@
"""Type definitions for Text2Vec embedding providers."""
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class Text2VecProviderConfig(TypedDict, total=False):
"""Configuration for Text2Vec provider."""
model_name: Annotated[str, "shibing624/text2vec-base-chinese"]
class Text2VecProviderSpec(TypedDict):
"""Text2Vec provider specification."""
provider: Required[Literal["text2vec"]]
config: Text2VecProviderConfig

View File

@@ -0,0 +1,15 @@
"""VoyageAI embedding providers."""
from crewai.rag.embeddings.providers.voyageai.types import (
VoyageAIProviderConfig,
VoyageAIProviderSpec,
)
from crewai.rag.embeddings.providers.voyageai.voyageai_provider import (
VoyageAIProvider,
)
__all__ = [
"VoyageAIProvider",
"VoyageAIProviderConfig",
"VoyageAIProviderSpec",
]

View File

@@ -0,0 +1,62 @@
"""VoyageAI embedding function implementation."""
from typing import cast
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
from typing_extensions import Unpack
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderConfig
class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
"""Embedding function for VoyageAI models."""
def __init__(self, **kwargs: Unpack[VoyageAIProviderConfig]) -> None:
"""Initialize VoyageAI embedding function.
Args:
**kwargs: Configuration parameters for VoyageAI.
"""
try:
import voyageai # type: ignore[import-not-found]
except ImportError as e:
raise ImportError(
"voyageai is required for voyageai embeddings. "
"Install it with: uv add voyageai"
) from e
self._config = kwargs
self._client = voyageai.Client(
api_key=kwargs["api_key"],
max_retries=kwargs.get("max_retries", 0),
timeout=kwargs.get("timeout"),
)
@staticmethod
def name() -> str:
"""Return the name of the embedding function for ChromaDB compatibility."""
return "voyageai"
def __call__(self, input: Documents) -> Embeddings:
"""Generate embeddings for input documents.
Args:
input: List of documents to embed.
Returns:
List of embedding vectors.
"""
if isinstance(input, str):
input = [input]
result = self._client.embed(
texts=input,
model=self._config.get("model", "voyage-2"),
input_type=self._config.get("input_type"),
truncation=self._config.get("truncation", True),
output_dtype=self._config.get("output_dtype"),
output_dimension=self._config.get("output_dimension"),
)
return cast(Embeddings, result.embeddings)

View File

@@ -0,0 +1,25 @@
"""Type definitions for VoyageAI embedding providers."""
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
class VoyageAIProviderConfig(TypedDict, total=False):
"""Configuration for VoyageAI provider."""
api_key: str
model: Annotated[str, "voyage-2"]
input_type: str
truncation: Annotated[bool, True]
output_dtype: str
output_dimension: int
max_retries: Annotated[int, 0]
timeout: float
class VoyageAIProviderSpec(TypedDict):
"""VoyageAI provider specification."""
provider: Required[Literal["voyageai"]]
config: VoyageAIProviderConfig

View File

@@ -0,0 +1,55 @@
"""Voyage AI embeddings provider."""
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
VoyageAIEmbeddingFunction,
)
class VoyageAIProvider(BaseEmbeddingsProvider[VoyageAIEmbeddingFunction]):
"""Voyage AI embeddings provider."""
embedding_callable: type[VoyageAIEmbeddingFunction] = Field(
default=VoyageAIEmbeddingFunction,
description="Voyage AI embedding function class",
)
model: str = Field(
default="voyage-2",
description="Model to use for embeddings",
validation_alias="EMBEDDINGS_VOYAGEAI_MODEL",
)
api_key: str = Field(
description="Voyage AI API key", validation_alias="EMBEDDINGS_VOYAGEAI_API_KEY"
)
input_type: str | None = Field(
default=None,
description="Input type for embeddings",
validation_alias="EMBEDDINGS_VOYAGEAI_INPUT_TYPE",
)
truncation: bool = Field(
default=True,
description="Whether to truncate inputs",
validation_alias="EMBEDDINGS_VOYAGEAI_TRUNCATION",
)
output_dtype: str | None = Field(
default=None,
description="Output data type",
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE",
)
output_dimension: int | None = Field(
default=None,
description="Output dimension",
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION",
)
max_retries: int = Field(
default=0,
description="Maximum retries for API calls",
validation_alias="EMBEDDINGS_VOYAGEAI_MAX_RETRIES",
)
timeout: float | None = Field(
default=None,
description="Timeout for API calls",
validation_alias="EMBEDDINGS_VOYAGEAI_TIMEOUT",
)

View File

@@ -0,0 +1,78 @@
"""Type definitions for the embeddings module."""
from typing import Literal, TypeAlias
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
from crewai.rag.embeddings.providers.google.types import (
GenerativeAiProviderSpec,
VertexAIProviderSpec,
)
from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderSpec
from crewai.rag.embeddings.providers.ibm.types import (
WatsonProviderSpec,
WatsonXProviderSpec,
)
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec
from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
from crewai.rag.embeddings.providers.openclip.types import OpenCLIPProviderSpec
from crewai.rag.embeddings.providers.roboflow.types import RoboflowProviderSpec
from crewai.rag.embeddings.providers.sentence_transformer.types import (
SentenceTransformerProviderSpec,
)
from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
ProviderSpec = (
AzureProviderSpec
| BedrockProviderSpec
| CohereProviderSpec
| CustomProviderSpec
| GenerativeAiProviderSpec
| HuggingFaceProviderSpec
| InstructorProviderSpec
| JinaProviderSpec
| OllamaProviderSpec
| ONNXProviderSpec
| OpenAIProviderSpec
| OpenCLIPProviderSpec
| RoboflowProviderSpec
| SentenceTransformerProviderSpec
| Text2VecProviderSpec
| VertexAIProviderSpec
| VoyageAIProviderSpec
| WatsonProviderSpec # Deprecated, use WatsonXProviderSpec
| WatsonXProviderSpec
)
AllowedEmbeddingProviders = Literal[
"azure",
"amazon-bedrock",
"cohere",
"custom",
"google-generativeai",
"google-vertex",
"huggingface",
"instructor",
"jina",
"ollama",
"onnx",
"openai",
"openclip",
"roboflow",
"sentence-transformer",
"text2vec",
"voyageai",
"watsonx",
"watson", # for backward compatibility until v1.0.0
]
EmbedderConfig: TypeAlias = (
ProviderSpec | BaseEmbeddingsProvider | type[BaseEmbeddingsProvider]
)

View File

@@ -0,0 +1,47 @@
"""Factory functions for creating RAG clients from configuration."""
from typing import cast
from crewai.rag.config.optional_imports.protocols import (
ChromaFactoryModule,
QdrantFactoryModule,
)
from crewai.rag.config.types import RagConfigType
from crewai.rag.core.base_client import BaseClient
from crewai.core.utilities.import_utils import require
def create_client(config: RagConfigType) -> BaseClient:
"""Create a client from configuration using the appropriate factory.
Args:
config: The RAG client configuration.
Returns:
The created client instance.
Raises:
ValueError: If the configuration provider is not supported.
"""
if config.provider == "chromadb":
chromadb_mod = cast(
ChromaFactoryModule,
require(
"crewai.rag.chromadb.factory",
purpose="The 'chromadb' provider",
),
)
return chromadb_mod.create_client(config)
if config.provider == "qdrant":
qdrant_mod = cast(
QdrantFactoryModule,
require(
"crewai.rag.qdrant.factory",
purpose="The 'qdrant' provider",
),
)
return qdrant_mod.create_client(config)
raise ValueError(f"Unsupported provider: {config.provider}")

View File

@@ -0,0 +1 @@
"""Qdrant vector database client implementation."""

View File

@@ -0,0 +1,518 @@
"""Qdrant client implementation."""
from typing import Any, cast
from typing_extensions import Unpack
from crewai.rag.core.base_client import (
BaseClient,
BaseCollectionAddParams,
BaseCollectionParams,
BaseCollectionSearchParams,
)
from crewai.rag.core.exceptions import ClientMethodMismatchError
from crewai.rag.qdrant.types import (
AsyncEmbeddingFunction,
EmbeddingFunction,
QdrantClientType,
QdrantCollectionCreateParams,
)
from crewai.rag.qdrant.utils import (
_create_point_from_document,
_get_collection_params,
_is_async_client,
_is_async_embedding_function,
_is_sync_client,
_prepare_search_params,
_process_search_results,
)
from crewai.rag.types import SearchResult
class QdrantClient(BaseClient):
"""Qdrant implementation of the BaseClient protocol.
Provides vector database operations for Qdrant, supporting both
synchronous and asynchronous clients.
Attributes:
client: Qdrant client instance (QdrantClient or AsyncQdrantClient).
embedding_function: Function to generate embeddings for documents.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
"""
def __init__(
self,
client: QdrantClientType,
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
default_limit: int = 5,
default_score_threshold: float = 0.6,
default_batch_size: int = 100,
) -> None:
"""Initialize QdrantClient with client and embedding function.
Args:
client: Pre-configured Qdrant client instance.
embedding_function: Embedding function for text to vector conversion.
default_limit: Default number of results to return in searches.
default_score_threshold: Default minimum score for search results.
default_batch_size: Default batch size for adding documents.
"""
self.client = client
self.embedding_function = embedding_function
self.default_limit = default_limit
self.default_score_threshold = default_score_threshold
self.default_batch_size = default_batch_size
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
"""Create a new collection in Qdrant.
Keyword Args:
collection_name: Name of the collection to create. Must be unique.
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
sparse_vectors_config: Optional sparse vector configuration.
shard_number: Optional number of shards.
replication_factor: Optional replication factor.
write_consistency_factor: Optional write consistency factor.
on_disk_payload: Optional flag to store payload on disk.
hnsw_config: Optional HNSW index configuration.
optimizers_config: Optional optimizer configuration.
wal_config: Optional write-ahead log configuration.
quantization_config: Optional quantization configuration.
init_from: Optional collection to initialize from.
timeout: Optional timeout for the operation.
Raises:
ValueError: If collection with the same name already exists.
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="create_collection",
expected_client="QdrantClient",
alt_method="acreate_collection",
alt_client="AsyncQdrantClient",
)
collection_name = kwargs["collection_name"]
if self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' already exists")
params = _get_collection_params(kwargs)
self.client.create_collection(**params)
async def acreate_collection(
self, **kwargs: Unpack[QdrantCollectionCreateParams]
) -> None:
"""Create a new collection in Qdrant asynchronously.
Keyword Args:
collection_name: Name of the collection to create. Must be unique.
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
sparse_vectors_config: Optional sparse vector configuration.
shard_number: Optional number of shards.
replication_factor: Optional replication factor.
write_consistency_factor: Optional write consistency factor.
on_disk_payload: Optional flag to store payload on disk.
hnsw_config: Optional HNSW index configuration.
optimizers_config: Optional optimizer configuration.
wal_config: Optional write-ahead log configuration.
quantization_config: Optional quantization configuration.
init_from: Optional collection to initialize from.
timeout: Optional timeout for the operation.
Raises:
ValueError: If collection with the same name already exists.
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="acreate_collection",
expected_client="AsyncQdrantClient",
alt_method="create_collection",
alt_client="QdrantClient",
)
collection_name = kwargs["collection_name"]
if await self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' already exists")
params = _get_collection_params(kwargs)
await self.client.create_collection(**params)
def get_or_create_collection(
self, **kwargs: Unpack[QdrantCollectionCreateParams]
) -> Any:
"""Get an existing collection or create it if it doesn't exist.
Keyword Args:
collection_name: Name of the collection to get or create.
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
sparse_vectors_config: Optional sparse vector configuration.
shard_number: Optional number of shards.
replication_factor: Optional replication factor.
write_consistency_factor: Optional write consistency factor.
on_disk_payload: Optional flag to store payload on disk.
hnsw_config: Optional HNSW index configuration.
optimizers_config: Optional optimizer configuration.
wal_config: Optional write-ahead log configuration.
quantization_config: Optional quantization configuration.
init_from: Optional collection to initialize from.
timeout: Optional timeout for the operation.
Returns:
Collection info dict with name and other metadata.
Raises:
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="get_or_create_collection",
expected_client="QdrantClient",
alt_method="aget_or_create_collection",
alt_client="AsyncQdrantClient",
)
collection_name = kwargs["collection_name"]
if self.client.collection_exists(collection_name):
return self.client.get_collection(collection_name)
params = _get_collection_params(kwargs)
self.client.create_collection(**params)
return self.client.get_collection(collection_name)
async def aget_or_create_collection(
self, **kwargs: Unpack[QdrantCollectionCreateParams]
) -> Any:
"""Get an existing collection or create it if it doesn't exist asynchronously.
Keyword Args:
collection_name: Name of the collection to get or create.
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
sparse_vectors_config: Optional sparse vector configuration.
shard_number: Optional number of shards.
replication_factor: Optional replication factor.
write_consistency_factor: Optional write consistency factor.
on_disk_payload: Optional flag to store payload on disk.
hnsw_config: Optional HNSW index configuration.
optimizers_config: Optional optimizer configuration.
wal_config: Optional write-ahead log configuration.
quantization_config: Optional quantization configuration.
init_from: Optional collection to initialize from.
timeout: Optional timeout for the operation.
Returns:
Collection info dict with name and other metadata.
Raises:
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="aget_or_create_collection",
expected_client="AsyncQdrantClient",
alt_method="get_or_create_collection",
alt_client="QdrantClient",
)
collection_name = kwargs["collection_name"]
if await self.client.collection_exists(collection_name):
return await self.client.get_collection(collection_name)
params = _get_collection_params(kwargs)
await self.client.create_collection(**params)
return await self.client.get_collection(collection_name)
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection.
Keyword Args:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing document data.
batch_size: Optional batch size for processing documents (default: 100)
Raises:
ValueError: If collection doesn't exist or documents list is empty.
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="add_documents",
expected_client="QdrantClient",
alt_method="aadd_documents",
alt_client="AsyncQdrantClient",
)
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
batch_size = kwargs.get("batch_size", self.default_batch_size)
if not documents:
raise ValueError("Documents list cannot be empty")
if not self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")
for i in range(0, len(documents), batch_size):
batch_docs = documents[i : min(i + batch_size, len(documents))]
points = []
for doc in batch_docs:
if _is_async_embedding_function(self.embedding_function):
raise TypeError(
"Async embedding function cannot be used with sync add_documents. "
"Use aadd_documents instead."
)
sync_fn = cast(EmbeddingFunction, self.embedding_function)
embedding = sync_fn(doc["content"])
point = _create_point_from_document(doc, embedding)
points.append(point)
self.client.upsert(collection_name=collection_name, points=points)
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
"""Add documents with their embeddings to a collection asynchronously.
Keyword Args:
collection_name: The name of the collection to add documents to.
documents: List of BaseRecord dicts containing document data.
batch_size: Optional batch size for processing documents (default: 100)
Raises:
ValueError: If collection doesn't exist or documents list is empty.
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="aadd_documents",
expected_client="AsyncQdrantClient",
alt_method="add_documents",
alt_client="QdrantClient",
)
collection_name = kwargs["collection_name"]
documents = kwargs["documents"]
batch_size = kwargs.get("batch_size", self.default_batch_size)
if not documents:
raise ValueError("Documents list cannot be empty")
if not await self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")
for i in range(0, len(documents), batch_size):
batch_docs = documents[i : min(i + batch_size, len(documents))]
points = []
for doc in batch_docs:
if _is_async_embedding_function(self.embedding_function):
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
embedding = await async_fn(doc["content"])
else:
sync_fn = cast(EmbeddingFunction, self.embedding_function)
embedding = sync_fn(doc["content"])
point = _create_point_from_document(doc, embedding)
points.append(point)
await self.client.upsert(collection_name=collection_name, points=points)
def search(
self, **kwargs: Unpack[BaseCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query.
Keyword Args:
collection_name: Name of the collection to search in.
query: The text query to search for.
limit: Maximum number of results to return (default: 10).
metadata_filter: Optional filter for metadata fields.
score_threshold: Optional minimum similarity score (0-1) for results.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
Raises:
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="search",
expected_client="QdrantClient",
alt_method="asearch",
alt_client="AsyncQdrantClient",
)
collection_name = kwargs["collection_name"]
query = kwargs["query"]
limit = kwargs.get("limit", self.default_limit)
metadata_filter = kwargs.get("metadata_filter")
score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
if not self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")
if _is_async_embedding_function(self.embedding_function):
raise TypeError(
"Async embedding function cannot be used with sync search. "
"Use asearch instead."
)
sync_fn = cast(EmbeddingFunction, self.embedding_function)
query_embedding = sync_fn(query)
search_kwargs = _prepare_search_params(
collection_name=collection_name,
query_embedding=query_embedding,
limit=limit,
score_threshold=score_threshold,
metadata_filter=metadata_filter,
)
response = self.client.query_points(**search_kwargs)
return _process_search_results(response)
async def asearch(
self, **kwargs: Unpack[BaseCollectionSearchParams]
) -> list[SearchResult]:
"""Search for similar documents using a query asynchronously.
Keyword Args:
collection_name: Name of the collection to search in.
query: The text query to search for.
limit: Maximum number of results to return (default: 10).
metadata_filter: Optional filter for metadata fields.
score_threshold: Optional minimum similarity score (0-1) for results.
Returns:
List of SearchResult dicts containing id, content, metadata, and score.
Raises:
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="asearch",
expected_client="AsyncQdrantClient",
alt_method="search",
alt_client="QdrantClient",
)
collection_name = kwargs["collection_name"]
query = kwargs["query"]
limit = kwargs.get("limit", self.default_limit)
metadata_filter = kwargs.get("metadata_filter")
score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
if not await self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")
if _is_async_embedding_function(self.embedding_function):
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
query_embedding = await async_fn(query)
else:
sync_fn = cast(EmbeddingFunction, self.embedding_function)
query_embedding = sync_fn(query)
search_kwargs = _prepare_search_params(
collection_name=collection_name,
query_embedding=query_embedding,
limit=limit,
score_threshold=score_threshold,
metadata_filter=metadata_filter,
)
response = await self.client.query_points(**search_kwargs)
return _process_search_results(response)
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data.
Keyword Args:
collection_name: Name of the collection to delete.
Raises:
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="delete_collection",
expected_client="QdrantClient",
alt_method="adelete_collection",
alt_client="AsyncQdrantClient",
)
collection_name = kwargs["collection_name"]
if not self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")
self.client.delete_collection(collection_name=collection_name)
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
"""Delete a collection and all its data asynchronously.
Keyword Args:
collection_name: Name of the collection to delete.
Raises:
ValueError: If collection doesn't exist.
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="adelete_collection",
expected_client="AsyncQdrantClient",
alt_method="delete_collection",
alt_client="QdrantClient",
)
collection_name = kwargs["collection_name"]
if not await self.client.collection_exists(collection_name):
raise ValueError(f"Collection '{collection_name}' does not exist")
await self.client.delete_collection(collection_name=collection_name)
def reset(self) -> None:
"""Reset the vector database by deleting all collections and data.
Raises:
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_sync_client(self.client):
raise ClientMethodMismatchError(
method_name="reset",
expected_client="QdrantClient",
alt_method="areset",
alt_client="AsyncQdrantClient",
)
collections_response = self.client.get_collections()
for collection in collections_response.collections:
self.client.delete_collection(collection_name=collection.name)
async def areset(self) -> None:
"""Reset the vector database by deleting all collections and data asynchronously.
Raises:
ConnectionError: If unable to connect to Qdrant server.
"""
if not _is_async_client(self.client):
raise ClientMethodMismatchError(
method_name="areset",
expected_client="AsyncQdrantClient",
alt_method="reset",
alt_client="QdrantClient",
)
collections_response = await self.client.get_collections()
for collection in collections_response.collections:
await self.client.delete_collection(collection_name=collection.name)

View File

@@ -0,0 +1,55 @@
"""Qdrant configuration model."""
from dataclasses import field
from typing import Literal, cast
from pydantic.dataclasses import dataclass as pyd_dataclass
from crewai.rag.config.base import BaseRagConfig
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
def _default_options() -> QdrantClientParams:
"""Create default Qdrant client options.
Returns:
Default options with file-based storage.
"""
return QdrantClientParams(path=DEFAULT_STORAGE_PATH)
def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper:
"""Create default Qdrant embedding function.
Returns:
Default embedding function using fastembed with all-MiniLM-L6-v2.
"""
from fastembed import TextEmbedding # type: ignore[import-not-found]
model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL)
def embed_fn(text: str) -> list[float]:
"""Embed a single text string.
Args:
text: Text to embed.
Returns:
Embedding vector as list of floats.
"""
embeddings = list(model.embed([text]))
return embeddings[0].tolist() if embeddings else []
return cast(QdrantEmbeddingFunctionWrapper, embed_fn)
@pyd_dataclass(frozen=True)
class QdrantConfig(BaseRagConfig):
"""Configuration for Qdrant client."""
provider: Literal["qdrant"] = field(default="qdrant", init=False)
options: QdrantClientParams = field(default_factory=_default_options)
embedding_function: QdrantEmbeddingFunctionWrapper = field(
default_factory=_default_embedding_function
)

View File

@@ -0,0 +1,12 @@
"""Constants for Qdrant implementation."""
import os
from typing import Final
from qdrant_client.models import Distance, VectorParams
from crewai.core.utilities.paths import db_storage_path
DEFAULT_VECTOR_PARAMS: Final = VectorParams(size=384, distance=Distance.COSINE)
DEFAULT_EMBEDDING_MODEL: Final[str] = "sentence-transformers/all-MiniLM-L6-v2"
DEFAULT_STORAGE_PATH: Final[str] = os.path.join(db_storage_path(), "qdrant")

View File

@@ -0,0 +1,26 @@
"""Factory functions for creating Qdrant clients from configuration."""
from qdrant_client import QdrantClient as SyncQdrantClientBase
from crewai.rag.qdrant.client import QdrantClient
from crewai.rag.qdrant.config import QdrantConfig
def create_client(config: QdrantConfig) -> QdrantClient:
"""Create a Qdrant client from configuration.
Args:
config: The Qdrant configuration.
Returns:
A configured QdrantClient instance.
"""
qdrant_client = SyncQdrantClientBase(**config.options)
return QdrantClient(
client=qdrant_client,
embedding_function=config.embedding_function,
default_limit=config.limit,
default_score_threshold=config.score_threshold,
default_batch_size=config.batch_size,
)

View File

@@ -0,0 +1,156 @@
"""Type definitions specific to Qdrant implementation."""
from collections.abc import Awaitable, Callable
from typing import Annotated, Any, Protocol, TypeAlias
import numpy as np
from pydantic import GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found]
from qdrant_client import (
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
)
from qdrant_client.models import ( # type: ignore[import-not-found]
FieldCondition,
Filter,
HasIdCondition,
HasVectorCondition,
HnswConfigDiff,
InitFrom,
IsEmptyCondition,
IsNullCondition,
NestedCondition,
OptimizersConfigDiff,
QuantizationConfig,
ShardingMethod,
SparseVectorsConfig,
VectorsConfig,
WalConfigDiff,
)
from typing_extensions import NotRequired, TypedDict
from crewai.rag.core.base_client import BaseCollectionParams
QdrantClientType = SyncQdrantClient | AsyncQdrantClient
QueryEmbedding: TypeAlias = list[float] | np.ndarray[Any, np.dtype[np.floating[Any]]]
BasicConditions = FieldCondition | IsEmptyCondition | IsNullCondition
StructuralConditions = HasIdCondition | HasVectorCondition | NestedCondition
FilterCondition = BasicConditions | StructuralConditions | Filter
MetadataFilterValue = bool | int | str
MetadataFilter = dict[str, MetadataFilterValue]
class EmbeddingFunction(Protocol):
"""Protocol for embedding functions that convert text to vectors."""
def __call__(self, text: str) -> QueryEmbedding:
"""Convert text to embedding vector.
Args:
text: Input text to embed.
Returns:
Embedding vector as list of floats or numpy array.
"""
...
class QdrantEmbeddingFunctionWrapper(EmbeddingFunction):
"""Base class for Qdrant EmbeddingFunction to work with Pydantic validation."""
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Generate Pydantic core schema for Qdrant EmbeddingFunction.
This allows Pydantic to handle Qdrant's EmbeddingFunction type
without requiring arbitrary_types_allowed=True.
"""
return core_schema.any_schema()
class AsyncEmbeddingFunction(Protocol):
"""Protocol for async embedding functions that convert text to vectors."""
async def __call__(self, text: str) -> QueryEmbedding:
"""Convert text to embedding vector asynchronously.
Args:
text: Input text to embed.
Returns:
Embedding vector as list of floats or numpy array.
"""
...
class QdrantClientParams(TypedDict, total=False):
"""Parameters for QdrantClient initialization.
Notes:
Need to implement in factory or remove.
"""
location: str | None
url: str | None
port: int
grpc_port: int
prefer_grpc: bool
https: bool | None
api_key: str | None
prefix: str | None
timeout: int | None
host: str | None
path: str | None
force_disable_check_same_thread: bool
grpc_options: dict[str, Any] | None
auth_token_provider: Callable[[], str] | Callable[[], Awaitable[str]] | None
cloud_inference: bool
local_inference_batch_size: int | None
check_compatibility: bool
class CommonCreateFields(TypedDict, total=False):
"""Fields shared between high-level and direct create_collection params."""
vectors_config: VectorsConfig
sparse_vectors_config: SparseVectorsConfig
shard_number: Annotated[int, "Number of shards (default: 1)"]
sharding_method: ShardingMethod
replication_factor: Annotated[int, "Number of replicas per shard (default: 1)"]
write_consistency_factor: Annotated[int, "Await N replicas on write (default: 1)"]
on_disk_payload: Annotated[bool, "Store payload on disk instead of RAM"]
hnsw_config: HnswConfigDiff
optimizers_config: OptimizersConfigDiff
wal_config: WalConfigDiff
quantization_config: QuantizationConfig
init_from: InitFrom | str
timeout: Annotated[int, "Operation timeout in seconds"]
class QdrantCollectionCreateParams(
BaseCollectionParams, CommonCreateFields, total=False
):
"""High-level parameters for creating a Qdrant collection."""
class CreateCollectionParams(CommonCreateFields, total=False):
"""Parameters for qdrant_client.create_collection."""
collection_name: str
class PreparedSearchParams(TypedDict):
"""Type definition for prepared Qdrant search parameters."""
collection_name: str
query: list[float]
limit: Annotated[int, "Max results to return"]
with_payload: Annotated[bool, "Include payload in results"]
with_vectors: Annotated[bool, "Include vectors in results"]
score_threshold: NotRequired[Annotated[float, "Min similarity score (0-1)"]]
query_filter: NotRequired[Filter]

View File

@@ -0,0 +1,232 @@
"""Utility functions for Qdrant operations."""
import asyncio
from typing import TypeGuard
from uuid import uuid4
from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found]
from qdrant_client import (
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
)
from qdrant_client.models import ( # type: ignore[import-not-found]
FieldCondition,
Filter,
MatchValue,
PointStruct,
QueryResponse,
)
from crewai.rag.qdrant.constants import DEFAULT_VECTOR_PARAMS
from crewai.rag.qdrant.types import (
AsyncEmbeddingFunction,
CreateCollectionParams,
EmbeddingFunction,
FilterCondition,
MetadataFilter,
PreparedSearchParams,
QdrantClientType,
QdrantCollectionCreateParams,
QueryEmbedding,
)
from crewai.rag.types import BaseRecord, SearchResult
def _ensure_list_embedding(embedding: QueryEmbedding) -> list[float]:
"""Convert embedding to list[float] format if needed.
Args:
embedding: Embedding vector as list or numpy array.
Returns:
Embedding as list[float].
"""
if not isinstance(embedding, list):
result = embedding.tolist()
return result if isinstance(result, list) else [result]
return embedding
def _is_sync_client(client: QdrantClientType) -> TypeGuard[SyncQdrantClient]:
"""Type guard to check if the client is a synchronous QdrantClient.
Args:
client: The client to check.
Returns:
True if the client is a QdrantClient, False otherwise.
"""
return isinstance(client, SyncQdrantClient)
def _is_async_client(client: QdrantClientType) -> TypeGuard[AsyncQdrantClient]:
"""Type guard to check if the client is an asynchronous AsyncQdrantClient.
Args:
client: The client to check.
Returns:
True if the client is an AsyncQdrantClient, False otherwise.
"""
return isinstance(client, AsyncQdrantClient)
def _is_async_embedding_function(
func: EmbeddingFunction | AsyncEmbeddingFunction,
) -> TypeGuard[AsyncEmbeddingFunction]:
"""Type guard to check if the embedding function is async.
Args:
func: The embedding function to check.
Returns:
True if the function is async, False otherwise.
"""
return asyncio.iscoroutinefunction(func)
def _get_collection_params(
kwargs: QdrantCollectionCreateParams,
) -> CreateCollectionParams:
"""Extract collection creation parameters from kwargs."""
params: CreateCollectionParams = {
"collection_name": kwargs["collection_name"],
"vectors_config": kwargs.get("vectors_config", DEFAULT_VECTOR_PARAMS),
}
if "sparse_vectors_config" in kwargs:
params["sparse_vectors_config"] = kwargs["sparse_vectors_config"]
if "shard_number" in kwargs:
params["shard_number"] = kwargs["shard_number"]
if "sharding_method" in kwargs:
params["sharding_method"] = kwargs["sharding_method"]
if "replication_factor" in kwargs:
params["replication_factor"] = kwargs["replication_factor"]
if "write_consistency_factor" in kwargs:
params["write_consistency_factor"] = kwargs["write_consistency_factor"]
if "on_disk_payload" in kwargs:
params["on_disk_payload"] = kwargs["on_disk_payload"]
if "hnsw_config" in kwargs:
params["hnsw_config"] = kwargs["hnsw_config"]
if "optimizers_config" in kwargs:
params["optimizers_config"] = kwargs["optimizers_config"]
if "wal_config" in kwargs:
params["wal_config"] = kwargs["wal_config"]
if "quantization_config" in kwargs:
params["quantization_config"] = kwargs["quantization_config"]
if "init_from" in kwargs:
params["init_from"] = kwargs["init_from"]
if "timeout" in kwargs:
params["timeout"] = kwargs["timeout"]
return params
def _prepare_search_params(
collection_name: str,
query_embedding: QueryEmbedding,
limit: int,
score_threshold: float | None,
metadata_filter: MetadataFilter | None,
) -> PreparedSearchParams:
"""Prepare search parameters for Qdrant query_points.
Args:
collection_name: Name of the collection to search.
query_embedding: Embedding vector for the query.
limit: Maximum number of results.
score_threshold: Optional minimum similarity score.
metadata_filter: Optional metadata filters.
Returns:
Dictionary of parameters for query_points method.
"""
query_vector = _ensure_list_embedding(query_embedding)
search_kwargs: PreparedSearchParams = {
"collection_name": collection_name,
"query": query_vector,
"limit": limit,
"with_payload": True,
"with_vectors": False,
}
if score_threshold is not None:
search_kwargs["score_threshold"] = score_threshold
if metadata_filter:
filter_conditions: list[FilterCondition] = []
for key, value in metadata_filter.items():
filter_conditions.append(
FieldCondition(key=key, match=MatchValue(value=value))
)
search_kwargs["query_filter"] = Filter(must=filter_conditions)
return search_kwargs
def _normalize_qdrant_score(score: float) -> float:
"""Normalize Qdrant cosine similarity score to [0, 1] range.
Converts from Qdrant's [-1, 1] cosine similarity range to [0, 1] range for standardization across clients.
Args:
score: Raw cosine similarity score from Qdrant [-1, 1].
Returns:
Normalized score in [0, 1] range where 1 is most similar.
"""
normalized = (score + 1.0) / 2.0
return max(0.0, min(1.0, normalized))
def _process_search_results(response: QueryResponse) -> list[SearchResult]:
"""Process Qdrant search response into SearchResult format.
Args:
response: Response from Qdrant query_points method.
Returns:
List of SearchResult dictionaries.
"""
results: list[SearchResult] = []
for point in response.points:
payload = point.payload or {}
score = _normalize_qdrant_score(score=point.score)
result: SearchResult = {
"id": str(point.id),
"content": payload.get("content", ""),
"metadata": {k: v for k, v in payload.items() if k != "content"},
"score": score,
}
results.append(result)
return results
def _create_point_from_document(
doc: BaseRecord, embedding: QueryEmbedding
) -> PointStruct:
"""Create a PointStruct from a document and its embedding.
Args:
doc: Document dictionary containing content, metadata, and optional doc_id.
embedding: The embedding vector for the document content.
Returns:
PointStruct ready to be upserted to Qdrant.
"""
doc_id = doc.get("doc_id", str(uuid4()))
vector = _ensure_list_embedding(embedding)
metadata = doc.get("metadata", {})
if isinstance(metadata, list):
metadata = metadata[0] if metadata else {}
elif not isinstance(metadata, dict):
metadata = dict(metadata) if metadata else {}
return PointStruct(
id=doc_id,
vector=vector,
payload={"content": doc["content"], **metadata},
)

View File

@@ -0,0 +1 @@
"""Storage components for RAG infrastructure."""

View File

@@ -0,0 +1,55 @@
from abc import ABC, abstractmethod
from typing import Any
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.types import ProviderSpec
class BaseRAGStorage(ABC):
"""
Base class for RAG-based Storage implementations.
"""
app: Any | None = None
def __init__(
self,
type: str,
allow_reset: bool = True,
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
crew: Any = None,
):
self.type = type
self.allow_reset = allow_reset
self.embedder_config = embedder_config
self.crew = crew
self.agents = self._initialize_agents()
def _initialize_agents(self) -> str:
if self.crew:
return "_".join(
[self._sanitize_role(agent.role) for agent in self.crew.agents]
)
return ""
@abstractmethod
def _sanitize_role(self, role: str) -> str:
"""Sanitizes agent roles to ensure valid directory names."""
@abstractmethod
def save(self, value: Any, metadata: dict[str, Any]) -> None:
"""Save a value with metadata to the storage."""
@abstractmethod
def search(
self,
query: str,
limit: int = 5,
filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search for entries in the storage."""
@abstractmethod
def reset(self) -> None:
"""Reset the storage."""

View File

@@ -0,0 +1,49 @@
"""Type definitions for RAG (Retrieval-Augmented Generation) systems."""
from collections.abc import Callable, Mapping
from typing import Any, TypeAlias
from typing_extensions import Required, TypedDict
class BaseRecord(TypedDict, total=False):
"""A typed dictionary representing a document record.
Attributes:
doc_id: Optional unique identifier for the document. If not provided,
a content-based ID will be generated using SHA256 hash.
content: The text content of the document (required)
metadata: Optional metadata associated with the document
"""
doc_id: str
content: Required[str]
metadata: (
Mapping[str, str | int | float | bool]
| list[Mapping[str, str | int | float | bool]]
)
Embeddings: TypeAlias = list[list[float]]
EmbeddingFunction: TypeAlias = Callable[..., Any]
class SearchResult(TypedDict):
"""Standard search result format for vector store queries.
This provides a consistent interface for search results across different
vector store implementations. Each implementation should convert their
native result format to this standard format.
Attributes:
id: Unique identifier of the document
content: The text content of the document
metadata: Optional metadata associated with the document
score: Similarity score (higher is better, typically between 0 and 1)
"""
id: str
content: str
metadata: dict[str, Any]
score: float

View File

@@ -0,0 +1 @@
"""Core utilities for CrewAI."""

Some files were not shown because too many files have changed in this diff Show More