'add typings to embedding configurator input arg'

This commit is contained in:
Nick Fujita
2025-02-20 17:52:13 +09:00
parent 00c2f5043e
commit f4642f11cc

View File

@@ -3,6 +3,31 @@ from typing import Any, Dict, Optional, cast
from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.api.types import validate_embedding_function
from pydantic import BaseModel
class EmbeddingProviderConfig(BaseModel):
model: str | None = None
url: str | None = None
project_id: str | None = None
region: str | None = None
task_type: str | None = None
session: str | None = None
api_url: str | None = None
embedder: str | callable | None = None
api_key: str | None = None
api_base: str | None = None
api_type: str | None = None
api_version: str | None = None
default_headers: str | None = None
dimensions: str | None = None
deployment_id: str | None = None
organization_id: str | None = None
class EmbeddingConfig(BaseModel):
provider: str
config: EmbeddingProviderConfig | None = None
class EmbeddingConfigurator:
@@ -23,15 +48,19 @@ class EmbeddingConfigurator:
def configure_embedder(
self,
embedder_config: Optional[Dict[str, Any]] = None,
embedder_config: EmbeddingConfig | None = None,
) -> EmbeddingFunction:
"""Configures and returns an embedding function based on the provided config."""
if embedder_config is None:
return self._create_default_embedding_function()
provider = embedder_config.get("provider")
config = embedder_config.get("config", {})
model_name = config.get("model") if provider != "custom" else None
provider = embedder_config.provider
config = (
embedder_config.config
if embedder_config.config
else EmbeddingProviderConfig()
)
model_name = config.model if provider != "custom" else None
if provider not in self.embedding_functions:
raise Exception(
@@ -56,123 +85,123 @@ class EmbeddingConfigurator:
)
@staticmethod
def _configure_openai(config, model_name):
def _configure_openai(config: EmbeddingProviderConfig, model_name: str):
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
api_key=config.api_key or os.getenv("OPENAI_API_KEY"),
model_name=model_name,
api_base=config.get("api_base", None),
api_type=config.get("api_type", None),
api_version=config.get("api_version", None),
default_headers=config.get("default_headers", None),
dimensions=config.get("dimensions", None),
deployment_id=config.get("deployment_id", None),
organization_id=config.get("organization_id", None),
api_base=config.api_base,
api_type=config.api_type,
api_version=config.api_version,
default_headers=config.default_headers,
dimensions=config.dimensions,
deployment_id=config.deployment_id,
organization_id=config.organization_id,
)
@staticmethod
def _configure_azure(config, model_name):
def _configure_azure(config: EmbeddingProviderConfig, model_name: str):
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
return OpenAIEmbeddingFunction(
api_key=config.get("api_key"),
api_base=config.get("api_base"),
api_type=config.get("api_type", "azure"),
api_version=config.get("api_version"),
api_key=config.api_key,
api_base=config.api_base,
api_type=config.api_type if config.api_type else "azure",
api_version=config.api_version,
model_name=model_name,
default_headers=config.get("default_headers"),
dimensions=config.get("dimensions"),
deployment_id=config.get("deployment_id"),
organization_id=config.get("organization_id"),
default_headers=config.default_headers,
dimensions=config.dimensions,
deployment_id=config.deployment_id,
organization_id=config.organization_id,
)
@staticmethod
def _configure_ollama(config, model_name):
def _configure_ollama(config: EmbeddingProviderConfig, model_name: str):
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
return OllamaEmbeddingFunction(
url=config.get("url", "http://localhost:11434/api/embeddings"),
url=config.url if config.url else "http://localhost:11434/api/embeddings",
model_name=model_name,
)
@staticmethod
def _configure_vertexai(config, model_name):
def _configure_vertexai(config: EmbeddingProviderConfig, model_name: str):
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleVertexEmbeddingFunction,
)
return GoogleVertexEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
project_id=config.get("project_id"),
region=config.get("region"),
api_key=config.api_key,
project_id=config.project_id,
region=config.region,
)
@staticmethod
def _configure_google(config, model_name):
def _configure_google(config: EmbeddingProviderConfig, model_name: str):
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
)
return GoogleGenerativeAiEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
task_type=config.get("task_type"),
api_key=config.api_key,
task_type=config.task_type,
)
@staticmethod
def _configure_cohere(config, model_name):
def _configure_cohere(config: EmbeddingProviderConfig, model_name: str):
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
return CohereEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
api_key=config.api_key,
)
@staticmethod
def _configure_voyageai(config, model_name):
def _configure_voyageai(config: EmbeddingProviderConfig, model_name: str):
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
VoyageAIEmbeddingFunction,
)
return VoyageAIEmbeddingFunction(
model_name=model_name,
api_key=config.get("api_key"),
api_key=config.api_key,
)
@staticmethod
def _configure_bedrock(config, model_name):
def _configure_bedrock(config: EmbeddingProviderConfig, model_name: str):
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)
# Allow custom model_name override with backwards compatibility
kwargs = {"session": config.get("session")}
kwargs = {"session": config.session}
if model_name is not None:
kwargs["model_name"] = model_name
return AmazonBedrockEmbeddingFunction(**kwargs)
@staticmethod
def _configure_huggingface(config, model_name):
def _configure_huggingface(config: EmbeddingProviderConfig, model_name: str):
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingServer,
)
return HuggingFaceEmbeddingServer(
url=config.get("api_url"),
url=config.api_url,
)
@staticmethod
def _configure_watson(config, model_name):
def _configure_watson(config: EmbeddingProviderConfig, model_name: str):
try:
import ibm_watsonx_ai.foundation_models as watson_models
from ibm_watsonx_ai import Credentials
@@ -193,12 +222,10 @@ class EmbeddingConfigurator:
}
embedding = watson_models.Embeddings(
model_id=config.get("model"),
model_id=config.model,
params=embed_params,
credentials=Credentials(
api_key=config.get("api_key"), url=config.get("api_url")
),
project_id=config.get("project_id"),
credentials=Credentials(api_key=config.api_key, url=config.api_url),
project_id=config.project_id,
)
try:
@@ -211,8 +238,8 @@ class EmbeddingConfigurator:
return WatsonEmbeddingFunction()
@staticmethod
def _configure_custom(config):
custom_embedder = config.get("embedder")
def _configure_custom(config: EmbeddingProviderConfig):
custom_embedder = config.embedder
if isinstance(custom_embedder, EmbeddingFunction):
try:
validate_embedding_function(custom_embedder)