fix: replace dual-lock with single cross-process lock in LanceDB storage

This commit is contained in:
Greyson Lalonde
2026-03-13 01:29:41 -04:00
parent 5a4f6956b3
commit e303ca4243
2 changed files with 81 additions and 128 deletions

View File

@@ -434,40 +434,36 @@ class EncodingFlow(Flow[EncodingState]):
)
)
# All storage mutations under one lock so no other pipeline can
# interleave and cause version conflicts. The lock is reentrant
# (RLock) so the individual storage methods re-acquire it safely.
updated_records: dict[str, MemoryRecord] = {}
with self._storage.write_lock:
if dedup_deletes:
self._storage.delete(record_ids=list(dedup_deletes))
self.state.records_deleted += len(dedup_deletes)
if dedup_deletes:
self._storage.delete(record_ids=list(dedup_deletes))
self.state.records_deleted += len(dedup_deletes)
for rid, (_item_idx, new_content) in dedup_updates.items():
existing = all_similar.get(rid)
if existing is not None:
new_emb = update_emb_map.get(rid, [])
updated = MemoryRecord(
id=existing.id,
content=new_content,
scope=existing.scope,
categories=existing.categories,
metadata=existing.metadata,
importance=existing.importance,
created_at=existing.created_at,
last_accessed=now,
embedding=new_emb if new_emb else existing.embedding,
)
self._storage.update(updated)
self.state.records_updated += 1
updated_records[rid] = updated
for rid, (_item_idx, new_content) in dedup_updates.items():
existing = all_similar.get(rid)
if existing is not None:
new_emb = update_emb_map.get(rid, [])
updated = MemoryRecord(
id=existing.id,
content=new_content,
scope=existing.scope,
categories=existing.categories,
metadata=existing.metadata,
importance=existing.importance,
created_at=existing.created_at,
last_accessed=now,
embedding=new_emb if new_emb else existing.embedding,
)
self._storage.update(updated)
self.state.records_updated += 1
updated_records[rid] = updated
if to_insert:
records = [r for _, r in to_insert]
self._storage.save(records)
self.state.records_inserted += len(records)
for idx, record in to_insert:
items[idx].result_record = record
if to_insert:
records = [r for _, r in to_insert]
self._storage.save(records)
self.state.records_inserted += len(records)
for idx, record in to_insert:
items[idx].result_record = record
# Set result_record for non-insert items (after lock, using updated_records)
for _i, item in enumerate(items):

View File

