feat: add custom embedding types and migrate providers

- introduce baseembeddingsprovider and helper for embedding functions  
- add core embedding types and migrate providers, factory, and storage modules  
- remove unused type aliases and fix pydantic schema error  
- update providers with env var support and related fixes
This commit is contained in:
Greyson LaLonde
2025-09-25 18:28:39 -04:00
committed by GitHub
parent e070c1400c
commit ce5ea9be6f
74 changed files with 2767 additions and 1308 deletions

View File

@@ -8,7 +8,9 @@ from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.factory import build_embedder
from crewai.rag.embeddings.types import ProviderSpec
from crewai.rag.factory import create_client
from crewai.rag.types import BaseRecord, SearchResult
from crewai.utilities.logger import Logger
@@ -22,12 +24,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
def __init__(
self,
embedder: dict[str, Any] | None = None,
embedder: ProviderSpec | BaseEmbeddingsProvider | None = None,
collection_name: str | None = None,
) -> None:
self.collection_name = collection_name
self._client: BaseClient | None = None
self._embedder_config = embedder # Store embedder config
warnings.filterwarnings(
"ignore",
@@ -36,29 +37,12 @@ class KnowledgeStorage(BaseKnowledgeStorage):
)
if embedder:
# Cast to EmbedderConfig for type checking
embedder_typed = cast(EmbedderConfig, embedder)
embedding_function = get_embedding_function(embedder_typed)
batch_size = None
if isinstance(embedder, dict) and "config" in embedder:
nested_config = embedder["config"]
if isinstance(nested_config, dict):
batch_size = nested_config.get("batch_size")
# Create config with batch_size if provided
if batch_size is not None:
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
),
batch_size=batch_size,
)
else:
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
)
embedding_function = build_embedder(embedder)
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
)
)
self._client = create_client(config)
def _get_client(self) -> BaseClient:
@@ -123,23 +107,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
batch_size = None
if self._embedder_config and isinstance(self._embedder_config, dict):
if "config" in self._embedder_config:
nested_config = self._embedder_config["config"]
if isinstance(nested_config, dict):
batch_size = nested_config.get("batch_size")
if batch_size is not None:
client.add_documents(
collection_name=collection_name,
documents=rag_documents,
batch_size=batch_size,
)
else:
client.add_documents(
collection_name=collection_name, documents=rag_documents
)
client.add_documents(
collection_name=collection_name, documents=rag_documents
)
except Exception as e:
if "dimension mismatch" in str(e).lower():
Logger(verbose=True).log(

View File

@@ -7,8 +7,9 @@ from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
from crewai.rag.config.utils import get_rag_client
from crewai.rag.core.base_client import BaseClient
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
from crewai.rag.embeddings.types import EmbeddingOptions
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.factory import build_embedder
from crewai.rag.embeddings.types import ProviderSpec
from crewai.rag.factory import create_client
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
from crewai.rag.types import BaseRecord
@@ -26,7 +27,7 @@ class RAGStorage(BaseRAGStorage):
self,
type: str,
allow_reset: bool = True,
embedder_config: EmbeddingOptions | EmbedderConfig | None = None,
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
crew: Any = None,
path: str | None = None,
) -> None:
@@ -50,15 +51,17 @@ class RAGStorage(BaseRAGStorage):
)
if self.embedder_config:
embedding_function = get_embedding_function(self.embedder_config)
embedding_function = build_embedder(self.embedder_config)
try:
_ = embedding_function(["test"])
except Exception as e:
provider = (
self.embedder_config.provider
if isinstance(self.embedder_config, EmbeddingOptions)
else self.embedder_config.get("provider", "unknown")
self.embedder_config["provider"]
if isinstance(self.embedder_config, dict)
else self.embedder_config.__class__.__name__.replace(
"Provider", ""
).lower()
)
raise ValueError(
f"Failed to initialize embedder. Please check your configuration or connection.\n"
@@ -80,7 +83,7 @@ class RAGStorage(BaseRAGStorage):
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
),
batch_size=batch_size,
batch_size=cast(int, batch_size),
)
else:
config = ChromaDBConfig(
@@ -142,7 +145,7 @@ class RAGStorage(BaseRAGStorage):
client.add_documents(
collection_name=collection_name,
documents=[document],
batch_size=batch_size,
batch_size=cast(int, batch_size),
)
else:
client.add_documents(

View File

@@ -0,0 +1,142 @@
"""Base embeddings callable utilities for RAG systems."""
from typing import Protocol, TypeVar, runtime_checkable
import numpy as np
from crewai.rag.core.types import (
Embeddable,
Embedding,
Embeddings,
PyEmbedding,
)
T = TypeVar("T")
D = TypeVar("D", bound=Embeddable, contravariant=True)
def normalize_embeddings(
target: Embedding | list[Embedding] | PyEmbedding | list[PyEmbedding],
) -> Embeddings | None:
"""Normalize various embedding formats to a standard list of numpy arrays.
Args:
target: Input embeddings in various formats (list of floats, list of lists,
numpy array, or list of numpy arrays).
Returns:
Normalized embeddings as a list of numpy arrays, or None if input is None.
Raises:
ValueError: If embeddings are empty or in an unsupported format.
"""
if isinstance(target, np.ndarray):
if target.ndim == 1:
return [target.astype(np.float32)]
if target.ndim == 2:
return [row.astype(np.float32) for row in target]
raise ValueError(f"Unsupported numpy array shape: {target.shape}")
first = target[0]
if isinstance(first, (int, float)) and not isinstance(first, bool):
return [np.array(target, dtype=np.float32)]
if isinstance(first, list):
return [np.array(emb, dtype=np.float32) for emb in target]
if isinstance(first, np.ndarray):
return [emb.astype(np.float32) for emb in target] # type: ignore[union-attr]
raise ValueError(f"Unsupported embeddings format: {type(first)}")
def maybe_cast_one_to_many(target: T | list[T] | None) -> list[T] | None:
"""Cast a single item to a list if needed.
Args:
target: A single item or list of items.
Returns:
A list of items or None if input is None.
"""
if target is None:
return None
return target if isinstance(target, list) else [target]
def validate_embeddings(embeddings: Embeddings) -> Embeddings:
"""Validate embeddings format and content.
Args:
embeddings: List of numpy arrays to validate.
Returns:
Validated embeddings.
Raises:
ValueError: If embeddings format or content is invalid.
"""
if not isinstance(embeddings, list):
raise ValueError(
f"Expected embeddings to be a list, got {type(embeddings).__name__}"
)
if len(embeddings) == 0:
raise ValueError(
f"Expected embeddings to be a list with at least one item, got {len(embeddings)} embeddings"
)
if not all(isinstance(e, np.ndarray) for e in embeddings):
raise ValueError(
"Expected each embedding in the embeddings to be a numpy array"
)
for i, embedding in enumerate(embeddings):
if embedding.ndim == 0:
raise ValueError(
f"Expected a 1-dimensional array, got a 0-dimensional array {embedding}"
)
if embedding.size == 0:
raise ValueError(
f"Expected each embedding to be a 1-dimensional numpy array with at least 1 value. "
f"Got an array with no values at position {i}"
)
if not all(
isinstance(value, (np.integer, float, np.floating))
and not isinstance(value, bool)
for value in embedding
):
raise ValueError(
f"Expected embedding to contain numeric values, got non-numeric values at position {i}"
)
return embeddings
@runtime_checkable
class EmbeddingFunction(Protocol[D]):
"""Protocol for embedding functions.
Embedding functions convert input data (documents or images) into vector embeddings.
"""
def __call__(self, input: D) -> Embeddings:
"""Convert input data to embeddings.
Args:
input: Input data to embed (documents or images).
Returns:
List of numpy arrays representing the embeddings.
"""
...
def __init_subclass__(cls) -> None:
"""Wrap __call__ method to normalize and validate embeddings."""
super().__init_subclass__()
original_call = cls.__call__
def wrapped_call(self: EmbeddingFunction[D], input: D) -> Embeddings:
result = original_call(self, input)
if result is None:
raise ValueError("Embedding function returned None")
normalized = normalize_embeddings(result)
if normalized is None:
raise ValueError("Normalization returned None for non-None input")
return validate_embeddings(normalized)
cls.__call__ = wrapped_call # type: ignore[method-assign]

View File

@@ -0,0 +1,23 @@
"""Base class for embedding providers."""
from typing import Generic, TypeVar
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
T = TypeVar("T", bound=EmbeddingFunction)
class BaseEmbeddingsProvider(BaseSettings, Generic[T]):
"""Abstract base class for embedding providers.
This class provides a common interface for dynamically loading and building
embedding functions from various providers.
"""
model_config = SettingsConfigDict(extra="allow", populate_by_name=True)
embedding_callable: type[T] = Field(
..., description="The embedding function class to use"
)

View File

@@ -0,0 +1,28 @@
"""Core type definitions for RAG systems."""
from collections.abc import Sequence
from typing import TypeVar
import numpy as np
from numpy import floating, integer, number
from numpy.typing import NDArray
T = TypeVar("T")
PyEmbedding = Sequence[float] | Sequence[int]
PyEmbeddings = list[PyEmbedding]
Embedding = NDArray[np.int32 | np.float32]
Embeddings = list[Embedding]
Documents = list[str]
Images = list[np.ndarray]
Embeddable = Documents | Images
ScalarType = TypeVar("ScalarType", bound=np.generic)
IntegerType = TypeVar("IntegerType", bound=integer)
FloatingType = TypeVar("FloatingType", bound=floating)
NumberType = TypeVar("NumberType", bound=number)
DType32 = TypeVar("DType32", np.int32, np.float32)
DType64 = TypeVar("DType64", np.int64, np.float64)
DTypeCommon = TypeVar("DTypeCommon", np.int32, np.int64, np.float32, np.float64)

View File

@@ -1,245 +0,0 @@
import os
from typing import Any, cast
from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.api.types import validate_embedding_function
class EmbeddingConfigurator:
def __init__(self):
self.embedding_functions = {
"openai": self._configure_openai,
"azure": self._configure_azure,
"ollama": self._configure_ollama,
"vertexai": self._configure_vertexai,
"google": self._configure_google,
"cohere": self._configure_cohere,
"voyageai": self._configure_voyageai,
"bedrock": self._configure_bedrock,
"huggingface": self._configure_huggingface,
"watson": self._configure_watson,
"custom": self._configure_custom,
}
def configure_embedder(
self,
embedder_config: dict[str, Any] | None = None,
) -> EmbeddingFunction:
"""Configures and returns an embedding function based on the provided config."""
if embedder_config is None:
return self._create_default_embedding_function()
provider = embedder_config.get("provider")
config = embedder_config.get("config", {})
model_name = config.get("model") if provider != "custom" else None
if provider not in self.embedding_functions:
raise Exception(
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
)
try:
embedding_function = self.embedding_functions[provider]
except ImportError as e:
missing_package = str(e).split()[-1]
raise ImportError(
f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
) from e
return (
embedding_function(config)
if provider == "custom"
else embedding_function(config, model_name)
)
@staticmethod
def _create_default_embedding_function():
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
@staticmethod
def _configure_openai(config, model_name):
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
model_name=model_name,
api_base=config.get("api_base", None),
api_type=config.get("api_type", None),
api_version=config.get("api_version", None),
default_headers=config.get("default_headers", None),
dimensions=config.get("dimensions", None),
deployment_id=config.get("deployment_id", None),
organization_id=config.get("organization_id", None),
)
@staticmethod
def _configure_azure(config, model_name):
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=config.get("api_key"),
api_base=config.get("api_base"),
api_type=config.get("api_type", "azure"),
api_version=config.get("api_version"),
model_name=model_name,
default_headers=config.get("default_headers"),
dimensions=config.get("dimensions"),
deployment_id=config.get("deployment_id"),
organization_id=config.get("organization_id"),
)
@staticmethod
def _configure_ollama(config, model_name):
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
return OllamaEmbeddingFunction(
url=config.get("url", "http://localhost:11434/api/embeddings"),
model_name=model_name,
)
@staticmethod
def _configure_vertexai(config, model_name):
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)
return GoogleVertexEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
project_id=config.get("project_id"),
region=config.get("region"),
)
@staticmethod
def _configure_google(config, model_name):
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
)
return GoogleGenerativeAiEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
task_type=config.get("task_type"),
)
@staticmethod
def _configure_cohere(config, model_name):
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
return CohereEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
@staticmethod
def _configure_voyageai(config, model_name):
from chromadb.utils.embedding_functions.voyageai_embedding_function import ( # type: ignore[import-not-found]
VoyageAIEmbeddingFunction,
)
return VoyageAIEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
)
@staticmethod
def _configure_bedrock(config, model_name):
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)
# Allow custom model_name override with backwards compatibility
kwargs = {"session": config.get("session")}
if model_name is not None:
kwargs["model_name"] = model_name
return AmazonBedrockEmbeddingFunction(**kwargs)
@staticmethod
def _configure_huggingface(config, model_name):
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingServer,
)
return HuggingFaceEmbeddingServer(
url=config.get("api_url"),
)
@staticmethod
def _configure_watson(config, model_name):
try:
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found]
from ibm_watsonx_ai import Credentials # type: ignore[import-not-found]
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found]
EmbedTextParamsMetaNames as EmbedParams,
)
except ImportError as e:
raise ImportError(
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
) from e
class WatsonEmbeddingFunction(EmbeddingFunction):
def __call__(self, input: Documents) -> Embeddings:
if isinstance(input, str):
input = [input]
embed_params = {
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
EmbedParams.RETURN_OPTIONS: {"input_text": True},
}
embedding = watson_models.Embeddings(
model_id=config.get("model"),
params=embed_params,
credentials=Credentials(
api_key=config.get("api_key"), url=config.get("api_url")
),
project_id=config.get("project_id"),
)
try:
embeddings = embedding.embed_documents(input)
return cast(Embeddings, embeddings)
except Exception as e:
print("Error during Watson embedding:", e)
raise e
return WatsonEmbeddingFunction()
@staticmethod
def _configure_custom(config):
custom_embedder = config.get("embedder")
if isinstance(custom_embedder, EmbeddingFunction):
try:
validate_embedding_function(custom_embedder)
return custom_embedder
except Exception as e:
raise ValueError(f"Invalid custom embedding function: {e!s}") from e
elif callable(custom_embedder):
try:
instance = custom_embedder()
if isinstance(instance, EmbeddingFunction):
validate_embedding_function(instance)
return instance
raise ValueError(
"Custom embedder does not create an EmbeddingFunction instance"
)
except Exception as e:
raise ValueError(f"Error instantiating custom embedder: {e!s}") from e
else:
raise ValueError(
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"
)

