feat: implement file-based locking decorator for concurrent RAG client access

This commit is contained in:
Greyson LaLonde
2025-08-27 12:03:12 -04:00
parent 109de91d08
commit d4c27d22cf

View File

@@ -0,0 +1,622 @@
from __future__ import annotations
import asyncio
import contextvars
import functools
import hashlib
import inspect
import os
import re
import sys
import threading
import time
import weakref
from pathlib import Path
from collections.abc import Callable, Coroutine
from typing import Any, TypeVar, ParamSpec, Concatenate, TypedDict
import portalocker
from portalocker import constants
from typing_extensions import NotRequired, Self
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
_STATE: dict[str, int] = {"pid": os.getpid()}
def _reset_after_fork() -> None:
"""Reset in-process state in the child after a fork.
Resets all locks and thread-local storage after a process fork
to prevent lock contamination across processes.
"""
global _sync_rlocks, _async_locks_by_loop, _tls, _task_depths_var, _STATE
_sync_rlocks = {}
_async_locks_by_loop = weakref.WeakKeyDictionary()
_tls = threading.local()
# Reset task-local depths for async
_task_depths_var = contextvars.ContextVar("locked_task_depths", default=None)
_STATE["pid"] = os.getpid()
def _ensure_same_process() -> None:
"""Ensure we're in the same process, reset if forked.
Checks if the current PID matches the stored PID and resets
state if a fork has occurred.
"""
if _STATE["pid"] != os.getpid():
_reset_after_fork()
# Automatically reset in a forked child on POSIX
_register_at_fork = getattr(os, "register_at_fork", None)
if _register_at_fork is not None:
_register_at_fork(after_in_child=_reset_after_fork)
class LockConfig(TypedDict):
"""Configuration for portalocker locks.
Attributes:
mode: File open mode.
timeout: Optional lock timeout.
check_interval: Optional check interval.
fail_when_locked: Whether to fail if already locked.
flags: Portalocker lock flags.
"""
mode: str
timeout: NotRequired[float]
check_interval: NotRequired[float]
fail_when_locked: bool
flags: portalocker.LockFlags
def _get_platform_lock_flags() -> portalocker.LockFlags:
"""Get platform-appropriate lock flags.
Returns:
LockFlags.EXCLUSIVE for exclusive file locking.
"""
# Use EXCLUSIVE flag only - let portalocker handle blocking/non-blocking internally
return constants.LockFlags.EXCLUSIVE
def _get_lock_config() -> LockConfig:
"""Get lock configuration appropriate for the platform.
Returns:
LockConfig dict with mode, flags, and fail_when_locked settings.
"""
config: LockConfig = {
"mode": "a+",
"fail_when_locked": False,
"flags": _get_platform_lock_flags(),
}
return config
LOCK_CONFIG: LockConfig = _get_lock_config()
LOCK_STALE_SECONDS = 120
def _default_lock_dir() -> Path:
"""Get or create the default lock directory.
Returns:
Path to ~/.crewai/locks directory, created if necessary.
"""
lock_dir = Path.home() / ".crewai" / "locks"
lock_dir.mkdir(parents=True, exist_ok=True)
# Best-effort: restrict perms on POSIX
try:
if os.name == "posix":
lock_dir.chmod(0o700)
except Exception:
pass
# Clean up old lock files
_cleanup_stale_locks(lock_dir)
return lock_dir
def _cleanup_stale_locks(lock_dir: Path, max_age_seconds: int = 86400) -> None:
"""Remove lock files older than max_age_seconds.
Args:
lock_dir: Directory containing lock files.
max_age_seconds: Maximum age before considering a lock stale (default 24 hours).
"""
try:
current_time = time.time()
for lock_file in lock_dir.glob("*.lock"):
try:
# Check if file is old and not currently locked
file_age = current_time - lock_file.stat().st_mtime
if file_age > max_age_seconds:
# Try to acquire exclusive lock - if successful, file is not in use
try:
with portalocker.Lock(
str(lock_file),
mode="a+",
timeout=0.01, # Very short timeout
fail_when_locked=True,
flags=constants.LockFlags.EXCLUSIVE,
):
pass # We got the lock, file is not in use
# Safe to remove
lock_file.unlink(missing_ok=True)
except (portalocker.LockException, OSError):
# File is locked or can't be accessed, skip it
pass
except (OSError, IOError):
# Skip files we can't stat or process
pass
except Exception:
# Cleanup is best-effort, don't fail on errors
pass
def _hash_str(value: str) -> str:
"""Generate a short hash of a string.
Args:
value: String to hash.
Returns:
First 10 characters of SHA256 hash.
"""
return hashlib.sha256(value.encode()).hexdigest()[:10]
def _qualname_for(func: Callable[..., Any], owner: object | None = None) -> str:
"""Get qualified name for a function.
Args:
func: Function to get qualified name for.
owner: Optional owner object for the function.
Returns:
Fully qualified name including module and class.
"""
target = inspect.unwrap(func)
if inspect.ismethod(func) and getattr(func, "__self__", None) is not None:
owner_obj = func.__self__
cls = owner_obj if inspect.isclass(owner_obj) else owner_obj.__class__
return f"{target.__module__}.{cls.__qualname__}.{getattr(target, '__name__', '<?>')}"
if owner is not None:
cls = owner if inspect.isclass(owner) else owner.__class__
return f"{target.__module__}.{cls.__qualname__}.{getattr(target, '__name__', '<?>')}"
qn = getattr(target, "__qualname__", None)
if qn is not None:
return f"{getattr(target, '__module__', target.__class__.__module__)}.{qn}"
if isinstance(target, functools.partial):
f = inspect.unwrap(target.func)
return f"{getattr(f, '__module__', 'builtins')}.{getattr(f, '__qualname__', getattr(f, '__name__', '<?>'))}"
cls = target.__class__
return f"{cls.__module__}.{cls.__qualname__}.__call__"
def _get_lock_context(
instance: Any | None,
func: Callable[..., Any],
kwargs: dict[str, Any],
) -> tuple[Path, str | None]:
"""Extract lock context from function call.
Args:
instance: Instance the function is called on.
func: Function being called.
kwargs: Keyword arguments passed to function.
Returns:
Tuple of (lock_file_path, collection_name).
"""
collection_name = (
str(kwargs.get("collection_name")) if "collection_name" in kwargs else None
)
lock_dir = _default_lock_dir()
base = _qualname_for(func, owner=instance)
safe_base = re.sub(r"[^\w.\-]+", "_", base)
suffix = f"_{_hash_str(collection_name)}" if collection_name else ""
path = lock_dir / f"{safe_base}{suffix}.lock"
return path, collection_name
def _write_lock_metadata(lock_file_path: Path) -> None:
"""Write metadata to lock file for staleness detection.
Args:
lock_file_path: Path to the lock file.
"""
try:
with open(lock_file_path, "w") as f:
f.write(f"{os.getpid()}\n{time.time()}\n")
f.flush()
os.fsync(f.fileno()) # Ensure data is written to disk
# Set restrictive permissions on lock file (Unix only)
if sys.platform not in ("win32", "cygwin"):
try:
lock_file_path.chmod(0o600)
except Exception:
pass
except Exception:
# Best effort - don't fail if we can't write metadata
pass
def _check_lock_staleness(lock_file_path: Path) -> bool:
"""Check if a lock file is stale.
Args:
lock_file_path: Path to the lock file.
Returns:
True if lock is stale, False otherwise.
"""
try:
if not lock_file_path.exists():
return False
with open(lock_file_path) as f:
lines = f.readlines()
if len(lines) < 2:
return True # unreadable metadata
pid = int(lines[0].strip())
ts = float(lines[1].strip())
# If the process is alive, do NOT treat as stale based on time alone.
if sys.platform not in ("win32", "cygwin"):
try:
os.kill(pid, 0)
return False # alive → not stale
except (OSError, ProcessLookupError):
pass # dead process → proceed to time check
# Process dead: time window can be small; consider stale now
return (time.time() - ts) > 1.0 # essentially “dead means stale”
except Exception:
return True
_sync_rlocks: dict[Path, threading.RLock] = {}
_sync_rlocks_guard = threading.Lock()
_tls = threading.local()
def _get_sync_rlock(path: Path) -> threading.RLock:
"""Get or create a reentrant lock for a path.
Args:
path: Path to get lock for.
Returns:
Threading RLock for the given path.
"""
with _sync_rlocks_guard:
lk = _sync_rlocks.get(path)
if lk is None:
lk = threading.RLock()
_sync_rlocks[path] = lk
return lk
class _SyncDepthManager:
"""Context manager for sync depth tracking.
Tracks reentrancy depth for synchronous locks to determine
when to acquire/release file locks.
"""
def __init__(self, path: Path):
"""Initialize depth manager.
Args:
path: Path to track depth for.
"""
self.path = path
self.depth = 0
def __enter__(self) -> int:
"""Enter context and increment depth.
Returns:
Current depth after increment.
"""
depths = getattr(_tls, "depths", None)
if depths is None:
depths = {}
_tls.depths = depths
self.depth = depths.get(self.path, 0) + 1
depths[self.path] = self.depth
return self.depth
def __exit__(self, *args: Any) -> None:
"""Exit context and decrement depth.
Args:
*args: Exception information if any.
"""
depths = getattr(_tls, "depths", {})
v = depths.get(self.path, 1) - 1
if v <= 0:
depths.pop(self.path, None)
else:
depths[self.path] = v
def _safe_to_delete(path: Path) -> bool:
"""Check if a lock file can be safely deleted.
Args:
path: Path to the lock file.
Returns:
True if file can be deleted safely, False otherwise.
"""
try:
with portalocker.Lock(
str(path),
mode="a+",
timeout=0.01, # very short, non-blocking-ish
fail_when_locked=True, # fail if someone holds it
flags=constants.LockFlags.EXCLUSIVE,
):
return True
except Exception:
return False
def with_lock(func: Callable[Concatenate[T, P], R]) -> Callable[Concatenate[T, P], R]:
"""Decorator for file-based cross-process locking.
Args:
func: Function to wrap with locking.
Returns:
Wrapped function with locking behavior.
"""
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
_ensure_same_process()
path, _ = _get_lock_context(self, func, kwargs)
local_lock = _get_sync_rlock(path)
prune_after = False
with local_lock:
with _SyncDepthManager(path) as depth:
if depth == 1:
# stale handling
if _check_lock_staleness(path) and _safe_to_delete(path):
try:
path.unlink(missing_ok=True)
except Exception:
pass
# acquire file lock
lock_config = LockConfig(
mode=LOCK_CONFIG["mode"],
fail_when_locked=LOCK_CONFIG["fail_when_locked"],
flags=LOCK_CONFIG["flags"],
)
with portalocker.Lock(str(path), **lock_config) as _fh:
_write_lock_metadata(path)
result = func(self, *args, **kwargs)
try:
path.unlink(missing_ok=True)
except Exception:
pass
prune_after = True
else:
result = func(self, *args, **kwargs)
# <-- NOW its safe to remove the entry
if prune_after:
with _sync_rlocks_guard:
_sync_rlocks.pop(path, None)
return result
return wrapper
# Use weak references to avoid keeping event loops alive
_async_locks_by_loop: weakref.WeakKeyDictionary[
asyncio.AbstractEventLoop, dict[Path, asyncio.Lock]
] = weakref.WeakKeyDictionary()
_async_locks_guard = threading.Lock()
_task_depths_var: contextvars.ContextVar[dict[Path, int] | None] = (
contextvars.ContextVar("locked_task_depths", default=None)
)
def _get_async_lock(path: Path) -> asyncio.Lock:
"""Get or create an async lock for the current event loop.
Args:
path: Path to get lock for.
Returns:
Asyncio Lock for the given path in current event loop.
"""
loop = asyncio.get_running_loop()
with _async_locks_guard:
# Get locks dict for this event loop
loop_locks = _async_locks_by_loop.get(loop)
if loop_locks is None:
loop_locks = {}
_async_locks_by_loop[loop] = loop_locks
# Get or create lock for this path
lk = loop_locks.get(path)
if lk is None:
# Create lock in the context of the running loop
lk = asyncio.Lock()
loop_locks[path] = lk
return lk
class _AsyncDepthManager:
"""Context manager for async task-local depth tracking.
Tracks reentrancy depth for async locks to determine
when to acquire/release file locks.
"""
def __init__(self, path: Path):
"""Initialize async depth manager.
Args:
path: Path to track depth for.
"""
self.path = path
self.depths: dict[Path, int] | None = None
self.is_reentrant = False
def __enter__(self) -> Self:
"""Enter context and track async task depth.
Returns:
Self for context management.
"""
d = _task_depths_var.get()
if d is None:
d = {}
_task_depths_var.set(d)
self.depths = d
cur_depth = self.depths.get(self.path, 0)
if cur_depth > 0:
self.is_reentrant = True
self.depths[self.path] = cur_depth + 1
else:
self.depths[self.path] = 1
return self
def __exit__(self, *args: Any) -> None:
"""Exit context and update task depth.
Args:
*args: Exception information if any.
"""
if self.depths is not None:
new_depth = self.depths.get(self.path, 1) - 1
if new_depth <= 0:
self.depths.pop(self.path, None)
else:
self.depths[self.path] = new_depth
async def _safe_to_delete_async(path: Path) -> bool:
"""Check if a lock file can be safely deleted (async).
Args:
path: Path to the lock file.
Returns:
True if file can be deleted safely, False otherwise.
"""
def _try_lock() -> bool:
try:
with portalocker.Lock(
str(path),
mode="a+",
timeout=0.01, # very short, effectively non-blocking
fail_when_locked=True, # fail if another process holds it
flags=constants.LockFlags.EXCLUSIVE,
):
return True
except Exception:
return False
return await asyncio.to_thread(_try_lock)
def async_with_lock(
func: Callable[Concatenate[T, P], Coroutine[Any, Any, R]],
) -> Callable[Concatenate[T, P], Coroutine[Any, Any, R]]:
"""Async decorator for file-based cross-process locking.
Args:
func: Async function to wrap with locking.
Returns:
Wrapped async function with locking behavior.
"""
@functools.wraps(func)
async def wrapper(self: T, *args: P.args, **kwargs: P.kwargs) -> R:
_ensure_same_process()
path, _ = _get_lock_context(self, func, kwargs)
with _AsyncDepthManager(path) as depth_mgr:
if depth_mgr.is_reentrant:
# Re-entrant within the same task: skip file lock
return await func(self, *args, **kwargs)
# Safer stale handling: only unlink if we can lock it first
if _check_lock_staleness(path) and await _safe_to_delete_async(path):
try:
await asyncio.to_thread(lambda: path.unlink(missing_ok=True))
except Exception:
pass
# Acquire per-loop async lock to serialize within this loop
async_lock = _get_async_lock(path)
await async_lock.acquire()
try:
# Acquire cross-process file lock in a thread
lock_config = LockConfig(
mode=LOCK_CONFIG["mode"],
fail_when_locked=LOCK_CONFIG["fail_when_locked"],
flags=LOCK_CONFIG["flags"],
)
file_lock = portalocker.Lock(str(path), **lock_config)
await asyncio.to_thread(file_lock.acquire)
try:
# Write/refresh metadata while lock is held
await asyncio.to_thread(lambda: _write_lock_metadata(path))
result = await func(self, *args, **kwargs)
finally:
# Release file lock before unlink to avoid inode race
try:
await asyncio.to_thread(file_lock.release)
finally:
# Now it's safe to unlink the path
try:
await asyncio.to_thread(
lambda: path.unlink(missing_ok=True)
)
except Exception:
pass
return result
finally:
async_lock.release()
with _async_locks_guard:
loop = asyncio.get_running_loop()
loop_locks = _async_locks_by_loop.get(loop)
if loop_locks is not None:
loop_locks.pop(path, None)
return wrapper