From 1b1639b862244049061917d0e48daeee2c786479 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Thu, 11 Jun 2026 19:27:05 +0000 Subject: [PATCH] fix: atomic writes and locking for JsonProvider checkpoints (#6125) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - JsonProvider.checkpoint() and acheckpoint() now write to a temp file then os.replace() for atomic file updates (no partial writes) - Added threading.Lock to JsonProvider for sync path serialization - Added asyncio.Lock to JsonProvider for async path serialization - Added threading.Lock (_lineage_lock) to RuntimeState to protect read-update of _parent_id across concurrent checkpoint calls - Added asyncio.Lock (_async_lineage_lock) to RuntimeState for async - checkpoint_listener._do_checkpoint() now acquires _lineage_lock for the full read-write-update cycle - Temp files are cleaned up on write failure - Added 8 new tests: atomic write correctness, temp file cleanup on success/failure, async equivalents, and concurrent sync/async lineage preservation with 5 writers × 10 writes each Co-Authored-By: João --- .../src/crewai/state/checkpoint_listener.py | 63 +++--- .../crewai/state/provider/json_provider.py | 71 ++++-- lib/crewai/src/crewai/state/runtime.py | 94 ++++---- lib/crewai/tests/test_checkpoint.py | 207 ++++++++++++++++++ 4 files changed, 353 insertions(+), 82 deletions(-) diff --git a/lib/crewai/src/crewai/state/checkpoint_listener.py b/lib/crewai/src/crewai/state/checkpoint_listener.py index 53ae0b494..2055c9586 100644 --- a/lib/crewai/src/crewai/state/checkpoint_listener.py +++ b/lib/crewai/src/crewai/state/checkpoint_listener.py @@ -113,7 +113,11 @@ def _find_checkpoint(source: Any) -> CheckpointConfig | None: def _do_checkpoint( state: RuntimeState, cfg: CheckpointConfig, event: BaseEvent | None = None ) -> None: - """Write a checkpoint and prune old ones if configured.""" + """Write a checkpoint and prune old ones if configured. + + The state's lineage lock is held for the entire read-write-update cycle + so concurrent callers see consistent ``_parent_id`` values. + """ provider_name: str = type(cfg.provider).__name__ trigger: str | None = event.type if event is not None else None context: dict[str, Any] = { @@ -123,36 +127,37 @@ def _do_checkpoint( "agent_role": event.agent_role if event is not None else None, } - parent_id_snapshot: str | None = state._parent_id - branch_snapshot: str = state._branch - - crewai_event_bus.emit( - cfg, - CheckpointStartedEvent( - location=cfg.location, - provider=provider_name, - trigger=trigger, - branch=branch_snapshot, - parent_id=parent_id_snapshot, - **context, - ), - ) - start: float = time.perf_counter() try: - _prepare_entities(state.root) - payload = state.model_dump(mode="json") - if event is not None: - payload["trigger"] = event.type - data = json.dumps(payload) - location = cfg.provider.checkpoint( - data, - cfg.location, - parent_id=parent_id_snapshot, - branch=branch_snapshot, - ) - state._chain_lineage(cfg.provider, location) - checkpoint_id: str = cfg.provider.extract_id(location) + with state._lineage_lock: + parent_id_snapshot: str | None = state._parent_id + branch_snapshot: str = state._branch + + crewai_event_bus.emit( + cfg, + CheckpointStartedEvent( + location=cfg.location, + provider=provider_name, + trigger=trigger, + branch=branch_snapshot, + parent_id=parent_id_snapshot, + **context, + ), + ) + + _prepare_entities(state.root) + payload = state.model_dump(mode="json") + if event is not None: + payload["trigger"] = event.type + data = json.dumps(payload) + location = cfg.provider.checkpoint( + data, + cfg.location, + parent_id=parent_id_snapshot, + branch=branch_snapshot, + ) + state._chain_lineage(cfg.provider, location) + checkpoint_id: str = cfg.provider.extract_id(location) except Exception as exc: crewai_event_bus.emit( cfg, diff --git a/lib/crewai/src/crewai/state/provider/json_provider.py b/lib/crewai/src/crewai/state/provider/json_provider.py index 904526292..f234180d5 100644 --- a/lib/crewai/src/crewai/state/provider/json_provider.py +++ b/lib/crewai/src/crewai/state/provider/json_provider.py @@ -2,11 +2,14 @@ from __future__ import annotations +import asyncio from datetime import datetime, timezone import glob import logging import os from pathlib import Path +import tempfile +import threading from typing import Literal import uuid @@ -35,10 +38,19 @@ def _safe_branch(base: str, branch: str) -> None: class JsonProvider(BaseProvider): - """Persists runtime state checkpoints as JSON files on the local filesystem.""" + """Persists runtime state checkpoints as JSON files on the local filesystem. + + File writes are atomic (write-to-temp then ``os.replace``) and serialized + by an internal lock so concurrent callers cannot create diverging lineage. + """ provider_type: Literal["json"] = "json" + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._lock = threading.Lock() + self._async_lock = asyncio.Lock() + def checkpoint( self, data: str, @@ -47,7 +59,11 @@ class JsonProvider(BaseProvider): parent_id: str | None = None, branch: str = "main", ) -> str: - """Write a JSON checkpoint file. + """Write a JSON checkpoint file atomically. + + Uses write-to-temp + ``os.replace()`` to guarantee the checkpoint + file is never partially written. A threading lock serializes + concurrent writes to prevent lineage divergence. Args: data: The serialized JSON string to persist. @@ -60,12 +76,25 @@ class JsonProvider(BaseProvider): Returns: The path to the written checkpoint file. """ - file_path = _build_path(location, branch, parent_id) - file_path.parent.mkdir(parents=True, exist_ok=True) + with self._lock: + file_path = _build_path(location, branch, parent_id) + file_path.parent.mkdir(parents=True, exist_ok=True) - with open(file_path, "w") as f: - f.write(data) - return str(file_path) + fd, tmp_path = tempfile.mkstemp(dir=str(file_path.parent), suffix=".tmp") + try: + with os.fdopen(fd, "w") as f: + f.write(data) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp_path, str(file_path)) + except BaseException: + # Clean up temp file on failure + try: + os.unlink(tmp_path) + except OSError: + pass + raise + return str(file_path) async def acheckpoint( self, @@ -75,7 +104,11 @@ class JsonProvider(BaseProvider): parent_id: str | None = None, branch: str = "main", ) -> str: - """Write a JSON checkpoint file asynchronously. + """Write a JSON checkpoint file atomically and asynchronously. + + Uses write-to-temp + ``os.replace()`` to guarantee the checkpoint + file is never partially written. An asyncio lock serializes + concurrent writes to prevent lineage divergence. Args: data: The serialized JSON string to persist. @@ -88,12 +121,24 @@ class JsonProvider(BaseProvider): Returns: The path to the written checkpoint file. """ - file_path = _build_path(location, branch, parent_id) - await aiofiles.os.makedirs(str(file_path.parent), exist_ok=True) + async with self._async_lock: + file_path = _build_path(location, branch, parent_id) + await aiofiles.os.makedirs(str(file_path.parent), exist_ok=True) - async with aiofiles.open(file_path, "w") as f: - await f.write(data) - return str(file_path) + fd, tmp_path = tempfile.mkstemp(dir=str(file_path.parent), suffix=".tmp") + try: + with os.fdopen(fd, "w") as f: + f.write(data) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp_path, str(file_path)) + except BaseException: + try: + os.unlink(tmp_path) + except OSError: + pass + raise + return str(file_path) def prune(self, location: str, max_keep: int, *, branch: str = "main") -> int: """Remove oldest checkpoint files beyond *max_keep* on a branch.""" diff --git a/lib/crewai/src/crewai/state/runtime.py b/lib/crewai/src/crewai/state/runtime.py index 7c6f78643..bb2eea6fe 100644 --- a/lib/crewai/src/crewai/state/runtime.py +++ b/lib/crewai/src/crewai/state/runtime.py @@ -9,7 +9,9 @@ via ``RuntimeState.model_rebuild()``. from __future__ import annotations +import asyncio import logging +import threading import time from typing import TYPE_CHECKING, Any import uuid @@ -181,6 +183,8 @@ class RuntimeState(RootModel): # type: ignore[type-arg] _checkpoint_id: str | None = PrivateAttr(default=None) _parent_id: str | None = PrivateAttr(default=None) _branch: str = PrivateAttr(default="main") + _lineage_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock) + _async_lineage_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock) @property def event_record(self) -> EventRecord: @@ -286,6 +290,10 @@ class RuntimeState(RootModel): # type: ignore[type-arg] def checkpoint(self, location: str) -> str: """Write a checkpoint. + The lineage lock ensures that concurrent callers serialize their + read of ``_parent_id``, write, and update — preventing lineage + divergence. + Args: location: Storage destination. For JsonProvider this is a directory path; for SqliteProvider it is a database file path. @@ -293,32 +301,37 @@ class RuntimeState(RootModel): # type: ignore[type-arg] Returns: A location identifier for the saved checkpoint. """ - provider_name, parent_id_snapshot, branch_snapshot, start = ( - self._begin_checkpoint(location) - ) - try: - _prepare_entities(self.root) - result = self._provider.checkpoint( - self.model_dump_json(), - location, - parent_id=parent_id_snapshot, - branch=branch_snapshot, + with self._lineage_lock: + provider_name, parent_id_snapshot, branch_snapshot, start = ( + self._begin_checkpoint(location) ) - self._chain_lineage(self._provider, result) - except Exception as exc: - self._emit_checkpoint_failed( - location, provider_name, branch_snapshot, parent_id_snapshot, exc - ) - raise + try: + _prepare_entities(self.root) + result = self._provider.checkpoint( + self.model_dump_json(), + location, + parent_id=parent_id_snapshot, + branch=branch_snapshot, + ) + self._chain_lineage(self._provider, result) + except Exception as exc: + self._emit_checkpoint_failed( + location, provider_name, branch_snapshot, parent_id_snapshot, exc + ) + raise - self._emit_checkpoint_completed( - result, provider_name, branch_snapshot, parent_id_snapshot, start - ) - return result + self._emit_checkpoint_completed( + result, provider_name, branch_snapshot, parent_id_snapshot, start + ) + return result async def acheckpoint(self, location: str) -> str: """Async version of :meth:`checkpoint`. + The async lineage lock ensures that concurrent coroutines serialize + their read of ``_parent_id``, write, and update — preventing lineage + divergence. + Args: location: Storage destination. For JsonProvider this is a directory path; for SqliteProvider it is a database file path. @@ -326,28 +339,29 @@ class RuntimeState(RootModel): # type: ignore[type-arg] Returns: A location identifier for the saved checkpoint. """ - provider_name, parent_id_snapshot, branch_snapshot, start = ( - self._begin_checkpoint(location) - ) - try: - _prepare_entities(self.root) - result = await self._provider.acheckpoint( - self.model_dump_json(), - location, - parent_id=parent_id_snapshot, - branch=branch_snapshot, + async with self._async_lineage_lock: + provider_name, parent_id_snapshot, branch_snapshot, start = ( + self._begin_checkpoint(location) ) - self._chain_lineage(self._provider, result) - except Exception as exc: - self._emit_checkpoint_failed( - location, provider_name, branch_snapshot, parent_id_snapshot, exc - ) - raise + try: + _prepare_entities(self.root) + result = await self._provider.acheckpoint( + self.model_dump_json(), + location, + parent_id=parent_id_snapshot, + branch=branch_snapshot, + ) + self._chain_lineage(self._provider, result) + except Exception as exc: + self._emit_checkpoint_failed( + location, provider_name, branch_snapshot, parent_id_snapshot, exc + ) + raise - self._emit_checkpoint_completed( - result, provider_name, branch_snapshot, parent_id_snapshot, start - ) - return result + self._emit_checkpoint_completed( + result, provider_name, branch_snapshot, parent_id_snapshot, start + ) + return result def fork(self, branch: str | None = None) -> None: """Create a new execution branch and write an initial checkpoint. diff --git a/lib/crewai/tests/test_checkpoint.py b/lib/crewai/tests/test_checkpoint.py index 8cd7cf399..7bc787a83 100644 --- a/lib/crewai/tests/test_checkpoint.py +++ b/lib/crewai/tests/test_checkpoint.py @@ -2,6 +2,8 @@ from __future__ import annotations +import asyncio +import concurrent.futures import inspect import json import os @@ -766,3 +768,208 @@ class TestCustomLLMCheckpointRestore: assert isinstance(llm, BaseLLM) assert not inspect.isabstract(type(llm)) assert llm.model == "stub" + + +class TestJsonProviderAtomicWrites: + """Verify that JsonProvider writes are atomic and no partial files appear.""" + + def test_checkpoint_file_is_complete(self) -> None: + """After checkpoint(), the file must contain exactly the data written.""" + provider = JsonProvider() + payload = json.dumps({"key": "value", "nested": {"a": 1}}) + with tempfile.TemporaryDirectory() as d: + path = provider.checkpoint(payload, d, branch="main") + with open(path) as f: + assert f.read() == payload + + def test_no_temp_files_left_on_success(self) -> None: + """Successful writes must not leave .tmp files behind.""" + provider = JsonProvider() + with tempfile.TemporaryDirectory() as d: + provider.checkpoint("{}", d, branch="main") + branch_dir = os.path.join(d, "main") + tmp_files = [f for f in os.listdir(branch_dir) if f.endswith(".tmp")] + assert tmp_files == [] + + def test_no_temp_files_left_on_failure(self) -> None: + """A failed write must clean up its temp file.""" + provider = JsonProvider() + with tempfile.TemporaryDirectory() as d: + branch_dir = os.path.join(d, "main") + os.makedirs(branch_dir) + # Make the target file path a directory to force os.replace to fail + fake_target = os.path.join(branch_dir, "blocker") + os.makedirs(fake_target) + # Patch _build_path to return our blocker path + from crewai.state.provider import json_provider as jp + original_build = jp._build_path + from pathlib import Path + def bad_build(*a, **kw): + return Path(fake_target) + jp._build_path = bad_build + try: + with pytest.raises(OSError): + provider.checkpoint("{}", d, branch="main") + finally: + jp._build_path = original_build + # No leftover .tmp files + tmp_files = [f for f in os.listdir(branch_dir) if f.endswith(".tmp")] + assert tmp_files == [] + + @pytest.mark.asyncio + async def test_acheckpoint_file_is_complete(self) -> None: + """Async checkpoint must produce a complete file.""" + provider = JsonProvider() + payload = json.dumps({"async": True, "data": list(range(100))}) + with tempfile.TemporaryDirectory() as d: + path = await provider.acheckpoint(payload, d, branch="main") + with open(path) as f: + assert f.read() == payload + + @pytest.mark.asyncio + async def test_acheckpoint_no_temp_files(self) -> None: + """Async writes must not leave .tmp files.""" + provider = JsonProvider() + with tempfile.TemporaryDirectory() as d: + await provider.acheckpoint("{}", d, branch="main") + branch_dir = os.path.join(d, "main") + tmp_files = [f for f in os.listdir(branch_dir) if f.endswith(".tmp")] + assert tmp_files == [] + + +class TestJsonProviderConcurrency: + """Verify that concurrent checkpoint writes do not lose data or diverge lineage.""" + + def _make_state(self) -> RuntimeState: + from crewai import Agent, Crew + agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini") + crew = Crew(agents=[agent], tasks=[], verbose=False) + return RuntimeState(root=[crew]) + + def test_concurrent_sync_checkpoints_preserve_lineage(self) -> None: + """Multiple threads writing checkpoints must form a linear chain.""" + state = self._make_state() + state._provider = JsonProvider() + num_writers = 5 + writes_per_thread = 10 + + with tempfile.TemporaryDirectory() as d: + def writer(): + for _ in range(writes_per_thread): + state.checkpoint(d) + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_writers) as pool: + futures = [pool.submit(writer) for _ in range(num_writers)] + for f in futures: + f.result() + + # All checkpoint files should exist + branch_dir = os.path.join(d, "main") + # Sort by filename which encodes timestamp + uuid + files = sorted( + [f for f in os.listdir(branch_dir) if f.endswith(".json")], + ) + total_expected = num_writers * writes_per_thread + assert len(files) == total_expected + + # Verify lineage: each checkpoint's parent_id must refer to an + # existing earlier checkpoint (or "none" for the root). With the + # lock, exactly one checkpoint should have parent "none" and all + # others must form a single linear chain. + provider = JsonProvider() + ids = [provider.extract_id(os.path.join(branch_dir, f)) for f in files] + parent_ids = [] + for f in files: + stem = os.path.splitext(f)[0] + idx = stem.find("_p-") + parent_ids.append(stem[idx + 3:] if idx != -1 else "none") + + # Exactly one root checkpoint (parent "none") + assert parent_ids.count("none") == 1, ( + f"Expected exactly 1 root checkpoint, got {parent_ids.count('none')}" + ) + # Every non-root checkpoint must reference a valid earlier checkpoint id + id_set = set(ids) + for i, parent in enumerate(parent_ids): + if parent == "none": + continue + assert parent in id_set, ( + f"Checkpoint {i} has parent {parent!r} which is not " + f"a known checkpoint id — lineage diverged" + ) + + @pytest.mark.asyncio + async def test_concurrent_async_checkpoints_preserve_lineage(self) -> None: + """Multiple async tasks writing checkpoints must form a linear chain.""" + state = self._make_state() + state._provider = JsonProvider() + num_tasks = 5 + writes_per_task = 10 + + with tempfile.TemporaryDirectory() as d: + async def writer(): + for _ in range(writes_per_task): + await state.acheckpoint(d) + + await asyncio.gather(*(writer() for _ in range(num_tasks))) + + branch_dir = os.path.join(d, "main") + files = sorted( + [f for f in os.listdir(branch_dir) if f.endswith(".json")], + ) + total_expected = num_tasks * writes_per_task + assert len(files) == total_expected + + provider = JsonProvider() + ids = [provider.extract_id(os.path.join(branch_dir, f)) for f in files] + parent_ids = [] + for f in files: + stem = os.path.splitext(f)[0] + idx = stem.find("_p-") + parent_ids.append(stem[idx + 3:] if idx != -1 else "none") + + # Exactly one root checkpoint (parent "none") + assert parent_ids.count("none") == 1, ( + f"Expected exactly 1 root checkpoint, got {parent_ids.count('none')}" + ) + # Every non-root checkpoint must reference a valid checkpoint id + id_set = set(ids) + for i, parent in enumerate(parent_ids): + if parent == "none": + continue + assert parent in id_set, ( + f"Checkpoint {i} has parent {parent!r} which is not " + f"a known checkpoint id — lineage diverged" + ) + + def test_concurrent_checkpoints_all_files_valid_json(self) -> None: + """Every checkpoint file produced by concurrent writers must be valid JSON.""" + state = self._make_state() + state._provider = JsonProvider() + num_writers = 5 + writes_per_thread = 10 + + with tempfile.TemporaryDirectory() as d: + def writer(): + for _ in range(writes_per_thread): + state.checkpoint(d) + + with concurrent.futures.ThreadPoolExecutor(max_workers=num_writers) as pool: + futures = [pool.submit(writer) for _ in range(num_writers)] + for f in futures: + f.result() + + branch_dir = os.path.join(d, "main") + for filename in os.listdir(branch_dir): + if not filename.endswith(".json"): + continue + filepath = os.path.join(branch_dir, filename) + with open(filepath) as fh: + content = fh.read() + try: + json.loads(content) + except json.JSONDecodeError: + pytest.fail( + f"Checkpoint {filename} contains invalid JSON " + f"(partial write?): {content[:200]!r}" + )