mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
* sort imports * update --------- Co-authored-by: Brandon Hancock (bhancock_ai) <109994880+bhancockio@users.noreply.github.com> Co-authored-by: Eduardo Chiarotti <dudumelgaco@hotmail.com>
185 lines
6.3 KiB
Python
185 lines
6.3 KiB
Python
import os
|
|
from typing import Any, Dict, cast
|
|
|
|
from chromadb import Documents, EmbeddingFunction, Embeddings
|
|
from chromadb.api.types import validate_embedding_function
|
|
|
|
|
|
class EmbeddingConfigurator:
|
|
def __init__(self):
|
|
self.embedding_functions = {
|
|
"openai": self._configure_openai,
|
|
"azure": self._configure_azure,
|
|
"ollama": self._configure_ollama,
|
|
"vertexai": self._configure_vertexai,
|
|
"google": self._configure_google,
|
|
"cohere": self._configure_cohere,
|
|
"bedrock": self._configure_bedrock,
|
|
"huggingface": self._configure_huggingface,
|
|
"watson": self._configure_watson,
|
|
}
|
|
|
|
def configure_embedder(
|
|
self,
|
|
embedder_config: Dict[str, Any] | None = None,
|
|
) -> EmbeddingFunction:
|
|
"""Configures and returns an embedding function based on the provided config."""
|
|
if embedder_config is None:
|
|
return self._create_default_embedding_function()
|
|
|
|
provider = embedder_config.get("provider")
|
|
config = embedder_config.get("config", {})
|
|
model_name = config.get("model")
|
|
|
|
if isinstance(provider, EmbeddingFunction):
|
|
try:
|
|
validate_embedding_function(provider)
|
|
return provider
|
|
except Exception as e:
|
|
raise ValueError(f"Invalid custom embedding function: {str(e)}")
|
|
|
|
if provider not in self.embedding_functions:
|
|
raise Exception(
|
|
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
|
|
)
|
|
|
|
return self.embedding_functions[provider](config, model_name)
|
|
|
|
@staticmethod
|
|
def _create_default_embedding_function():
|
|
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
|
OpenAIEmbeddingFunction,
|
|
)
|
|
|
|
return OpenAIEmbeddingFunction(
|
|
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
|
|
)
|
|
|
|
@staticmethod
|
|
def _configure_openai(config, model_name):
|
|
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
|
OpenAIEmbeddingFunction,
|
|
)
|
|
|
|
return OpenAIEmbeddingFunction(
|
|
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
|
model_name=model_name,
|
|
)
|
|
|
|
@staticmethod
|
|
def _configure_azure(config, model_name):
|
|
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
|
OpenAIEmbeddingFunction,
|
|
)
|
|
|
|
return OpenAIEmbeddingFunction(
|
|
api_key=config.get("api_key"),
|
|
api_base=config.get("api_base"),
|
|
api_type=config.get("api_type", "azure"),
|
|
api_version=config.get("api_version"),
|
|
model_name=model_name,
|
|
)
|
|
|
|
@staticmethod
|
|
def _configure_ollama(config, model_name):
|
|
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
|
OllamaEmbeddingFunction,
|
|
)
|
|
|
|
return OllamaEmbeddingFunction(
|
|
url=config.get("url", "http://localhost:11434/api/embeddings"),
|
|
model_name=model_name,
|
|
)
|
|
|
|
@staticmethod
|
|
def _configure_vertexai(config, model_name):
|
|
from chromadb.utils.embedding_functions.google_embedding_function import (
|
|
GoogleVertexEmbeddingFunction,
|
|
)
|
|
|
|
return GoogleVertexEmbeddingFunction(
|
|
model_name=model_name,
|
|
api_key=config.get("api_key"),
|
|
)
|
|
|
|
@staticmethod
|
|
def _configure_google(config, model_name):
|
|
from chromadb.utils.embedding_functions.google_embedding_function import (
|
|
GoogleGenerativeAiEmbeddingFunction,
|
|
)
|
|
|
|
return GoogleGenerativeAiEmbeddingFunction(
|
|
model_name=model_name,
|
|
api_key=config.get("api_key"),
|
|
)
|
|
|
|
@staticmethod
|
|
def _configure_cohere(config, model_name):
|
|
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
|
CohereEmbeddingFunction,
|
|
)
|
|
|
|
return CohereEmbeddingFunction(
|
|
model_name=model_name,
|
|
api_key=config.get("api_key"),
|
|
)
|
|
|
|
@staticmethod
|
|
def _configure_bedrock(config, model_name):
|
|
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
|
AmazonBedrockEmbeddingFunction,
|
|
)
|
|
|
|
return AmazonBedrockEmbeddingFunction(
|
|
session=config.get("session"),
|
|
)
|
|
|
|
@staticmethod
|
|
def _configure_huggingface(config, model_name):
|
|
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
|
HuggingFaceEmbeddingServer,
|
|
)
|
|
|
|
return HuggingFaceEmbeddingServer(
|
|
url=config.get("api_url"),
|
|
)
|
|
|
|
@staticmethod
|
|
def _configure_watson(config, model_name):
|
|
try:
|
|
import ibm_watsonx_ai.foundation_models as watson_models
|
|
from ibm_watsonx_ai import Credentials
|
|
from ibm_watsonx_ai.metanames import EmbedTextParamsMetaNames as EmbedParams
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"IBM Watson dependencies are not installed. Please install them to use Watson embedding."
|
|
) from e
|
|
|
|
class WatsonEmbeddingFunction(EmbeddingFunction):
|
|
def __call__(self, input: Documents) -> Embeddings:
|
|
if isinstance(input, str):
|
|
input = [input]
|
|
|
|
embed_params = {
|
|
EmbedParams.TRUNCATE_INPUT_TOKENS: 3,
|
|
EmbedParams.RETURN_OPTIONS: {"input_text": True},
|
|
}
|
|
|
|
embedding = watson_models.Embeddings(
|
|
model_id=config.get("model"),
|
|
params=embed_params,
|
|
credentials=Credentials(
|
|
api_key=config.get("api_key"), url=config.get("api_url")
|
|
),
|
|
project_id=config.get("project_id"),
|
|
)
|
|
|
|
try:
|
|
embeddings = embedding.embed_documents(input)
|
|
return cast(Embeddings, embeddings)
|
|
except Exception as e:
|
|
print("Error during Watson embedding:", e)
|
|
raise e
|
|
|
|
return WatsonEmbeddingFunction()
|