Enhance embedding configuration with custom embedder support

- Add support for custom embedding functions in EmbeddingConfigurator
- Update type hints for embedder configuration
- Extend configuration options for various embedding providers
- Add optional embedder configuration to Memory class
This commit is contained in:
Lorenze Jay
2025-02-07 12:41:57 -08:00
parent abee94d056
commit cafac13447
3 changed files with 47 additions and 11 deletions

View File

@@ -8,6 +8,8 @@ class Memory:
Base class for memory, now supporting agent tags and generic metadata. Base class for memory, now supporting agent tags and generic metadata.
""" """
embedder_config: Optional[Dict[str, Any]] = None
def __init__(self, storage: RAGStorage): def __init__(self, storage: RAGStorage):
self.storage = storage self.storage = storage

View File

@@ -13,7 +13,7 @@ class BaseRAGStorage(ABC):
self, self,
type: str, type: str,
allow_reset: bool = True, allow_reset: bool = True,
embedder_config: Optional[Any] = None, embedder_config: Optional[Dict[str, Any]] = None,
crew: Any = None, crew: Any = None,
): ):
self.type = type self.type = type

View File

@@ -1,5 +1,5 @@
import os import os
from typing import Any, Dict, cast from typing import Any, Dict, Union, cast
from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb import Documents, EmbeddingFunction, Embeddings
from chromadb.api.types import validate_embedding_function from chromadb.api.types import validate_embedding_function
@@ -18,11 +18,12 @@ class EmbeddingConfigurator:
"bedrock": self._configure_bedrock, "bedrock": self._configure_bedrock,
"huggingface": self._configure_huggingface, "huggingface": self._configure_huggingface,
"watson": self._configure_watson, "watson": self._configure_watson,
"custom": self._configure_custom,
} }
def configure_embedder( def configure_embedder(
self, self,
embedder_config: Dict[str, Any] | None = None, embedder_config: Union[Dict[str, Any], None] = None,
) -> EmbeddingFunction: ) -> EmbeddingFunction:
"""Configures and returns an embedding function based on the provided config.""" """Configures and returns an embedding function based on the provided config."""
if embedder_config is None: if embedder_config is None:
@@ -30,19 +31,13 @@ class EmbeddingConfigurator:
provider = embedder_config.get("provider") provider = embedder_config.get("provider")
config = embedder_config.get("config", {}) config = embedder_config.get("config", {})
model_name = config.get("model") model_name = config.get("model") if provider != "custom" else None
if isinstance(provider, EmbeddingFunction):
try:
validate_embedding_function(provider)
return provider
except Exception as e:
raise ValueError(f"Invalid custom embedding function: {str(e)}")
if provider not in self.embedding_functions: if provider not in self.embedding_functions:
raise Exception( raise Exception(
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}" f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
) )
return self.embedding_functions[provider](config, model_name) return self.embedding_functions[provider](config, model_name)
@staticmethod @staticmethod
@@ -64,6 +59,13 @@ class EmbeddingConfigurator:
return OpenAIEmbeddingFunction( return OpenAIEmbeddingFunction(
api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"), api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"),
model_name=model_name, model_name=model_name,
api_base=config.get("api_base", None),
api_type=config.get("api_type", None),
api_version=config.get("api_version", None),
default_headers=config.get("default_headers", None),
dimensions=config.get("dimensions", None),
deployment_id=config.get("deployment_id", None),
organization_id=config.get("organization_id", None),
) )
@staticmethod @staticmethod
@@ -78,6 +80,10 @@ class EmbeddingConfigurator:
api_type=config.get("api_type", "azure"), api_type=config.get("api_type", "azure"),
api_version=config.get("api_version"), api_version=config.get("api_version"),
model_name=model_name, model_name=model_name,
default_headers=config.get("default_headers"),
dimensions=config.get("dimensions"),
deployment_id=config.get("deployment_id"),
organization_id=config.get("organization_id"),
) )
@staticmethod @staticmethod
@@ -100,6 +106,8 @@ class EmbeddingConfigurator:
return GoogleVertexEmbeddingFunction( return GoogleVertexEmbeddingFunction(
model_name=model_name, model_name=model_name,
api_key=config.get("api_key"), api_key=config.get("api_key"),
project_id=config.get("project_id"),
region=config.get("region"),
) )
@staticmethod @staticmethod
@@ -111,6 +119,7 @@ class EmbeddingConfigurator:
return GoogleGenerativeAiEmbeddingFunction( return GoogleGenerativeAiEmbeddingFunction(
model_name=model_name, model_name=model_name,
api_key=config.get("api_key"), api_key=config.get("api_key"),
task_type=config.get("task_type"),
) )
@staticmethod @staticmethod
@@ -195,3 +204,28 @@ class EmbeddingConfigurator:
raise e raise e
return WatsonEmbeddingFunction() return WatsonEmbeddingFunction()
@staticmethod
def _configure_custom(config, model_name):
custom_embedder = config.get("embedder")
if isinstance(custom_embedder, EmbeddingFunction):
try:
validate_embedding_function(custom_embedder)
return custom_embedder
except Exception as e:
raise ValueError(f"Invalid custom embedding function: {str(e)}")
elif callable(custom_embedder):
try:
instance = custom_embedder()
if isinstance(instance, EmbeddingFunction):
validate_embedding_function(instance)
return instance
raise ValueError(
"Custom embedder does not create an EmbeddingFunction instance"
)
except Exception as e:
raise ValueError(f"Error instantiating custom embedder: {str(e)}")
else:
raise ValueError(
"Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one"
)