mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-14 18:48:29 +00:00
feat: add custom embedding types and migrate providers
- introduce baseembeddingsprovider and helper for embedding functions - add core embedding types and migrate providers, factory, and storage modules - remove unused type aliases and fix pydantic schema error - update providers with env var support and related fixes
This commit is contained in:
@@ -8,7 +8,9 @@ from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.types import BaseRecord, SearchResult
|
||||
from crewai.utilities.logger import Logger
|
||||
@@ -22,12 +24,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder: dict[str, Any] | None = None,
|
||||
embedder: ProviderSpec | BaseEmbeddingsProvider | None = None,
|
||||
collection_name: str | None = None,
|
||||
) -> None:
|
||||
self.collection_name = collection_name
|
||||
self._client: BaseClient | None = None
|
||||
self._embedder_config = embedder # Store embedder config
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
@@ -36,29 +37,12 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
)
|
||||
|
||||
if embedder:
|
||||
# Cast to EmbedderConfig for type checking
|
||||
embedder_typed = cast(EmbedderConfig, embedder)
|
||||
embedding_function = get_embedding_function(embedder_typed)
|
||||
batch_size = None
|
||||
if isinstance(embedder, dict) and "config" in embedder:
|
||||
nested_config = embedder["config"]
|
||||
if isinstance(nested_config, dict):
|
||||
batch_size = nested_config.get("batch_size")
|
||||
|
||||
# Create config with batch_size if provided
|
||||
if batch_size is not None:
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
),
|
||||
batch_size=batch_size,
|
||||
)
|
||||
else:
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
embedding_function = build_embedder(embedder)
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
)
|
||||
self._client = create_client(config)
|
||||
|
||||
def _get_client(self) -> BaseClient:
|
||||
@@ -123,23 +107,9 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
|
||||
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
|
||||
|
||||
batch_size = None
|
||||
if self._embedder_config and isinstance(self._embedder_config, dict):
|
||||
if "config" in self._embedder_config:
|
||||
nested_config = self._embedder_config["config"]
|
||||
if isinstance(nested_config, dict):
|
||||
batch_size = nested_config.get("batch_size")
|
||||
|
||||
if batch_size is not None:
|
||||
client.add_documents(
|
||||
collection_name=collection_name,
|
||||
documents=rag_documents,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
else:
|
||||
client.add_documents(
|
||||
collection_name=collection_name, documents=rag_documents
|
||||
)
|
||||
client.add_documents(
|
||||
collection_name=collection_name, documents=rag_documents
|
||||
)
|
||||
except Exception as e:
|
||||
if "dimension mismatch" in str(e).lower():
|
||||
Logger(verbose=True).log(
|
||||
|
||||
@@ -7,8 +7,9 @@ from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
from crewai.rag.config.utils import get_rag_client
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.embeddings.factory import EmbedderConfig, get_embedding_function
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.factory import build_embedder
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
from crewai.rag.factory import create_client
|
||||
from crewai.rag.storage.base_rag_storage import BaseRAGStorage
|
||||
from crewai.rag.types import BaseRecord
|
||||
@@ -26,7 +27,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: EmbeddingOptions | EmbedderConfig | None = None,
|
||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
|
||||
crew: Any = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
@@ -50,15 +51,17 @@ class RAGStorage(BaseRAGStorage):
|
||||
)
|
||||
|
||||
if self.embedder_config:
|
||||
embedding_function = get_embedding_function(self.embedder_config)
|
||||
embedding_function = build_embedder(self.embedder_config)
|
||||
|
||||
try:
|
||||
_ = embedding_function(["test"])
|
||||
except Exception as e:
|
||||
provider = (
|
||||
self.embedder_config.provider
|
||||
if isinstance(self.embedder_config, EmbeddingOptions)
|
||||
else self.embedder_config.get("provider", "unknown")
|
||||
self.embedder_config["provider"]
|
||||
if isinstance(self.embedder_config, dict)
|
||||
else self.embedder_config.__class__.__name__.replace(
|
||||
"Provider", ""
|
||||
).lower()
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to initialize embedder. Please check your configuration or connection.\n"
|
||||
@@ -80,7 +83,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
),
|
||||
batch_size=batch_size,
|
||||
batch_size=cast(int, batch_size),
|
||||
)
|
||||
else:
|
||||
config = ChromaDBConfig(
|
||||
@@ -142,7 +145,7 @@ class RAGStorage(BaseRAGStorage):
|
||||
client.add_documents(
|
||||
collection_name=collection_name,
|
||||
documents=[document],
|
||||
batch_size=batch_size,
|
||||
batch_size=cast(int, batch_size),
|
||||
)
|
||||
else:
|
||||
client.add_documents(
|
||||
|
||||
142
src/crewai/rag/core/base_embeddings_callable.py
Normal file
142
src/crewai/rag/core/base_embeddings_callable.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Base embeddings callable utilities for RAG systems."""
|
||||
|
||||
from typing import Protocol, TypeVar, runtime_checkable
|
||||
|
||||
import numpy as np
|
||||
|
||||
from crewai.rag.core.types import (
|
||||
Embeddable,
|
||||
Embedding,
|
||||
Embeddings,
|
||||
PyEmbedding,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
D = TypeVar("D", bound=Embeddable, contravariant=True)
|
||||
|
||||
|
||||
def normalize_embeddings(
|
||||
target: Embedding | list[Embedding] | PyEmbedding | list[PyEmbedding],
|
||||
) -> Embeddings | None:
|
||||
"""Normalize various embedding formats to a standard list of numpy arrays.
|
||||
|
||||
Args:
|
||||
target: Input embeddings in various formats (list of floats, list of lists,
|
||||
numpy array, or list of numpy arrays).
|
||||
|
||||
Returns:
|
||||
Normalized embeddings as a list of numpy arrays, or None if input is None.
|
||||
|
||||
Raises:
|
||||
ValueError: If embeddings are empty or in an unsupported format.
|
||||
"""
|
||||
if isinstance(target, np.ndarray):
|
||||
if target.ndim == 1:
|
||||
return [target.astype(np.float32)]
|
||||
if target.ndim == 2:
|
||||
return [row.astype(np.float32) for row in target]
|
||||
raise ValueError(f"Unsupported numpy array shape: {target.shape}")
|
||||
|
||||
first = target[0]
|
||||
if isinstance(first, (int, float)) and not isinstance(first, bool):
|
||||
return [np.array(target, dtype=np.float32)]
|
||||
if isinstance(first, list):
|
||||
return [np.array(emb, dtype=np.float32) for emb in target]
|
||||
if isinstance(first, np.ndarray):
|
||||
return [emb.astype(np.float32) for emb in target] # type: ignore[union-attr]
|
||||
|
||||
raise ValueError(f"Unsupported embeddings format: {type(first)}")
|
||||
|
||||
|
||||
def maybe_cast_one_to_many(target: T | list[T] | None) -> list[T] | None:
|
||||
"""Cast a single item to a list if needed.
|
||||
|
||||
Args:
|
||||
target: A single item or list of items.
|
||||
|
||||
Returns:
|
||||
A list of items or None if input is None.
|
||||
"""
|
||||
if target is None:
|
||||
return None
|
||||
return target if isinstance(target, list) else [target]
|
||||
|
||||
|
||||
def validate_embeddings(embeddings: Embeddings) -> Embeddings:
|
||||
"""Validate embeddings format and content.
|
||||
|
||||
Args:
|
||||
embeddings: List of numpy arrays to validate.
|
||||
|
||||
Returns:
|
||||
Validated embeddings.
|
||||
|
||||
Raises:
|
||||
ValueError: If embeddings format or content is invalid.
|
||||
"""
|
||||
if not isinstance(embeddings, list):
|
||||
raise ValueError(
|
||||
f"Expected embeddings to be a list, got {type(embeddings).__name__}"
|
||||
)
|
||||
if len(embeddings) == 0:
|
||||
raise ValueError(
|
||||
f"Expected embeddings to be a list with at least one item, got {len(embeddings)} embeddings"
|
||||
)
|
||||
if not all(isinstance(e, np.ndarray) for e in embeddings):
|
||||
raise ValueError(
|
||||
"Expected each embedding in the embeddings to be a numpy array"
|
||||
)
|
||||
for i, embedding in enumerate(embeddings):
|
||||
if embedding.ndim == 0:
|
||||
raise ValueError(
|
||||
f"Expected a 1-dimensional array, got a 0-dimensional array {embedding}"
|
||||
)
|
||||
if embedding.size == 0:
|
||||
raise ValueError(
|
||||
f"Expected each embedding to be a 1-dimensional numpy array with at least 1 value. "
|
||||
f"Got an array with no values at position {i}"
|
||||
)
|
||||
if not all(
|
||||
isinstance(value, (np.integer, float, np.floating))
|
||||
and not isinstance(value, bool)
|
||||
for value in embedding
|
||||
):
|
||||
raise ValueError(
|
||||
f"Expected embedding to contain numeric values, got non-numeric values at position {i}"
|
||||
)
|
||||
return embeddings
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class EmbeddingFunction(Protocol[D]):
|
||||
"""Protocol for embedding functions.
|
||||
|
||||
Embedding functions convert input data (documents or images) into vector embeddings.
|
||||
"""
|
||||
|
||||
def __call__(self, input: D) -> Embeddings:
|
||||
"""Convert input data to embeddings.
|
||||
|
||||
Args:
|
||||
input: Input data to embed (documents or images).
|
||||
|
||||
Returns:
|
||||
List of numpy arrays representing the embeddings.
|
||||
"""
|
||||
...
|
||||
|
||||
def __init_subclass__(cls) -> None:
|
||||
"""Wrap __call__ method to normalize and validate embeddings."""
|
||||
super().__init_subclass__()
|
||||
original_call = cls.__call__
|
||||
|
||||
def wrapped_call(self: EmbeddingFunction[D], input: D) -> Embeddings:
|
||||
result = original_call(self, input)
|
||||
if result is None:
|
||||
raise ValueError("Embedding function returned None")
|
||||
normalized = normalize_embeddings(result)
|
||||
if normalized is None:
|
||||
raise ValueError("Normalization returned None for non-None input")
|
||||
return validate_embeddings(normalized)
|
||||
|
||||
cls.__call__ = wrapped_call # type: ignore[method-assign]
|
||||
23
src/crewai/rag/core/base_embeddings_provider.py
Normal file
23
src/crewai/rag/core/base_embeddings_provider.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Base class for embedding providers."""
|
||||
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
|
||||
T = TypeVar("T", bound=EmbeddingFunction)
|
||||
|
||||
|
||||
class BaseEmbeddingsProvider(BaseSettings, Generic[T]):
|
||||
"""Abstract base class for embedding providers.
|
||||
|
||||
This class provides a common interface for dynamically loading and building
|
||||
embedding functions from various providers.
|
||||
"""
|
||||
|
||||
model_config = SettingsConfigDict(extra="allow", populate_by_name=True)
|
||||
embedding_callable: type[T] = Field(
|
||||
..., description="The embedding function class to use"
|
||||
)
|
||||
28
src/crewai/rag/core/types.py
Normal file
28
src/crewai/rag/core/types.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Core type definitions for RAG systems."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TypeVar
|
||||
|
||||
import numpy as np
|
||||
from numpy import floating, integer, number
|
||||
from numpy.typing import NDArray
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
PyEmbedding = Sequence[float] | Sequence[int]
|
||||
PyEmbeddings = list[PyEmbedding]
|
||||
Embedding = NDArray[np.int32 | np.float32]
|
||||
Embeddings = list[Embedding]
|
||||
|
||||
Documents = list[str]
|
||||
Images = list[np.ndarray]
|
||||
Embeddable = Documents | Images
|
||||
|
||||
ScalarType = TypeVar("ScalarType", bound=np.generic)
|
||||
IntegerType = TypeVar("IntegerType", bound=integer)
|
||||
FloatingType = TypeVar("FloatingType", bound=floating)
|
||||
NumberType = TypeVar("NumberType", bound=number)
|
||||
|
||||
DType32 = TypeVar("DType32", np.int32, np.float32)
|
||||
DType64 = TypeVar("DType64", np.int64, np.float64)
|
||||
DTypeCommon = TypeVar("DTypeCommon", np.int32, np.int64, np.float32, np.float64)
|
||||
@@ -1,245 +0,0 @@
|
||||
import os
|
||||
from typing import Any, cast
|
||||
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
from chromadb.api.types import validate_embedding_function
|
||||
|
||||
|
||||
class EmbeddingConfigurator:
|
||||
def __init__(self):
|
||||
self.embedding_functions = {
|
||||
"openai": self._configure_openai,
|
||||
"azure": self._configure_azure,
|
||||
"ollama": self._configure_ollama,
|
||||
"vertexai": self._configure_vertexai,
|
||||
"google": self._configure_google,
|
||||
"cohere": self._configure_cohere,
|
||||
"voyageai": self._configure_voyageai,
|
||||
"bedrock": self._configure_bedrock,
|
||||
"huggingface": self._configure_huggingface,
|
||||
"watson": self._configure_watson,
|
||||
"custom": self._configure_custom,
|
||||
}
|
||||
|
||||
def configure_embedder(
|
||||
self,
|
||||
embedder_config: dict[str, Any] | None = None,
|
||||
) -> EmbeddingFunction:
|
||||
"""Configures and returns an embedding function based on the provided config."""
|
||||
if embedder_config is None:
|
||||
return self._create_default_embedding_function()
|
||||
|
||||
provider = embedder_config.get("provider")
|
||||
config = embedder_config.get("config", {})
|
||||
model_name = config.get("model") if provider != "custom" else None
|
||||
|
||||
if provider not in self.embedding_functions:
|
||||
raise Exception(
|
||||
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
|
||||
)
|
||||
|
||||
try:
|
||||
embedding_function = self.embedding_functions[provider]
|
||||
except ImportError as e:
|
||||
missing_package = str(e).split()[-1]
|
||||
raise ImportError(
|
||||
f"{missing_package} is not installed. Please install it with: pip install {missing_package}"
|
||||
) from e
|
||||
|
||||
return (
|
||||
embedding_function(config)
|
||||
if provider == "custom"
|
||||
else embedding_function(config, model_name)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_default_embedding_function():
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_openai(config, model_name):
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
||||
model_name=model_name,
|
||||
api_base=config.get("api_base", None),
|
||||
api_type=config.get("api_type", None),
|
||||
api_version=config.get("api_version", None),
|
||||
default_headers=config.get("default_headers", None),
|
||||
dimensions=config.get("dimensions", None),
|
||||
deployment_id=config.get("deployment_id", None),
|
||||
organization_id=config.get("organization_id", None),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_azure(config, model_name):
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=config.get("api_key"),
|
||||
api_base=config.get("api_base"),
|
||||
api_type=config.get("api_type", "azure"),
|
||||
api_version=config.get("api_version"),
|
||||
model_name=model_name,
|
||||
default_headers=config.get("default_headers"),
|
||||
dimensions=config.get("dimensions"),
|
||||
deployment_id=config.get("deployment_id"),
|
||||
organization_id=config.get("organization_id"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_ollama(config, model_name):
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OllamaEmbeddingFunction(
|
||||
url=config.get("url", "http://localhost:11434/api/embeddings"),
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_vertexai(config, model_name):
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
|
||||
return GoogleVertexEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
project_id=config.get("project_id"),
|
||||
region=config.get("region"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_google(config, model_name):
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
)
|
||||
|
||||
return GoogleGenerativeAiEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
task_type=config.get("task_type"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_cohere(config, model_name):
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
|
||||
return CohereEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_voyageai(config, model_name):
|
||||
from chromadb.utils.embedding_functions.voyageai_embedding_function import ( # type: ignore[import-not-found]
|
||||
VoyageAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return VoyageAIEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_bedrock(config, model_name):
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
|
||||
# Allow custom model_name override with backwards compatibility
|
||||
kwargs = {"session": config.get("session")}
|
||||
if model_name is not None:
|
||||
kwargs["model_name"] = model_name
|
||||
return AmazonBedrockEmbeddingFunction(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _configure_huggingface(config, model_name):
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingServer,
|
||||
)
|
||||
|
||||
return HuggingFaceEmbeddingServer(
|
||||
url=config.get("api_url"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_watson(config, model_name):
|
||||
try:
|
||||
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found]
|
||||
from ibm_watsonx_ai import Credentials # type: ignore[import-not-found]
|
||||
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found]
|
||||
EmbedTextParamsMetaNames as EmbedParams,
|
||||
)
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
||||
) from e
|
||||
|
||||
class WatsonEmbeddingFunction(EmbeddingFunction):
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
embed_params = {
|
||||
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
|
||||
EmbedParams.RETURN_OPTIONS: {"input_text": True},
|
||||
}
|
||||
|
||||
embedding = watson_models.Embeddings(
|
||||
model_id=config.get("model"),
|
||||
params=embed_params,
|
||||
credentials=Credentials(
|
||||
api_key=config.get("api_key"), url=config.get("api_url")
|
||||
),
|
||||
project_id=config.get("project_id"),
|
||||
)
|
||||
|
||||
try:
|
||||
embeddings = embedding.embed_documents(input)
|
||||
return cast(Embeddings, embeddings)
|
||||
except Exception as e:
|
||||
print("Error during Watson embedding:", e)
|
||||
raise e
|
||||
|
||||
return WatsonEmbeddingFunction()
|
||||
|
||||
@staticmethod
|
||||
def _configure_custom(config):
|
||||
custom_embedder = config.get("embedder")
|
||||
if isinstance(custom_embedder, EmbeddingFunction):
|
||||
try:
|
||||
validate_embedding_function(custom_embedder)
|
||||
return custom_embedder
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid custom embedding function: {e!s}") from e
|
||||
elif callable(custom_embedder):
|
||||
try:
|
||||
instance = custom_embedder()
|
||||
if isinstance(instance, EmbeddingFunction):
|
||||
validate_embedding_function(instance)
|
||||
return instance
|
||||
raise ValueError(
|
||||
"Custom embedder does not create an EmbeddingFunction instance"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error instantiating custom embedder: {e!s}") from e
|
||||
else:
|
||||
raise ValueError(
|
||||
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"
|
||||
)
|
||||
@@ -1,249 +1,363 @@
|
||||
"""Minimal embedding function factory for CrewAI."""
|
||||
"""Factory functions for creating embedding providers and functions."""
|
||||
|
||||
import os
|
||||
from collections.abc import Callable, MutableMapping
|
||||
from typing import Any, Final, Literal, TypedDict
|
||||
from __future__ import annotations
|
||||
|
||||
from chromadb import EmbeddingFunction
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
GooglePalmEmbeddingFunction,
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.instructor_embedding_function import (
|
||||
InstructorEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.jina_embedding_function import (
|
||||
JinaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
|
||||
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
|
||||
OpenCLIPEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
|
||||
RoboflowEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
|
||||
SentenceTransformerEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
Text2VecEmbeddingFunction,
|
||||
)
|
||||
from typing_extensions import NotRequired
|
||||
from typing import TYPE_CHECKING, TypeVar, overload
|
||||
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.utilities.import_utils import import_and_validate_definition
|
||||
|
||||
AllowedEmbeddingProviders = Literal[
|
||||
"openai",
|
||||
"cohere",
|
||||
"ollama",
|
||||
"huggingface",
|
||||
"sentence-transformer",
|
||||
"instructor",
|
||||
"google-palm",
|
||||
"google-generativeai",
|
||||
"google-vertex",
|
||||
"amazon-bedrock",
|
||||
"jina",
|
||||
"roboflow",
|
||||
"openclip",
|
||||
"text2vec",
|
||||
"onnx",
|
||||
]
|
||||
if TYPE_CHECKING:
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.instructor_embedding_function import (
|
||||
InstructorEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.jina_embedding_function import (
|
||||
JinaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
|
||||
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
|
||||
OpenCLIPEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
|
||||
RoboflowEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
|
||||
SentenceTransformerEmbeddingFunction,
|
||||
)
|
||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
Text2VecEmbeddingFunction,
|
||||
)
|
||||
|
||||
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
|
||||
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
|
||||
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
|
||||
from crewai.rag.embeddings.providers.google.types import (
|
||||
GenerativeAiProviderSpec,
|
||||
VertexAIProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.huggingface.types import (
|
||||
HuggingFaceProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
|
||||
WatsonEmbeddingFunction,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderSpec
|
||||
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
|
||||
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
|
||||
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
|
||||
from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec
|
||||
from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec
|
||||
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
|
||||
from crewai.rag.embeddings.providers.openclip.types import OpenCLIPProviderSpec
|
||||
from crewai.rag.embeddings.providers.roboflow.types import RoboflowProviderSpec
|
||||
from crewai.rag.embeddings.providers.sentence_transformer.types import (
|
||||
SentenceTransformerProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
|
||||
from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
|
||||
VoyageAIEmbeddingFunction,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
|
||||
|
||||
T = TypeVar("T", bound=EmbeddingFunction)
|
||||
|
||||
|
||||
class EmbedderConfig(TypedDict):
|
||||
"""Configuration for embedding functions with nested format."""
|
||||
|
||||
provider: AllowedEmbeddingProviders
|
||||
config: NotRequired[dict[str, Any]]
|
||||
|
||||
|
||||
EMBEDDING_PROVIDERS: Final[
|
||||
dict[AllowedEmbeddingProviders, Callable[..., EmbeddingFunction]]
|
||||
] = {
|
||||
"openai": OpenAIEmbeddingFunction,
|
||||
"cohere": CohereEmbeddingFunction,
|
||||
"ollama": OllamaEmbeddingFunction,
|
||||
"huggingface": HuggingFaceEmbeddingFunction,
|
||||
"sentence-transformer": SentenceTransformerEmbeddingFunction,
|
||||
"instructor": InstructorEmbeddingFunction,
|
||||
"google-palm": GooglePalmEmbeddingFunction,
|
||||
"google-generativeai": GoogleGenerativeAiEmbeddingFunction,
|
||||
"google-vertex": GoogleVertexEmbeddingFunction,
|
||||
"amazon-bedrock": AmazonBedrockEmbeddingFunction,
|
||||
"jina": JinaEmbeddingFunction,
|
||||
"roboflow": RoboflowEmbeddingFunction,
|
||||
"openclip": OpenCLIPEmbeddingFunction,
|
||||
"text2vec": Text2VecEmbeddingFunction,
|
||||
"onnx": ONNXMiniLM_L6_V2,
|
||||
}
|
||||
|
||||
PROVIDER_ENV_MAPPING: Final[dict[AllowedEmbeddingProviders, tuple[str, str]]] = {
|
||||
"openai": ("OPENAI_API_KEY", "api_key"),
|
||||
"cohere": ("COHERE_API_KEY", "api_key"),
|
||||
"huggingface": ("HUGGINGFACE_API_KEY", "api_key"),
|
||||
"google-palm": ("GOOGLE_API_KEY", "api_key"),
|
||||
"google-generativeai": ("GOOGLE_API_KEY", "api_key"),
|
||||
"google-vertex": ("GOOGLE_API_KEY", "api_key"),
|
||||
"jina": ("JINA_API_KEY", "api_key"),
|
||||
"roboflow": ("ROBOFLOW_API_KEY", "api_key"),
|
||||
PROVIDER_PATHS = {
|
||||
"azure": "crewai.rag.embeddings.providers.microsoft.azure.AzureProvider",
|
||||
"amazon-bedrock": "crewai.rag.embeddings.providers.aws.bedrock.BedrockProvider",
|
||||
"cohere": "crewai.rag.embeddings.providers.cohere.cohere_provider.CohereProvider",
|
||||
"custom": "crewai.rag.embeddings.providers.custom.custom_provider.CustomProvider",
|
||||
"google-generativeai": "crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider",
|
||||
"google-vertex": "crewai.rag.embeddings.providers.google.vertex.VertexAIProvider",
|
||||
"huggingface": "crewai.rag.embeddings.providers.huggingface.huggingface_provider.HuggingFaceProvider",
|
||||
"instructor": "crewai.rag.embeddings.providers.instructor.instructor_provider.InstructorProvider",
|
||||
"jina": "crewai.rag.embeddings.providers.jina.jina_provider.JinaProvider",
|
||||
"ollama": "crewai.rag.embeddings.providers.ollama.ollama_provider.OllamaProvider",
|
||||
"onnx": "crewai.rag.embeddings.providers.onnx.onnx_provider.ONNXProvider",
|
||||
"openai": "crewai.rag.embeddings.providers.openai.openai_provider.OpenAIProvider",
|
||||
"openclip": "crewai.rag.embeddings.providers.openclip.openclip_provider.OpenCLIPProvider",
|
||||
"roboflow": "crewai.rag.embeddings.providers.roboflow.roboflow_provider.RoboflowProvider",
|
||||
"sentence-transformer": "crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider.SentenceTransformerProvider",
|
||||
"text2vec": "crewai.rag.embeddings.providers.text2vec.text2vec_provider.Text2VecProvider",
|
||||
"voyageai": "crewai.rag.embeddings.providers.voyageai.voyageai_provider.VoyageAIProvider",
|
||||
"watson": "crewai.rag.embeddings.providers.ibm.watson.WatsonProvider",
|
||||
}
|
||||
|
||||
|
||||
def _inject_api_key_from_env(
|
||||
provider: AllowedEmbeddingProviders, config_dict: MutableMapping[str, Any]
|
||||
) -> None:
|
||||
"""Inject API key or other required configuration from environment if not explicitly provided.
|
||||
def build_embedder_from_provider(provider: BaseEmbeddingsProvider[T]) -> T:
|
||||
"""Build an embedding function instance from a provider.
|
||||
|
||||
Args:
|
||||
provider: The embedding provider name
|
||||
config_dict: The configuration dictionary to modify in-place
|
||||
|
||||
Raises:
|
||||
ImportError: If required libraries for certain providers are not installed
|
||||
ValueError: If AWS session creation fails for amazon-bedrock
|
||||
"""
|
||||
if provider in PROVIDER_ENV_MAPPING:
|
||||
env_var_name, config_key = PROVIDER_ENV_MAPPING[provider]
|
||||
if config_key not in config_dict:
|
||||
env_value = os.getenv(env_var_name)
|
||||
if env_value:
|
||||
config_dict[config_key] = env_value
|
||||
|
||||
if provider == "amazon-bedrock":
|
||||
if "session" not in config_dict:
|
||||
try:
|
||||
import boto3 # type: ignore[import]
|
||||
|
||||
config_dict["session"] = boto3.Session()
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"boto3 is required for amazon-bedrock embeddings. "
|
||||
"Install it with: uv add boto3"
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to create AWS session for amazon-bedrock. "
|
||||
f"Ensure AWS credentials are configured. Error: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def get_embedding_function(
|
||||
config: EmbeddingOptions | EmbedderConfig | None = None,
|
||||
) -> EmbeddingFunction:
|
||||
"""Get embedding function - delegates to ChromaDB.
|
||||
|
||||
Args:
|
||||
config: Optional configuration - either:
|
||||
- EmbeddingOptions: Pydantic model with flat configuration
|
||||
- EmbedderConfig: TypedDict with nested format {"provider": str, "config": dict}
|
||||
- None: Uses default OpenAI configuration
|
||||
provider: The embedding provider configuration.
|
||||
|
||||
Returns:
|
||||
EmbeddingFunction instance ready for use with ChromaDB
|
||||
An instance of the specified embedding function type.
|
||||
"""
|
||||
return provider.embedding_callable(
|
||||
**provider.model_dump(exclude={"embedding_callable"})
|
||||
)
|
||||
|
||||
Supported providers:
|
||||
- openai: OpenAI embeddings
|
||||
- cohere: Cohere embeddings
|
||||
- ollama: Ollama local embeddings
|
||||
- huggingface: HuggingFace embeddings
|
||||
- sentence-transformer: Local sentence transformers
|
||||
- instructor: Instructor embeddings for specialized tasks
|
||||
- google-palm: Google PaLM embeddings
|
||||
- google-generativeai: Google Generative AI embeddings
|
||||
- google-vertex: Google Vertex AI embeddings
|
||||
- amazon-bedrock: AWS Bedrock embeddings
|
||||
- jina: Jina AI embeddings
|
||||
- roboflow: Roboflow embeddings for vision tasks
|
||||
- openclip: OpenCLIP embeddings for multimodal tasks
|
||||
- text2vec: Text2Vec embeddings
|
||||
- onnx: ONNX MiniLM-L6-v2 (no API key needed, included with ChromaDB)
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: AzureProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: BedrockProviderSpec,
|
||||
) -> AmazonBedrockEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: CohereProviderSpec) -> CohereEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: CustomProviderSpec) -> EmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: GenerativeAiProviderSpec,
|
||||
) -> GoogleGenerativeAiEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: HuggingFaceProviderSpec,
|
||||
) -> HuggingFaceEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: OllamaProviderSpec) -> OllamaEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: VertexAIProviderSpec,
|
||||
) -> GoogleVertexEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: VoyageAIProviderSpec,
|
||||
) -> VoyageAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: WatsonProviderSpec) -> WatsonEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: SentenceTransformerProviderSpec,
|
||||
) -> SentenceTransformerEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: InstructorProviderSpec,
|
||||
) -> InstructorEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: JinaProviderSpec) -> JinaEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: RoboflowProviderSpec,
|
||||
) -> RoboflowEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: OpenCLIPProviderSpec,
|
||||
) -> OpenCLIPEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(
|
||||
spec: Text2VecProviderSpec,
|
||||
) -> Text2VecEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder_from_dict(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
|
||||
|
||||
|
||||
def build_embedder_from_dict(spec):
|
||||
"""Build an embedding function instance from a dictionary specification.
|
||||
|
||||
Args:
|
||||
spec: A dictionary with 'provider' and 'config' keys.
|
||||
Example: {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "sk-...",
|
||||
"model_name": "text-embedding-3-small"
|
||||
}
|
||||
}
|
||||
|
||||
Returns:
|
||||
An instance of the appropriate embedding function.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is not recognized.
|
||||
"""
|
||||
provider_name = spec["provider"]
|
||||
if not provider_name:
|
||||
raise ValueError("Missing 'provider' key in specification")
|
||||
|
||||
if provider_name not in PROVIDER_PATHS:
|
||||
raise ValueError(
|
||||
f"Unknown provider: {provider_name}. Available providers: {list(PROVIDER_PATHS.keys())}"
|
||||
)
|
||||
|
||||
provider_path = PROVIDER_PATHS[provider_name]
|
||||
try:
|
||||
provider_class = import_and_validate_definition(provider_path)
|
||||
except (ImportError, AttributeError, ValueError) as e:
|
||||
raise ImportError(f"Failed to import provider {provider_name}: {e}") from e
|
||||
|
||||
provider_config = spec.get("config", {})
|
||||
|
||||
if provider_name == "custom" and "embedding_callable" not in provider_config:
|
||||
raise ValueError("Custom provider requires 'embedding_callable' in config")
|
||||
|
||||
provider = provider_class(**provider_config)
|
||||
return build_embedder_from_provider(provider)
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: BaseEmbeddingsProvider[T]) -> T: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: AzureProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: BedrockProviderSpec) -> AmazonBedrockEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: CohereProviderSpec) -> CohereEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: CustomProviderSpec) -> EmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(
|
||||
spec: GenerativeAiProviderSpec,
|
||||
) -> GoogleGenerativeAiEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: HuggingFaceProviderSpec) -> HuggingFaceEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: OllamaProviderSpec) -> OllamaEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: OpenAIProviderSpec) -> OpenAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: VertexAIProviderSpec) -> GoogleVertexEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: VoyageAIProviderSpec) -> VoyageAIEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: WatsonProviderSpec) -> WatsonEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(
|
||||
spec: SentenceTransformerProviderSpec,
|
||||
) -> SentenceTransformerEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: InstructorProviderSpec) -> InstructorEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: JinaProviderSpec) -> JinaEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: RoboflowProviderSpec) -> RoboflowEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: OpenCLIPProviderSpec) -> OpenCLIPEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: Text2VecProviderSpec) -> Text2VecEmbeddingFunction: ...
|
||||
|
||||
|
||||
@overload
|
||||
def build_embedder(spec: ONNXProviderSpec) -> ONNXMiniLM_L6_V2: ...
|
||||
|
||||
|
||||
def build_embedder(spec):
|
||||
"""Build an embedding function from either a provider spec or a provider instance.
|
||||
|
||||
Args:
|
||||
spec: Either a provider specification dictionary or a provider instance.
|
||||
|
||||
Returns:
|
||||
An embedding function instance. If a typed provider is passed, returns
|
||||
the specific embedding function type.
|
||||
|
||||
Examples:
|
||||
# Use default OpenAI embedding
|
||||
>>> embedder = get_embedding_function()
|
||||
# From dictionary specification
|
||||
embedder = build_embedder({
|
||||
"provider": "openai",
|
||||
"config": {"api_key": "sk-..."}
|
||||
})
|
||||
|
||||
# Use Cohere with dict
|
||||
>>> embedder = get_embedding_function(EmbedderConfig(**{
|
||||
... "provider": "cohere",
|
||||
... "config": {
|
||||
... "api_key": "your-key",
|
||||
... "model_name": "embed-english-v3.0"
|
||||
... }
|
||||
... }))
|
||||
|
||||
# Use with EmbeddingOptions
|
||||
>>> embedder = get_embedding_function(
|
||||
... EmbeddingOptions(provider="sentence-transformer", model_name="all-MiniLM-L6-v2")
|
||||
... )
|
||||
|
||||
# Use Azure OpenAI
|
||||
>>> embedder = get_embedding_function(EmbedderConfig(**{
|
||||
... "provider": "openai",
|
||||
... "config": {
|
||||
... "api_key": "your-azure-key",
|
||||
... "api_base": "https://your-resource.openai.azure.com/",
|
||||
... "api_type": "azure",
|
||||
... "api_version": "2023-05-15",
|
||||
... "model": "text-embedding-3-small",
|
||||
... "deployment_id": "your-deployment-name"
|
||||
... }
|
||||
... })
|
||||
|
||||
>>> embedder = get_embedding_function(EmbedderConfig(**{
|
||||
... "provider": "onnx"
|
||||
... })
|
||||
# From provider instance
|
||||
provider = OpenAIProvider(api_key="sk-...")
|
||||
embedder = build_embedder(provider)
|
||||
"""
|
||||
if config is None:
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
||||
)
|
||||
if isinstance(spec, BaseEmbeddingsProvider):
|
||||
return build_embedder_from_provider(spec)
|
||||
return build_embedder_from_dict(spec)
|
||||
|
||||
provider: AllowedEmbeddingProviders
|
||||
config_dict: dict[str, Any]
|
||||
|
||||
if isinstance(config, EmbeddingOptions):
|
||||
config_dict = config.model_dump(exclude_none=True)
|
||||
provider = config_dict["provider"]
|
||||
else:
|
||||
provider = config["provider"]
|
||||
nested: dict[str, Any] = config.get("config", {})
|
||||
|
||||
if not nested and len(config) > 1:
|
||||
raise ValueError(
|
||||
"Invalid embedder configuration format. "
|
||||
"Configuration must be nested under a 'config' key. "
|
||||
"Example: {'provider': 'openai', 'config': {'api_key': '...', 'model': '...'}}"
|
||||
)
|
||||
|
||||
config_dict = dict(nested)
|
||||
if "model" in config_dict and "model_name" not in config_dict:
|
||||
config_dict["model_name"] = config_dict.pop("model")
|
||||
|
||||
if provider not in EMBEDDING_PROVIDERS:
|
||||
raise ValueError(
|
||||
f"Unsupported provider: {provider}. "
|
||||
f"Available providers: {list(EMBEDDING_PROVIDERS.keys())}"
|
||||
)
|
||||
|
||||
_inject_api_key_from_env(provider, config_dict)
|
||||
|
||||
config_dict.pop("batch_size", None)
|
||||
|
||||
return EMBEDDING_PROVIDERS[provider](**config_dict)
|
||||
# Backward compatibility alias
|
||||
get_embedding_function = build_embedder
|
||||
|
||||
1
src/crewai/rag/embeddings/providers/__init__.py
Normal file
1
src/crewai/rag/embeddings/providers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Embedding provider implementations."""
|
||||
13
src/crewai/rag/embeddings/providers/aws/__init__.py
Normal file
13
src/crewai/rag/embeddings/providers/aws/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""AWS embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.aws.bedrock import BedrockProvider
|
||||
from crewai.rag.embeddings.providers.aws.types import (
|
||||
BedrockProviderConfig,
|
||||
BedrockProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BedrockProvider",
|
||||
"BedrockProviderConfig",
|
||||
"BedrockProviderSpec",
|
||||
]
|
||||
58
src/crewai/rag/embeddings/providers/aws/bedrock.py
Normal file
58
src/crewai/rag/embeddings/providers/aws/bedrock.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Amazon Bedrock embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
try:
|
||||
from boto3.session import Session # type: ignore[import-untyped]
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"boto3 is required for amazon-bedrock embeddings. Install it with: uv add boto3"
|
||||
) from exc
|
||||
|
||||
|
||||
def create_aws_session() -> Session:
|
||||
"""Create an AWS session for Bedrock.
|
||||
|
||||
Returns:
|
||||
boto3.Session: AWS session object
|
||||
|
||||
Raises:
|
||||
ImportError: If boto3 is not installed
|
||||
ValueError: If AWS session creation fails
|
||||
"""
|
||||
try:
|
||||
import boto3 # type: ignore[import]
|
||||
|
||||
return boto3.Session()
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"boto3 is required for amazon-bedrock embeddings. "
|
||||
"Install it with: uv add boto3"
|
||||
) from e
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to create AWS session for amazon-bedrock. "
|
||||
f"Ensure AWS credentials are configured. Error: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
class BedrockProvider(BaseEmbeddingsProvider[AmazonBedrockEmbeddingFunction]):
|
||||
"""Amazon Bedrock embeddings provider."""
|
||||
|
||||
embedding_callable: type[AmazonBedrockEmbeddingFunction] = Field(
|
||||
default=AmazonBedrockEmbeddingFunction,
|
||||
description="Amazon Bedrock embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="amazon.titan-embed-text-v1",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="BEDROCK_MODEL_NAME",
|
||||
)
|
||||
session: Session = Field(
|
||||
default_factory=create_aws_session, description="AWS session object"
|
||||
)
|
||||
17
src/crewai/rag/embeddings/providers/aws/types.py
Normal file
17
src/crewai/rag/embeddings/providers/aws/types.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Type definitions for AWS embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal, TypedDict
|
||||
|
||||
|
||||
class BedrockProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Bedrock provider."""
|
||||
|
||||
model_name: Annotated[str, "amazon.titan-embed-text-v1"]
|
||||
session: Any
|
||||
|
||||
|
||||
class BedrockProviderSpec(TypedDict):
|
||||
"""Bedrock provider specification."""
|
||||
|
||||
provider: Literal["amazon-bedrock"]
|
||||
config: BedrockProviderConfig
|
||||
13
src/crewai/rag/embeddings/providers/cohere/__init__.py
Normal file
13
src/crewai/rag/embeddings/providers/cohere/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Cohere embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.cohere.cohere_provider import CohereProvider
|
||||
from crewai.rag.embeddings.providers.cohere.types import (
|
||||
CohereProviderConfig,
|
||||
CohereProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CohereProvider",
|
||||
"CohereProviderConfig",
|
||||
"CohereProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,24 @@
|
||||
"""Cohere embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class CohereProvider(BaseEmbeddingsProvider[CohereEmbeddingFunction]):
|
||||
"""Cohere embeddings provider."""
|
||||
|
||||
embedding_callable: type[CohereEmbeddingFunction] = Field(
|
||||
default=CohereEmbeddingFunction, description="Cohere embedding function class"
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Cohere API key", validation_alias="COHERE_API_KEY"
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="large",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="COHERE_MODEL_NAME",
|
||||
)
|
||||
17
src/crewai/rag/embeddings/providers/cohere/types.py
Normal file
17
src/crewai/rag/embeddings/providers/cohere/types.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Type definitions for Cohere embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
|
||||
|
||||
class CohereProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Cohere provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "large"]
|
||||
|
||||
|
||||
class CohereProviderSpec(TypedDict):
|
||||
"""Cohere provider specification."""
|
||||
|
||||
provider: Literal["cohere"]
|
||||
config: CohereProviderConfig
|
||||
13
src/crewai/rag/embeddings/providers/custom/__init__.py
Normal file
13
src/crewai/rag/embeddings/providers/custom/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Custom embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.custom.custom_provider import CustomProvider
|
||||
from crewai.rag.embeddings.providers.custom.types import (
|
||||
CustomProviderConfig,
|
||||
CustomProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"CustomProvider",
|
||||
"CustomProviderConfig",
|
||||
"CustomProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,19 @@
|
||||
"""Custom embeddings provider for user-defined embedding functions."""
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.custom.embedding_callable import (
|
||||
CustomEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
class CustomProvider(BaseEmbeddingsProvider[CustomEmbeddingFunction]):
|
||||
"""Custom embeddings provider for user-defined embedding functions."""
|
||||
|
||||
embedding_callable: type[CustomEmbeddingFunction] = Field(
|
||||
..., description="Custom embedding function class"
|
||||
)
|
||||
|
||||
model_config = SettingsConfigDict(extra="allow")
|
||||
@@ -0,0 +1,22 @@
|
||||
"""Custom embedding function base implementation."""
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.core.types import Documents, Embeddings
|
||||
|
||||
|
||||
class CustomEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""Base class for custom embedding functions.
|
||||
|
||||
This provides a concrete implementation that can be subclassed for custom embeddings.
|
||||
"""
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Convert input documents to embeddings.
|
||||
|
||||
Args:
|
||||
input: List of documents to embed.
|
||||
|
||||
Returns:
|
||||
List of numpy arrays representing the embeddings.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement __call__ method")
|
||||
18
src/crewai/rag/embeddings/providers/custom/types.py
Normal file
18
src/crewai/rag/embeddings/providers/custom/types.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Type definitions for custom embedding providers."""
|
||||
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
from chromadb.api.types import EmbeddingFunction
|
||||
|
||||
|
||||
class CustomProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Custom provider."""
|
||||
|
||||
embedding_callable: type[EmbeddingFunction]
|
||||
|
||||
|
||||
class CustomProviderSpec(TypedDict):
|
||||
"""Custom provider specification."""
|
||||
|
||||
provider: Literal["custom"]
|
||||
config: CustomProviderConfig
|
||||
23
src/crewai/rag/embeddings/providers/google/__init__.py
Normal file
23
src/crewai/rag/embeddings/providers/google/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Google embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.google.generative_ai import (
|
||||
GenerativeAiProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.google.types import (
|
||||
GenerativeAiProviderConfig,
|
||||
GenerativeAiProviderSpec,
|
||||
VertexAIProviderConfig,
|
||||
VertexAIProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.google.vertex import (
|
||||
VertexAIProvider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"GenerativeAiProvider",
|
||||
"GenerativeAiProviderConfig",
|
||||
"GenerativeAiProviderSpec",
|
||||
"VertexAIProvider",
|
||||
"VertexAIProviderConfig",
|
||||
"VertexAIProviderSpec",
|
||||
]
|
||||
30
src/crewai/rag/embeddings/providers/google/generative_ai.py
Normal file
30
src/crewai/rag/embeddings/providers/google/generative_ai.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Google Generative AI embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class GenerativeAiProvider(BaseEmbeddingsProvider[GoogleGenerativeAiEmbeddingFunction]):
|
||||
"""Google Generative AI embeddings provider."""
|
||||
|
||||
embedding_callable: type[GoogleGenerativeAiEmbeddingFunction] = Field(
|
||||
default=GoogleGenerativeAiEmbeddingFunction,
|
||||
description="Google Generative AI embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="models/embedding-001",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="GOOGLE_GENERATIVE_AI_MODEL_NAME",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key", validation_alias="GOOGLE_API_KEY"
|
||||
)
|
||||
task_type: str = Field(
|
||||
default="RETRIEVAL_DOCUMENT",
|
||||
description="Task type for embeddings",
|
||||
validation_alias="GOOGLE_GENERATIVE_AI_TASK_TYPE",
|
||||
)
|
||||
34
src/crewai/rag/embeddings/providers/google/types.py
Normal file
34
src/crewai/rag/embeddings/providers/google/types.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Type definitions for Google embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
|
||||
|
||||
class GenerativeAiProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Google Generative AI provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "models/embedding-001"]
|
||||
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
|
||||
|
||||
|
||||
class GenerativeAiProviderSpec(TypedDict):
|
||||
"""Google Generative AI provider specification."""
|
||||
|
||||
provider: Literal["google-generativeai"]
|
||||
config: GenerativeAiProviderConfig
|
||||
|
||||
|
||||
class VertexAIProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Vertex AI provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "textembedding-gecko"]
|
||||
project_id: Annotated[str, "cloud-large-language-models"]
|
||||
region: Annotated[str, "us-central1"]
|
||||
|
||||
|
||||
class VertexAIProviderSpec(TypedDict):
|
||||
"""Vertex AI provider specification."""
|
||||
|
||||
provider: Literal["google-vertex"]
|
||||
config: VertexAIProviderConfig
|
||||
35
src/crewai/rag/embeddings/providers/google/vertex.py
Normal file
35
src/crewai/rag/embeddings/providers/google/vertex.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Google Vertex AI embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class VertexAIProvider(BaseEmbeddingsProvider[GoogleVertexEmbeddingFunction]):
|
||||
"""Google Vertex AI embeddings provider."""
|
||||
|
||||
embedding_callable: type[GoogleVertexEmbeddingFunction] = Field(
|
||||
default=GoogleVertexEmbeddingFunction,
|
||||
description="Vertex AI embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="textembedding-gecko",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="GOOGLE_VERTEX_MODEL_NAME",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Google API key", validation_alias="GOOGLE_CLOUD_API_KEY"
|
||||
)
|
||||
project_id: str = Field(
|
||||
default="cloud-large-language-models",
|
||||
description="GCP project ID",
|
||||
validation_alias="GOOGLE_CLOUD_PROJECT",
|
||||
)
|
||||
region: str = Field(
|
||||
default="us-central1",
|
||||
description="GCP region",
|
||||
validation_alias="GOOGLE_CLOUD_REGION",
|
||||
)
|
||||
15
src/crewai/rag/embeddings/providers/huggingface/__init__.py
Normal file
15
src/crewai/rag/embeddings/providers/huggingface/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""HuggingFace embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.huggingface.huggingface_provider import (
|
||||
HuggingFaceProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.huggingface.types import (
|
||||
HuggingFaceProviderConfig,
|
||||
HuggingFaceProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"HuggingFaceProvider",
|
||||
"HuggingFaceProviderConfig",
|
||||
"HuggingFaceProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,20 @@
|
||||
"""HuggingFace embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingServer,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
|
||||
"""HuggingFace embeddings provider."""
|
||||
|
||||
embedding_callable: type[HuggingFaceEmbeddingServer] = Field(
|
||||
default=HuggingFaceEmbeddingServer,
|
||||
description="HuggingFace embedding function class",
|
||||
)
|
||||
url: str = Field(
|
||||
description="HuggingFace API URL", validation_alias="HUGGINGFACE_URL"
|
||||
)
|
||||
16
src/crewai/rag/embeddings/providers/huggingface/types.py
Normal file
16
src/crewai/rag/embeddings/providers/huggingface/types.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Type definitions for HuggingFace embedding providers."""
|
||||
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
|
||||
class HuggingFaceProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for HuggingFace provider."""
|
||||
|
||||
url: str
|
||||
|
||||
|
||||
class HuggingFaceProviderSpec(TypedDict):
|
||||
"""HuggingFace provider specification."""
|
||||
|
||||
provider: Literal["huggingface"]
|
||||
config: HuggingFaceProviderConfig
|
||||
15
src/crewai/rag/embeddings/providers/ibm/__init__.py
Normal file
15
src/crewai/rag/embeddings/providers/ibm/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""IBM embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.ibm.types import (
|
||||
WatsonProviderConfig,
|
||||
WatsonProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.ibm.watson import (
|
||||
WatsonProvider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"WatsonProvider",
|
||||
"WatsonProviderConfig",
|
||||
"WatsonProviderSpec",
|
||||
]
|
||||
144
src/crewai/rag/embeddings/providers/ibm/embedding_callable.py
Normal file
144
src/crewai/rag/embeddings/providers/ibm/embedding_callable.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""IBM Watson embedding function implementation."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import ibm_watsonx_ai.foundation_models as watson_models # type: ignore[import-not-found, import-untyped]
|
||||
from ibm_watsonx_ai import Credentials # type: ignore[import-not-found, import-untyped]
|
||||
from ibm_watsonx_ai.metanames import ( # type: ignore[import-not-found, import-untyped]
|
||||
EmbedTextParamsMetaNames as EmbedParams,
|
||||
)
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.core.types import Documents, Embeddings
|
||||
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderConfig
|
||||
|
||||
|
||||
class WatsonEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""Embedding function for IBM Watson models."""
|
||||
|
||||
def __init__(self, **kwargs: Unpack[WatsonProviderConfig]) -> None:
|
||||
"""Initialize Watson embedding function.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration parameters for Watson Embeddings and Credentials.
|
||||
"""
|
||||
self._config = kwargs
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Generate embeddings for input documents.
|
||||
|
||||
Args:
|
||||
input: List of documents to embed.
|
||||
|
||||
Returns:
|
||||
List of embedding vectors.
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
embeddings_config: dict = {
|
||||
"model_id": self._config["model_id"],
|
||||
}
|
||||
if "params" in self._config and self._config["params"] is not None:
|
||||
embeddings_config["params"] = self._config["params"]
|
||||
if "project_id" in self._config and self._config["project_id"] is not None:
|
||||
embeddings_config["project_id"] = self._config["project_id"]
|
||||
if "space_id" in self._config and self._config["space_id"] is not None:
|
||||
embeddings_config["space_id"] = self._config["space_id"]
|
||||
if "api_client" in self._config and self._config["api_client"] is not None:
|
||||
embeddings_config["api_client"] = self._config["api_client"]
|
||||
if "verify" in self._config and self._config["verify"] is not None:
|
||||
embeddings_config["verify"] = self._config["verify"]
|
||||
if "persistent_connection" in self._config:
|
||||
embeddings_config["persistent_connection"] = self._config[
|
||||
"persistent_connection"
|
||||
]
|
||||
if "batch_size" in self._config:
|
||||
embeddings_config["batch_size"] = self._config["batch_size"]
|
||||
if "concurrency_limit" in self._config:
|
||||
embeddings_config["concurrency_limit"] = self._config["concurrency_limit"]
|
||||
if "max_retries" in self._config and self._config["max_retries"] is not None:
|
||||
embeddings_config["max_retries"] = self._config["max_retries"]
|
||||
if "delay_time" in self._config and self._config["delay_time"] is not None:
|
||||
embeddings_config["delay_time"] = self._config["delay_time"]
|
||||
if (
|
||||
"retry_status_codes" in self._config
|
||||
and self._config["retry_status_codes"] is not None
|
||||
):
|
||||
embeddings_config["retry_status_codes"] = self._config["retry_status_codes"]
|
||||
|
||||
if "credentials" in self._config and self._config["credentials"] is not None:
|
||||
embeddings_config["credentials"] = self._config["credentials"]
|
||||
else:
|
||||
cred_config: dict = {}
|
||||
if "url" in self._config and self._config["url"] is not None:
|
||||
cred_config["url"] = self._config["url"]
|
||||
if "api_key" in self._config and self._config["api_key"] is not None:
|
||||
cred_config["api_key"] = self._config["api_key"]
|
||||
if "name" in self._config and self._config["name"] is not None:
|
||||
cred_config["name"] = self._config["name"]
|
||||
if (
|
||||
"iam_serviceid_crn" in self._config
|
||||
and self._config["iam_serviceid_crn"] is not None
|
||||
):
|
||||
cred_config["iam_serviceid_crn"] = self._config["iam_serviceid_crn"]
|
||||
if (
|
||||
"trusted_profile_id" in self._config
|
||||
and self._config["trusted_profile_id"] is not None
|
||||
):
|
||||
cred_config["trusted_profile_id"] = self._config["trusted_profile_id"]
|
||||
if "token" in self._config and self._config["token"] is not None:
|
||||
cred_config["token"] = self._config["token"]
|
||||
if (
|
||||
"projects_token" in self._config
|
||||
and self._config["projects_token"] is not None
|
||||
):
|
||||
cred_config["projects_token"] = self._config["projects_token"]
|
||||
if "username" in self._config and self._config["username"] is not None:
|
||||
cred_config["username"] = self._config["username"]
|
||||
if "password" in self._config and self._config["password"] is not None:
|
||||
cred_config["password"] = self._config["password"]
|
||||
if (
|
||||
"instance_id" in self._config
|
||||
and self._config["instance_id"] is not None
|
||||
):
|
||||
cred_config["instance_id"] = self._config["instance_id"]
|
||||
if "version" in self._config and self._config["version"] is not None:
|
||||
cred_config["version"] = self._config["version"]
|
||||
if (
|
||||
"bedrock_url" in self._config
|
||||
and self._config["bedrock_url"] is not None
|
||||
):
|
||||
cred_config["bedrock_url"] = self._config["bedrock_url"]
|
||||
if (
|
||||
"platform_url" in self._config
|
||||
and self._config["platform_url"] is not None
|
||||
):
|
||||
cred_config["platform_url"] = self._config["platform_url"]
|
||||
if "proxies" in self._config and self._config["proxies"] is not None:
|
||||
cred_config["proxies"] = self._config["proxies"]
|
||||
if (
|
||||
"verify" not in embeddings_config
|
||||
and "verify" in self._config
|
||||
and self._config["verify"] is not None
|
||||
):
|
||||
cred_config["verify"] = self._config["verify"]
|
||||
|
||||
if cred_config:
|
||||
embeddings_config["credentials"] = Credentials(**cred_config)
|
||||
|
||||
if "params" not in embeddings_config:
|
||||
embeddings_config["params"] = {
|
||||
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
|
||||
EmbedParams.RETURN_OPTIONS: {"input_text": True},
|
||||
}
|
||||
|
||||
embedding = watson_models.Embeddings(**embeddings_config)
|
||||
|
||||
try:
|
||||
embeddings = embedding.embed_documents(input)
|
||||
return cast(Embeddings, embeddings)
|
||||
except Exception as e:
|
||||
print(f"Error during Watson embedding: {e}")
|
||||
raise
|
||||
42
src/crewai/rag/embeddings/providers/ibm/types.py
Normal file
42
src/crewai/rag/embeddings/providers/ibm/types.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Type definitions for IBM Watson embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal, TypedDict
|
||||
|
||||
|
||||
class WatsonProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Watson provider."""
|
||||
|
||||
model_id: str
|
||||
url: str
|
||||
params: dict[str, str | dict[str, str]]
|
||||
credentials: Any
|
||||
project_id: str
|
||||
space_id: str
|
||||
api_client: Any
|
||||
verify: bool | str
|
||||
persistent_connection: Annotated[bool, True]
|
||||
batch_size: Annotated[int, 100]
|
||||
concurrency_limit: Annotated[int, 10]
|
||||
max_retries: int
|
||||
delay_time: float
|
||||
retry_status_codes: list[int]
|
||||
api_key: str
|
||||
name: str
|
||||
iam_serviceid_crn: str
|
||||
trusted_profile_id: str
|
||||
token: str
|
||||
projects_token: str
|
||||
username: str
|
||||
password: str
|
||||
instance_id: str
|
||||
version: str
|
||||
bedrock_url: str
|
||||
platform_url: str
|
||||
proxies: dict
|
||||
|
||||
|
||||
class WatsonProviderSpec(TypedDict):
|
||||
"""Watson provider specification."""
|
||||
|
||||
provider: Literal["watson"]
|
||||
config: WatsonProviderConfig
|
||||
126
src/crewai/rag/embeddings/providers/ibm/watson.py
Normal file
126
src/crewai/rag/embeddings/providers/ibm/watson.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""IBM Watson embeddings provider."""
|
||||
|
||||
from ibm_watsonx_ai import ( # type: ignore[import-not-found,import-untyped]
|
||||
APIClient,
|
||||
Credentials,
|
||||
)
|
||||
from pydantic import Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.ibm.embedding_callable import (
|
||||
WatsonEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
class WatsonProvider(BaseEmbeddingsProvider[WatsonEmbeddingFunction]):
|
||||
"""IBM Watson embeddings provider.
|
||||
|
||||
Note: Requires custom implementation as Watson uses a different interface.
|
||||
"""
|
||||
|
||||
embedding_callable: type[WatsonEmbeddingFunction] = Field(
|
||||
default=WatsonEmbeddingFunction, description="Watson embedding function class"
|
||||
)
|
||||
model_id: str = Field(
|
||||
description="Watson model ID", validation_alias="WATSON_MODEL_ID"
|
||||
)
|
||||
params: dict[str, str | dict[str, str]] | None = Field(
|
||||
default=None, description="Additional parameters"
|
||||
)
|
||||
credentials: Credentials | None = Field(
|
||||
default=None, description="Watson credentials"
|
||||
)
|
||||
project_id: str | None = Field(
|
||||
default=None,
|
||||
description="Watson project ID",
|
||||
validation_alias="WATSON_PROJECT_ID",
|
||||
)
|
||||
space_id: str | None = Field(
|
||||
default=None, description="Watson space ID", validation_alias="WATSON_SPACE_ID"
|
||||
)
|
||||
api_client: APIClient | None = Field(default=None, description="Watson API client")
|
||||
verify: bool | str | None = Field(
|
||||
default=None, description="SSL verification", validation_alias="WATSON_VERIFY"
|
||||
)
|
||||
persistent_connection: bool = Field(
|
||||
default=True,
|
||||
description="Use persistent connection",
|
||||
validation_alias="WATSON_PERSISTENT_CONNECTION",
|
||||
)
|
||||
batch_size: int = Field(
|
||||
default=100,
|
||||
description="Batch size for processing",
|
||||
validation_alias="WATSON_BATCH_SIZE",
|
||||
)
|
||||
concurrency_limit: int = Field(
|
||||
default=10,
|
||||
description="Concurrency limit",
|
||||
validation_alias="WATSON_CONCURRENCY_LIMIT",
|
||||
)
|
||||
max_retries: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum retries",
|
||||
validation_alias="WATSON_MAX_RETRIES",
|
||||
)
|
||||
delay_time: float | None = Field(
|
||||
default=None,
|
||||
description="Delay time between retries",
|
||||
validation_alias="WATSON_DELAY_TIME",
|
||||
)
|
||||
retry_status_codes: list[int] | None = Field(
|
||||
default=None, description="HTTP status codes to retry on"
|
||||
)
|
||||
url: str = Field(description="Watson API URL", validation_alias="WATSON_URL")
|
||||
api_key: str = Field(
|
||||
description="Watson API key", validation_alias="WATSON_API_KEY"
|
||||
)
|
||||
name: str | None = Field(
|
||||
default=None, description="Service name", validation_alias="WATSON_NAME"
|
||||
)
|
||||
iam_serviceid_crn: str | None = Field(
|
||||
default=None,
|
||||
description="IAM service ID CRN",
|
||||
validation_alias="WATSON_IAM_SERVICEID_CRN",
|
||||
)
|
||||
trusted_profile_id: str | None = Field(
|
||||
default=None,
|
||||
description="Trusted profile ID",
|
||||
validation_alias="WATSON_TRUSTED_PROFILE_ID",
|
||||
)
|
||||
token: str | None = Field(
|
||||
default=None, description="Bearer token", validation_alias="WATSON_TOKEN"
|
||||
)
|
||||
projects_token: str | None = Field(
|
||||
default=None,
|
||||
description="Projects token",
|
||||
validation_alias="WATSON_PROJECTS_TOKEN",
|
||||
)
|
||||
username: str | None = Field(
|
||||
default=None, description="Username", validation_alias="WATSON_USERNAME"
|
||||
)
|
||||
password: str | None = Field(
|
||||
default=None, description="Password", validation_alias="WATSON_PASSWORD"
|
||||
)
|
||||
instance_id: str | None = Field(
|
||||
default=None,
|
||||
description="Service instance ID",
|
||||
validation_alias="WATSON_INSTANCE_ID",
|
||||
)
|
||||
version: str | None = Field(
|
||||
default=None, description="API version", validation_alias="WATSON_VERSION"
|
||||
)
|
||||
bedrock_url: str | None = Field(
|
||||
default=None, description="Bedrock URL", validation_alias="WATSON_BEDROCK_URL"
|
||||
)
|
||||
platform_url: str | None = Field(
|
||||
default=None, description="Platform URL", validation_alias="WATSON_PLATFORM_URL"
|
||||
)
|
||||
proxies: dict | None = Field(default=None, description="Proxy configuration")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_space_or_project(self) -> Self:
|
||||
"""Validate that either space_id or project_id is provided."""
|
||||
if not self.space_id and not self.project_id:
|
||||
raise ValueError("One of 'space_id' or 'project_id' must be provided")
|
||||
return self
|
||||
15
src/crewai/rag/embeddings/providers/instructor/__init__.py
Normal file
15
src/crewai/rag/embeddings/providers/instructor/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Instructor embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.instructor.instructor_provider import (
|
||||
InstructorProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.instructor.types import (
|
||||
InstructorProviderConfig,
|
||||
InstructorProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"InstructorProvider",
|
||||
"InstructorProviderConfig",
|
||||
"InstructorProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,32 @@
|
||||
"""Instructor embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.instructor_embedding_function import (
|
||||
InstructorEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class InstructorProvider(BaseEmbeddingsProvider[InstructorEmbeddingFunction]):
|
||||
"""Instructor embeddings provider."""
|
||||
|
||||
embedding_callable: type[InstructorEmbeddingFunction] = Field(
|
||||
default=InstructorEmbeddingFunction,
|
||||
description="Instructor embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="hkunlp/instructor-base",
|
||||
description="Model name to use",
|
||||
validation_alias="INSTRUCTOR_MODEL_NAME",
|
||||
)
|
||||
device: str = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on (cpu or cuda)",
|
||||
validation_alias="INSTRUCTOR_DEVICE",
|
||||
)
|
||||
instruction: str | None = Field(
|
||||
default=None,
|
||||
description="Instruction for embeddings",
|
||||
validation_alias="INSTRUCTOR_INSTRUCTION",
|
||||
)
|
||||
18
src/crewai/rag/embeddings/providers/instructor/types.py
Normal file
18
src/crewai/rag/embeddings/providers/instructor/types.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Type definitions for Instructor embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
|
||||
|
||||
class InstructorProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Instructor provider."""
|
||||
|
||||
model_name: Annotated[str, "hkunlp/instructor-base"]
|
||||
device: Annotated[str, "cpu"]
|
||||
instruction: str
|
||||
|
||||
|
||||
class InstructorProviderSpec(TypedDict):
|
||||
"""Instructor provider specification."""
|
||||
|
||||
provider: Literal["instructor"]
|
||||
config: InstructorProviderConfig
|
||||
13
src/crewai/rag/embeddings/providers/jina/__init__.py
Normal file
13
src/crewai/rag/embeddings/providers/jina/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""Jina embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.jina.jina_provider import JinaProvider
|
||||
from crewai.rag.embeddings.providers.jina.types import (
|
||||
JinaProviderConfig,
|
||||
JinaProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"JinaProvider",
|
||||
"JinaProviderConfig",
|
||||
"JinaProviderSpec",
|
||||
]
|
||||
22
src/crewai/rag/embeddings/providers/jina/jina_provider.py
Normal file
22
src/crewai/rag/embeddings/providers/jina/jina_provider.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Jina embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.jina_embedding_function import (
|
||||
JinaEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class JinaProvider(BaseEmbeddingsProvider[JinaEmbeddingFunction]):
|
||||
"""Jina embeddings provider."""
|
||||
|
||||
embedding_callable: type[JinaEmbeddingFunction] = Field(
|
||||
default=JinaEmbeddingFunction, description="Jina embedding function class"
|
||||
)
|
||||
api_key: str = Field(description="Jina API key", validation_alias="JINA_API_KEY")
|
||||
model_name: str = Field(
|
||||
default="jina-embeddings-v2-base-en",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="JINA_MODEL_NAME",
|
||||
)
|
||||
17
src/crewai/rag/embeddings/providers/jina/types.py
Normal file
17
src/crewai/rag/embeddings/providers/jina/types.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Type definitions for Jina embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
|
||||
|
||||
class JinaProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Jina provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "jina-embeddings-v2-base-en"]
|
||||
|
||||
|
||||
class JinaProviderSpec(TypedDict):
|
||||
"""Jina provider specification."""
|
||||
|
||||
provider: Literal["jina"]
|
||||
config: JinaProviderConfig
|
||||
15
src/crewai/rag/embeddings/providers/microsoft/__init__.py
Normal file
15
src/crewai/rag/embeddings/providers/microsoft/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Microsoft embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.microsoft.azure import (
|
||||
AzureProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.microsoft.types import (
|
||||
AzureProviderConfig,
|
||||
AzureProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AzureProvider",
|
||||
"AzureProviderConfig",
|
||||
"AzureProviderSpec",
|
||||
]
|
||||
58
src/crewai/rag/embeddings/providers/microsoft/azure.py
Normal file
58
src/crewai/rag/embeddings/providers/microsoft/azure.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Azure OpenAI embeddings provider."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class AzureProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
"""Azure OpenAI embeddings provider."""
|
||||
|
||||
embedding_callable: type[OpenAIEmbeddingFunction] = Field(
|
||||
default=OpenAIEmbeddingFunction,
|
||||
description="Azure OpenAI embedding function class",
|
||||
)
|
||||
api_key: str = Field(description="Azure API key", validation_alias="OPENAI_API_KEY")
|
||||
api_base: str | None = Field(
|
||||
default=None,
|
||||
description="Azure endpoint URL",
|
||||
validation_alias="OPENAI_API_BASE",
|
||||
)
|
||||
api_type: str = Field(
|
||||
default="azure",
|
||||
description="API type for Azure",
|
||||
validation_alias="OPENAI_API_TYPE",
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default=None,
|
||||
description="Azure API version",
|
||||
validation_alias="OPENAI_API_VERSION",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-ada-002",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="OPENAI_MODEL_NAME",
|
||||
)
|
||||
default_headers: dict[str, Any] | None = Field(
|
||||
default=None, description="Default headers for API requests"
|
||||
)
|
||||
dimensions: int | None = Field(
|
||||
default=None,
|
||||
description="Embedding dimensions",
|
||||
validation_alias="OPENAI_DIMENSIONS",
|
||||
)
|
||||
deployment_id: str | None = Field(
|
||||
default=None,
|
||||
description="Azure deployment ID",
|
||||
validation_alias="OPENAI_DEPLOYMENT_ID",
|
||||
)
|
||||
organization_id: str | None = Field(
|
||||
default=None,
|
||||
description="Organization ID",
|
||||
validation_alias="OPENAI_ORGANIZATION_ID",
|
||||
)
|
||||
24
src/crewai/rag/embeddings/providers/microsoft/types.py
Normal file
24
src/crewai/rag/embeddings/providers/microsoft/types.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Type definitions for Microsoft Azure embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal, TypedDict
|
||||
|
||||
|
||||
class AzureProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Azure provider."""
|
||||
|
||||
api_key: str
|
||||
api_base: str
|
||||
api_type: Annotated[str, "azure"]
|
||||
api_version: str
|
||||
model_name: Annotated[str, "text-embedding-ada-002"]
|
||||
default_headers: dict[str, Any]
|
||||
dimensions: int
|
||||
deployment_id: str
|
||||
organization_id: str
|
||||
|
||||
|
||||
class AzureProviderSpec(TypedDict):
|
||||
"""Azure provider specification."""
|
||||
|
||||
provider: Literal["azure"]
|
||||
config: AzureProviderConfig
|
||||
15
src/crewai/rag/embeddings/providers/ollama/__init__.py
Normal file
15
src/crewai/rag/embeddings/providers/ollama/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Ollama embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.ollama.ollama_provider import (
|
||||
OllamaProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.ollama.types import (
|
||||
OllamaProviderConfig,
|
||||
OllamaProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OllamaProvider",
|
||||
"OllamaProviderConfig",
|
||||
"OllamaProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,25 @@
|
||||
"""Ollama embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class OllamaProvider(BaseEmbeddingsProvider[OllamaEmbeddingFunction]):
|
||||
"""Ollama embeddings provider."""
|
||||
|
||||
embedding_callable: type[OllamaEmbeddingFunction] = Field(
|
||||
default=OllamaEmbeddingFunction, description="Ollama embedding function class"
|
||||
)
|
||||
url: str = Field(
|
||||
default="http://localhost:11434/api/embeddings",
|
||||
description="Ollama API endpoint URL",
|
||||
validation_alias="OLLAMA_URL",
|
||||
)
|
||||
model_name: str = Field(
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="OLLAMA_MODEL_NAME",
|
||||
)
|
||||
17
src/crewai/rag/embeddings/providers/ollama/types.py
Normal file
17
src/crewai/rag/embeddings/providers/ollama/types.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Type definitions for Ollama embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
|
||||
|
||||
class OllamaProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Ollama provider."""
|
||||
|
||||
url: Annotated[str, "http://localhost:11434/api/embeddings"]
|
||||
model_name: str
|
||||
|
||||
|
||||
class OllamaProviderSpec(TypedDict):
|
||||
"""Ollama provider specification."""
|
||||
|
||||
provider: Literal["ollama"]
|
||||
config: OllamaProviderConfig
|
||||
13
src/crewai/rag/embeddings/providers/onnx/__init__.py
Normal file
13
src/crewai/rag/embeddings/providers/onnx/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""ONNX embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.onnx.onnx_provider import ONNXProvider
|
||||
from crewai.rag.embeddings.providers.onnx.types import (
|
||||
ONNXProviderConfig,
|
||||
ONNXProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ONNXProvider",
|
||||
"ONNXProviderConfig",
|
||||
"ONNXProviderSpec",
|
||||
]
|
||||
19
src/crewai/rag/embeddings/providers/onnx/onnx_provider.py
Normal file
19
src/crewai/rag/embeddings/providers/onnx/onnx_provider.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""ONNX embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ONNXMiniLM_L6_V2
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class ONNXProvider(BaseEmbeddingsProvider[ONNXMiniLM_L6_V2]):
|
||||
"""ONNX embeddings provider."""
|
||||
|
||||
embedding_callable: type[ONNXMiniLM_L6_V2] = Field(
|
||||
default=ONNXMiniLM_L6_V2, description="ONNX MiniLM embedding function class"
|
||||
)
|
||||
preferred_providers: list[str] | None = Field(
|
||||
default=None,
|
||||
description="Preferred ONNX execution providers",
|
||||
validation_alias="ONNX_PREFERRED_PROVIDERS",
|
||||
)
|
||||
16
src/crewai/rag/embeddings/providers/onnx/types.py
Normal file
16
src/crewai/rag/embeddings/providers/onnx/types.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Type definitions for ONNX embedding providers."""
|
||||
|
||||
from typing import Literal, TypedDict
|
||||
|
||||
|
||||
class ONNXProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for ONNX provider."""
|
||||
|
||||
preferred_providers: list[str]
|
||||
|
||||
|
||||
class ONNXProviderSpec(TypedDict):
|
||||
"""ONNX provider specification."""
|
||||
|
||||
provider: Literal["onnx"]
|
||||
config: ONNXProviderConfig
|
||||
15
src/crewai/rag/embeddings/providers/openai/__init__.py
Normal file
15
src/crewai/rag/embeddings/providers/openai/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""OpenAI embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.openai.openai_provider import (
|
||||
OpenAIProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.openai.types import (
|
||||
OpenAIProviderConfig,
|
||||
OpenAIProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OpenAIProvider",
|
||||
"OpenAIProviderConfig",
|
||||
"OpenAIProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,58 @@
|
||||
"""OpenAI embeddings provider."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class OpenAIProvider(BaseEmbeddingsProvider[OpenAIEmbeddingFunction]):
|
||||
"""OpenAI embeddings provider."""
|
||||
|
||||
embedding_callable: type[OpenAIEmbeddingFunction] = Field(
|
||||
default=OpenAIEmbeddingFunction,
|
||||
description="OpenAI embedding function class",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="OpenAI API key", validation_alias="OPENAI_API_KEY"
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="text-embedding-ada-002",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias="OPENAI_MODEL_NAME",
|
||||
)
|
||||
api_base: str | None = Field(
|
||||
default=None,
|
||||
description="Base URL for API requests",
|
||||
validation_alias="OPENAI_API_BASE",
|
||||
)
|
||||
api_type: str | None = Field(
|
||||
default=None,
|
||||
description="API type (e.g., 'azure')",
|
||||
validation_alias="OPENAI_API_TYPE",
|
||||
)
|
||||
api_version: str | None = Field(
|
||||
default=None, description="API version", validation_alias="OPENAI_API_VERSION"
|
||||
)
|
||||
default_headers: dict[str, Any] | None = Field(
|
||||
default=None, description="Default headers for API requests"
|
||||
)
|
||||
dimensions: int | None = Field(
|
||||
default=None,
|
||||
description="Embedding dimensions",
|
||||
validation_alias="OPENAI_DIMENSIONS",
|
||||
)
|
||||
deployment_id: str | None = Field(
|
||||
default=None,
|
||||
description="Azure deployment ID",
|
||||
validation_alias="OPENAI_DEPLOYMENT_ID",
|
||||
)
|
||||
organization_id: str | None = Field(
|
||||
default=None,
|
||||
description="OpenAI organization ID",
|
||||
validation_alias="OPENAI_ORGANIZATION_ID",
|
||||
)
|
||||
24
src/crewai/rag/embeddings/providers/openai/types.py
Normal file
24
src/crewai/rag/embeddings/providers/openai/types.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Type definitions for OpenAI embedding providers."""
|
||||
|
||||
from typing import Annotated, Any, Literal, TypedDict
|
||||
|
||||
|
||||
class OpenAIProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for OpenAI provider."""
|
||||
|
||||
api_key: str
|
||||
model_name: Annotated[str, "text-embedding-ada-002"]
|
||||
api_base: str
|
||||
api_type: str
|
||||
api_version: str
|
||||
default_headers: dict[str, Any]
|
||||
dimensions: int
|
||||
deployment_id: str
|
||||
organization_id: str
|
||||
|
||||
|
||||
class OpenAIProviderSpec(TypedDict):
|
||||
"""OpenAI provider specification."""
|
||||
|
||||
provider: Literal["openai"]
|
||||
config: OpenAIProviderConfig
|
||||
15
src/crewai/rag/embeddings/providers/openclip/__init__.py
Normal file
15
src/crewai/rag/embeddings/providers/openclip/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""OpenCLIP embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.openclip.openclip_provider import (
|
||||
OpenCLIPProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.openclip.types import (
|
||||
OpenCLIPProviderConfig,
|
||||
OpenCLIPProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OpenCLIPProvider",
|
||||
"OpenCLIPProviderConfig",
|
||||
"OpenCLIPProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,32 @@
|
||||
"""OpenCLIP embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
|
||||
OpenCLIPEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class OpenCLIPProvider(BaseEmbeddingsProvider[OpenCLIPEmbeddingFunction]):
|
||||
"""OpenCLIP embeddings provider."""
|
||||
|
||||
embedding_callable: type[OpenCLIPEmbeddingFunction] = Field(
|
||||
default=OpenCLIPEmbeddingFunction,
|
||||
description="OpenCLIP embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="ViT-B-32",
|
||||
description="Model name to use",
|
||||
validation_alias="OPENCLIP_MODEL_NAME",
|
||||
)
|
||||
checkpoint: str = Field(
|
||||
default="laion2b_s34b_b79k",
|
||||
description="Model checkpoint",
|
||||
validation_alias="OPENCLIP_CHECKPOINT",
|
||||
)
|
||||
device: str | None = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on",
|
||||
validation_alias="OPENCLIP_DEVICE",
|
||||
)
|
||||
18
src/crewai/rag/embeddings/providers/openclip/types.py
Normal file
18
src/crewai/rag/embeddings/providers/openclip/types.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Type definitions for OpenCLIP embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
|
||||
|
||||
class OpenCLIPProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for OpenCLIP provider."""
|
||||
|
||||
model_name: Annotated[str, "ViT-B-32"]
|
||||
checkpoint: Annotated[str, "laion2b_s34b_b79k"]
|
||||
device: Annotated[str, "cpu"]
|
||||
|
||||
|
||||
class OpenCLIPProviderSpec(TypedDict):
|
||||
"""OpenCLIP provider specification."""
|
||||
|
||||
provider: Literal["openclip"]
|
||||
config: OpenCLIPProviderConfig
|
||||
15
src/crewai/rag/embeddings/providers/roboflow/__init__.py
Normal file
15
src/crewai/rag/embeddings/providers/roboflow/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Roboflow embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.roboflow.roboflow_provider import (
|
||||
RoboflowProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.roboflow.types import (
|
||||
RoboflowProviderConfig,
|
||||
RoboflowProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RoboflowProvider",
|
||||
"RoboflowProviderConfig",
|
||||
"RoboflowProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,25 @@
|
||||
"""Roboflow embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
|
||||
RoboflowEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class RoboflowProvider(BaseEmbeddingsProvider[RoboflowEmbeddingFunction]):
|
||||
"""Roboflow embeddings provider."""
|
||||
|
||||
embedding_callable: type[RoboflowEmbeddingFunction] = Field(
|
||||
default=RoboflowEmbeddingFunction,
|
||||
description="Roboflow embedding function class",
|
||||
)
|
||||
api_key: str = Field(
|
||||
default="", description="Roboflow API key", validation_alias="ROBOFLOW_API_KEY"
|
||||
)
|
||||
api_url: str = Field(
|
||||
default="https://infer.roboflow.com",
|
||||
description="Roboflow API URL",
|
||||
validation_alias="ROBOFLOW_API_URL",
|
||||
)
|
||||
17
src/crewai/rag/embeddings/providers/roboflow/types.py
Normal file
17
src/crewai/rag/embeddings/providers/roboflow/types.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""Type definitions for Roboflow embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
|
||||
|
||||
class RoboflowProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Roboflow provider."""
|
||||
|
||||
api_key: Annotated[str, ""]
|
||||
api_url: Annotated[str, "https://infer.roboflow.com"]
|
||||
|
||||
|
||||
class RoboflowProviderSpec(TypedDict):
|
||||
"""Roboflow provider specification."""
|
||||
|
||||
provider: Literal["roboflow"]
|
||||
config: RoboflowProviderConfig
|
||||
@@ -0,0 +1,15 @@
|
||||
"""SentenceTransformer embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.sentence_transformer.sentence_transformer_provider import (
|
||||
SentenceTransformerProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.sentence_transformer.types import (
|
||||
SentenceTransformerProviderConfig,
|
||||
SentenceTransformerProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SentenceTransformerProvider",
|
||||
"SentenceTransformerProviderConfig",
|
||||
"SentenceTransformerProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,34 @@
|
||||
"""SentenceTransformer embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
|
||||
SentenceTransformerEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class SentenceTransformerProvider(
|
||||
BaseEmbeddingsProvider[SentenceTransformerEmbeddingFunction]
|
||||
):
|
||||
"""SentenceTransformer embeddings provider."""
|
||||
|
||||
embedding_callable: type[SentenceTransformerEmbeddingFunction] = Field(
|
||||
default=SentenceTransformerEmbeddingFunction,
|
||||
description="SentenceTransformer embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="all-MiniLM-L6-v2",
|
||||
description="Model name to use",
|
||||
validation_alias="SENTENCE_TRANSFORMER_MODEL_NAME",
|
||||
)
|
||||
device: str = Field(
|
||||
default="cpu",
|
||||
description="Device to run model on (cpu or cuda)",
|
||||
validation_alias="SENTENCE_TRANSFORMER_DEVICE",
|
||||
)
|
||||
normalize_embeddings: bool = Field(
|
||||
default=False,
|
||||
description="Whether to normalize embeddings",
|
||||
validation_alias="SENTENCE_TRANSFORMER_NORMALIZE_EMBEDDINGS",
|
||||
)
|
||||
@@ -0,0 +1,18 @@
|
||||
"""Type definitions for SentenceTransformer embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
|
||||
|
||||
class SentenceTransformerProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for SentenceTransformer provider."""
|
||||
|
||||
model_name: Annotated[str, "all-MiniLM-L6-v2"]
|
||||
device: Annotated[str, "cpu"]
|
||||
normalize_embeddings: Annotated[bool, False]
|
||||
|
||||
|
||||
class SentenceTransformerProviderSpec(TypedDict):
|
||||
"""SentenceTransformer provider specification."""
|
||||
|
||||
provider: Literal["sentence-transformer"]
|
||||
config: SentenceTransformerProviderConfig
|
||||
15
src/crewai/rag/embeddings/providers/text2vec/__init__.py
Normal file
15
src/crewai/rag/embeddings/providers/text2vec/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Text2Vec embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.text2vec.text2vec_provider import (
|
||||
Text2VecProvider,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.text2vec.types import (
|
||||
Text2VecProviderConfig,
|
||||
Text2VecProviderSpec,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Text2VecProvider",
|
||||
"Text2VecProviderConfig",
|
||||
"Text2VecProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,22 @@
|
||||
"""Text2Vec embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
|
||||
Text2VecEmbeddingFunction,
|
||||
)
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class Text2VecProvider(BaseEmbeddingsProvider[Text2VecEmbeddingFunction]):
|
||||
"""Text2Vec embeddings provider."""
|
||||
|
||||
embedding_callable: type[Text2VecEmbeddingFunction] = Field(
|
||||
default=Text2VecEmbeddingFunction,
|
||||
description="Text2Vec embedding function class",
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="shibing624/text2vec-base-chinese",
|
||||
description="Model name to use",
|
||||
validation_alias="TEXT2VEC_MODEL_NAME",
|
||||
)
|
||||
16
src/crewai/rag/embeddings/providers/text2vec/types.py
Normal file
16
src/crewai/rag/embeddings/providers/text2vec/types.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Type definitions for Text2Vec embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
|
||||
|
||||
class Text2VecProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for Text2Vec provider."""
|
||||
|
||||
model_name: Annotated[str, "shibing624/text2vec-base-chinese"]
|
||||
|
||||
|
||||
class Text2VecProviderSpec(TypedDict):
|
||||
"""Text2Vec provider specification."""
|
||||
|
||||
provider: Literal["text2vec"]
|
||||
config: Text2VecProviderConfig
|
||||
15
src/crewai/rag/embeddings/providers/voyageai/__init__.py
Normal file
15
src/crewai/rag/embeddings/providers/voyageai/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""VoyageAI embedding providers."""
|
||||
|
||||
from crewai.rag.embeddings.providers.voyageai.types import (
|
||||
VoyageAIProviderConfig,
|
||||
VoyageAIProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.voyageai.voyageai_provider import (
|
||||
VoyageAIProvider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"VoyageAIProvider",
|
||||
"VoyageAIProviderConfig",
|
||||
"VoyageAIProviderSpec",
|
||||
]
|
||||
@@ -0,0 +1,50 @@
|
||||
"""VoyageAI embedding function implementation."""
|
||||
|
||||
from typing import cast
|
||||
|
||||
import voyageai
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from crewai.rag.core.base_embeddings_callable import EmbeddingFunction
|
||||
from crewai.rag.core.types import Documents, Embeddings
|
||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderConfig
|
||||
|
||||
|
||||
class VoyageAIEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"""Embedding function for VoyageAI models."""
|
||||
|
||||
def __init__(self, **kwargs: Unpack[VoyageAIProviderConfig]) -> None:
|
||||
"""Initialize VoyageAI embedding function.
|
||||
|
||||
Args:
|
||||
**kwargs: Configuration parameters for VoyageAI.
|
||||
"""
|
||||
self._config = kwargs
|
||||
self._client = voyageai.Client(
|
||||
api_key=kwargs["api_key"],
|
||||
max_retries=kwargs.get("max_retries", 0),
|
||||
timeout=kwargs.get("timeout"),
|
||||
)
|
||||
|
||||
def __call__(self, input: Documents) -> Embeddings:
|
||||
"""Generate embeddings for input documents.
|
||||
|
||||
Args:
|
||||
input: List of documents to embed.
|
||||
|
||||
Returns:
|
||||
List of embedding vectors.
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
result = self._client.embed(
|
||||
texts=input,
|
||||
model=self._config.get("model", "voyage-2"),
|
||||
input_type=self._config.get("input_type"),
|
||||
truncation=self._config.get("truncation", True),
|
||||
output_dtype=self._config.get("output_dtype"),
|
||||
output_dimension=self._config.get("output_dimension"),
|
||||
)
|
||||
|
||||
return cast(Embeddings, result.embeddings)
|
||||
23
src/crewai/rag/embeddings/providers/voyageai/types.py
Normal file
23
src/crewai/rag/embeddings/providers/voyageai/types.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Type definitions for VoyageAI embedding providers."""
|
||||
|
||||
from typing import Annotated, Literal, TypedDict
|
||||
|
||||
|
||||
class VoyageAIProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for VoyageAI provider."""
|
||||
|
||||
api_key: str
|
||||
model: Annotated[str, "voyage-2"]
|
||||
input_type: str
|
||||
truncation: Annotated[bool, True]
|
||||
output_dtype: str
|
||||
output_dimension: int
|
||||
max_retries: Annotated[int, 0]
|
||||
timeout: float
|
||||
|
||||
|
||||
class VoyageAIProviderSpec(TypedDict):
|
||||
"""VoyageAI provider specification."""
|
||||
|
||||
provider: Literal["voyageai"]
|
||||
config: VoyageAIProviderConfig
|
||||
@@ -0,0 +1,55 @@
|
||||
"""Voyage AI embeddings provider."""
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.providers.voyageai.embedding_callable import (
|
||||
VoyageAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
|
||||
class VoyageAIProvider(BaseEmbeddingsProvider[VoyageAIEmbeddingFunction]):
|
||||
"""Voyage AI embeddings provider."""
|
||||
|
||||
embedding_callable: type[VoyageAIEmbeddingFunction] = Field(
|
||||
default=VoyageAIEmbeddingFunction,
|
||||
description="Voyage AI embedding function class",
|
||||
)
|
||||
model: str = Field(
|
||||
default="voyage-2",
|
||||
description="Model to use for embeddings",
|
||||
validation_alias="VOYAGEAI_MODEL",
|
||||
)
|
||||
api_key: str = Field(
|
||||
description="Voyage AI API key", validation_alias="VOYAGEAI_API_KEY"
|
||||
)
|
||||
input_type: str | None = Field(
|
||||
default=None,
|
||||
description="Input type for embeddings",
|
||||
validation_alias="VOYAGEAI_INPUT_TYPE",
|
||||
)
|
||||
truncation: bool = Field(
|
||||
default=True,
|
||||
description="Whether to truncate inputs",
|
||||
validation_alias="VOYAGEAI_TRUNCATION",
|
||||
)
|
||||
output_dtype: str | None = Field(
|
||||
default=None,
|
||||
description="Output data type",
|
||||
validation_alias="VOYAGEAI_OUTPUT_DTYPE",
|
||||
)
|
||||
output_dimension: int | None = Field(
|
||||
default=None,
|
||||
description="Output dimension",
|
||||
validation_alias="VOYAGEAI_OUTPUT_DIMENSION",
|
||||
)
|
||||
max_retries: int = Field(
|
||||
default=0,
|
||||
description="Maximum retries for API calls",
|
||||
validation_alias="VOYAGEAI_MAX_RETRIES",
|
||||
)
|
||||
timeout: float | None = Field(
|
||||
default=None,
|
||||
description="Timeout for API calls",
|
||||
validation_alias="VOYAGEAI_TIMEOUT",
|
||||
)
|
||||
@@ -2,61 +2,67 @@
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field, SecretStr
|
||||
from crewai.rag.embeddings.providers.aws.types import BedrockProviderSpec
|
||||
from crewai.rag.embeddings.providers.cohere.types import CohereProviderSpec
|
||||
from crewai.rag.embeddings.providers.custom.types import CustomProviderSpec
|
||||
from crewai.rag.embeddings.providers.google.types import (
|
||||
GenerativeAiProviderSpec,
|
||||
VertexAIProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.huggingface.types import HuggingFaceProviderSpec
|
||||
from crewai.rag.embeddings.providers.ibm.types import WatsonProviderSpec
|
||||
from crewai.rag.embeddings.providers.instructor.types import InstructorProviderSpec
|
||||
from crewai.rag.embeddings.providers.jina.types import JinaProviderSpec
|
||||
from crewai.rag.embeddings.providers.microsoft.types import AzureProviderSpec
|
||||
from crewai.rag.embeddings.providers.ollama.types import OllamaProviderSpec
|
||||
from crewai.rag.embeddings.providers.onnx.types import ONNXProviderSpec
|
||||
from crewai.rag.embeddings.providers.openai.types import OpenAIProviderSpec
|
||||
from crewai.rag.embeddings.providers.openclip.types import OpenCLIPProviderSpec
|
||||
from crewai.rag.embeddings.providers.roboflow.types import RoboflowProviderSpec
|
||||
from crewai.rag.embeddings.providers.sentence_transformer.types import (
|
||||
SentenceTransformerProviderSpec,
|
||||
)
|
||||
from crewai.rag.embeddings.providers.text2vec.types import Text2VecProviderSpec
|
||||
from crewai.rag.embeddings.providers.voyageai.types import VoyageAIProviderSpec
|
||||
|
||||
from crewai.rag.types import EmbeddingFunction
|
||||
ProviderSpec = (
|
||||
AzureProviderSpec
|
||||
| BedrockProviderSpec
|
||||
| CohereProviderSpec
|
||||
| CustomProviderSpec
|
||||
| GenerativeAiProviderSpec
|
||||
| HuggingFaceProviderSpec
|
||||
| InstructorProviderSpec
|
||||
| JinaProviderSpec
|
||||
| OllamaProviderSpec
|
||||
| ONNXProviderSpec
|
||||
| OpenAIProviderSpec
|
||||
| OpenCLIPProviderSpec
|
||||
| RoboflowProviderSpec
|
||||
| SentenceTransformerProviderSpec
|
||||
| Text2VecProviderSpec
|
||||
| VertexAIProviderSpec
|
||||
| VoyageAIProviderSpec
|
||||
| WatsonProviderSpec
|
||||
)
|
||||
|
||||
EmbeddingProvider = Literal[
|
||||
"openai",
|
||||
AllowedEmbeddingProviders = Literal[
|
||||
"azure",
|
||||
"amazon-bedrock",
|
||||
"cohere",
|
||||
"ollama",
|
||||
"huggingface",
|
||||
"sentence-transformer",
|
||||
"instructor",
|
||||
"google-palm",
|
||||
"custom",
|
||||
"google-generativeai",
|
||||
"google-vertex",
|
||||
"amazon-bedrock",
|
||||
"huggingface",
|
||||
"instructor",
|
||||
"jina",
|
||||
"roboflow",
|
||||
"openclip",
|
||||
"text2vec",
|
||||
"ollama",
|
||||
"onnx",
|
||||
"openai",
|
||||
"openclip",
|
||||
"roboflow",
|
||||
"sentence-transformer",
|
||||
"text2vec",
|
||||
"voyageai",
|
||||
"watson",
|
||||
]
|
||||
"""Supported embedding providers.
|
||||
|
||||
These correspond to the embedding functions available in ChromaDB's
|
||||
embedding_functions module. Each provider has specific requirements
|
||||
and configuration options.
|
||||
"""
|
||||
|
||||
|
||||
class EmbeddingOptions(BaseModel):
|
||||
"""Configuration options for embedding providers.
|
||||
|
||||
Generic attributes that can be passed to get_embedding_function
|
||||
to configure various embedding providers.
|
||||
"""
|
||||
|
||||
provider: EmbeddingProvider = Field(
|
||||
..., description="Embedding provider name (e.g., 'openai', 'cohere', 'onnx')"
|
||||
)
|
||||
model_name: str | None = Field(
|
||||
default=None, description="Model name for the embedding provider"
|
||||
)
|
||||
api_key: SecretStr | None = Field(
|
||||
default=None, description="API key for the embedding provider"
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingConfig(BaseModel):
|
||||
"""Configuration wrapper for embedding functions.
|
||||
|
||||
Accepts either a pre-configured EmbeddingFunction or EmbeddingOptions
|
||||
to create one. This provides flexibility in how embeddings are configured.
|
||||
|
||||
Attributes:
|
||||
function: Either a callable EmbeddingFunction or EmbeddingOptions to create one
|
||||
"""
|
||||
|
||||
function: EmbeddingFunction | EmbeddingOptions
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from crewai.rag.embeddings.factory import EmbedderConfig
|
||||
from crewai.rag.embeddings.types import EmbeddingOptions
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
|
||||
|
||||
class BaseRAGStorage(ABC):
|
||||
@@ -16,7 +16,7 @@ class BaseRAGStorage(ABC):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: EmbeddingOptions | EmbedderConfig | None = None,
|
||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
|
||||
crew: Any = None,
|
||||
):
|
||||
self.type = type
|
||||
|
||||
@@ -24,8 +24,7 @@ class BaseRecord(TypedDict, total=False):
|
||||
)
|
||||
|
||||
|
||||
DenseVector: TypeAlias = list[float]
|
||||
IntVector: TypeAlias = list[int]
|
||||
Embeddings: TypeAlias = list[list[float]]
|
||||
|
||||
EmbeddingFunction: TypeAlias = Callable[..., Any]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user