mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
refactor: move rag package to lib/core and extract standalone utilities
This commit is contained in:
1
lib/core/.python-version
Normal file
1
lib/core/.python-version
Normal file
@@ -0,0 +1 @@
|
||||
3.13
|
||||
0
lib/core/README.md
Normal file
0
lib/core/README.md
Normal file
56
lib/core/pyproject.toml
Normal file
56
lib/core/pyproject.toml
Normal file
@@ -0,0 +1,56 @@
|
||||
[project]
|
||||
name = "crewai-core"
|
||||
dynamic = ["version"]
|
||||
description = ""
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "Greyson Lalonde", email = "greyson.r.lalonde@gmail.com" }
|
||||
]
|
||||
keywords = [
|
||||
"crewai",
|
||||
"ai",
|
||||
"agents",
|
||||
"framework",
|
||||
"orchestration",
|
||||
"llm",
|
||||
"core",
|
||||
"typed",
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Typing :: Typed",
|
||||
]
|
||||
requires-python = ">=3.10, <3.14"
|
||||
dependencies = [
|
||||
"chromadb~=1.1.0",
|
||||
"pydantic-settings>=2.10.1",
|
||||
"uv>=0.4.25",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://crewai.com"
|
||||
Documentation = "https://docs.crewai.com"
|
||||
Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
asyncio_mode = "strict"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.version]
|
||||
path = "src/crewai/core/__init__.py"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/crewai"]
|
||||
1
lib/core/src/crewai/core/__init__.py
Normal file
1
lib/core/src/crewai/core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = "1.0.0a0"
|
||||
0
lib/core/src/crewai/core/py.typed
Normal file
0
lib/core/src/crewai/core/py.typed
Normal file
60
lib/core/src/crewai/core/rag/__init__.py
Normal file
60
lib/core/src/crewai/core/rag/__init__.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""RAG (Retrieval-Augmented Generation) infrastructure for CrewAI."""
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.config.utils import set_rag_config
|
||||
|
||||
_module_path = __path__
|
||||
_module_file = __file__
|
||||
|
||||
|
||||
class _RagModule(ModuleType):
|
||||
"""Module wrapper to intercept attribute setting for config."""
|
||||
|
||||
__path__ = _module_path
|
||||
__file__ = _module_file
|
||||
|
||||
def __init__(self, module_name: str):
|
||||
"""Initialize the module wrapper.
|
||||
|
||||
Args:
|
||||
module_name: Name of the module.
|
||||
"""
|
||||
super().__init__(module_name)
|
||||
|
||||
def __setattr__(self, name: str, value: RagConfigType) -> None:
|
||||
"""Set module attributes.
|
||||
|
||||
Args:
|
||||
name: Attribute name.
|
||||
value: Attribute value.
|
||||
"""
|
||||
if name == "config":
|
||||
return set_rag_config(value)
|
||||
raise AttributeError(f"Setting attribute '{name}' is not allowed.")
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Get module attributes.
|
||||
|
||||
Args:
|
||||
name: Attribute name.
|
||||
|
||||
Returns:
|
||||
The requested attribute.
|
||||
|
||||
Raises:
|
||||
AttributeError: If attribute doesn't exist.
|
||||
"""
|
||||
try:
|
||||
return importlib.import_module(f"{self.__name__}.{name}")
|
||||
except ImportError as e:
|
||||
raise AttributeError(
|
||||
f"module '{self.__name__}' has no attribute '{name}'"
|
||||
) from e
|
||||
|
||||
|
||||
sys.modules[__name__] = _RagModule(__name__)
|
||||
0
lib/core/src/crewai/core/rag/chromadb/__init__.py
Normal file
0
lib/core/src/crewai/core/rag/chromadb/__init__.py
Normal file
617
lib/core/src/crewai/core/rag/chromadb/client.py
Normal file
617
lib/core/src/crewai/core/rag/chromadb/client.py
Normal file
@@ -0,0 +1,617 @@
|
||||
"""ChromaDB client implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from chromadb.api.types import (
|
||||
EmbeddingFunction as ChromaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.api.types import (
|
||||
QueryResult,
|
||||
)
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.chromadb.types import (
|
||||
ChromaDBClientType,
|
||||
ChromaDBCollectionCreateParams,
|
||||
ChromaDBCollectionSearchParams,
|
||||
)
|
||||
from crewai.rag.chromadb.utils import (
|
||||
_create_batch_slice,
|
||||
_extract_search_params,
|
||||
_is_async_client,
|
||||
_is_sync_client,
|
||||
_prepare_documents_for_chromadb,
|
||||
_process_query_results,
|
||||
_sanitize_collection_name,
|
||||
)
|
||||
from crewai.rag.core.base_client import (
|
||||
BaseClient,
|
||||
BaseCollectionAddParams,
|
||||
BaseCollectionParams,
|
||||
)
|
||||
from crewai.rag.types import SearchResult
|
||||
from crewai.core.utilities.logger_utils import suppress_logging
|
||||
|
||||
|
||||
class ChromaDBClient(BaseClient):
|
||||
"""ChromaDB implementation of the BaseClient protocol.
|
||||
|
||||
Provides vector database operations for ChromaDB, supporting both
|
||||
synchronous and asynchronous clients.
|
||||
|
||||
Attributes:
|
||||
client: ChromaDB client instance (ClientAPI or AsyncClientAPI).
|
||||
embedding_function: Function to generate embeddings for documents.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: ChromaDBClientType,
|
||||
embedding_function: ChromaEmbeddingFunction,
|
||||
default_limit: int = 5,
|
||||
default_score_threshold: float = 0.6,
|
||||
default_batch_size: int = 100,
|
||||
) -> None:
|
||||
"""Initialize ChromaDBClient with client and embedding function.
|
||||
|
||||
Args:
|
||||
client: Pre-configured ChromaDB client instance.
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
default_batch_size: Default batch size for adding documents.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
self.default_limit = default_limit
|
||||
self.default_score_threshold = default_score_threshold
|
||||
self.default_batch_size = default_batch_size
|
||||
|
||||
def create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
) -> None:
|
||||
"""Create a new collection in ChromaDB.
|
||||
|
||||
Uses the client's default embedding function if none provided.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to create. Must be unique.
|
||||
configuration: Optional collection configuration specifying distance metrics,
|
||||
HNSW parameters, or other backend-specific settings.
|
||||
metadata: Optional metadata dictionary to attach to the collection.
|
||||
embedding_function: Optional custom embedding function. If not provided,
|
||||
uses the client's default embedding function.
|
||||
data_loader: Optional data loader for batch loading data into the collection.
|
||||
get_or_create: If True, returns existing collection if it already exists
|
||||
instead of raising an error. Defaults to False.
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ValueError: If collection with the same name already exists and get_or_create
|
||||
is False.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> client = ChromaDBClient()
|
||||
>>> client.create_collection(
|
||||
... collection_name="documents",
|
||||
... metadata={"description": "Product documentation"},
|
||||
... get_or_create=True
|
||||
... )
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method create_collection() requires a ClientAPI. "
|
||||
"Use acreate_collection() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
metadata = kwargs.get("metadata", {})
|
||||
if "hnsw:space" not in metadata:
|
||||
metadata["hnsw:space"] = "cosine"
|
||||
|
||||
self.client.create_collection(
|
||||
name=_sanitize_collection_name(kwargs["collection_name"]),
|
||||
configuration=kwargs.get("configuration"),
|
||||
metadata=metadata,
|
||||
embedding_function=kwargs.get(
|
||||
"embedding_function", self.embedding_function
|
||||
),
|
||||
data_loader=kwargs.get("data_loader"),
|
||||
get_or_create=kwargs.get("get_or_create", False),
|
||||
)
|
||||
|
||||
async def acreate_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
) -> None:
|
||||
"""Create a new collection in ChromaDB asynchronously.
|
||||
|
||||
Creates a new collection with the specified name and optional configuration.
|
||||
If an embedding function is not provided, uses the client's default embedding function.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to create. Must be unique.
|
||||
configuration: Optional collection configuration specifying distance metrics,
|
||||
HNSW parameters, or other backend-specific settings.
|
||||
metadata: Optional metadata dictionary to attach to the collection.
|
||||
embedding_function: Optional custom embedding function. If not provided,
|
||||
uses the client's default embedding function.
|
||||
data_loader: Optional data loader for batch loading data into the collection.
|
||||
get_or_create: If True, returns existing collection if it already exists
|
||||
instead of raising an error. Defaults to False.
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ValueError: If collection with the same name already exists and get_or_create
|
||||
is False.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> async def main():
|
||||
... client = ChromaDBClient()
|
||||
... await client.acreate_collection(
|
||||
... collection_name="documents",
|
||||
... metadata={"description": "Product documentation"},
|
||||
... get_or_create=True
|
||||
... )
|
||||
>>> asyncio.run(main())
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method acreate_collection() requires an AsyncClientAPI. "
|
||||
"Use create_collection() for ClientAPI."
|
||||
)
|
||||
|
||||
metadata = kwargs.get("metadata", {})
|
||||
if "hnsw:space" not in metadata:
|
||||
metadata["hnsw:space"] = "cosine"
|
||||
|
||||
await self.client.create_collection(
|
||||
name=_sanitize_collection_name(kwargs["collection_name"]),
|
||||
configuration=kwargs.get("configuration"),
|
||||
metadata=metadata,
|
||||
embedding_function=kwargs.get(
|
||||
"embedding_function", self.embedding_function
|
||||
),
|
||||
data_loader=kwargs.get("data_loader"),
|
||||
get_or_create=kwargs.get("get_or_create", False),
|
||||
)
|
||||
|
||||
def get_or_create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
) -> Any:
|
||||
"""Get an existing collection or create it if it doesn't exist.
|
||||
|
||||
Returns existing collection if found, otherwise creates a new one.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to get or create.
|
||||
configuration: Optional collection configuration specifying distance metrics,
|
||||
HNSW parameters, or other backend-specific settings.
|
||||
metadata: Optional metadata dictionary to attach to the collection.
|
||||
embedding_function: Optional custom embedding function. If not provided,
|
||||
uses the client's default embedding function.
|
||||
data_loader: Optional data loader for batch loading data into the collection.
|
||||
|
||||
Returns:
|
||||
A ChromaDB Collection object.
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> client = ChromaDBClient()
|
||||
>>> collection = client.get_or_create_collection(
|
||||
... collection_name="documents",
|
||||
... metadata={"description": "Product documentation"}
|
||||
... )
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method get_or_create_collection() requires a ClientAPI. "
|
||||
"Use aget_or_create_collection() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
metadata = kwargs.get("metadata", {})
|
||||
if "hnsw:space" not in metadata:
|
||||
metadata["hnsw:space"] = "cosine"
|
||||
|
||||
return self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(kwargs["collection_name"]),
|
||||
configuration=kwargs.get("configuration"),
|
||||
metadata=metadata,
|
||||
embedding_function=kwargs.get(
|
||||
"embedding_function", self.embedding_function
|
||||
),
|
||||
data_loader=kwargs.get("data_loader"),
|
||||
)
|
||||
|
||||
async def aget_or_create_collection(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionCreateParams]
|
||||
) -> Any:
|
||||
"""Get an existing collection or create it if it doesn't exist asynchronously.
|
||||
|
||||
Returns existing collection if found, otherwise creates a new one.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to get or create.
|
||||
configuration: Optional collection configuration specifying distance metrics,
|
||||
HNSW parameters, or other backend-specific settings.
|
||||
metadata: Optional metadata dictionary to attach to the collection.
|
||||
embedding_function: Optional custom embedding function. If not provided,
|
||||
uses the client's default embedding function.
|
||||
data_loader: Optional data loader for batch loading data into the collection.
|
||||
|
||||
Returns:
|
||||
A ChromaDB AsyncCollection object.
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> async def main():
|
||||
... client = ChromaDBClient()
|
||||
... collection = await client.aget_or_create_collection(
|
||||
... collection_name="documents",
|
||||
... metadata={"description": "Product documentation"}
|
||||
... )
|
||||
>>> asyncio.run(main())
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method aget_or_create_collection() requires an AsyncClientAPI. "
|
||||
"Use get_or_create_collection() for ClientAPI."
|
||||
)
|
||||
|
||||
metadata = kwargs.get("metadata", {})
|
||||
if "hnsw:space" not in metadata:
|
||||
metadata["hnsw:space"] = "cosine"
|
||||
|
||||
return await self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(kwargs["collection_name"]),
|
||||
configuration=kwargs.get("configuration"),
|
||||
metadata=metadata,
|
||||
embedding_function=kwargs.get(
|
||||
"embedding_function", self.embedding_function
|
||||
),
|
||||
data_loader=kwargs.get("data_loader"),
|
||||
)
|
||||
|
||||
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection.
|
||||
|
||||
Performs an upsert operation - documents with existing IDs are updated.
|
||||
Generates embeddings automatically using the configured embedding function.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing:
|
||||
- content: The text content (required)
|
||||
- doc_id: Optional unique identifier (auto-generated if missing)
|
||||
- metadata: Optional metadata dictionary
|
||||
batch_size: Optional batch size for processing documents (default: 100)
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method add_documents() requires a ClientAPI. "
|
||||
"Use aadd_documents() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
batch_size = kwargs.get("batch_size", self.default_batch_size)
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
for i in range(0, len(prepared.ids), batch_size):
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared=prepared, start_index=i, batch_size=batch_size
|
||||
)
|
||||
|
||||
collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_texts,
|
||||
metadatas=batch_metadatas,
|
||||
)
|
||||
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection asynchronously.
|
||||
|
||||
Performs an upsert operation - documents with existing IDs are updated.
|
||||
Generates embeddings automatically using the configured embedding function.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing:
|
||||
- content: The text content (required)
|
||||
- doc_id: Optional unique identifier (auto-generated if missing)
|
||||
- metadata: Optional metadata dictionary
|
||||
batch_size: Optional batch size for processing documents (default: 100)
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method aadd_documents() requires an AsyncClientAPI. "
|
||||
"Use add_documents() for ClientAPI."
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
batch_size = kwargs.get("batch_size", self.default_batch_size)
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
prepared = _prepare_documents_for_chromadb(documents)
|
||||
|
||||
for i in range(0, len(prepared.ids), batch_size):
|
||||
batch_ids, batch_texts, batch_metadatas = _create_batch_slice(
|
||||
prepared=prepared, start_index=i, batch_size=batch_size
|
||||
)
|
||||
|
||||
await collection.upsert(
|
||||
ids=batch_ids,
|
||||
documents=batch_texts,
|
||||
metadatas=batch_metadatas,
|
||||
)
|
||||
|
||||
def search(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query.
|
||||
|
||||
Performs semantic search to find documents similar to the query text.
|
||||
Uses the configured embedding function to generate query embeddings.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to search in.
|
||||
query: The text query to search for.
|
||||
limit: Maximum number of results to return (default: 10).
|
||||
metadata_filter: Optional filter for metadata fields.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
where: Optional ChromaDB where clause for metadata filtering.
|
||||
where_document: Optional ChromaDB where clause for document content filtering.
|
||||
include: Optional list of fields to include in results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method search() requires a ClientAPI. "
|
||||
"Use asearch() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
if "limit" not in kwargs:
|
||||
kwargs["limit"] = self.default_limit
|
||||
if "score_threshold" not in kwargs:
|
||||
kwargs["score_threshold"] = self.default_score_threshold
|
||||
|
||||
params = _extract_search_params(kwargs)
|
||||
|
||||
collection = self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(params.collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
where = params.where if params.where is not None else params.metadata_filter
|
||||
|
||||
with suppress_logging(
|
||||
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||
):
|
||||
results: QueryResult = collection.query(
|
||||
query_texts=[params.query],
|
||||
n_results=params.limit,
|
||||
where=where,
|
||||
where_document=params.where_document,
|
||||
include=params.include,
|
||||
)
|
||||
|
||||
return _process_query_results(
|
||||
collection=collection,
|
||||
results=results,
|
||||
params=params,
|
||||
)
|
||||
|
||||
async def asearch(
|
||||
self, **kwargs: Unpack[ChromaDBCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query asynchronously.
|
||||
|
||||
Performs semantic search to find documents similar to the query text.
|
||||
Uses the configured embedding function to generate query embeddings.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to search in.
|
||||
query: The text query to search for.
|
||||
limit: Maximum number of results to return (default: 10).
|
||||
metadata_filter: Optional filter for metadata fields.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
where: Optional ChromaDB where clause for metadata filtering.
|
||||
where_document: Optional ChromaDB where clause for document content filtering.
|
||||
include: Optional list of fields to include in results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method asearch() requires an AsyncClientAPI. "
|
||||
"Use search() for ClientAPI."
|
||||
)
|
||||
|
||||
if "limit" not in kwargs:
|
||||
kwargs["limit"] = self.default_limit
|
||||
if "score_threshold" not in kwargs:
|
||||
kwargs["score_threshold"] = self.default_score_threshold
|
||||
|
||||
params = _extract_search_params(kwargs)
|
||||
|
||||
collection = await self.client.get_or_create_collection(
|
||||
name=_sanitize_collection_name(params.collection_name),
|
||||
embedding_function=self.embedding_function,
|
||||
)
|
||||
|
||||
where = params.where if params.where is not None else params.metadata_filter
|
||||
|
||||
with suppress_logging(
|
||||
"chromadb.segment.impl.vector.local_persistent_hnsw", logging.ERROR
|
||||
):
|
||||
results: QueryResult = await collection.query(
|
||||
query_texts=[params.query],
|
||||
n_results=params.limit,
|
||||
where=where,
|
||||
where_document=params.where_document,
|
||||
include=params.include,
|
||||
)
|
||||
|
||||
return _process_query_results(
|
||||
collection=collection,
|
||||
results=results,
|
||||
params=params,
|
||||
)
|
||||
|
||||
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data.
|
||||
|
||||
Permanently removes a collection and all documents, embeddings, and metadata it contains.
|
||||
This operation cannot be undone.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> client = ChromaDBClient()
|
||||
>>> client.delete_collection(collection_name="old_documents")
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method delete_collection() requires a ClientAPI. "
|
||||
"Use adelete_collection() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
self.client.delete_collection(name=_sanitize_collection_name(collection_name))
|
||||
|
||||
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data asynchronously.
|
||||
|
||||
Permanently removes a collection and all documents, embeddings, and metadata it contains.
|
||||
This operation cannot be undone.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> async def main():
|
||||
... client = ChromaDBClient()
|
||||
... await client.adelete_collection(collection_name="old_documents")
|
||||
>>> asyncio.run(main())
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method adelete_collection() requires an AsyncClientAPI. "
|
||||
"Use delete_collection() for ClientAPI."
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
await self.client.delete_collection(
|
||||
name=_sanitize_collection_name(collection_name)
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data.
|
||||
|
||||
Completely clears the ChromaDB instance, removing all collections,
|
||||
documents, embeddings, and metadata. This operation cannot be undone.
|
||||
Use with extreme caution in production environments.
|
||||
|
||||
Raises:
|
||||
TypeError: If AsyncClientAPI is used instead of ClientAPI for sync operations.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> client = ChromaDBClient()
|
||||
>>> client.reset() # Removes ALL data from ChromaDB
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise TypeError(
|
||||
"Synchronous method reset() requires a ClientAPI. "
|
||||
"Use areset() for AsyncClientAPI."
|
||||
)
|
||||
|
||||
self.client.reset()
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data asynchronously.
|
||||
|
||||
Completely clears the ChromaDB instance, removing all collections,
|
||||
documents, embeddings, and metadata. This operation cannot be undone.
|
||||
Use with extreme caution in production environments.
|
||||
|
||||
Raises:
|
||||
TypeError: If ClientAPI is used instead of AsyncClientAPI for async operations.
|
||||
ConnectionError: If unable to connect to ChromaDB server.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> async def main():
|
||||
... client = ChromaDBClient()
|
||||
... await client.areset() # Removes ALL data from ChromaDB
|
||||
>>> asyncio.run(main())
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise TypeError(
|
||||
"Asynchronous method areset() requires an AsyncClientAPI. "
|
||||
"Use reset() for ClientAPI."
|
||||
)
|
||||
|
||||
await self.client.reset()
|
||||
75
lib/core/src/crewai/core/rag/chromadb/config.py
Normal file
75
lib/core/src/crewai/core/rag/chromadb/config.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""ChromaDB configuration model."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import field
|
||||
from typing import Literal, cast
|
||||
|
||||
from chromadb.config import Settings
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
from crewai.rag.chromadb.constants import (
|
||||
DEFAULT_DATABASE,
|
||||
DEFAULT_STORAGE_PATH,
|
||||
DEFAULT_TENANT,
|
||||
)
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.base import BaseRagConfig
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=".*Mixing V1 models and V2 models.*",
|
||||
category=UserWarning,
|
||||
module="pydantic._internal._generate_schema",
|
||||
)
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r".*'model_fields'.*is deprecated.*",
|
||||
module=r"^chromadb(\.|$)",
|
||||
)
|
||||
|
||||
|
||||
def _default_settings() -> Settings:
|
||||
"""Create default ChromaDB settings.
|
||||
|
||||
Returns:
|
||||
Settings with persistent storage and reset enabled.
|
||||
"""
|
||||
return Settings(
|
||||
persist_directory=DEFAULT_STORAGE_PATH,
|
||||
allow_reset=True,
|
||||
is_persistent=True,
|
||||
)
|
||||
|
||||
|
||||
def _default_embedding_function() -> ChromaEmbeddingFunctionWrapper:
|
||||
"""Create default ChromaDB embedding function.
|
||||
|
||||
Returns:
|
||||
Default embedding function using all-MiniLM-L6-v2 via ONNX.
|
||||
"""
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return cast(
|
||||
ChromaEmbeddingFunctionWrapper,
|
||||
OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model_name="text-embedding-3-small",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pyd_dataclass(frozen=True)
|
||||
class ChromaDBConfig(BaseRagConfig):
|
||||
"""Configuration for ChromaDB client."""
|
||||
|
||||
provider: Literal["chromadb"] = field(default="chromadb", init=False)
|
||||
tenant: str = DEFAULT_TENANT
|
||||
database: str = DEFAULT_DATABASE
|
||||
settings: Settings = field(default_factory=_default_settings)
|
||||
embedding_function: ChromaEmbeddingFunctionWrapper = field(
|
||||
default_factory=_default_embedding_function
|
||||
)
|
||||
17
lib/core/src/crewai/core/rag/chromadb/constants.py
Normal file
17
lib/core/src/crewai/core/rag/chromadb/constants.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Constants for ChromaDB configuration."""
|
||||
|
||||
import re
|
||||
from typing import Final
|
||||
|
||||
from crewai.core.utilities.paths import db_storage_path
|
||||
|
||||
DEFAULT_TENANT: Final[str] = "default_tenant"
|
||||
DEFAULT_DATABASE: Final[str] = "default_database"
|
||||
DEFAULT_STORAGE_PATH: Final[str] = db_storage_path()
|
||||
|
||||
MIN_COLLECTION_LENGTH: Final[int] = 3
|
||||
MAX_COLLECTION_LENGTH: Final[int] = 63
|
||||
DEFAULT_COLLECTION: Final[str] = "default_collection"
|
||||
|
||||
INVALID_CHARS_PATTERN: Final[re.Pattern[str]] = re.compile(r"[^a-zA-Z0-9_-]")
|
||||
IPV4_PATTERN: Final[re.Pattern[str]] = re.compile(r"^(\d{1,3}\.){3}\d{1,3}$")
|
||||
45
lib/core/src/crewai/core/rag/chromadb/factory.py
Normal file
45
lib/core/src/crewai/core/rag/chromadb/factory.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Factory functions for creating ChromaDB clients."""
|
||||
|
||||
import os
|
||||
from hashlib import md5
|
||||
|
||||
import portalocker
|
||||
from chromadb import PersistentClient
|
||||
|
||||
from crewai.rag.chromadb.client import ChromaDBClient
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
|
||||
|
||||
def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
"""Create a ChromaDBClient from configuration.
|
||||
|
||||
Args:
|
||||
config: ChromaDB configuration object.
|
||||
|
||||
Returns:
|
||||
Configured ChromaDBClient instance.
|
||||
|
||||
Notes:
|
||||
Need to update to use chromadb.Client to support more client types in the near future.
|
||||
"""
|
||||
|
||||
persist_dir = config.settings.persist_directory
|
||||
os.makedirs(persist_dir, exist_ok=True)
|
||||
lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest()
|
||||
lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock")
|
||||
|
||||
with portalocker.Lock(lockfile):
|
||||
client = PersistentClient(
|
||||
path=persist_dir,
|
||||
settings=config.settings,
|
||||
tenant=config.tenant,
|
||||
database=config.database,
|
||||
)
|
||||
|
||||
return ChromaDBClient(
|
||||
client=client,
|
||||
embedding_function=config.embedding_function,
|
||||
default_limit=config.limit,
|
||||
default_score_threshold=config.score_threshold,
|
||||
default_batch_size=config.batch_size,
|
||||
)
|
||||
103
lib/core/src/crewai/core/rag/chromadb/types.py
Normal file
103
lib/core/src/crewai/core/rag/chromadb/types.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Type definitions specific to ChromaDB implementation."""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from chromadb.api import AsyncClientAPI, ClientAPI
|
||||
from chromadb.api.configuration import CollectionConfigurationInterface
|
||||
from chromadb.api.types import (
|
||||
CollectionMetadata,
|
||||
DataLoader,
|
||||
Include,
|
||||
Loadable,
|
||||
Where,
|
||||
WhereDocument,
|
||||
)
|
||||
from chromadb.api.types import (
|
||||
EmbeddingFunction as ChromaEmbeddingFunction,
|
||||
)
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
|
||||
from crewai.rag.core.base_client import BaseCollectionParams, BaseCollectionSearchParams
|
||||
|
||||
ChromaDBClientType = ClientAPI | AsyncClientAPI
|
||||
|
||||
|
||||
class ChromaEmbeddingFunctionWrapper(ChromaEmbeddingFunction):
|
||||
"""Base class for ChromaDB EmbeddingFunction to work with Pydantic validation."""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""Generate Pydantic core schema for ChromaDB EmbeddingFunction.
|
||||
|
||||
This allows Pydantic to handle ChromaDB's EmbeddingFunction type
|
||||
without requiring arbitrary_types_allowed=True.
|
||||
"""
|
||||
return core_schema.any_schema()
|
||||
|
||||
|
||||
class PreparedDocuments(NamedTuple):
|
||||
"""Prepared documents ready for ChromaDB insertion.
|
||||
|
||||
Attributes:
|
||||
ids: List of document IDs
|
||||
texts: List of document texts
|
||||
metadatas: List of document metadata mappings (empty dict for no metadata)
|
||||
"""
|
||||
|
||||
ids: list[str]
|
||||
texts: list[str]
|
||||
metadatas: list[Mapping[str, str | int | float | bool]]
|
||||
|
||||
|
||||
class ExtractedSearchParams(NamedTuple):
|
||||
"""Extracted search parameters for ChromaDB queries.
|
||||
|
||||
Attributes:
|
||||
collection_name: Name of the collection to search
|
||||
query: Search query text
|
||||
limit: Maximum number of results
|
||||
metadata_filter: Optional metadata filter
|
||||
score_threshold: Optional minimum similarity score
|
||||
where: Optional ChromaDB where clause
|
||||
where_document: Optional ChromaDB document filter
|
||||
include: Fields to include in results
|
||||
"""
|
||||
|
||||
collection_name: str
|
||||
query: str
|
||||
limit: int
|
||||
metadata_filter: dict[str, Any] | None
|
||||
score_threshold: float | None
|
||||
where: Where | None
|
||||
where_document: WhereDocument | None
|
||||
include: Include
|
||||
|
||||
|
||||
class ChromaDBCollectionCreateParams(BaseCollectionParams, total=False):
|
||||
"""Parameters for creating a ChromaDB collection.
|
||||
|
||||
This class extends BaseCollectionParams to include any additional
|
||||
parameters specific to ChromaDB collection creation.
|
||||
"""
|
||||
|
||||
configuration: CollectionConfigurationInterface
|
||||
metadata: CollectionMetadata
|
||||
embedding_function: ChromaEmbeddingFunction
|
||||
data_loader: DataLoader[Loadable]
|
||||
get_or_create: bool
|
||||
|
||||
|
||||
class ChromaDBCollectionSearchParams(BaseCollectionSearchParams, total=False):
|
||||
"""Parameters for searching a ChromaDB collection.
|
||||
|
||||
This class extends BaseCollectionSearchParams to include ChromaDB-specific
|
||||
search parameters like where clauses and include options.
|
||||
"""
|
||||
|
||||
where: Where
|
||||
where_document: WhereDocument
|
||||
include: Include
|
||||
323
lib/core/src/crewai/core/rag/chromadb/utils.py
Normal file
323
lib/core/src/crewai/core/rag/chromadb/utils.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""Utility functions for ChromaDB client implementation."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import Literal, TypeGuard, cast
|
||||
|
||||
from chromadb.api import AsyncClientAPI, ClientAPI
|
||||
from chromadb.api.models.AsyncCollection import AsyncCollection
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from chromadb.api.types import (
|
||||
Include,
|
||||
QueryResult,
|
||||
)
|
||||
|
||||
from crewai.rag.chromadb.constants import (
|
||||
DEFAULT_COLLECTION,
|
||||
INVALID_CHARS_PATTERN,
|
||||
IPV4_PATTERN,
|
||||
MAX_COLLECTION_LENGTH,
|
||||
MIN_COLLECTION_LENGTH,
|
||||
)
|
||||
from crewai.rag.chromadb.types import (
|
||||
ChromaDBClientType,
|
||||
ChromaDBCollectionSearchParams,
|
||||
ExtractedSearchParams,
|
||||
PreparedDocuments,
|
||||
)
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
|
||||
|
||||
def _is_sync_client(client: ChromaDBClientType) -> TypeGuard[ClientAPI]:
|
||||
"""Type guard to check if the client is a synchronous ClientAPI.
|
||||
|
||||
Args:
|
||||
client: The client to check.
|
||||
|
||||
Returns:
|
||||
True if the client is a ClientAPI, False otherwise.
|
||||
"""
|
||||
return isinstance(client, ClientAPI)
|
||||
|
||||
|
||||
def _is_async_client(client: ChromaDBClientType) -> TypeGuard[AsyncClientAPI]:
|
||||
"""Type guard to check if the client is an asynchronous AsyncClientAPI.
|
||||
|
||||
Args:
|
||||
client: The client to check.
|
||||
|
||||
Returns:
|
||||
True if the client is an AsyncClientAPI, False otherwise.
|
||||
"""
|
||||
return isinstance(client, AsyncClientAPI)
|
||||
|
||||
|
||||
def _prepare_documents_for_chromadb(
|
||||
documents: list[BaseRecord],
|
||||
) -> PreparedDocuments:
|
||||
"""Prepare documents for ChromaDB by extracting IDs, texts, and metadata.
|
||||
|
||||
Args:
|
||||
documents: List of BaseRecord documents to prepare.
|
||||
|
||||
Returns:
|
||||
PreparedDocuments with ids, texts, and metadatas ready for ChromaDB.
|
||||
"""
|
||||
ids: list[str] = []
|
||||
texts: list[str] = []
|
||||
metadatas: list[Mapping[str, str | int | float | bool]] = []
|
||||
|
||||
for doc in documents:
|
||||
if "doc_id" in doc:
|
||||
ids.append(doc["doc_id"])
|
||||
else:
|
||||
content_for_hash = doc["content"]
|
||||
metadata = doc.get("metadata")
|
||||
if metadata:
|
||||
metadata_str = json.dumps(metadata, sort_keys=True)
|
||||
content_for_hash = f"{content_for_hash}|{metadata_str}"
|
||||
|
||||
content_hash = hashlib.blake2b(
|
||||
content_for_hash.encode(), digest_size=32
|
||||
).hexdigest()
|
||||
ids.append(content_hash)
|
||||
|
||||
texts.append(doc["content"])
|
||||
metadata = doc.get("metadata")
|
||||
if metadata:
|
||||
if isinstance(metadata, list):
|
||||
metadatas.append(metadata[0] if metadata and metadata[0] else {})
|
||||
else:
|
||||
metadatas.append(metadata)
|
||||
else:
|
||||
metadatas.append({})
|
||||
|
||||
return PreparedDocuments(ids, texts, metadatas)
|
||||
|
||||
|
||||
def _create_batch_slice(
|
||||
prepared: PreparedDocuments, start_index: int, batch_size: int
|
||||
) -> tuple[list[str], list[str], list[Mapping[str, str | int | float | bool]] | None]:
|
||||
"""Create a batch slice from prepared documents.
|
||||
|
||||
Args:
|
||||
prepared: PreparedDocuments containing ids, texts, and metadatas.
|
||||
start_index: Starting index for the batch.
|
||||
batch_size: Size of the batch.
|
||||
|
||||
Returns:
|
||||
Tuple of (batch_ids, batch_texts, batch_metadatas).
|
||||
"""
|
||||
batch_end = min(start_index + batch_size, len(prepared.ids))
|
||||
batch_ids = prepared.ids[start_index:batch_end]
|
||||
batch_texts = prepared.texts[start_index:batch_end]
|
||||
batch_metadatas = (
|
||||
prepared.metadatas[start_index:batch_end] if prepared.metadatas else None
|
||||
)
|
||||
|
||||
if batch_metadatas and not any(m for m in batch_metadatas):
|
||||
batch_metadatas = None
|
||||
|
||||
return batch_ids, batch_texts, batch_metadatas
|
||||
|
||||
|
||||
def _extract_search_params(
|
||||
kwargs: ChromaDBCollectionSearchParams,
|
||||
) -> ExtractedSearchParams:
|
||||
"""Extract search parameters from kwargs.
|
||||
|
||||
Args:
|
||||
kwargs: Keyword arguments containing search parameters.
|
||||
|
||||
Returns:
|
||||
ExtractedSearchParams with all extracted parameters.
|
||||
"""
|
||||
return ExtractedSearchParams(
|
||||
collection_name=kwargs["collection_name"],
|
||||
query=kwargs["query"],
|
||||
limit=kwargs.get("limit", 10),
|
||||
metadata_filter=kwargs.get("metadata_filter"),
|
||||
score_threshold=kwargs.get("score_threshold"),
|
||||
where=kwargs.get("where"),
|
||||
where_document=kwargs.get("where_document"),
|
||||
include=cast(
|
||||
Include,
|
||||
kwargs.get(
|
||||
"include",
|
||||
["metadatas", "documents", "distances"],
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _convert_distance_to_score(
|
||||
distance: float,
|
||||
distance_metric: Literal["l2", "cosine", "ip"],
|
||||
) -> float:
|
||||
"""Convert ChromaDB distance to similarity score.
|
||||
|
||||
Notes:
|
||||
Assuming all embedding are unit-normalized for now, including custom embeddings.
|
||||
|
||||
Args:
|
||||
distance: The distance value from ChromaDB.
|
||||
distance_metric: The distance metric used ("l2", "cosine", or "ip").
|
||||
|
||||
Returns:
|
||||
Similarity score in range [0, 1] where 1 is most similar.
|
||||
"""
|
||||
if distance_metric == "cosine":
|
||||
score = 1.0 - 0.5 * distance
|
||||
return max(0.0, min(1.0, score))
|
||||
if distance_metric == "l2":
|
||||
score = 1.0 / (1.0 + distance)
|
||||
return max(0.0, min(1.0, score))
|
||||
raise ValueError(f"Unsupported distance metric: {distance_metric}")
|
||||
|
||||
|
||||
def _convert_chromadb_results_to_search_results(
|
||||
results: QueryResult,
|
||||
include: Include,
|
||||
distance_metric: Literal["l2", "cosine", "ip"],
|
||||
score_threshold: float | None = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Convert ChromaDB query results to SearchResult format.
|
||||
|
||||
Args:
|
||||
results: ChromaDB query results.
|
||||
include: List of fields that were included in the query.
|
||||
distance_metric: The distance metric used by the collection.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
"""
|
||||
search_results: list[SearchResult] = []
|
||||
|
||||
include_strings = list(include) if include else []
|
||||
|
||||
ids = results["ids"][0] if results.get("ids") else []
|
||||
|
||||
documents_list = results.get("documents")
|
||||
documents = (
|
||||
documents_list[0] if documents_list and "documents" in include_strings else []
|
||||
)
|
||||
|
||||
metadatas_list = results.get("metadatas")
|
||||
metadatas = (
|
||||
metadatas_list[0] if metadatas_list and "metadatas" in include_strings else []
|
||||
)
|
||||
|
||||
distances_list = results.get("distances")
|
||||
distances = (
|
||||
distances_list[0] if distances_list and "distances" in include_strings else []
|
||||
)
|
||||
|
||||
for i, doc_id in enumerate(ids):
|
||||
if not distances or i >= len(distances):
|
||||
continue
|
||||
|
||||
distance = distances[i]
|
||||
score = _convert_distance_to_score(
|
||||
distance=distance, distance_metric=distance_metric
|
||||
)
|
||||
|
||||
if score_threshold and score < score_threshold:
|
||||
continue
|
||||
|
||||
result: SearchResult = {
|
||||
"id": doc_id,
|
||||
"content": documents[i] if documents and i < len(documents) else "",
|
||||
"metadata": dict(metadatas[i])
|
||||
if metadatas and i < len(metadatas) and metadatas[i] is not None
|
||||
else {},
|
||||
"score": score,
|
||||
}
|
||||
search_results.append(result)
|
||||
|
||||
return search_results
|
||||
|
||||
|
||||
def _process_query_results(
|
||||
collection: Collection | AsyncCollection,
|
||||
results: QueryResult,
|
||||
params: ExtractedSearchParams,
|
||||
) -> list[SearchResult]:
|
||||
"""Process ChromaDB query results and convert to SearchResult format.
|
||||
|
||||
Args:
|
||||
collection: The ChromaDB collection (sync or async) that was queried.
|
||||
results: Raw query results from ChromaDB.
|
||||
params: The search parameters used for the query.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
"""
|
||||
|
||||
distance_metric = cast(
|
||||
Literal["l2", "cosine", "ip"],
|
||||
collection.metadata.get("hnsw:space", "l2") if collection.metadata else "l2",
|
||||
)
|
||||
|
||||
return _convert_chromadb_results_to_search_results(
|
||||
results=results,
|
||||
include=params.include,
|
||||
distance_metric=distance_metric,
|
||||
score_threshold=params.score_threshold,
|
||||
)
|
||||
|
||||
|
||||
def _is_ipv4_pattern(name: str) -> bool:
|
||||
"""Check if a string matches an IPv4 address pattern.
|
||||
|
||||
Args:
|
||||
name: The string to check
|
||||
|
||||
Returns:
|
||||
True if the string matches an IPv4 pattern, False otherwise
|
||||
"""
|
||||
return bool(IPV4_PATTERN.match(name))
|
||||
|
||||
|
||||
def _sanitize_collection_name(
|
||||
name: str | None, max_collection_length: int = MAX_COLLECTION_LENGTH
|
||||
) -> str:
|
||||
"""Sanitize a collection name to meet ChromaDB requirements.
|
||||
|
||||
Requirements:
|
||||
1. 3-63 characters long
|
||||
2. Starts and ends with alphanumeric character
|
||||
3. Contains only alphanumeric characters, underscores, or hyphens
|
||||
4. No consecutive periods
|
||||
5. Not a valid IPv4 address
|
||||
|
||||
Args:
|
||||
name: The original collection name to sanitize
|
||||
max_collection_length: Maximum allowed length for the collection name
|
||||
|
||||
Returns:
|
||||
A sanitized collection name that meets ChromaDB requirements
|
||||
"""
|
||||
if not name:
|
||||
return DEFAULT_COLLECTION
|
||||
|
||||
if _is_ipv4_pattern(name):
|
||||
name = f"ip_{name}"
|
||||
|
||||
sanitized = INVALID_CHARS_PATTERN.sub("_", name)
|
||||
|
||||
if not sanitized[0].isalnum():
|
||||
sanitized = "a" + sanitized
|
||||
|
||||
if not sanitized[-1].isalnum():
|
||||
sanitized = sanitized[:-1] + "z"
|
||||
|
||||
if len(sanitized) < MIN_COLLECTION_LENGTH:
|
||||
sanitized += "x" * (MIN_COLLECTION_LENGTH - len(sanitized))
|
||||
if len(sanitized) > max_collection_length:
|
||||
sanitized = sanitized[:max_collection_length]
|
||||
if not sanitized[-1].isalnum():
|
||||
sanitized = sanitized[:-1] + "z"
|
||||
|
||||
return sanitized
|
||||
1
lib/core/src/crewai/core/rag/config/__init__.py
Normal file
1
lib/core/src/crewai/core/rag/config/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""RAG client configuration management using ContextVars for thread-safe provider switching."""
|
||||
19
lib/core/src/crewai/core/rag/config/base.py
Normal file
19
lib/core/src/crewai/core/rag/config/base.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Base configuration class for RAG providers."""
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Any
|
||||
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
from crewai.rag.config.optional_imports.types import SupportedProvider
|
||||
|
||||
|
||||
@pyd_dataclass(frozen=True)
|
||||
class BaseRagConfig:
|
||||
"""Base class for RAG configuration with Pydantic serialization support."""
|
||||
|
||||
provider: SupportedProvider = field(init=False)
|
||||
embedding_function: Any | None = field(default=None)
|
||||
limit: int = field(default=5)
|
||||
score_threshold: float = field(default=0.6)
|
||||
batch_size: int = field(default=100)
|
||||
8
lib/core/src/crewai/core/rag/config/constants.py
Normal file
8
lib/core/src/crewai/core/rag/config/constants.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Constants for RAG configuration."""
|
||||
|
||||
from typing import Final
|
||||
|
||||
DISCRIMINATOR: Final[str] = "provider"
|
||||
|
||||
DEFAULT_RAG_CONFIG_PATH: Final[str] = "crewai.rag.chromadb.config"
|
||||
DEFAULT_RAG_CONFIG_CLASS: Final[str] = "ChromaDBConfig"
|
||||
@@ -0,0 +1 @@
|
||||
"""Optional imports for RAG configuration providers."""
|
||||
26
lib/core/src/crewai/core/rag/config/optional_imports/base.py
Normal file
26
lib/core/src/crewai/core/rag/config/optional_imports/base.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Base classes for missing provider configurations."""
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
|
||||
@pyd_dataclass(config=ConfigDict(extra="forbid"))
|
||||
class _MissingProvider:
|
||||
"""Base class for missing provider configurations.
|
||||
|
||||
Raises RuntimeError when instantiated to indicate missing dependencies.
|
||||
"""
|
||||
|
||||
provider: Literal["chromadb", "qdrant", "__missing__"] = field(
|
||||
default="__missing__"
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Raises error indicating the provider is not installed."""
|
||||
raise RuntimeError(
|
||||
f"provider '{self.provider}' requested but not installed. "
|
||||
f"Install the extra: `uv add crewai'[{self.provider}]'`."
|
||||
)
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Protocol definitions for RAG factory modules."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.rag.chromadb.client import ChromaDBClient
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.qdrant.client import QdrantClient
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
|
||||
|
||||
class ChromaFactoryModule(Protocol):
|
||||
"""Protocol for ChromaDB factory module."""
|
||||
|
||||
def create_client(self, config: ChromaDBConfig) -> ChromaDBClient:
|
||||
"""Creates a ChromaDB client from configuration."""
|
||||
...
|
||||
|
||||
|
||||
class QdrantFactoryModule(Protocol):
|
||||
"""Protocol for Qdrant factory module."""
|
||||
|
||||
def create_client(self, config: QdrantConfig) -> QdrantClient:
|
||||
"""Creates a Qdrant client from configuration."""
|
||||
...
|
||||
@@ -0,0 +1,23 @@
|
||||
"""Provider-specific missing configuration classes."""
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
from crewai.rag.config.optional_imports.base import _MissingProvider
|
||||
|
||||
|
||||
@pyd_dataclass(config=ConfigDict(extra="forbid"))
|
||||
class MissingChromaDBConfig(_MissingProvider):
|
||||
"""Placeholder for missing ChromaDB configuration."""
|
||||
|
||||
provider: Literal["chromadb"] = field(default="chromadb")
|
||||
|
||||
|
||||
@pyd_dataclass(config=ConfigDict(extra="forbid"))
|
||||
class MissingQdrantConfig(_MissingProvider):
|
||||
"""Placeholder for missing Qdrant configuration."""
|
||||
|
||||
provider: Literal["qdrant"] = field(default="qdrant")
|
||||
@@ -0,0 +1,8 @@
|
||||
"""Type definitions for optional imports."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
SupportedProvider = Annotated[
|
||||
Literal["chromadb", "qdrant"],
|
||||
"Supported RAG provider types, add providers here as they become available",
|
||||
]
|
||||
35
lib/core/src/crewai/core/rag/config/types.py
Normal file
35
lib/core/src/crewai/core/rag/config/types.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Type definitions for RAG configuration."""
|
||||
|
||||
from typing import TYPE_CHECKING, Annotated, TypeAlias
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.config.constants import DISCRIMINATOR
|
||||
|
||||
# Linter freaks out on conditional imports, assigning in the type checking fixes it
|
||||
if TYPE_CHECKING:
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig as ChromaDBConfig_
|
||||
|
||||
ChromaDBConfig = ChromaDBConfig_
|
||||
from crewai.rag.qdrant.config import QdrantConfig as QdrantConfig_
|
||||
|
||||
QdrantConfig = QdrantConfig_
|
||||
else:
|
||||
try:
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
except ImportError:
|
||||
from crewai.rag.config.optional_imports.providers import (
|
||||
MissingChromaDBConfig as ChromaDBConfig,
|
||||
)
|
||||
|
||||
try:
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
except ImportError:
|
||||
from crewai.rag.config.optional_imports.providers import (
|
||||
MissingQdrantConfig as QdrantConfig,
|
||||
)
|
||||
|
||||
SupportedProviderConfig: TypeAlias = ChromaDBConfig | QdrantConfig
|
||||
RagConfigType: TypeAlias = Annotated[
|
||||
SupportedProviderConfig, Field(discriminator=DISCRIMINATOR)
|
||||
]
|
||||
86
lib/core/src/crewai/core/rag/config/utils.py
Normal file
86
lib/core/src/crewai/core/rag/config/utils.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""RAG client configuration utilities."""
|
||||
|
||||
from contextvars import ContextVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.rag.config.constants import (
|
||||
DEFAULT_RAG_CONFIG_CLASS,
|
||||
DEFAULT_RAG_CONFIG_PATH,
|
||||
)
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.core.utilities.import_utils import require
|
||||
|
||||
|
||||
class RagContext(BaseModel):
|
||||
"""Context holding RAG configuration and client instance."""
|
||||
|
||||
config: RagConfigType = Field(..., description="RAG provider configuration")
|
||||
client: BaseClient | None = Field(
|
||||
default=None, description="Instantiated RAG client"
|
||||
)
|
||||
|
||||
|
||||
_rag_context: ContextVar[RagContext | None] = ContextVar("_rag_context", default=None)
|
||||
|
||||
|
||||
def set_rag_config(config: RagConfigType) -> None:
|
||||
"""Set global RAG client configuration and instantiate the client.
|
||||
|
||||
Args:
|
||||
config: The RAG client configuration (ChromaDBConfig).
|
||||
"""
|
||||
client = create_client(config)
|
||||
context = RagContext(config=config, client=client)
|
||||
_rag_context.set(context)
|
||||
|
||||
|
||||
def get_rag_config() -> RagConfigType:
|
||||
"""Get current RAG configuration.
|
||||
|
||||
Returns:
|
||||
The current RAG configuration object.
|
||||
"""
|
||||
context = _rag_context.get()
|
||||
if context is None:
|
||||
module = require(DEFAULT_RAG_CONFIG_PATH, purpose="RAG configuration")
|
||||
config_class = getattr(module, DEFAULT_RAG_CONFIG_CLASS)
|
||||
default_config = config_class()
|
||||
set_rag_config(default_config)
|
||||
context = _rag_context.get()
|
||||
|
||||
if context is None or context.config is None:
|
||||
raise ValueError(
|
||||
"RAG configuration is not set. Please set the RAG config first."
|
||||
)
|
||||
|
||||
return context.config
|
||||
|
||||
|
||||
def get_rag_client() -> BaseClient:
|
||||
"""Get the current RAG client instance.
|
||||
|
||||
Returns:
|
||||
The current RAG client, creating one if needed.
|
||||
"""
|
||||
context = _rag_context.get()
|
||||
if context is None:
|
||||
get_rag_config()
|
||||
context = _rag_context.get()
|
||||
|
||||
if context and context.client is None:
|
||||
context.client = create_client(context.config)
|
||||
|
||||
if context is None or context.client is None:
|
||||
raise ValueError(
|
||||
"RAG client is not configured. Please set the RAG config first."
|
||||
)
|
||||
|
||||
return context.client
|
||||
|
||||
|
||||
def clear_rag_config() -> None:
|
||||
"""Clear the current RAG configuration and client, reverting to defaults."""
|
||||
_rag_context.set(None)
|
||||
1
lib/core/src/crewai/core/rag/core/__init__.py
Normal file
1
lib/core/src/crewai/core/rag/core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Core abstract base classes and protocols for RAG systems."""
|
||||
448
lib/core/src/crewai/core/rag/core/base_client.py
Normal file
448
lib/core/src/crewai/core/rag/core/base_client.py
Normal file
@@ -0,0 +1,448 @@
|
||||
"""Protocol for vector database client implementations."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Annotated, Any, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from typing_extensions import Required, TypedDict, Unpack
|
||||
|
||||
from crewai.rag.types import (
|
||||
BaseRecord,
|
||||
EmbeddingFunction,
|
||||
SearchResult,
|
||||
)
|
||||
|
||||
|
||||
class BaseCollectionParams(TypedDict):
|
||||
"""Base parameters for collection operations.
|
||||
|
||||
Attributes:
|
||||
collection_name: The name of the collection/index to operate on.
|
||||
"""
|
||||
|
||||
collection_name: Required[
|
||||
Annotated[
|
||||
str,
|
||||
"Name of the collection/index. Implementations may have specific constraints (e.g., character limits, allowed characters, case sensitivity).",
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
class BaseCollectionAddParams(BaseCollectionParams, total=False):
|
||||
"""Parameters for adding documents to a collection.
|
||||
|
||||
Extends BaseCollectionParams with document-specific fields.
|
||||
|
||||
Attributes:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dictionaries containing document data.
|
||||
batch_size: Optional batch size for processing documents to avoid token limits.
|
||||
"""
|
||||
|
||||
documents: Required[list[BaseRecord]]
|
||||
batch_size: int
|
||||
|
||||
|
||||
class BaseCollectionSearchParams(BaseCollectionParams, total=False):
|
||||
"""Parameters for searching within a collection.
|
||||
|
||||
Extends BaseCollectionParams with search-specific optional fields.
|
||||
All fields except collection_name and query are optional.
|
||||
|
||||
Attributes:
|
||||
query: The text query to search for (required).
|
||||
limit: Maximum number of results to return.
|
||||
metadata_filter: Filter results by metadata fields.
|
||||
score_threshold: Minimum similarity score for results (0-1).
|
||||
"""
|
||||
|
||||
query: Required[str]
|
||||
limit: int
|
||||
metadata_filter: dict[str, Any] | None
|
||||
score_threshold: float
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class BaseClient(Protocol):
|
||||
"""Protocol for vector store client implementations.
|
||||
|
||||
This protocol defines the interface that all vector store client implementations
|
||||
must follow. It provides a consistent API for storing and retrieving
|
||||
documents with their vector embeddings across different vector database
|
||||
backends (e.g., Qdrant, ChromaDB, Weaviate). Implementing classes should
|
||||
handle connection management, data persistence, and vector similarity
|
||||
search operations specific to their backend.
|
||||
|
||||
Implementation Guidelines:
|
||||
Implementations should accept BaseClientParams in their constructor to allow
|
||||
passing pre-configured client instances:
|
||||
|
||||
class MyVectorClient:
|
||||
def __init__(self, client: Any | None = None, **kwargs):
|
||||
if client:
|
||||
self.client = client
|
||||
else:
|
||||
self.client = self._create_default_client(**kwargs)
|
||||
|
||||
Notes:
|
||||
This protocol replaces the former BaseRAGStorage abstraction,
|
||||
providing a cleaner interface for vector store operations.
|
||||
|
||||
Attributes:
|
||||
embedding_function: Callable that takes a list of text strings
|
||||
and returns a list of embedding vectors. Implementations
|
||||
should always provide a default embedding function.
|
||||
client: The underlying vector database client instance. This could be
|
||||
passed via BaseClientParams during initialization or created internally.
|
||||
"""
|
||||
|
||||
client: Any
|
||||
embedding_function: EmbeddingFunction
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""Generate Pydantic core schema for BaseClient Protocol.
|
||||
|
||||
This allows the Protocol to be used in Pydantic models without
|
||||
requiring arbitrary_types_allowed=True.
|
||||
"""
|
||||
return core_schema.any_schema()
|
||||
|
||||
@abstractmethod
|
||||
def create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Create a new collection/index in the vector database.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to create. Must be unique within
|
||||
the vector database instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection name already exists.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def acreate_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Create a new collection/index in the vector database asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to create. Must be unique within
|
||||
the vector database instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection name already exists.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_or_create_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> Any:
|
||||
"""Get an existing collection or create it if it doesn't exist.
|
||||
|
||||
This method provides a convenient way to ensure a collection exists
|
||||
without having to check for its existence first.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to get or create.
|
||||
|
||||
Returns:
|
||||
A collection object whose type depends on the backend implementation.
|
||||
This could be a collection reference, ID, or client object.
|
||||
|
||||
Raises:
|
||||
ValueError: If unable to create the collection.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def aget_or_create_collection(
|
||||
self, **kwargs: Unpack[BaseCollectionParams]
|
||||
) -> Any:
|
||||
"""Get an existing collection or create it if it doesn't exist asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to get or create.
|
||||
|
||||
Returns:
|
||||
A collection object whose type depends on the backend implementation.
|
||||
|
||||
Raises:
|
||||
ValueError: If unable to create the collection.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection.
|
||||
|
||||
This method performs an upsert operation - if a document with the same ID
|
||||
already exists, it will be updated with the new content and metadata.
|
||||
|
||||
Implementations should handle embedding generation internally based on
|
||||
the configured embedding function.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing:
|
||||
- content: The text content (required)
|
||||
- doc_id: Optional unique identifier (auto-generated from content hash if missing)
|
||||
- metadata: Optional metadata dictionary
|
||||
Embeddings will be generated automatically.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
TypeError: If documents are not BaseRecord dict instances.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
|
||||
Example:
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>> from crewai.rag.types import BaseRecord
|
||||
>>> client = ChromaDBClient()
|
||||
>>>
|
||||
>>> records: list[BaseRecord] = [
|
||||
... {
|
||||
... "content": "Machine learning basics",
|
||||
... "metadata": {"source": "file3", "topic": "ML"}
|
||||
... },
|
||||
... {
|
||||
... "doc_id": "custom_id",
|
||||
... "content": "Deep learning fundamentals",
|
||||
... "metadata": {"source": "file4", "topic": "DL"}
|
||||
... }
|
||||
... ]
|
||||
>>> client.add_documents(collection_name="my_docs", documents=records)
|
||||
>>>
|
||||
>>> records_with_id: list[BaseRecord] = [
|
||||
... {
|
||||
... "doc_id": "nlp_001",
|
||||
... "content": "Advanced NLP techniques",
|
||||
... "metadata": {"source": "file5", "topic": "NLP"}
|
||||
... }
|
||||
... ]
|
||||
>>> client.add_documents(collection_name="my_docs", documents=records_with_id)
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection asynchronously.
|
||||
|
||||
Implementations should handle embedding generation internally based on
|
||||
the configured embedding function.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing:
|
||||
- content: The text content (required)
|
||||
- doc_id: Optional unique identifier (auto-generated from content hash if missing)
|
||||
- metadata: Optional metadata dictionary
|
||||
Embeddings will be generated automatically.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
TypeError: If documents are not BaseRecord dict instances.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>> from crewai.rag.types import BaseRecord
|
||||
>>>
|
||||
>>> async def add_documents():
|
||||
... client = ChromaDBClient()
|
||||
...
|
||||
... records: list[BaseRecord] = [
|
||||
... {
|
||||
... "doc_id": "doc2",
|
||||
... "content": "Async operations in Python",
|
||||
... "metadata": {"source": "file2", "topic": "async"}
|
||||
... }
|
||||
... ]
|
||||
... await client.aadd_documents(collection_name="my_docs", documents=records)
|
||||
...
|
||||
>>> asyncio.run(add_documents())
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query.
|
||||
|
||||
Performs a vector similarity search to find the most similar documents
|
||||
to the provided query.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to search in.
|
||||
query: The text query to search for. The implementation handles
|
||||
embedding generation internally.
|
||||
limit: Maximum number of results to return. Defaults to 10.
|
||||
metadata_filter: Optional metadata filter to apply to the search. The exact
|
||||
format depends on the backend, but typically supports equality
|
||||
and range queries on metadata fields.
|
||||
score_threshold: Optional minimum similarity score threshold. Only
|
||||
results with scores >= this threshold will be returned. The
|
||||
score interpretation depends on the distance metric used.
|
||||
|
||||
Returns:
|
||||
A list of SearchResult dictionaries ordered by similarity score in
|
||||
descending order. Each result contains:
|
||||
- id: Document ID
|
||||
- content: Document text content
|
||||
- metadata: Document metadata
|
||||
- score: Similarity score (0-1, higher is better)
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
|
||||
Example:
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>> client = ChromaDBClient()
|
||||
>>>
|
||||
>>> results = client.search(
|
||||
... collection_name="my_docs",
|
||||
... query="What is machine learning?",
|
||||
... limit=5,
|
||||
... metadata_filter={"source": "file1"},
|
||||
... score_threshold=0.7
|
||||
... )
|
||||
>>> for result in results:
|
||||
... print(f"{result['id']}: {result['score']:.2f}")
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def asearch(
|
||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to search in.
|
||||
query: The text query to search for. The implementation handles
|
||||
embedding generation internally.
|
||||
limit: Maximum number of results to return. Defaults to 10.
|
||||
metadata_filter: Optional metadata filter to apply to the search.
|
||||
score_threshold: Optional minimum similarity score threshold.
|
||||
|
||||
Returns:
|
||||
A list of SearchResult dictionaries ordered by similarity score.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>>
|
||||
>>> async def search_documents():
|
||||
... client = ChromaDBClient()
|
||||
... results = await client.asearch(
|
||||
... collection_name="my_docs",
|
||||
... query="Python programming best practices",
|
||||
... limit=5,
|
||||
... metadata_filter={"source": "file1"},
|
||||
... score_threshold=0.7
|
||||
... )
|
||||
... for result in results:
|
||||
... print(f"{result['id']}: {result['score']:.2f}")
|
||||
...
|
||||
>>> asyncio.run(search_documents())
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data.
|
||||
|
||||
This operation is irreversible and will permanently remove all documents,
|
||||
embeddings, and metadata associated with the collection.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If the collection doesn't exist.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
|
||||
Example:
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>> client = ChromaDBClient()
|
||||
>>> client.delete_collection(collection_name="old_docs")
|
||||
>>> print("Collection 'old_docs' deleted successfully")
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If the collection doesn't exist.
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>>
|
||||
>>> async def delete_old_collection():
|
||||
... client = ChromaDBClient()
|
||||
... await client.adelete_collection(collection_name="old_docs")
|
||||
... print("Collection 'old_docs' deleted successfully")
|
||||
...
|
||||
>>> asyncio.run(delete_old_collection())
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data.
|
||||
|
||||
This method provides a way to completely clear the vector database,
|
||||
removing all collections and their contents. Use with caution as
|
||||
this operation is irreversible.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
PermissionError: If the operation is not allowed by the backend.
|
||||
|
||||
Example:
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>> client = ChromaDBClient()
|
||||
>>> client.reset()
|
||||
>>> print("Vector database completely reset - all data deleted")
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def areset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data asynchronously.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to the vector database backend.
|
||||
PermissionError: If the operation is not allowed by the backend.
|
||||
|
||||
Example:
|
||||
>>> import asyncio
|
||||
>>> from crewai.rag.chromadb.client import ChromaDBClient
|
||||
>>>
|
||||
>>> async def reset_database():
|
||||
... client = ChromaDBClient()
|
||||
... await client.areset()
|
||||
... print("Vector database completely reset - all data deleted")
|
||||
...
|
||||
>>> asyncio.run(reset_database())
|
||||
"""
|
||||
...
|
||||
149
lib/core/src/crewai/core/rag/core/base_embeddings_callable.py
Normal file
149
lib/core/src/crewai/core/rag/core/base_embeddings_callable.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Base embeddings callable utilities for RAG systems."""
|
||||
|
||||
from typing import Protocol, TypeVar, runtime_checkable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from crewai.rag.core.types import (
|
||||
Embeddable,
|
||||
Embedding,
|
||||
Embeddings,
|
||||
PyEmbedding,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
D = TypeVar("D", bound=Embeddable, contravariant=True)
|
||||
|
||||
|
||||
def normalize_embeddings(
|
||||
target: Embedding | list[Embedding] | PyEmbedding | list[PyEmbedding],
|
||||
) -> Embeddings | None:
|
||||
"""Normalize various embedding formats to a standard list of numpy arrays.
|
||||
|
||||
Args:
|
||||
target: Input embeddings in various formats (list of floats, list of lists,
|
||||
numpy array, or list of numpy arrays).
|
||||
|
||||
Returns:
|
||||
Normalized embeddings as a list of numpy arrays, or None if input is None.
|
||||
|
||||
Raises:
|
||||
ValueError: If embeddings are empty or in an unsupported format.
|
||||
"""
|
||||
if isinstance(target, np.ndarray):
|
||||
if target.ndim == 1:
|
||||
return [target.astype(np.float32)]
|
||||
if target.ndim == 2:
|
||||
return [row.astype(np.float32) for row in target]
|
||||
raise ValueError(f"Unsupported numpy array shape: {target.shape}")
|
||||
|
||||
first = target[0]
|
||||
if isinstance(first, (int, float)) and not isinstance(first, bool):
|
||||
return [np.array(target, dtype=np.float32)]
|
||||
if isinstance(first, list):
|
||||
return [np.array(emb, dtype=np.float32) for emb in target]
|
||||
if isinstance(first, np.ndarray):
|
||||
return [emb.astype(np.float32) for emb in target] # type: ignore[union-attr]
|
||||
|
||||
raise ValueError(f"Unsupported embeddings format: {type(first)}")
|
||||
|
||||
|
||||
def maybe_cast_one_to_many(target: T | list[T] | None) -> list[T] | None:
|
||||
"""Cast a single item to a list if needed.
|
||||
|
||||
Args:
|
||||
target: A single item or list of items.
|
||||
|
||||
Returns:
|
||||
A list of items or None if input is None.
|
||||
"""
|
||||
if target is None:
|
||||
return None
|
||||
return target if isinstance(target, list) else [target]
|
||||
|
||||
|
||||
def validate_embeddings(embeddings: Embeddings) -> Embeddings:
|
||||
"""Validate embeddings format and content.
|
||||
|
||||
Args:
|
||||
embeddings: List of numpy arrays to validate.
|
||||
|
||||
Returns:
|
||||
Validated embeddings.
|
||||
|
||||
Raises:
|
||||
ValueError: If embeddings format or content is invalid.
|
||||
"""
|
||||
if not isinstance(embeddings, list):
|
||||
raise ValueError(
|
||||
f"Expected embeddings to be a list, got {type(embeddings).__name__}"
|
||||
)
|
||||
if len(embeddings) == 0:
|
||||
raise ValueError(
|
||||
f"Expected embeddings to be a list with at least one item, got {len(embeddings)} embeddings"
|
||||
)
|
||||
if not all(isinstance(e, np.ndarray) for e in embeddings):
|
||||
raise ValueError(
|
||||
"Expected each embedding in the embeddings to be a numpy array"
|
||||
)
|
||||
for i, embedding in enumerate(embeddings):
|
||||
if embedding.ndim == 0:
|
||||
raise ValueError(
|
||||
f"Expected a 1-dimensional array, got a 0-dimensional array {embedding}"
|
||||
)
|
||||
if embedding.size == 0:
|
||||
raise ValueError(
|
||||
f"Expected each embedding to be a 1-dimensional numpy array with at least 1 value. "
|
||||
f"Got an array with no values at position {i}"
|
||||
)
|
||||
if not all(
|
||||
isinstance(value, (np.integer, float, np.floating))
|
||||
and not isinstance(value, bool)
|
||||
for value in embedding
|
||||
):
|
||||
raise ValueError(
|
||||
f"Expected embedding to contain numeric values, got non-numeric values at position {i}"
|
||||
)
|
||||
return embeddings
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class EmbeddingFunction(Protocol[D]):
|
||||
"""Protocol for embedding functions.
|
||||
|
||||
Embedding functions convert input data (documents or images) into vector embeddings.
|
||||
"""
|
||||
|
||||
def __call__(self, input: D) -> Embeddings:
|
||||
"""Convert input data to embeddings.
|
||||
|
||||
Args:
|
||||
input: Input data to embed (documents or images).
|
||||
|
||||
Returns:
|
||||
List of numpy arrays representing the embeddings.
|
||||
"""
|
||||
...
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
"""Wrap __call__ method to normalize and validate embeddings."""
|
||||
super().__init_subclass__()
|
||||
original_call = cls.__call__
|
||||
|
||||
def wrapped_call(self: EmbeddingFunction[D], input: D) -> Embeddings:
|
||||
result = original_call(self, input)
|
||||
if result is None:
|
||||
raise ValueError("Embedding function returned None")
|
||||
normalized = normalize_embeddings(result)
|
||||
if normalized is None:
|
||||
raise ValueError("Normalization returned None for non-None input")
|
||||
return validate_embeddings(normalized)
|
||||
|
||||
cls.__call__ = wrapped_call # type: ignore[method-assign]
|
||||
|
||||
def embed_query(self, input: D) -> Embeddings:
|
||||
"""
|
||||
Get the embeddings for a query input.
|
||||
This method is optional, and if not implemented, the default behavior is to call __call__.
|
||||
"""
|
||||
return self.__call__(input=input)
|
||||
@@ -0,0 +1,23 @@
|
||||
"""Base class for embedding providers."""
|
||||
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
|
||||
T = TypeVar("T", bound=EmbeddingFunction)
|
||||
|
||||
|
||||
class BaseEmbeddingsProvider(BaseSettings, Generic[T]):
|
||||
"""Abstract base class for embedding providers.
|
||||
|
||||
This class provides a common interface for dynamically loading and building
|
||||
embedding functions from various providers.
|
||||
"""
|
||||
|
||||
model_config = SettingsConfigDict(extra="allow", populate_by_name=True)
|
||||
embedding_callable: type[T] = Field(
|
||||
..., description="The embedding function class to use"
|
||||
)
|
||||
26
lib/core/src/crewai/core/rag/core/exceptions.py
Normal file
26
lib/core/src/crewai/core/rag/core/exceptions.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Core exceptions for RAG module."""
|
||||
|
||||
|
||||
class ClientMethodMismatchError(TypeError):
|
||||
"""Raised when a method is called with the wrong client type.
|
||||
|
||||
Typically used when a sync method is called with an async client,
|
||||
or vice versa.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, method_name: str, expected_client: str, alt_method: str, alt_client: str
|
||||
) -> None:
|
||||
"""Create a ClientMethodMismatchError.
|
||||
|
||||
Args:
|
||||
method_name: Method that was called incorrectly.
|
||||
expected_client: Required client type.
|
||||
alt_method: Suggested alternative method.
|
||||
alt_client: Client type for the alternative method.
|
||||
"""
|
||||
message = (
|
||||
f"Method {method_name}() requires a {expected_client}. "
|
||||
f"Use {alt_method}() for {alt_client}."
|
||||
)
|
||||
super().__init__(message)
|
||||
28
lib/core/src/crewai/core/rag/core/types.py
Normal file
28
lib/core/src/crewai/core/rag/core/types.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Core type definitions for RAG systems."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TypeVar
|
||||
|
||||
import numpy as np
|
||||
from numpy import floating, integer, number
|
||||
from numpy.typing import NDArray
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
PyEmbedding = Sequence[float] | Sequence[int]
|
||||
PyEmbeddings = list[PyEmbedding]
|
||||
Embedding = NDArray[np.int32 | np.float32]
|
||||
Embeddings = list[Embedding]
|
||||
|
||||
Documents = list[str]
|
||||
Images = list[np.ndarray]
|
||||
Embeddable = Documents | Images
|
||||
|
||||
ScalarType = TypeVar("ScalarType", bound=np.generic)
|
||||
IntegerType = TypeVar("IntegerType", bound=integer)
|
||||
FloatingType = TypeVar("FloatingType", bound=floating)
|
||||
NumberType = TypeVar("NumberType", bound=number)
|
||||
|
||||
DType32 = TypeVar("DType32", np.int32, np.float32)
|
||||
DType64 = TypeVar("DType64", np.int64, np.float64)
|
||||
DTypeCommon = TypeVar("DTypeCommon", np.int32, np.int64, np.float32, np.float64)
|
||||
1
lib/core/src/crewai/core/rag/embeddings/__init__.py
Normal file
1
lib/core/src/crewai/core/rag/embeddings/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Embedding components for RAG infrastructure."""
|
||||
392
lib/core/src/crewai/core/rag/embeddings/factory.py
Normal file
392
lib/core/src/crewai/core/rag/embeddings/factory.py
Normal file
@@ -0,0 +1,392 @@
|
||||
"""Factory functions for creating embedding providers and functions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, TypeVar, overload
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.core.utilities.import_utils import import_and_validate_definition
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.instructor_embedding_function import (
|
||||
InstructorEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.jina_embedding_function import (
|
||||
JinaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
|
||||
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
|
||||
OpenCLIPEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
|
||||
RoboflowEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
|
||||
SentenceTransformerEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
Text2VecEmbeddingFunction,
|
||||
)
|
||||
|
||||
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
|
||||
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
|
||||
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
|
||||
from crewai.rag.embeddings.providers.google.types import (
|
||||
GenerativeAiProviderSpec,
|
||||
VertexAIProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.huggingface.types import (
|
||||
HuggingFaceProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
|
||||
WatsonXEmbeddingFunction,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.ibm.types import (
|
||||
WatsonProviderSpec,
|
||||
WatsonXProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
|
||||
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
|
||||
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
|
||||
from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec
|
||||
from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec
|
||||
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
|
||||
from crewai.rag.embeddings.providers.openclip.types import OpenCLIPProviderSpec
|
||||
from crewai.rag.embeddings.providers.roboflow.types import RoboflowProviderSpec
|
||||
from crewai.rag.embeddings.providers.sentence_transformer.types import (
|
||||
SentenceTransformerProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
|
||||
from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
|
||||
VoyageAIEmbeddingFunction,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
|
||||
|
||||
T = TypeVar("T", bound=EmbeddingFunction)
|
||||
|
||||
|
||||
PROVIDER_PATHS = {
|
||||
"azure": "crewai.rag.embeddings.providers.microsoft.azure.AzureProvider",
|
||||
"amazon-bedrock": "crewai.rag.embeddings.providers.aws.bedrock.BedrockProvider",
|
||||
"cohere": "crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider",
|
||||
"custom": "crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider",
|
||||
"google-generativeai": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
|
||||
"google-vertex": "crewai.rag.embeddings.providers.google.vertex.VertexAIProvider",
|
||||
"huggingface": "crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider",
|
||||
"instructor": "crewai.rag.embeddings.providers.instructor.instructor_provider.InstructorProvider",
|
||||
"jina": "crewai.rag.embeddings.providers.jina.jina_provider.JinaProvider",
|
||||
"ollama": "crewai.rag.embeddings.providers.ollama.ollama_provider.OllamaProvider",
|
||||
"onnx": "crewai.rag.embeddings.providers.onnx.onnx_provider.ONNXProvider",
|
||||
"openai": "crewai.rag.embeddings.providers.openai.openai_provider.OpenAIProvider",
|
||||
"openclip": "crewai.rag.embeddings.providers.openclip.openclip_provider.OpenCLIPProvider",
|
||||
"roboflow": "crewai.rag.embeddings.providers.roboflow.roboflow_provider.RoboflowProvider",
|
||||
"sentence-transformer": "crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider.SentenceTransformerProvider",
|
||||
"text2vec": "crewai.rag.embeddings.providers.text2vec.text2vec_provider.Text2VecProvider",
|
||||
"voyageai": "crewai.rag.embeddings.providers.voyageai.voyageai_provider.VoyageAIProvider",
|
||||
"watson": "crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider", # Deprecated alias
|
||||
"watsonx": "crewai.rag.embeddings.providers.ibm.watsonx.WatsonXProvider",
|
||||
}
|
||||
|
||||
|
||||
def build_embedder_from_provider(provider: BaseEmbeddingsProvider[T]) -> T:
|
||||
"""Build an embedding function instance from a provider.
|
||||
|
||||
Args:
|
||||
provider: The embedding provider configuration.
|
||||
|
||||
Returns:
|
||||
An instance of the specified embedding function type.
|
||||
"""
|
||||
return provider.embedding_callable(
|
||||
**provider.model_dump(exclude={"embedding_callable"})
|
||||
)
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: AzureProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: BedrockProviderSpec,
|
||||
) -> AmazonBedrockEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: CohereProviderSpec) -> CohereEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: CustomProviderSpec) -> EmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: GenerativeAiProviderSpec,
|
||||
) -> GoogleGenerativeAiEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: HuggingFaceProviderSpec,
|
||||
) -> HuggingFaceEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: OllamaProviderSpec) -> OllamaEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: VertexAIProviderSpec,
|
||||
) -> GoogleVertexEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: VoyageAIProviderSpec,
|
||||
) -> VoyageAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: WatsonXProviderSpec) -> WatsonXEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
@deprecated(
|
||||
'The "WatsonProviderSpec" provider spec is deprecated and will be removed in v1.0.0. Use "WatsonXProviderSpec" instead.'
|
||||
)
|
||||
def build_embedder_from_dict(spec: WatsonProviderSpec) -> WatsonXEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: SentenceTransformerProviderSpec,
|
||||
) -> SentenceTransformerEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: InstructorProviderSpec,
|
||||
) -> InstructorEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: JinaProviderSpec) -> JinaEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: RoboflowProviderSpec,
|
||||
) -> RoboflowEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: OpenCLIPProviderSpec,
|
||||
) -> OpenCLIPEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: Text2VecProviderSpec,
|
||||
) -> Text2VecEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
|
||||
|
||||
|
||||
def build_embedder_from_dict(spec):
|
||||
"""Build an embedding function instance from a dictionary specification.
|
||||
|
||||
Args:
|
||||
spec: A dictionary with 'provider' and 'config' keys.
|
||||
Example: {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "sk-...",
|
||||
"model_name": "text-embedding-3-small"
|
||||
}
|
||||
}
|
||||
|
||||
Returns:
|
||||
An instance of the appropriate embedding function.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is not recognized.
|
||||
"""
|
||||
provider_name = spec["provider"]
|
||||
if not provider_name:
|
||||
raise ValueError("Missing 'provider' key in specification")
|
||||
|
||||
if provider_name == "watson":
|
||||
warnings.warn(
|
||||
'The "watson" provider key is deprecated and will be removed in v1.0.0. '
|
||||
'Use "watsonx" instead.',
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if provider_name not in PROVIDER_PATHS:
|
||||
raise ValueError(
|
||||
f"Unknown provider: {provider_name}. Available providers: {list(PROVIDER_PATHS.keys())}"
|
||||
)
|
||||
|
||||
provider_path = PROVIDER_PATHS[provider_name]
|
||||
try:
|
||||
provider_class = import_and_validate_definition(provider_path)
|
||||
except (ImportError, AttributeError, ValueError) as e:
|
||||
raise ImportError(f"Failed to import provider {provider_name}: {e}") from e
|
||||
|
||||
provider_config = spec.get("config", {})
|
||||
|
||||
if provider_name == "custom" and "embedding_callable" not in provider_config:
|
||||
raise ValueError("Custom provider requires 'embedding_callable' in config")
|
||||
|
||||
provider = provider_class(**provider_config)
|
||||
return build_embedder_from_provider(provider)
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: BaseEmbeddingsProvider[T]) -> T: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: AzureProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: BedrockProviderSpec) -> AmazonBedrockEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: CohereProviderSpec) -> CohereEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: CustomProviderSpec) -> EmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(
|
||||
spec: GenerativeAiProviderSpec,
|
||||
) -> GoogleGenerativeAiEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: HuggingFaceProviderSpec) -> HuggingFaceEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: OllamaProviderSpec) -> OllamaEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: VertexAIProviderSpec) -> GoogleVertexEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: VoyageAIProviderSpec) -> VoyageAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: WatsonXProviderSpec) -> WatsonXEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
@deprecated(
|
||||
'The "WatsonProviderSpec" provider spec is deprecated and will be removed in v1.0.0. Use "WatsonXProviderSpec" instead.'
|
||||
)
|
||||
def build_embedder(spec: WatsonProviderSpec) -> WatsonXEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(
|
||||
spec: SentenceTransformerProviderSpec,
|
||||
) -> SentenceTransformerEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: InstructorProviderSpec) -> InstructorEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: JinaProviderSpec) -> JinaEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: RoboflowProviderSpec) -> RoboflowEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: OpenCLIPProviderSpec) -> OpenCLIPEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: Text2VecProviderSpec) -> Text2VecEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
|
||||
|
||||
|
||||
def build_embedder(spec):
|
||||
"""Build an embedding function from either a provider spec or a provider instance.
|
||||
|
||||
Args:
|
||||
spec: Either a provider specification dictionary or a provider instance.
|
||||
|
||||
Returns:
|
||||
An embedding function instance. If a typed provider is passed, returns
|
||||
the specific embedding function type.
|
||||
|
||||
Examples:
|
||||
# From dictionary specification
|
||||
embedder = build_embedder({
|
||||
"provider": "openai",
|
||||
"config": {"api_key": "sk-..."}
|
||||
})
|
||||
|
||||
# From provider instance
|
||||
provider = OpenAIProvider(api_key="sk-...")
|
||||
embedder = build_embedder(provider)
|
||||
"""
|
||||
if isinstance(spec, BaseEmbeddingsProvider):
|
||||
return build_embedder_from_provider(spec)
|
||||
return build_embedder_from_dict(spec)
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
get_embedding_function = build_embedder
|
||||
@@ -0,0 +1 @@
|
||||
"""Embedding provider implementations."""
|
||||
@@ -0,0 +1,13 @@
|
||||
"""AWS embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.aws.bedrock import BedrockProvider
|
||||
from crewai.rag.embeddings.providers.aws.types import (
|
||||
BedrockProviderConfig,
|
||||
BedrockProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BedrockProvider",
|
||||
"BedrockProviderConfig",
|
||||
"BedrockProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,53 @@
|
||||
"""Amazon Bedrock embeddings provider."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
def create_aws_session() -> Any:
|
||||
"""Create an AWS session for Bedrock.
|
||||
|
||||
Returns:
|
||||
boto3.Session: AWS session object
|
||||
|
||||
Raises:
|
||||
ImportError: If boto3 is not installed
|
||||
ValueError: If AWS session creation fails
|
||||
"""
|
||||
try:
|
||||
import boto3 # type: ignore[import]
|
||||
|
||||
return boto3.Session()
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"boto3 is required for amazon-bedrock embeddings. "
|
||||
"Install it with: uv add boto3"
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to create AWS session for amazon-bedrock. "
|
||||
f"Ensure AWS credentials are configured. Error: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
class BedrockProvider(BaseEmbeddingsProvider[AmazonBedrockEmbeddingFunction]):
|
||||
"""Amazon Bedrock embeddings provider."""
|
||||
|
||||
embedding_callable: type[AmazonBedrockEmbeddingFunction] = Field(
|
||||
default=AmazonBedrockEmbeddingFunction,
|
||||
description="Amazon Bedrock embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="amazon.titan-embed-text-v1",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_BEDROCK_MODEL_NAME",
|
||||
)
|
||||
session: Any = Field(
|
||||
default_factory=create_aws_session, description="AWS session object"
|
||||
)
|
||||
@@ -0,0 +1,19 @@
|
||||
"""Type definitions for AWS embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class BedrockProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Bedrock provider."""
|
||||
|
||||
model_name: Annotated[str, "amazon.titan-embed-text-v1"]
|
||||
session: Any
|
||||
|
||||
|
||||
class BedrockProviderSpec(TypedDict, total=False):
|
||||
"""Bedrock provider specification."""
|
||||
|
||||
provider: Required[Literal["amazon-bedrock"]]
|
||||
config: BedrockProviderConfig
|
||||
@@ -0,0 +1,13 @@
|
||||
"""Cohere embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.cohere.cohere_provider import CohereProvider
|
||||
from crewai.rag.embeddings.providers.cohere.types import (
|
||||
CohereProviderConfig,
|
||||
CohereProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CohereProvider",
|
||||
"CohereProviderConfig",
|
||||
"CohereProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,24 @@
|
||||
"""Cohere embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class CohereProvider(BaseEmbeddingsProvider[CohereEmbeddingFunction]):
|
||||
"""Cohere embeddings provider."""
|
||||
|
||||
embedding_callable: type[CohereEmbeddingFunction] = Field(
|
||||
default=CohereEmbeddingFunction, description="Cohere embedding function class"
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Cohere API key", validation_alias="EMBEDDINGS_COHERE_API_KEY"
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="large",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_COHERE_MODEL_NAME",
|
||||
)
|
||||
@@ -0,0 +1,19 @@
|
||||
"""Type definitions for Cohere embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class CohereProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Cohere provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "large"]
|
||||
|
||||
|
||||
class CohereProviderSpec(TypedDict, total=False):
|
||||
"""Cohere provider specification."""
|
||||
|
||||
provider: Required[Literal["cohere"]]
|
||||
config: CohereProviderConfig
|
||||
@@ -0,0 +1,13 @@
|
||||
"""Custom embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.custom.custom_provider import CustomProvider
|
||||
from crewai.rag.embeddings.providers.custom.types import (
|
||||
CustomProviderConfig,
|
||||
CustomProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CustomProvider",
|
||||
"CustomProviderConfig",
|
||||
"CustomProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,19 @@
|
||||
"""Custom embeddings provider for user-defined embedding functions."""
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.custom.embedding_callable import (
|
||||
CustomEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
class CustomProvider(BaseEmbeddingsProvider[CustomEmbeddingFunction]):
|
||||
"""Custom embeddings provider for user-defined embedding functions."""
|
||||
|
||||
embedding_callable: type[CustomEmbeddingFunction] = Field(
|
||||
..., description="Custom embedding function class"
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(extra="allow")
|
||||
@@ -0,0 +1,22 @@
|
||||
"""Custom embedding function base implementation."""
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.core.types import Documents, Embeddings
|
||||
|
||||
|
||||
class CustomEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""Base class for custom embedding functions.
|
||||
|
||||
This provides a concrete implementation that can be subclassed for custom embeddings.
|
||||
"""
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Convert input documents to embeddings.
|
||||
|
||||
Args:
|
||||
input: List of documents to embed.
|
||||
|
||||
Returns:
|
||||
List of numpy arrays representing the embeddings.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement __call__ method")
|
||||
@@ -0,0 +1,19 @@
|
||||
"""Type definitions for custom embedding providers."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from chromadb.api.types import EmbeddingFunction
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class CustomProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Custom provider."""
|
||||
|
||||
embedding_callable: type[EmbeddingFunction]
|
||||
|
||||
|
||||
class CustomProviderSpec(TypedDict, total=False):
|
||||
"""Custom provider specification."""
|
||||
|
||||
provider: Required[Literal["custom"]]
|
||||
config: CustomProviderConfig
|
||||
@@ -0,0 +1,23 @@
|
||||
"""Google embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.google.generative_ai import (
|
||||
GenerativeAiProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.google.types import (
|
||||
GenerativeAiProviderConfig,
|
||||
GenerativeAiProviderSpec,
|
||||
VertexAIProviderConfig,
|
||||
VertexAIProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.google.vertex import (
|
||||
VertexAIProvider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"GenerativeAiProvider",
|
||||
"GenerativeAiProviderConfig",
|
||||
"GenerativeAiProviderSpec",
|
||||
"VertexAIProvider",
|
||||
"VertexAIProviderConfig",
|
||||
"VertexAIProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,30 @@
|
||||
"""Google Generative AI embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class GenerativeAiProvider(BaseEmbeddingsProvider[GoogleGenerativeAiEmbeddingFunction]):
|
||||
"""Google Generative AI embeddings provider."""
|
||||
|
||||
embedding_callable: type[GoogleGenerativeAiEmbeddingFunction] = Field(
|
||||
default=GoogleGenerativeAiEmbeddingFunction,
|
||||
description="Google Generative AI embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="models/embedding-001",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_MODEL_NAME",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_API_KEY"
|
||||
)
|
||||
task_type: str = Field(
|
||||
default="RETRIEVAL_DOCUMENT",
|
||||
description="Task type for embeddings",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
||||
)
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Type definitions for Google embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class GenerativeAiProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Google Generative AI provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "models/embedding-001"]
|
||||
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
|
||||
|
||||
|
||||
class GenerativeAiProviderSpec(TypedDict):
|
||||
"""Google Generative AI provider specification."""
|
||||
|
||||
provider: Literal["google-generativeai"]
|
||||
config: GenerativeAiProviderConfig
|
||||
|
||||
|
||||
class VertexAIProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Vertex AI provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "textembedding-gecko"]
|
||||
project_id: Annotated[str, "cloud-large-language-models"]
|
||||
region: Annotated[str, "us-central1"]
|
||||
|
||||
|
||||
class VertexAIProviderSpec(TypedDict, total=False):
|
||||
"""Vertex AI provider specification."""
|
||||
|
||||
provider: Required[Literal["google-vertex"]]
|
||||
config: VertexAIProviderConfig
|
||||
@@ -0,0 +1,35 @@
|
||||
"""Google Vertex AI embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
|
||||
"""Google Vertex AI embeddings provider."""
|
||||
|
||||
embedding_callable: type[GoogleVertexEmbeddingFunction] = Field(
|
||||
default=GoogleVertexEmbeddingFunction,
|
||||
description="Vertex AI embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="textembedding-gecko",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_VERTEX_MODEL_NAME",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key", validation_alias="EMBEDDINGS_GOOGLE_CLOUD_API_KEY"
|
||||
)
|
||||
project_id: str = Field(
|
||||
default="cloud-large-language-models",
|
||||
description="GCP project ID",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_PROJECT",
|
||||
)
|
||||
region: str = Field(
|
||||
default="us-central1",
|
||||
description="GCP region",
|
||||
validation_alias="EMBEDDINGS_GOOGLE_CLOUD_REGION",
|
||||
)
|
||||
@@ -0,0 +1,15 @@
|
||||
"""HuggingFace embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.huggingface.huggingface_provider import (
|
||||
HuggingFaceProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.huggingface.types import (
|
||||
HuggingFaceProviderConfig,
|
||||
HuggingFaceProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"HuggingFaceProvider",
|
||||
"HuggingFaceProviderConfig",
|
||||
"HuggingFaceProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,20 @@
|
||||
"""HuggingFace embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingServer,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
|
||||
"""HuggingFace embeddings provider."""
|
||||
|
||||
embedding_callable: type[HuggingFaceEmbeddingServer] = Field(
|
||||
default=HuggingFaceEmbeddingServer,
|
||||
description="HuggingFace embedding function class",
|
||||
)
|
||||
url: str = Field(
|
||||
description="HuggingFace API URL", validation_alias="EMBEDDINGS_HUGGINGFACE_URL"
|
||||
)
|
||||
@@ -0,0 +1,18 @@
|
||||
"""Type definitions for HuggingFace embedding providers."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class HuggingFaceProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for HuggingFace provider."""
|
||||
|
||||
url: str
|
||||
|
||||
|
||||
class HuggingFaceProviderSpec(TypedDict, total=False):
|
||||
"""HuggingFace provider specification."""
|
||||
|
||||
provider: Required[Literal["huggingface"]]
|
||||
config: HuggingFaceProviderConfig
|
||||
@@ -0,0 +1,17 @@
|
||||
"""IBM embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.ibm.types import (
|
||||
WatsonProviderSpec,
|
||||
WatsonXProviderConfig,
|
||||
WatsonXProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.ibm.watsonx import (
|
||||
WatsonXProvider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"WatsonProviderSpec",
|
||||
"WatsonXProvider",
|
||||
"WatsonXProviderConfig",
|
||||
"WatsonXProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,159 @@
|
||||
"""IBM WatsonX embedding function implementation."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.embeddings.providers.ibm.types import WatsonXProviderConfig
|
||||
|
||||
|
||||
class WatsonXEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""Embedding function for IBM WatsonX models."""
|
||||
|
||||
def __init__(self, **kwargs: Unpack[WatsonXProviderConfig]) -> None:
|
||||
"""Initialize WatsonX embedding function.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration parameters for WatsonX Embeddings and Credentials.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._config = kwargs
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
"""Return the name of the embedding function for ChromaDB compatibility."""
|
||||
return "watsonx"
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Generate embeddings for input documents.
|
||||
|
||||
Args:
|
||||
input: List of documents to embed.
|
||||
|
||||
Returns:
|
||||
List of embedding vectors.
|
||||
"""
|
||||
try:
|
||||
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found, import-untyped]
|
||||
from ibm_watsonx_ai import (
|
||||
Credentials, # type: ignore[import-not-found, import-untyped]
|
||||
)
|
||||
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found, import-untyped]
|
||||
EmbedTextParamsMetaNames as EmbedParams,
|
||||
)
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"ibm-watsonx-ai is required for watsonx embeddings. "
|
||||
"Install it with: uv add ibm-watsonx-ai"
|
||||
) from e
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
embeddings_config: dict = {
|
||||
"model_id": self._config["model_id"],
|
||||
}
|
||||
if "params" in self._config and self._config["params"] is not None:
|
||||
embeddings_config["params"] = self._config["params"]
|
||||
if "project_id" in self._config and self._config["project_id"] is not None:
|
||||
embeddings_config["project_id"] = self._config["project_id"]
|
||||
if "space_id" in self._config and self._config["space_id"] is not None:
|
||||
embeddings_config["space_id"] = self._config["space_id"]
|
||||
if "api_client" in self._config and self._config["api_client"] is not None:
|
||||
embeddings_config["api_client"] = self._config["api_client"]
|
||||
if "verify" in self._config and self._config["verify"] is not None:
|
||||
embeddings_config["verify"] = self._config["verify"]
|
||||
if "persistent_connection" in self._config:
|
||||
embeddings_config["persistent_connection"] = self._config[
|
||||
"persistent_connection"
|
||||
]
|
||||
if "batch_size" in self._config:
|
||||
embeddings_config["batch_size"] = self._config["batch_size"]
|
||||
if "concurrency_limit" in self._config:
|
||||
embeddings_config["concurrency_limit"] = self._config["concurrency_limit"]
|
||||
if "max_retries" in self._config and self._config["max_retries"] is not None:
|
||||
embeddings_config["max_retries"] = self._config["max_retries"]
|
||||
if "delay_time" in self._config and self._config["delay_time"] is not None:
|
||||
embeddings_config["delay_time"] = self._config["delay_time"]
|
||||
if (
|
||||
"retry_status_codes" in self._config
|
||||
and self._config["retry_status_codes"] is not None
|
||||
):
|
||||
embeddings_config["retry_status_codes"] = self._config["retry_status_codes"]
|
||||
|
||||
if "credentials" in self._config and self._config["credentials"] is not None:
|
||||
embeddings_config["credentials"] = self._config["credentials"]
|
||||
else:
|
||||
cred_config: dict = {}
|
||||
if "url" in self._config and self._config["url"] is not None:
|
||||
cred_config["url"] = self._config["url"]
|
||||
if "api_key" in self._config and self._config["api_key"] is not None:
|
||||
cred_config["api_key"] = self._config["api_key"]
|
||||
if "name" in self._config and self._config["name"] is not None:
|
||||
cred_config["name"] = self._config["name"]
|
||||
if (
|
||||
"iam_serviceid_crn" in self._config
|
||||
and self._config["iam_serviceid_crn"] is not None
|
||||
):
|
||||
cred_config["iam_serviceid_crn"] = self._config["iam_serviceid_crn"]
|
||||
if (
|
||||
"trusted_profile_id" in self._config
|
||||
and self._config["trusted_profile_id"] is not None
|
||||
):
|
||||
cred_config["trusted_profile_id"] = self._config["trusted_profile_id"]
|
||||
if "token" in self._config and self._config["token"] is not None:
|
||||
cred_config["token"] = self._config["token"]
|
||||
if (
|
||||
"projects_token" in self._config
|
||||
and self._config["projects_token"] is not None
|
||||
):
|
||||
cred_config["projects_token"] = self._config["projects_token"]
|
||||
if "username" in self._config and self._config["username"] is not None:
|
||||
cred_config["username"] = self._config["username"]
|
||||
if "password" in self._config and self._config["password"] is not None:
|
||||
cred_config["password"] = self._config["password"]
|
||||
if (
|
||||
"instance_id" in self._config
|
||||
and self._config["instance_id"] is not None
|
||||
):
|
||||
cred_config["instance_id"] = self._config["instance_id"]
|
||||
if "version" in self._config and self._config["version"] is not None:
|
||||
cred_config["version"] = self._config["version"]
|
||||
if (
|
||||
"bedrock_url" in self._config
|
||||
and self._config["bedrock_url"] is not None
|
||||
):
|
||||
cred_config["bedrock_url"] = self._config["bedrock_url"]
|
||||
if (
|
||||
"platform_url" in self._config
|
||||
and self._config["platform_url"] is not None
|
||||
):
|
||||
cred_config["platform_url"] = self._config["platform_url"]
|
||||
if "proxies" in self._config and self._config["proxies"] is not None:
|
||||
cred_config["proxies"] = self._config["proxies"]
|
||||
if (
|
||||
"verify" not in embeddings_config
|
||||
and "verify" in self._config
|
||||
and self._config["verify"] is not None
|
||||
):
|
||||
cred_config["verify"] = self._config["verify"]
|
||||
|
||||
if cred_config:
|
||||
embeddings_config["credentials"] = Credentials(**cred_config)
|
||||
|
||||
if "params" not in embeddings_config:
|
||||
embeddings_config["params"] = {
|
||||
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
|
||||
EmbedParams.RETURN_OPTIONS: {"input_text": True},
|
||||
}
|
||||
|
||||
embedding = watson_models.Embeddings(**embeddings_config)
|
||||
|
||||
try:
|
||||
embeddings = embedding.embed_documents(input)
|
||||
return cast(Embeddings, embeddings)
|
||||
except Exception as e:
|
||||
print(f"Error during WatsonX embedding: {e}")
|
||||
raise
|
||||
@@ -0,0 +1,58 @@
|
||||
"""Type definitions for IBM WatsonX embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict, deprecated
|
||||
|
||||
|
||||
class WatsonXProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for WatsonX provider."""
|
||||
|
||||
model_id: str
|
||||
url: str
|
||||
params: dict[str, str | dict[str, str]]
|
||||
credentials: Any
|
||||
project_id: str
|
||||
space_id: str
|
||||
api_client: Any
|
||||
verify: bool | str
|
||||
persistent_connection: Annotated[bool, True]
|
||||
batch_size: Annotated[int, 100]
|
||||
concurrency_limit: Annotated[int, 10]
|
||||
max_retries: int
|
||||
delay_time: float
|
||||
retry_status_codes: list[int]
|
||||
api_key: str
|
||||
name: str
|
||||
iam_serviceid_crn: str
|
||||
trusted_profile_id: str
|
||||
token: str
|
||||
projects_token: str
|
||||
username: str
|
||||
password: str
|
||||
instance_id: str
|
||||
version: str
|
||||
bedrock_url: str
|
||||
platform_url: str
|
||||
proxies: dict
|
||||
|
||||
|
||||
class WatsonXProviderSpec(TypedDict, total=False):
|
||||
"""WatsonX provider specification."""
|
||||
|
||||
provider: Required[Literal["watsonx"]]
|
||||
config: WatsonXProviderConfig
|
||||
|
||||
|
||||
@deprecated(
|
||||
'The "WatsonProviderSpec" provider spec is deprecated and will be removed in v1.0.0. Use "WatsonXProviderSpec" instead.'
|
||||
)
|
||||
class WatsonProviderSpec(TypedDict, total=False):
|
||||
"""Watson provider specification (deprecated).
|
||||
|
||||
Notes:
|
||||
- This is deprecated. Use WatsonXProviderSpec with provider="watsonx" instead.
|
||||
"""
|
||||
|
||||
provider: Required[Literal["watson"]]
|
||||
config: WatsonXProviderConfig
|
||||
142
lib/core/src/crewai/core/rag/embeddings/providers/ibm/watsonx.py
Normal file
142
lib/core/src/crewai/core/rag/embeddings/providers/ibm/watsonx.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""IBM WatsonX embeddings provider."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
|
||||
WatsonXEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
class WatsonXProvider(BaseEmbeddingsProvider[WatsonXEmbeddingFunction]):
|
||||
"""IBM WatsonX embeddings provider.
|
||||
|
||||
Note: Requires custom implementation as WatsonX uses a different interface.
|
||||
"""
|
||||
|
||||
embedding_callable: type[WatsonXEmbeddingFunction] = Field(
|
||||
default=WatsonXEmbeddingFunction, description="WatsonX embedding function class"
|
||||
)
|
||||
model_id: str = Field(
|
||||
description="WatsonX model ID", validation_alias="EMBEDDINGS_WATSONX_MODEL_ID"
|
||||
)
|
||||
params: dict[str, str | dict[str, str]] | None = Field(
|
||||
default=None, description="Additional parameters"
|
||||
)
|
||||
credentials: Any | None = Field(default=None, description="WatsonX credentials")
|
||||
project_id: str | None = Field(
|
||||
default=None,
|
||||
description="WatsonX project ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PROJECT_ID",
|
||||
)
|
||||
space_id: str | None = Field(
|
||||
default=None,
|
||||
description="WatsonX space ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_SPACE_ID",
|
||||
)
|
||||
api_client: Any | None = Field(default=None, description="WatsonX API client")
|
||||
verify: bool | str | None = Field(
|
||||
default=None,
|
||||
description="SSL verification",
|
||||
validation_alias="EMBEDDINGS_WATSONX_VERIFY",
|
||||
)
|
||||
persistent_connection: bool = Field(
|
||||
default=True,
|
||||
description="Use persistent connection",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PERSISTENT_CONNECTION",
|
||||
)
|
||||
batch_size: int = Field(
|
||||
default=100,
|
||||
description="Batch size for processing",
|
||||
validation_alias="EMBEDDINGS_WATSONX_BATCH_SIZE",
|
||||
)
|
||||
concurrency_limit: int = Field(
|
||||
default=10,
|
||||
description="Concurrency limit",
|
||||
validation_alias="EMBEDDINGS_WATSONX_CONCURRENCY_LIMIT",
|
||||
)
|
||||
max_retries: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum retries",
|
||||
validation_alias="EMBEDDINGS_WATSONX_MAX_RETRIES",
|
||||
)
|
||||
delay_time: float | None = Field(
|
||||
default=None,
|
||||
description="Delay time between retries",
|
||||
validation_alias="EMBEDDINGS_WATSONX_DELAY_TIME",
|
||||
)
|
||||
retry_status_codes: list[int] | None = Field(
|
||||
default=None, description="HTTP status codes to retry on"
|
||||
)
|
||||
url: str = Field(
|
||||
description="WatsonX API URL", validation_alias="EMBEDDINGS_WATSONX_URL"
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="WatsonX API key", validation_alias="EMBEDDINGS_WATSONX_API_KEY"
|
||||
)
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
description="Service name",
|
||||
validation_alias="EMBEDDINGS_WATSONX_NAME",
|
||||
)
|
||||
iam_serviceid_crn: str | None = Field(
|
||||
default=None,
|
||||
description="IAM service ID CRN",
|
||||
validation_alias="EMBEDDINGS_WATSONX_IAM_SERVICEID_CRN",
|
||||
)
|
||||
trusted_profile_id: str | None = Field(
|
||||
default=None,
|
||||
description="Trusted profile ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_TRUSTED_PROFILE_ID",
|
||||
)
|
||||
token: str | None = Field(
|
||||
default=None,
|
||||
description="Bearer token",
|
||||
validation_alias="EMBEDDINGS_WATSONX_TOKEN",
|
||||
)
|
||||
projects_token: str | None = Field(
|
||||
default=None,
|
||||
description="Projects token",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PROJECTS_TOKEN",
|
||||
)
|
||||
username: str | None = Field(
|
||||
default=None,
|
||||
description="Username",
|
||||
validation_alias="EMBEDDINGS_WATSONX_USERNAME",
|
||||
)
|
||||
password: str | None = Field(
|
||||
default=None,
|
||||
description="Password",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PASSWORD",
|
||||
)
|
||||
instance_id: str | None = Field(
|
||||
default=None,
|
||||
description="Service instance ID",
|
||||
validation_alias="EMBEDDINGS_WATSONX_INSTANCE_ID",
|
||||
)
|
||||
version: str | None = Field(
|
||||
default=None,
|
||||
description="API version",
|
||||
validation_alias="EMBEDDINGS_WATSONX_VERSION",
|
||||
)
|
||||
bedrock_url: str | None = Field(
|
||||
default=None,
|
||||
description="Bedrock URL",
|
||||
validation_alias="EMBEDDINGS_WATSONX_BEDROCK_URL",
|
||||
)
|
||||
platform_url: str | None = Field(
|
||||
default=None,
|
||||
description="Platform URL",
|
||||
validation_alias="EMBEDDINGS_WATSONX_PLATFORM_URL",
|
||||
)
|
||||
proxies: dict | None = Field(default=None, description="Proxy configuration")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_space_or_project(self) -> Self:
|
||||
"""Validate that either space_id or project_id is provided."""
|
||||
if not self.space_id and not self.project_id:
|
||||
raise ValueError("One of 'space_id' or 'project_id' must be provided")
|
||||
return self
|
||||
@@ -0,0 +1,15 @@
|
||||
"""Instructor embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.instructor.instructor_provider import (
|
||||
InstructorProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.instructor.types import (
|
||||
InstructorProviderConfig,
|
||||
InstructorProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"InstructorProvider",
|
||||
"InstructorProviderConfig",
|
||||
"InstructorProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Instructor embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.instructor_embedding_function import (
|
||||
InstructorEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class InstructorProvider(BaseEmbeddingsProvider[InstructorEmbeddingFunction]):
|
||||
"""Instructor embeddings provider."""
|
||||
|
||||
embedding_callable: type[InstructorEmbeddingFunction] = Field(
|
||||
default=InstructorEmbeddingFunction,
|
||||
description="Instructor embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="hkunlp/instructor-base",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_MODEL_NAME",
|
||||
)
|
||||
device: str = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on (cpu or cuda)",
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_DEVICE",
|
||||
)
|
||||
instruction: str | None = Field(
|
||||
default=None,
|
||||
description="Instruction for embeddings",
|
||||
validation_alias="EMBEDDINGS_INSTRUCTOR_INSTRUCTION",
|
||||
)
|
||||
@@ -0,0 +1,20 @@
|
||||
"""Type definitions for Instructor embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class InstructorProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Instructor provider."""
|
||||
|
||||
model_name: Annotated[str, "hkunlp/instructor-base"]
|
||||
device: Annotated[str, "cpu"]
|
||||
instruction: str
|
||||
|
||||
|
||||
class InstructorProviderSpec(TypedDict, total=False):
|
||||
"""Instructor provider specification."""
|
||||
|
||||
provider: Required[Literal["instructor"]]
|
||||
config: InstructorProviderConfig
|
||||
@@ -0,0 +1,13 @@
|
||||
"""Jina embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.jina.jina_provider import JinaProvider
|
||||
from crewai.rag.embeddings.providers.jina.types import (
|
||||
JinaProviderConfig,
|
||||
JinaProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"JinaProvider",
|
||||
"JinaProviderConfig",
|
||||
"JinaProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,24 @@
|
||||
"""Jina embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.jina_embedding_function import (
|
||||
JinaEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class JinaProvider(BaseEmbeddingsProvider[JinaEmbeddingFunction]):
|
||||
"""Jina embeddings provider."""
|
||||
|
||||
embedding_callable: type[JinaEmbeddingFunction] = Field(
|
||||
default=JinaEmbeddingFunction, description="Jina embedding function class"
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Jina API key", validation_alias="EMBEDDINGS_JINA_API_KEY"
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="jina-embeddings-v2-base-en",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_JINA_MODEL_NAME",
|
||||
)
|
||||
@@ -0,0 +1,19 @@
|
||||
"""Type definitions for Jina embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class JinaProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Jina provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "jina-embeddings-v2-base-en"]
|
||||
|
||||
|
||||
class JinaProviderSpec(TypedDict, total=False):
|
||||
"""Jina provider specification."""
|
||||
|
||||
provider: Required[Literal["jina"]]
|
||||
config: JinaProviderConfig
|
||||
@@ -0,0 +1,15 @@
|
||||
"""Microsoft embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.microsoft.azure import (
|
||||
AzureProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.microsoft.types import (
|
||||
AzureProviderConfig,
|
||||
AzureProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AzureProvider",
|
||||
"AzureProviderConfig",
|
||||
"AzureProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,60 @@
|
||||
"""Azure OpenAI embeddings provider."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
"""Azure OpenAI embeddings provider."""
|
||||
|
||||
embedding_callable: type[OpenAIEmbeddingFunction] = Field(
|
||||
default=OpenAIEmbeddingFunction,
|
||||
description="Azure OpenAI embedding function class",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Azure API key", validation_alias="EMBEDDINGS_OPENAI_API_KEY"
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
default=None,
|
||||
description="Azure endpoint URL",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
|
||||
)
|
||||
api_type: str = Field(
|
||||
default="azure",
|
||||
description="API type for Azure",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default=None,
|
||||
description="Azure API version",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-ada-002",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
)
|
||||
default_headers: dict[str, Any] | None = Field(
|
||||
default=None, description="Default headers for API requests"
|
||||
)
|
||||
dimensions: int | None = Field(
|
||||
default=None,
|
||||
description="Embedding dimensions",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
|
||||
)
|
||||
deployment_id: str | None = Field(
|
||||
default=None,
|
||||
description="Azure deployment ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
|
||||
)
|
||||
organization_id: str | None = Field(
|
||||
default=None,
|
||||
description="Organization ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
|
||||
)
|
||||
@@ -0,0 +1,26 @@
|
||||
"""Type definitions for Microsoft Azure embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class AzureProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Azure provider."""
|
||||
|
||||
api_key: str
|
||||
api_base: str
|
||||
api_type: Annotated[str, "azure"]
|
||||
api_version: str
|
||||
model_name: Annotated[str, "text-embedding-ada-002"]
|
||||
default_headers: dict[str, Any]
|
||||
dimensions: int
|
||||
deployment_id: str
|
||||
organization_id: str
|
||||
|
||||
|
||||
class AzureProviderSpec(TypedDict, total=False):
|
||||
"""Azure provider specification."""
|
||||
|
||||
provider: Required[Literal["azure"]]
|
||||
config: AzureProviderConfig
|
||||
@@ -0,0 +1,15 @@
|
||||
"""Ollama embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.ollama.ollama_provider import (
|
||||
OllamaProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.ollama.types import (
|
||||
OllamaProviderConfig,
|
||||
OllamaProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OllamaProvider",
|
||||
"OllamaProviderConfig",
|
||||
"OllamaProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,25 @@
|
||||
"""Ollama embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class OllamaProvider(BaseEmbeddingsProvider[OllamaEmbeddingFunction]):
|
||||
"""Ollama embeddings provider."""
|
||||
|
||||
embedding_callable: type[OllamaEmbeddingFunction] = Field(
|
||||
default=OllamaEmbeddingFunction, description="Ollama embedding function class"
|
||||
)
|
||||
url: str = Field(
|
||||
default="http://localhost:11434/api/embeddings",
|
||||
description="Ollama API endpoint URL",
|
||||
validation_alias="EMBEDDINGS_OLLAMA_URL",
|
||||
)
|
||||
model_name: str = Field(
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_OLLAMA_MODEL_NAME",
|
||||
)
|
||||
@@ -0,0 +1,19 @@
|
||||
"""Type definitions for Ollama embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class OllamaProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Ollama provider."""
|
||||
|
||||
url: Annotated[str, "http://localhost:11434/api/embeddings"]
|
||||
model_name: str
|
||||
|
||||
|
||||
class OllamaProviderSpec(TypedDict, total=False):
|
||||
"""Ollama provider specification."""
|
||||
|
||||
provider: Required[Literal["ollama"]]
|
||||
config: OllamaProviderConfig
|
||||
@@ -0,0 +1,13 @@
|
||||
"""ONNX embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.onnx.onnx_provider import ONNXProvider
|
||||
from crewai.rag.embeddings.providers.onnx.types import (
|
||||
ONNXProviderConfig,
|
||||
ONNXProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ONNXProvider",
|
||||
"ONNXProviderConfig",
|
||||
"ONNXProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,19 @@
|
||||
"""ONNX embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class ONNXProvider(BaseEmbeddingsProvider[ONNXMiniLM_L6_V2]):
|
||||
"""ONNX embeddings provider."""
|
||||
|
||||
embedding_callable: type[ONNXMiniLM_L6_V2] = Field(
|
||||
default=ONNXMiniLM_L6_V2, description="ONNX MiniLM embedding function class"
|
||||
)
|
||||
preferred_providers: list[str] | None = Field(
|
||||
default=None,
|
||||
description="Preferred ONNX execution providers",
|
||||
validation_alias="EMBEDDINGS_ONNX_PREFERRED_PROVIDERS",
|
||||
)
|
||||
@@ -0,0 +1,18 @@
|
||||
"""Type definitions for ONNX embedding providers."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class ONNXProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for ONNX provider."""
|
||||
|
||||
preferred_providers: list[str]
|
||||
|
||||
|
||||
class ONNXProviderSpec(TypedDict, total=False):
|
||||
"""ONNX provider specification."""
|
||||
|
||||
provider: Required[Literal["onnx"]]
|
||||
config: ONNXProviderConfig
|
||||
@@ -0,0 +1,15 @@
|
||||
"""OpenAI embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.openai.openai_provider import (
|
||||
OpenAIProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.openai.types import (
|
||||
OpenAIProviderConfig,
|
||||
OpenAIProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OpenAIProvider",
|
||||
"OpenAIProviderConfig",
|
||||
"OpenAIProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,62 @@
|
||||
"""OpenAI embeddings provider."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
"""OpenAI embeddings provider."""
|
||||
|
||||
embedding_callable: type[OpenAIEmbeddingFunction] = Field(
|
||||
default=OpenAIEmbeddingFunction,
|
||||
description="OpenAI embedding function class",
|
||||
)
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="OpenAI API key",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_KEY",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-ada-002",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_OPENAI_MODEL_NAME",
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
default=None,
|
||||
description="Base URL for API requests",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_BASE",
|
||||
)
|
||||
api_type: str | None = Field(
|
||||
default=None,
|
||||
description="API type (e.g., 'azure')",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_TYPE",
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default=None,
|
||||
description="API version",
|
||||
validation_alias="EMBEDDINGS_OPENAI_API_VERSION",
|
||||
)
|
||||
default_headers: dict[str, Any] | None = Field(
|
||||
default=None, description="Default headers for API requests"
|
||||
)
|
||||
dimensions: int | None = Field(
|
||||
default=None,
|
||||
description="Embedding dimensions",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DIMENSIONS",
|
||||
)
|
||||
deployment_id: str | None = Field(
|
||||
default=None,
|
||||
description="Azure deployment ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_DEPLOYMENT_ID",
|
||||
)
|
||||
organization_id: str | None = Field(
|
||||
default=None,
|
||||
description="OpenAI organization ID",
|
||||
validation_alias="EMBEDDINGS_OPENAI_ORGANIZATION_ID",
|
||||
)
|
||||
@@ -0,0 +1,26 @@
|
||||
"""Type definitions for OpenAI embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class OpenAIProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for OpenAI provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "text-embedding-ada-002"]
|
||||
api_base: str
|
||||
api_type: str
|
||||
api_version: str
|
||||
default_headers: dict[str, Any]
|
||||
dimensions: int
|
||||
deployment_id: str
|
||||
organization_id: str
|
||||
|
||||
|
||||
class OpenAIProviderSpec(TypedDict, total=False):
|
||||
"""OpenAI provider specification."""
|
||||
|
||||
provider: Required[Literal["openai"]]
|
||||
config: OpenAIProviderConfig
|
||||
@@ -0,0 +1,15 @@
|
||||
"""OpenCLIP embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.openclip.openclip_provider import (
|
||||
OpenCLIPProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.openclip.types import (
|
||||
OpenCLIPProviderConfig,
|
||||
OpenCLIPProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OpenCLIPProvider",
|
||||
"OpenCLIPProviderConfig",
|
||||
"OpenCLIPProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,32 @@
|
||||
"""OpenCLIP embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
|
||||
OpenCLIPEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class OpenCLIPProvider(BaseEmbeddingsProvider[OpenCLIPEmbeddingFunction]):
|
||||
"""OpenCLIP embeddings provider."""
|
||||
|
||||
embedding_callable: type[OpenCLIPEmbeddingFunction] = Field(
|
||||
default=OpenCLIPEmbeddingFunction,
|
||||
description="OpenCLIP embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="ViT-B-32",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_MODEL_NAME",
|
||||
)
|
||||
checkpoint: str = Field(
|
||||
default="laion2b_s34b_b79k",
|
||||
description="Model checkpoint",
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_CHECKPOINT",
|
||||
)
|
||||
device: str | None = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on",
|
||||
validation_alias="EMBEDDINGS_OPENCLIP_DEVICE",
|
||||
)
|
||||
@@ -0,0 +1,20 @@
|
||||
"""Type definitions for OpenCLIP embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class OpenCLIPProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for OpenCLIP provider."""
|
||||
|
||||
model_name: Annotated[str, "ViT-B-32"]
|
||||
checkpoint: Annotated[str, "laion2b_s34b_b79k"]
|
||||
device: Annotated[str, "cpu"]
|
||||
|
||||
|
||||
class OpenCLIPProviderSpec(TypedDict):
|
||||
"""OpenCLIP provider specification."""
|
||||
|
||||
provider: Required[Literal["openclip"]]
|
||||
config: OpenCLIPProviderConfig
|
||||
@@ -0,0 +1,15 @@
|
||||
"""Roboflow embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.roboflow.roboflow_provider import (
|
||||
RoboflowProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.roboflow.types import (
|
||||
RoboflowProviderConfig,
|
||||
RoboflowProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RoboflowProvider",
|
||||
"RoboflowProviderConfig",
|
||||
"RoboflowProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Roboflow embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
|
||||
RoboflowEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class RoboflowProvider(BaseEmbeddingsProvider[RoboflowEmbeddingFunction]):
|
||||
"""Roboflow embeddings provider."""
|
||||
|
||||
embedding_callable: type[RoboflowEmbeddingFunction] = Field(
|
||||
default=RoboflowEmbeddingFunction,
|
||||
description="Roboflow embedding function class",
|
||||
)
|
||||
api_key: str = Field(
|
||||
default="",
|
||||
description="Roboflow API key",
|
||||
validation_alias="EMBEDDINGS_ROBOFLOW_API_KEY",
|
||||
)
|
||||
api_url: str = Field(
|
||||
default="https://infer.roboflow.com",
|
||||
description="Roboflow API URL",
|
||||
validation_alias="EMBEDDINGS_ROBOFLOW_API_URL",
|
||||
)
|
||||
@@ -0,0 +1,19 @@
|
||||
"""Type definitions for Roboflow embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class RoboflowProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Roboflow provider."""
|
||||
|
||||
api_key: Annotated[str, ""]
|
||||
api_url: Annotated[str, "https://infer.roboflow.com"]
|
||||
|
||||
|
||||
class RoboflowProviderSpec(TypedDict):
|
||||
"""Roboflow provider specification."""
|
||||
|
||||
provider: Required[Literal["roboflow"]]
|
||||
config: RoboflowProviderConfig
|
||||
@@ -0,0 +1,15 @@
|
||||
"""SentenceTransformer embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider import (
|
||||
SentenceTransformerProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.sentence_transformer.types import (
|
||||
SentenceTransformerProviderConfig,
|
||||
SentenceTransformerProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SentenceTransformerProvider",
|
||||
"SentenceTransformerProviderConfig",
|
||||
"SentenceTransformerProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,34 @@
|
||||
"""SentenceTransformer embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
|
||||
SentenceTransformerEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class SentenceTransformerProvider(
|
||||
BaseEmbeddingsProvider[SentenceTransformerEmbeddingFunction]
|
||||
):
|
||||
"""SentenceTransformer embeddings provider."""
|
||||
|
||||
embedding_callable: type[SentenceTransformerEmbeddingFunction] = Field(
|
||||
default=SentenceTransformerEmbeddingFunction,
|
||||
description="SentenceTransformer embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="all-MiniLM-L6-v2",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_MODEL_NAME",
|
||||
)
|
||||
device: str = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on (cpu or cuda)",
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_DEVICE",
|
||||
)
|
||||
normalize_embeddings: bool = Field(
|
||||
default=False,
|
||||
description="Whether to normalize embeddings",
|
||||
validation_alias="EMBEDDINGS_SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
|
||||
)
|
||||
@@ -0,0 +1,20 @@
|
||||
"""Type definitions for SentenceTransformer embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class SentenceTransformerProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for SentenceTransformer provider."""
|
||||
|
||||
model_name: Annotated[str, "all-MiniLM-L6-v2"]
|
||||
device: Annotated[str, "cpu"]
|
||||
normalize_embeddings: Annotated[bool, False]
|
||||
|
||||
|
||||
class SentenceTransformerProviderSpec(TypedDict):
|
||||
"""SentenceTransformer provider specification."""
|
||||
|
||||
provider: Required[Literal["sentence-transformer"]]
|
||||
config: SentenceTransformerProviderConfig
|
||||
@@ -0,0 +1,15 @@
|
||||
"""Text2Vec embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.text2vec.text2vec_provider import (
|
||||
Text2VecProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.text2vec.types import (
|
||||
Text2VecProviderConfig,
|
||||
Text2VecProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Text2VecProvider",
|
||||
"Text2VecProviderConfig",
|
||||
"Text2VecProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,22 @@
|
||||
"""Text2Vec embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
Text2VecEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class Text2VecProvider(BaseEmbeddingsProvider[Text2VecEmbeddingFunction]):
|
||||
"""Text2Vec embeddings provider."""
|
||||
|
||||
embedding_callable: type[Text2VecEmbeddingFunction] = Field(
|
||||
default=Text2VecEmbeddingFunction,
|
||||
description="Text2Vec embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="shibing624/text2vec-base-chinese",
|
||||
description="Model name to use",
|
||||
validation_alias="EMBEDDINGS_TEXT2VEC_MODEL_NAME",
|
||||
)
|
||||
@@ -0,0 +1,18 @@
|
||||
"""Type definitions for Text2Vec embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class Text2VecProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Text2Vec provider."""
|
||||
|
||||
model_name: Annotated[str, "shibing624/text2vec-base-chinese"]
|
||||
|
||||
|
||||
class Text2VecProviderSpec(TypedDict):
|
||||
"""Text2Vec provider specification."""
|
||||
|
||||
provider: Required[Literal["text2vec"]]
|
||||
config: Text2VecProviderConfig
|
||||
@@ -0,0 +1,15 @@
|
||||
"""VoyageAI embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.voyageai.types import (
|
||||
VoyageAIProviderConfig,
|
||||
VoyageAIProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.voyageai.voyageai_provider import (
|
||||
VoyageAIProvider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"VoyageAIProvider",
|
||||
"VoyageAIProviderConfig",
|
||||
"VoyageAIProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,62 @@
|
||||
"""VoyageAI embedding function implementation."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderConfig
|
||||
|
||||
|
||||
class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""Embedding function for VoyageAI models."""
|
||||
|
||||
def __init__(self, **kwargs: Unpack[VoyageAIProviderConfig]) -> None:
|
||||
"""Initialize VoyageAI embedding function.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration parameters for VoyageAI.
|
||||
"""
|
||||
try:
|
||||
import voyageai # type: ignore[import-not-found]
|
||||
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"voyageai is required for voyageai embeddings. "
|
||||
"Install it with: uv add voyageai"
|
||||
) from e
|
||||
self._config = kwargs
|
||||
self._client = voyageai.Client(
|
||||
api_key=kwargs["api_key"],
|
||||
max_retries=kwargs.get("max_retries", 0),
|
||||
timeout=kwargs.get("timeout"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def name() -> str:
|
||||
"""Return the name of the embedding function for ChromaDB compatibility."""
|
||||
return "voyageai"
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Generate embeddings for input documents.
|
||||
|
||||
Args:
|
||||
input: List of documents to embed.
|
||||
|
||||
Returns:
|
||||
List of embedding vectors.
|
||||
"""
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
result = self._client.embed(
|
||||
texts=input,
|
||||
model=self._config.get("model", "voyage-2"),
|
||||
input_type=self._config.get("input_type"),
|
||||
truncation=self._config.get("truncation", True),
|
||||
output_dtype=self._config.get("output_dtype"),
|
||||
output_dimension=self._config.get("output_dimension"),
|
||||
)
|
||||
|
||||
return cast(Embeddings, result.embeddings)
|
||||
@@ -0,0 +1,25 @@
|
||||
"""Type definitions for VoyageAI embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class VoyageAIProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for VoyageAI provider."""
|
||||
|
||||
api_key: str
|
||||
model: Annotated[str, "voyage-2"]
|
||||
input_type: str
|
||||
truncation: Annotated[bool, True]
|
||||
output_dtype: str
|
||||
output_dimension: int
|
||||
max_retries: Annotated[int, 0]
|
||||
timeout: float
|
||||
|
||||
|
||||
class VoyageAIProviderSpec(TypedDict):
|
||||
"""VoyageAI provider specification."""
|
||||
|
||||
provider: Required[Literal["voyageai"]]
|
||||
config: VoyageAIProviderConfig
|
||||
@@ -0,0 +1,55 @@
|
||||
"""Voyage AI embeddings provider."""
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
|
||||
VoyageAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
class VoyageAIProvider(BaseEmbeddingsProvider[VoyageAIEmbeddingFunction]):
|
||||
"""Voyage AI embeddings provider."""
|
||||
|
||||
embedding_callable: type[VoyageAIEmbeddingFunction] = Field(
|
||||
default=VoyageAIEmbeddingFunction,
|
||||
description="Voyage AI embedding function class",
|
||||
)
|
||||
model: str = Field(
|
||||
default="voyage-2",
|
||||
description="Model to use for embeddings",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_MODEL",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Voyage AI API key", validation_alias="EMBEDDINGS_VOYAGEAI_API_KEY"
|
||||
)
|
||||
input_type: str | None = Field(
|
||||
default=None,
|
||||
description="Input type for embeddings",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_INPUT_TYPE",
|
||||
)
|
||||
truncation: bool = Field(
|
||||
default=True,
|
||||
description="Whether to truncate inputs",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_TRUNCATION",
|
||||
)
|
||||
output_dtype: str | None = Field(
|
||||
default=None,
|
||||
description="Output data type",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DTYPE",
|
||||
)
|
||||
output_dimension: int | None = Field(
|
||||
default=None,
|
||||
description="Output dimension",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_OUTPUT_DIMENSION",
|
||||
)
|
||||
max_retries: int = Field(
|
||||
default=0,
|
||||
description="Maximum retries for API calls",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_MAX_RETRIES",
|
||||
)
|
||||
timeout: float | None = Field(
|
||||
default=None,
|
||||
description="Timeout for API calls",
|
||||
validation_alias="EMBEDDINGS_VOYAGEAI_TIMEOUT",
|
||||
)
|
||||
78
lib/core/src/crewai/core/rag/embeddings/types.py
Normal file
78
lib/core/src/crewai/core/rag/embeddings/types.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Type definitions for the embeddings module."""
|
||||
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
|
||||
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
|
||||
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
|
||||
from crewai.rag.embeddings.providers.google.types import (
|
||||
GenerativeAiProviderSpec,
|
||||
VertexAIProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderSpec
|
||||
from crewai.rag.embeddings.providers.ibm.types import (
|
||||
WatsonProviderSpec,
|
||||
WatsonXProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
|
||||
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
|
||||
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
|
||||
from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec
|
||||
from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec
|
||||
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
|
||||
from crewai.rag.embeddings.providers.openclip.types import OpenCLIPProviderSpec
|
||||
from crewai.rag.embeddings.providers.roboflow.types import RoboflowProviderSpec
|
||||
from crewai.rag.embeddings.providers.sentence_transformer.types import (
|
||||
SentenceTransformerProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
|
||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
|
||||
|
||||
ProviderSpec = (
|
||||
AzureProviderSpec
|
||||
| BedrockProviderSpec
|
||||
| CohereProviderSpec
|
||||
| CustomProviderSpec
|
||||
| GenerativeAiProviderSpec
|
||||
| HuggingFaceProviderSpec
|
||||
| InstructorProviderSpec
|
||||
| JinaProviderSpec
|
||||
| OllamaProviderSpec
|
||||
| ONNXProviderSpec
|
||||
| OpenAIProviderSpec
|
||||
| OpenCLIPProviderSpec
|
||||
| RoboflowProviderSpec
|
||||
| SentenceTransformerProviderSpec
|
||||
| Text2VecProviderSpec
|
||||
| VertexAIProviderSpec
|
||||
| VoyageAIProviderSpec
|
||||
| WatsonProviderSpec # Deprecated, use WatsonXProviderSpec
|
||||
| WatsonXProviderSpec
|
||||
)
|
||||
|
||||
AllowedEmbeddingProviders = Literal[
|
||||
"azure",
|
||||
"amazon-bedrock",
|
||||
"cohere",
|
||||
"custom",
|
||||
"google-generativeai",
|
||||
"google-vertex",
|
||||
"huggingface",
|
||||
"instructor",
|
||||
"jina",
|
||||
"ollama",
|
||||
"onnx",
|
||||
"openai",
|
||||
"openclip",
|
||||
"roboflow",
|
||||
"sentence-transformer",
|
||||
"text2vec",
|
||||
"voyageai",
|
||||
"watsonx",
|
||||
"watson", # for backward compatibility until v1.0.0
|
||||
]
|
||||
|
||||
EmbedderConfig: TypeAlias = (
|
||||
ProviderSpec | BaseEmbeddingsProvider | type[BaseEmbeddingsProvider]
|
||||
)
|
||||
47
lib/core/src/crewai/core/rag/factory.py
Normal file
47
lib/core/src/crewai/core/rag/factory.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Factory functions for creating RAG clients from configuration."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
from crewai.rag.config.optional_imports.protocols import (
|
||||
ChromaFactoryModule,
|
||||
QdrantFactoryModule,
|
||||
)
|
||||
from crewai.rag.config.types import RagConfigType
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.core.utilities.import_utils import require
|
||||
|
||||
|
||||
def create_client(config: RagConfigType) -> BaseClient:
|
||||
"""Create a client from configuration using the appropriate factory.
|
||||
|
||||
Args:
|
||||
config: The RAG client configuration.
|
||||
|
||||
Returns:
|
||||
The created client instance.
|
||||
|
||||
Raises:
|
||||
ValueError: If the configuration provider is not supported.
|
||||
"""
|
||||
|
||||
if config.provider == "chromadb":
|
||||
chromadb_mod = cast(
|
||||
ChromaFactoryModule,
|
||||
require(
|
||||
"crewai.rag.chromadb.factory",
|
||||
purpose="The 'chromadb' provider",
|
||||
),
|
||||
)
|
||||
return chromadb_mod.create_client(config)
|
||||
|
||||
if config.provider == "qdrant":
|
||||
qdrant_mod = cast(
|
||||
QdrantFactoryModule,
|
||||
require(
|
||||
"crewai.rag.qdrant.factory",
|
||||
purpose="The 'qdrant' provider",
|
||||
),
|
||||
)
|
||||
return qdrant_mod.create_client(config)
|
||||
|
||||
raise ValueError(f"Unsupported provider: {config.provider}")
|
||||
1
lib/core/src/crewai/core/rag/qdrant/__init__.py
Normal file
1
lib/core/src/crewai/core/rag/qdrant/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Qdrant vector database client implementation."""
|
||||
518
lib/core/src/crewai/core/rag/qdrant/client.py
Normal file
518
lib/core/src/crewai/core/rag/qdrant/client.py
Normal file
@@ -0,0 +1,518 @@
|
||||
"""Qdrant client implementation."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_client import (
|
||||
BaseClient,
|
||||
BaseCollectionAddParams,
|
||||
BaseCollectionParams,
|
||||
BaseCollectionSearchParams,
|
||||
)
|
||||
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
||||
from crewai.rag.qdrant.types import (
|
||||
AsyncEmbeddingFunction,
|
||||
EmbeddingFunction,
|
||||
QdrantClientType,
|
||||
QdrantCollectionCreateParams,
|
||||
)
|
||||
from crewai.rag.qdrant.utils import (
|
||||
_create_point_from_document,
|
||||
_get_collection_params,
|
||||
_is_async_client,
|
||||
_is_async_embedding_function,
|
||||
_is_sync_client,
|
||||
_prepare_search_params,
|
||||
_process_search_results,
|
||||
)
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
|
||||
class QdrantClient(BaseClient):
|
||||
"""Qdrant implementation of the BaseClient protocol.
|
||||
|
||||
Provides vector database operations for Qdrant, supporting both
|
||||
synchronous and asynchronous clients.
|
||||
|
||||
Attributes:
|
||||
client: Qdrant client instance (QdrantClient or AsyncQdrantClient).
|
||||
embedding_function: Function to generate embeddings for documents.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: QdrantClientType,
|
||||
embedding_function: EmbeddingFunction | AsyncEmbeddingFunction,
|
||||
default_limit: int = 5,
|
||||
default_score_threshold: float = 0.6,
|
||||
default_batch_size: int = 100,
|
||||
) -> None:
|
||||
"""Initialize QdrantClient with client and embedding function.
|
||||
|
||||
Args:
|
||||
client: Pre-configured Qdrant client instance.
|
||||
embedding_function: Embedding function for text to vector conversion.
|
||||
default_limit: Default number of results to return in searches.
|
||||
default_score_threshold: Default minimum score for search results.
|
||||
default_batch_size: Default batch size for adding documents.
|
||||
"""
|
||||
self.client = client
|
||||
self.embedding_function = embedding_function
|
||||
self.default_limit = default_limit
|
||||
self.default_score_threshold = default_score_threshold
|
||||
self.default_batch_size = default_batch_size
|
||||
|
||||
def create_collection(self, **kwargs: Unpack[QdrantCollectionCreateParams]) -> None:
|
||||
"""Create a new collection in Qdrant.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to create. Must be unique.
|
||||
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
|
||||
sparse_vectors_config: Optional sparse vector configuration.
|
||||
shard_number: Optional number of shards.
|
||||
replication_factor: Optional replication factor.
|
||||
write_consistency_factor: Optional write consistency factor.
|
||||
on_disk_payload: Optional flag to store payload on disk.
|
||||
hnsw_config: Optional HNSW index configuration.
|
||||
optimizers_config: Optional optimizer configuration.
|
||||
wal_config: Optional write-ahead log configuration.
|
||||
quantization_config: Optional quantization configuration.
|
||||
init_from: Optional collection to initialize from.
|
||||
timeout: Optional timeout for the operation.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection with the same name already exists.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="create_collection",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="acreate_collection",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' already exists")
|
||||
|
||||
params = _get_collection_params(kwargs)
|
||||
self.client.create_collection(**params)
|
||||
|
||||
async def acreate_collection(
|
||||
self, **kwargs: Unpack[QdrantCollectionCreateParams]
|
||||
) -> None:
|
||||
"""Create a new collection in Qdrant asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to create. Must be unique.
|
||||
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
|
||||
sparse_vectors_config: Optional sparse vector configuration.
|
||||
shard_number: Optional number of shards.
|
||||
replication_factor: Optional replication factor.
|
||||
write_consistency_factor: Optional write consistency factor.
|
||||
on_disk_payload: Optional flag to store payload on disk.
|
||||
hnsw_config: Optional HNSW index configuration.
|
||||
optimizers_config: Optional optimizer configuration.
|
||||
wal_config: Optional write-ahead log configuration.
|
||||
quantization_config: Optional quantization configuration.
|
||||
init_from: Optional collection to initialize from.
|
||||
timeout: Optional timeout for the operation.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection with the same name already exists.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="acreate_collection",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="create_collection",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if await self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' already exists")
|
||||
|
||||
params = _get_collection_params(kwargs)
|
||||
await self.client.create_collection(**params)
|
||||
|
||||
def get_or_create_collection(
|
||||
self, **kwargs: Unpack[QdrantCollectionCreateParams]
|
||||
) -> Any:
|
||||
"""Get an existing collection or create it if it doesn't exist.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to get or create.
|
||||
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
|
||||
sparse_vectors_config: Optional sparse vector configuration.
|
||||
shard_number: Optional number of shards.
|
||||
replication_factor: Optional replication factor.
|
||||
write_consistency_factor: Optional write consistency factor.
|
||||
on_disk_payload: Optional flag to store payload on disk.
|
||||
hnsw_config: Optional HNSW index configuration.
|
||||
optimizers_config: Optional optimizer configuration.
|
||||
wal_config: Optional write-ahead log configuration.
|
||||
quantization_config: Optional quantization configuration.
|
||||
init_from: Optional collection to initialize from.
|
||||
timeout: Optional timeout for the operation.
|
||||
|
||||
Returns:
|
||||
Collection info dict with name and other metadata.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="get_or_create_collection",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="aget_or_create_collection",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if self.client.collection_exists(collection_name):
|
||||
return self.client.get_collection(collection_name)
|
||||
|
||||
params = _get_collection_params(kwargs)
|
||||
self.client.create_collection(**params)
|
||||
|
||||
return self.client.get_collection(collection_name)
|
||||
|
||||
async def aget_or_create_collection(
|
||||
self, **kwargs: Unpack[QdrantCollectionCreateParams]
|
||||
) -> Any:
|
||||
"""Get an existing collection or create it if it doesn't exist asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to get or create.
|
||||
vectors_config: Optional vector configuration. Defaults to 1536 dimensions with cosine distance.
|
||||
sparse_vectors_config: Optional sparse vector configuration.
|
||||
shard_number: Optional number of shards.
|
||||
replication_factor: Optional replication factor.
|
||||
write_consistency_factor: Optional write consistency factor.
|
||||
on_disk_payload: Optional flag to store payload on disk.
|
||||
hnsw_config: Optional HNSW index configuration.
|
||||
optimizers_config: Optional optimizer configuration.
|
||||
wal_config: Optional write-ahead log configuration.
|
||||
quantization_config: Optional quantization configuration.
|
||||
init_from: Optional collection to initialize from.
|
||||
timeout: Optional timeout for the operation.
|
||||
|
||||
Returns:
|
||||
Collection info dict with name and other metadata.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="aget_or_create_collection",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="get_or_create_collection",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if await self.client.collection_exists(collection_name):
|
||||
return await self.client.get_collection(collection_name)
|
||||
|
||||
params = _get_collection_params(kwargs)
|
||||
await self.client.create_collection(**params)
|
||||
|
||||
return await self.client.get_collection(collection_name)
|
||||
|
||||
def add_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing document data.
|
||||
batch_size: Optional batch size for processing documents (default: 100)
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="add_documents",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="aadd_documents",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
batch_size = kwargs.get("batch_size", self.default_batch_size)
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
if not self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
for i in range(0, len(documents), batch_size):
|
||||
batch_docs = documents[i : min(i + batch_size, len(documents))]
|
||||
points = []
|
||||
for doc in batch_docs:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
raise TypeError(
|
||||
"Async embedding function cannot be used with sync add_documents. "
|
||||
"Use aadd_documents instead."
|
||||
)
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
point = _create_point_from_document(doc, embedding)
|
||||
points.append(point)
|
||||
self.client.upsert(collection_name=collection_name, points=points)
|
||||
|
||||
async def aadd_documents(self, **kwargs: Unpack[BaseCollectionAddParams]) -> None:
|
||||
"""Add documents with their embeddings to a collection asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: The name of the collection to add documents to.
|
||||
documents: List of BaseRecord dicts containing document data.
|
||||
batch_size: Optional batch size for processing documents (default: 100)
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist or documents list is empty.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="aadd_documents",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="add_documents",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
documents = kwargs["documents"]
|
||||
batch_size = kwargs.get("batch_size", self.default_batch_size)
|
||||
|
||||
if not documents:
|
||||
raise ValueError("Documents list cannot be empty")
|
||||
|
||||
if not await self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
for i in range(0, len(documents), batch_size):
|
||||
batch_docs = documents[i : min(i + batch_size, len(documents))]
|
||||
points = []
|
||||
for doc in batch_docs:
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
||||
embedding = await async_fn(doc["content"])
|
||||
else:
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
embedding = sync_fn(doc["content"])
|
||||
point = _create_point_from_document(doc, embedding)
|
||||
points.append(point)
|
||||
await self.client.upsert(collection_name=collection_name, points=points)
|
||||
|
||||
def search(
|
||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to search in.
|
||||
query: The text query to search for.
|
||||
limit: Maximum number of results to return (default: 10).
|
||||
metadata_filter: Optional filter for metadata fields.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="search",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="asearch",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
query = kwargs["query"]
|
||||
limit = kwargs.get("limit", self.default_limit)
|
||||
metadata_filter = kwargs.get("metadata_filter")
|
||||
score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
|
||||
|
||||
if not self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
raise TypeError(
|
||||
"Async embedding function cannot be used with sync search. "
|
||||
"Use asearch instead."
|
||||
)
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
query_embedding = sync_fn(query)
|
||||
|
||||
search_kwargs = _prepare_search_params(
|
||||
collection_name=collection_name,
|
||||
query_embedding=query_embedding,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
metadata_filter=metadata_filter,
|
||||
)
|
||||
|
||||
response = self.client.query_points(**search_kwargs)
|
||||
return _process_search_results(response)
|
||||
|
||||
async def asearch(
|
||||
self, **kwargs: Unpack[BaseCollectionSearchParams]
|
||||
) -> list[SearchResult]:
|
||||
"""Search for similar documents using a query asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to search in.
|
||||
query: The text query to search for.
|
||||
limit: Maximum number of results to return (default: 10).
|
||||
metadata_filter: Optional filter for metadata fields.
|
||||
score_threshold: Optional minimum similarity score (0-1) for results.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dicts containing id, content, metadata, and score.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="asearch",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="search",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
query = kwargs["query"]
|
||||
limit = kwargs.get("limit", self.default_limit)
|
||||
metadata_filter = kwargs.get("metadata_filter")
|
||||
score_threshold = kwargs.get("score_threshold", self.default_score_threshold)
|
||||
|
||||
if not await self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
if _is_async_embedding_function(self.embedding_function):
|
||||
async_fn = cast(AsyncEmbeddingFunction, self.embedding_function)
|
||||
query_embedding = await async_fn(query)
|
||||
else:
|
||||
sync_fn = cast(EmbeddingFunction, self.embedding_function)
|
||||
query_embedding = sync_fn(query)
|
||||
|
||||
search_kwargs = _prepare_search_params(
|
||||
collection_name=collection_name,
|
||||
query_embedding=query_embedding,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
metadata_filter=metadata_filter,
|
||||
)
|
||||
|
||||
response = await self.client.query_points(**search_kwargs)
|
||||
return _process_search_results(response)
|
||||
|
||||
def delete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="delete_collection",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="adelete_collection",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if not self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
self.client.delete_collection(collection_name=collection_name)
|
||||
|
||||
async def adelete_collection(self, **kwargs: Unpack[BaseCollectionParams]) -> None:
|
||||
"""Delete a collection and all its data asynchronously.
|
||||
|
||||
Keyword Args:
|
||||
collection_name: Name of the collection to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If collection doesn't exist.
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="adelete_collection",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="delete_collection",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collection_name = kwargs["collection_name"]
|
||||
|
||||
if not await self.client.collection_exists(collection_name):
|
||||
raise ValueError(f"Collection '{collection_name}' does not exist")
|
||||
|
||||
await self.client.delete_collection(collection_name=collection_name)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_sync_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="reset",
|
||||
expected_client="QdrantClient",
|
||||
alt_method="areset",
|
||||
alt_client="AsyncQdrantClient",
|
||||
)
|
||||
|
||||
collections_response = self.client.get_collections()
|
||||
|
||||
for collection in collections_response.collections:
|
||||
self.client.delete_collection(collection_name=collection.name)
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the vector database by deleting all collections and data asynchronously.
|
||||
|
||||
Raises:
|
||||
ConnectionError: If unable to connect to Qdrant server.
|
||||
"""
|
||||
if not _is_async_client(self.client):
|
||||
raise ClientMethodMismatchError(
|
||||
method_name="areset",
|
||||
expected_client="AsyncQdrantClient",
|
||||
alt_method="reset",
|
||||
alt_client="QdrantClient",
|
||||
)
|
||||
|
||||
collections_response = await self.client.get_collections()
|
||||
|
||||
for collection in collections_response.collections:
|
||||
await self.client.delete_collection(collection_name=collection.name)
|
||||
55
lib/core/src/crewai/core/rag/qdrant/config.py
Normal file
55
lib/core/src/crewai/core/rag/qdrant/config.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Qdrant configuration model."""
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Literal, cast
|
||||
|
||||
from pydantic.dataclasses import dataclass as pyd_dataclass
|
||||
|
||||
from crewai.rag.config.base import BaseRagConfig
|
||||
from crewai.rag.qdrant.constants import DEFAULT_EMBEDDING_MODEL, DEFAULT_STORAGE_PATH
|
||||
from crewai.rag.qdrant.types import QdrantClientParams, QdrantEmbeddingFunctionWrapper
|
||||
|
||||
|
||||
def _default_options() -> QdrantClientParams:
|
||||
"""Create default Qdrant client options.
|
||||
|
||||
Returns:
|
||||
Default options with file-based storage.
|
||||
"""
|
||||
return QdrantClientParams(path=DEFAULT_STORAGE_PATH)
|
||||
|
||||
|
||||
def _default_embedding_function() -> QdrantEmbeddingFunctionWrapper:
|
||||
"""Create default Qdrant embedding function.
|
||||
|
||||
Returns:
|
||||
Default embedding function using fastembed with all-MiniLM-L6-v2.
|
||||
"""
|
||||
from fastembed import TextEmbedding # type: ignore[import-not-found]
|
||||
|
||||
model = TextEmbedding(model_name=DEFAULT_EMBEDDING_MODEL)
|
||||
|
||||
def embed_fn(text: str) -> list[float]:
|
||||
"""Embed a single text string.
|
||||
|
||||
Args:
|
||||
text: Text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats.
|
||||
"""
|
||||
embeddings = list(model.embed([text]))
|
||||
return embeddings[0].tolist() if embeddings else []
|
||||
|
||||
return cast(QdrantEmbeddingFunctionWrapper, embed_fn)
|
||||
|
||||
|
||||
@pyd_dataclass(frozen=True)
|
||||
class QdrantConfig(BaseRagConfig):
|
||||
"""Configuration for Qdrant client."""
|
||||
|
||||
provider: Literal["qdrant"] = field(default="qdrant", init=False)
|
||||
options: QdrantClientParams = field(default_factory=_default_options)
|
||||
embedding_function: QdrantEmbeddingFunctionWrapper = field(
|
||||
default_factory=_default_embedding_function
|
||||
)
|
||||
12
lib/core/src/crewai/core/rag/qdrant/constants.py
Normal file
12
lib/core/src/crewai/core/rag/qdrant/constants.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Constants for Qdrant implementation."""
|
||||
|
||||
import os
|
||||
from typing import Final
|
||||
|
||||
from qdrant_client.models import Distance, VectorParams
|
||||
|
||||
from crewai.core.utilities.paths import db_storage_path
|
||||
|
||||
DEFAULT_VECTOR_PARAMS: Final = VectorParams(size=384, distance=Distance.COSINE)
|
||||
DEFAULT_EMBEDDING_MODEL: Final[str] = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
DEFAULT_STORAGE_PATH: Final[str] = os.path.join(db_storage_path(), "qdrant")
|
||||
26
lib/core/src/crewai/core/rag/qdrant/factory.py
Normal file
26
lib/core/src/crewai/core/rag/qdrant/factory.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Factory functions for creating Qdrant clients from configuration."""
|
||||
|
||||
from qdrant_client import QdrantClient as SyncQdrantClientBase
|
||||
|
||||
from crewai.rag.qdrant.client import QdrantClient
|
||||
from crewai.rag.qdrant.config import QdrantConfig
|
||||
|
||||
|
||||
def create_client(config: QdrantConfig) -> QdrantClient:
|
||||
"""Create a Qdrant client from configuration.
|
||||
|
||||
Args:
|
||||
config: The Qdrant configuration.
|
||||
|
||||
Returns:
|
||||
A configured QdrantClient instance.
|
||||
"""
|
||||
|
||||
qdrant_client = SyncQdrantClientBase(**config.options)
|
||||
return QdrantClient(
|
||||
client=qdrant_client,
|
||||
embedding_function=config.embedding_function,
|
||||
default_limit=config.limit,
|
||||
default_score_threshold=config.score_threshold,
|
||||
default_batch_size=config.batch_size,
|
||||
)
|
||||
156
lib/core/src/crewai/core/rag/qdrant/types.py
Normal file
156
lib/core/src/crewai/core/rag/qdrant/types.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Type definitions specific to Qdrant implementation."""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Annotated, Any, Protocol, TypeAlias
|
||||
|
||||
import numpy as np
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found]
|
||||
from qdrant_client import (
|
||||
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
|
||||
)
|
||||
from qdrant_client.models import ( # type: ignore[import-not-found]
|
||||
FieldCondition,
|
||||
Filter,
|
||||
HasIdCondition,
|
||||
HasVectorCondition,
|
||||
HnswConfigDiff,
|
||||
InitFrom,
|
||||
IsEmptyCondition,
|
||||
IsNullCondition,
|
||||
NestedCondition,
|
||||
OptimizersConfigDiff,
|
||||
QuantizationConfig,
|
||||
ShardingMethod,
|
||||
SparseVectorsConfig,
|
||||
VectorsConfig,
|
||||
WalConfigDiff,
|
||||
)
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from crewai.rag.core.base_client import BaseCollectionParams
|
||||
|
||||
QdrantClientType = SyncQdrantClient | AsyncQdrantClient
|
||||
|
||||
QueryEmbedding: TypeAlias = list[float] | np.ndarray[Any, np.dtype[np.floating[Any]]]
|
||||
|
||||
BasicConditions = FieldCondition | IsEmptyCondition | IsNullCondition
|
||||
StructuralConditions = HasIdCondition | HasVectorCondition | NestedCondition
|
||||
FilterCondition = BasicConditions | StructuralConditions | Filter
|
||||
|
||||
MetadataFilterValue = bool | int | str
|
||||
MetadataFilter = dict[str, MetadataFilterValue]
|
||||
|
||||
|
||||
class EmbeddingFunction(Protocol):
|
||||
"""Protocol for embedding functions that convert text to vectors."""
|
||||
|
||||
def __call__(self, text: str) -> QueryEmbedding:
|
||||
"""Convert text to embedding vector.
|
||||
|
||||
Args:
|
||||
text: Input text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats or numpy array.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class QdrantEmbeddingFunctionWrapper(EmbeddingFunction):
|
||||
"""Base class for Qdrant EmbeddingFunction to work with Pydantic validation."""
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""Generate Pydantic core schema for Qdrant EmbeddingFunction.
|
||||
|
||||
This allows Pydantic to handle Qdrant's EmbeddingFunction type
|
||||
without requiring arbitrary_types_allowed=True.
|
||||
"""
|
||||
return core_schema.any_schema()
|
||||
|
||||
|
||||
class AsyncEmbeddingFunction(Protocol):
|
||||
"""Protocol for async embedding functions that convert text to vectors."""
|
||||
|
||||
async def __call__(self, text: str) -> QueryEmbedding:
|
||||
"""Convert text to embedding vector asynchronously.
|
||||
|
||||
Args:
|
||||
text: Input text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding vector as list of floats or numpy array.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class QdrantClientParams(TypedDict, total=False):
|
||||
"""Parameters for QdrantClient initialization.
|
||||
|
||||
Notes:
|
||||
Need to implement in factory or remove.
|
||||
"""
|
||||
|
||||
location: str | None
|
||||
url: str | None
|
||||
port: int
|
||||
grpc_port: int
|
||||
prefer_grpc: bool
|
||||
https: bool | None
|
||||
api_key: str | None
|
||||
prefix: str | None
|
||||
timeout: int | None
|
||||
host: str | None
|
||||
path: str | None
|
||||
force_disable_check_same_thread: bool
|
||||
grpc_options: dict[str, Any] | None
|
||||
auth_token_provider: Callable[[], str] | Callable[[], Awaitable[str]] | None
|
||||
cloud_inference: bool
|
||||
local_inference_batch_size: int | None
|
||||
check_compatibility: bool
|
||||
|
||||
|
||||
class CommonCreateFields(TypedDict, total=False):
|
||||
"""Fields shared between high-level and direct create_collection params."""
|
||||
|
||||
vectors_config: VectorsConfig
|
||||
sparse_vectors_config: SparseVectorsConfig
|
||||
shard_number: Annotated[int, "Number of shards (default: 1)"]
|
||||
sharding_method: ShardingMethod
|
||||
replication_factor: Annotated[int, "Number of replicas per shard (default: 1)"]
|
||||
write_consistency_factor: Annotated[int, "Await N replicas on write (default: 1)"]
|
||||
on_disk_payload: Annotated[bool, "Store payload on disk instead of RAM"]
|
||||
hnsw_config: HnswConfigDiff
|
||||
optimizers_config: OptimizersConfigDiff
|
||||
wal_config: WalConfigDiff
|
||||
quantization_config: QuantizationConfig
|
||||
init_from: InitFrom | str
|
||||
timeout: Annotated[int, "Operation timeout in seconds"]
|
||||
|
||||
|
||||
class QdrantCollectionCreateParams(
|
||||
BaseCollectionParams, CommonCreateFields, total=False
|
||||
):
|
||||
"""High-level parameters for creating a Qdrant collection."""
|
||||
|
||||
|
||||
class CreateCollectionParams(CommonCreateFields, total=False):
|
||||
"""Parameters for qdrant_client.create_collection."""
|
||||
|
||||
collection_name: str
|
||||
|
||||
|
||||
class PreparedSearchParams(TypedDict):
|
||||
"""Type definition for prepared Qdrant search parameters."""
|
||||
|
||||
collection_name: str
|
||||
query: list[float]
|
||||
limit: Annotated[int, "Max results to return"]
|
||||
with_payload: Annotated[bool, "Include payload in results"]
|
||||
with_vectors: Annotated[bool, "Include vectors in results"]
|
||||
score_threshold: NotRequired[Annotated[float, "Min similarity score (0-1)"]]
|
||||
query_filter: NotRequired[Filter]
|
||||
232
lib/core/src/crewai/core/rag/qdrant/utils.py
Normal file
232
lib/core/src/crewai/core/rag/qdrant/utils.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""Utility functions for Qdrant operations."""
|
||||
|
||||
import asyncio
|
||||
from typing import TypeGuard
|
||||
from uuid import uuid4
|
||||
|
||||
from qdrant_client import AsyncQdrantClient # type: ignore[import-not-found]
|
||||
from qdrant_client import (
|
||||
QdrantClient as SyncQdrantClient, # type: ignore[import-not-found]
|
||||
)
|
||||
from qdrant_client.models import ( # type: ignore[import-not-found]
|
||||
FieldCondition,
|
||||
Filter,
|
||||
MatchValue,
|
||||
PointStruct,
|
||||
QueryResponse,
|
||||
)
|
||||
|
||||
from crewai.rag.qdrant.constants import DEFAULT_VECTOR_PARAMS
|
||||
from crewai.rag.qdrant.types import (
|
||||
AsyncEmbeddingFunction,
|
||||
CreateCollectionParams,
|
||||
EmbeddingFunction,
|
||||
FilterCondition,
|
||||
MetadataFilter,
|
||||
PreparedSearchParams,
|
||||
QdrantClientType,
|
||||
QdrantCollectionCreateParams,
|
||||
QueryEmbedding,
|
||||
)
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
|
||||
|
||||
def _ensure_list_embedding(embedding: QueryEmbedding) -> list[float]:
|
||||
"""Convert embedding to list[float] format if needed.
|
||||
|
||||
Args:
|
||||
embedding: Embedding vector as list or numpy array.
|
||||
|
||||
Returns:
|
||||
Embedding as list[float].
|
||||
"""
|
||||
if not isinstance(embedding, list):
|
||||
result = embedding.tolist()
|
||||
return result if isinstance(result, list) else [result]
|
||||
return embedding
|
||||
|
||||
|
||||
def _is_sync_client(client: QdrantClientType) -> TypeGuard[SyncQdrantClient]:
|
||||
"""Type guard to check if the client is a synchronous QdrantClient.
|
||||
|
||||
Args:
|
||||
client: The client to check.
|
||||
|
||||
Returns:
|
||||
True if the client is a QdrantClient, False otherwise.
|
||||
"""
|
||||
return isinstance(client, SyncQdrantClient)
|
||||
|
||||
|
||||
def _is_async_client(client: QdrantClientType) -> TypeGuard[AsyncQdrantClient]:
|
||||
"""Type guard to check if the client is an asynchronous AsyncQdrantClient.
|
||||
|
||||
Args:
|
||||
client: The client to check.
|
||||
|
||||
Returns:
|
||||
True if the client is an AsyncQdrantClient, False otherwise.
|
||||
"""
|
||||
return isinstance(client, AsyncQdrantClient)
|
||||
|
||||
|
||||
def _is_async_embedding_function(
|
||||
func: EmbeddingFunction | AsyncEmbeddingFunction,
|
||||
) -> TypeGuard[AsyncEmbeddingFunction]:
|
||||
"""Type guard to check if the embedding function is async.
|
||||
|
||||
Args:
|
||||
func: The embedding function to check.
|
||||
|
||||
Returns:
|
||||
True if the function is async, False otherwise.
|
||||
"""
|
||||
return asyncio.iscoroutinefunction(func)
|
||||
|
||||
|
||||
def _get_collection_params(
|
||||
kwargs: QdrantCollectionCreateParams,
|
||||
) -> CreateCollectionParams:
|
||||
"""Extract collection creation parameters from kwargs."""
|
||||
params: CreateCollectionParams = {
|
||||
"collection_name": kwargs["collection_name"],
|
||||
"vectors_config": kwargs.get("vectors_config", DEFAULT_VECTOR_PARAMS),
|
||||
}
|
||||
|
||||
if "sparse_vectors_config" in kwargs:
|
||||
params["sparse_vectors_config"] = kwargs["sparse_vectors_config"]
|
||||
if "shard_number" in kwargs:
|
||||
params["shard_number"] = kwargs["shard_number"]
|
||||
if "sharding_method" in kwargs:
|
||||
params["sharding_method"] = kwargs["sharding_method"]
|
||||
if "replication_factor" in kwargs:
|
||||
params["replication_factor"] = kwargs["replication_factor"]
|
||||
if "write_consistency_factor" in kwargs:
|
||||
params["write_consistency_factor"] = kwargs["write_consistency_factor"]
|
||||
if "on_disk_payload" in kwargs:
|
||||
params["on_disk_payload"] = kwargs["on_disk_payload"]
|
||||
if "hnsw_config" in kwargs:
|
||||
params["hnsw_config"] = kwargs["hnsw_config"]
|
||||
if "optimizers_config" in kwargs:
|
||||
params["optimizers_config"] = kwargs["optimizers_config"]
|
||||
if "wal_config" in kwargs:
|
||||
params["wal_config"] = kwargs["wal_config"]
|
||||
if "quantization_config" in kwargs:
|
||||
params["quantization_config"] = kwargs["quantization_config"]
|
||||
if "init_from" in kwargs:
|
||||
params["init_from"] = kwargs["init_from"]
|
||||
if "timeout" in kwargs:
|
||||
params["timeout"] = kwargs["timeout"]
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def _prepare_search_params(
|
||||
collection_name: str,
|
||||
query_embedding: QueryEmbedding,
|
||||
limit: int,
|
||||
score_threshold: float | None,
|
||||
metadata_filter: MetadataFilter | None,
|
||||
) -> PreparedSearchParams:
|
||||
"""Prepare search parameters for Qdrant query_points.
|
||||
|
||||
Args:
|
||||
collection_name: Name of the collection to search.
|
||||
query_embedding: Embedding vector for the query.
|
||||
limit: Maximum number of results.
|
||||
score_threshold: Optional minimum similarity score.
|
||||
metadata_filter: Optional metadata filters.
|
||||
|
||||
Returns:
|
||||
Dictionary of parameters for query_points method.
|
||||
"""
|
||||
query_vector = _ensure_list_embedding(query_embedding)
|
||||
|
||||
search_kwargs: PreparedSearchParams = {
|
||||
"collection_name": collection_name,
|
||||
"query": query_vector,
|
||||
"limit": limit,
|
||||
"with_payload": True,
|
||||
"with_vectors": False,
|
||||
}
|
||||
|
||||
if score_threshold is not None:
|
||||
search_kwargs["score_threshold"] = score_threshold
|
||||
|
||||
if metadata_filter:
|
||||
filter_conditions: list[FilterCondition] = []
|
||||
for key, value in metadata_filter.items():
|
||||
filter_conditions.append(
|
||||
FieldCondition(key=key, match=MatchValue(value=value))
|
||||
)
|
||||
|
||||
search_kwargs["query_filter"] = Filter(must=filter_conditions)
|
||||
|
||||
return search_kwargs
|
||||
|
||||
|
||||
def _normalize_qdrant_score(score: float) -> float:
|
||||
"""Normalize Qdrant cosine similarity score to [0, 1] range.
|
||||
|
||||
Converts from Qdrant's [-1, 1] cosine similarity range to [0, 1] range for standardization across clients.
|
||||
|
||||
Args:
|
||||
score: Raw cosine similarity score from Qdrant [-1, 1].
|
||||
|
||||
Returns:
|
||||
Normalized score in [0, 1] range where 1 is most similar.
|
||||
"""
|
||||
normalized = (score + 1.0) / 2.0
|
||||
return max(0.0, min(1.0, normalized))
|
||||
|
||||
|
||||
def _process_search_results(response: QueryResponse) -> list[SearchResult]:
|
||||
"""Process Qdrant search response into SearchResult format.
|
||||
|
||||
Args:
|
||||
response: Response from Qdrant query_points method.
|
||||
|
||||
Returns:
|
||||
List of SearchResult dictionaries.
|
||||
"""
|
||||
results: list[SearchResult] = []
|
||||
for point in response.points:
|
||||
payload = point.payload or {}
|
||||
score = _normalize_qdrant_score(score=point.score)
|
||||
result: SearchResult = {
|
||||
"id": str(point.id),
|
||||
"content": payload.get("content", ""),
|
||||
"metadata": {k: v for k, v in payload.items() if k != "content"},
|
||||
"score": score,
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _create_point_from_document(
|
||||
doc: BaseRecord, embedding: QueryEmbedding
|
||||
) -> PointStruct:
|
||||
"""Create a PointStruct from a document and its embedding.
|
||||
|
||||
Args:
|
||||
doc: Document dictionary containing content, metadata, and optional doc_id.
|
||||
embedding: The embedding vector for the document content.
|
||||
|
||||
Returns:
|
||||
PointStruct ready to be upserted to Qdrant.
|
||||
"""
|
||||
doc_id = doc.get("doc_id", str(uuid4()))
|
||||
vector = _ensure_list_embedding(embedding)
|
||||
|
||||
metadata = doc.get("metadata", {})
|
||||
if isinstance(metadata, list):
|
||||
metadata = metadata[0] if metadata else {}
|
||||
elif not isinstance(metadata, dict):
|
||||
metadata = dict(metadata) if metadata else {}
|
||||
|
||||
return PointStruct(
|
||||
id=doc_id,
|
||||
vector=vector,
|
||||
payload={"content": doc["content"], **metadata},
|
||||
)
|
||||
1
lib/core/src/crewai/core/rag/storage/__init__.py
Normal file
1
lib/core/src/crewai/core/rag/storage/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Storage components for RAG infrastructure."""
|
||||
55
lib/core/src/crewai/core/rag/storage/base_rag_storage.py
Normal file
55
lib/core/src/crewai/core/rag/storage/base_rag_storage.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
|
||||
|
||||
class BaseRAGStorage(ABC):
|
||||
"""
|
||||
Base class for RAG-based Storage implementations.
|
||||
"""
|
||||
|
||||
app: Any | None = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
|
||||
crew: Any = None,
|
||||
):
|
||||
self.type = type
|
||||
self.allow_reset = allow_reset
|
||||
self.embedder_config = embedder_config
|
||||
self.crew = crew
|
||||
self.agents = self._initialize_agents()
|
||||
|
||||
def _initialize_agents(self) -> str:
|
||||
if self.crew:
|
||||
return "_".join(
|
||||
[self._sanitize_role(agent.role) for agent in self.crew.agents]
|
||||
)
|
||||
return ""
|
||||
|
||||
@abstractmethod
|
||||
def _sanitize_role(self, role: str) -> str:
|
||||
"""Sanitizes agent roles to ensure valid directory names."""
|
||||
|
||||
@abstractmethod
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
"""Save a value with metadata to the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search for entries in the storage."""
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the storage."""
|
||||
49
lib/core/src/crewai/core/rag/types.py
Normal file
49
lib/core/src/crewai/core/rag/types.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Type definitions for RAG (Retrieval-Augmented Generation) systems."""
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
|
||||
class BaseRecord(TypedDict, total=False):
|
||||
"""A typed dictionary representing a document record.
|
||||
|
||||
Attributes:
|
||||
doc_id: Optional unique identifier for the document. If not provided,
|
||||
a content-based ID will be generated using SHA256 hash.
|
||||
content: The text content of the document (required)
|
||||
metadata: Optional metadata associated with the document
|
||||
"""
|
||||
|
||||
doc_id: str
|
||||
content: Required[str]
|
||||
metadata: (
|
||||
Mapping[str, str | int | float | bool]
|
||||
| list[Mapping[str, str | int | float | bool]]
|
||||
)
|
||||
|
||||
|
||||
Embeddings: TypeAlias = list[list[float]]
|
||||
|
||||
EmbeddingFunction: TypeAlias = Callable[..., Any]
|
||||
|
||||
|
||||
class SearchResult(TypedDict):
|
||||
"""Standard search result format for vector store queries.
|
||||
|
||||
This provides a consistent interface for search results across different
|
||||
vector store implementations. Each implementation should convert their
|
||||
native result format to this standard format.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier of the document
|
||||
content: The text content of the document
|
||||
metadata: Optional metadata associated with the document
|
||||
score: Similarity score (higher is better, typically between 0 and 1)
|
||||
"""
|
||||
|
||||
id: str
|
||||
content: str
|
||||
metadata: dict[str, Any]
|
||||
score: float
|
||||
1
lib/core/src/crewai/core/utilities/__init__.py
Normal file
1
lib/core/src/crewai/core/utilities/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Core utilities for CrewAI."""
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user