diff --git a/lib/crewai/src/crewai/memory/storage/lancedb_storage.py b/lib/crewai/src/crewai/memory/storage/lancedb_storage.py index e03332ebe..08ca68483 100644 --- a/lib/crewai/src/crewai/memory/storage/lancedb_storage.py +++ b/lib/crewai/src/crewai/memory/storage/lancedb_storage.py @@ -2,6 +2,7 @@ from __future__ import annotations +from contextlib import AbstractContextManager from datetime import datetime import json import logging @@ -12,9 +13,9 @@ import time from typing import Any, ClassVar import lancedb -import portalocker from crewai.memory.types import MemoryRecord, ScopeInfo +from crewai.utilities.lock_store import lock as store_lock _logger = logging.getLogger(__name__) @@ -101,7 +102,7 @@ class LanceDBStorage: self._compact_every = compact_every self._save_count = 0 - self._lockfile = str(self._path / ".lance_write.lock") + self._lock_name = f"lancedb:{self._path.resolve()}" resolved = str(self._path.resolve()) with LanceDBStorage._path_locks_guard: @@ -156,9 +157,9 @@ class LanceDBStorage: break return DEFAULT_VECTOR_DIM - def _file_lock(self) -> portalocker.Lock: - """Return a cross-process file lock for serialising writes.""" - return portalocker.Lock(self._lockfile, timeout=120) + def _file_lock(self) -> AbstractContextManager[None]: + """Return a cross-process lock for serialising writes.""" + return store_lock(self._lock_name) def _do_write(self, op: str, *args: Any, **kwargs: Any) -> Any: """Execute a single table write with retry on commit conflicts. @@ -625,7 +626,9 @@ class LanceDBStorage: return prefix = scope_prefix.rstrip("/") if prefix: - self._do_write("delete", f"scope >= '{prefix}' AND scope < '{prefix}/\uffff'") + self._do_write( + "delete", f"scope >= '{prefix}' AND scope < '{prefix}/\uffff'" + ) def optimize(self) -> None: """Compact the table synchronously and refresh the scope index. diff --git a/lib/crewai/src/crewai/rag/chromadb/factory.py b/lib/crewai/src/crewai/rag/chromadb/factory.py index 6d1c4cd68..2a857e067 100644 --- a/lib/crewai/src/crewai/rag/chromadb/factory.py +++ b/lib/crewai/src/crewai/rag/chromadb/factory.py @@ -1,13 +1,12 @@ """Factory functions for creating ChromaDB clients.""" -from hashlib import md5 import os from chromadb import PersistentClient -import portalocker from crewai.rag.chromadb.client import ChromaDBClient from crewai.rag.chromadb.config import ChromaDBConfig +from crewai.utilities.lock_store import lock def create_client(config: ChromaDBConfig) -> ChromaDBClient: @@ -25,10 +24,8 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient: persist_dir = config.settings.persist_directory os.makedirs(persist_dir, exist_ok=True) - lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest() - lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock") - with portalocker.Lock(lockfile, timeout=120): + with lock(f"chromadb:{persist_dir}"): client = PersistentClient( path=persist_dir, settings=config.settings, diff --git a/lib/crewai/src/crewai/utilities/lock_store.py b/lib/crewai/src/crewai/utilities/lock_store.py new file mode 100644 index 000000000..91b3d742a --- /dev/null +++ b/lib/crewai/src/crewai/utilities/lock_store.py @@ -0,0 +1,61 @@ +"""Centralised lock factory. + +If ``REDIS_URL`` is set, locks are distributed via ``portalocker.RedisLock``. Otherwise, falls +back to the standard ``portalocker.Lock``. +""" + +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from functools import lru_cache +from hashlib import md5 +import os +import tempfile +from typing import TYPE_CHECKING, Final + +import portalocker + + +if TYPE_CHECKING: + import redis + + +_REDIS_URL: str | None = os.environ.get("REDIS_URL") + +_DEFAULT_TIMEOUT: Final[int] = 120 + + +@lru_cache(maxsize=1) +def _redis_connection() -> redis.Redis: + """Return a cached Redis connection, creating one on first call.""" + from redis import Redis + + if _REDIS_URL is None: + raise ValueError("REDIS_URL environment variable is not set") + return Redis.from_url(_REDIS_URL) + + +@contextmanager +def lock(name: str, *, timeout: float = _DEFAULT_TIMEOUT) -> Iterator[None]: + """Acquire a named lock, yielding while it is held. + + Args: + name: A human-readable lock name (e.g. ``"chromadb_init"``). + Automatically namespaced to avoid collisions. + timeout: Maximum seconds to wait for the lock before raising. + """ + channel = f"crewai:{md5(name.encode(), usedforsecurity=False).hexdigest()}" + + if _REDIS_URL: + with portalocker.RedisLock( + channel=channel, + connection=_redis_connection(), + timeout=timeout, + ): + yield + else: + lock_dir = tempfile.gettempdir() + lock_path = os.path.join(lock_dir, f"{channel}.lock") + with portalocker.Lock(lock_path, timeout=timeout): + yield diff --git a/lib/crewai/tests/memory/test_concurrent_storage.py b/lib/crewai/tests/memory/test_concurrent_storage.py index 9ac21d67d..86ef50771 100644 --- a/lib/crewai/tests/memory/test_concurrent_storage.py +++ b/lib/crewai/tests/memory/test_concurrent_storage.py @@ -119,11 +119,10 @@ def _sqlite_flow_worker(db_path: str, worker_id: int, n_writes: int, result_dir: def _chromadb_worker(persist_dir: str, worker_id: int, result_dir: str): try: - from hashlib import md5 - from chromadb import PersistentClient from chromadb.config import Settings - import portalocker + + from crewai.utilities.lock_store import lock settings = Settings( persist_directory=persist_dir, @@ -131,10 +130,7 @@ def _chromadb_worker(persist_dir: str, worker_id: int, result_dir: str): is_persistent=True, ) - # Test the actual locking path directly (same as factory.py) - lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest() - lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock") - with portalocker.Lock(lockfile, timeout=120): + with lock(f"chromadb:{persist_dir}"): PersistentClient(path=persist_dir, settings=settings) _write_result(result_dir, worker_id, True)