From e529766391eac91f7a26517b3b8f9728a4a14544 Mon Sep 17 00:00:00 2001 From: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com> Date: Fri, 7 Feb 2025 13:49:46 -0800 Subject: [PATCH] Enhance embedding configuration with custom embedder support (#2060) * Enhance embedding configuration with custom embedder support - Add support for custom embedding functions in EmbeddingConfigurator - Update type hints for embedder configuration - Extend configuration options for various embedding providers - Add optional embedder configuration to Memory class * added docs * Refine custom embedder configuration support - Update custom embedder configuration method to handle custom embedding functions - Modify type hints for embedder configuration - Remove unused model_name parameter in custom embedder configuration --- docs/concepts/memory.mdx | 27 ++++++++ src/crewai/memory/memory.py | 2 + src/crewai/memory/storage/base_rag_storage.py | 2 +- .../utilities/embedding_configurator.py | 61 +++++++++++++++---- 4 files changed, 80 insertions(+), 12 deletions(-) diff --git a/docs/concepts/memory.mdx b/docs/concepts/memory.mdx index a725c41e7..33df47b82 100644 --- a/docs/concepts/memory.mdx +++ b/docs/concepts/memory.mdx @@ -368,6 +368,33 @@ my_crew = Crew( ) ``` +### Adding Custom Embedding Function + +```python Code +from crewai import Crew, Agent, Task, Process +from chromadb import Documents, EmbeddingFunction, Embeddings + +# Create a custom embedding function +class CustomEmbedder(EmbeddingFunction): + def __call__(self, input: Documents) -> Embeddings: + # generate embeddings + return [1, 2, 3] # this is a dummy embedding + +my_crew = Crew( + agents=[...], + tasks=[...], + process=Process.sequential, + memory=True, + verbose=True, + embedder={ + "provider": "custom", + "config": { + "embedder": CustomEmbedder() + } + } +) +``` + ### Resetting Memory ```shell diff --git a/src/crewai/memory/memory.py b/src/crewai/memory/memory.py index 51a700323..4387ebd64 100644 --- a/src/crewai/memory/memory.py +++ b/src/crewai/memory/memory.py @@ -10,6 +10,8 @@ class Memory(BaseModel): Base class for memory, now supporting agent tags and generic metadata. """ + embedder_config: Optional[Dict[str, Any]] = None + storage: Any def __init__(self, storage: Any, **data: Any): diff --git a/src/crewai/memory/storage/base_rag_storage.py b/src/crewai/memory/storage/base_rag_storage.py index 10b82ebff..4ab9acb99 100644 --- a/src/crewai/memory/storage/base_rag_storage.py +++ b/src/crewai/memory/storage/base_rag_storage.py @@ -13,7 +13,7 @@ class BaseRAGStorage(ABC): self, type: str, allow_reset: bool = True, - embedder_config: Optional[Any] = None, + embedder_config: Optional[Dict[str, Any]] = None, crew: Any = None, ): self.type = type diff --git a/src/crewai/utilities/embedding_configurator.py b/src/crewai/utilities/embedding_configurator.py index ef07c8ebf..e523b60f0 100644 --- a/src/crewai/utilities/embedding_configurator.py +++ b/src/crewai/utilities/embedding_configurator.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, cast +from typing import Any, Dict, Optional, cast from chromadb import Documents, EmbeddingFunction, Embeddings from chromadb.api.types import validate_embedding_function @@ -18,11 +18,12 @@ class EmbeddingConfigurator: "bedrock": self._configure_bedrock, "huggingface": self._configure_huggingface, "watson": self._configure_watson, + "custom": self._configure_custom, } def configure_embedder( self, - embedder_config: Dict[str, Any] | None = None, + embedder_config: Optional[Dict[str, Any]] = None, ) -> EmbeddingFunction: """Configures and returns an embedding function based on the provided config.""" if embedder_config is None: @@ -30,20 +31,19 @@ class EmbeddingConfigurator: provider = embedder_config.get("provider") config = embedder_config.get("config", {}) - model_name = config.get("model") - - if isinstance(provider, EmbeddingFunction): - try: - validate_embedding_function(provider) - return provider - except Exception as e: - raise ValueError(f"Invalid custom embedding function: {str(e)}") + model_name = config.get("model") if provider != "custom" else None if provider not in self.embedding_functions: raise Exception( f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}" ) - return self.embedding_functions[provider](config, model_name) + + embedding_function = self.embedding_functions[provider] + return ( + embedding_function(config) + if provider == "custom" + else embedding_function(config, model_name) + ) @staticmethod def _create_default_embedding_function(): @@ -64,6 +64,13 @@ class EmbeddingConfigurator: return OpenAIEmbeddingFunction( api_key=config.get("api_key") or os.getenv("OPENAI_API_KEY"), model_name=model_name, + api_base=config.get("api_base", None), + api_type=config.get("api_type", None), + api_version=config.get("api_version", None), + default_headers=config.get("default_headers", None), + dimensions=config.get("dimensions", None), + deployment_id=config.get("deployment_id", None), + organization_id=config.get("organization_id", None), ) @staticmethod @@ -78,6 +85,10 @@ class EmbeddingConfigurator: api_type=config.get("api_type", "azure"), api_version=config.get("api_version"), model_name=model_name, + default_headers=config.get("default_headers"), + dimensions=config.get("dimensions"), + deployment_id=config.get("deployment_id"), + organization_id=config.get("organization_id"), ) @staticmethod @@ -100,6 +111,8 @@ class EmbeddingConfigurator: return GoogleVertexEmbeddingFunction( model_name=model_name, api_key=config.get("api_key"), + project_id=config.get("project_id"), + region=config.get("region"), ) @staticmethod @@ -111,6 +124,7 @@ class EmbeddingConfigurator: return GoogleGenerativeAiEmbeddingFunction( model_name=model_name, api_key=config.get("api_key"), + task_type=config.get("task_type"), ) @staticmethod @@ -195,3 +209,28 @@ class EmbeddingConfigurator: raise e return WatsonEmbeddingFunction() + + @staticmethod + def _configure_custom(config): + custom_embedder = config.get("embedder") + if isinstance(custom_embedder, EmbeddingFunction): + try: + validate_embedding_function(custom_embedder) + return custom_embedder + except Exception as e: + raise ValueError(f"Invalid custom embedding function: {str(e)}") + elif callable(custom_embedder): + try: + instance = custom_embedder() + if isinstance(instance, EmbeddingFunction): + validate_embedding_function(instance) + return instance + raise ValueError( + "Custom embedder does not create an EmbeddingFunction instance" + ) + except Exception as e: + raise ValueError(f"Error instantiating custom embedder: {str(e)}") + else: + raise ValueError( + "Custom embedder must be an instance of `EmbeddingFunction` or a callable that creates one" + )