mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
fix: Make model_name optional for custom embedders
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -66,11 +66,9 @@ class EmbeddingConfigurator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
embedding_function = self.embedding_functions[provider]
|
embedding_function = self.embedding_functions[provider]
|
||||||
return (
|
if provider == "custom":
|
||||||
embedding_function(config)
|
return embedding_function(config)
|
||||||
if provider == "custom"
|
return embedding_function(config, model_name)
|
||||||
else embedding_function(config, model_name)
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _create_default_embedding_function():
|
def _create_default_embedding_function():
|
||||||
@@ -210,13 +208,13 @@ class EmbeddingConfigurator:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _configure_custom(config, model_name):
|
def _configure_custom(config, model_name=None):
|
||||||
"""Configure a custom embedding function.
|
"""Configure a custom embedding function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Configuration dictionary containing:
|
config: Configuration dictionary containing:
|
||||||
- embedder: Custom EmbeddingFunction instance
|
- embedder: Custom EmbeddingFunction instance
|
||||||
model_name: Not used for custom embedders
|
model_name: Not used for custom embedders, defaults to None
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
EmbeddingFunction: The validated custom embedding function
|
EmbeddingFunction: The validated custom embedding function
|
||||||
|
|||||||
@@ -74,7 +74,11 @@ def test_memory_reset_with_custom_provider(temp_db_dir):
|
|||||||
return [[0.5] * 10] * len(input)
|
return [[0.5] * 10] * len(input)
|
||||||
|
|
||||||
memory = ShortTermMemory(
|
memory = ShortTermMemory(
|
||||||
path=str(temp_db_dir), embedder_config={"provider": CustomEmbedder()}
|
path=str(temp_db_dir),
|
||||||
|
embedder_config={
|
||||||
|
"provider": "custom",
|
||||||
|
"config": {"embedder": CustomEmbedder()}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
memory.reset() # Should work with custom embedder
|
memory.reset() # Should work with custom embedder
|
||||||
|
|
||||||
@@ -132,7 +136,11 @@ def test_memory_reset_cleans_up_files(temp_db_dir):
|
|||||||
if "CREWAI_EMBEDDING_PROVIDER" in os.environ:
|
if "CREWAI_EMBEDDING_PROVIDER" in os.environ:
|
||||||
del os.environ["CREWAI_EMBEDDING_PROVIDER"]
|
del os.environ["CREWAI_EMBEDDING_PROVIDER"]
|
||||||
memory = ShortTermMemory(
|
memory = ShortTermMemory(
|
||||||
path=str(temp_db_dir), embedder_config={"provider": TestEmbedder()}
|
path=str(temp_db_dir),
|
||||||
|
embedder_config={
|
||||||
|
"provider": "custom",
|
||||||
|
"config": {"embedder": TestEmbedder()}
|
||||||
|
}
|
||||||
)
|
)
|
||||||
memory.save("test memory", {"test": "metadata"})
|
memory.save("test memory", {"test": "metadata"})
|
||||||
assert any(temp_db_dir.iterdir()) # Directory should have files
|
assert any(temp_db_dir.iterdir()) # Directory should have files
|
||||||
|
|||||||
Reference in New Issue
Block a user