mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 14:52:36 +00:00
fix: handle concurrent table creation race and increase SQLite connect timeout
This commit is contained in:
@@ -72,7 +72,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
|
||||
def init_db(self) -> None:
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
# Main state table
|
||||
conn.execute(
|
||||
@@ -137,7 +137,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
|
||||
)
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO flow_states (
|
||||
@@ -164,7 +164,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
Returns:
|
||||
The most recent state as a dictionary, or None if no state exists
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT state_json
|
||||
@@ -214,7 +214,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
self.save_state(flow_uuid, context.method_name, state_data)
|
||||
|
||||
# Save pending feedback context
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
# Use INSERT OR REPLACE to handle re-triggering feedback on same flow
|
||||
conn.execute(
|
||||
"""
|
||||
@@ -249,7 +249,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
# Import here to avoid circular imports
|
||||
from crewai.flow.async_feedback.types import PendingFeedbackContext
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT state_json, context_json
|
||||
@@ -273,7 +273,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
DELETE FROM pending_feedback
|
||||
|
||||
@@ -38,7 +38,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If database initialization fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
@@ -83,7 +83,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
"""
|
||||
inputs = inputs or {}
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
@@ -126,7 +126,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If updating the task output fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
|
||||
@@ -167,7 +167,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If loading task outputs fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT *
|
||||
@@ -206,7 +206,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If deleting task outputs fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM latest_kickoff_task_outputs")
|
||||
|
||||
@@ -207,8 +207,11 @@ class LanceDBStorage:
|
||||
"vector": [0.0] * vector_dim,
|
||||
}
|
||||
]
|
||||
table = self._db.create_table(self._table_name, placeholder)
|
||||
table.delete("id = '__schema_placeholder__'")
|
||||
try:
|
||||
table = self._db.create_table(self._table_name, placeholder)
|
||||
table.delete("id = '__schema_placeholder__'")
|
||||
except ValueError:
|
||||
table = self._db.open_table(self._table_name)
|
||||
return table
|
||||
|
||||
def _ensure_scope_index(self) -> None:
|
||||
|
||||
268
lib/crewai/tests/memory/test_concurrent_storage.py
Normal file
268
lib/crewai/tests/memory/test_concurrent_storage.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""Stress tests for concurrent multi-process storage access.
|
||||
|
||||
Simulates the Airflow pattern: N worker processes each writing to the
|
||||
same storage directory simultaneously. Verifies no LockException and
|
||||
data integrity after all writes complete.
|
||||
|
||||
Uses temp files for IPC instead of multiprocessing.Manager (which uses
|
||||
sockets blocked by pytest_recording).
|
||||
"""
|
||||
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File-based IPC helpers (avoids Manager sockets)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _write_result(result_dir: str, worker_id: int, success: bool, error: str = ""):
|
||||
path = os.path.join(result_dir, f"worker-{worker_id}.json")
|
||||
with open(path, "w") as f:
|
||||
json.dump({"success": success, "error": error}, f)
|
||||
|
||||
|
||||
def _collect_results(result_dir: str, n_workers: int):
|
||||
errors = {}
|
||||
successes = 0
|
||||
for wid in range(n_workers):
|
||||
path = os.path.join(result_dir, f"worker-{wid}.json")
|
||||
if not os.path.exists(path):
|
||||
errors[wid] = "Process produced no output (crashed or timed out)"
|
||||
continue
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
if data["success"]:
|
||||
successes += 1
|
||||
else:
|
||||
errors[wid] = data["error"]
|
||||
return successes, errors
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Worker functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _lancedb_worker(path: str, worker_id: int, n_records: int, result_dir: str):
|
||||
try:
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
from crewai.memory.types import MemoryRecord
|
||||
|
||||
storage = LanceDBStorage(path=path, table_name="memories", vector_dim=8)
|
||||
records = [
|
||||
MemoryRecord(
|
||||
id=f"worker-{worker_id}-record-{i}",
|
||||
content=f"content from worker {worker_id} record {i}",
|
||||
scope=f"/test/worker-{worker_id}",
|
||||
categories=["test"],
|
||||
metadata={"worker": worker_id},
|
||||
importance=0.5,
|
||||
embedding=[float(worker_id)] * 8,
|
||||
)
|
||||
for i in range(n_records)
|
||||
]
|
||||
storage.save(records)
|
||||
_write_result(result_dir, worker_id, True)
|
||||
except Exception as e:
|
||||
_write_result(result_dir, worker_id, False, f"{type(e).__name__}: {e}")
|
||||
|
||||
|
||||
def _sqlite_kickoff_worker(db_path: str, worker_id: int, n_writes: int, result_dir: str):
|
||||
try:
|
||||
from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
KickoffTaskOutputsSQLiteStorage,
|
||||
)
|
||||
|
||||
KickoffTaskOutputsSQLiteStorage(db_path=db_path)
|
||||
for i in range(n_writes):
|
||||
with sqlite3.connect(db_path, timeout=30) as conn:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute(
|
||||
"""INSERT OR REPLACE INTO latest_kickoff_task_outputs
|
||||
(task_id, expected_output, output, task_index, inputs, was_replayed)
|
||||
VALUES (?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
f"worker-{worker_id}-task-{i}",
|
||||
"expected output",
|
||||
'{"result": "ok"}',
|
||||
worker_id * 1000 + i,
|
||||
"{}",
|
||||
False,
|
||||
),
|
||||
)
|
||||
_write_result(result_dir, worker_id, True)
|
||||
except Exception as e:
|
||||
_write_result(result_dir, worker_id, False, f"{type(e).__name__}: {e}")
|
||||
|
||||
|
||||
def _sqlite_flow_worker(db_path: str, worker_id: int, n_writes: int, result_dir: str):
|
||||
try:
|
||||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||
|
||||
persistence = SQLiteFlowPersistence(db_path=db_path)
|
||||
for i in range(n_writes):
|
||||
persistence.save_state(
|
||||
flow_uuid=f"flow-{worker_id}-{i}",
|
||||
method_name="test_method",
|
||||
state_data={"worker": worker_id, "iteration": i},
|
||||
)
|
||||
_write_result(result_dir, worker_id, True)
|
||||
except Exception as e:
|
||||
_write_result(result_dir, worker_id, False, f"{type(e).__name__}: {e}")
|
||||
|
||||
|
||||
def _chromadb_worker(persist_dir: str, worker_id: int, result_dir: str):
|
||||
try:
|
||||
from hashlib import md5
|
||||
|
||||
from chromadb import PersistentClient
|
||||
from chromadb.config import Settings
|
||||
import portalocker
|
||||
|
||||
settings = Settings(
|
||||
persist_directory=persist_dir,
|
||||
anonymized_telemetry=False,
|
||||
is_persistent=True,
|
||||
)
|
||||
|
||||
# Test the actual locking path directly (same as factory.py)
|
||||
lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest()
|
||||
lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock")
|
||||
with portalocker.Lock(lockfile, timeout=120):
|
||||
PersistentClient(path=persist_dir, settings=settings)
|
||||
|
||||
_write_result(result_dir, worker_id, True)
|
||||
except Exception as e:
|
||||
_write_result(result_dir, worker_id, False, f"{type(e).__name__}: {e}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
N_WORKERS = 6
|
||||
N_RECORDS = 20
|
||||
|
||||
|
||||
def _run_workers(target, args_fn, n_workers=N_WORKERS, timeout=120):
|
||||
"""Spawn n_workers processes and collect results via temp files."""
|
||||
with tempfile.TemporaryDirectory() as result_dir:
|
||||
procs = []
|
||||
for wid in range(n_workers):
|
||||
p = multiprocessing.Process(
|
||||
target=target,
|
||||
args=args_fn(wid, result_dir),
|
||||
)
|
||||
procs.append(p)
|
||||
|
||||
for p in procs:
|
||||
p.start()
|
||||
for p in procs:
|
||||
p.join(timeout=timeout)
|
||||
|
||||
successes, errors = _collect_results(result_dir, n_workers)
|
||||
return successes, errors
|
||||
|
||||
|
||||
class TestConcurrentLanceDB:
|
||||
"""Concurrent multi-process writes to LanceDB."""
|
||||
|
||||
def test_concurrent_saves_no_lock_exception(self, tmp_path):
|
||||
db_path = str(tmp_path / "lancedb_concurrent")
|
||||
|
||||
successes, errors = _run_workers(
|
||||
_lancedb_worker,
|
||||
lambda wid, rd: (db_path, wid, N_RECORDS, rd),
|
||||
)
|
||||
|
||||
assert not errors, f"Workers failed: {errors}"
|
||||
assert successes == N_WORKERS
|
||||
|
||||
def test_data_integrity_after_concurrent_saves(self, tmp_path):
|
||||
db_path = str(tmp_path / "lancedb_integrity")
|
||||
|
||||
successes, errors = _run_workers(
|
||||
_lancedb_worker,
|
||||
lambda wid, rd: (db_path, wid, N_RECORDS, rd),
|
||||
)
|
||||
|
||||
assert not errors, f"Workers failed: {errors}"
|
||||
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
|
||||
storage = LanceDBStorage(path=db_path, table_name="memories", vector_dim=8)
|
||||
total = storage.count()
|
||||
expected = N_WORKERS * N_RECORDS
|
||||
assert total == expected, f"Expected {expected} records, got {total}"
|
||||
|
||||
|
||||
class TestConcurrentSQLiteKickoff:
|
||||
"""Concurrent multi-process writes to kickoff task outputs SQLite."""
|
||||
|
||||
def test_concurrent_writes_no_error(self, tmp_path):
|
||||
db_path = str(tmp_path / "kickoff.db")
|
||||
|
||||
from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
KickoffTaskOutputsSQLiteStorage,
|
||||
)
|
||||
|
||||
KickoffTaskOutputsSQLiteStorage(db_path=db_path)
|
||||
|
||||
successes, errors = _run_workers(
|
||||
_sqlite_kickoff_worker,
|
||||
lambda wid, rd: (db_path, wid, N_RECORDS, rd),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
assert not errors, f"Workers failed: {errors}"
|
||||
assert successes == N_WORKERS
|
||||
|
||||
with sqlite3.connect(db_path, timeout=30) as conn:
|
||||
count = conn.execute(
|
||||
"SELECT COUNT(*) FROM latest_kickoff_task_outputs"
|
||||
).fetchone()[0]
|
||||
expected = N_WORKERS * N_RECORDS
|
||||
assert count == expected, f"Expected {expected} rows, got {count}"
|
||||
|
||||
|
||||
class TestConcurrentSQLiteFlow:
|
||||
"""Concurrent multi-process writes to flow persistence SQLite."""
|
||||
|
||||
def test_concurrent_writes_no_error(self, tmp_path):
|
||||
db_path = str(tmp_path / "flow_states.db")
|
||||
|
||||
successes, errors = _run_workers(
|
||||
_sqlite_flow_worker,
|
||||
lambda wid, rd: (db_path, wid, N_RECORDS, rd),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
assert not errors, f"Workers failed: {errors}"
|
||||
assert successes == N_WORKERS
|
||||
|
||||
with sqlite3.connect(db_path, timeout=30) as conn:
|
||||
count = conn.execute("SELECT COUNT(*) FROM flow_states").fetchone()[0]
|
||||
expected = N_WORKERS * N_RECORDS
|
||||
assert count == expected, f"Expected {expected} rows, got {count}"
|
||||
|
||||
|
||||
class TestConcurrentChromaDB:
|
||||
"""Concurrent multi-process ChromaDB client creation."""
|
||||
|
||||
def test_concurrent_client_creation_no_lock_exception(self, tmp_path):
|
||||
persist_dir = str(tmp_path / "chromadb_concurrent")
|
||||
os.makedirs(persist_dir, exist_ok=True)
|
||||
|
||||
successes, errors = _run_workers(
|
||||
_chromadb_worker,
|
||||
lambda wid, rd: (persist_dir, wid, rd),
|
||||
)
|
||||
|
||||
assert not errors, f"Workers failed: {errors}"
|
||||
assert successes == N_WORKERS
|
||||
Reference in New Issue
Block a user