mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-06-03 15:28:10 +00:00
Compare commits
4 Commits
main
...
matcha/ove
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f5589ae2b9 | ||
|
|
809a057896 | ||
|
|
522e596009 | ||
|
|
4ffa7acffa |
@@ -1,8 +1,20 @@
|
||||
"""Centralised lock factory.
|
||||
|
||||
If ``REDIS_URL`` is set and the ``redis`` package is installed, locks are
|
||||
distributed via ``portalocker.RedisLock``. Otherwise, falls back to the
|
||||
standard file-based ``portalocker.Lock`` in the system temp dir.
|
||||
The locking backend is resolved in this order of precedence:
|
||||
|
||||
1. A backend registered in-process via :func:`set_lock_backend`. Best for
|
||||
tests and runtime wiring.
|
||||
2. A backend named by the ``CREWAI_LOCK_FACTORY`` environment variable, in
|
||||
``"module:callable"`` form (e.g. ``"my_pkg.locks:lock"``). The import path
|
||||
is resolved lazily and cached. Best for deployment-driven selection, since
|
||||
it requires no code changes and rolls back with an env unset.
|
||||
3. The built-in default: if ``REDIS_URL`` is set and the ``redis`` package is
|
||||
installed, locks are distributed via ``portalocker.RedisLock``; otherwise
|
||||
they fall back to a file-based ``portalocker.Lock`` in the system temp dir.
|
||||
|
||||
A custom backend is any callable matching :class:`LockBackend`. It receives the
|
||||
raw lock ``name`` (not the ``crewai:<hash>`` channel) and owns its own
|
||||
namespacing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -11,16 +23,19 @@ from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from functools import lru_cache
|
||||
from hashlib import md5
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import TYPE_CHECKING, Final
|
||||
from typing import TYPE_CHECKING, Final, Protocol, runtime_checkable
|
||||
|
||||
import portalocker
|
||||
import portalocker.exceptions
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from contextlib import AbstractContextManager
|
||||
|
||||
import redis
|
||||
|
||||
|
||||
@@ -28,9 +43,35 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
_REDIS_URL: str | None = os.environ.get("REDIS_URL")
|
||||
|
||||
# Optional "module:callable" import path for a custom lock backend. Read once at
|
||||
# import time, mirroring ``_REDIS_URL``; the env must be set before the process
|
||||
# starts.
|
||||
_LOCK_FACTORY_SPEC: str | None = os.environ.get("CREWAI_LOCK_FACTORY")
|
||||
|
||||
_DEFAULT_TIMEOUT: Final[int] = 120
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class LockBackend(Protocol):
|
||||
"""A pluggable locking backend.
|
||||
|
||||
A backend is any callable that, given a raw lock ``name`` and a
|
||||
``timeout``, returns a context manager that holds the lock for the
|
||||
duration of the ``with`` block and releases it on exit. The ``name`` is
|
||||
passed through verbatim (e.g. ``"chromadb_init"``); the backend owns its
|
||||
own namespacing.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self, name: str, *, timeout: float
|
||||
) -> AbstractContextManager[None]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Active backend override; ``None`` means use the built-in default selection.
|
||||
_backend: LockBackend | None = None
|
||||
|
||||
|
||||
def _redis_available() -> bool:
|
||||
"""Return True if redis is installed and REDIS_URL is set."""
|
||||
if not _REDIS_URL:
|
||||
@@ -53,16 +94,59 @@ def _redis_connection() -> redis.Redis[bytes]:
|
||||
return Redis.from_url(_REDIS_URL)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def lock(name: str, *, timeout: float = _DEFAULT_TIMEOUT) -> Iterator[None]:
|
||||
"""Acquire a named lock, yielding while it is held.
|
||||
@lru_cache(maxsize=1)
|
||||
def _env_lock_factory() -> LockBackend | None:
|
||||
"""Resolve the ``CREWAI_LOCK_FACTORY`` import path to a callable.
|
||||
|
||||
Args:
|
||||
name: A human-readable lock name (e.g. ``"chromadb_init"``).
|
||||
Automatically namespaced to avoid collisions.
|
||||
timeout: Maximum seconds to wait for the lock before raising.
|
||||
Returns ``None`` when the env var is unset. Resolution is cached, so the
|
||||
import happens at most once per process.
|
||||
|
||||
Raises:
|
||||
ValueError: if the spec is not in ``"module:callable"`` form.
|
||||
ImportError / AttributeError: if the module or attribute is missing.
|
||||
TypeError: if the resolved attribute is not callable.
|
||||
"""
|
||||
channel = f"crewai:{md5(name.encode(), usedforsecurity=False).hexdigest()}"
|
||||
if not _LOCK_FACTORY_SPEC:
|
||||
return None
|
||||
|
||||
module_path, sep, attr = _LOCK_FACTORY_SPEC.partition(":")
|
||||
if not sep or not module_path or not attr:
|
||||
raise ValueError(
|
||||
"CREWAI_LOCK_FACTORY must be in 'module:callable' form, "
|
||||
f"got {_LOCK_FACTORY_SPEC!r}"
|
||||
)
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
factory: LockBackend = getattr(module, attr)
|
||||
if not callable(factory):
|
||||
raise TypeError(
|
||||
f"CREWAI_LOCK_FACTORY={_LOCK_FACTORY_SPEC!r} resolved to a "
|
||||
f"non-callable {type(factory).__name__}; expected a callable "
|
||||
"matching LockBackend (name, *, timeout) -> context manager."
|
||||
)
|
||||
logger.debug("Using custom lock backend from %s", _LOCK_FACTORY_SPEC)
|
||||
return factory
|
||||
|
||||
|
||||
def _active_backend() -> LockBackend:
|
||||
"""Return the backend to use, honouring override > env > default."""
|
||||
if _backend is not None:
|
||||
return _backend
|
||||
env_factory = _env_lock_factory()
|
||||
if env_factory is not None:
|
||||
return env_factory
|
||||
return _default_lock
|
||||
|
||||
|
||||
def _namespaced_channel(name: str) -> str:
|
||||
"""Return the collision-resistant, namespaced channel for ``name``."""
|
||||
return f"crewai:{md5(name.encode(), usedforsecurity=False).hexdigest()}"
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _default_lock(name: str, *, timeout: float = _DEFAULT_TIMEOUT) -> Iterator[None]:
|
||||
"""The built-in backend: Redis when available, else a temp-dir file lock."""
|
||||
channel = _namespaced_channel(name)
|
||||
|
||||
if _redis_available():
|
||||
with portalocker.RedisLock(
|
||||
@@ -87,3 +171,42 @@ def lock(name: str, *, timeout: float = _DEFAULT_TIMEOUT) -> Iterator[None]:
|
||||
yield
|
||||
finally:
|
||||
pl.release() # type: ignore[no-untyped-call]
|
||||
|
||||
|
||||
def set_lock_backend(backend: LockBackend | None) -> None:
|
||||
"""Override the locking backend used by :func:`lock`.
|
||||
|
||||
Args:
|
||||
backend: A callable matching the :class:`LockBackend` protocol, i.e.
|
||||
``backend(name, *, timeout) -> contextmanager``. Pass ``None`` to
|
||||
clear the override, falling back to the ``CREWAI_LOCK_FACTORY``
|
||||
env path if set, otherwise the built-in Redis/file default.
|
||||
"""
|
||||
global _backend
|
||||
_backend = backend
|
||||
|
||||
|
||||
def get_lock_backend() -> LockBackend:
|
||||
"""Return the currently active locking backend.
|
||||
|
||||
Honours the override > ``CREWAI_LOCK_FACTORY`` env > built-in default
|
||||
precedence.
|
||||
"""
|
||||
return _active_backend()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def lock(name: str, *, timeout: float = _DEFAULT_TIMEOUT) -> Iterator[None]:
|
||||
"""Acquire a named lock, yielding while it is held.
|
||||
|
||||
Delegates to the active backend, resolved as override >
|
||||
``CREWAI_LOCK_FACTORY`` env > built-in Redis/file selection.
|
||||
|
||||
Args:
|
||||
name: A human-readable lock name (e.g. ``"chromadb_init"``). The
|
||||
built-in default namespaces it to avoid collisions; custom
|
||||
backends receive it verbatim.
|
||||
timeout: Maximum seconds to wait for the lock before raising.
|
||||
"""
|
||||
with _active_backend()(name, timeout=timeout):
|
||||
yield
|
||||
|
||||
@@ -6,7 +6,9 @@ backend is selected. We trust portalocker to handle actual locking mechanics.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
import sys
|
||||
import types
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
@@ -20,6 +22,17 @@ def no_redis_url(monkeypatch):
|
||||
monkeypatch.setattr(lock_store, "_REDIS_URL", None)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_backend(monkeypatch):
|
||||
"""Ensure backend overrides never leak across tests."""
|
||||
monkeypatch.setattr(lock_store, "_LOCK_FACTORY_SPEC", None)
|
||||
lock_store._env_lock_factory.cache_clear()
|
||||
lock_store.set_lock_backend(None)
|
||||
yield
|
||||
lock_store.set_lock_backend(None)
|
||||
lock_store._env_lock_factory.cache_clear()
|
||||
|
||||
|
||||
# _redis_available
|
||||
|
||||
|
||||
@@ -64,3 +77,166 @@ def test_uses_redis_lock_when_redis_available(monkeypatch):
|
||||
kwargs = mock_redis_lock.call_args.kwargs
|
||||
assert kwargs["channel"].startswith("crewai:")
|
||||
assert kwargs["connection"] is fake_conn
|
||||
|
||||
|
||||
# backend override
|
||||
|
||||
|
||||
def test_override_backend_is_used():
|
||||
calls = []
|
||||
|
||||
@contextmanager
|
||||
def fake_backend(name, *, timeout):
|
||||
calls.append((name, timeout))
|
||||
yield
|
||||
|
||||
lock_store.set_lock_backend(fake_backend)
|
||||
|
||||
# The default file/redis path must not be touched when overridden.
|
||||
with mock.patch("portalocker.Lock") as mock_lock:
|
||||
with lock("override_test", timeout=5):
|
||||
pass
|
||||
|
||||
mock_lock.assert_not_called()
|
||||
assert calls == [("override_test", 5)]
|
||||
|
||||
|
||||
def test_reset_restores_default_backend():
|
||||
@contextmanager
|
||||
def fake_backend(name, *, timeout):
|
||||
yield
|
||||
|
||||
lock_store.set_lock_backend(fake_backend)
|
||||
lock_store.set_lock_backend(None)
|
||||
|
||||
with mock.patch("portalocker.Lock") as mock_lock:
|
||||
with lock("after_reset"):
|
||||
pass
|
||||
|
||||
mock_lock.assert_called_once()
|
||||
|
||||
|
||||
def test_get_lock_backend_reflects_override():
|
||||
assert lock_store.get_lock_backend() is lock_store._default_lock
|
||||
|
||||
@contextmanager
|
||||
def fake_backend(name, *, timeout):
|
||||
yield
|
||||
|
||||
lock_store.set_lock_backend(fake_backend)
|
||||
assert lock_store.get_lock_backend() is fake_backend
|
||||
|
||||
|
||||
# CREWAI_LOCK_FACTORY env import-path
|
||||
|
||||
|
||||
def _install_env_factory(monkeypatch, factory, modname="fakelocks", attr="lock"):
|
||||
"""Point CREWAI_LOCK_FACTORY at ``factory`` via a registered fake module."""
|
||||
module = types.ModuleType(modname)
|
||||
setattr(module, attr, factory)
|
||||
monkeypatch.setitem(sys.modules, modname, module)
|
||||
monkeypatch.setattr(lock_store, "_LOCK_FACTORY_SPEC", f"{modname}:{attr}")
|
||||
lock_store._env_lock_factory.cache_clear()
|
||||
|
||||
|
||||
def test_env_factory_used_when_spec_set(monkeypatch):
|
||||
calls = []
|
||||
|
||||
@contextmanager
|
||||
def fake_backend(name, *, timeout):
|
||||
calls.append((name, timeout))
|
||||
yield
|
||||
|
||||
_install_env_factory(monkeypatch, fake_backend)
|
||||
|
||||
with mock.patch("portalocker.Lock") as mock_lock:
|
||||
with lock("env_test", timeout=7):
|
||||
pass
|
||||
|
||||
mock_lock.assert_not_called()
|
||||
assert calls == [("env_test", 7)]
|
||||
assert lock_store.get_lock_backend() is fake_backend
|
||||
|
||||
|
||||
def test_programmatic_override_takes_precedence_over_env(monkeypatch):
|
||||
@contextmanager
|
||||
def env_backend(name, *, timeout):
|
||||
raise AssertionError("env backend should not be used")
|
||||
yield # pragma: no cover
|
||||
|
||||
used = []
|
||||
|
||||
@contextmanager
|
||||
def code_backend(name, *, timeout):
|
||||
used.append(name)
|
||||
yield
|
||||
|
||||
_install_env_factory(monkeypatch, env_backend)
|
||||
lock_store.set_lock_backend(code_backend)
|
||||
|
||||
with lock("precedence_test"):
|
||||
pass
|
||||
|
||||
assert used == ["precedence_test"]
|
||||
assert lock_store.get_lock_backend() is code_backend
|
||||
|
||||
|
||||
def test_env_factory_is_cached(monkeypatch):
|
||||
@contextmanager
|
||||
def fake_backend(name, *, timeout):
|
||||
yield
|
||||
|
||||
_install_env_factory(monkeypatch, fake_backend)
|
||||
|
||||
with lock("a"):
|
||||
pass
|
||||
|
||||
# Remove the module: a cached factory must keep working without re-importing.
|
||||
monkeypatch.delitem(sys.modules, "fakelocks")
|
||||
with lock("b"):
|
||||
pass
|
||||
|
||||
assert lock_store.get_lock_backend() is fake_backend
|
||||
|
||||
|
||||
def test_invalid_spec_raises(monkeypatch):
|
||||
monkeypatch.setattr(lock_store, "_LOCK_FACTORY_SPEC", "no_colon_here")
|
||||
lock_store._env_lock_factory.cache_clear()
|
||||
|
||||
with pytest.raises(ValueError, match="module:callable"):
|
||||
with lock("bad_spec"):
|
||||
pass
|
||||
|
||||
|
||||
def test_non_callable_factory_raises_with_context(monkeypatch):
|
||||
# Resolve the spec to a non-callable attribute.
|
||||
_install_env_factory(monkeypatch, "not a callable", attr="lock")
|
||||
|
||||
with pytest.raises(TypeError, match="CREWAI_LOCK_FACTORY"):
|
||||
with lock("bad_factory"):
|
||||
pass
|
||||
|
||||
|
||||
def test_env_factory_used_after_reset(monkeypatch):
|
||||
"""Clearing the in-process override falls back to the env factory."""
|
||||
seen = []
|
||||
|
||||
@contextmanager
|
||||
def env_backend(name, *, timeout):
|
||||
seen.append(name)
|
||||
yield
|
||||
|
||||
@contextmanager
|
||||
def code_backend(name, *, timeout):
|
||||
raise AssertionError("override should have been cleared")
|
||||
yield # pragma: no cover
|
||||
|
||||
_install_env_factory(monkeypatch, env_backend)
|
||||
lock_store.set_lock_backend(code_backend)
|
||||
lock_store.set_lock_backend(None)
|
||||
|
||||
with lock("after_reset_env"):
|
||||
pass
|
||||
|
||||
assert seen == ["after_reset_env"]
|
||||
assert lock_store.get_lock_backend() is env_backend
|
||||
|
||||
Reference in New Issue
Block a user