@@ -2,7 +2,6 @@
from __future__ import annotations
from contextlib import AbstractContextManager
import contextvars
from datetime import datetime
import json
@@ -11,9 +10,9 @@ import os
from pathlib import Path
import threading
import time
from typing import Any, ClassVar
from typing import Any
import lancedb
import lancedb # type: ignore[import-untyped]
from crewai.memory.types import MemoryRecord, ScopeInfo
from crewai.utilities.lock_store import lock as store_lock
@@ -42,15 +41,6 @@ _RETRY_BASE_DELAY = 0.2 # seconds; doubles on each retry
class LanceDBStorage:
"""LanceDB-backed storage for the unified memory system."""
# Class-level registry: maps resolved database path -> shared write lock.
# When multiple Memory instances (e.g. agent + crew) independently create
# LanceDBStorage pointing at the same directory, they share one lock so
# their writes don't conflict.
# Uses RLock (reentrant) so callers can hold the lock for a batch of
# operations while the individual methods re-acquire it without deadlocking.
_path_locks: ClassVar[dict[str, threading.RLock]] = {}
_path_locks_guard: ClassVar[threading.Lock] = threading.Lock()
def __init__(
self,
path: str | Path | None = None,
@@ -86,44 +76,19 @@ class LanceDBStorage:
self._table_name = table_name
self._db = lancedb.connect(str(self._path))
# On macOS and Linux the default per-process open-file limit is 256.
# A LanceDB table stores one file per fragment (one fragment per save()
# call by default). With hundreds of fragments, a single full-table
# scan opens all of them simultaneously, exhausting the limit.
# Raise it proactively so scans on large tables never hit OS error 24.
try:
import resource
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
if soft < 4096:
resource.setrlimit(resource.RLIMIT_NOFILE, (min(hard, 4096), hard))
except Exception: # noqa: S110
pass # Windows or already at the max hard limit — safe to ignore
self._compact_every = compact_every
self._save_count = 0
self._lock_name = f"lancedb:{self._path.resolve()}"
resolved = str(self._path.resolve())
with LanceDBStorage._path_locks_guard:
if resolved not in LanceDBStorage._path_locks:
LanceDBStorage._path_locks[resolved] = threading.RLock()
self._write_lock = LanceDBStorage._path_locks[resolved]
# Try to open an existing table and infer dimension from its schema.
# If no table exists yet, defer creation until the first save so the
# dimension can be auto-detected from the embedder's actual output.
try:
self._table: lancedb.table.Table | None = self._db.open_table(
self._table_name
)
self._table: Any = self._db.open_table(self._table_name)
self._vector_dim: int = self._infer_dim_from_table(self._table)
# Best-effort: create the scope index if it doesn't exist yet.
with self._file_lock():
with store_lock(self._lock_name):
self._ensure_scope_index()
# Compact in the background if the table has accumulated many
# fragments from previous runs (each save() creates one).
self._compact_if_needed()
except Exception:
self._table = None
@@ -132,40 +97,25 @@ class LanceDBStorage:
# Explicit dim provided: create the table immediately if it doesn't exist.
if self._table is None and vector_dim is not None:
self._vector_dim = vector_dim
with self._file_lock():
with store_lock(self._lock_name):
self._table = self._create_table(vector_dim)
@property
def write_lock(self) -> threading.RLock:
"""The shared reentrant write lock for this database path.
Callers can acquire this to hold the lock across multiple storage
operations (e.g. delete + update + save as one atomic batch).
Individual methods also acquire it internally, but since it's
reentrant (RLock), the same thread won't deadlock.
"""
return self._write_lock
@staticmethod
def _infer_dim_from_table(table: lancedb.table.Table) -> int:
def _infer_dim_from_table(table: Any) -> int:
"""Read vector dimension from an existing table's schema."""
schema = table.schema
for field in schema:
if field.name == "vector":
try:
return field.type.list_size
return int(field.type.list_size)
except Exception:
break
return DEFAULT_VECTOR_DIM
def _file_lock(self) -> AbstractContextManager[None]:
"""Return a cross-process lock for serialising writes."""
return store_lock(self._lock_name)
def _do_write(self, op: str, *args: Any, **kwargs: Any) -> Any:
"""Execute a single table write with retry on commit conflicts.
Caller must already hold the cross-process file lock.
Caller must already hold ``store_lock(self._lock_name)``.
"""
delay = _RETRY_BASE_DELAY
for attempt in range(_MAX_RETRIES + 1):
@@ -189,10 +139,10 @@ class LanceDBStorage:
delay *= 2
return None # unreachable, but satisfies type checker
def _create_table(self, vector_dim: int) -> lancedb.table.Table:
def _create_table(self, vector_dim: int) -> Any:
"""Create a new table with the given vector dimension.
Caller must already hold the cross-process file lock.
Caller must already hold ``store_lock(self._lock_name)``.
"""
placeholder = [
{
@@ -263,13 +213,13 @@ class LanceDBStorage:
"""Run ``table.optimize()`` in a background thread, absorbing errors."""
try:
if self._table is not None:
with self._file_lock():
with store_lock(self._lock_name):
self._table.optimize()
self._ensure_scope_index()
except Exception:
_logger.debug("LanceDB background compaction failed", exc_info=True)
def _ensure_table(self, vector_dim: int | None = None) -> lancedb.table.Table:
def _ensure_table(self, vector_dim: int | None = None) -> Any:
"""Return the table, creating it lazily if needed.
Args:
@@ -335,12 +285,12 @@ class LanceDBStorage:
dim = len(r.embedding)
break
is_new_table = self._table is None
with self._write_lock, self._file_lock():
with store_lock(self._lock_name):
self._ensure_table(vector_dim=dim)
rows = [self._record_to_row(r) for r in records]
for r in rows:
if r["vector"] is None or len(r["vector"]) != self._vector_dim:
r["vector"] = [0.0] * self._vector_dim
rows = [self._record_to_row(rec) for rec in records]
for row in rows:
if row["vector"] is None or len(row["vector"]) != self._vector_dim:
row["vector"] = [0.0] * self._vector_dim
self._do_write("add", rows)
if is_new_table:
self._ensure_scope_index()
@@ -351,7 +301,7 @@ class LanceDBStorage:
def update(self, record: MemoryRecord) -> None:
"""Update a record by ID. Preserves created_at, updates last_accessed."""
with self._write_lock, self._file_lock():
with store_lock(self._lock_name):
self._ensure_table()
safe_id = str(record.id).replace("'", "''")
self._do_write("delete", f"id = '{safe_id}'")
@@ -372,7 +322,7 @@ class LanceDBStorage:
"""
if not record_ids or self._table is None:
return
with self._write_lock, self._file_lock():
with store_lock(self._lock_name):
now = datetime.utcnow().isoformat()
safe_ids = [str(rid).replace("'", "''") for rid in record_ids]
ids_expr = ", ".join(f"'{rid}'" for rid in safe_ids)
@@ -386,7 +336,7 @@ class LanceDBStorage:
"""Return a single record by ID, or None if not found."""
if self._table is None:
return None
with self._write_lock:
with store_lock(self._lock_name):
safe_id = str(record_id).replace("'", "''")
rows = self._table.search().where(f"id = '{safe_id}'").limit(1).to_list()
if not rows:
@@ -404,7 +354,7 @@ class LanceDBStorage:
) -> list[tuple[MemoryRecord, float]]:
if self._table is None:
return []
with self._write_lock:
with store_lock(self._lock_name):
query = self._table.search(query_embedding)
if scope_prefix is not None and scope_prefix.strip("/"):
prefix = scope_prefix.rstrip("/")
@@ -440,12 +390,12 @@ class LanceDBStorage:
) -> int:
if self._table is None:
return 0
with self._write_lock, self._file_lock():
with store_lock(self._lock_name):
if record_ids and not (categories or metadata_filter):
before = self._table.count_rows()
before = int(self._table.count_rows())
ids_expr = ", ".join(f"'{rid}'" for rid in record_ids)
self._do_write("delete", f"id IN ({ids_expr})")
return before - self._table.count_rows()
return before - int(self._table.count_rows())
if categories or metadata_filter:
rows = self._scan_rows(scope_prefix)
to_delete: list[str] = []
@@ -464,10 +414,10 @@ class LanceDBStorage:
to_delete.append(record.id)
if not to_delete:
return 0
before = self._table.count_rows()
before = int(self._table.count_rows())
ids_expr = ", ".join(f"'{rid}'" for rid in to_delete)
self._do_write("delete", f"id IN ({ids_expr})")
return before - self._table.count_rows()
return before - int(self._table.count_rows())
conditions = []
if scope_prefix is not None and scope_prefix.strip("/"):
prefix = scope_prefix.rstrip("/")
@@ -477,13 +427,13 @@ class LanceDBStorage:
if older_than is not None:
conditions.append(f"created_at < '{older_than.isoformat()}'")
if not conditions:
before = self._table.count_rows()
before = int(self._table.count_rows())
self._do_write("delete", "id != ''")
return before - self._table.count_rows()
return before - int(self._table.count_rows())
where_expr = " AND ".join(conditions)
before = self._table.count_rows()
before = int(self._table.count_rows())
self._do_write("delete", where_expr)
return before - self._table.count_rows()
return before - int(self._table.count_rows())
def _scan_rows(
self,
@@ -496,6 +446,8 @@ class LanceDBStorage:
Uses a full table scan (no vector query) so the limit is applied after
the scope filter, not to ANN candidates before filtering.
Caller must hold ``store_lock(self._lock_name)``.
Args:
scope_prefix: Optional scope path prefix to filter by.
limit: Maximum number of rows to return (applied after filtering).
@@ -505,13 +457,13 @@ class LanceDBStorage:
"""
if self._table is None:
return []
with self._write_lock:
q = self._table.search()
if scope_prefix is not None and scope_prefix.strip("/"):
q = q.where(f"scope LIKE '{scope_prefix.rstrip('/')}%'")
if columns is not None:
q = q.select(columns)
return q.limit(limit).to_list()
q = self._table.search()
if scope_prefix is not None and scope_prefix.strip("/"):
q = q.where(f"scope LIKE '{scope_prefix.rstrip('/')}%'")
if columns is not None:
q = q.select(columns)
result: list[dict[str, Any]] = q.limit(limit).to_list()
return result
def list_records(
self, scope_prefix: str | None = None, limit: int = 200, offset: int = 0
@@ -526,7 +478,8 @@ class LanceDBStorage:
Returns:
List of MemoryRecord, ordered by created_at descending.
"""
rows = self._scan_rows(scope_prefix, limit=limit + offset)
with store_lock(self._lock_name):
rows = self._scan_rows(scope_prefix, limit=limit + offset)
records = [self._row_to_record(r) for r in rows]
records.sort(key=lambda r: r.created_at, reverse=True)
return records[offset : offset + limit]
@@ -536,10 +489,11 @@ class LanceDBStorage:
prefix = scope if scope != "/" else ""
if prefix and not prefix.startswith("/"):
prefix = "/" + prefix
rows = self._scan_rows(
prefix or None,
columns=["scope", "categories_str", "created_at"],
)
with store_lock(self._lock_name):
rows = self._scan_rows(
prefix or None,
columns=["scope", "categories_str", "created_at"],
)
if not rows:
return ScopeInfo(
path=scope or "/",
@@ -590,7 +544,8 @@ class LanceDBStorage:
def list_scopes(self, parent: str = "/") -> list[str]:
parent = parent.rstrip("/") or ""
prefix = (parent + "/") if parent else "/"
rows = self._scan_rows(prefix if prefix != "/" else None, columns=["scope"])
with store_lock(self._lock_name):
rows = self._scan_rows(prefix if prefix != "/" else None, columns=["scope"])
children: set[str] = set()
for row in rows:
sc = str(row.get("scope", ""))
@@ -602,7 +557,8 @@ class LanceDBStorage:
return sorted(children)
def list_categories(self, scope_prefix: str | None = None) -> dict[str, int]:
rows = self._scan_rows(scope_prefix, columns=["categories_str"])
with store_lock(self._lock_name):
rows = self._scan_rows(scope_prefix, columns=["categories_str"])
counts: dict[str, int] = {}
for row in rows:
cat_str = row.get("categories_str") or "[]"
@@ -618,12 +574,13 @@ class LanceDBStorage:
if self._table is None:
return 0
if scope_prefix is None or scope_prefix.strip("/") == "":
return self._table.count_rows()
with store_lock(self._lock_name):
return int(self._table.count_rows())
info = self.get_scope_info(scope_prefix)
return info.record_count
def reset(self, scope_prefix: str | None = None) -> None:
with self._write_lock, self._file_lock():
with store_lock(self._lock_name):
if scope_prefix is None or scope_prefix.strip("/") == "":
if self._table is not None:
self._db.drop_table(self._table_name)
@@ -649,7 +606,7 @@ class LanceDBStorage:
"""
if self._table is None:
return
with self._write_lock, self._file_lock():
with store_lock(self._lock_name):
self._table.optimize()
self._ensure_scope_index()