From cafac13447804f20fe90f428a26f048933ba7dca Mon Sep 17 00:00:00 2001 From: Lorenze Jay Date: Fri, 7 Feb 2025 12:41:57 -0800 Subject: [PATCH] Enhance embedding configuration with custom embedder support - Add support for custom embedding functions in EmbeddingConfigurator - Update type hints for embedder configuration - Extend configuration options for various embedding providers - Add optional embedder configuration to Memory class --- src/crewai/memory/memory.py | 2 + src/crewai/memory/storage/base_rag_storage.py | 2 +- .../utilities/embedding_configurator.py | 54 +++++++++++++++---- 3 files changed, 47 insertions(+), 11 deletions(-) diff --git a/src/crewai/memory/memory.py b/src/crewai/memory/memory.py index 46af2c04d..a122628d5 100644 --- a/src/crewai/memory/memory.py +++ b/src/crewai/memory/memory.py @@ -8,6 +8,8 @@ class Memory: Base class for memory, now supporting agent tags and generic metadata. """ + embedder_config: Optional[Dict[str, Any]] = None + def __init__(self, storage: RAGStorage): self.storage = storage diff --git a/src/crewai/memory/storage/base_rag_storage.py b/src/crewai/memory/storage/base_rag_storage.py index 10b82ebff..4ab9acb99 100644 --- a/src/crewai/memory/storage/base_rag_storage.py +++ b/src/crewai/memory/storage/base_rag_storage.py @@ -13,7 +13,7 @@ class BaseRAGStorage(ABC): self, type: str, allow_reset: bool = True, - embedder_config: Optional[Any] = None, + embedder_config: Optional[Dict[str, Any]] = None, crew: Any = None, ): self.type = type diff --git a/src/crewai/utilities/embedding_configurator.py b/src/crewai/utilities/embedding_configurator.py index ef07c8ebf..245a70c54 100644 --- a/src/crewai/utilities/embedding_configurator.py +++ b/src/crewai/utilities/embedding_configurator.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, cast +from typing import Any, Dict, Union, cast from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb.api.types import validate_embedding_function @@ -18,11 +18,12 @@ class EmbeddingConfigurator: "bedrock": self._configure_bedrock, "huggingface": self._configure_huggingface, "watson": self._configure_watson, + "custom": self._configure_custom, } def configure_embedder( self, - embedder_config: Dict[str, Any] | None = None, + embedder_config: Union[Dict[str, Any], None] = None, ) -> EmbeddingFunction: """Configures and returns an embedding function based on the provided config.""" if embedder_config is None: @@ -30,19 +31,13 @@ class EmbeddingConfigurator: provider = embedder_config.get("provider") config = embedder_config.get("config", {}) - model_name = config.get("model") - - if isinstance(provider, EmbeddingFunction): - try: - validate_embedding_function(provider) - return provider - except Exception as e: - raise ValueError(f"Invalid custom embedding function: {str(e)}") + model_name = config.get("model") if provider != "custom" else None if provider not in self.embedding_functions: raise Exception( f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}" ) + return self.embedding_functions[provider](config, model_name) @staticmethod @@ -64,6 +59,13 @@ class EmbeddingConfigurator: return OpenAIEmbeddingFunction( api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"), model_name=model_name, + api_base=config.get("api_base", None), + api_type=config.get("api_type", None), + api_version=config.get("api_version", None), + default_headers=config.get("default_headers", None), + dimensions=config.get("dimensions", None), + deployment_id=config.get("deployment_id", None), + organization_id=config.get("organization_id", None), ) @staticmethod @@ -78,6 +80,10 @@ class EmbeddingConfigurator: api_type=config.get("api_type", "azure"), api_version=config.get("api_version"), model_name=model_name, + default_headers=config.get("default_headers"), + dimensions=config.get("dimensions"), + deployment_id=config.get("deployment_id"), + organization_id=config.get("organization_id"), ) @staticmethod @@ -100,6 +106,8 @@ class EmbeddingConfigurator: return GoogleVertexEmbeddingFunction( model_name=model_name, api_key=config.get("api_key"), + project_id=config.get("project_id"), + region=config.get("region"), ) @staticmethod @@ -111,6 +119,7 @@ class EmbeddingConfigurator: return GoogleGenerativeAiEmbeddingFunction( model_name=model_name, api_key=config.get("api_key"), + task_type=config.get("task_type"), ) @staticmethod @@ -195,3 +204,28 @@ class EmbeddingConfigurator: raise e return WatsonEmbeddingFunction() + + @staticmethod + def _configure_custom(config, model_name): + custom_embedder = config.get("embedder") + if isinstance(custom_embedder, EmbeddingFunction): + try: + validate_embedding_function(custom_embedder) + return custom_embedder + except Exception as e: + raise ValueError(f"Invalid custom embedding function: {str(e)}") + elif callable(custom_embedder): + try: + instance = custom_embedder() + if isinstance(instance, EmbeddingFunction): + validate_embedding_function(instance) + return instance + raise ValueError( + "Custom embedder does not create an EmbeddingFunction instance" + ) + except Exception as e: + raise ValueError(f"Error instantiating custom embedder: {str(e)}") + else: + raise ValueError( + "Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one" + )