fix: atomic writes and locking for JsonProvider checkpoints (#6125)

- JsonProvider.checkpoint() and acheckpoint() now write to a temp file
  then os.replace() for atomic file updates (no partial writes)
- Added threading.Lock to JsonProvider for sync path serialization
- Added asyncio.Lock to JsonProvider for async path serialization
- Added threading.Lock (_lineage_lock) to RuntimeState to protect
  read-update of _parent_id across concurrent checkpoint calls
- Added asyncio.Lock (_async_lineage_lock) to RuntimeState for async
- checkpoint_listener._do_checkpoint() now acquires _lineage_lock for
  the full read-write-update cycle
- Temp files are cleaned up on write failure
- Added 8 new tests: atomic write correctness, temp file cleanup on
  success/failure, async equivalents, and concurrent sync/async
  lineage preservation with 5 writers × 10 writes each

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2026-06-11 19:27:05 +00:00
parent 21fa8e32d9
commit 1b1639b862
4 changed files with 353 additions and 82 deletions

View File

@@ -113,7 +113,11 @@ def _find_checkpoint(source: Any) -> CheckpointConfig | None:
def _do_checkpoint(
state: RuntimeState, cfg: CheckpointConfig, event: BaseEvent | None = None
) -> None:
"""Write a checkpoint and prune old ones if configured."""
"""Write a checkpoint and prune old ones if configured.
The state's lineage lock is held for the entire read-write-update cycle
so concurrent callers see consistent ``_parent_id`` values.
"""
provider_name: str = type(cfg.provider).__name__
trigger: str | None = event.type if event is not None else None
context: dict[str, Any] = {
@@ -123,36 +127,37 @@ def _do_checkpoint(
"agent_role": event.agent_role if event is not None else None,
}
parent_id_snapshot: str | None = state._parent_id
branch_snapshot: str = state._branch
crewai_event_bus.emit(
cfg,
CheckpointStartedEvent(
location=cfg.location,
provider=provider_name,
trigger=trigger,
branch=branch_snapshot,
parent_id=parent_id_snapshot,
**context,
),
)
start: float = time.perf_counter()
try:
_prepare_entities(state.root)
payload = state.model_dump(mode="json")
if event is not None:
payload["trigger"] = event.type
data = json.dumps(payload)
location = cfg.provider.checkpoint(
data,
cfg.location,
parent_id=parent_id_snapshot,
branch=branch_snapshot,
)
state._chain_lineage(cfg.provider, location)
checkpoint_id: str = cfg.provider.extract_id(location)
with state._lineage_lock:
parent_id_snapshot: str | None = state._parent_id
branch_snapshot: str = state._branch
crewai_event_bus.emit(
cfg,
CheckpointStartedEvent(
location=cfg.location,
provider=provider_name,
trigger=trigger,
branch=branch_snapshot,
parent_id=parent_id_snapshot,
**context,
),
)
_prepare_entities(state.root)
payload = state.model_dump(mode="json")
if event is not None:
payload["trigger"] = event.type
data = json.dumps(payload)
location = cfg.provider.checkpoint(
data,
cfg.location,
parent_id=parent_id_snapshot,
branch=branch_snapshot,
)
state._chain_lineage(cfg.provider, location)
checkpoint_id: str = cfg.provider.extract_id(location)
except Exception as exc:
crewai_event_bus.emit(
cfg,

View File

@@ -2,11 +2,14 @@
from __future__ import annotations
import asyncio
from datetime import datetime, timezone
import glob
import logging
import os
from pathlib import Path
import tempfile
import threading
from typing import Literal
import uuid
@@ -35,10 +38,19 @@ def _safe_branch(base: str, branch: str) -> None:
class JsonProvider(BaseProvider):
"""Persists runtime state checkpoints as JSON files on the local filesystem."""
"""Persists runtime state checkpoints as JSON files on the local filesystem.
File writes are atomic (write-to-temp then ``os.replace``) and serialized
by an internal lock so concurrent callers cannot create diverging lineage.
"""
provider_type: Literal["json"] = "json"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._lock = threading.Lock()
self._async_lock = asyncio.Lock()
def checkpoint(
self,
data: str,
@@ -47,7 +59,11 @@ class JsonProvider(BaseProvider):
parent_id: str | None = None,
branch: str = "main",
) -> str:
"""Write a JSON checkpoint file.
"""Write a JSON checkpoint file atomically.
Uses write-to-temp + ``os.replace()`` to guarantee the checkpoint
file is never partially written. A threading lock serializes
concurrent writes to prevent lineage divergence.
Args:
data: The serialized JSON string to persist.
@@ -60,12 +76,25 @@ class JsonProvider(BaseProvider):
Returns:
The path to the written checkpoint file.
"""
file_path = _build_path(location, branch, parent_id)
file_path.parent.mkdir(parents=True, exist_ok=True)
with self._lock:
file_path = _build_path(location, branch, parent_id)
file_path.parent.mkdir(parents=True, exist_ok=True)
with open(file_path, "w") as f:
f.write(data)
return str(file_path)
fd, tmp_path = tempfile.mkstemp(dir=str(file_path.parent), suffix=".tmp")
try:
with os.fdopen(fd, "w") as f:
f.write(data)
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, str(file_path))
except BaseException:
# Clean up temp file on failure
try:
os.unlink(tmp_path)
except OSError:
pass
raise
return str(file_path)
async def acheckpoint(
self,
@@ -75,7 +104,11 @@ class JsonProvider(BaseProvider):
parent_id: str | None = None,
branch: str = "main",
) -> str:
"""Write a JSON checkpoint file asynchronously.
"""Write a JSON checkpoint file atomically and asynchronously.
Uses write-to-temp + ``os.replace()`` to guarantee the checkpoint
file is never partially written. An asyncio lock serializes
concurrent writes to prevent lineage divergence.
Args:
data: The serialized JSON string to persist.
@@ -88,12 +121,24 @@ class JsonProvider(BaseProvider):
Returns:
The path to the written checkpoint file.
"""
file_path = _build_path(location, branch, parent_id)
await aiofiles.os.makedirs(str(file_path.parent), exist_ok=True)
async with self._async_lock:
file_path = _build_path(location, branch, parent_id)
await aiofiles.os.makedirs(str(file_path.parent), exist_ok=True)
async with aiofiles.open(file_path, "w") as f:
await f.write(data)
return str(file_path)
fd, tmp_path = tempfile.mkstemp(dir=str(file_path.parent), suffix=".tmp")
try:
with os.fdopen(fd, "w") as f:
f.write(data)
f.flush()
os.fsync(f.fileno())
os.replace(tmp_path, str(file_path))
except BaseException:
try:
os.unlink(tmp_path)
except OSError:
pass
raise
return str(file_path)
def prune(self, location: str, max_keep: int, *, branch: str = "main") -> int:
"""Remove oldest checkpoint files beyond *max_keep* on a branch."""

View File

@@ -9,7 +9,9 @@ via ``RuntimeState.model_rebuild()``.
from __future__ import annotations
import asyncio
import logging
import threading
import time
from typing import TYPE_CHECKING, Any
import uuid
@@ -181,6 +183,8 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
_checkpoint_id: str | None = PrivateAttr(default=None)
_parent_id: str | None = PrivateAttr(default=None)
_branch: str = PrivateAttr(default="main")
_lineage_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
_async_lineage_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
@property
def event_record(self) -> EventRecord:
@@ -286,6 +290,10 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
def checkpoint(self, location: str) -> str:
"""Write a checkpoint.
The lineage lock ensures that concurrent callers serialize their
read of ``_parent_id``, write, and update — preventing lineage
divergence.
Args:
location: Storage destination. For JsonProvider this is a directory
path; for SqliteProvider it is a database file path.
@@ -293,32 +301,37 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
Returns:
A location identifier for the saved checkpoint.
"""
provider_name, parent_id_snapshot, branch_snapshot, start = (
self._begin_checkpoint(location)
)
try:
_prepare_entities(self.root)
result = self._provider.checkpoint(
self.model_dump_json(),
location,
parent_id=parent_id_snapshot,
branch=branch_snapshot,
with self._lineage_lock:
provider_name, parent_id_snapshot, branch_snapshot, start = (
self._begin_checkpoint(location)
)
self._chain_lineage(self._provider, result)
except Exception as exc:
self._emit_checkpoint_failed(
location, provider_name, branch_snapshot, parent_id_snapshot, exc
)
raise
try:
_prepare_entities(self.root)
result = self._provider.checkpoint(
self.model_dump_json(),
location,
parent_id=parent_id_snapshot,
branch=branch_snapshot,
)
self._chain_lineage(self._provider, result)
except Exception as exc:
self._emit_checkpoint_failed(
location, provider_name, branch_snapshot, parent_id_snapshot, exc
)
raise
self._emit_checkpoint_completed(
result, provider_name, branch_snapshot, parent_id_snapshot, start
)
return result
self._emit_checkpoint_completed(
result, provider_name, branch_snapshot, parent_id_snapshot, start
)
return result
async def acheckpoint(self, location: str) -> str:
"""Async version of :meth:`checkpoint`.
The async lineage lock ensures that concurrent coroutines serialize
their read of ``_parent_id``, write, and update — preventing lineage
divergence.
Args:
location: Storage destination. For JsonProvider this is a directory
path; for SqliteProvider it is a database file path.
@@ -326,28 +339,29 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
Returns:
A location identifier for the saved checkpoint.
"""
provider_name, parent_id_snapshot, branch_snapshot, start = (
self._begin_checkpoint(location)
)
try:
_prepare_entities(self.root)
result = await self._provider.acheckpoint(
self.model_dump_json(),
location,
parent_id=parent_id_snapshot,
branch=branch_snapshot,
async with self._async_lineage_lock:
provider_name, parent_id_snapshot, branch_snapshot, start = (
self._begin_checkpoint(location)
)
self._chain_lineage(self._provider, result)
except Exception as exc:
self._emit_checkpoint_failed(
location, provider_name, branch_snapshot, parent_id_snapshot, exc
)
raise
try:
_prepare_entities(self.root)
result = await self._provider.acheckpoint(
self.model_dump_json(),
location,
parent_id=parent_id_snapshot,
branch=branch_snapshot,
)
self._chain_lineage(self._provider, result)
except Exception as exc:
self._emit_checkpoint_failed(
location, provider_name, branch_snapshot, parent_id_snapshot, exc
)
raise
self._emit_checkpoint_completed(
result, provider_name, branch_snapshot, parent_id_snapshot, start
)
return result
self._emit_checkpoint_completed(
result, provider_name, branch_snapshot, parent_id_snapshot, start
)
return result
def fork(self, branch: str | None = None) -> None:
"""Create a new execution branch and write an initial checkpoint.

View File

@@ -2,6 +2,8 @@
from __future__ import annotations
import asyncio
import concurrent.futures
import inspect
import json
import os
@@ -766,3 +768,208 @@ class TestCustomLLMCheckpointRestore:
assert isinstance(llm, BaseLLM)
assert not inspect.isabstract(type(llm))
assert llm.model == "stub"
class TestJsonProviderAtomicWrites:
"""Verify that JsonProvider writes are atomic and no partial files appear."""
def test_checkpoint_file_is_complete(self) -> None:
"""After checkpoint(), the file must contain exactly the data written."""
provider = JsonProvider()
payload = json.dumps({"key": "value", "nested": {"a": 1}})
with tempfile.TemporaryDirectory() as d:
path = provider.checkpoint(payload, d, branch="main")
with open(path) as f:
assert f.read() == payload
def test_no_temp_files_left_on_success(self) -> None:
"""Successful writes must not leave .tmp files behind."""
provider = JsonProvider()
with tempfile.TemporaryDirectory() as d:
provider.checkpoint("{}", d, branch="main")
branch_dir = os.path.join(d, "main")
tmp_files = [f for f in os.listdir(branch_dir) if f.endswith(".tmp")]
assert tmp_files == []
def test_no_temp_files_left_on_failure(self) -> None:
"""A failed write must clean up its temp file."""
provider = JsonProvider()
with tempfile.TemporaryDirectory() as d:
branch_dir = os.path.join(d, "main")
os.makedirs(branch_dir)
# Make the target file path a directory to force os.replace to fail
fake_target = os.path.join(branch_dir, "blocker")
os.makedirs(fake_target)
# Patch _build_path to return our blocker path
from crewai.state.provider import json_provider as jp
original_build = jp._build_path
from pathlib import Path
def bad_build(*a, **kw):
return Path(fake_target)
jp._build_path = bad_build
try:
with pytest.raises(OSError):
provider.checkpoint("{}", d, branch="main")
finally:
jp._build_path = original_build
# No leftover .tmp files
tmp_files = [f for f in os.listdir(branch_dir) if f.endswith(".tmp")]
assert tmp_files == []
@pytest.mark.asyncio
async def test_acheckpoint_file_is_complete(self) -> None:
"""Async checkpoint must produce a complete file."""
provider = JsonProvider()
payload = json.dumps({"async": True, "data": list(range(100))})
with tempfile.TemporaryDirectory() as d:
path = await provider.acheckpoint(payload, d, branch="main")
with open(path) as f:
assert f.read() == payload
@pytest.mark.asyncio
async def test_acheckpoint_no_temp_files(self) -> None:
"""Async writes must not leave .tmp files."""
provider = JsonProvider()
with tempfile.TemporaryDirectory() as d:
await provider.acheckpoint("{}", d, branch="main")
branch_dir = os.path.join(d, "main")
tmp_files = [f for f in os.listdir(branch_dir) if f.endswith(".tmp")]
assert tmp_files == []
class TestJsonProviderConcurrency:
"""Verify that concurrent checkpoint writes do not lose data or diverge lineage."""
def _make_state(self) -> RuntimeState:
from crewai import Agent, Crew
agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini")
crew = Crew(agents=[agent], tasks=[], verbose=False)
return RuntimeState(root=[crew])
def test_concurrent_sync_checkpoints_preserve_lineage(self) -> None:
"""Multiple threads writing checkpoints must form a linear chain."""
state = self._make_state()
state._provider = JsonProvider()
num_writers = 5
writes_per_thread = 10
with tempfile.TemporaryDirectory() as d:
def writer():
for _ in range(writes_per_thread):
state.checkpoint(d)
with concurrent.futures.ThreadPoolExecutor(max_workers=num_writers) as pool:
futures = [pool.submit(writer) for _ in range(num_writers)]
for f in futures:
f.result()
# All checkpoint files should exist
branch_dir = os.path.join(d, "main")
# Sort by filename which encodes timestamp + uuid
files = sorted(
[f for f in os.listdir(branch_dir) if f.endswith(".json")],
)
total_expected = num_writers * writes_per_thread
assert len(files) == total_expected
# Verify lineage: each checkpoint's parent_id must refer to an
# existing earlier checkpoint (or "none" for the root). With the
# lock, exactly one checkpoint should have parent "none" and all
# others must form a single linear chain.
provider = JsonProvider()
ids = [provider.extract_id(os.path.join(branch_dir, f)) for f in files]
parent_ids = []
for f in files:
stem = os.path.splitext(f)[0]
idx = stem.find("_p-")
parent_ids.append(stem[idx + 3:] if idx != -1 else "none")
# Exactly one root checkpoint (parent "none")
assert parent_ids.count("none") == 1, (
f"Expected exactly 1 root checkpoint, got {parent_ids.count('none')}"
)
# Every non-root checkpoint must reference a valid earlier checkpoint id
id_set = set(ids)
for i, parent in enumerate(parent_ids):
if parent == "none":
continue
assert parent in id_set, (
f"Checkpoint {i} has parent {parent!r} which is not "
f"a known checkpoint id — lineage diverged"
)
@pytest.mark.asyncio
async def test_concurrent_async_checkpoints_preserve_lineage(self) -> None:
"""Multiple async tasks writing checkpoints must form a linear chain."""
state = self._make_state()
state._provider = JsonProvider()
num_tasks = 5
writes_per_task = 10
with tempfile.TemporaryDirectory() as d:
async def writer():
for _ in range(writes_per_task):
await state.acheckpoint(d)
await asyncio.gather(*(writer() for _ in range(num_tasks)))
branch_dir = os.path.join(d, "main")
files = sorted(
[f for f in os.listdir(branch_dir) if f.endswith(".json")],
)
total_expected = num_tasks * writes_per_task
assert len(files) == total_expected
provider = JsonProvider()
ids = [provider.extract_id(os.path.join(branch_dir, f)) for f in files]
parent_ids = []
for f in files:
stem = os.path.splitext(f)[0]
idx = stem.find("_p-")
parent_ids.append(stem[idx + 3:] if idx != -1 else "none")
# Exactly one root checkpoint (parent "none")
assert parent_ids.count("none") == 1, (
f"Expected exactly 1 root checkpoint, got {parent_ids.count('none')}"
)
# Every non-root checkpoint must reference a valid checkpoint id
id_set = set(ids)
for i, parent in enumerate(parent_ids):
if parent == "none":
continue
assert parent in id_set, (
f"Checkpoint {i} has parent {parent!r} which is not "
f"a known checkpoint id — lineage diverged"
)
def test_concurrent_checkpoints_all_files_valid_json(self) -> None:
"""Every checkpoint file produced by concurrent writers must be valid JSON."""
state = self._make_state()
state._provider = JsonProvider()
num_writers = 5
writes_per_thread = 10
with tempfile.TemporaryDirectory() as d:
def writer():
for _ in range(writes_per_thread):
state.checkpoint(d)
with concurrent.futures.ThreadPoolExecutor(max_workers=num_writers) as pool:
futures = [pool.submit(writer) for _ in range(num_writers)]
for f in futures:
f.result()
branch_dir = os.path.join(d, "main")
for filename in os.listdir(branch_dir):
if not filename.endswith(".json"):
continue
filepath = os.path.join(branch_dir, filename)
with open(filepath) as fh:
content = fh.read()
try:
json.loads(content)
except json.JSONDecodeError:
pytest.fail(
f"Checkpoint {filename} contains invalid JSON "
f"(partial write?): {content[:200]!r}"
)