Compare commits

..

4 Commits

5 changed files with 369 additions and 73 deletions

View File

@@ -72,7 +72,8 @@ class SQLiteFlowPersistence(FlowPersistence):
def init_db(self) -> None:
"""Create the necessary tables if they don't exist."""
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("PRAGMA journal_mode=WAL")
# Main state table
conn.execute(
"""
@@ -136,7 +137,7 @@ class SQLiteFlowPersistence(FlowPersistence):
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
)
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute(
"""
INSERT INTO flow_states (
@@ -163,7 +164,7 @@ class SQLiteFlowPersistence(FlowPersistence):
Returns:
The most recent state as a dictionary, or None if no state exists
"""
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, timeout=30) as conn:
cursor = conn.execute(
"""
SELECT state_json
@@ -213,7 +214,7 @@ class SQLiteFlowPersistence(FlowPersistence):
self.save_state(flow_uuid, context.method_name, state_data)
# Save pending feedback context
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, timeout=30) as conn:
# Use INSERT OR REPLACE to handle re-triggering feedback on same flow
conn.execute(
"""
@@ -248,7 +249,7 @@ class SQLiteFlowPersistence(FlowPersistence):
# Import here to avoid circular imports
from crewai.flow.async_feedback.types import PendingFeedbackContext
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, timeout=30) as conn:
cursor = conn.execute(
"""
SELECT state_json, context_json
@@ -272,7 +273,7 @@ class SQLiteFlowPersistence(FlowPersistence):
Args:
flow_uuid: Unique identifier for the flow instance
"""
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute(
"""
DELETE FROM pending_feedback

View File

@@ -38,7 +38,8 @@ class KickoffTaskOutputsSQLiteStorage:
DatabaseOperationError: If database initialization fails due to SQLite errors.
"""
try:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("PRAGMA journal_mode=WAL")
cursor = conn.cursor()
cursor.execute(
"""
@@ -82,7 +83,7 @@ class KickoffTaskOutputsSQLiteStorage:
"""
inputs = inputs or {}
try:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("BEGIN TRANSACTION")
cursor = conn.cursor()
cursor.execute(
@@ -125,7 +126,7 @@ class KickoffTaskOutputsSQLiteStorage:
DatabaseOperationError: If updating the task output fails due to SQLite errors.
"""
try:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("BEGIN TRANSACTION")
cursor = conn.cursor()
@@ -166,7 +167,7 @@ class KickoffTaskOutputsSQLiteStorage:
DatabaseOperationError: If loading task outputs fails due to SQLite errors.
"""
try:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, timeout=30) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT *
@@ -205,7 +206,7 @@ class KickoffTaskOutputsSQLiteStorage:
DatabaseOperationError: If deleting task outputs fails due to SQLite errors.
"""
try:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("BEGIN TRANSACTION")
cursor = conn.cursor()
cursor.execute("DELETE FROM latest_kickoff_task_outputs")

View File

@@ -12,6 +12,7 @@ import time
from typing import Any, ClassVar
import lancedb
import portalocker
from crewai.memory.types import MemoryRecord, ScopeInfo
@@ -90,6 +91,7 @@ class LanceDBStorage:
# Raise it proactively so scans on large tables never hit OS error 24.
try:
import resource
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
if soft < 4096:
resource.setrlimit(resource.RLIMIT_NOFILE, (min(hard, 4096), hard))
@@ -99,7 +101,8 @@ class LanceDBStorage:
self._compact_every = compact_every
self._save_count = 0
# Get or create a shared write lock for this database path.
self._lockfile = str(self._path / ".lance_write.lock")
resolved = str(self._path.resolve())
with LanceDBStorage._path_locks_guard:
if resolved not in LanceDBStorage._path_locks:
@@ -110,10 +113,13 @@ class LanceDBStorage:
# If no table exists yet, defer creation until the first save so the
# dimension can be auto-detected from the embedder's actual output.
try:
self._table: lancedb.table.Table | None = self._db.open_table(self._table_name)
self._table: lancedb.table.Table | None = self._db.open_table(
self._table_name
)
self._vector_dim: int = self._infer_dim_from_table(self._table)
# Best-effort: create the scope index if it doesn't exist yet.
self._ensure_scope_index()
with self._file_lock():
self._ensure_scope_index()
# Compact in the background if the table has accumulated many
# fragments from previous runs (each save() creates one).
self._compact_if_needed()
@@ -124,7 +130,8 @@ class LanceDBStorage:
# Explicit dim provided: create the table immediately if it doesn't exist.
if self._table is None and vector_dim is not None:
self._vector_dim = vector_dim
self._table = self._create_table(vector_dim)
with self._file_lock():
self._table = self._create_table(vector_dim)
@property
def write_lock(self) -> threading.RLock:
@@ -149,18 +156,14 @@ class LanceDBStorage:
break
return DEFAULT_VECTOR_DIM
def _retry_write(self, op: str, *args: Any, **kwargs: Any) -> Any:
"""Execute a table operation with retry on LanceDB commit conflicts.
def _file_lock(self) -> portalocker.Lock:
"""Return a cross-process file lock for serialising writes."""
return portalocker.Lock(self._lockfile, timeout=120)
Args:
op: Method name on the table object (e.g. "add", "delete").
*args, **kwargs: Passed to the table method.
def _do_write(self, op: str, *args: Any, **kwargs: Any) -> Any:
"""Execute a single table write with retry on commit conflicts.
LanceDB uses optimistic concurrency: if two transactions overlap,
the second to commit fails with an ``OSError`` containing
"Commit conflict". This helper retries with exponential backoff,
refreshing the table reference before each retry so the retried
call uses the latest committed version (not a stale reference).
Caller must already hold the cross-process file lock.
"""
delay = _RETRY_BASE_DELAY
for attempt in range(_MAX_RETRIES + 1):
@@ -171,20 +174,24 @@ class LanceDBStorage:
raise
_logger.debug(
"LanceDB commit conflict on %s (attempt %d/%d), retrying in %.1fs",
op, attempt + 1, _MAX_RETRIES, delay,
op,
attempt + 1,
_MAX_RETRIES,
delay,
)
# Refresh table to pick up the latest version before retrying.
# The next getattr(self._table, op) will use the fresh table.
try:
self._table = self._db.open_table(self._table_name)
except Exception: # noqa: S110
pass # table refresh is best-effort
pass
time.sleep(delay)
delay *= 2
return None # unreachable, but satisfies type checker
def _create_table(self, vector_dim: int) -> lancedb.table.Table:
"""Create a new table with the given vector dimension."""
"""Create a new table with the given vector dimension.
Caller must already hold the cross-process file lock.
"""
placeholder = [
{
"id": "__schema_placeholder__",
@@ -200,8 +207,11 @@ class LanceDBStorage:
"vector": [0.0] * vector_dim,
}
]
table = self._db.create_table(self._table_name, placeholder)
table.delete("id = '__schema_placeholder__'")
try:
table = self._db.create_table(self._table_name, placeholder)
table.delete("id = '__schema_placeholder__'")
except ValueError:
table = self._db.open_table(self._table_name)
return table
def _ensure_scope_index(self) -> None:
@@ -248,9 +258,9 @@ class LanceDBStorage:
"""Run ``table.optimize()`` in a background thread, absorbing errors."""
try:
if self._table is not None:
self._table.optimize()
# Refresh the scope index so new fragments are covered.
self._ensure_scope_index()
with self._file_lock():
self._table.optimize()
self._ensure_scope_index()
except Exception:
_logger.debug("LanceDB background compaction failed", exc_info=True)
@@ -280,7 +290,9 @@ class LanceDBStorage:
"last_accessed": record.last_accessed.isoformat(),
"source": record.source or "",
"private": record.private,
"vector": record.embedding if record.embedding else [0.0] * self._vector_dim,
"vector": record.embedding
if record.embedding
else [0.0] * self._vector_dim,
}
def _row_to_record(self, row: dict[str, Any]) -> MemoryRecord:
@@ -296,7 +308,9 @@ class LanceDBStorage:
id=str(row["id"]),
content=str(row["content"]),
scope=str(row["scope"]),
categories=json.loads(row["categories_str"]) if row.get("categories_str") else [],
categories=json.loads(row["categories_str"])
if row.get("categories_str")
else [],
metadata=json.loads(row["metadata_str"]) if row.get("metadata_str") else {},
importance=float(row.get("importance", 0.5)),
created_at=_parse_dt(row.get("created_at")),
@@ -316,16 +330,15 @@ class LanceDBStorage:
dim = len(r.embedding)
break
is_new_table = self._table is None
with self._write_lock:
with self._write_lock, self._file_lock():
self._ensure_table(vector_dim=dim)
rows = [self._record_to_row(r) for r in records]
for r in rows:
if r["vector"] is None or len(r["vector"]) != self._vector_dim:
r["vector"] = [0.0] * self._vector_dim
self._retry_write("add", rows)
# Create the scope index on the first save so it covers the initial dataset.
if is_new_table:
self._ensure_scope_index()
self._do_write("add", rows)
if is_new_table:
self._ensure_scope_index()
# Auto-compact every N saves so fragment files don't pile up.
self._save_count += 1
if self._compact_every > 0 and self._save_count % self._compact_every == 0:
@@ -333,14 +346,14 @@ class LanceDBStorage:
def update(self, record: MemoryRecord) -> None:
"""Update a record by ID. Preserves created_at, updates last_accessed."""
with self._write_lock:
with self._write_lock, self._file_lock():
self._ensure_table()
safe_id = str(record.id).replace("'", "''")
self._retry_write("delete", f"id = '{safe_id}'")
self._do_write("delete", f"id = '{safe_id}'")
row = self._record_to_row(record)
if row["vector"] is None or len(row["vector"]) != self._vector_dim:
row["vector"] = [0.0] * self._vector_dim
self._retry_write("add", [row])
self._do_write("add", [row])
def touch_records(self, record_ids: list[str]) -> None:
"""Update last_accessed to now for the given record IDs.
@@ -354,11 +367,11 @@ class LanceDBStorage:
"""
if not record_ids or self._table is None:
return
with self._write_lock:
with self._write_lock, self._file_lock():
now = datetime.utcnow().isoformat()
safe_ids = [str(rid).replace("'", "''") for rid in record_ids]
ids_expr = ", ".join(f"'{rid}'" for rid in safe_ids)
self._retry_write(
self._do_write(
"update",
where=f"id IN ({ids_expr})",
values={"last_accessed": now},
@@ -390,13 +403,17 @@ class LanceDBStorage:
prefix = scope_prefix.rstrip("/")
like_val = prefix + "%"
query = query.where(f"scope LIKE '{like_val}'")
results = query.limit(limit * 3 if (categories or metadata_filter) else limit).to_list()
results = query.limit(
limit * 3 if (categories or metadata_filter) else limit
).to_list()
out: list[tuple[MemoryRecord, float]] = []
for row in results:
record = self._row_to_record(row)
if categories and not any(c in record.categories for c in categories):
continue
if metadata_filter and not all(record.metadata.get(k) == v for k, v in metadata_filter.items()):
if metadata_filter and not all(
record.metadata.get(k) == v for k, v in metadata_filter.items()
):
continue
distance = row.get("_distance", 0.0)
score = 1.0 / (1.0 + float(distance)) if distance is not None else 1.0
@@ -416,20 +433,24 @@ class LanceDBStorage:
) -> int:
if self._table is None:
return 0
with self._write_lock:
with self._write_lock, self._file_lock():
if record_ids and not (categories or metadata_filter):
before = self._table.count_rows()
ids_expr = ", ".join(f"'{rid}'" for rid in record_ids)
self._retry_write("delete", f"id IN ({ids_expr})")
self._do_write("delete", f"id IN ({ids_expr})")
return before - self._table.count_rows()
if categories or metadata_filter:
rows = self._scan_rows(scope_prefix)
to_delete: list[str] = []
for row in rows:
record = self._row_to_record(row)
if categories and not any(c in record.categories for c in categories):
if categories and not any(
c in record.categories for c in categories
):
continue
if metadata_filter and not all(record.metadata.get(k) == v for k, v in metadata_filter.items()):
if metadata_filter and not all(
record.metadata.get(k) == v for k, v in metadata_filter.items()
):
continue
if older_than and record.created_at >= older_than:
continue
@@ -438,7 +459,7 @@ class LanceDBStorage:
return 0
before = self._table.count_rows()
ids_expr = ", ".join(f"'{rid}'" for rid in to_delete)
self._retry_write("delete", f"id IN ({ids_expr})")
self._do_write("delete", f"id IN ({ids_expr})")
return before - self._table.count_rows()
conditions = []
if scope_prefix is not None and scope_prefix.strip("/"):
@@ -450,11 +471,11 @@ class LanceDBStorage:
conditions.append(f"created_at < '{older_than.isoformat()}'")
if not conditions:
before = self._table.count_rows()
self._retry_write("delete", "id != ''")
self._do_write("delete", "id != ''")
return before - self._table.count_rows()
where_expr = " AND ".join(conditions)
before = self._table.count_rows()
self._retry_write("delete", where_expr)
self._do_write("delete", where_expr)
return before - self._table.count_rows()
def _scan_rows(
@@ -528,7 +549,7 @@ class LanceDBStorage:
for row in rows:
sc = str(row.get("scope", ""))
if child_prefix and sc.startswith(child_prefix):
rest = sc[len(child_prefix):]
rest = sc[len(child_prefix) :]
first_component = rest.split("/", 1)[0]
if first_component:
children.add(child_prefix + first_component)
@@ -539,7 +560,11 @@ class LanceDBStorage:
pass
created = row.get("created_at")
if created:
dt = datetime.fromisoformat(str(created).replace("Z", "+00:00")) if isinstance(created, str) else created
dt = (
datetime.fromisoformat(str(created).replace("Z", "+00:00"))
if isinstance(created, str)
else created
)
if isinstance(dt, datetime):
if oldest is None or dt < oldest:
oldest = dt
@@ -562,7 +587,7 @@ class LanceDBStorage:
for row in rows:
sc = str(row.get("scope", ""))
if sc.startswith(prefix) and sc != (prefix.rstrip("/") or "/"):
rest = sc[len(prefix):]
rest = sc[len(prefix) :]
first_component = rest.split("/", 1)[0]
if first_component:
children.add(prefix + first_component)
@@ -590,17 +615,17 @@ class LanceDBStorage:
return info.record_count
def reset(self, scope_prefix: str | None = None) -> None:
if scope_prefix is None or scope_prefix.strip("/") == "":
if self._table is not None:
self._db.drop_table(self._table_name)
self._table = None
# Dimension is preserved; table will be recreated on next save.
return
if self._table is None:
return
prefix = scope_prefix.rstrip("/")
if prefix:
self._table.delete(f"scope >= '{prefix}' AND scope < '{prefix}/\uFFFF'")
with self._write_lock, self._file_lock():
if scope_prefix is None or scope_prefix.strip("/") == "":
if self._table is not None:
self._db.drop_table(self._table_name)
self._table = None
return
if self._table is None:
return
prefix = scope_prefix.rstrip("/")
if prefix:
self._do_write("delete", f"scope >= '{prefix}' AND scope < '{prefix}/\uffff'")
def optimize(self) -> None:
"""Compact the table synchronously and refresh the scope index.
@@ -614,8 +639,9 @@ class LanceDBStorage:
"""
if self._table is None:
return
self._table.optimize()
self._ensure_scope_index()
with self._write_lock, self._file_lock():
self._table.optimize()
self._ensure_scope_index()
async def asave(self, records: list[MemoryRecord]) -> None:
self.save(records)

View File

@@ -28,7 +28,7 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest()
lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock")
with portalocker.Lock(lockfile):
with portalocker.Lock(lockfile, timeout=120):
client = PersistentClient(
path=persist_dir,
settings=config.settings,

View File

@@ -0,0 +1,268 @@
"""Stress tests for concurrent multi-process storage access.
Simulates the Airflow pattern: N worker processes each writing to the
same storage directory simultaneously. Verifies no LockException and
data integrity after all writes complete.
Uses temp files for IPC instead of multiprocessing.Manager (which uses
sockets blocked by pytest_recording).
"""
import json
import multiprocessing
import os
import sqlite3
import tempfile
from pathlib import Path
import pytest
# ---------------------------------------------------------------------------
# File-based IPC helpers (avoids Manager sockets)
# ---------------------------------------------------------------------------
def _write_result(result_dir: str, worker_id: int, success: bool, error: str = ""):
path = os.path.join(result_dir, f"worker-{worker_id}.json")
with open(path, "w") as f:
json.dump({"success": success, "error": error}, f)
def _collect_results(result_dir: str, n_workers: int):
errors = {}
successes = 0
for wid in range(n_workers):
path = os.path.join(result_dir, f"worker-{wid}.json")
if not os.path.exists(path):
errors[wid] = "Process produced no output (crashed or timed out)"
continue
with open(path) as f:
data = json.load(f)
if data["success"]:
successes += 1
else:
errors[wid] = data["error"]
return successes, errors
# ---------------------------------------------------------------------------
# Worker functions
# ---------------------------------------------------------------------------
def _lancedb_worker(path: str, worker_id: int, n_records: int, result_dir: str):
try:
from crewai.memory.storage.lancedb_storage import LanceDBStorage
from crewai.memory.types import MemoryRecord
storage = LanceDBStorage(path=path, table_name="memories", vector_dim=8)
records = [
MemoryRecord(
id=f"worker-{worker_id}-record-{i}",
content=f"content from worker {worker_id} record {i}",
scope=f"/test/worker-{worker_id}",
categories=["test"],
metadata={"worker": worker_id},
importance=0.5,
embedding=[float(worker_id)] * 8,
)
for i in range(n_records)
]
storage.save(records)
_write_result(result_dir, worker_id, True)
except Exception as e:
_write_result(result_dir, worker_id, False, f"{type(e).__name__}: {e}")
def _sqlite_kickoff_worker(db_path: str, worker_id: int, n_writes: int, result_dir: str):
try:
from crewai.memory.storage.kickoff_task_outputs_storage import (
KickoffTaskOutputsSQLiteStorage,
)
KickoffTaskOutputsSQLiteStorage(db_path=db_path)
for i in range(n_writes):
with sqlite3.connect(db_path, timeout=30) as conn:
conn.execute("PRAGMA journal_mode=WAL")
conn.execute(
"""INSERT OR REPLACE INTO latest_kickoff_task_outputs
(task_id, expected_output, output, task_index, inputs, was_replayed)
VALUES (?, ?, ?, ?, ?, ?)""",
(
f"worker-{worker_id}-task-{i}",
"expected output",
'{"result": "ok"}',
worker_id * 1000 + i,
"{}",
False,
),
)
_write_result(result_dir, worker_id, True)
except Exception as e:
_write_result(result_dir, worker_id, False, f"{type(e).__name__}: {e}")
def _sqlite_flow_worker(db_path: str, worker_id: int, n_writes: int, result_dir: str):
try:
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
persistence = SQLiteFlowPersistence(db_path=db_path)
for i in range(n_writes):
persistence.save_state(
flow_uuid=f"flow-{worker_id}-{i}",
method_name="test_method",
state_data={"worker": worker_id, "iteration": i},
)
_write_result(result_dir, worker_id, True)
except Exception as e:
_write_result(result_dir, worker_id, False, f"{type(e).__name__}: {e}")
def _chromadb_worker(persist_dir: str, worker_id: int, result_dir: str):
try:
from hashlib import md5
from chromadb import PersistentClient
from chromadb.config import Settings
import portalocker
settings = Settings(
persist_directory=persist_dir,
anonymized_telemetry=False,
is_persistent=True,
)
# Test the actual locking path directly (same as factory.py)
lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest()
lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock")
with portalocker.Lock(lockfile, timeout=120):
PersistentClient(path=persist_dir, settings=settings)
_write_result(result_dir, worker_id, True)
except Exception as e:
_write_result(result_dir, worker_id, False, f"{type(e).__name__}: {e}")
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
N_WORKERS = 6
N_RECORDS = 20
def _run_workers(target, args_fn, n_workers=N_WORKERS, timeout=120):
"""Spawn n_workers processes and collect results via temp files."""
with tempfile.TemporaryDirectory() as result_dir:
procs = []
for wid in range(n_workers):
p = multiprocessing.Process(
target=target,
args=args_fn(wid, result_dir),
)
procs.append(p)
for p in procs:
p.start()
for p in procs:
p.join(timeout=timeout)
successes, errors = _collect_results(result_dir, n_workers)
return successes, errors
class TestConcurrentLanceDB:
"""Concurrent multi-process writes to LanceDB."""
def test_concurrent_saves_no_lock_exception(self, tmp_path):
db_path = str(tmp_path / "lancedb_concurrent")
successes, errors = _run_workers(
_lancedb_worker,
lambda wid, rd: (db_path, wid, N_RECORDS, rd),
)
assert not errors, f"Workers failed: {errors}"
assert successes == N_WORKERS
def test_data_integrity_after_concurrent_saves(self, tmp_path):
db_path = str(tmp_path / "lancedb_integrity")
successes, errors = _run_workers(
_lancedb_worker,
lambda wid, rd: (db_path, wid, N_RECORDS, rd),
)
assert not errors, f"Workers failed: {errors}"
from crewai.memory.storage.lancedb_storage import LanceDBStorage
storage = LanceDBStorage(path=db_path, table_name="memories", vector_dim=8)
total = storage.count()
expected = N_WORKERS * N_RECORDS
assert total == expected, f"Expected {expected} records, got {total}"
class TestConcurrentSQLiteKickoff:
"""Concurrent multi-process writes to kickoff task outputs SQLite."""
def test_concurrent_writes_no_error(self, tmp_path):
db_path = str(tmp_path / "kickoff.db")
from crewai.memory.storage.kickoff_task_outputs_storage import (
KickoffTaskOutputsSQLiteStorage,
)
KickoffTaskOutputsSQLiteStorage(db_path=db_path)
successes, errors = _run_workers(
_sqlite_kickoff_worker,
lambda wid, rd: (db_path, wid, N_RECORDS, rd),
timeout=60,
)
assert not errors, f"Workers failed: {errors}"
assert successes == N_WORKERS
with sqlite3.connect(db_path, timeout=30) as conn:
count = conn.execute(
"SELECT COUNT(*) FROM latest_kickoff_task_outputs"
).fetchone()[0]
expected = N_WORKERS * N_RECORDS
assert count == expected, f"Expected {expected} rows, got {count}"
class TestConcurrentSQLiteFlow:
"""Concurrent multi-process writes to flow persistence SQLite."""
def test_concurrent_writes_no_error(self, tmp_path):
db_path = str(tmp_path / "flow_states.db")
successes, errors = _run_workers(
_sqlite_flow_worker,
lambda wid, rd: (db_path, wid, N_RECORDS, rd),
timeout=60,
)
assert not errors, f"Workers failed: {errors}"
assert successes == N_WORKERS
with sqlite3.connect(db_path, timeout=30) as conn:
count = conn.execute("SELECT COUNT(*) FROM flow_states").fetchone()[0]
expected = N_WORKERS * N_RECORDS
assert count == expected, f"Expected {expected} rows, got {count}"
class TestConcurrentChromaDB:
"""Concurrent multi-process ChromaDB client creation."""
def test_concurrent_client_creation_no_lock_exception(self, tmp_path):
persist_dir = str(tmp_path / "chromadb_concurrent")
os.makedirs(persist_dir, exist_ok=True)
successes, errors = _run_workers(
_chromadb_worker,
lambda wid, rd: (persist_dir, wid, rd),
)
assert not errors, f"Workers failed: {errors}"
assert successes == N_WORKERS