mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-02 07:42:40 +00:00
'add typings to embedding configurator input arg'
This commit is contained in:
@@ -3,6 +3,31 @@ from typing import Any, Dict, Optional, cast
|
|||||||
|
|
||||||
from chromadb import Documents, EmbeddingFunction, Embeddings
|
from chromadb import Documents, EmbeddingFunction, Embeddings
|
||||||
from chromadb.api.types import validate_embedding_function
|
from chromadb.api.types import validate_embedding_function
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingProviderConfig(BaseModel):
|
||||||
|
model: str | None = None
|
||||||
|
url: str | None = None
|
||||||
|
project_id: str | None = None
|
||||||
|
region: str | None = None
|
||||||
|
task_type: str | None = None
|
||||||
|
session: str | None = None
|
||||||
|
api_url: str | None = None
|
||||||
|
embedder: str | callable | None = None
|
||||||
|
api_key: str | None = None
|
||||||
|
api_base: str | None = None
|
||||||
|
api_type: str | None = None
|
||||||
|
api_version: str | None = None
|
||||||
|
default_headers: str | None = None
|
||||||
|
dimensions: str | None = None
|
||||||
|
deployment_id: str | None = None
|
||||||
|
organization_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingConfig(BaseModel):
|
||||||
|
provider: str
|
||||||
|
config: EmbeddingProviderConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingConfigurator:
|
class EmbeddingConfigurator:
|
||||||
@@ -23,15 +48,19 @@ class EmbeddingConfigurator:
|
|||||||
|
|
||||||
def configure_embedder(
|
def configure_embedder(
|
||||||
self,
|
self,
|
||||||
embedder_config: Optional[Dict[str, Any]] = None,
|
embedder_config: EmbeddingConfig | None = None,
|
||||||
) -> EmbeddingFunction:
|
) -> EmbeddingFunction:
|
||||||
"""Configures and returns an embedding function based on the provided config."""
|
"""Configures and returns an embedding function based on the provided config."""
|
||||||
if embedder_config is None:
|
if embedder_config is None:
|
||||||
return self._create_default_embedding_function()
|
return self._create_default_embedding_function()
|
||||||
|
|
||||||
provider = embedder_config.get("provider")
|
provider = embedder_config.provider
|
||||||
config = embedder_config.get("config", {})
|
config = (
|
||||||
model_name = config.get("model") if provider != "custom" else None
|
embedder_config.config
|
||||||
|
if embedder_config.config
|
||||||
|
else EmbeddingProviderConfig()
|
||||||
|
)
|
||||||
|
model_name = config.model if provider != "custom" else None
|
||||||
|
|
||||||
if provider not in self.embedding_functions:
|
if provider not in self.embedding_functions:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
@@ -56,123 +85,123 @@ class EmbeddingConfigurator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_openai(config, model_name):
|
def _configure_openai(config: EmbeddingProviderConfig, model_name: str):
|
||||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||||
OpenAIEmbeddingFunction,
|
OpenAIEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
return OpenAIEmbeddingFunction(
|
return OpenAIEmbeddingFunction(
|
||||||
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
|
api_key=config.api_key or os.getenv("OPENAI_API_KEY"),
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_base=config.get("api_base", None),
|
api_base=config.api_base,
|
||||||
api_type=config.get("api_type", None),
|
api_type=config.api_type,
|
||||||
api_version=config.get("api_version", None),
|
api_version=config.api_version,
|
||||||
default_headers=config.get("default_headers", None),
|
default_headers=config.default_headers,
|
||||||
dimensions=config.get("dimensions", None),
|
dimensions=config.dimensions,
|
||||||
deployment_id=config.get("deployment_id", None),
|
deployment_id=config.deployment_id,
|
||||||
organization_id=config.get("organization_id", None),
|
organization_id=config.organization_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_azure(config, model_name):
|
def _configure_azure(config: EmbeddingProviderConfig, model_name: str):
|
||||||
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
from chromadb.utils.embedding_functions.openai_embedding_function import (
|
||||||
OpenAIEmbeddingFunction,
|
OpenAIEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
return OpenAIEmbeddingFunction(
|
return OpenAIEmbeddingFunction(
|
||||||
api_key=config.get("api_key"),
|
api_key=config.api_key,
|
||||||
api_base=config.get("api_base"),
|
api_base=config.api_base,
|
||||||
api_type=config.get("api_type", "azure"),
|
api_type=config.api_type if config.api_type else "azure",
|
||||||
api_version=config.get("api_version"),
|
api_version=config.api_version,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
default_headers=config.get("default_headers"),
|
default_headers=config.default_headers,
|
||||||
dimensions=config.get("dimensions"),
|
dimensions=config.dimensions,
|
||||||
deployment_id=config.get("deployment_id"),
|
deployment_id=config.deployment_id,
|
||||||
organization_id=config.get("organization_id"),
|
organization_id=config.organization_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_ollama(config, model_name):
|
def _configure_ollama(config: EmbeddingProviderConfig, model_name: str):
|
||||||
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
from chromadb.utils.embedding_functions.ollama_embedding_function import (
|
||||||
OllamaEmbeddingFunction,
|
OllamaEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
return OllamaEmbeddingFunction(
|
return OllamaEmbeddingFunction(
|
||||||
url=config.get("url", "http://localhost:11434/api/embeddings"),
|
url=config.url if config.url else "http://localhost:11434/api/embeddings",
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_vertexai(config, model_name):
|
def _configure_vertexai(config: EmbeddingProviderConfig, model_name: str):
|
||||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||||
GoogleVertexEmbeddingFunction,
|
GoogleVertexEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
return GoogleVertexEmbeddingFunction(
|
return GoogleVertexEmbeddingFunction(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_key=config.get("api_key"),
|
api_key=config.api_key,
|
||||||
project_id=config.get("project_id"),
|
project_id=config.project_id,
|
||||||
region=config.get("region"),
|
region=config.region,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_google(config, model_name):
|
def _configure_google(config: EmbeddingProviderConfig, model_name: str):
|
||||||
from chromadb.utils.embedding_functions.google_embedding_function import (
|
from chromadb.utils.embedding_functions.google_embedding_function import (
|
||||||
GoogleGenerativeAiEmbeddingFunction,
|
GoogleGenerativeAiEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
return GoogleGenerativeAiEmbeddingFunction(
|
return GoogleGenerativeAiEmbeddingFunction(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_key=config.get("api_key"),
|
api_key=config.api_key,
|
||||||
task_type=config.get("task_type"),
|
task_type=config.task_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_cohere(config, model_name):
|
def _configure_cohere(config: EmbeddingProviderConfig, model_name: str):
|
||||||
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
from chromadb.utils.embedding_functions.cohere_embedding_function import (
|
||||||
CohereEmbeddingFunction,
|
CohereEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
return CohereEmbeddingFunction(
|
return CohereEmbeddingFunction(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_key=config.get("api_key"),
|
api_key=config.api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_voyageai(config, model_name):
|
def _configure_voyageai(config: EmbeddingProviderConfig, model_name: str):
|
||||||
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
|
from chromadb.utils.embedding_functions.voyageai_embedding_function import (
|
||||||
VoyageAIEmbeddingFunction,
|
VoyageAIEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
return VoyageAIEmbeddingFunction(
|
return VoyageAIEmbeddingFunction(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
api_key=config.get("api_key"),
|
api_key=config.api_key,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_bedrock(config, model_name):
|
def _configure_bedrock(config: EmbeddingProviderConfig, model_name: str):
|
||||||
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
|
||||||
AmazonBedrockEmbeddingFunction,
|
AmazonBedrockEmbeddingFunction,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Allow custom model_name override with backwards compatibility
|
# Allow custom model_name override with backwards compatibility
|
||||||
kwargs = {"session": config.get("session")}
|
kwargs = {"session": config.session}
|
||||||
if model_name is not None:
|
if model_name is not None:
|
||||||
kwargs["model_name"] = model_name
|
kwargs["model_name"] = model_name
|
||||||
return AmazonBedrockEmbeddingFunction(**kwargs)
|
return AmazonBedrockEmbeddingFunction(**kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_huggingface(config, model_name):
|
def _configure_huggingface(config: EmbeddingProviderConfig, model_name: str):
|
||||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||||
HuggingFaceEmbeddingServer,
|
HuggingFaceEmbeddingServer,
|
||||||
)
|
)
|
||||||
|
|
||||||
return HuggingFaceEmbeddingServer(
|
return HuggingFaceEmbeddingServer(
|
||||||
url=config.get("api_url"),
|
url=config.api_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_watson(config, model_name):
|
def _configure_watson(config: EmbeddingProviderConfig, model_name: str):
|
||||||
try:
|
try:
|
||||||
import ibm_watsonx_ai.foundation_models as watson_models
|
import ibm_watsonx_ai.foundation_models as watson_models
|
||||||
from ibm_watsonx_ai import Credentials
|
from ibm_watsonx_ai import Credentials
|
||||||
@@ -193,12 +222,10 @@ class EmbeddingConfigurator:
|
|||||||
}
|
}
|
||||||
|
|
||||||
embedding = watson_models.Embeddings(
|
embedding = watson_models.Embeddings(
|
||||||
model_id=config.get("model"),
|
model_id=config.model,
|
||||||
params=embed_params,
|
params=embed_params,
|
||||||
credentials=Credentials(
|
credentials=Credentials(api_key=config.api_key, url=config.api_url),
|
||||||
api_key=config.get("api_key"), url=config.get("api_url")
|
project_id=config.project_id,
|
||||||
),
|
|
||||||
project_id=config.get("project_id"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -211,8 +238,8 @@ class EmbeddingConfigurator:
|
|||||||
return WatsonEmbeddingFunction()
|
return WatsonEmbeddingFunction()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_custom(config):
|
def _configure_custom(config: EmbeddingProviderConfig):
|
||||||
custom_embedder = config.get("embedder")
|
custom_embedder = config.embedder
|
||||||
if isinstance(custom_embedder, EmbeddingFunction):
|
if isinstance(custom_embedder, EmbeddingFunction):
|
||||||
try:
|
try:
|
||||||
validate_embedding_function(custom_embedder)
|
validate_embedding_function(custom_embedder)
|
||||||
|
|||||||
Reference in New Issue
Block a user