mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-03 22:19:27 +00:00
fix: atomic writes and locking for JsonProvider checkpoints (#6125)
- JsonProvider.checkpoint() and acheckpoint() now write to a temp file then os.replace() for atomic file updates (no partial writes) - Added threading.Lock to JsonProvider for sync path serialization - Added asyncio.Lock to JsonProvider for async path serialization - Added threading.Lock (_lineage_lock) to RuntimeState to protect read-update of _parent_id across concurrent checkpoint calls - Added asyncio.Lock (_async_lineage_lock) to RuntimeState for async - checkpoint_listener._do_checkpoint() now acquires _lineage_lock for the full read-write-update cycle - Temp files are cleaned up on write failure - Added 8 new tests: atomic write correctness, temp file cleanup on success/failure, async equivalents, and concurrent sync/async lineage preservation with 5 writers × 10 writes each Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -113,7 +113,11 @@ def _find_checkpoint(source: Any) -> CheckpointConfig | None:
|
||||
def _do_checkpoint(
|
||||
state: RuntimeState, cfg: CheckpointConfig, event: BaseEvent | None = None
|
||||
) -> None:
|
||||
"""Write a checkpoint and prune old ones if configured."""
|
||||
"""Write a checkpoint and prune old ones if configured.
|
||||
|
||||
The state's lineage lock is held for the entire read-write-update cycle
|
||||
so concurrent callers see consistent ``_parent_id`` values.
|
||||
"""
|
||||
provider_name: str = type(cfg.provider).__name__
|
||||
trigger: str | None = event.type if event is not None else None
|
||||
context: dict[str, Any] = {
|
||||
@@ -123,36 +127,37 @@ def _do_checkpoint(
|
||||
"agent_role": event.agent_role if event is not None else None,
|
||||
}
|
||||
|
||||
parent_id_snapshot: str | None = state._parent_id
|
||||
branch_snapshot: str = state._branch
|
||||
|
||||
crewai_event_bus.emit(
|
||||
cfg,
|
||||
CheckpointStartedEvent(
|
||||
location=cfg.location,
|
||||
provider=provider_name,
|
||||
trigger=trigger,
|
||||
branch=branch_snapshot,
|
||||
parent_id=parent_id_snapshot,
|
||||
**context,
|
||||
),
|
||||
)
|
||||
|
||||
start: float = time.perf_counter()
|
||||
try:
|
||||
_prepare_entities(state.root)
|
||||
payload = state.model_dump(mode="json")
|
||||
if event is not None:
|
||||
payload["trigger"] = event.type
|
||||
data = json.dumps(payload)
|
||||
location = cfg.provider.checkpoint(
|
||||
data,
|
||||
cfg.location,
|
||||
parent_id=parent_id_snapshot,
|
||||
branch=branch_snapshot,
|
||||
)
|
||||
state._chain_lineage(cfg.provider, location)
|
||||
checkpoint_id: str = cfg.provider.extract_id(location)
|
||||
with state._lineage_lock:
|
||||
parent_id_snapshot: str | None = state._parent_id
|
||||
branch_snapshot: str = state._branch
|
||||
|
||||
crewai_event_bus.emit(
|
||||
cfg,
|
||||
CheckpointStartedEvent(
|
||||
location=cfg.location,
|
||||
provider=provider_name,
|
||||
trigger=trigger,
|
||||
branch=branch_snapshot,
|
||||
parent_id=parent_id_snapshot,
|
||||
**context,
|
||||
),
|
||||
)
|
||||
|
||||
_prepare_entities(state.root)
|
||||
payload = state.model_dump(mode="json")
|
||||
if event is not None:
|
||||
payload["trigger"] = event.type
|
||||
data = json.dumps(payload)
|
||||
location = cfg.provider.checkpoint(
|
||||
data,
|
||||
cfg.location,
|
||||
parent_id=parent_id_snapshot,
|
||||
branch=branch_snapshot,
|
||||
)
|
||||
state._chain_lineage(cfg.provider, location)
|
||||
checkpoint_id: str = cfg.provider.extract_id(location)
|
||||
except Exception as exc:
|
||||
crewai_event_bus.emit(
|
||||
cfg,
|
||||
|
||||
@@ -2,11 +2,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import threading
|
||||
from typing import Literal
|
||||
import uuid
|
||||
|
||||
@@ -35,10 +38,19 @@ def _safe_branch(base: str, branch: str) -> None:
|
||||
|
||||
|
||||
class JsonProvider(BaseProvider):
|
||||
"""Persists runtime state checkpoints as JSON files on the local filesystem."""
|
||||
"""Persists runtime state checkpoints as JSON files on the local filesystem.
|
||||
|
||||
File writes are atomic (write-to-temp then ``os.replace``) and serialized
|
||||
by an internal lock so concurrent callers cannot create diverging lineage.
|
||||
"""
|
||||
|
||||
provider_type: Literal["json"] = "json"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._lock = threading.Lock()
|
||||
self._async_lock = asyncio.Lock()
|
||||
|
||||
def checkpoint(
|
||||
self,
|
||||
data: str,
|
||||
@@ -47,7 +59,11 @@ class JsonProvider(BaseProvider):
|
||||
parent_id: str | None = None,
|
||||
branch: str = "main",
|
||||
) -> str:
|
||||
"""Write a JSON checkpoint file.
|
||||
"""Write a JSON checkpoint file atomically.
|
||||
|
||||
Uses write-to-temp + ``os.replace()`` to guarantee the checkpoint
|
||||
file is never partially written. A threading lock serializes
|
||||
concurrent writes to prevent lineage divergence.
|
||||
|
||||
Args:
|
||||
data: The serialized JSON string to persist.
|
||||
@@ -60,12 +76,25 @@ class JsonProvider(BaseProvider):
|
||||
Returns:
|
||||
The path to the written checkpoint file.
|
||||
"""
|
||||
file_path = _build_path(location, branch, parent_id)
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self._lock:
|
||||
file_path = _build_path(location, branch, parent_id)
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(file_path, "w") as f:
|
||||
f.write(data)
|
||||
return str(file_path)
|
||||
fd, tmp_path = tempfile.mkstemp(dir=str(file_path.parent), suffix=".tmp")
|
||||
try:
|
||||
with os.fdopen(fd, "w") as f:
|
||||
f.write(data)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(tmp_path, str(file_path))
|
||||
except BaseException:
|
||||
# Clean up temp file on failure
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
return str(file_path)
|
||||
|
||||
async def acheckpoint(
|
||||
self,
|
||||
@@ -75,7 +104,11 @@ class JsonProvider(BaseProvider):
|
||||
parent_id: str | None = None,
|
||||
branch: str = "main",
|
||||
) -> str:
|
||||
"""Write a JSON checkpoint file asynchronously.
|
||||
"""Write a JSON checkpoint file atomically and asynchronously.
|
||||
|
||||
Uses write-to-temp + ``os.replace()`` to guarantee the checkpoint
|
||||
file is never partially written. An asyncio lock serializes
|
||||
concurrent writes to prevent lineage divergence.
|
||||
|
||||
Args:
|
||||
data: The serialized JSON string to persist.
|
||||
@@ -88,12 +121,24 @@ class JsonProvider(BaseProvider):
|
||||
Returns:
|
||||
The path to the written checkpoint file.
|
||||
"""
|
||||
file_path = _build_path(location, branch, parent_id)
|
||||
await aiofiles.os.makedirs(str(file_path.parent), exist_ok=True)
|
||||
async with self._async_lock:
|
||||
file_path = _build_path(location, branch, parent_id)
|
||||
await aiofiles.os.makedirs(str(file_path.parent), exist_ok=True)
|
||||
|
||||
async with aiofiles.open(file_path, "w") as f:
|
||||
await f.write(data)
|
||||
return str(file_path)
|
||||
fd, tmp_path = tempfile.mkstemp(dir=str(file_path.parent), suffix=".tmp")
|
||||
try:
|
||||
with os.fdopen(fd, "w") as f:
|
||||
f.write(data)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(tmp_path, str(file_path))
|
||||
except BaseException:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
raise
|
||||
return str(file_path)
|
||||
|
||||
def prune(self, location: str, max_keep: int, *, branch: str = "main") -> int:
|
||||
"""Remove oldest checkpoint files beyond *max_keep* on a branch."""
|
||||
|
||||
@@ -9,7 +9,9 @@ via ``RuntimeState.model_rebuild()``.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import uuid
|
||||
@@ -181,6 +183,8 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
_checkpoint_id: str | None = PrivateAttr(default=None)
|
||||
_parent_id: str | None = PrivateAttr(default=None)
|
||||
_branch: str = PrivateAttr(default="main")
|
||||
_lineage_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
|
||||
_async_lineage_lock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
|
||||
|
||||
@property
|
||||
def event_record(self) -> EventRecord:
|
||||
@@ -286,6 +290,10 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
def checkpoint(self, location: str) -> str:
|
||||
"""Write a checkpoint.
|
||||
|
||||
The lineage lock ensures that concurrent callers serialize their
|
||||
read of ``_parent_id``, write, and update — preventing lineage
|
||||
divergence.
|
||||
|
||||
Args:
|
||||
location: Storage destination. For JsonProvider this is a directory
|
||||
path; for SqliteProvider it is a database file path.
|
||||
@@ -293,32 +301,37 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
Returns:
|
||||
A location identifier for the saved checkpoint.
|
||||
"""
|
||||
provider_name, parent_id_snapshot, branch_snapshot, start = (
|
||||
self._begin_checkpoint(location)
|
||||
)
|
||||
try:
|
||||
_prepare_entities(self.root)
|
||||
result = self._provider.checkpoint(
|
||||
self.model_dump_json(),
|
||||
location,
|
||||
parent_id=parent_id_snapshot,
|
||||
branch=branch_snapshot,
|
||||
with self._lineage_lock:
|
||||
provider_name, parent_id_snapshot, branch_snapshot, start = (
|
||||
self._begin_checkpoint(location)
|
||||
)
|
||||
self._chain_lineage(self._provider, result)
|
||||
except Exception as exc:
|
||||
self._emit_checkpoint_failed(
|
||||
location, provider_name, branch_snapshot, parent_id_snapshot, exc
|
||||
)
|
||||
raise
|
||||
try:
|
||||
_prepare_entities(self.root)
|
||||
result = self._provider.checkpoint(
|
||||
self.model_dump_json(),
|
||||
location,
|
||||
parent_id=parent_id_snapshot,
|
||||
branch=branch_snapshot,
|
||||
)
|
||||
self._chain_lineage(self._provider, result)
|
||||
except Exception as exc:
|
||||
self._emit_checkpoint_failed(
|
||||
location, provider_name, branch_snapshot, parent_id_snapshot, exc
|
||||
)
|
||||
raise
|
||||
|
||||
self._emit_checkpoint_completed(
|
||||
result, provider_name, branch_snapshot, parent_id_snapshot, start
|
||||
)
|
||||
return result
|
||||
self._emit_checkpoint_completed(
|
||||
result, provider_name, branch_snapshot, parent_id_snapshot, start
|
||||
)
|
||||
return result
|
||||
|
||||
async def acheckpoint(self, location: str) -> str:
|
||||
"""Async version of :meth:`checkpoint`.
|
||||
|
||||
The async lineage lock ensures that concurrent coroutines serialize
|
||||
their read of ``_parent_id``, write, and update — preventing lineage
|
||||
divergence.
|
||||
|
||||
Args:
|
||||
location: Storage destination. For JsonProvider this is a directory
|
||||
path; for SqliteProvider it is a database file path.
|
||||
@@ -326,28 +339,29 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
Returns:
|
||||
A location identifier for the saved checkpoint.
|
||||
"""
|
||||
provider_name, parent_id_snapshot, branch_snapshot, start = (
|
||||
self._begin_checkpoint(location)
|
||||
)
|
||||
try:
|
||||
_prepare_entities(self.root)
|
||||
result = await self._provider.acheckpoint(
|
||||
self.model_dump_json(),
|
||||
location,
|
||||
parent_id=parent_id_snapshot,
|
||||
branch=branch_snapshot,
|
||||
async with self._async_lineage_lock:
|
||||
provider_name, parent_id_snapshot, branch_snapshot, start = (
|
||||
self._begin_checkpoint(location)
|
||||
)
|
||||
self._chain_lineage(self._provider, result)
|
||||
except Exception as exc:
|
||||
self._emit_checkpoint_failed(
|
||||
location, provider_name, branch_snapshot, parent_id_snapshot, exc
|
||||
)
|
||||
raise
|
||||
try:
|
||||
_prepare_entities(self.root)
|
||||
result = await self._provider.acheckpoint(
|
||||
self.model_dump_json(),
|
||||
location,
|
||||
parent_id=parent_id_snapshot,
|
||||
branch=branch_snapshot,
|
||||
)
|
||||
self._chain_lineage(self._provider, result)
|
||||
except Exception as exc:
|
||||
self._emit_checkpoint_failed(
|
||||
location, provider_name, branch_snapshot, parent_id_snapshot, exc
|
||||
)
|
||||
raise
|
||||
|
||||
self._emit_checkpoint_completed(
|
||||
result, provider_name, branch_snapshot, parent_id_snapshot, start
|
||||
)
|
||||
return result
|
||||
self._emit_checkpoint_completed(
|
||||
result, provider_name, branch_snapshot, parent_id_snapshot, start
|
||||
)
|
||||
return result
|
||||
|
||||
def fork(self, branch: str | None = None) -> None:
|
||||
"""Create a new execution branch and write an initial checkpoint.
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
@@ -766,3 +768,208 @@ class TestCustomLLMCheckpointRestore:
|
||||
assert isinstance(llm, BaseLLM)
|
||||
assert not inspect.isabstract(type(llm))
|
||||
assert llm.model == "stub"
|
||||
|
||||
|
||||
class TestJsonProviderAtomicWrites:
|
||||
"""Verify that JsonProvider writes are atomic and no partial files appear."""
|
||||
|
||||
def test_checkpoint_file_is_complete(self) -> None:
|
||||
"""After checkpoint(), the file must contain exactly the data written."""
|
||||
provider = JsonProvider()
|
||||
payload = json.dumps({"key": "value", "nested": {"a": 1}})
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
path = provider.checkpoint(payload, d, branch="main")
|
||||
with open(path) as f:
|
||||
assert f.read() == payload
|
||||
|
||||
def test_no_temp_files_left_on_success(self) -> None:
|
||||
"""Successful writes must not leave .tmp files behind."""
|
||||
provider = JsonProvider()
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
provider.checkpoint("{}", d, branch="main")
|
||||
branch_dir = os.path.join(d, "main")
|
||||
tmp_files = [f for f in os.listdir(branch_dir) if f.endswith(".tmp")]
|
||||
assert tmp_files == []
|
||||
|
||||
def test_no_temp_files_left_on_failure(self) -> None:
|
||||
"""A failed write must clean up its temp file."""
|
||||
provider = JsonProvider()
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
branch_dir = os.path.join(d, "main")
|
||||
os.makedirs(branch_dir)
|
||||
# Make the target file path a directory to force os.replace to fail
|
||||
fake_target = os.path.join(branch_dir, "blocker")
|
||||
os.makedirs(fake_target)
|
||||
# Patch _build_path to return our blocker path
|
||||
from crewai.state.provider import json_provider as jp
|
||||
original_build = jp._build_path
|
||||
from pathlib import Path
|
||||
def bad_build(*a, **kw):
|
||||
return Path(fake_target)
|
||||
jp._build_path = bad_build
|
||||
try:
|
||||
with pytest.raises(OSError):
|
||||
provider.checkpoint("{}", d, branch="main")
|
||||
finally:
|
||||
jp._build_path = original_build
|
||||
# No leftover .tmp files
|
||||
tmp_files = [f for f in os.listdir(branch_dir) if f.endswith(".tmp")]
|
||||
assert tmp_files == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acheckpoint_file_is_complete(self) -> None:
|
||||
"""Async checkpoint must produce a complete file."""
|
||||
provider = JsonProvider()
|
||||
payload = json.dumps({"async": True, "data": list(range(100))})
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
path = await provider.acheckpoint(payload, d, branch="main")
|
||||
with open(path) as f:
|
||||
assert f.read() == payload
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acheckpoint_no_temp_files(self) -> None:
|
||||
"""Async writes must not leave .tmp files."""
|
||||
provider = JsonProvider()
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
await provider.acheckpoint("{}", d, branch="main")
|
||||
branch_dir = os.path.join(d, "main")
|
||||
tmp_files = [f for f in os.listdir(branch_dir) if f.endswith(".tmp")]
|
||||
assert tmp_files == []
|
||||
|
||||
|
||||
class TestJsonProviderConcurrency:
|
||||
"""Verify that concurrent checkpoint writes do not lose data or diverge lineage."""
|
||||
|
||||
def _make_state(self) -> RuntimeState:
|
||||
from crewai import Agent, Crew
|
||||
agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini")
|
||||
crew = Crew(agents=[agent], tasks=[], verbose=False)
|
||||
return RuntimeState(root=[crew])
|
||||
|
||||
def test_concurrent_sync_checkpoints_preserve_lineage(self) -> None:
|
||||
"""Multiple threads writing checkpoints must form a linear chain."""
|
||||
state = self._make_state()
|
||||
state._provider = JsonProvider()
|
||||
num_writers = 5
|
||||
writes_per_thread = 10
|
||||
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
def writer():
|
||||
for _ in range(writes_per_thread):
|
||||
state.checkpoint(d)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_writers) as pool:
|
||||
futures = [pool.submit(writer) for _ in range(num_writers)]
|
||||
for f in futures:
|
||||
f.result()
|
||||
|
||||
# All checkpoint files should exist
|
||||
branch_dir = os.path.join(d, "main")
|
||||
# Sort by filename which encodes timestamp + uuid
|
||||
files = sorted(
|
||||
[f for f in os.listdir(branch_dir) if f.endswith(".json")],
|
||||
)
|
||||
total_expected = num_writers * writes_per_thread
|
||||
assert len(files) == total_expected
|
||||
|
||||
# Verify lineage: each checkpoint's parent_id must refer to an
|
||||
# existing earlier checkpoint (or "none" for the root). With the
|
||||
# lock, exactly one checkpoint should have parent "none" and all
|
||||
# others must form a single linear chain.
|
||||
provider = JsonProvider()
|
||||
ids = [provider.extract_id(os.path.join(branch_dir, f)) for f in files]
|
||||
parent_ids = []
|
||||
for f in files:
|
||||
stem = os.path.splitext(f)[0]
|
||||
idx = stem.find("_p-")
|
||||
parent_ids.append(stem[idx + 3:] if idx != -1 else "none")
|
||||
|
||||
# Exactly one root checkpoint (parent "none")
|
||||
assert parent_ids.count("none") == 1, (
|
||||
f"Expected exactly 1 root checkpoint, got {parent_ids.count('none')}"
|
||||
)
|
||||
# Every non-root checkpoint must reference a valid earlier checkpoint id
|
||||
id_set = set(ids)
|
||||
for i, parent in enumerate(parent_ids):
|
||||
if parent == "none":
|
||||
continue
|
||||
assert parent in id_set, (
|
||||
f"Checkpoint {i} has parent {parent!r} which is not "
|
||||
f"a known checkpoint id — lineage diverged"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_async_checkpoints_preserve_lineage(self) -> None:
|
||||
"""Multiple async tasks writing checkpoints must form a linear chain."""
|
||||
state = self._make_state()
|
||||
state._provider = JsonProvider()
|
||||
num_tasks = 5
|
||||
writes_per_task = 10
|
||||
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
async def writer():
|
||||
for _ in range(writes_per_task):
|
||||
await state.acheckpoint(d)
|
||||
|
||||
await asyncio.gather(*(writer() for _ in range(num_tasks)))
|
||||
|
||||
branch_dir = os.path.join(d, "main")
|
||||
files = sorted(
|
||||
[f for f in os.listdir(branch_dir) if f.endswith(".json")],
|
||||
)
|
||||
total_expected = num_tasks * writes_per_task
|
||||
assert len(files) == total_expected
|
||||
|
||||
provider = JsonProvider()
|
||||
ids = [provider.extract_id(os.path.join(branch_dir, f)) for f in files]
|
||||
parent_ids = []
|
||||
for f in files:
|
||||
stem = os.path.splitext(f)[0]
|
||||
idx = stem.find("_p-")
|
||||
parent_ids.append(stem[idx + 3:] if idx != -1 else "none")
|
||||
|
||||
# Exactly one root checkpoint (parent "none")
|
||||
assert parent_ids.count("none") == 1, (
|
||||
f"Expected exactly 1 root checkpoint, got {parent_ids.count('none')}"
|
||||
)
|
||||
# Every non-root checkpoint must reference a valid checkpoint id
|
||||
id_set = set(ids)
|
||||
for i, parent in enumerate(parent_ids):
|
||||
if parent == "none":
|
||||
continue
|
||||
assert parent in id_set, (
|
||||
f"Checkpoint {i} has parent {parent!r} which is not "
|
||||
f"a known checkpoint id — lineage diverged"
|
||||
)
|
||||
|
||||
def test_concurrent_checkpoints_all_files_valid_json(self) -> None:
|
||||
"""Every checkpoint file produced by concurrent writers must be valid JSON."""
|
||||
state = self._make_state()
|
||||
state._provider = JsonProvider()
|
||||
num_writers = 5
|
||||
writes_per_thread = 10
|
||||
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
def writer():
|
||||
for _ in range(writes_per_thread):
|
||||
state.checkpoint(d)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_writers) as pool:
|
||||
futures = [pool.submit(writer) for _ in range(num_writers)]
|
||||
for f in futures:
|
||||
f.result()
|
||||
|
||||
branch_dir = os.path.join(d, "main")
|
||||
for filename in os.listdir(branch_dir):
|
||||
if not filename.endswith(".json"):
|
||||
continue
|
||||
filepath = os.path.join(branch_dir, filename)
|
||||
with open(filepath) as fh:
|
||||
content = fh.read()
|
||||
try:
|
||||
json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
pytest.fail(
|
||||
f"Checkpoint {filename} contains invalid JSON "
|
||||
f"(partial write?): {content[:200]!r}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user