mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-15 11:58:31 +00:00
fix: ensure custom rag store persist path is set if passed
Co-authored-by: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com>
This commit is contained in:
@@ -16,6 +16,7 @@ from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.crew import Crew
|
||||
from crewai.rag.core.base_client import BaseClient
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
from crewai.rag.embeddings.types import ProviderSpec
|
||||
@@ -32,16 +33,16 @@ class RAGStorage(BaseRAGStorage):
|
||||
self,
|
||||
type: str,
|
||||
allow_reset: bool = True,
|
||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider | None = None,
|
||||
crew: Any = None,
|
||||
embedder_config: ProviderSpec | BaseEmbeddingsProvider[Any] | None = None,
|
||||
crew: Crew | None = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(type, allow_reset, embedder_config, crew)
|
||||
agents = crew.agents if crew else []
|
||||
agents = [self._sanitize_role(agent.role) for agent in agents]
|
||||
agents = "_".join(agents)
|
||||
self.agents = agents
|
||||
self.storage_file_name = self._build_storage_file_name(type, agents)
|
||||
crew_agents = crew.agents if crew else []
|
||||
sanitized_roles = [self._sanitize_role(agent.role) for agent in crew_agents]
|
||||
agents_str = "_".join(sanitized_roles)
|
||||
self.agents = agents_str
|
||||
self.storage_file_name = self._build_storage_file_name(type, agents_str)
|
||||
|
||||
self.type = type
|
||||
self._client: BaseClient | None = None
|
||||
@@ -96,6 +97,10 @@ class RAGStorage(BaseRAGStorage):
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
)
|
||||
|
||||
if self.path:
|
||||
config.settings.persist_directory = self.path
|
||||
|
||||
self._client = create_client(config)
|
||||
|
||||
def _get_client(self) -> BaseClient:
|
||||
|
||||
82
lib/crewai/tests/rag/test_rag_storage_path.py
Normal file
82
lib/crewai/tests/rag/test_rag_storage_path.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Tests for RAGStorage custom path functionality."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.create_client")
|
||||
@patch("crewai.memory.storage.rag_storage.build_embedder")
|
||||
def test_rag_storage_custom_path(
|
||||
mock_build_embedder: MagicMock,
|
||||
mock_create_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test RAGStorage uses custom path when provided."""
|
||||
mock_build_embedder.return_value = MagicMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_create_client.return_value = MagicMock()
|
||||
|
||||
custom_path = "/custom/memory/path"
|
||||
embedder_config = {"provider": "openai", "config": {"model": "text-embedding-3-small"}}
|
||||
|
||||
RAGStorage(
|
||||
type="short_term",
|
||||
crew=None,
|
||||
path=custom_path,
|
||||
embedder_config=embedder_config,
|
||||
)
|
||||
|
||||
mock_create_client.assert_called_once()
|
||||
config_arg = mock_create_client.call_args[0][0]
|
||||
assert config_arg.settings.persist_directory == custom_path
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.create_client")
|
||||
@patch("crewai.memory.storage.rag_storage.build_embedder")
|
||||
def test_rag_storage_default_path_when_none(
|
||||
mock_build_embedder: MagicMock,
|
||||
mock_create_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test RAGStorage uses default path when no custom path is provided."""
|
||||
mock_build_embedder.return_value = MagicMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_create_client.return_value = MagicMock()
|
||||
|
||||
embedder_config = {"provider": "openai", "config": {"model": "text-embedding-3-small"}}
|
||||
|
||||
storage = RAGStorage(
|
||||
type="short_term",
|
||||
crew=None,
|
||||
path=None,
|
||||
embedder_config=embedder_config,
|
||||
)
|
||||
|
||||
mock_create_client.assert_called_once()
|
||||
assert storage.path is None
|
||||
|
||||
|
||||
@patch("crewai.memory.storage.rag_storage.create_client")
|
||||
@patch("crewai.memory.storage.rag_storage.build_embedder")
|
||||
def test_rag_storage_custom_path_with_batch_size(
|
||||
mock_build_embedder: MagicMock,
|
||||
mock_create_client: MagicMock,
|
||||
) -> None:
|
||||
"""Test RAGStorage uses custom path with batch_size in config."""
|
||||
mock_build_embedder.return_value = MagicMock(return_value=[[0.1, 0.2, 0.3]])
|
||||
mock_create_client.return_value = MagicMock()
|
||||
|
||||
custom_path = "/custom/batch/path"
|
||||
embedder_config = {
|
||||
"provider": "openai",
|
||||
"config": {"model": "text-embedding-3-small", "batch_size": 100},
|
||||
}
|
||||
|
||||
RAGStorage(
|
||||
type="long_term",
|
||||
crew=None,
|
||||
path=custom_path,
|
||||
embedder_config=embedder_config,
|
||||
)
|
||||
|
||||
mock_create_client.assert_called_once()
|
||||
config_arg = mock_create_client.call_args[0][0]
|
||||
assert config_arg.settings.persist_directory == custom_path
|
||||
assert config_arg.batch_size == 100
|
||||
Reference in New Issue
Block a user