fix: Make model_name optional for custom embedders

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-10 00:01:24 +00:00
parent 5566c587a4
commit a2e1b3896e
2 changed files with 15 additions and 9 deletions

View File

@@ -66,11 +66,9 @@ class EmbeddingConfigurator:
)
embedding_function = self.embedding_functions[provider]
return (
embedding_function(config)
if provider == "custom"
else embedding_function(config, model_name)
)
if provider == "custom":
return embedding_function(config)
return embedding_function(config, model_name)
@staticmethod
def _create_default_embedding_function():
@@ -210,13 +208,13 @@ class EmbeddingConfigurator:
)
@staticmethod
def _configure_custom(config, model_name):
def _configure_custom(config, model_name=None):
"""Configure a custom embedding function.
Args:
config: Configuration dictionary containing:
- embedder: Custom EmbeddingFunction instance
model_name: Not used for custom embedders
model_name: Not used for custom embedders, defaults to None
Returns:
EmbeddingFunction: The validated custom embedding function

View File

@@ -74,7 +74,11 @@ def test_memory_reset_with_custom_provider(temp_db_dir):
return [[0.5] * 10] * len(input)
memory = ShortTermMemory(
path=str(temp_db_dir), embedder_config={"provider": CustomEmbedder()}
path=str(temp_db_dir),
embedder_config={
"provider": "custom",
"config": {"embedder": CustomEmbedder()}
}
)
memory.reset() # Should work with custom embedder
@@ -132,7 +136,11 @@ def test_memory_reset_cleans_up_files(temp_db_dir):
if "CREWAI_EMBEDDING_PROVIDER" in os.environ:
del os.environ["CREWAI_EMBEDDING_PROVIDER"]
memory = ShortTermMemory(
path=str(temp_db_dir), embedder_config={"provider": TestEmbedder()}
path=str(temp_db_dir),
embedder_config={
"provider": "custom",
"config": {"embedder": TestEmbedder()}
}
)
memory.save("test memory", {"test": "metadata"})
assert any(temp_db_dir.iterdir()) # Directory should have files