View File

@@ -1,249 +1,363 @@
"""Minimal embedding function factory for CrewAI."""
"""Factory functions for creating embedding providers and functions."""
import os
from collections.abc import Callable, MutableMapping
from typing import Any, Final, Literal, TypedDict
from __future__ import annotations
from chromadb import EmbeddingFunction
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
GooglePalmEmbeddingFunction,
GoogleVertexEmbeddingFunction,
)
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingFunction,
)
from chromadb.utils.embedding_functions.instructor_embedding_function import (
InstructorEmbeddingFunction,
)
from chromadb.utils.embedding_functions.jina_embedding_function import (
JinaEmbeddingFunction,
)
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
OpenCLIPEmbeddingFunction,
)
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
RoboflowEmbeddingFunction,
)
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
SentenceTransformerEmbeddingFunction,
)
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
Text2VecEmbeddingFunction,
)
from typing_extensions import NotRequired
from typing import TYPE_CHECKING, TypeVar, overload
from crewai.rag.embeddings.types import EmbeddingOptions
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.utilities.import_utils import import_and_validate_definition
AllowedEmbeddingProviders = Literal[
"openai",
"cohere",
"ollama",
"huggingface",
"sentence-transformer",
"instructor",
"google-palm",
"google-generativeai",
"google-vertex",
"amazon-bedrock",
"jina",
"roboflow",
"openclip",
"text2vec",
"onnx",
]
if TYPE_CHECKING:
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
GoogleVertexEmbeddingFunction,
)
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingFunction,
)
from chromadb.utils.embedding_functions.instructor_embedding_function import (
InstructorEmbeddingFunction,
)
from chromadb.utils.embedding_functions.jina_embedding_function import (
JinaEmbeddingFunction,
)
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
OpenCLIPEmbeddingFunction,
)
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
RoboflowEmbeddingFunction,
)
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
SentenceTransformerEmbeddingFunction,
)
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
Text2VecEmbeddingFunction,
)
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
from crewai.rag.embeddings.providers.google.types import (
GenerativeAiProviderSpec,
VertexAIProviderSpec,
)
from crewai.rag.embeddings.providers.huggingface.types import (
HuggingFaceProviderSpec,
)
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
WatsonEmbeddingFunction,
)
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderSpec
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec
from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
from crewai.rag.embeddings.providers.openclip.types import OpenCLIPProviderSpec
from crewai.rag.embeddings.providers.roboflow.types import RoboflowProviderSpec
from crewai.rag.embeddings.providers.sentence_transformer.types import (
SentenceTransformerProviderSpec,
)
from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
VoyageAIEmbeddingFunction,
)
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
T = TypeVar("T", bound=EmbeddingFunction)
class EmbedderConfig(TypedDict):
"""Configuration for embedding functions with nested format."""
provider: AllowedEmbeddingProviders
config: NotRequired[dict[str, Any]]
EMBEDDING_PROVIDERS: Final[
dict[AllowedEmbeddingProviders, Callable[..., EmbeddingFunction]]
] = {
"openai": OpenAIEmbeddingFunction,
"cohere": CohereEmbeddingFunction,
"ollama": OllamaEmbeddingFunction,
"huggingface": HuggingFaceEmbeddingFunction,
"sentence-transformer": SentenceTransformerEmbeddingFunction,
"instructor": InstructorEmbeddingFunction,
"google-palm": GooglePalmEmbeddingFunction,
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
"google-vertex": GoogleVertexEmbeddingFunction,
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
"jina": JinaEmbeddingFunction,
"roboflow": RoboflowEmbeddingFunction,
"openclip": OpenCLIPEmbeddingFunction,
"text2vec": Text2VecEmbeddingFunction,
"onnx": ONNXMiniLM_L6_V2,
}
PROVIDER_ENV_MAPPING: Final[dict[AllowedEmbeddingProviders, tuple[str, str]]] = {
"openai": ("OPENAI_API_KEY", "api_key"),
"cohere": ("COHERE_API_KEY", "api_key"),
"huggingface": ("HUGGINGFACE_API_KEY", "api_key"),
"google-palm": ("GOOGLE_API_KEY", "api_key"),
"google-generativeai": ("GOOGLE_API_KEY", "api_key"),
"google-vertex": ("GOOGLE_API_KEY", "api_key"),
"jina": ("JINA_API_KEY", "api_key"),
"roboflow": ("ROBOFLOW_API_KEY", "api_key"),
PROVIDER_PATHS = {
"azure": "crewai.rag.embeddings.providers.microsoft.azure.AzureProvider",
"amazon-bedrock": "crewai.rag.embeddings.providers.aws.bedrock.BedrockProvider",
"cohere": "crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider",
"custom": "crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider",
"google-generativeai": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
"google-vertex": "crewai.rag.embeddings.providers.google.vertex.VertexAIProvider",
"huggingface": "crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider",
"instructor": "crewai.rag.embeddings.providers.instructor.instructor_provider.InstructorProvider",
"jina": "crewai.rag.embeddings.providers.jina.jina_provider.JinaProvider",
"ollama": "crewai.rag.embeddings.providers.ollama.ollama_provider.OllamaProvider",
"onnx": "crewai.rag.embeddings.providers.onnx.onnx_provider.ONNXProvider",
"openai": "crewai.rag.embeddings.providers.openai.openai_provider.OpenAIProvider",
"openclip": "crewai.rag.embeddings.providers.openclip.openclip_provider.OpenCLIPProvider",
"roboflow": "crewai.rag.embeddings.providers.roboflow.roboflow_provider.RoboflowProvider",
"sentence-transformer": "crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider.SentenceTransformerProvider",
"text2vec": "crewai.rag.embeddings.providers.text2vec.text2vec_provider.Text2VecProvider",
"voyageai": "crewai.rag.embeddings.providers.voyageai.voyageai_provider.VoyageAIProvider",
"watson": "crewai.rag.embeddings.providers.ibm.watson.WatsonProvider",
}
def _inject_api_key_from_env(
provider: AllowedEmbeddingProviders, config_dict: MutableMapping[str, Any]
) -> None:
"""Inject API key or other required configuration from environment if not explicitly provided.
def build_embedder_from_provider(provider: BaseEmbeddingsProvider[T]) -> T:
"""Build an embedding function instance from a provider.
Args:
provider: The embedding provider name
config_dict: The configuration dictionary to modify in-place
Raises:
ImportError: If required libraries for certain providers are not installed
ValueError: If AWS session creation fails for amazon-bedrock
"""
if provider in PROVIDER_ENV_MAPPING:
env_var_name, config_key = PROVIDER_ENV_MAPPING[provider]
if config_key not in config_dict:
env_value = os.getenv(env_var_name)
if env_value:
config_dict[config_key] = env_value
if provider == "amazon-bedrock":
if "session" not in config_dict:
try:
import boto3 # type: ignore[import]
config_dict["session"] = boto3.Session()
except ImportError as e:
raise ImportError(
"boto3 is required for amazon-bedrock embeddings. "
"Install it with: uv add boto3"
) from e
except Exception as e:
raise ValueError(
f"Failed to create AWS session for amazon-bedrock. "
f"Ensure AWS credentials are configured. Error: {e}"
) from e
def get_embedding_function(
config: EmbeddingOptions | EmbedderConfig | None = None,
) -> EmbeddingFunction:
"""Get embedding function - delegates to ChromaDB.
Args:
config: Optional configuration - either:
- EmbeddingOptions: Pydantic model with flat configuration
- EmbedderConfig: TypedDict with nested format {"provider": str, "config": dict}
- None: Uses default OpenAI configuration
provider: The embedding provider configuration.
Returns:
EmbeddingFunction instance ready for use with ChromaDB
An instance of the specified embedding function type.
"""
return provider.embedding_callable(
**provider.model_dump(exclude={"embedding_callable"})
)
Supported providers:
- openai: OpenAI embeddings
- cohere: Cohere embeddings
- ollama: Ollama local embeddings
- huggingface: HuggingFace embeddings
- sentence-transformer: Local sentence transformers
- instructor: Instructor embeddings for specialized tasks
- google-palm: Google PaLM embeddings
- google-generativeai: Google Generative AI embeddings
- google-vertex: Google Vertex AI embeddings
- amazon-bedrock: AWS Bedrock embeddings
- jina: Jina AI embeddings
- roboflow: Roboflow embeddings for vision tasks
- openclip: OpenCLIP embeddings for multimodal tasks
- text2vec: Text2Vec embeddings
- onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB)
@overload
def build_embedder_from_dict(spec: AzureProviderSpec) -> OpenAIEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: BedrockProviderSpec,
) -> AmazonBedrockEmbeddingFunction: ...
@overload
def build_embedder_from_dict(spec: CohereProviderSpec) -> CohereEmbeddingFunction: ...
@overload
def build_embedder_from_dict(spec: CustomProviderSpec) -> EmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: GenerativeAiProviderSpec,
) -> GoogleGenerativeAiEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: HuggingFaceProviderSpec,
) -> HuggingFaceEmbeddingFunction: ...
@overload
def build_embedder_from_dict(spec: OllamaProviderSpec) -> OllamaEmbeddingFunction: ...
@overload
def build_embedder_from_dict(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: VertexAIProviderSpec,
) -> GoogleVertexEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: VoyageAIProviderSpec,
) -> VoyageAIEmbeddingFunction: ...
@overload
def build_embedder_from_dict(spec: WatsonProviderSpec) -> WatsonEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: SentenceTransformerProviderSpec,
) -> SentenceTransformerEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: InstructorProviderSpec,
) -> InstructorEmbeddingFunction: ...
@overload
def build_embedder_from_dict(spec: JinaProviderSpec) -> JinaEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: RoboflowProviderSpec,
) -> RoboflowEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: OpenCLIPProviderSpec,
) -> OpenCLIPEmbeddingFunction: ...
@overload
def build_embedder_from_dict(
spec: Text2VecProviderSpec,
) -> Text2VecEmbeddingFunction: ...
@overload
def build_embedder_from_dict(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
def build_embedder_from_dict(spec):
"""Build an embedding function instance from a dictionary specification.
Args:
spec: A dictionary with 'provider' and 'config' keys.
Example: {
"provider": "openai",
"config": {
"api_key": "sk-...",
"model_name": "text-embedding-3-small"
}
}
Returns:
An instance of the appropriate embedding function.
Raises:
ValueError: If the provider is not recognized.
"""
provider_name = spec["provider"]
if not provider_name:
raise ValueError("Missing 'provider' key in specification")
if provider_name not in PROVIDER_PATHS:
raise ValueError(
f"Unknown provider: {provider_name}. Available providers: {list(PROVIDER_PATHS.keys())}"
)
provider_path = PROVIDER_PATHS[provider_name]
try:
provider_class = import_and_validate_definition(provider_path)
except (ImportError, AttributeError, ValueError) as e:
raise ImportError(f"Failed to import provider {provider_name}: {e}") from e
provider_config = spec.get("config", {})
if provider_name == "custom" and "embedding_callable" not in provider_config:
raise ValueError("Custom provider requires 'embedding_callable' in config")
provider = provider_class(**provider_config)
return build_embedder_from_provider(provider)
@overload
def build_embedder(spec: BaseEmbeddingsProvider[T]) -> T: ...
@overload
def build_embedder(spec: AzureProviderSpec) -> OpenAIEmbeddingFunction: ...
@overload
def build_embedder(spec: BedrockProviderSpec) -> AmazonBedrockEmbeddingFunction: ...
@overload
def build_embedder(spec: CohereProviderSpec) -> CohereEmbeddingFunction: ...
@overload
def build_embedder(spec: CustomProviderSpec) -> EmbeddingFunction: ...
@overload
def build_embedder(
spec: GenerativeAiProviderSpec,
) -> GoogleGenerativeAiEmbeddingFunction: ...
@overload
def build_embedder(spec: HuggingFaceProviderSpec) -> HuggingFaceEmbeddingFunction: ...
@overload
def build_embedder(spec: OllamaProviderSpec) -> OllamaEmbeddingFunction: ...
@overload
def build_embedder(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
@overload
def build_embedder(spec: VertexAIProviderSpec) -> GoogleVertexEmbeddingFunction: ...
@overload
def build_embedder(spec: VoyageAIProviderSpec) -> VoyageAIEmbeddingFunction: ...
@overload
def build_embedder(spec: WatsonProviderSpec) -> WatsonEmbeddingFunction: ...
@overload
def build_embedder(
spec: SentenceTransformerProviderSpec,
) -> SentenceTransformerEmbeddingFunction: ...
@overload
def build_embedder(spec: InstructorProviderSpec) -> InstructorEmbeddingFunction: ...
@overload
def build_embedder(spec: JinaProviderSpec) -> JinaEmbeddingFunction: ...
@overload
def build_embedder(spec: RoboflowProviderSpec) -> RoboflowEmbeddingFunction: ...
@overload
def build_embedder(spec: OpenCLIPProviderSpec) -> OpenCLIPEmbeddingFunction: ...
@overload
def build_embedder(spec: Text2VecProviderSpec) -> Text2VecEmbeddingFunction: ...
@overload
def build_embedder(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
def build_embedder(spec):
"""Build an embedding function from either a provider spec or a provider instance.
Args:
spec: Either a provider specification dictionary or a provider instance.
Returns:
An embedding function instance. If a typed provider is passed, returns
the specific embedding function type.
Examples:
# Use default OpenAI embedding
>>> embedder = get_embedding_function()
# From dictionary specification
embedder = build_embedder({
"provider": "openai",
"config": {"api_key": "sk-..."}
})
# Use Cohere with dict
>>> embedder = get_embedding_function(EmbedderConfig(**{
... "provider": "cohere",
... "config": {
... "api_key": "your-key",
... "model_name": "embed-english-v3.0"
... }
... }))
# Use with EmbeddingOptions
>>> embedder = get_embedding_function(
... EmbeddingOptions(provider="sentence-transformer", model_name="all-MiniLM-L6-v2")
... )
# Use Azure OpenAI
>>> embedder = get_embedding_function(EmbedderConfig(**{
... "provider": "openai",
... "config": {
... "api_key": "your-azure-key",
... "api_base": "https://your-resource.openai.azure.com/",
... "api_type": "azure",
... "api_version": "2023-05-15",
... "model": "text-embedding-3-small",
... "deployment_id": "your-deployment-name"
... }
... })
>>> embedder = get_embedding_function(EmbedderConfig(**{
... "provider": "onnx"
... })
# From provider instance
provider = OpenAIProvider(api_key="sk-...")
embedder = build_embedder(provider)
"""
if config is None:
return OpenAIEmbeddingFunction(
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)
if isinstance(spec, BaseEmbeddingsProvider):
return build_embedder_from_provider(spec)
return build_embedder_from_dict(spec)
provider: AllowedEmbeddingProviders
config_dict: dict[str, Any]
if isinstance(config, EmbeddingOptions):
config_dict = config.model_dump(exclude_none=True)
provider = config_dict["provider"]
else:
provider = config["provider"]
nested: dict[str, Any] = config.get("config", {})
if not nested and len(config) > 1:
raise ValueError(
"Invalid embedder configuration format. "
"Configuration must be nested under a 'config' key. "
"Example: {'provider': 'openai', 'config': {'api_key': '...', 'model': '...'}}"
)
config_dict = dict(nested)
if "model" in config_dict and "model_name" not in config_dict:
config_dict["model_name"] = config_dict.pop("model")
if provider not in EMBEDDING_PROVIDERS:
raise ValueError(
f"Unsupported provider: {provider}. "
f"Available providers: {list(EMBEDDING_PROVIDERS.keys())}"
)
_inject_api_key_from_env(provider, config_dict)
config_dict.pop("batch_size", None)
return EMBEDDING_PROVIDERS[provider](**config_dict)
# Backward compatibility alias
get_embedding_function = build_embedder

View File

@@ -0,0 +1 @@
"""Embedding provider implementations."""

View File

@@ -0,0 +1,13 @@
"""AWS embedding providers."""
from crewai.rag.embeddings.providers.aws.bedrock import BedrockProvider
from crewai.rag.embeddings.providers.aws.types import (
BedrockProviderConfig,
BedrockProviderSpec,
)
__all__ = [
"BedrockProvider",
"BedrockProviderConfig",
"BedrockProviderSpec",
]

View File

@@ -0,0 +1,58 @@
"""Amazon Bedrock embeddings provider."""
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
try:
from boto3.session import Session # type: ignore[import-untyped]
except ImportError as exc:
raise ImportError(
"boto3 is required for amazon-bedrock embeddings. Install it with: uv add boto3"
) from exc
def create_aws_session() -> Session:
"""Create an AWS session for Bedrock.
Returns:
boto3.Session: AWS session object
Raises:
ImportError: If boto3 is not installed
ValueError: If AWS session creation fails
"""
try:
import boto3 # type: ignore[import]
return boto3.Session()
except ImportError as e:
raise ImportError(
"boto3 is required for amazon-bedrock embeddings. "
"Install it with: uv add boto3"
) from e
except Exception as e:
raise ValueError(
f"Failed to create AWS session for amazon-bedrock. "
f"Ensure AWS credentials are configured. Error: {e}"
) from e
class BedrockProvider(BaseEmbeddingsProvider[AmazonBedrockEmbeddingFunction]):
"""Amazon Bedrock embeddings provider."""
embedding_callable: type[AmazonBedrockEmbeddingFunction] = Field(
default=AmazonBedrockEmbeddingFunction,
description="Amazon Bedrock embedding function class",
)
model_name: str = Field(
default="amazon.titan-embed-text-v1",
description="Model name to use for embeddings",
validation_alias="BEDROCK_MODEL_NAME",
)
session: Session = Field(
default_factory=create_aws_session, description="AWS session object"
)

View File

@@ -0,0 +1,17 @@
"""Type definitions for AWS embedding providers."""
from typing import Annotated, Any, Literal, TypedDict
class BedrockProviderConfig(TypedDict, total=False):
"""Configuration for Bedrock provider."""
model_name: Annotated[str, "amazon.titan-embed-text-v1"]
session: Any
class BedrockProviderSpec(TypedDict):
"""Bedrock provider specification."""
provider: Literal["amazon-bedrock"]
config: BedrockProviderConfig

View File

@@ -0,0 +1,13 @@
"""Cohere embedding providers."""
from crewai.rag.embeddings.providers.cohere.cohere_provider import CohereProvider
from crewai.rag.embeddings.providers.cohere.types import (
CohereProviderConfig,
CohereProviderSpec,
)
__all__ = [
"CohereProvider",
"CohereProviderConfig",
"CohereProviderSpec",
]

View File

@@ -0,0 +1,24 @@
"""Cohere embeddings provider."""
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class CohereProvider(BaseEmbeddingsProvider[CohereEmbeddingFunction]):
"""Cohere embeddings provider."""
embedding_callable: type[CohereEmbeddingFunction] = Field(
default=CohereEmbeddingFunction, description="Cohere embedding function class"
)
api_key: str = Field(
description="Cohere API key", validation_alias="COHERE_API_KEY"
)
model_name: str = Field(
default="large",
description="Model name to use for embeddings",
validation_alias="COHERE_MODEL_NAME",
)

View File

@@ -0,0 +1,17 @@
"""Type definitions for Cohere embedding providers."""
from typing import Annotated, Literal, TypedDict
class CohereProviderConfig(TypedDict, total=False):
"""Configuration for Cohere provider."""
api_key: str
model_name: Annotated[str, "large"]
class CohereProviderSpec(TypedDict):
"""Cohere provider specification."""
provider: Literal["cohere"]
config: CohereProviderConfig

View File

@@ -0,0 +1,13 @@
"""Custom embedding providers."""
from crewai.rag.embeddings.providers.custom.custom_provider import CustomProvider
from crewai.rag.embeddings.providers.custom.types import (
CustomProviderConfig,
CustomProviderSpec,
)
__all__ = [
"CustomProvider",
"CustomProviderConfig",
"CustomProviderSpec",
]

View File

@@ -0,0 +1,19 @@
"""Custom embeddings provider for user-defined embedding functions."""
from pydantic import Field
from pydantic_settings import SettingsConfigDict
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.custom.embedding_callable import (
CustomEmbeddingFunction,
)
class CustomProvider(BaseEmbeddingsProvider[CustomEmbeddingFunction]):
"""Custom embeddings provider for user-defined embedding functions."""
embedding_callable: type[CustomEmbeddingFunction] = Field(
..., description="Custom embedding function class"
)
model_config = SettingsConfigDict(extra="allow")

View File

@@ -0,0 +1,22 @@
"""Custom embedding function base implementation."""
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
from crewai.rag.core.types import Documents, Embeddings
class CustomEmbeddingFunction(EmbeddingFunction[Documents]):
"""Base class for custom embedding functions.
This provides a concrete implementation that can be subclassed for custom embeddings.
"""
def __call__(self, input: Documents) -> Embeddings:
"""Convert input documents to embeddings.
Args:
input: List of documents to embed.
Returns:
List of numpy arrays representing the embeddings.
"""
raise NotImplementedError("Subclasses must implement __call__ method")

View File

@@ -0,0 +1,18 @@
"""Type definitions for custom embedding providers."""
from typing import Literal, TypedDict
from chromadb.api.types import EmbeddingFunction
class CustomProviderConfig(TypedDict, total=False):
"""Configuration for Custom provider."""
embedding_callable: type[EmbeddingFunction]
class CustomProviderSpec(TypedDict):
"""Custom provider specification."""
provider: Literal["custom"]
config: CustomProviderConfig

View File

@@ -0,0 +1,23 @@
"""Google embedding providers."""
from crewai.rag.embeddings.providers.google.generative_ai import (
GenerativeAiProvider,
)
from crewai.rag.embeddings.providers.google.types import (
GenerativeAiProviderConfig,
GenerativeAiProviderSpec,
VertexAIProviderConfig,
VertexAIProviderSpec,
)
from crewai.rag.embeddings.providers.google.vertex import (
VertexAIProvider,
)
__all__ = [
"GenerativeAiProvider",
"GenerativeAiProviderConfig",
"GenerativeAiProviderSpec",
"VertexAIProvider",
"VertexAIProviderConfig",
"VertexAIProviderSpec",
]

View File

@@ -0,0 +1,30 @@
"""Google Generative AI embeddings provider."""
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class GenerativeAiProvider(BaseEmbeddingsProvider[GoogleGenerativeAiEmbeddingFunction]):
"""Google Generative AI embeddings provider."""
embedding_callable: type[GoogleGenerativeAiEmbeddingFunction] = Field(
default=GoogleGenerativeAiEmbeddingFunction,
description="Google Generative AI embedding function class",
)
model_name: str = Field(
default="models/embedding-001",
description="Model name to use for embeddings",
validation_alias="GOOGLE_GENERATIVE_AI_MODEL_NAME",
)
api_key: str = Field(
description="Google API key", validation_alias="GOOGLE_API_KEY"
)
task_type: str = Field(
default="RETRIEVAL_DOCUMENT",
description="Task type for embeddings",
validation_alias="GOOGLE_GENERATIVE_AI_TASK_TYPE",
)

View File

@@ -0,0 +1,34 @@
"""Type definitions for Google embedding providers."""
from typing import Annotated, Literal, TypedDict
class GenerativeAiProviderConfig(TypedDict, total=False):
"""Configuration for Google Generative AI provider."""
api_key: str
model_name: Annotated[str, "models/embedding-001"]
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
class GenerativeAiProviderSpec(TypedDict):
"""Google Generative AI provider specification."""
provider: Literal["google-generativeai"]
config: GenerativeAiProviderConfig
class VertexAIProviderConfig(TypedDict, total=False):
"""Configuration for Vertex AI provider."""
api_key: str
model_name: Annotated[str, "textembedding-gecko"]
project_id: Annotated[str, "cloud-large-language-models"]
region: Annotated[str, "us-central1"]
class VertexAIProviderSpec(TypedDict):
"""Vertex AI provider specification."""
provider: Literal["google-vertex"]
config: VertexAIProviderConfig

View File

@@ -0,0 +1,35 @@
"""Google Vertex AI embeddings provider."""
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
"""Google Vertex AI embeddings provider."""
embedding_callable: type[GoogleVertexEmbeddingFunction] = Field(
default=GoogleVertexEmbeddingFunction,
description="Vertex AI embedding function class",
)
model_name: str = Field(
default="textembedding-gecko",
description="Model name to use for embeddings",
validation_alias="GOOGLE_VERTEX_MODEL_NAME",
)
api_key: str = Field(
description="Google API key", validation_alias="GOOGLE_CLOUD_API_KEY"
)
project_id: str = Field(
default="cloud-large-language-models",
description="GCP project ID",
validation_alias="GOOGLE_CLOUD_PROJECT",
)
region: str = Field(
default="us-central1",
description="GCP region",
validation_alias="GOOGLE_CLOUD_REGION",
)

View File

@@ -0,0 +1,15 @@
"""HuggingFace embedding providers."""
from crewai.rag.embeddings.providers.huggingface.huggingface_provider import (
HuggingFaceProvider,
)
from crewai.rag.embeddings.providers.huggingface.types import (
HuggingFaceProviderConfig,
HuggingFaceProviderSpec,
)
__all__ = [
"HuggingFaceProvider",
"HuggingFaceProviderConfig",
"HuggingFaceProviderSpec",
]

View File

@@ -0,0 +1,20 @@
"""HuggingFace embeddings provider."""
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingServer,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
"""HuggingFace embeddings provider."""
embedding_callable: type[HuggingFaceEmbeddingServer] = Field(
default=HuggingFaceEmbeddingServer,
description="HuggingFace embedding function class",
)
url: str = Field(
description="HuggingFace API URL", validation_alias="HUGGINGFACE_URL"
)

View File

@@ -0,0 +1,16 @@
"""Type definitions for HuggingFace embedding providers."""
from typing import Literal, TypedDict
class HuggingFaceProviderConfig(TypedDict, total=False):
"""Configuration for HuggingFace provider."""
url: str
class HuggingFaceProviderSpec(TypedDict):
"""HuggingFace provider specification."""
provider: Literal["huggingface"]
config: HuggingFaceProviderConfig

View File

@@ -0,0 +1,15 @@
"""IBM embedding providers."""
from crewai.rag.embeddings.providers.ibm.types import (
WatsonProviderConfig,
WatsonProviderSpec,
)
from crewai.rag.embeddings.providers.ibm.watson import (
WatsonProvider,
)
__all__ = [
"WatsonProvider",
"WatsonProviderConfig",
"WatsonProviderSpec",
]

View File

@@ -0,0 +1,144 @@
"""IBM Watson embedding function implementation."""
from typing import cast
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found, import-untyped]
from ibm_watsonx_ai import Credentials # type: ignore[import-not-found, import-untyped]
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found, import-untyped]
EmbedTextParamsMetaNames as EmbedParams,
)
from typing_extensions import Unpack
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
from crewai.rag.core.types import Documents, Embeddings
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderConfig
class WatsonEmbeddingFunction(EmbeddingFunction[Documents]):
"""Embedding function for IBM Watson models."""
def __init__(self, **kwargs: Unpack[WatsonProviderConfig]) -> None:
"""Initialize Watson embedding function.
Args:
**kwargs: Configuration parameters for Watson Embeddings and Credentials.
"""
self._config = kwargs
def __call__(self, input: Documents) -> Embeddings:
"""Generate embeddings for input documents.
Args:
input: List of documents to embed.
Returns:
List of embedding vectors.
"""
if isinstance(input, str):
input = [input]
embeddings_config: dict = {
"model_id": self._config["model_id"],
}
if "params" in self._config and self._config["params"] is not None:
embeddings_config["params"] = self._config["params"]
if "project_id" in self._config and self._config["project_id"] is not None:
embeddings_config["project_id"] = self._config["project_id"]
if "space_id" in self._config and self._config["space_id"] is not None:
embeddings_config["space_id"] = self._config["space_id"]
if "api_client" in self._config and self._config["api_client"] is not None:
embeddings_config["api_client"] = self._config["api_client"]
if "verify" in self._config and self._config["verify"] is not None:
embeddings_config["verify"] = self._config["verify"]
if "persistent_connection" in self._config:
embeddings_config["persistent_connection"] = self._config[
"persistent_connection"
]
if "batch_size" in self._config:
embeddings_config["batch_size"] = self._config["batch_size"]
if "concurrency_limit" in self._config:
embeddings_config["concurrency_limit"] = self._config["concurrency_limit"]
if "max_retries" in self._config and self._config["max_retries"] is not None:
embeddings_config["max_retries"] = self._config["max_retries"]
if "delay_time" in self._config and self._config["delay_time"] is not None:
embeddings_config["delay_time"] = self._config["delay_time"]
if (
"retry_status_codes" in self._config
and self._config["retry_status_codes"] is not None
):
embeddings_config["retry_status_codes"] = self._config["retry_status_codes"]
if "credentials" in self._config and self._config["credentials"] is not None:
embeddings_config["credentials"] = self._config["credentials"]
else:
cred_config: dict = {}
if "url" in self._config and self._config["url"] is not None:
cred_config["url"] = self._config["url"]
if "api_key" in self._config and self._config["api_key"] is not None:
cred_config["api_key"] = self._config["api_key"]
if "name" in self._config and self._config["name"] is not None:
cred_config["name"] = self._config["name"]
if (
"iam_serviceid_crn" in self._config
and self._config["iam_serviceid_crn"] is not None
):
cred_config["iam_serviceid_crn"] = self._config["iam_serviceid_crn"]
if (
"trusted_profile_id" in self._config
and self._config["trusted_profile_id"] is not None
):
cred_config["trusted_profile_id"] = self._config["trusted_profile_id"]
if "token" in self._config and self._config["token"] is not None:
cred_config["token"] = self._config["token"]
if (
"projects_token" in self._config
and self._config["projects_token"] is not None
):
cred_config["projects_token"] = self._config["projects_token"]
if "username" in self._config and self._config["username"] is not None:
cred_config["username"] = self._config["username"]
if "password" in self._config and self._config["password"] is not None:
cred_config["password"] = self._config["password"]
if (
"instance_id" in self._config
and self._config["instance_id"] is not None
):
cred_config["instance_id"] = self._config["instance_id"]
if "version" in self._config and self._config["version"] is not None:
cred_config["version"] = self._config["version"]
if (
"bedrock_url" in self._config
and self._config["bedrock_url"] is not None
):
cred_config["bedrock_url"] = self._config["bedrock_url"]
if (
"platform_url" in self._config
and self._config["platform_url"] is not None
):
cred_config["platform_url"] = self._config["platform_url"]
if "proxies" in self._config and self._config["proxies"] is not None:
cred_config["proxies"] = self._config["proxies"]
if (
"verify" not in embeddings_config
and "verify" in self._config
and self._config["verify"] is not None
):
cred_config["verify"] = self._config["verify"]
if cred_config:
embeddings_config["credentials"] = Credentials(**cred_config)
if "params" not in embeddings_config:
embeddings_config["params"] = {
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
EmbedParams.RETURN_OPTIONS: {"input_text": True},
}
embedding = watson_models.Embeddings(**embeddings_config)
try:
embeddings = embedding.embed_documents(input)
return cast(Embeddings, embeddings)
except Exception as e:
print(f"Error during Watson embedding: {e}")
raise

View File

@@ -0,0 +1,42 @@
"""Type definitions for IBM Watson embedding providers."""
from typing import Annotated, Any, Literal, TypedDict
class WatsonProviderConfig(TypedDict, total=False):
"""Configuration for Watson provider."""
model_id: str
url: str
params: dict[str, str | dict[str, str]]
credentials: Any
project_id: str
space_id: str
api_client: Any
verify: bool | str
persistent_connection: Annotated[bool, True]
batch_size: Annotated[int, 100]
concurrency_limit: Annotated[int, 10]
max_retries: int
delay_time: float
retry_status_codes: list[int]
api_key: str
name: str
iam_serviceid_crn: str
trusted_profile_id: str
token: str
projects_token: str
username: str
password: str
instance_id: str
version: str
bedrock_url: str
platform_url: str
proxies: dict
class WatsonProviderSpec(TypedDict):
"""Watson provider specification."""
provider: Literal["watson"]
config: WatsonProviderConfig

View File

@@ -0,0 +1,126 @@
"""IBM Watson embeddings provider."""
from ibm_watsonx_ai import ( # type: ignore[import-not-found,import-untyped]
APIClient,
Credentials,
)
from pydantic import Field, model_validator
from typing_extensions import Self
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
WatsonEmbeddingFunction,
)
class WatsonProvider(BaseEmbeddingsProvider[WatsonEmbeddingFunction]):
"""IBM Watson embeddings provider.
Note: Requires custom implementation as Watson uses a different interface.
"""
embedding_callable: type[WatsonEmbeddingFunction] = Field(
default=WatsonEmbeddingFunction, description="Watson embedding function class"
)
model_id: str = Field(
description="Watson model ID", validation_alias="WATSON_MODEL_ID"
)
params: dict[str, str | dict[str, str]] | None = Field(
default=None, description="Additional parameters"
)
credentials: Credentials | None = Field(
default=None, description="Watson credentials"
)
project_id: str | None = Field(
default=None,
description="Watson project ID",
validation_alias="WATSON_PROJECT_ID",
)
space_id: str | None = Field(
default=None, description="Watson space ID", validation_alias="WATSON_SPACE_ID"
)
api_client: APIClient | None = Field(default=None, description="Watson API client")
verify: bool | str | None = Field(
default=None, description="SSL verification", validation_alias="WATSON_VERIFY"
)
persistent_connection: bool = Field(
default=True,
description="Use persistent connection",
validation_alias="WATSON_PERSISTENT_CONNECTION",
)
batch_size: int = Field(
default=100,
description="Batch size for processing",
validation_alias="WATSON_BATCH_SIZE",
)
concurrency_limit: int = Field(
default=10,
description="Concurrency limit",
validation_alias="WATSON_CONCURRENCY_LIMIT",
)
max_retries: int | None = Field(
default=None,
description="Maximum retries",
validation_alias="WATSON_MAX_RETRIES",
)
delay_time: float | None = Field(
default=None,
description="Delay time between retries",
validation_alias="WATSON_DELAY_TIME",
)
retry_status_codes: list[int] | None = Field(
default=None, description="HTTP status codes to retry on"
)
url: str = Field(description="Watson API URL", validation_alias="WATSON_URL")
api_key: str = Field(
description="Watson API key", validation_alias="WATSON_API_KEY"
)
name: str | None = Field(
default=None, description="Service name", validation_alias="WATSON_NAME"
)
iam_serviceid_crn: str | None = Field(
default=None,
description="IAM service ID CRN",
validation_alias="WATSON_IAM_SERVICEID_CRN",
)
trusted_profile_id: str | None = Field(
default=None,
description="Trusted profile ID",
validation_alias="WATSON_TRUSTED_PROFILE_ID",
)
token: str | None = Field(
default=None, description="Bearer token", validation_alias="WATSON_TOKEN"
)
projects_token: str | None = Field(
default=None,
description="Projects token",
validation_alias="WATSON_PROJECTS_TOKEN",
)
username: str | None = Field(
default=None, description="Username", validation_alias="WATSON_USERNAME"
)
password: str | None = Field(
default=None, description="Password", validation_alias="WATSON_PASSWORD"
)
instance_id: str | None = Field(
default=None,
description="Service instance ID",
validation_alias="WATSON_INSTANCE_ID",
)
version: str | None = Field(
default=None, description="API version", validation_alias="WATSON_VERSION"
)
bedrock_url: str | None = Field(
default=None, description="Bedrock URL", validation_alias="WATSON_BEDROCK_URL"
)
platform_url: str | None = Field(
default=None, description="Platform URL", validation_alias="WATSON_PLATFORM_URL"
)
proxies: dict | None = Field(default=None, description="Proxy configuration")
@model_validator(mode="after")
def validate_space_or_project(self) -> Self:
"""Validate that either space_id or project_id is provided."""
if not self.space_id and not self.project_id:
raise ValueError("One of 'space_id' or 'project_id' must be provided")
return self

View File

@@ -0,0 +1,15 @@
"""Instructor embedding providers."""
from crewai.rag.embeddings.providers.instructor.instructor_provider import (
InstructorProvider,
)
from crewai.rag.embeddings.providers.instructor.types import (
InstructorProviderConfig,
InstructorProviderSpec,
)
__all__ = [
"InstructorProvider",
"InstructorProviderConfig",
"InstructorProviderSpec",
]

View File

@@ -0,0 +1,32 @@
"""Instructor embeddings provider."""
from chromadb.utils.embedding_functions.instructor_embedding_function import (
InstructorEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class InstructorProvider(BaseEmbeddingsProvider[InstructorEmbeddingFunction]):
"""Instructor embeddings provider."""
embedding_callable: type[InstructorEmbeddingFunction] = Field(
default=InstructorEmbeddingFunction,
description="Instructor embedding function class",
)
model_name: str = Field(
default="hkunlp/instructor-base",
description="Model name to use",
validation_alias="INSTRUCTOR_MODEL_NAME",
)
device: str = Field(
default="cpu",
description="Device to run model on (cpu or cuda)",
validation_alias="INSTRUCTOR_DEVICE",
)
instruction: str | None = Field(
default=None,
description="Instruction for embeddings",
validation_alias="INSTRUCTOR_INSTRUCTION",
)

View File

@@ -0,0 +1,18 @@
"""Type definitions for Instructor embedding providers."""
from typing import Annotated, Literal, TypedDict
class InstructorProviderConfig(TypedDict, total=False):
"""Configuration for Instructor provider."""
model_name: Annotated[str, "hkunlp/instructor-base"]
device: Annotated[str, "cpu"]
instruction: str
class InstructorProviderSpec(TypedDict):
"""Instructor provider specification."""
provider: Literal["instructor"]
config: InstructorProviderConfig

View File

@@ -0,0 +1,13 @@
"""Jina embedding providers."""
from crewai.rag.embeddings.providers.jina.jina_provider import JinaProvider
from crewai.rag.embeddings.providers.jina.types import (
JinaProviderConfig,
JinaProviderSpec,
)
__all__ = [
"JinaProvider",
"JinaProviderConfig",
"JinaProviderSpec",
]

View File

@@ -0,0 +1,22 @@
"""Jina embeddings provider."""
from chromadb.utils.embedding_functions.jina_embedding_function import (
JinaEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class JinaProvider(BaseEmbeddingsProvider[JinaEmbeddingFunction]):
"""Jina embeddings provider."""
embedding_callable: type[JinaEmbeddingFunction] = Field(
default=JinaEmbeddingFunction, description="Jina embedding function class"
)
api_key: str = Field(description="Jina API key", validation_alias="JINA_API_KEY")
model_name: str = Field(
default="jina-embeddings-v2-base-en",
description="Model name to use for embeddings",
validation_alias="JINA_MODEL_NAME",
)

View File

@@ -0,0 +1,17 @@
"""Type definitions for Jina embedding providers."""
from typing import Annotated, Literal, TypedDict
class JinaProviderConfig(TypedDict, total=False):
"""Configuration for Jina provider."""
api_key: str
model_name: Annotated[str, "jina-embeddings-v2-base-en"]
class JinaProviderSpec(TypedDict):
"""Jina provider specification."""
provider: Literal["jina"]
config: JinaProviderConfig

View File

@@ -0,0 +1,15 @@
"""Microsoft embedding providers."""
from crewai.rag.embeddings.providers.microsoft.azure import (
AzureProvider,
)
from crewai.rag.embeddings.providers.microsoft.types import (
AzureProviderConfig,
AzureProviderSpec,
)
__all__ = [
"AzureProvider",
"AzureProviderConfig",
"AzureProviderSpec",
]

View File

@@ -0,0 +1,58 @@
"""Azure OpenAI embeddings provider."""
from typing import Any
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
"""Azure OpenAI embeddings provider."""
embedding_callable: type[OpenAIEmbeddingFunction] = Field(
default=OpenAIEmbeddingFunction,
description="Azure OpenAI embedding function class",
)
api_key: str = Field(description="Azure API key", validation_alias="OPENAI_API_KEY")
api_base: str | None = Field(
default=None,
description="Azure endpoint URL",
validation_alias="OPENAI_API_BASE",
)
api_type: str = Field(
default="azure",
description="API type for Azure",
validation_alias="OPENAI_API_TYPE",
)
api_version: str | None = Field(
default=None,
description="Azure API version",
validation_alias="OPENAI_API_VERSION",
)
model_name: str = Field(
default="text-embedding-ada-002",
description="Model name to use for embeddings",
validation_alias="OPENAI_MODEL_NAME",
)
default_headers: dict[str, Any] | None = Field(
default=None, description="Default headers for API requests"
)
dimensions: int | None = Field(
default=None,
description="Embedding dimensions",
validation_alias="OPENAI_DIMENSIONS",
)
deployment_id: str | None = Field(
default=None,
description="Azure deployment ID",
validation_alias="OPENAI_DEPLOYMENT_ID",
)
organization_id: str | None = Field(
default=None,
description="Organization ID",
validation_alias="OPENAI_ORGANIZATION_ID",
)

View File

@@ -0,0 +1,24 @@
"""Type definitions for Microsoft Azure embedding providers."""
from typing import Annotated, Any, Literal, TypedDict
class AzureProviderConfig(TypedDict, total=False):
"""Configuration for Azure provider."""
api_key: str
api_base: str
api_type: Annotated[str, "azure"]
api_version: str
model_name: Annotated[str, "text-embedding-ada-002"]
default_headers: dict[str, Any]
dimensions: int
deployment_id: str
organization_id: str
class AzureProviderSpec(TypedDict):
"""Azure provider specification."""
provider: Literal["azure"]
config: AzureProviderConfig

View File

@@ -0,0 +1,15 @@
"""Ollama embedding providers."""
from crewai.rag.embeddings.providers.ollama.ollama_provider import (
OllamaProvider,
)
from crewai.rag.embeddings.providers.ollama.types import (
OllamaProviderConfig,
OllamaProviderSpec,
)
__all__ = [
"OllamaProvider",
"OllamaProviderConfig",
"OllamaProviderSpec",
]

View File

@@ -0,0 +1,25 @@
"""Ollama embeddings provider."""
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class OllamaProvider(BaseEmbeddingsProvider[OllamaEmbeddingFunction]):
"""Ollama embeddings provider."""
embedding_callable: type[OllamaEmbeddingFunction] = Field(
default=OllamaEmbeddingFunction, description="Ollama embedding function class"
)
url: str = Field(
default="http://localhost:11434/api/embeddings",
description="Ollama API endpoint URL",
validation_alias="OLLAMA_URL",
)
model_name: str = Field(
description="Model name to use for embeddings",
validation_alias="OLLAMA_MODEL_NAME",
)

View File

@@ -0,0 +1,17 @@
"""Type definitions for Ollama embedding providers."""
from typing import Annotated, Literal, TypedDict
class OllamaProviderConfig(TypedDict, total=False):
"""Configuration for Ollama provider."""
url: Annotated[str, "http://localhost:11434/api/embeddings"]
model_name: str
class OllamaProviderSpec(TypedDict):
"""Ollama provider specification."""
provider: Literal["ollama"]
config: OllamaProviderConfig

View File

@@ -0,0 +1,13 @@
"""ONNX embedding providers."""
from crewai.rag.embeddings.providers.onnx.onnx_provider import ONNXProvider
from crewai.rag.embeddings.providers.onnx.types import (
ONNXProviderConfig,
ONNXProviderSpec,
)
__all__ = [
"ONNXProvider",
"ONNXProviderConfig",
"ONNXProviderSpec",
]

View File

@@ -0,0 +1,19 @@
"""ONNX embeddings provider."""
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class ONNXProvider(BaseEmbeddingsProvider[ONNXMiniLM_L6_V2]):
"""ONNX embeddings provider."""
embedding_callable: type[ONNXMiniLM_L6_V2] = Field(
default=ONNXMiniLM_L6_V2, description="ONNX MiniLM embedding function class"
)
preferred_providers: list[str] | None = Field(
default=None,
description="Preferred ONNX execution providers",
validation_alias="ONNX_PREFERRED_PROVIDERS",
)

View File

@@ -0,0 +1,16 @@
"""Type definitions for ONNX embedding providers."""
from typing import Literal, TypedDict
class ONNXProviderConfig(TypedDict, total=False):
"""Configuration for ONNX provider."""
preferred_providers: list[str]
class ONNXProviderSpec(TypedDict):
"""ONNX provider specification."""
provider: Literal["onnx"]
config: ONNXProviderConfig

View File

@@ -0,0 +1,15 @@
"""OpenAI embedding providers."""
from crewai.rag.embeddings.providers.openai.openai_provider import (
OpenAIProvider,
)
from crewai.rag.embeddings.providers.openai.types import (
OpenAIProviderConfig,
OpenAIProviderSpec,
)
__all__ = [
"OpenAIProvider",
"OpenAIProviderConfig",
"OpenAIProviderSpec",
]

View File

@@ -0,0 +1,58 @@
"""OpenAI embeddings provider."""
from typing import Any
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
"""OpenAI embeddings provider."""
embedding_callable: type[OpenAIEmbeddingFunction] = Field(
default=OpenAIEmbeddingFunction,
description="OpenAI embedding function class",
)
api_key: str = Field(
description="OpenAI API key", validation_alias="OPENAI_API_KEY"
)
model_name: str = Field(
default="text-embedding-ada-002",
description="Model name to use for embeddings",
validation_alias="OPENAI_MODEL_NAME",
)
api_base: str | None = Field(
default=None,
description="Base URL for API requests",
validation_alias="OPENAI_API_BASE",
)
api_type: str | None = Field(
default=None,
description="API type (e.g., 'azure')",
validation_alias="OPENAI_API_TYPE",
)
api_version: str | None = Field(
default=None, description="API version", validation_alias="OPENAI_API_VERSION"
)
default_headers: dict[str, Any] | None = Field(
default=None, description="Default headers for API requests"
)
dimensions: int | None = Field(
default=None,
description="Embedding dimensions",
validation_alias="OPENAI_DIMENSIONS",
)
deployment_id: str | None = Field(
default=None,
description="Azure deployment ID",
validation_alias="OPENAI_DEPLOYMENT_ID",
)
organization_id: str | None = Field(
default=None,
description="OpenAI organization ID",
validation_alias="OPENAI_ORGANIZATION_ID",
)

View File

@@ -0,0 +1,24 @@
"""Type definitions for OpenAI embedding providers."""
from typing import Annotated, Any, Literal, TypedDict
class OpenAIProviderConfig(TypedDict, total=False):
"""Configuration for OpenAI provider."""
api_key: str
model_name: Annotated[str, "text-embedding-ada-002"]
api_base: str
api_type: str
api_version: str
default_headers: dict[str, Any]
dimensions: int
deployment_id: str
organization_id: str
class OpenAIProviderSpec(TypedDict):
"""OpenAI provider specification."""
provider: Literal["openai"]
config: OpenAIProviderConfig

View File

@@ -0,0 +1,15 @@
"""OpenCLIP embedding providers."""
from crewai.rag.embeddings.providers.openclip.openclip_provider import (
OpenCLIPProvider,
)
from crewai.rag.embeddings.providers.openclip.types import (
OpenCLIPProviderConfig,
OpenCLIPProviderSpec,
)
__all__ = [
"OpenCLIPProvider",
"OpenCLIPProviderConfig",
"OpenCLIPProviderSpec",
]

View File

@@ -0,0 +1,32 @@
"""OpenCLIP embeddings provider."""
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
OpenCLIPEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class OpenCLIPProvider(BaseEmbeddingsProvider[OpenCLIPEmbeddingFunction]):
"""OpenCLIP embeddings provider."""
embedding_callable: type[OpenCLIPEmbeddingFunction] = Field(
default=OpenCLIPEmbeddingFunction,
description="OpenCLIP embedding function class",
)
model_name: str = Field(
default="ViT-B-32",
description="Model name to use",
validation_alias="OPENCLIP_MODEL_NAME",
)
checkpoint: str = Field(
default="laion2b_s34b_b79k",
description="Model checkpoint",
validation_alias="OPENCLIP_CHECKPOINT",
)
device: str | None = Field(
default="cpu",
description="Device to run model on",
validation_alias="OPENCLIP_DEVICE",
)

View File

@@ -0,0 +1,18 @@
"""Type definitions for OpenCLIP embedding providers."""
from typing import Annotated, Literal, TypedDict
class OpenCLIPProviderConfig(TypedDict, total=False):
"""Configuration for OpenCLIP provider."""
model_name: Annotated[str, "ViT-B-32"]
checkpoint: Annotated[str, "laion2b_s34b_b79k"]
device: Annotated[str, "cpu"]
class OpenCLIPProviderSpec(TypedDict):
"""OpenCLIP provider specification."""
provider: Literal["openclip"]
config: OpenCLIPProviderConfig

View File

@@ -0,0 +1,15 @@
"""Roboflow embedding providers."""
from crewai.rag.embeddings.providers.roboflow.roboflow_provider import (
RoboflowProvider,
)
from crewai.rag.embeddings.providers.roboflow.types import (
RoboflowProviderConfig,
RoboflowProviderSpec,
)
__all__ = [
"RoboflowProvider",
"RoboflowProviderConfig",
"RoboflowProviderSpec",
]

View File

@@ -0,0 +1,25 @@
"""Roboflow embeddings provider."""
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
RoboflowEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class RoboflowProvider(BaseEmbeddingsProvider[RoboflowEmbeddingFunction]):
"""Roboflow embeddings provider."""
embedding_callable: type[RoboflowEmbeddingFunction] = Field(
default=RoboflowEmbeddingFunction,
description="Roboflow embedding function class",
)
api_key: str = Field(
default="", description="Roboflow API key", validation_alias="ROBOFLOW_API_KEY"
)
api_url: str = Field(
default="https://infer.roboflow.com",
description="Roboflow API URL",
validation_alias="ROBOFLOW_API_URL",
)

View File

@@ -0,0 +1,17 @@
"""Type definitions for Roboflow embedding providers."""
from typing import Annotated, Literal, TypedDict
class RoboflowProviderConfig(TypedDict, total=False):
"""Configuration for Roboflow provider."""
api_key: Annotated[str, ""]
api_url: Annotated[str, "https://infer.roboflow.com"]
class RoboflowProviderSpec(TypedDict):
"""Roboflow provider specification."""
provider: Literal["roboflow"]
config: RoboflowProviderConfig

View File

@@ -0,0 +1,15 @@
"""SentenceTransformer embedding providers."""
from crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider import (
SentenceTransformerProvider,
)
from crewai.rag.embeddings.providers.sentence_transformer.types import (
SentenceTransformerProviderConfig,
SentenceTransformerProviderSpec,
)
__all__ = [
"SentenceTransformerProvider",
"SentenceTransformerProviderConfig",
"SentenceTransformerProviderSpec",
]

View File

@@ -0,0 +1,34 @@
"""SentenceTransformer embeddings provider."""
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
SentenceTransformerEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class SentenceTransformerProvider(
BaseEmbeddingsProvider[SentenceTransformerEmbeddingFunction]
):
"""SentenceTransformer embeddings provider."""
embedding_callable: type[SentenceTransformerEmbeddingFunction] = Field(
default=SentenceTransformerEmbeddingFunction,
description="SentenceTransformer embedding function class",
)
model_name: str = Field(
default="all-MiniLM-L6-v2",
description="Model name to use",
validation_alias="SENTENCE_TRANSFORMER_MODEL_NAME",
)
device: str = Field(
default="cpu",
description="Device to run model on (cpu or cuda)",
validation_alias="SENTENCE_TRANSFORMER_DEVICE",
)
normalize_embeddings: bool = Field(
default=False,
description="Whether to normalize embeddings",
validation_alias="SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
)

View File

@@ -0,0 +1,18 @@
"""Type definitions for SentenceTransformer embedding providers."""
from typing import Annotated, Literal, TypedDict
class SentenceTransformerProviderConfig(TypedDict, total=False):
"""Configuration for SentenceTransformer provider."""
model_name: Annotated[str, "all-MiniLM-L6-v2"]
device: Annotated[str, "cpu"]
normalize_embeddings: Annotated[bool, False]
class SentenceTransformerProviderSpec(TypedDict):
"""SentenceTransformer provider specification."""
provider: Literal["sentence-transformer"]
config: SentenceTransformerProviderConfig

View File

@@ -0,0 +1,15 @@
"""Text2Vec embedding providers."""
from crewai.rag.embeddings.providers.text2vec.text2vec_provider import (
Text2VecProvider,
)
from crewai.rag.embeddings.providers.text2vec.types import (
Text2VecProviderConfig,
Text2VecProviderSpec,
)
__all__ = [
"Text2VecProvider",
"Text2VecProviderConfig",
"Text2VecProviderSpec",
]

View File

@@ -0,0 +1,22 @@
"""Text2Vec embeddings provider."""
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
Text2VecEmbeddingFunction,
)
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class Text2VecProvider(BaseEmbeddingsProvider[Text2VecEmbeddingFunction]):
"""Text2Vec embeddings provider."""
embedding_callable: type[Text2VecEmbeddingFunction] = Field(
default=Text2VecEmbeddingFunction,
description="Text2Vec embedding function class",
)
model_name: str = Field(
default="shibing624/text2vec-base-chinese",
description="Model name to use",
validation_alias="TEXT2VEC_MODEL_NAME",
)

View File

@@ -0,0 +1,16 @@
"""Type definitions for Text2Vec embedding providers."""
from typing import Annotated, Literal, TypedDict
class Text2VecProviderConfig(TypedDict, total=False):
"""Configuration for Text2Vec provider."""
model_name: Annotated[str, "shibing624/text2vec-base-chinese"]
class Text2VecProviderSpec(TypedDict):
"""Text2Vec provider specification."""
provider: Literal["text2vec"]
config: Text2VecProviderConfig

View File

@@ -0,0 +1,15 @@
"""VoyageAI embedding providers."""
from crewai.rag.embeddings.providers.voyageai.types import (
VoyageAIProviderConfig,
VoyageAIProviderSpec,
)
from crewai.rag.embeddings.providers.voyageai.voyageai_provider import (
VoyageAIProvider,
)
__all__ = [
"VoyageAIProvider",
"VoyageAIProviderConfig",
"VoyageAIProviderSpec",
]

View File

@@ -0,0 +1,50 @@
"""VoyageAI embedding function implementation."""
from typing import cast
import voyageai
from typing_extensions import Unpack
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
from crewai.rag.core.types import Documents, Embeddings
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderConfig
class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
"""Embedding function for VoyageAI models."""
def __init__(self, **kwargs: Unpack[VoyageAIProviderConfig]) -> None:
"""Initialize VoyageAI embedding function.
Args:
**kwargs: Configuration parameters for VoyageAI.
"""
self._config = kwargs
self._client = voyageai.Client(
api_key=kwargs["api_key"],
max_retries=kwargs.get("max_retries", 0),
timeout=kwargs.get("timeout"),
)
def __call__(self, input: Documents) -> Embeddings:
"""Generate embeddings for input documents.
Args:
input: List of documents to embed.
Returns:
List of embedding vectors.
"""
if isinstance(input, str):
input = [input]
result = self._client.embed(
texts=input,
model=self._config.get("model", "voyage-2"),
input_type=self._config.get("input_type"),
truncation=self._config.get("truncation", True),
output_dtype=self._config.get("output_dtype"),
output_dimension=self._config.get("output_dimension"),
)
return cast(Embeddings, result.embeddings)

View File

@@ -0,0 +1,23 @@
"""Type definitions for VoyageAI embedding providers."""
from typing import Annotated, Literal, TypedDict
class VoyageAIProviderConfig(TypedDict, total=False):
"""Configuration for VoyageAI provider."""
api_key: str
model: Annotated[str, "voyage-2"]
input_type: str
truncation: Annotated[bool, True]
output_dtype: str
output_dimension: int
max_retries: Annotated[int, 0]
timeout: float
class VoyageAIProviderSpec(TypedDict):
"""VoyageAI provider specification."""
provider: Literal["voyageai"]
config: VoyageAIProviderConfig

View File

@@ -0,0 +1,55 @@
"""Voyage AI embeddings provider."""
from pydantic import Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
VoyageAIEmbeddingFunction,
)
class VoyageAIProvider(BaseEmbeddingsProvider[VoyageAIEmbeddingFunction]):
"""Voyage AI embeddings provider."""
embedding_callable: type[VoyageAIEmbeddingFunction] = Field(
default=VoyageAIEmbeddingFunction,
description="Voyage AI embedding function class",
)
model: str = Field(
default="voyage-2",
description="Model to use for embeddings",
validation_alias="VOYAGEAI_MODEL",
)
api_key: str = Field(
description="Voyage AI API key", validation_alias="VOYAGEAI_API_KEY"
)
input_type: str | None = Field(
default=None,
description="Input type for embeddings",
validation_alias="VOYAGEAI_INPUT_TYPE",
)
truncation: bool = Field(
default=True,
description="Whether to truncate inputs",
validation_alias="VOYAGEAI_TRUNCATION",
)
output_dtype: str | None = Field(
default=None,
description="Output data type",
validation_alias="VOYAGEAI_OUTPUT_DTYPE",
)
output_dimension: int | None = Field(
default=None,
description="Output dimension",
validation_alias="VOYAGEAI_OUTPUT_DIMENSION",
)
max_retries: int = Field(
default=0,
description="Maximum retries for API calls",
validation_alias="VOYAGEAI_MAX_RETRIES",
)
timeout: float | None = Field(
default=None,
description="Timeout for API calls",
validation_alias="VOYAGEAI_TIMEOUT",
)

View File

@@ -2,61 +2,67 @@
from typing import Literal
from pydantic import BaseModel, Field, SecretStr
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
from crewai.rag.embeddings.providers.google.types import (
GenerativeAiProviderSpec,
VertexAIProviderSpec,
)
from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderSpec
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderSpec
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec
from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
from crewai.rag.embeddings.providers.openclip.types import OpenCLIPProviderSpec
from crewai.rag.embeddings.providers.roboflow.types import RoboflowProviderSpec
from crewai.rag.embeddings.providers.sentence_transformer.types import (
SentenceTransformerProviderSpec,
)
from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
from crewai.rag.types import EmbeddingFunction
ProviderSpec = (
AzureProviderSpec
| BedrockProviderSpec
| CohereProviderSpec
| CustomProviderSpec
| GenerativeAiProviderSpec
| HuggingFaceProviderSpec
| InstructorProviderSpec
| JinaProviderSpec
| OllamaProviderSpec
| ONNXProviderSpec
| OpenAIProviderSpec
| OpenCLIPProviderSpec
| RoboflowProviderSpec
| SentenceTransformerProviderSpec
| Text2VecProviderSpec
| VertexAIProviderSpec
| VoyageAIProviderSpec
| WatsonProviderSpec
)
EmbeddingProvider = Literal[
"openai",
AllowedEmbeddingProviders = Literal[
"azure",
"amazon-bedrock",
"cohere",
"ollama",
"huggingface",
"sentence-transformer",
"instructor",
"google-palm",
"custom",
"google-generativeai",
"google-vertex",
"amazon-bedrock",
"huggingface",
"instructor",
"jina",
"roboflow",
"openclip",
"text2vec",
"ollama",
"onnx",
"openai",
"openclip",
"roboflow",
"sentence-transformer",
"text2vec",
"voyageai",
"watson",
]
"""Supported embedding providers.
These correspond to the embedding functions available in ChromaDB's
embedding_functions module. Each provider has specific requirements
and configuration options.
"""
class EmbeddingOptions(BaseModel):
"""Configuration options for embedding providers.
Generic attributes that can be passed to get_embedding_function
to configure various embedding providers.
"""
provider: EmbeddingProvider = Field(
..., description="Embedding provider name (e.g., 'openai', 'cohere', 'onnx')"
)
model_name: str | None = Field(
default=None, description="Model name for the embedding provider"
)
api_key: SecretStr | None = Field(
default=None, description="API key for the embedding provider"
)
class EmbeddingConfig(BaseModel):
"""Configuration wrapper for embedding functions.
Accepts either a pre-configured EmbeddingFunction or EmbeddingOptions
to create one. This provides flexibility in how embeddings are configured.
Attributes:
function: Either a callable EmbeddingFunction or EmbeddingOptions to create one
"""
function: EmbeddingFunction | EmbeddingOptions

View File

@@ -1,8 +1,8 @@
from abc import ABC, abstractmethod
from typing import Any
from crewai.rag.embeddings.factory import EmbedderConfig
from crewai.rag.embeddings.types import EmbeddingOptions
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
from crewai.rag.embeddings.types import ProviderSpec
class BaseRAGStorage(ABC):
@@ -16,7 +16,7 @@ class BaseRAGStorage(ABC):
self,
type: str,
allow_reset: bool = True,
embedder_config: EmbeddingOptions | EmbedderConfig | None = None,
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
crew: Any = None,
):
self.type = type

View File

@@ -24,8 +24,7 @@ class BaseRecord(TypedDict, total=False):
)
DenseVector: TypeAlias = list[float]
IntVector: TypeAlias = list[int]
Embeddings: TypeAlias = list[list[float]]
EmbeddingFunction: TypeAlias = Callable[..., Any]