mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-03-11 06:18:19 +00:00
Compare commits
4 Commits
main
...
gl/fix/con
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a15aa0fb97 | ||
|
|
67bc64e82c | ||
|
|
a037ade1ca | ||
|
|
1bc92ebb5f |
@@ -72,7 +72,8 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
|
||||
def init_db(self) -> None:
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
# Main state table
|
||||
conn.execute(
|
||||
"""
|
||||
@@ -136,7 +137,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
|
||||
)
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO flow_states (
|
||||
@@ -163,7 +164,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
Returns:
|
||||
The most recent state as a dictionary, or None if no state exists
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT state_json
|
||||
@@ -213,7 +214,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
self.save_state(flow_uuid, context.method_name, state_data)
|
||||
|
||||
# Save pending feedback context
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
# Use INSERT OR REPLACE to handle re-triggering feedback on same flow
|
||||
conn.execute(
|
||||
"""
|
||||
@@ -248,7 +249,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
# Import here to avoid circular imports
|
||||
from crewai.flow.async_feedback.types import PendingFeedbackContext
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT state_json, context_json
|
||||
@@ -272,7 +273,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
DELETE FROM pending_feedback
|
||||
|
||||
@@ -38,7 +38,8 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If database initialization fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
@@ -82,7 +83,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
"""
|
||||
inputs = inputs or {}
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
@@ -125,7 +126,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If updating the task output fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
|
||||
@@ -166,7 +167,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If loading task outputs fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("""
|
||||
SELECT *
|
||||
@@ -205,7 +206,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
DatabaseOperationError: If deleting task outputs fails due to SQLite errors.
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
conn.execute("BEGIN TRANSACTION")
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("DELETE FROM latest_kickoff_task_outputs")
|
||||
|
||||
@@ -12,6 +12,7 @@ import time
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import lancedb
|
||||
import portalocker
|
||||
|
||||
from crewai.memory.types import MemoryRecord, ScopeInfo
|
||||
|
||||
@@ -90,6 +91,7 @@ class LanceDBStorage:
|
||||
# Raise it proactively so scans on large tables never hit OS error 24.
|
||||
try:
|
||||
import resource
|
||||
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
if soft < 4096:
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (min(hard, 4096), hard))
|
||||
@@ -99,7 +101,8 @@ class LanceDBStorage:
|
||||
self._compact_every = compact_every
|
||||
self._save_count = 0
|
||||
|
||||
# Get or create a shared write lock for this database path.
|
||||
self._lockfile = str(self._path / ".lance_write.lock")
|
||||
|
||||
resolved = str(self._path.resolve())
|
||||
with LanceDBStorage._path_locks_guard:
|
||||
if resolved not in LanceDBStorage._path_locks:
|
||||
@@ -110,10 +113,13 @@ class LanceDBStorage:
|
||||
# If no table exists yet, defer creation until the first save so the
|
||||
# dimension can be auto-detected from the embedder's actual output.
|
||||
try:
|
||||
self._table: lancedb.table.Table | None = self._db.open_table(self._table_name)
|
||||
self._table: lancedb.table.Table | None = self._db.open_table(
|
||||
self._table_name
|
||||
)
|
||||
self._vector_dim: int = self._infer_dim_from_table(self._table)
|
||||
# Best-effort: create the scope index if it doesn't exist yet.
|
||||
self._ensure_scope_index()
|
||||
with self._file_lock():
|
||||
self._ensure_scope_index()
|
||||
# Compact in the background if the table has accumulated many
|
||||
# fragments from previous runs (each save() creates one).
|
||||
self._compact_if_needed()
|
||||
@@ -124,7 +130,8 @@ class LanceDBStorage:
|
||||
# Explicit dim provided: create the table immediately if it doesn't exist.
|
||||
if self._table is None and vector_dim is not None:
|
||||
self._vector_dim = vector_dim
|
||||
self._table = self._create_table(vector_dim)
|
||||
with self._file_lock():
|
||||
self._table = self._create_table(vector_dim)
|
||||
|
||||
@property
|
||||
def write_lock(self) -> threading.RLock:
|
||||
@@ -149,18 +156,14 @@ class LanceDBStorage:
|
||||
break
|
||||
return DEFAULT_VECTOR_DIM
|
||||
|
||||
def _retry_write(self, op: str, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Execute a table operation with retry on LanceDB commit conflicts.
|
||||
def _file_lock(self) -> portalocker.Lock:
|
||||
"""Return a cross-process file lock for serialising writes."""
|
||||
return portalocker.Lock(self._lockfile, timeout=120)
|
||||
|
||||
Args:
|
||||
op: Method name on the table object (e.g. "add", "delete").
|
||||
*args, **kwargs: Passed to the table method.
|
||||
def _do_write(self, op: str, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Execute a single table write with retry on commit conflicts.
|
||||
|
||||
LanceDB uses optimistic concurrency: if two transactions overlap,
|
||||
the second to commit fails with an ``OSError`` containing
|
||||
"Commit conflict". This helper retries with exponential backoff,
|
||||
refreshing the table reference before each retry so the retried
|
||||
call uses the latest committed version (not a stale reference).
|
||||
Caller must already hold the cross-process file lock.
|
||||
"""
|
||||
delay = _RETRY_BASE_DELAY
|
||||
for attempt in range(_MAX_RETRIES + 1):
|
||||
@@ -171,20 +174,24 @@ class LanceDBStorage:
|
||||
raise
|
||||
_logger.debug(
|
||||
"LanceDB commit conflict on %s (attempt %d/%d), retrying in %.1fs",
|
||||
op, attempt + 1, _MAX_RETRIES, delay,
|
||||
op,
|
||||
attempt + 1,
|
||||
_MAX_RETRIES,
|
||||
delay,
|
||||
)
|
||||
# Refresh table to pick up the latest version before retrying.
|
||||
# The next getattr(self._table, op) will use the fresh table.
|
||||
try:
|
||||
self._table = self._db.open_table(self._table_name)
|
||||
except Exception: # noqa: S110
|
||||
pass # table refresh is best-effort
|
||||
pass
|
||||
time.sleep(delay)
|
||||
delay *= 2
|
||||
return None # unreachable, but satisfies type checker
|
||||
|
||||
def _create_table(self, vector_dim: int) -> lancedb.table.Table:
|
||||
"""Create a new table with the given vector dimension."""
|
||||
"""Create a new table with the given vector dimension.
|
||||
|
||||
Caller must already hold the cross-process file lock.
|
||||
"""
|
||||
placeholder = [
|
||||
{
|
||||
"id": "__schema_placeholder__",
|
||||
@@ -200,8 +207,11 @@ class LanceDBStorage:
|
||||
"vector": [0.0] * vector_dim,
|
||||
}
|
||||
]
|
||||
table = self._db.create_table(self._table_name, placeholder)
|
||||
table.delete("id = '__schema_placeholder__'")
|
||||
try:
|
||||
table = self._db.create_table(self._table_name, placeholder)
|
||||
table.delete("id = '__schema_placeholder__'")
|
||||
except ValueError:
|
||||
table = self._db.open_table(self._table_name)
|
||||
return table
|
||||
|
||||
def _ensure_scope_index(self) -> None:
|
||||
@@ -248,9 +258,9 @@ class LanceDBStorage:
|
||||
"""Run ``table.optimize()`` in a background thread, absorbing errors."""
|
||||
try:
|
||||
if self._table is not None:
|
||||
self._table.optimize()
|
||||
# Refresh the scope index so new fragments are covered.
|
||||
self._ensure_scope_index()
|
||||
with self._file_lock():
|
||||
self._table.optimize()
|
||||
self._ensure_scope_index()
|
||||
except Exception:
|
||||
_logger.debug("LanceDB background compaction failed", exc_info=True)
|
||||
|
||||
@@ -280,7 +290,9 @@ class LanceDBStorage:
|
||||
"last_accessed": record.last_accessed.isoformat(),
|
||||
"source": record.source or "",
|
||||
"private": record.private,
|
||||
"vector": record.embedding if record.embedding else [0.0] * self._vector_dim,
|
||||
"vector": record.embedding
|
||||
if record.embedding
|
||||
else [0.0] * self._vector_dim,
|
||||
}
|
||||
|
||||
def _row_to_record(self, row: dict[str, Any]) -> MemoryRecord:
|
||||
@@ -296,7 +308,9 @@ class LanceDBStorage:
|
||||
id=str(row["id"]),
|
||||
content=str(row["content"]),
|
||||
scope=str(row["scope"]),
|
||||
categories=json.loads(row["categories_str"]) if row.get("categories_str") else [],
|
||||
categories=json.loads(row["categories_str"])
|
||||
if row.get("categories_str")
|
||||
else [],
|
||||
metadata=json.loads(row["metadata_str"]) if row.get("metadata_str") else {},
|
||||
importance=float(row.get("importance", 0.5)),
|
||||
created_at=_parse_dt(row.get("created_at")),
|
||||
@@ -316,16 +330,15 @@ class LanceDBStorage:
|
||||
dim = len(r.embedding)
|
||||
break
|
||||
is_new_table = self._table is None
|
||||
with self._write_lock:
|
||||
with self._write_lock, self._file_lock():
|
||||
self._ensure_table(vector_dim=dim)
|
||||
rows = [self._record_to_row(r) for r in records]
|
||||
for r in rows:
|
||||
if r["vector"] is None or len(r["vector"]) != self._vector_dim:
|
||||
r["vector"] = [0.0] * self._vector_dim
|
||||
self._retry_write("add", rows)
|
||||
# Create the scope index on the first save so it covers the initial dataset.
|
||||
if is_new_table:
|
||||
self._ensure_scope_index()
|
||||
self._do_write("add", rows)
|
||||
if is_new_table:
|
||||
self._ensure_scope_index()
|
||||
# Auto-compact every N saves so fragment files don't pile up.
|
||||
self._save_count += 1
|
||||
if self._compact_every > 0 and self._save_count % self._compact_every == 0:
|
||||
@@ -333,14 +346,14 @@ class LanceDBStorage:
|
||||
|
||||
def update(self, record: MemoryRecord) -> None:
|
||||
"""Update a record by ID. Preserves created_at, updates last_accessed."""
|
||||
with self._write_lock:
|
||||
with self._write_lock, self._file_lock():
|
||||
self._ensure_table()
|
||||
safe_id = str(record.id).replace("'", "''")
|
||||
self._retry_write("delete", f"id = '{safe_id}'")
|
||||
self._do_write("delete", f"id = '{safe_id}'")
|
||||
row = self._record_to_row(record)
|
||||
if row["vector"] is None or len(row["vector"]) != self._vector_dim:
|
||||
row["vector"] = [0.0] * self._vector_dim
|
||||
self._retry_write("add", [row])
|
||||
self._do_write("add", [row])
|
||||
|
||||
def touch_records(self, record_ids: list[str]) -> None:
|
||||
"""Update last_accessed to now for the given record IDs.
|
||||
@@ -354,11 +367,11 @@ class LanceDBStorage:
|
||||
"""
|
||||
if not record_ids or self._table is None:
|
||||
return
|
||||
with self._write_lock:
|
||||
with self._write_lock, self._file_lock():
|
||||
now = datetime.utcnow().isoformat()
|
||||
safe_ids = [str(rid).replace("'", "''") for rid in record_ids]
|
||||
ids_expr = ", ".join(f"'{rid}'" for rid in safe_ids)
|
||||
self._retry_write(
|
||||
self._do_write(
|
||||
"update",
|
||||
where=f"id IN ({ids_expr})",
|
||||
values={"last_accessed": now},
|
||||
@@ -390,13 +403,17 @@ class LanceDBStorage:
|
||||
prefix = scope_prefix.rstrip("/")
|
||||
like_val = prefix + "%"
|
||||
query = query.where(f"scope LIKE '{like_val}'")
|
||||
results = query.limit(limit * 3 if (categories or metadata_filter) else limit).to_list()
|
||||
results = query.limit(
|
||||
limit * 3 if (categories or metadata_filter) else limit
|
||||
).to_list()
|
||||
out: list[tuple[MemoryRecord, float]] = []
|
||||
for row in results:
|
||||
record = self._row_to_record(row)
|
||||
if categories and not any(c in record.categories for c in categories):
|
||||
continue
|
||||
if metadata_filter and not all(record.metadata.get(k) == v for k, v in metadata_filter.items()):
|
||||
if metadata_filter and not all(
|
||||
record.metadata.get(k) == v for k, v in metadata_filter.items()
|
||||
):
|
||||
continue
|
||||
distance = row.get("_distance", 0.0)
|
||||
score = 1.0 / (1.0 + float(distance)) if distance is not None else 1.0
|
||||
@@ -416,20 +433,24 @@ class LanceDBStorage:
|
||||
) -> int:
|
||||
if self._table is None:
|
||||
return 0
|
||||
with self._write_lock:
|
||||
with self._write_lock, self._file_lock():
|
||||
if record_ids and not (categories or metadata_filter):
|
||||
before = self._table.count_rows()
|
||||
ids_expr = ", ".join(f"'{rid}'" for rid in record_ids)
|
||||
self._retry_write("delete", f"id IN ({ids_expr})")
|
||||
self._do_write("delete", f"id IN ({ids_expr})")
|
||||
return before - self._table.count_rows()
|
||||
if categories or metadata_filter:
|
||||
rows = self._scan_rows(scope_prefix)
|
||||
to_delete: list[str] = []
|
||||
for row in rows:
|
||||
record = self._row_to_record(row)
|
||||
if categories and not any(c in record.categories for c in categories):
|
||||
if categories and not any(
|
||||
c in record.categories for c in categories
|
||||
):
|
||||
continue
|
||||
if metadata_filter and not all(record.metadata.get(k) == v for k, v in metadata_filter.items()):
|
||||
if metadata_filter and not all(
|
||||
record.metadata.get(k) == v for k, v in metadata_filter.items()
|
||||
):
|
||||
continue
|
||||
if older_than and record.created_at >= older_than:
|
||||
continue
|
||||
@@ -438,7 +459,7 @@ class LanceDBStorage:
|
||||
return 0
|
||||
before = self._table.count_rows()
|
||||
ids_expr = ", ".join(f"'{rid}'" for rid in to_delete)
|
||||
self._retry_write("delete", f"id IN ({ids_expr})")
|
||||
self._do_write("delete", f"id IN ({ids_expr})")
|
||||
return before - self._table.count_rows()
|
||||
conditions = []
|
||||
if scope_prefix is not None and scope_prefix.strip("/"):
|
||||
@@ -450,11 +471,11 @@ class LanceDBStorage:
|
||||
conditions.append(f"created_at < '{older_than.isoformat()}'")
|
||||
if not conditions:
|
||||
before = self._table.count_rows()
|
||||
self._retry_write("delete", "id != ''")
|
||||
self._do_write("delete", "id != ''")
|
||||
return before - self._table.count_rows()
|
||||
where_expr = " AND ".join(conditions)
|
||||
before = self._table.count_rows()
|
||||
self._retry_write("delete", where_expr)
|
||||
self._do_write("delete", where_expr)
|
||||
return before - self._table.count_rows()
|
||||
|
||||
def _scan_rows(
|
||||
@@ -528,7 +549,7 @@ class LanceDBStorage:
|
||||
for row in rows:
|
||||
sc = str(row.get("scope", ""))
|
||||
if child_prefix and sc.startswith(child_prefix):
|
||||
rest = sc[len(child_prefix):]
|
||||
rest = sc[len(child_prefix) :]
|
||||
first_component = rest.split("/", 1)[0]
|
||||
if first_component:
|
||||
children.add(child_prefix + first_component)
|
||||
@@ -539,7 +560,11 @@ class LanceDBStorage:
|
||||
pass
|
||||
created = row.get("created_at")
|
||||
if created:
|
||||
dt = datetime.fromisoformat(str(created).replace("Z", "+00:00")) if isinstance(created, str) else created
|
||||
dt = (
|
||||
datetime.fromisoformat(str(created).replace("Z", "+00:00"))
|
||||
if isinstance(created, str)
|
||||
else created
|
||||
)
|
||||
if isinstance(dt, datetime):
|
||||
if oldest is None or dt < oldest:
|
||||
oldest = dt
|
||||
@@ -562,7 +587,7 @@ class LanceDBStorage:
|
||||
for row in rows:
|
||||
sc = str(row.get("scope", ""))
|
||||
if sc.startswith(prefix) and sc != (prefix.rstrip("/") or "/"):
|
||||
rest = sc[len(prefix):]
|
||||
rest = sc[len(prefix) :]
|
||||
first_component = rest.split("/", 1)[0]
|
||||
if first_component:
|
||||
children.add(prefix + first_component)
|
||||
@@ -590,17 +615,17 @@ class LanceDBStorage:
|
||||
return info.record_count
|
||||
|
||||
def reset(self, scope_prefix: str | None = None) -> None:
|
||||
if scope_prefix is None or scope_prefix.strip("/") == "":
|
||||
if self._table is not None:
|
||||
self._db.drop_table(self._table_name)
|
||||
self._table = None
|
||||
# Dimension is preserved; table will be recreated on next save.
|
||||
return
|
||||
if self._table is None:
|
||||
return
|
||||
prefix = scope_prefix.rstrip("/")
|
||||
if prefix:
|
||||
self._table.delete(f"scope >= '{prefix}' AND scope < '{prefix}/\uFFFF'")
|
||||
with self._write_lock, self._file_lock():
|
||||
if scope_prefix is None or scope_prefix.strip("/") == "":
|
||||
if self._table is not None:
|
||||
self._db.drop_table(self._table_name)
|
||||
self._table = None
|
||||
return
|
||||
if self._table is None:
|
||||
return
|
||||
prefix = scope_prefix.rstrip("/")
|
||||
if prefix:
|
||||
self._do_write("delete", f"scope >= '{prefix}' AND scope < '{prefix}/\uffff'")
|
||||
|
||||
def optimize(self) -> None:
|
||||
"""Compact the table synchronously and refresh the scope index.
|
||||
@@ -614,8 +639,9 @@ class LanceDBStorage:
|
||||
"""
|
||||
if self._table is None:
|
||||
return
|
||||
self._table.optimize()
|
||||
self._ensure_scope_index()
|
||||
with self._write_lock, self._file_lock():
|
||||
self._table.optimize()
|
||||
self._ensure_scope_index()
|
||||
|
||||
async def asave(self, records: list[MemoryRecord]) -> None:
|
||||
self.save(records)
|
||||
|
||||
@@ -28,7 +28,7 @@ def create_client(config: ChromaDBConfig) -> ChromaDBClient:
|
||||
lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest()
|
||||
lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock")
|
||||
|
||||
with portalocker.Lock(lockfile):
|
||||
with portalocker.Lock(lockfile, timeout=120):
|
||||
client = PersistentClient(
|
||||
path=persist_dir,
|
||||
settings=config.settings,
|
||||
|
||||
268
lib/crewai/tests/memory/test_concurrent_storage.py
Normal file
268
lib/crewai/tests/memory/test_concurrent_storage.py
Normal file
@@ -0,0 +1,268 @@
|
||||
"""Stress tests for concurrent multi-process storage access.
|
||||
|
||||
Simulates the Airflow pattern: N worker processes each writing to the
|
||||
same storage directory simultaneously. Verifies no LockException and
|
||||
data integrity after all writes complete.
|
||||
|
||||
Uses temp files for IPC instead of multiprocessing.Manager (which uses
|
||||
sockets blocked by pytest_recording).
|
||||
"""
|
||||
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File-based IPC helpers (avoids Manager sockets)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _write_result(result_dir: str, worker_id: int, success: bool, error: str = ""):
|
||||
path = os.path.join(result_dir, f"worker-{worker_id}.json")
|
||||
with open(path, "w") as f:
|
||||
json.dump({"success": success, "error": error}, f)
|
||||
|
||||
|
||||
def _collect_results(result_dir: str, n_workers: int):
|
||||
errors = {}
|
||||
successes = 0
|
||||
for wid in range(n_workers):
|
||||
path = os.path.join(result_dir, f"worker-{wid}.json")
|
||||
if not os.path.exists(path):
|
||||
errors[wid] = "Process produced no output (crashed or timed out)"
|
||||
continue
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
if data["success"]:
|
||||
successes += 1
|
||||
else:
|
||||
errors[wid] = data["error"]
|
||||
return successes, errors
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Worker functions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _lancedb_worker(path: str, worker_id: int, n_records: int, result_dir: str):
|
||||
try:
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
from crewai.memory.types import MemoryRecord
|
||||
|
||||
storage = LanceDBStorage(path=path, table_name="memories", vector_dim=8)
|
||||
records = [
|
||||
MemoryRecord(
|
||||
id=f"worker-{worker_id}-record-{i}",
|
||||
content=f"content from worker {worker_id} record {i}",
|
||||
scope=f"/test/worker-{worker_id}",
|
||||
categories=["test"],
|
||||
metadata={"worker": worker_id},
|
||||
importance=0.5,
|
||||
embedding=[float(worker_id)] * 8,
|
||||
)
|
||||
for i in range(n_records)
|
||||
]
|
||||
storage.save(records)
|
||||
_write_result(result_dir, worker_id, True)
|
||||
except Exception as e:
|
||||
_write_result(result_dir, worker_id, False, f"{type(e).__name__}: {e}")
|
||||
|
||||
|
||||
def _sqlite_kickoff_worker(db_path: str, worker_id: int, n_writes: int, result_dir: str):
|
||||
try:
|
||||
from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
KickoffTaskOutputsSQLiteStorage,
|
||||
)
|
||||
|
||||
KickoffTaskOutputsSQLiteStorage(db_path=db_path)
|
||||
for i in range(n_writes):
|
||||
with sqlite3.connect(db_path, timeout=30) as conn:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute(
|
||||
"""INSERT OR REPLACE INTO latest_kickoff_task_outputs
|
||||
(task_id, expected_output, output, task_index, inputs, was_replayed)
|
||||
VALUES (?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
f"worker-{worker_id}-task-{i}",
|
||||
"expected output",
|
||||
'{"result": "ok"}',
|
||||
worker_id * 1000 + i,
|
||||
"{}",
|
||||
False,
|
||||
),
|
||||
)
|
||||
_write_result(result_dir, worker_id, True)
|
||||
except Exception as e:
|
||||
_write_result(result_dir, worker_id, False, f"{type(e).__name__}: {e}")
|
||||
|
||||
|
||||
def _sqlite_flow_worker(db_path: str, worker_id: int, n_writes: int, result_dir: str):
|
||||
try:
|
||||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||
|
||||
persistence = SQLiteFlowPersistence(db_path=db_path)
|
||||
for i in range(n_writes):
|
||||
persistence.save_state(
|
||||
flow_uuid=f"flow-{worker_id}-{i}",
|
||||
method_name="test_method",
|
||||
state_data={"worker": worker_id, "iteration": i},
|
||||
)
|
||||
_write_result(result_dir, worker_id, True)
|
||||
except Exception as e:
|
||||
_write_result(result_dir, worker_id, False, f"{type(e).__name__}: {e}")
|
||||
|
||||
|
||||
def _chromadb_worker(persist_dir: str, worker_id: int, result_dir: str):
|
||||
try:
|
||||
from hashlib import md5
|
||||
|
||||
from chromadb import PersistentClient
|
||||
from chromadb.config import Settings
|
||||
import portalocker
|
||||
|
||||
settings = Settings(
|
||||
persist_directory=persist_dir,
|
||||
anonymized_telemetry=False,
|
||||
is_persistent=True,
|
||||
)
|
||||
|
||||
# Test the actual locking path directly (same as factory.py)
|
||||
lock_id = md5(persist_dir.encode(), usedforsecurity=False).hexdigest()
|
||||
lockfile = os.path.join(persist_dir, f"chromadb-{lock_id}.lock")
|
||||
with portalocker.Lock(lockfile, timeout=120):
|
||||
PersistentClient(path=persist_dir, settings=settings)
|
||||
|
||||
_write_result(result_dir, worker_id, True)
|
||||
except Exception as e:
|
||||
_write_result(result_dir, worker_id, False, f"{type(e).__name__}: {e}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
N_WORKERS = 6
|
||||
N_RECORDS = 20
|
||||
|
||||
|
||||
def _run_workers(target, args_fn, n_workers=N_WORKERS, timeout=120):
|
||||
"""Spawn n_workers processes and collect results via temp files."""
|
||||
with tempfile.TemporaryDirectory() as result_dir:
|
||||
procs = []
|
||||
for wid in range(n_workers):
|
||||
p = multiprocessing.Process(
|
||||
target=target,
|
||||
args=args_fn(wid, result_dir),
|
||||
)
|
||||
procs.append(p)
|
||||
|
||||
for p in procs:
|
||||
p.start()
|
||||
for p in procs:
|
||||
p.join(timeout=timeout)
|
||||
|
||||
successes, errors = _collect_results(result_dir, n_workers)
|
||||
return successes, errors
|
||||
|
||||
|
||||
class TestConcurrentLanceDB:
|
||||
"""Concurrent multi-process writes to LanceDB."""
|
||||
|
||||
def test_concurrent_saves_no_lock_exception(self, tmp_path):
|
||||
db_path = str(tmp_path / "lancedb_concurrent")
|
||||
|
||||
successes, errors = _run_workers(
|
||||
_lancedb_worker,
|
||||
lambda wid, rd: (db_path, wid, N_RECORDS, rd),
|
||||
)
|
||||
|
||||
assert not errors, f"Workers failed: {errors}"
|
||||
assert successes == N_WORKERS
|
||||
|
||||
def test_data_integrity_after_concurrent_saves(self, tmp_path):
|
||||
db_path = str(tmp_path / "lancedb_integrity")
|
||||
|
||||
successes, errors = _run_workers(
|
||||
_lancedb_worker,
|
||||
lambda wid, rd: (db_path, wid, N_RECORDS, rd),
|
||||
)
|
||||
|
||||
assert not errors, f"Workers failed: {errors}"
|
||||
|
||||
from crewai.memory.storage.lancedb_storage import LanceDBStorage
|
||||
|
||||
storage = LanceDBStorage(path=db_path, table_name="memories", vector_dim=8)
|
||||
total = storage.count()
|
||||
expected = N_WORKERS * N_RECORDS
|
||||
assert total == expected, f"Expected {expected} records, got {total}"
|
||||
|
||||
|
||||
class TestConcurrentSQLiteKickoff:
|
||||
"""Concurrent multi-process writes to kickoff task outputs SQLite."""
|
||||
|
||||
def test_concurrent_writes_no_error(self, tmp_path):
|
||||
db_path = str(tmp_path / "kickoff.db")
|
||||
|
||||
from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
KickoffTaskOutputsSQLiteStorage,
|
||||
)
|
||||
|
||||
KickoffTaskOutputsSQLiteStorage(db_path=db_path)
|
||||
|
||||
successes, errors = _run_workers(
|
||||
_sqlite_kickoff_worker,
|
||||
lambda wid, rd: (db_path, wid, N_RECORDS, rd),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
assert not errors, f"Workers failed: {errors}"
|
||||
assert successes == N_WORKERS
|
||||
|
||||
with sqlite3.connect(db_path, timeout=30) as conn:
|
||||
count = conn.execute(
|
||||
"SELECT COUNT(*) FROM latest_kickoff_task_outputs"
|
||||
).fetchone()[0]
|
||||
expected = N_WORKERS * N_RECORDS
|
||||
assert count == expected, f"Expected {expected} rows, got {count}"
|
||||
|
||||
|
||||
class TestConcurrentSQLiteFlow:
|
||||
"""Concurrent multi-process writes to flow persistence SQLite."""
|
||||
|
||||
def test_concurrent_writes_no_error(self, tmp_path):
|
||||
db_path = str(tmp_path / "flow_states.db")
|
||||
|
||||
successes, errors = _run_workers(
|
||||
_sqlite_flow_worker,
|
||||
lambda wid, rd: (db_path, wid, N_RECORDS, rd),
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
assert not errors, f"Workers failed: {errors}"
|
||||
assert successes == N_WORKERS
|
||||
|
||||
with sqlite3.connect(db_path, timeout=30) as conn:
|
||||
count = conn.execute("SELECT COUNT(*) FROM flow_states").fetchone()[0]
|
||||
expected = N_WORKERS * N_RECORDS
|
||||
assert count == expected, f"Expected {expected} rows, got {count}"
|
||||
|
||||
|
||||
class TestConcurrentChromaDB:
|
||||
"""Concurrent multi-process ChromaDB client creation."""
|
||||
|
||||
def test_concurrent_client_creation_no_lock_exception(self, tmp_path):
|
||||
persist_dir = str(tmp_path / "chromadb_concurrent")
|
||||
os.makedirs(persist_dir, exist_ok=True)
|
||||
|
||||
successes, errors = _run_workers(
|
||||
_chromadb_worker,
|
||||
lambda wid, rd: (persist_dir, wid, rd),
|
||||
)
|
||||
|
||||
assert not errors, f"Workers failed: {errors}"
|
||||
assert successes == N_WORKERS
|
||||
Reference in New Issue
Block a user