From 50508297c99bf714c1b79dc8c2d0f72059e34581 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sat, 28 Dec 2024 01:36:03 +0000 Subject: [PATCH] feat: centralize default memory path logic & add path validation tests Co-Authored-By: Joe Moura --- .../memory/long_term/long_term_memory.py | 11 ++- src/crewai/memory/storage/base_rag_storage.py | 4 +- .../storage/kickoff_task_outputs_storage.py | 8 +- .../memory/storage/ltm_sqlite_storage.py | 4 +- src/crewai/utilities/paths.py | 23 +++++ tests/memory/test_storage_paths.py | 83 +++++++++++++++++++ uv.lock | 68 ++++++++------- 7 files changed, 159 insertions(+), 42 deletions(-) create mode 100644 tests/memory/test_storage_paths.py diff --git a/src/crewai/memory/long_term/long_term_memory.py b/src/crewai/memory/long_term/long_term_memory.py index 656709ac9..a856ea67b 100644 --- a/src/crewai/memory/long_term/long_term_memory.py +++ b/src/crewai/memory/long_term/long_term_memory.py @@ -15,8 +15,17 @@ class LongTermMemory(Memory): """ def __init__(self, storage=None, path=None): + """Initialize long term memory. + + Args: + storage: Optional custom storage instance + path: Optional custom path for storage location + + Note: + If both storage and path are provided, storage takes precedence + """ if not storage: - storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage() + storage = LTMSQLiteStorage(storage_path=path) if path else LTMSQLiteStorage() super().__init__(storage) def save(self, item: LongTermMemoryItem) -> None: # type: ignore # BUG?: Signature of "save" incompatible with supertype "Memory" diff --git a/src/crewai/memory/storage/base_rag_storage.py b/src/crewai/memory/storage/base_rag_storage.py index 9ff827484..f799cf696 100644 --- a/src/crewai/memory/storage/base_rag_storage.py +++ b/src/crewai/memory/storage/base_rag_storage.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, TypeVar from abc import ABC, abstractmethod from pathlib import Path -from crewai.utilities.paths import db_storage_path +from crewai.utilities.paths import get_default_storage_path class BaseRAGStorage(ABC): @@ -37,7 +37,7 @@ class BaseRAGStorage(ABC): OSError: If storage path cannot be created """ self.type = type - self.storage_path = storage_path if storage_path else db_storage_path() + self.storage_path = storage_path if storage_path else get_default_storage_path('rag') # Validate storage path try: diff --git a/src/crewai/memory/storage/kickoff_task_outputs_storage.py b/src/crewai/memory/storage/kickoff_task_outputs_storage.py index 284e9639c..bb593cfbf 100644 --- a/src/crewai/memory/storage/kickoff_task_outputs_storage.py +++ b/src/crewai/memory/storage/kickoff_task_outputs_storage.py @@ -7,7 +7,7 @@ from typing import Any, Dict, List, Optional from crewai.task import Task from crewai.utilities import Printer from crewai.utilities.crew_json_encoder import CrewJSONEncoder -from crewai.utilities.paths import db_storage_path +from crewai.utilities.paths import get_default_storage_path class KickoffTaskOutputsSQLiteStorage: @@ -25,11 +25,7 @@ class KickoffTaskOutputsSQLiteStorage: PermissionError: If storage path is not writable OSError: If storage path cannot be created """ - self.storage_path = ( - storage_path - if storage_path - else Path(f"{db_storage_path()}/latest_kickoff_task_outputs.db") - ) + self.storage_path = storage_path if storage_path else get_default_storage_path('kickoff') # Validate storage path try: diff --git a/src/crewai/memory/storage/ltm_sqlite_storage.py b/src/crewai/memory/storage/ltm_sqlite_storage.py index 8a61cdfc1..3c8153669 100644 --- a/src/crewai/memory/storage/ltm_sqlite_storage.py +++ b/src/crewai/memory/storage/ltm_sqlite_storage.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union from crewai.utilities import Printer -from crewai.utilities.paths import db_storage_path +from crewai.utilities.paths import get_default_storage_path class LTMSQLiteStorage: @@ -23,7 +23,7 @@ class LTMSQLiteStorage: PermissionError: If storage path is not writable OSError: If storage path cannot be created """ - self.storage_path = storage_path if storage_path else Path(f"{db_storage_path()}/latest_long_term_memories.db") + self.storage_path = storage_path if storage_path else get_default_storage_path('ltm') # Validate storage path try: diff --git a/src/crewai/utilities/paths.py b/src/crewai/utilities/paths.py index 51cf8b4e4..48381f1af 100644 --- a/src/crewai/utilities/paths.py +++ b/src/crewai/utilities/paths.py @@ -22,3 +22,26 @@ def get_project_directory_name(): cwd = Path.cwd() project_directory_name = cwd.name return project_directory_name + +def get_default_storage_path(storage_type: str) -> Path: + """Returns the default storage path for a given storage type. + + Args: + storage_type: Type of storage ('ltm', 'kickoff', 'rag') + + Returns: + Path: Default storage path for the specified type + + Raises: + ValueError: If storage_type is not recognized + """ + base_path = db_storage_path() + + if storage_type == 'ltm': + return base_path / 'latest_long_term_memories.db' + elif storage_type == 'kickoff': + return base_path / 'latest_kickoff_task_outputs.db' + elif storage_type == 'rag': + return base_path + else: + raise ValueError(f"Unknown storage type: {storage_type}") diff --git a/tests/memory/test_storage_paths.py b/tests/memory/test_storage_paths.py new file mode 100644 index 000000000..bcfa32474 --- /dev/null +++ b/tests/memory/test_storage_paths.py @@ -0,0 +1,83 @@ +import os +import tempfile +from pathlib import Path +import pytest +from unittest.mock import patch + +from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage +from crewai.memory.storage.kickoff_task_outputs_storage import KickoffTaskOutputsSQLiteStorage +from crewai.memory.storage.base_rag_storage import BaseRAGStorage +from crewai.utilities.paths import get_default_storage_path + +class MockRAGStorage(BaseRAGStorage): + """Mock implementation of BaseRAGStorage for testing.""" + def _sanitize_role(self, role: str) -> str: + return role.lower() + + def save(self, value, metadata): + pass + + def search(self, query, limit=3, filter=None, score_threshold=0.35): + return [] + + def reset(self): + pass + + def _generate_embedding(self, text, metadata=None): + return [] + + def _initialize_app(self): + pass + +def test_default_storage_paths(): + """Test that default storage paths are created correctly.""" + ltm_path = get_default_storage_path('ltm') + kickoff_path = get_default_storage_path('kickoff') + rag_path = get_default_storage_path('rag') + + assert str(ltm_path).endswith('latest_long_term_memories.db') + assert str(kickoff_path).endswith('latest_kickoff_task_outputs.db') + assert isinstance(rag_path, Path) + +def test_custom_storage_paths(): + """Test that custom storage paths are respected.""" + with tempfile.TemporaryDirectory() as temp_dir: + custom_path = Path(temp_dir) / 'custom.db' + + ltm = LTMSQLiteStorage(storage_path=custom_path) + assert ltm.storage_path == custom_path + + kickoff = KickoffTaskOutputsSQLiteStorage(storage_path=custom_path) + assert kickoff.storage_path == custom_path + + rag = MockRAGStorage('test', storage_path=custom_path) + assert rag.storage_path == custom_path + +def test_directory_creation(): + """Test that storage directories are created automatically.""" + with tempfile.TemporaryDirectory() as temp_dir: + test_dir = Path(temp_dir) / 'test_storage' + storage_path = test_dir / 'test.db' + + assert not test_dir.exists() + LTMSQLiteStorage(storage_path=storage_path) + assert test_dir.exists() + +def test_permission_error(): + """Test that permission errors are handled correctly.""" + with tempfile.TemporaryDirectory() as temp_dir: + test_dir = Path(temp_dir) / 'readonly' + test_dir.mkdir() + os.chmod(test_dir, 0o444) # Read-only + + storage_path = test_dir / 'test.db' + with pytest.raises((PermissionError, OSError)) as exc_info: + LTMSQLiteStorage(storage_path=storage_path) + # Verify that the error message mentions permission + assert "permission" in str(exc_info.value).lower() + +def test_invalid_path(): + """Test that invalid paths raise appropriate errors.""" + with pytest.raises(OSError): + # Try to create storage in a non-existent root directory + LTMSQLiteStorage(storage_path=Path('/nonexistent/dir/test.db')) diff --git a/uv.lock b/uv.lock index c37a1fa4e..506998441 100644 --- a/uv.lock +++ b/uv.lock @@ -1,10 +1,18 @@ version = 1 requires-python = ">=3.10, <3.13" resolution-markers = [ - "python_full_version < '3.11'", - "python_full_version == '3.11.*'", - "python_full_version >= '3.12' and python_full_version < '3.12.4'", - "python_full_version >= '3.12.4'", + "python_full_version < '3.11' and sys_platform == 'darwin'", + "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')", + "python_full_version == '3.11.*' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')", + "python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform == 'darwin'", + "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux')", + "python_full_version >= '3.12.4' and sys_platform == 'darwin'", + "python_full_version >= '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version >= '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux')", ] [[package]] @@ -300,7 +308,7 @@ name = "build" version = "1.2.2.post1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "os_name == 'nt'" }, + { name = "colorama", marker = "(os_name == 'nt' and platform_machine != 'aarch64' and sys_platform == 'linux') or (os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" }, { name = "importlib-metadata", marker = "python_full_version < '3.10.2'" }, { name = "packaging" }, { name = "pyproject-hooks" }, @@ -535,7 +543,7 @@ name = "click" version = "8.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } wheels = [ @@ -642,7 +650,6 @@ tools = [ [package.dev-dependencies] dev = [ { name = "cairosvg" }, - { name = "crewai-tools" }, { name = "mkdocs" }, { name = "mkdocs-material" }, { name = "mkdocs-material-extensions" }, @@ -696,7 +703,6 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "cairosvg", specifier = ">=2.7.1" }, - { name = "crewai-tools", specifier = ">=0.17.0" }, { name = "mkdocs", specifier = ">=1.4.3" }, { name = "mkdocs-material", specifier = ">=9.5.7" }, { name = "mkdocs-material-extensions", specifier = ">=1.3.1" }, @@ -2462,7 +2468,7 @@ version = "1.6.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, { name = "ghp-import" }, { name = "jinja2" }, { name = "markdown" }, @@ -2643,7 +2649,7 @@ version = "2.10.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pygments" }, - { name = "pywin32", marker = "platform_system == 'Windows'" }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, { name = "tqdm" }, ] sdist = { url = "https://files.pythonhosted.org/packages/3a/93/80ac75c20ce54c785648b4ed363c88f148bf22637e10c9863db4fbe73e74/mpire-2.10.2.tar.gz", hash = "sha256:f66a321e93fadff34585a4bfa05e95bd946cf714b442f51c529038eb45773d97", size = 271270 } @@ -2890,7 +2896,7 @@ name = "nvidia-cudnn-cu12" version = "9.1.0.70" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 }, @@ -2917,9 +2923,9 @@ name = "nvidia-cusolver-cu12" version = "11.4.5.107" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, - { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 }, @@ -2930,7 +2936,7 @@ name = "nvidia-cusparse-cu12" version = "12.1.0.106" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 }, @@ -3480,7 +3486,7 @@ name = "portalocker" version = "2.10.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pywin32", marker = "platform_system == 'Windows'" }, + { name = "pywin32", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ed/d3/c6c64067759e87af98cc668c1cc75171347d0f1577fab7ca3749134e3cd4/portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f", size = 40891 } wheels = [ @@ -5022,19 +5028,19 @@ dependencies = [ { name = "fsspec" }, { name = "jinja2" }, { name = "networkx" }, - { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "sympy" }, - { name = "triton", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions" }, ] wheels = [ @@ -5081,7 +5087,7 @@ name = "tqdm" version = "4.66.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/58/83/6ba9844a41128c62e810fddddd72473201f3eacde02046066142a2d96cc5/tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad", size = 169504 } wheels = [ @@ -5124,7 +5130,7 @@ version = "0.27.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "attrs" }, - { name = "cffi", marker = "implementation_name != 'pypy' and os_name == 'nt'" }, + { name = "cffi", marker = "(implementation_name != 'pypy' and os_name == 'nt' and platform_machine != 'aarch64' and sys_platform == 'linux') or (implementation_name != 'pypy' and os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "idna" }, { name = "outcome" }, @@ -5155,7 +5161,7 @@ name = "triton" version = "3.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" }, + { name = "filelock", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 },