From f4642f11cc27ac2b25d97cb912f3cac43d549b0c Mon Sep 17 00:00:00 2001 From: Nick Fujita Date: Thu, 20 Feb 2025 17:52:13 +0900 Subject: [PATCH] 'add typings to embedding configurator input arg' --- .../utilities/embedding_configurator.py | 121 +++++++++++------- 1 file changed, 74 insertions(+), 47 deletions(-) diff --git a/src/crewai/utilities/embedding_configurator.py b/src/crewai/utilities/embedding_configurator.py index e523b60f0..2b71ba09f 100644 --- a/src/crewai/utilities/embedding_configurator.py +++ b/src/crewai/utilities/embedding_configurator.py @@ -3,6 +3,31 @@ from typing import Any, Dict, Optional, cast from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb.api.types import validate_embedding_function +from pydantic import BaseModel + + +class EmbeddingProviderConfig(BaseModel): + model: str | None = None + url: str | None = None + project_id: str | None = None + region: str | None = None + task_type: str | None = None + session: str | None = None + api_url: str | None = None + embedder: str | callable | None = None + api_key: str | None = None + api_base: str | None = None + api_type: str | None = None + api_version: str | None = None + default_headers: str | None = None + dimensions: str | None = None + deployment_id: str | None = None + organization_id: str | None = None + + +class EmbeddingConfig(BaseModel): + provider: str + config: EmbeddingProviderConfig | None = None class EmbeddingConfigurator: @@ -23,15 +48,19 @@ class EmbeddingConfigurator: def configure_embedder( self, - embedder_config: Optional[Dict[str, Any]] = None, + embedder_config: EmbeddingConfig | None = None, ) -> EmbeddingFunction: """Configures and returns an embedding function based on the provided config.""" if embedder_config is None: return self._create_default_embedding_function() - provider = embedder_config.get("provider") - config = embedder_config.get("config", {}) - model_name = config.get("model") if provider != "custom" else None + provider = embedder_config.provider + config = ( + embedder_config.config + if embedder_config.config + else EmbeddingProviderConfig() + ) + model_name = config.model if provider != "custom" else None if provider not in self.embedding_functions: raise Exception( @@ -56,123 +85,123 @@ class EmbeddingConfigurator: ) @staticmethod - def _configure_openai(config, model_name): + def _configure_openai(config: EmbeddingProviderConfig, model_name: str): from chromadb.utils.embedding_functions.openai_embedding_function import ( OpenAIEmbeddingFunction, ) return OpenAIEmbeddingFunction( - api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"), + api_key=config.api_key or os.getenv("OPENAI_API_KEY"), model_name=model_name, - api_base=config.get("api_base", None), - api_type=config.get("api_type", None), - api_version=config.get("api_version", None), - default_headers=config.get("default_headers", None), - dimensions=config.get("dimensions", None), - deployment_id=config.get("deployment_id", None), - organization_id=config.get("organization_id", None), + api_base=config.api_base, + api_type=config.api_type, + api_version=config.api_version, + default_headers=config.default_headers, + dimensions=config.dimensions, + deployment_id=config.deployment_id, + organization_id=config.organization_id, ) @staticmethod - def _configure_azure(config, model_name): + def _configure_azure(config: EmbeddingProviderConfig, model_name: str): from chromadb.utils.embedding_functions.openai_embedding_function import ( OpenAIEmbeddingFunction, ) return OpenAIEmbeddingFunction( - api_key=config.get("api_key"), - api_base=config.get("api_base"), - api_type=config.get("api_type", "azure"), - api_version=config.get("api_version"), + api_key=config.api_key, + api_base=config.api_base, + api_type=config.api_type if config.api_type else "azure", + api_version=config.api_version, model_name=model_name, - default_headers=config.get("default_headers"), - dimensions=config.get("dimensions"), - deployment_id=config.get("deployment_id"), - organization_id=config.get("organization_id"), + default_headers=config.default_headers, + dimensions=config.dimensions, + deployment_id=config.deployment_id, + organization_id=config.organization_id, ) @staticmethod - def _configure_ollama(config, model_name): + def _configure_ollama(config: EmbeddingProviderConfig, model_name: str): from chromadb.utils.embedding_functions.ollama_embedding_function import ( OllamaEmbeddingFunction, ) return OllamaEmbeddingFunction( - url=config.get("url", "http://localhost:11434/api/embeddings"), + url=config.url if config.url else "http://localhost:11434/api/embeddings", model_name=model_name, ) @staticmethod - def _configure_vertexai(config, model_name): + def _configure_vertexai(config: EmbeddingProviderConfig, model_name: str): from chromadb.utils.embedding_functions.google_embedding_function import ( GoogleVertexEmbeddingFunction, ) return GoogleVertexEmbeddingFunction( model_name=model_name, - api_key=config.get("api_key"), - project_id=config.get("project_id"), - region=config.get("region"), + api_key=config.api_key, + project_id=config.project_id, + region=config.region, ) @staticmethod - def _configure_google(config, model_name): + def _configure_google(config: EmbeddingProviderConfig, model_name: str): from chromadb.utils.embedding_functions.google_embedding_function import ( GoogleGenerativeAiEmbeddingFunction, ) return GoogleGenerativeAiEmbeddingFunction( model_name=model_name, - api_key=config.get("api_key"), - task_type=config.get("task_type"), + api_key=config.api_key, + task_type=config.task_type, ) @staticmethod - def _configure_cohere(config, model_name): + def _configure_cohere(config: EmbeddingProviderConfig, model_name: str): from chromadb.utils.embedding_functions.cohere_embedding_function import ( CohereEmbeddingFunction, ) return CohereEmbeddingFunction( model_name=model_name, - api_key=config.get("api_key"), + api_key=config.api_key, ) @staticmethod - def _configure_voyageai(config, model_name): + def _configure_voyageai(config: EmbeddingProviderConfig, model_name: str): from chromadb.utils.embedding_functions.voyageai_embedding_function import ( VoyageAIEmbeddingFunction, ) return VoyageAIEmbeddingFunction( model_name=model_name, - api_key=config.get("api_key"), + api_key=config.api_key, ) @staticmethod - def _configure_bedrock(config, model_name): + def _configure_bedrock(config: EmbeddingProviderConfig, model_name: str): from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import ( AmazonBedrockEmbeddingFunction, ) # Allow custom model_name override with backwards compatibility - kwargs = {"session": config.get("session")} + kwargs = {"session": config.session} if model_name is not None: kwargs["model_name"] = model_name return AmazonBedrockEmbeddingFunction(**kwargs) @staticmethod - def _configure_huggingface(config, model_name): + def _configure_huggingface(config: EmbeddingProviderConfig, model_name: str): from chromadb.utils.embedding_functions.huggingface_embedding_function import ( HuggingFaceEmbeddingServer, ) return HuggingFaceEmbeddingServer( - url=config.get("api_url"), + url=config.api_url, ) @staticmethod - def _configure_watson(config, model_name): + def _configure_watson(config: EmbeddingProviderConfig, model_name: str): try: import ibm_watsonx_ai.foundation_models as watson_models from ibm_watsonx_ai import Credentials @@ -193,12 +222,10 @@ class EmbeddingConfigurator: } embedding = watson_models.Embeddings( - model_id=config.get("model"), + model_id=config.model, params=embed_params, - credentials=Credentials( - api_key=config.get("api_key"), url=config.get("api_url") - ), - project_id=config.get("project_id"), + credentials=Credentials(api_key=config.api_key, url=config.api_url), + project_id=config.project_id, ) try: @@ -211,8 +238,8 @@ class EmbeddingConfigurator: return WatsonEmbeddingFunction() @staticmethod - def _configure_custom(config): - custom_embedder = config.get("embedder") + def _configure_custom(config: EmbeddingProviderConfig): + custom_embedder = config.embedder if isinstance(custom_embedder, EmbeddingFunction): try: validate_embedding_function(custom_embedder)