mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
fix: Make model_name optional for custom embedders
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -66,11 +66,9 @@ class EmbeddingConfigurator:
|
||||
)
|
||||
|
||||
embedding_function = self.embedding_functions[provider]
|
||||
return (
|
||||
embedding_function(config)
|
||||
if provider == "custom"
|
||||
else embedding_function(config, model_name)
|
||||
)
|
||||
if provider == "custom":
|
||||
return embedding_function(config)
|
||||
return embedding_function(config, model_name)
|
||||
|
||||
@staticmethod
|
||||
def _create_default_embedding_function():
|
||||
@@ -210,13 +208,13 @@ class EmbeddingConfigurator:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _configure_custom(config, model_name):
|
||||
def _configure_custom(config, model_name=None):
|
||||
"""Configure a custom embedding function.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary containing:
|
||||
- embedder: Custom EmbeddingFunction instance
|
||||
model_name: Not used for custom embedders
|
||||
model_name: Not used for custom embedders, defaults to None
|
||||
|
||||
Returns:
|
||||
EmbeddingFunction: The validated custom embedding function
|
||||
|
||||
@@ -74,7 +74,11 @@ def test_memory_reset_with_custom_provider(temp_db_dir):
|
||||
return [[0.5] * 10] * len(input)
|
||||
|
||||
memory = ShortTermMemory(
|
||||
path=str(temp_db_dir), embedder_config={"provider": CustomEmbedder()}
|
||||
path=str(temp_db_dir),
|
||||
embedder_config={
|
||||
"provider": "custom",
|
||||
"config": {"embedder": CustomEmbedder()}
|
||||
}
|
||||
)
|
||||
memory.reset() # Should work with custom embedder
|
||||
|
||||
@@ -132,7 +136,11 @@ def test_memory_reset_cleans_up_files(temp_db_dir):
|
||||
if "CREWAI_EMBEDDING_PROVIDER" in os.environ:
|
||||
del os.environ["CREWAI_EMBEDDING_PROVIDER"]
|
||||
memory = ShortTermMemory(
|
||||
path=str(temp_db_dir), embedder_config={"provider": TestEmbedder()}
|
||||
path=str(temp_db_dir),
|
||||
embedder_config={
|
||||
"provider": "custom",
|
||||
"config": {"embedder": TestEmbedder()}
|
||||
}
|
||||
)
|
||||
memory.save("test memory", {"test": "metadata"})
|
||||
assert any(temp_db_dir.iterdir()) # Directory should have files
|
||||
|
||||
Reference in New Issue
Block a user