Refine custom embedder configuration support

- Update custom embedder configuration method to handle custom embedding functions
- Modify type hints for embedder configuration
- Remove unused model_name parameter in custom embedder configuration
This commit is contained in:
Lorenze Jay
2025-02-07 13:41:36 -08:00
parent d48211f7f8
commit 2d44356c81

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Any, Dict, Union, cast from typing import Any, Dict, Optional, cast
from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.api.types import validate_embedding_function from chromadb.api.types import validate_embedding_function
@@ -23,7 +23,7 @@ class EmbeddingConfigurator:
def configure_embedder( def configure_embedder(
self, self,
embedder_config: Union[Dict[str, Any], None] = None, embedder_config: Optional[Dict[str, Any]] = None,
) -> EmbeddingFunction: ) -> EmbeddingFunction:
"""Configures and returns an embedding function based on the provided config.""" """Configures and returns an embedding function based on the provided config."""
if embedder_config is None: if embedder_config is None:
@@ -38,7 +38,12 @@ class EmbeddingConfigurator:
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}" f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
) )
return self.embedding_functions[provider](config, model_name) embedding_function = self.embedding_functions[provider]
return (
embedding_function(config)
if provider == "custom"
else embedding_function(config, model_name)
)
@staticmethod @staticmethod
def _create_default_embedding_function(): def _create_default_embedding_function():
@@ -206,7 +211,7 @@ class EmbeddingConfigurator:
return WatsonEmbeddingFunction() return WatsonEmbeddingFunction()
@staticmethod @staticmethod
def _configure_custom(config, model_name): def _configure_custom(config):
custom_embedder = config.get("embedder") custom_embedder = config.get("embedder")
if isinstance(custom_embedder, EmbeddingFunction): if isinstance(custom_embedder, EmbeddingFunction):
try: try: