From a2e1b3896ebb2ac097faf9d7c22943496bbfaa77 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Mon, 10 Feb 2025 00:01:24 +0000 Subject: [PATCH] fix: Make model_name optional for custom embedders Co-Authored-By: Joe Moura --- src/crewai/utilities/embedding_configurator.py | 12 +++++------- tests/memory/test_memory_reset.py | 12 ++++++++++-- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/crewai/utilities/embedding_configurator.py b/src/crewai/utilities/embedding_configurator.py index f04eaa036..3e9648444 100644 --- a/src/crewai/utilities/embedding_configurator.py +++ b/src/crewai/utilities/embedding_configurator.py @@ -66,11 +66,9 @@ class EmbeddingConfigurator: ) embedding_function = self.embedding_functions[provider] - return ( - embedding_function(config) - if provider == "custom" - else embedding_function(config, model_name) - ) + if provider == "custom": + return embedding_function(config) + return embedding_function(config, model_name) @staticmethod def _create_default_embedding_function(): @@ -210,13 +208,13 @@ class EmbeddingConfigurator: ) @staticmethod - def _configure_custom(config, model_name): + def _configure_custom(config, model_name=None): """Configure a custom embedding function. Args: config: Configuration dictionary containing: - embedder: Custom EmbeddingFunction instance - model_name: Not used for custom embedders + model_name: Not used for custom embedders, defaults to None Returns: EmbeddingFunction: The validated custom embedding function diff --git a/tests/memory/test_memory_reset.py b/tests/memory/test_memory_reset.py index 0855ec351..807ce99d8 100644 --- a/tests/memory/test_memory_reset.py +++ b/tests/memory/test_memory_reset.py @@ -74,7 +74,11 @@ def test_memory_reset_with_custom_provider(temp_db_dir): return [[0.5] * 10] * len(input) memory = ShortTermMemory( - path=str(temp_db_dir), embedder_config={"provider": CustomEmbedder()} + path=str(temp_db_dir), + embedder_config={ + "provider": "custom", + "config": {"embedder": CustomEmbedder()} + } ) memory.reset() # Should work with custom embedder @@ -132,7 +136,11 @@ def test_memory_reset_cleans_up_files(temp_db_dir): if "CREWAI_EMBEDDING_PROVIDER" in os.environ: del os.environ["CREWAI_EMBEDDING_PROVIDER"] memory = ShortTermMemory( - path=str(temp_db_dir), embedder_config={"provider": TestEmbedder()} + path=str(temp_db_dir), + embedder_config={ + "provider": "custom", + "config": {"embedder": TestEmbedder()} + } ) memory.save("test memory", {"test": "metadata"}) assert any(temp_db_dir.iterdir()) # Directory should have files