fix: handle concurrent table creation race and increase SQLite connect timeout

This commit is contained in:
Greyson LaLonde
2026-03-10 23:07:19 -04:00
parent a037ade1ca
commit 67bc64e82c
4 changed files with 284 additions and 13 deletions

View File

@@ -72,7 +72,7 @@ class SQLiteFlowPersistence(FlowPersistence):
def init_db(self) -> None:
"""Create the necessary tables if they don't exist."""
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("PRAGMA journal_mode=WAL")
# Main state table
conn.execute(
@@ -137,7 +137,7 @@ class SQLiteFlowPersistence(FlowPersistence):
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
)
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute(
"""
INSERT INTO flow_states (
@@ -164,7 +164,7 @@ class SQLiteFlowPersistence(FlowPersistence):
Returns:
The most recent state as a dictionary, or None if no state exists
"""
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, timeout=30) as conn:
cursor = conn.execute(
"""
SELECT state_json
@@ -214,7 +214,7 @@ class SQLiteFlowPersistence(FlowPersistence):
self.save_state(flow_uuid, context.method_name, state_data)
# Save pending feedback context
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, timeout=30) as conn:
# Use INSERT OR REPLACE to handle re-triggering feedback on same flow
conn.execute(
"""
@@ -249,7 +249,7 @@ class SQLiteFlowPersistence(FlowPersistence):
# Import here to avoid circular imports
from crewai.flow.async_feedback.types import PendingFeedbackContext
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, timeout=30) as conn:
cursor = conn.execute(
"""
SELECT state_json, context_json
@@ -273,7 +273,7 @@ class SQLiteFlowPersistence(FlowPersistence):
Args:
flow_uuid: Unique identifier for the flow instance
"""
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute(
"""
DELETE FROM pending_feedback

View File

@@ -38,7 +38,7 @@ class KickoffTaskOutputsSQLiteStorage:
DatabaseOperationError: If database initialization fails due to SQLite errors.
"""
try:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("PRAGMA journal_mode=WAL")
cursor = conn.cursor()
cursor.execute(
@@ -83,7 +83,7 @@ class KickoffTaskOutputsSQLiteStorage:
"""
inputs = inputs or {}
try:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("BEGIN TRANSACTION")
cursor = conn.cursor()
cursor.execute(
@@ -126,7 +126,7 @@ class KickoffTaskOutputsSQLiteStorage:
DatabaseOperationError: If updating the task output fails due to SQLite errors.
"""
try:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("BEGIN TRANSACTION")
cursor = conn.cursor()
@@ -167,7 +167,7 @@ class KickoffTaskOutputsSQLiteStorage:
DatabaseOperationError: If loading task outputs fails due to SQLite errors.
"""
try:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, timeout=30) as conn:
cursor = conn.cursor()
cursor.execute("""
SELECT *
@@ -206,7 +206,7 @@ class KickoffTaskOutputsSQLiteStorage:
DatabaseOperationError: If deleting task outputs fails due to SQLite errors.
"""
try:
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(self.db_path, timeout=30) as conn:
conn.execute("BEGIN TRANSACTION")
cursor = conn.cursor()
cursor.execute("DELETE FROM latest_kickoff_task_outputs")

View File

@@ -207,8 +207,11 @@ class LanceDBStorage:
"vector": [0.0] * vector_dim,
}
]
table = self._db.create_table(self._table_name, placeholder)
table.delete("id = '__schema_placeholder__'")
try:
table = self._db.create_table(self._table_name, placeholder)
table.delete("id = '__schema_placeholder__'")
except ValueError:
table = self._db.open_table(self._table_name)
return table
def _ensure_scope_index(self) -> None:

View File

@@ -0,0 +1,268 @@
"""Stress tests for concurrent multi-process storage access.
Simulates the Airflow pattern: N worker processes each writing to the
same storage directory simultaneously. Verifies no LockException and
data integrity after all writes complete.
Uses temp files for IPC instead of multiprocessing.Manager (which uses
sockets blocked by pytest_recording).
"""
import json
import multiprocessing
import os
import sqlite3
import tempfile
from pathlib import Path
import pytest
# ---------------------------------------------------------------------------
# File-based IPC helpers (avoids Manager sockets)
# ---------------------------------------------------------------------------
def _write_result(result_dir: str, worker_id: int, success: bool, error: str = ""):
path = os.path.join(result_dir, f"worker-{worker_id}.json")
with open(path, "w") as f:
json.dump({"success": success, "error": error}, f)
def _collect_results(result_dir: str, n_workers: int):
errors = {}
successes = 0
for wid in range(n_workers):
path = os.path.join(result_dir, f"worker-{wid}.json")
if not os.path.exists(path):
errors[wid] = "Process produced no output (crashed or timed out)"
continue
with open(path) as f:
data = json.load(f)
if data["success"]:
successes += 1
else:
errors[wid] = data["error"]
return successes, errors
# ---------------------------------------------------------------------------
# Worker functions
# ---------------------------------------------------------------------------
def _lancedb_worker(path: str, worker_id: int, n_records: int, result_dir: str):
try:
from crewai.memory.storage.lancedb_storage import LanceDBStorage
from crewai.memory.types import MemoryRecord
storage = LanceDBStorage(path=path, table_name="memories", vector_dim=8)
records = [
MemoryRecord(
id=f"worker-{worker_id}-record-{i}",
content=f"content from worker {worker_id} record {i}",
scope=f"/test/worker-{worker_id}",
categories=["test"],
metadata={"worker": worker_id},
importance=0.5,
embedding=[float(worker_id)] * 8,
)
for i in range(n_records)
]
storage.save(records)
_write_result(result_dir, worker_id, True)
except Exception as e:
_write_result(result_dir, worker_id, False, f"{type(e).__name__}: {e}")
def _sqlite_kickoff_worker(db_path: str, worker_id: int, n_writes: int, result_dir: str):
try:
from crewai.memory.storage.kickoff_task_outputs_storage import (
KickoffTaskOutputsSQLiteStorage,
)
KickoffTaskOutputsSQLiteStorage(db_path=db_path)
for i in range(n_writes):
with sqlite3.connect(db_path, timeout=30) as conn:
conn.execute("PRAGMA journal_mode=WAL")
conn.execute(
"""INSERT OR REPLACE INTO latest_kickoff_task_outputs
(task_id, expected_output, output, task_index, inputs, was_replayed)
VALUES (?, ?, ?, ?, ?, ?)""",
(
f"worker-{worker_id}-task-{i}",
"expected output",
'{"result": "ok"}',
worker_id * 1000 + i,
"{}",
False,
),
)
_write_result(result_dir, worker_id, True)
except Exception as e:
_write_result(result_dir, worker_id, False, f"{type(e).__name__}: {e}")
def _sqlite_flow_worker(db_path: str, worker_id: int, n_writes: int, result_dir: str):
try:
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
persistence = SQLiteFlowPersistence(db_path=db_path)
for i in range(n_writes):
persistence.save_state(
flow_uuid=f"flow-{worker_id}-{i}",
method_name="test_method",
state_data={"worker": worker_id, "iteration": i},
)
_write_result(result_dir, worker_id, True)
except Exception as e:
_write_result(result_dir, worker_id, False, f"{type(e).__name__}: {e}")
def _chromadb_worker(persist_dir: str, worker_id: int, result_dir: str):
try:
from hashlib import md5
from chromadb import PersistentClient
from chromadb.config import Settings
import portalocker
settings = Settings(
persist_directory=persist_dir,
anonymized_telemetry=False,
is_persistent=True,
)
# Test the actual locking path directly (same as factory.py)
lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest()
lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock")
with portalocker.Lock(lockfile, timeout=120):
PersistentClient(path=persist_dir, settings=settings)
_write_result(result_dir, worker_id, True)
except Exception as e:
_write_result(result_dir, worker_id, False, f"{type(e).__name__}: {e}")
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
N_WORKERS = 6
N_RECORDS = 20
def _run_workers(target, args_fn, n_workers=N_WORKERS, timeout=120):
"""Spawn n_workers processes and collect results via temp files."""
with tempfile.TemporaryDirectory() as result_dir:
procs = []
for wid in range(n_workers):
p = multiprocessing.Process(
target=target,
args=args_fn(wid, result_dir),
)
procs.append(p)
for p in procs:
p.start()
for p in procs:
p.join(timeout=timeout)
successes, errors = _collect_results(result_dir, n_workers)
return successes, errors
class TestConcurrentLanceDB:
"""Concurrent multi-process writes to LanceDB."""
def test_concurrent_saves_no_lock_exception(self, tmp_path):
db_path = str(tmp_path / "lancedb_concurrent")
successes, errors = _run_workers(
_lancedb_worker,
lambda wid, rd: (db_path, wid, N_RECORDS, rd),
)
assert not errors, f"Workers failed: {errors}"
assert successes == N_WORKERS
def test_data_integrity_after_concurrent_saves(self, tmp_path):
db_path = str(tmp_path / "lancedb_integrity")
successes, errors = _run_workers(
_lancedb_worker,
lambda wid, rd: (db_path, wid, N_RECORDS, rd),
)
assert not errors, f"Workers failed: {errors}"
from crewai.memory.storage.lancedb_storage import LanceDBStorage
storage = LanceDBStorage(path=db_path, table_name="memories", vector_dim=8)
total = storage.count()
expected = N_WORKERS * N_RECORDS
assert total == expected, f"Expected {expected} records, got {total}"
class TestConcurrentSQLiteKickoff:
"""Concurrent multi-process writes to kickoff task outputs SQLite."""
def test_concurrent_writes_no_error(self, tmp_path):
db_path = str(tmp_path / "kickoff.db")
from crewai.memory.storage.kickoff_task_outputs_storage import (
KickoffTaskOutputsSQLiteStorage,
)
KickoffTaskOutputsSQLiteStorage(db_path=db_path)
successes, errors = _run_workers(
_sqlite_kickoff_worker,
lambda wid, rd: (db_path, wid, N_RECORDS, rd),
timeout=60,
)
assert not errors, f"Workers failed: {errors}"
assert successes == N_WORKERS
with sqlite3.connect(db_path, timeout=30) as conn:
count = conn.execute(
"SELECT COUNT(*) FROM latest_kickoff_task_outputs"
).fetchone()[0]
expected = N_WORKERS * N_RECORDS
assert count == expected, f"Expected {expected} rows, got {count}"
class TestConcurrentSQLiteFlow:
"""Concurrent multi-process writes to flow persistence SQLite."""
def test_concurrent_writes_no_error(self, tmp_path):
db_path = str(tmp_path / "flow_states.db")
successes, errors = _run_workers(
_sqlite_flow_worker,
lambda wid, rd: (db_path, wid, N_RECORDS, rd),
timeout=60,
)
assert not errors, f"Workers failed: {errors}"
assert successes == N_WORKERS
with sqlite3.connect(db_path, timeout=30) as conn:
count = conn.execute("SELECT COUNT(*) FROM flow_states").fetchone()[0]
expected = N_WORKERS * N_RECORDS
assert count == expected, f"Expected {expected} rows, got {count}"
class TestConcurrentChromaDB:
"""Concurrent multi-process ChromaDB client creation."""
def test_concurrent_client_creation_no_lock_exception(self, tmp_path):
persist_dir = str(tmp_path / "chromadb_concurrent")
os.makedirs(persist_dir, exist_ok=True)
successes, errors = _run_workers(
_chromadb_worker,
lambda wid, rd: (persist_dir, wid, rd),
)
assert not errors, f"Workers failed: {errors}"
assert successes == N_WORKERS