mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-26 21:02:35 +00:00
fix: replace dual-lock with single cross-process lock in LanceDB storage
This commit is contained in:
@@ -434,40 +434,36 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
)
|
||||
)
|
||||
|
||||
# All storage mutations under one lock so no other pipeline can
|
||||
# interleave and cause version conflicts. The lock is reentrant
|
||||
# (RLock) so the individual storage methods re-acquire it safely.
|
||||
updated_records: dict[str, MemoryRecord] = {}
|
||||
with self._storage.write_lock:
|
||||
if dedup_deletes:
|
||||
self._storage.delete(record_ids=list(dedup_deletes))
|
||||
self.state.records_deleted += len(dedup_deletes)
|
||||
if dedup_deletes:
|
||||
self._storage.delete(record_ids=list(dedup_deletes))
|
||||
self.state.records_deleted += len(dedup_deletes)
|
||||
|
||||
for rid, (_item_idx, new_content) in dedup_updates.items():
|
||||
existing = all_similar.get(rid)
|
||||
if existing is not None:
|
||||
new_emb = update_emb_map.get(rid, [])
|
||||
updated = MemoryRecord(
|
||||
id=existing.id,
|
||||
content=new_content,
|
||||
scope=existing.scope,
|
||||
categories=existing.categories,
|
||||
metadata=existing.metadata,
|
||||
importance=existing.importance,
|
||||
created_at=existing.created_at,
|
||||
last_accessed=now,
|
||||
embedding=new_emb if new_emb else existing.embedding,
|
||||
)
|
||||
self._storage.update(updated)
|
||||
self.state.records_updated += 1
|
||||
updated_records[rid] = updated
|
||||
for rid, (_item_idx, new_content) in dedup_updates.items():
|
||||
existing = all_similar.get(rid)
|
||||
if existing is not None:
|
||||
new_emb = update_emb_map.get(rid, [])
|
||||
updated = MemoryRecord(
|
||||
id=existing.id,
|
||||
content=new_content,
|
||||
scope=existing.scope,
|
||||
categories=existing.categories,
|
||||
metadata=existing.metadata,
|
||||
importance=existing.importance,
|
||||
created_at=existing.created_at,
|
||||
last_accessed=now,
|
||||
embedding=new_emb if new_emb else existing.embedding,
|
||||
)
|
||||
self._storage.update(updated)
|
||||
self.state.records_updated += 1
|
||||
updated_records[rid] = updated
|
||||
|
||||
if to_insert:
|
||||
records = [r for _, r in to_insert]
|
||||
self._storage.save(records)
|
||||
self.state.records_inserted += len(records)
|
||||
for idx, record in to_insert:
|
||||
items[idx].result_record = record
|
||||
if to_insert:
|
||||
records = [r for _, r in to_insert]
|
||||
self._storage.save(records)
|
||||
self.state.records_inserted += len(records)
|
||||
for idx, record in to_insert:
|
||||
items[idx].result_record = record
|
||||
|
||||
# Set result_record for non-insert items (after lock, using updated_records)
|
||||
for _i, item in enumerate(items):
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import AbstractContextManager
|
||||
import contextvars
|
||||
from datetime import datetime
|
||||
import json
|
||||
@@ -11,9 +10,9 @@ import os
|
||||
from pathlib import Path
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, ClassVar
|
||||
from typing import Any
|
||||
|
||||
import lancedb
|
||||
import lancedb # type: ignore[import-untyped]
|
||||
|
||||
from crewai.memory.types import MemoryRecord, ScopeInfo
|
||||
from crewai.utilities.lock_store import lock as store_lock
|
||||
@@ -42,15 +41,6 @@ _RETRY_BASE_DELAY = 0.2 # seconds; doubles on each retry
|
||||
class LanceDBStorage:
|
||||
"""LanceDB-backed storage for the unified memory system."""
|
||||
|
||||
# Class-level registry: maps resolved database path -> shared write lock.
|
||||
# When multiple Memory instances (e.g. agent + crew) independently create
|
||||
# LanceDBStorage pointing at the same directory, they share one lock so
|
||||
# their writes don't conflict.
|
||||
# Uses RLock (reentrant) so callers can hold the lock for a batch of
|
||||
# operations while the individual methods re-acquire it without deadlocking.
|
||||
_path_locks: ClassVar[dict[str, threading.RLock]] = {}
|
||||
_path_locks_guard: ClassVar[threading.Lock] = threading.Lock()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str | Path | None = None,
|
||||
@@ -86,44 +76,19 @@ class LanceDBStorage:
|
||||
self._table_name = table_name
|
||||
self._db = lancedb.connect(str(self._path))
|
||||
|
||||
# On macOS and Linux the default per-process open-file limit is 256.
|
||||
# A LanceDB table stores one file per fragment (one fragment per save()
|
||||
# call by default). With hundreds of fragments, a single full-table
|
||||
# scan opens all of them simultaneously, exhausting the limit.
|
||||
# Raise it proactively so scans on large tables never hit OS error 24.
|
||||
try:
|
||||
import resource
|
||||
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
if soft < 4096:
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (min(hard, 4096), hard))
|
||||
except Exception: # noqa: S110
|
||||
pass # Windows or already at the max hard limit — safe to ignore
|
||||
|
||||
self._compact_every = compact_every
|
||||
self._save_count = 0
|
||||
|
||||
self._lock_name = f"lancedb:{self._path.resolve()}"
|
||||
|
||||
resolved = str(self._path.resolve())
|
||||
with LanceDBStorage._path_locks_guard:
|
||||
if resolved not in LanceDBStorage._path_locks:
|
||||
LanceDBStorage._path_locks[resolved] = threading.RLock()
|
||||
self._write_lock = LanceDBStorage._path_locks[resolved]
|
||||
|
||||
# Try to open an existing table and infer dimension from its schema.
|
||||
# If no table exists yet, defer creation until the first save so the
|
||||
# dimension can be auto-detected from the embedder's actual output.
|
||||
try:
|
||||
self._table: lancedb.table.Table | None = self._db.open_table(
|
||||
self._table_name
|
||||
)
|
||||
self._table: Any = self._db.open_table(self._table_name)
|
||||
self._vector_dim: int = self._infer_dim_from_table(self._table)
|
||||
# Best-effort: create the scope index if it doesn't exist yet.
|
||||
with self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
self._ensure_scope_index()
|
||||
# Compact in the background if the table has accumulated many
|
||||
# fragments from previous runs (each save() creates one).
|
||||
self._compact_if_needed()
|
||||
except Exception:
|
||||
self._table = None
|
||||
@@ -132,40 +97,25 @@ class LanceDBStorage:
|
||||
# Explicit dim provided: create the table immediately if it doesn't exist.
|
||||
if self._table is None and vector_dim is not None:
|
||||
self._vector_dim = vector_dim
|
||||
with self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
self._table = self._create_table(vector_dim)
|
||||
|
||||
@property
|
||||
def write_lock(self) -> threading.RLock:
|
||||
"""The shared reentrant write lock for this database path.
|
||||
|
||||
Callers can acquire this to hold the lock across multiple storage
|
||||
operations (e.g. delete + update + save as one atomic batch).
|
||||
Individual methods also acquire it internally, but since it's
|
||||
reentrant (RLock), the same thread won't deadlock.
|
||||
"""
|
||||
return self._write_lock
|
||||
|
||||
@staticmethod
|
||||
def _infer_dim_from_table(table: lancedb.table.Table) -> int:
|
||||
def _infer_dim_from_table(table: Any) -> int:
|
||||
"""Read vector dimension from an existing table's schema."""
|
||||
schema = table.schema
|
||||
for field in schema:
|
||||
if field.name == "vector":
|
||||
try:
|
||||
return field.type.list_size
|
||||
return int(field.type.list_size)
|
||||
except Exception:
|
||||
break
|
||||
return DEFAULT_VECTOR_DIM
|
||||
|
||||
def _file_lock(self) -> AbstractContextManager[None]:
|
||||
"""Return a cross-process lock for serialising writes."""
|
||||
return store_lock(self._lock_name)
|
||||
|
||||
def _do_write(self, op: str, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Execute a single table write with retry on commit conflicts.
|
||||
|
||||
Caller must already hold the cross-process file lock.
|
||||
Caller must already hold ``store_lock(self._lock_name)``.
|
||||
"""
|
||||
delay = _RETRY_BASE_DELAY
|
||||
for attempt in range(_MAX_RETRIES + 1):
|
||||
@@ -189,10 +139,10 @@ class LanceDBStorage:
|
||||
delay *= 2
|
||||
return None # unreachable, but satisfies type checker
|
||||
|
||||
def _create_table(self, vector_dim: int) -> lancedb.table.Table:
|
||||
def _create_table(self, vector_dim: int) -> Any:
|
||||
"""Create a new table with the given vector dimension.
|
||||
|
||||
Caller must already hold the cross-process file lock.
|
||||
Caller must already hold ``store_lock(self._lock_name)``.
|
||||
"""
|
||||
placeholder = [
|
||||
{
|
||||
@@ -263,13 +213,13 @@ class LanceDBStorage:
|
||||
"""Run ``table.optimize()`` in a background thread, absorbing errors."""
|
||||
try:
|
||||
if self._table is not None:
|
||||
with self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
self._table.optimize()
|
||||
self._ensure_scope_index()
|
||||
except Exception:
|
||||
_logger.debug("LanceDB background compaction failed", exc_info=True)
|
||||
|
||||
def _ensure_table(self, vector_dim: int | None = None) -> lancedb.table.Table:
|
||||
def _ensure_table(self, vector_dim: int | None = None) -> Any:
|
||||
"""Return the table, creating it lazily if needed.
|
||||
|
||||
Args:
|
||||
@@ -335,12 +285,12 @@ class LanceDBStorage:
|
||||
dim = len(r.embedding)
|
||||
break
|
||||
is_new_table = self._table is None
|
||||
with self._write_lock, self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
self._ensure_table(vector_dim=dim)
|
||||
rows = [self._record_to_row(r) for r in records]
|
||||
for r in rows:
|
||||
if r["vector"] is None or len(r["vector"]) != self._vector_dim:
|
||||
r["vector"] = [0.0] * self._vector_dim
|
||||
rows = [self._record_to_row(rec) for rec in records]
|
||||
for row in rows:
|
||||
if row["vector"] is None or len(row["vector"]) != self._vector_dim:
|
||||
row["vector"] = [0.0] * self._vector_dim
|
||||
self._do_write("add", rows)
|
||||
if is_new_table:
|
||||
self._ensure_scope_index()
|
||||
@@ -351,7 +301,7 @@ class LanceDBStorage:
|
||||
|
||||
def update(self, record: MemoryRecord) -> None:
|
||||
"""Update a record by ID. Preserves created_at, updates last_accessed."""
|
||||
with self._write_lock, self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
self._ensure_table()
|
||||
safe_id = str(record.id).replace("'", "''")
|
||||
self._do_write("delete", f"id = '{safe_id}'")
|
||||
@@ -372,7 +322,7 @@ class LanceDBStorage:
|
||||
"""
|
||||
if not record_ids or self._table is None:
|
||||
return
|
||||
with self._write_lock, self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
now = datetime.utcnow().isoformat()
|
||||
safe_ids = [str(rid).replace("'", "''") for rid in record_ids]
|
||||
ids_expr = ", ".join(f"'{rid}'" for rid in safe_ids)
|
||||
@@ -386,7 +336,7 @@ class LanceDBStorage:
|
||||
"""Return a single record by ID, or None if not found."""
|
||||
if self._table is None:
|
||||
return None
|
||||
with self._write_lock:
|
||||
with store_lock(self._lock_name):
|
||||
safe_id = str(record_id).replace("'", "''")
|
||||
rows = self._table.search().where(f"id = '{safe_id}'").limit(1).to_list()
|
||||
if not rows:
|
||||
@@ -404,7 +354,7 @@ class LanceDBStorage:
|
||||
) -> list[tuple[MemoryRecord, float]]:
|
||||
if self._table is None:
|
||||
return []
|
||||
with self._write_lock:
|
||||
with store_lock(self._lock_name):
|
||||
query = self._table.search(query_embedding)
|
||||
if scope_prefix is not None and scope_prefix.strip("/"):
|
||||
prefix = scope_prefix.rstrip("/")
|
||||
@@ -440,12 +390,12 @@ class LanceDBStorage:
|
||||
) -> int:
|
||||
if self._table is None:
|
||||
return 0
|
||||
with self._write_lock, self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
if record_ids and not (categories or metadata_filter):
|
||||
before = self._table.count_rows()
|
||||
before = int(self._table.count_rows())
|
||||
ids_expr = ", ".join(f"'{rid}'" for rid in record_ids)
|
||||
self._do_write("delete", f"id IN ({ids_expr})")
|
||||
return before - self._table.count_rows()
|
||||
return before - int(self._table.count_rows())
|
||||
if categories or metadata_filter:
|
||||
rows = self._scan_rows(scope_prefix)
|
||||
to_delete: list[str] = []
|
||||
@@ -464,10 +414,10 @@ class LanceDBStorage:
|
||||
to_delete.append(record.id)
|
||||
if not to_delete:
|
||||
return 0
|
||||
before = self._table.count_rows()
|
||||
before = int(self._table.count_rows())
|
||||
ids_expr = ", ".join(f"'{rid}'" for rid in to_delete)
|
||||
self._do_write("delete", f"id IN ({ids_expr})")
|
||||
return before - self._table.count_rows()
|
||||
return before - int(self._table.count_rows())
|
||||
conditions = []
|
||||
if scope_prefix is not None and scope_prefix.strip("/"):
|
||||
prefix = scope_prefix.rstrip("/")
|
||||
@@ -477,13 +427,13 @@ class LanceDBStorage:
|
||||
if older_than is not None:
|
||||
conditions.append(f"created_at < '{older_than.isoformat()}'")
|
||||
if not conditions:
|
||||
before = self._table.count_rows()
|
||||
before = int(self._table.count_rows())
|
||||
self._do_write("delete", "id != ''")
|
||||
return before - self._table.count_rows()
|
||||
return before - int(self._table.count_rows())
|
||||
where_expr = " AND ".join(conditions)
|
||||
before = self._table.count_rows()
|
||||
before = int(self._table.count_rows())
|
||||
self._do_write("delete", where_expr)
|
||||
return before - self._table.count_rows()
|
||||
return before - int(self._table.count_rows())
|
||||
|
||||
def _scan_rows(
|
||||
self,
|
||||
@@ -496,6 +446,8 @@ class LanceDBStorage:
|
||||
Uses a full table scan (no vector query) so the limit is applied after
|
||||
the scope filter, not to ANN candidates before filtering.
|
||||
|
||||
Caller must hold ``store_lock(self._lock_name)``.
|
||||
|
||||
Args:
|
||||
scope_prefix: Optional scope path prefix to filter by.
|
||||
limit: Maximum number of rows to return (applied after filtering).
|
||||
@@ -505,13 +457,13 @@ class LanceDBStorage:
|
||||
"""
|
||||
if self._table is None:
|
||||
return []
|
||||
with self._write_lock:
|
||||
q = self._table.search()
|
||||
if scope_prefix is not None and scope_prefix.strip("/"):
|
||||
q = q.where(f"scope LIKE '{scope_prefix.rstrip('/')}%'")
|
||||
if columns is not None:
|
||||
q = q.select(columns)
|
||||
return q.limit(limit).to_list()
|
||||
q = self._table.search()
|
||||
if scope_prefix is not None and scope_prefix.strip("/"):
|
||||
q = q.where(f"scope LIKE '{scope_prefix.rstrip('/')}%'")
|
||||
if columns is not None:
|
||||
q = q.select(columns)
|
||||
result: list[dict[str, Any]] = q.limit(limit).to_list()
|
||||
return result
|
||||
|
||||
def list_records(
|
||||
self, scope_prefix: str | None = None, limit: int = 200, offset: int = 0
|
||||
@@ -526,7 +478,8 @@ class LanceDBStorage:
|
||||
Returns:
|
||||
List of MemoryRecord, ordered by created_at descending.
|
||||
"""
|
||||
rows = self._scan_rows(scope_prefix, limit=limit + offset)
|
||||
with store_lock(self._lock_name):
|
||||
rows = self._scan_rows(scope_prefix, limit=limit + offset)
|
||||
records = [self._row_to_record(r) for r in rows]
|
||||
records.sort(key=lambda r: r.created_at, reverse=True)
|
||||
return records[offset : offset + limit]
|
||||
@@ -536,10 +489,11 @@ class LanceDBStorage:
|
||||
prefix = scope if scope != "/" else ""
|
||||
if prefix and not prefix.startswith("/"):
|
||||
prefix = "/" + prefix
|
||||
rows = self._scan_rows(
|
||||
prefix or None,
|
||||
columns=["scope", "categories_str", "created_at"],
|
||||
)
|
||||
with store_lock(self._lock_name):
|
||||
rows = self._scan_rows(
|
||||
prefix or None,
|
||||
columns=["scope", "categories_str", "created_at"],
|
||||
)
|
||||
if not rows:
|
||||
return ScopeInfo(
|
||||
path=scope or "/",
|
||||
@@ -590,7 +544,8 @@ class LanceDBStorage:
|
||||
def list_scopes(self, parent: str = "/") -> list[str]:
|
||||
parent = parent.rstrip("/") or ""
|
||||
prefix = (parent + "/") if parent else "/"
|
||||
rows = self._scan_rows(prefix if prefix != "/" else None, columns=["scope"])
|
||||
with store_lock(self._lock_name):
|
||||
rows = self._scan_rows(prefix if prefix != "/" else None, columns=["scope"])
|
||||
children: set[str] = set()
|
||||
for row in rows:
|
||||
sc = str(row.get("scope", ""))
|
||||
@@ -602,7 +557,8 @@ class LanceDBStorage:
|
||||
return sorted(children)
|
||||
|
||||
def list_categories(self, scope_prefix: str | None = None) -> dict[str, int]:
|
||||
rows = self._scan_rows(scope_prefix, columns=["categories_str"])
|
||||
with store_lock(self._lock_name):
|
||||
rows = self._scan_rows(scope_prefix, columns=["categories_str"])
|
||||
counts: dict[str, int] = {}
|
||||
for row in rows:
|
||||
cat_str = row.get("categories_str") or "[]"
|
||||
@@ -618,12 +574,13 @@ class LanceDBStorage:
|
||||
if self._table is None:
|
||||
return 0
|
||||
if scope_prefix is None or scope_prefix.strip("/") == "":
|
||||
return self._table.count_rows()
|
||||
with store_lock(self._lock_name):
|
||||
return int(self._table.count_rows())
|
||||
info = self.get_scope_info(scope_prefix)
|
||||
return info.record_count
|
||||
|
||||
def reset(self, scope_prefix: str | None = None) -> None:
|
||||
with self._write_lock, self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
if scope_prefix is None or scope_prefix.strip("/") == "":
|
||||
if self._table is not None:
|
||||
self._db.drop_table(self._table_name)
|
||||
@@ -649,7 +606,7 @@ class LanceDBStorage:
|
||||
"""
|
||||
if self._table is None:
|
||||
return
|
||||
with self._write_lock, self._file_lock():
|
||||
with store_lock(self._lock_name):
|
||||
self._table.optimize()
|
||||
self._ensure_scope_index()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user