mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
'add typings to embedding configurator input arg'
This commit is contained in:
@@ -3,6 +3,31 @@ from typing import Any, Dict, Optional, cast
|
||||
|
||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||
from chromadb.api.types import validate_embedding_function
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class EmbeddingProviderConfig(BaseModel):
|
||||
model: str | None = None
|
||||
url: str | None = None
|
||||
project_id: str | None = None
|
||||
region: str | None = None
|
||||
task_type: str | None = None
|
||||
session: str | None = None
|
||||
api_url: str | None = None
|
||||
embedder: str | callable | None = None
|
||||
api_key: str | None = None
|
||||
api_base: str | None = None
|
||||
api_type: str | None = None
|
||||
api_version: str | None = None
|
||||
default_headers: str | None = None
|
||||
dimensions: str | None = None
|
||||
deployment_id: str | None = None
|
||||
organization_id: str | None = None
|
||||
|
||||
|
||||
class EmbeddingConfig(BaseModel):
|
||||
provider: str
|
||||
config: EmbeddingProviderConfig | None = None
|
||||
|
||||
|
||||
class EmbeddingConfigurator:
|
||||
@@ -23,15 +48,19 @@ class EmbeddingConfigurator:
|
||||
|
||||
def configure_embedder(
|
||||
self,
|
||||
embedder_config: Optional[Dict[str, Any]] = None,
|
||||
embedder_config: EmbeddingConfig | None = None,
|
||||
) -> EmbeddingFunction:
|
||||
"""Configures and returns an embedding function based on the provided config."""
|
||||
if embedder_config is None:
|
||||
return self._create_default_embedding_function()
|
||||
|
||||
provider = embedder_config.get("provider")
|
||||
config = embedder_config.get("config", {})
|
||||
model_name = config.get("model") if provider != "custom" else None
|
||||
provider = embedder_config.provider
|
||||
config = (
|
||||
embedder_config.config
|
||||
if embedder_config.config
|
||||
else EmbeddingProviderConfig()
|
||||
)
|
||||
model_name = config.model if provider != "custom" else None
|
||||
|
||||
if provider not in self.embedding_functions:
|
||||
raise Exception(
|
||||
@@ -56,123 +85,123 @@ class EmbeddingConfigurator:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_openai(config, model_name):
|
||||
def _configure_openai(config: EmbeddingProviderConfig, model_name: str):
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
||||
api_key=config.api_key or os.getenv("OPENAI_API_KEY"),
|
||||
model_name=model_name,
|
||||
api_base=config.get("api_base", None),
|
||||
api_type=config.get("api_type", None),
|
||||
api_version=config.get("api_version", None),
|
||||
default_headers=config.get("default_headers", None),
|
||||
dimensions=config.get("dimensions", None),
|
||||
deployment_id=config.get("deployment_id", None),
|
||||
organization_id=config.get("organization_id", None),
|
||||
api_base=config.api_base,
|
||||
api_type=config.api_type,
|
||||
api_version=config.api_version,
|
||||
default_headers=config.default_headers,
|
||||
dimensions=config.dimensions,
|
||||
deployment_id=config.deployment_id,
|
||||
organization_id=config.organization_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_azure(config, model_name):
|
||||
def _configure_azure(config: EmbeddingProviderConfig, model_name: str):
|
||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||
OpenAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OpenAIEmbeddingFunction(
|
||||
api_key=config.get("api_key"),
|
||||
api_base=config.get("api_base"),
|
||||
api_type=config.get("api_type", "azure"),
|
||||
api_version=config.get("api_version"),
|
||||
api_key=config.api_key,
|
||||
api_base=config.api_base,
|
||||
api_type=config.api_type if config.api_type else "azure",
|
||||
api_version=config.api_version,
|
||||
model_name=model_name,
|
||||
default_headers=config.get("default_headers"),
|
||||
dimensions=config.get("dimensions"),
|
||||
deployment_id=config.get("deployment_id"),
|
||||
organization_id=config.get("organization_id"),
|
||||
default_headers=config.default_headers,
|
||||
dimensions=config.dimensions,
|
||||
deployment_id=config.deployment_id,
|
||||
organization_id=config.organization_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_ollama(config, model_name):
|
||||
def _configure_ollama(config: EmbeddingProviderConfig, model_name: str):
|
||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||
OllamaEmbeddingFunction,
|
||||
)
|
||||
|
||||
return OllamaEmbeddingFunction(
|
||||
url=config.get("url", "http://localhost:11434/api/embeddings"),
|
||||
url=config.url if config.url else "http://localhost:11434/api/embeddings",
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_vertexai(config, model_name):
|
||||
def _configure_vertexai(config: EmbeddingProviderConfig, model_name: str):
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleVertexEmbeddingFunction,
|
||||
)
|
||||
|
||||
return GoogleVertexEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
project_id=config.get("project_id"),
|
||||
region=config.get("region"),
|
||||
api_key=config.api_key,
|
||||
project_id=config.project_id,
|
||||
region=config.region,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_google(config, model_name):
|
||||
def _configure_google(config: EmbeddingProviderConfig, model_name: str):
|
||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||
GoogleGenerativeAiEmbeddingFunction,
|
||||
)
|
||||
|
||||
return GoogleGenerativeAiEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
task_type=config.get("task_type"),
|
||||
api_key=config.api_key,
|
||||
task_type=config.task_type,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_cohere(config, model_name):
|
||||
def _configure_cohere(config: EmbeddingProviderConfig, model_name: str):
|
||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||
CohereEmbeddingFunction,
|
||||
)
|
||||
|
||||
return CohereEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
api_key=config.api_key,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_voyageai(config, model_name):
|
||||
def _configure_voyageai(config: EmbeddingProviderConfig, model_name: str):
|
||||
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
|
||||
VoyageAIEmbeddingFunction,
|
||||
)
|
||||
|
||||
return VoyageAIEmbeddingFunction(
|
||||
model_name=model_name,
|
||||
api_key=config.get("api_key"),
|
||||
api_key=config.api_key,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_bedrock(config, model_name):
|
||||
def _configure_bedrock(config: EmbeddingProviderConfig, model_name: str):
|
||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||
AmazonBedrockEmbeddingFunction,
|
||||
)
|
||||
|
||||
# Allow custom model_name override with backwards compatibility
|
||||
kwargs = {"session": config.get("session")}
|
||||
kwargs = {"session": config.session}
|
||||
if model_name is not None:
|
||||
kwargs["model_name"] = model_name
|
||||
return AmazonBedrockEmbeddingFunction(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _configure_huggingface(config, model_name):
|
||||
def _configure_huggingface(config: EmbeddingProviderConfig, model_name: str):
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingServer,
|
||||
)
|
||||
|
||||
return HuggingFaceEmbeddingServer(
|
||||
url=config.get("api_url"),
|
||||
url=config.api_url,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_watson(config, model_name):
|
||||
def _configure_watson(config: EmbeddingProviderConfig, model_name: str):
|
||||
try:
|
||||
import ibm_watsonx_ai.foundation_models as watson_models
|
||||
from ibm_watsonx_ai import Credentials
|
||||
@@ -193,12 +222,10 @@ class EmbeddingConfigurator:
|
||||
}
|
||||
|
||||
embedding = watson_models.Embeddings(
|
||||
model_id=config.get("model"),
|
||||
model_id=config.model,
|
||||
params=embed_params,
|
||||
credentials=Credentials(
|
||||
api_key=config.get("api_key"), url=config.get("api_url")
|
||||
),
|
||||
project_id=config.get("project_id"),
|
||||
credentials=Credentials(api_key=config.api_key, url=config.api_url),
|
||||
project_id=config.project_id,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -211,8 +238,8 @@ class EmbeddingConfigurator:
|
||||
return WatsonEmbeddingFunction()
|
||||
|
||||
@staticmethod
|
||||
def _configure_custom(config):
|
||||
custom_embedder = config.get("embedder")
|
||||
def _configure_custom(config: EmbeddingProviderConfig):
|
||||
custom_embedder = config.embedder
|
||||
if isinstance(custom_embedder, EmbeddingFunction):
|
||||
try:
|
||||
validate_embedding_function(custom_embedder)
|
||||
|
||||
Reference in New Issue
Block a user