From 2d44356c815614194f29f898de1d94a22dd5e4f5 Mon Sep 17 00:00:00 2001 From: Lorenze Jay Date: Fri, 7 Feb 2025 13:41:36 -0800 Subject: [PATCH] Refine custom embedder configuration support - Update custom embedder configuration method to handle custom embedding functions - Modify type hints for embedder configuration - Remove unused model_name parameter in custom embedder configuration --- src/crewai/utilities/embedding_configurator.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/crewai/utilities/embedding_configurator.py b/src/crewai/utilities/embedding_configurator.py index 245a70c54..e523b60f0 100644 --- a/src/crewai/utilities/embedding_configurator.py +++ b/src/crewai/utilities/embedding_configurator.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, Union, cast +from typing import Any, Dict, Optional, cast from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb.api.types import validate_embedding_function @@ -23,7 +23,7 @@ class EmbeddingConfigurator: def configure_embedder( self, - embedder_config: Union[Dict[str, Any], None] = None, + embedder_config: Optional[Dict[str, Any]] = None, ) -> EmbeddingFunction: """Configures and returns an embedding function based on the provided config.""" if embedder_config is None: @@ -38,7 +38,12 @@ class EmbeddingConfigurator: f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}" ) - return self.embedding_functions[provider](config, model_name) + embedding_function = self.embedding_functions[provider] + return ( + embedding_function(config) + if provider == "custom" + else embedding_function(config, model_name) + ) @staticmethod def _create_default_embedding_function(): @@ -206,7 +211,7 @@ class EmbeddingConfigurator: return WatsonEmbeddingFunction() @staticmethod - def _configure_custom(config, model_name): + def _configure_custom(config): custom_embedder = config.get("embedder") if isinstance(custom_embedder, EmbeddingFunction): try: