fix: Make model_name optional for custom embedders

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-10 00:01:24 +00:00
parent 5566c587a4
commit a2e1b3896e
2 changed files with 15 additions and 9 deletions

View File

@@ -66,11 +66,9 @@ class EmbeddingConfigurator:
) )
embedding_function = self.embedding_functions[provider] embedding_function = self.embedding_functions[provider]
return ( if provider == "custom":
embedding_function(config) return embedding_function(config)
if provider == "custom" return embedding_function(config, model_name)
else embedding_function(config, model_name)
)
@staticmethod @staticmethod
def _create_default_embedding_function(): def _create_default_embedding_function():
@@ -210,13 +208,13 @@ class EmbeddingConfigurator:
) )
@staticmethod @staticmethod
def _configure_custom(config, model_name): def _configure_custom(config, model_name=None):
"""Configure a custom embedding function. """Configure a custom embedding function.
Args: Args:
config: Configuration dictionary containing: config: Configuration dictionary containing:
- embedder: Custom EmbeddingFunction instance - embedder: Custom EmbeddingFunction instance
model_name: Not used for custom embedders model_name: Not used for custom embedders, defaults to None
Returns: Returns:
EmbeddingFunction: The validated custom embedding function EmbeddingFunction: The validated custom embedding function

View File

@@ -74,7 +74,11 @@ def test_memory_reset_with_custom_provider(temp_db_dir):
return [[0.5] * 10] * len(input) return [[0.5] * 10] * len(input)
memory = ShortTermMemory( memory = ShortTermMemory(
path=str(temp_db_dir), embedder_config={"provider": CustomEmbedder()} path=str(temp_db_dir),
embedder_config={
"provider": "custom",
"config": {"embedder": CustomEmbedder()}
}
) )
memory.reset() # Should work with custom embedder memory.reset() # Should work with custom embedder
@@ -132,7 +136,11 @@ def test_memory_reset_cleans_up_files(temp_db_dir):
if "CREWAI_EMBEDDING_PROVIDER" in os.environ: if "CREWAI_EMBEDDING_PROVIDER" in os.environ:
del os.environ["CREWAI_EMBEDDING_PROVIDER"] del os.environ["CREWAI_EMBEDDING_PROVIDER"]
memory = ShortTermMemory( memory = ShortTermMemory(
path=str(temp_db_dir), embedder_config={"provider": TestEmbedder()} path=str(temp_db_dir),
embedder_config={
"provider": "custom",
"config": {"embedder": TestEmbedder()}
}
) )
memory.save("test memory", {"test": "metadata"}) memory.save("test memory", {"test": "metadata"})
assert any(temp_db_dir.iterdir()) # Directory should have files assert any(temp_db_dir.iterdir()) # Directory should have files