feat: add checkpoint resume, diff, prune commands and save discoverability

Add three new CLI subcommands to improve checkpoint UX:

- `crewai checkpoint resume [id]` skips the TUI and resumes from the
  latest or specified checkpoint directly
- `crewai checkpoint diff <id1> <id2>` compares two checkpoints showing
  changes in metadata, inputs, task status, and outputs
- `crewai checkpoint prune --keep N --older-than Xd` removes old
  checkpoints from JSON dirs or SQLite databases

Also writes a resume hint to stderr after every checkpoint save so
users discover the command without needing to know it exists.
This commit is contained in:
Greyson LaLonde
2026-04-17 04:50:15 +08:00
committed by GitHub
parent 54391fdbdf
commit c5192b970c
4 changed files with 758 additions and 1 deletions

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
from datetime import datetime
from datetime import datetime, timedelta, timezone
import glob
import json
import os
@@ -37,6 +37,26 @@ ORDER BY rowid DESC
LIMIT 1
"""
_DELETE_OLDER_THAN = """
DELETE FROM checkpoints
WHERE created_at < ?
"""
_DELETE_KEEP_N = """
DELETE FROM checkpoints WHERE rowid NOT IN (
SELECT rowid FROM checkpoints ORDER BY rowid DESC LIMIT ?
)
"""
_COUNT_CHECKPOINTS = "SELECT COUNT(*) FROM checkpoints"
_SELECT_LIKE = """
SELECT id, created_at, json(data)
FROM checkpoints
WHERE id LIKE ?
ORDER BY rowid DESC
"""
_DEFAULT_DIR = "./.checkpoints"
_DEFAULT_DB = "./.checkpoints.db"
@@ -262,6 +282,8 @@ def _info_sqlite_latest(db_path: str) -> dict[str, Any] | None:
def _info_sqlite_id(db_path: str, checkpoint_id: str) -> dict[str, Any] | None:
with sqlite3.connect(db_path) as conn:
row = conn.execute(_SELECT_ONE, (checkpoint_id,)).fetchone()
if not row:
row = conn.execute(_SELECT_LIKE, (f"%{checkpoint_id}%",)).fetchone()
if not row:
return None
cid, created_at, raw = row
@@ -384,3 +406,287 @@ def _print_info(meta: dict[str, Any]) -> None:
if len(desc) > 70:
desc = desc[:67] + "..."
click.echo(f" {i + 1}. [{status}] {desc}")
def _resolve_checkpoint(
location: str, checkpoint_id: str | None
) -> dict[str, Any] | None:
if _is_sqlite(location):
if checkpoint_id:
return _info_sqlite_id(location, checkpoint_id)
return _info_sqlite_latest(location)
if os.path.isdir(location):
if checkpoint_id:
from crewai.state.provider.json_provider import JsonProvider
_json_provider: JsonProvider = JsonProvider()
pattern: str = os.path.join(location, "**", "*.json")
all_files: list[str] = glob.glob(pattern, recursive=True)
matches: list[str] = [
f for f in all_files if checkpoint_id in _json_provider.extract_id(f)
]
matches.sort(key=os.path.getmtime, reverse=True)
if matches:
return _info_json_file(matches[0])
return None
return _info_json_latest(location)
if os.path.isfile(location):
return _info_json_file(location)
return None
def _entity_type_from_meta(meta: dict[str, Any]) -> str:
for ent in meta.get("entities", []):
if ent.get("type") == "flow":
return "flow"
return "crew"
def resume_checkpoint(location: str, checkpoint_id: str | None) -> None:
import asyncio
meta: dict[str, Any] | None = _resolve_checkpoint(location, checkpoint_id)
if meta is None:
if checkpoint_id:
click.echo(f"Checkpoint not found: {checkpoint_id}")
else:
click.echo(f"No checkpoints found in {location}")
return
restore_path: str = meta.get("path") or meta.get("source", "")
if meta.get("db"):
restore_path = f"{meta['db']}#{meta['name']}"
click.echo(f"Resuming from: {meta.get('name', restore_path)}")
_print_info(meta)
click.echo()
from crewai.state.checkpoint_config import CheckpointConfig
config: CheckpointConfig = CheckpointConfig(restore_from=restore_path)
entity_type: str = _entity_type_from_meta(meta)
inputs: dict[str, Any] | None = meta.get("inputs") or None
if entity_type == "flow":
from crewai.flow.flow import Flow
flow = Flow.from_checkpoint(config)
result = asyncio.run(flow.kickoff_async(inputs=inputs))
else:
from crewai.crew import Crew
crew = Crew.from_checkpoint(config)
result = asyncio.run(crew.akickoff(inputs=inputs))
click.echo(f"\nResult: {getattr(result, 'raw', result)}")
def _task_list_from_meta(meta: dict[str, Any]) -> list[dict[str, Any]]:
tasks: list[dict[str, Any]] = []
for ent in meta.get("entities", []):
tasks.extend(
{
"entity": ent.get("name", "unnamed"),
"description": t.get("description", ""),
"completed": t.get("completed", False),
"output": t.get("output", ""),
}
for t in ent.get("tasks", [])
)
return tasks
def diff_checkpoints(location: str, id1: str, id2: str) -> None:
meta1: dict[str, Any] | None = _resolve_checkpoint(location, id1)
meta2: dict[str, Any] | None = _resolve_checkpoint(location, id2)
if meta1 is None:
click.echo(f"Checkpoint not found: {id1}")
return
if meta2 is None:
click.echo(f"Checkpoint not found: {id2}")
return
name1: str = meta1.get("name", id1)
name2: str = meta2.get("name", id2)
click.echo(f"--- {name1}")
click.echo(f"+++ {name2}")
click.echo()
fields: list[tuple[str, str]] = [
("Time", "ts"),
("Branch", "branch"),
("Trigger", "trigger"),
("Events", "event_count"),
]
for label, key in fields:
v1: str = str(meta1.get(key, ""))
v2: str = str(meta2.get(key, ""))
if v1 != v2:
click.echo(f" {label}:")
click.echo(f" - {v1}")
click.echo(f" + {v2}")
inputs1: dict[str, Any] = meta1.get("inputs", {})
inputs2: dict[str, Any] = meta2.get("inputs", {})
all_keys: list[str] = sorted(set(list(inputs1.keys()) + list(inputs2.keys())))
changed_inputs: list[tuple[str, Any, Any]] = [
(k, inputs1.get(k, ""), inputs2.get(k, ""))
for k in all_keys
if inputs1.get(k) != inputs2.get(k)
]
if changed_inputs:
click.echo("\n Inputs:")
for key, v1, v2 in changed_inputs:
click.echo(f" {key}:")
click.echo(f" - {v1}")
click.echo(f" + {v2}")
tasks1: list[dict[str, Any]] = _task_list_from_meta(meta1)
tasks2: list[dict[str, Any]] = _task_list_from_meta(meta2)
max_tasks: int = max(len(tasks1), len(tasks2))
if max_tasks == 0:
return
click.echo("\n Tasks:")
for i in range(max_tasks):
t1: dict[str, Any] | None = tasks1[i] if i < len(tasks1) else None
t2: dict[str, Any] | None = tasks2[i] if i < len(tasks2) else None
if t1 is None:
desc: str = t2["description"][:60] if t2 else ""
click.echo(f" + {i + 1}. [new] {desc}")
continue
if t2 is None:
desc = t1["description"][:60]
click.echo(f" - {i + 1}. [removed] {desc}")
continue
desc = str(t1["description"][:60])
s1: str = "done" if t1["completed"] else "pending"
s2: str = "done" if t2["completed"] else "pending"
if s1 != s2:
click.echo(f" {i + 1}. {desc}")
click.echo(f" status: {s1} -> {s2}")
out1: str = (t1.get("output") or "").strip()
out2: str = (t2.get("output") or "").strip()
if out1 != out2:
if s1 == s2:
click.echo(f" {i + 1}. {desc}")
preview1: str = (
out1[:80] + ("..." if len(out1) > 80 else "") if out1 else "(empty)"
)
preview2: str = (
out2[:80] + ("..." if len(out2) > 80 else "") if out2 else "(empty)"
)
click.echo(" output:")
click.echo(f" - {preview1}")
click.echo(f" + {preview2}")
def _parse_duration(value: str) -> timedelta:
match: re.Match[str] | None = re.match(r"^(\d+)([dhm])$", value.strip())
if not match:
raise click.BadParameter(
f"Invalid duration: {value!r}. Use format like '7d', '24h', or '30m'."
)
amount: int = int(match.group(1))
unit: str = match.group(2)
if unit == "d":
return timedelta(days=amount)
if unit == "h":
return timedelta(hours=amount)
return timedelta(minutes=amount)
def _prune_json(location: str, keep: int | None, older_than: timedelta | None) -> int:
pattern: str = os.path.join(location, "**", "*.json")
files: list[str] = sorted(
glob.glob(pattern, recursive=True), key=os.path.getmtime, reverse=True
)
if not files:
return 0
to_delete: set[str] = set()
if keep is not None and len(files) > keep:
to_delete.update(files[keep:])
if older_than is not None:
cutoff: datetime = datetime.now(timezone.utc) - older_than
for path in files:
mtime: datetime = datetime.fromtimestamp(
os.path.getmtime(path), tz=timezone.utc
)
if mtime < cutoff:
to_delete.add(path)
deleted: int = 0
for path in to_delete:
try:
os.remove(path)
deleted += 1
except OSError: # noqa: PERF203
pass
for dirpath, dirnames, filenames in os.walk(location, topdown=False):
if dirpath != location and not filenames and not dirnames:
try:
os.rmdir(dirpath)
except OSError:
pass
return deleted
def _prune_sqlite(db_path: str, keep: int | None, older_than: timedelta | None) -> int:
deleted: int = 0
with sqlite3.connect(db_path) as conn:
if older_than is not None:
cutoff: str = (datetime.now(timezone.utc) - older_than).strftime(
"%Y%m%dT%H%M%S"
)
cursor: sqlite3.Cursor = conn.execute(_DELETE_OLDER_THAN, (cutoff,))
deleted += cursor.rowcount
if keep is not None:
cursor = conn.execute(_DELETE_KEEP_N, (keep,))
deleted += cursor.rowcount
conn.commit()
return deleted
def prune_checkpoints(
location: str, keep: int | None, older_than: str | None, dry_run: bool = False
) -> None:
if keep is None and older_than is None:
click.echo("Specify --keep N and/or --older-than DURATION (e.g. 7d, 24h)")
return
duration: timedelta | None = _parse_duration(older_than) if older_than else None
deleted: int
if _is_sqlite(location):
if dry_run:
with sqlite3.connect(location) as conn:
total: int = conn.execute(_COUNT_CHECKPOINTS).fetchone()[0]
click.echo(f"Would prune from {total} checkpoint(s) in {location}")
return
deleted = _prune_sqlite(location, keep, duration)
elif os.path.isdir(location):
if dry_run:
files: list[str] = glob.glob(
os.path.join(location, "**", "*.json"), recursive=True
)
click.echo(f"Would prune from {len(files)} checkpoint(s) in {location}")
return
deleted = _prune_json(location, keep, duration)
else:
click.echo(f"Not a directory or SQLite database: {location}")
return
click.echo(f"Pruned {deleted} checkpoint(s) from {location}")

View File

@@ -873,5 +873,48 @@ def checkpoint_info(path: str) -> None:
info_checkpoint(_detect_location(path))
@checkpoint.command("resume")
@click.argument("checkpoint_id", required=False, default=None)
@click.pass_context
def checkpoint_resume(ctx: click.Context, checkpoint_id: str | None) -> None:
"""Resume from a checkpoint. Defaults to the most recent."""
from crewai.cli.checkpoint_cli import resume_checkpoint
resume_checkpoint(ctx.obj["location"], checkpoint_id)
@checkpoint.command("diff")
@click.argument("id1")
@click.argument("id2")
@click.pass_context
def checkpoint_diff(ctx: click.Context, id1: str, id2: str) -> None:
"""Compare two checkpoints side-by-side."""
from crewai.cli.checkpoint_cli import diff_checkpoints
diff_checkpoints(ctx.obj["location"], id1, id2)
@checkpoint.command("prune")
@click.option(
"--keep", type=int, default=None, help="Keep the N most recent checkpoints."
)
@click.option(
"--older-than",
default=None,
help="Remove checkpoints older than duration (e.g. 7d, 24h, 30m).",
)
@click.option(
"--dry-run", is_flag=True, help="Show what would be pruned without deleting."
)
@click.pass_context
def checkpoint_prune(
ctx: click.Context, keep: int | None, older_than: str | None, dry_run: bool
) -> None:
"""Remove old checkpoints."""
from crewai.cli.checkpoint_cli import prune_checkpoints
prune_checkpoints(ctx.obj["location"], keep, older_than, dry_run)
if __name__ == "__main__":
crewai()

View File

@@ -120,6 +120,12 @@ def _do_checkpoint(
)
state._chain_lineage(cfg.provider, location)
checkpoint_id: str = cfg.provider.extract_id(location)
msg: str = (
f"Checkpoint saved. Resume with: crewai checkpoint resume {checkpoint_id}"
)
logger.info(msg)
if cfg.max_checkpoints is not None:
cfg.provider.prune(cfg.location, cfg.max_checkpoints, branch=state._branch)

View File

@@ -0,0 +1,402 @@
"""Tests for checkpoint CLI commands."""
from __future__ import annotations
import json
import os
import sqlite3
import tempfile
import time
from datetime import datetime, timedelta, timezone
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from crewai.cli.checkpoint_cli import (
_parse_checkpoint_json,
_parse_duration,
_prune_json,
_prune_sqlite,
_resolve_checkpoint,
_task_list_from_meta,
diff_checkpoints,
prune_checkpoints,
resume_checkpoint,
)
def _make_checkpoint_data(
tasks_completed: int = 2,
tasks_total: int = 4,
trigger: str = "task_completed",
branch: str = "main",
parent_id: str | None = None,
entity_type: str = "crew",
name: str = "test_crew",
inputs: dict[str, Any] | None = None,
) -> str:
tasks: list[dict[str, Any]] = []
for i in range(tasks_total):
t: dict[str, Any] = {
"description": f"Task {i + 1} description",
"expected_output": f"Output {i + 1}",
}
if i < tasks_completed:
t["output"] = {"raw": f"Result of task {i + 1}"}
else:
t["output"] = None
tasks.append(t)
data: dict[str, Any] = {
"entities": [
{
"entity_type": entity_type,
"name": name,
"id": "abc12345-1234-1234-1234-abcdef012345",
"tasks": tasks,
"agents": [],
"checkpoint_inputs": inputs or {},
}
],
"event_record": {"nodes": {f"node_{i}": {} for i in range(3)}},
"trigger": trigger,
"branch": branch,
"parent_id": parent_id,
}
return json.dumps(data)
def _write_json_checkpoint(
base_dir: str,
branch: str = "main",
name: str | None = None,
data: str | None = None,
tasks_completed: int = 2,
inputs: dict[str, Any] | None = None,
) -> str:
branch_dir = os.path.join(base_dir, branch)
os.makedirs(branch_dir, exist_ok=True)
if name is None:
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S")
name = f"{ts}_abcd1234_p-none.json"
path = os.path.join(branch_dir, name)
if data is None:
data = _make_checkpoint_data(tasks_completed=tasks_completed, inputs=inputs)
with open(path, "w") as f:
f.write(data)
return path
def _create_sqlite_checkpoint(
db_path: str,
checkpoint_id: str | None = None,
data: str | None = None,
tasks_completed: int = 2,
branch: str = "main",
inputs: dict[str, Any] | None = None,
) -> str:
if checkpoint_id is None:
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S")
checkpoint_id = f"{ts}_abcd1234"
if data is None:
data = _make_checkpoint_data(
tasks_completed=tasks_completed, branch=branch, inputs=inputs
)
with sqlite3.connect(db_path) as conn:
conn.execute(
"""CREATE TABLE IF NOT EXISTS checkpoints (
id TEXT PRIMARY KEY,
created_at TEXT NOT NULL,
parent_id TEXT,
branch TEXT NOT NULL DEFAULT 'main',
data JSONB NOT NULL
)"""
)
conn.execute(
"INSERT INTO checkpoints (id, created_at, parent_id, branch, data) "
"VALUES (?, ?, ?, ?, jsonb(?))",
(checkpoint_id, checkpoint_id.split("_")[0], None, branch, data),
)
conn.commit()
return checkpoint_id
class TestParseDuration:
def test_days(self) -> None:
assert _parse_duration("7d") == timedelta(days=7)
def test_hours(self) -> None:
assert _parse_duration("24h") == timedelta(hours=24)
def test_minutes(self) -> None:
assert _parse_duration("30m") == timedelta(minutes=30)
def test_invalid_raises(self) -> None:
with pytest.raises(Exception):
_parse_duration("abc")
def test_no_unit_raises(self) -> None:
with pytest.raises(Exception):
_parse_duration("7")
class TestResolveCheckpoint:
def test_json_latest(self) -> None:
with tempfile.TemporaryDirectory() as d:
_write_json_checkpoint(d, name="20260101T000000_aaaa1111_p-none.json")
time.sleep(0.01)
path2 = _write_json_checkpoint(
d, name="20260102T000000_bbbb2222_p-none.json", tasks_completed=3
)
meta = _resolve_checkpoint(d, None)
assert meta is not None
assert meta["path"] == path2
def test_json_by_id(self) -> None:
with tempfile.TemporaryDirectory() as d:
_write_json_checkpoint(d, name="20260101T000000_aaaa1111_p-none.json")
_write_json_checkpoint(d, name="20260102T000000_bbbb2222_p-none.json")
meta = _resolve_checkpoint(d, "aaaa1111")
assert meta is not None
assert "aaaa1111" in meta["name"]
def test_json_not_found(self) -> None:
with tempfile.TemporaryDirectory() as d:
_write_json_checkpoint(d)
assert _resolve_checkpoint(d, "nonexistent") is None
def test_sqlite_latest(self) -> None:
with tempfile.TemporaryDirectory() as d:
db_path = os.path.join(d, "test.db")
_create_sqlite_checkpoint(db_path, "20260101T000000_aaaa1111")
_create_sqlite_checkpoint(
db_path, "20260102T000000_bbbb2222", tasks_completed=3
)
meta = _resolve_checkpoint(db_path, None)
assert meta is not None
assert "bbbb2222" in meta["name"]
def test_sqlite_by_id(self) -> None:
with tempfile.TemporaryDirectory() as d:
db_path = os.path.join(d, "test.db")
_create_sqlite_checkpoint(db_path, "20260101T000000_aaaa1111")
_create_sqlite_checkpoint(db_path, "20260102T000000_bbbb2222")
meta = _resolve_checkpoint(db_path, "20260101T000000_aaaa1111")
assert meta is not None
assert "aaaa1111" in meta["name"]
def test_sqlite_partial_id(self) -> None:
with tempfile.TemporaryDirectory() as d:
db_path = os.path.join(d, "test.db")
_create_sqlite_checkpoint(db_path, "20260101T000000_aaaa1111")
_create_sqlite_checkpoint(db_path, "20260102T000000_bbbb2222")
meta = _resolve_checkpoint(db_path, "aaaa1111")
assert meta is not None
assert "aaaa1111" in meta["name"]
def test_nonexistent(self) -> None:
assert _resolve_checkpoint("/nonexistent/path", None) is None
class TestTaskListFromMeta:
def test_flattens_tasks(self) -> None:
data = _make_checkpoint_data(tasks_completed=2, tasks_total=3)
meta = _parse_checkpoint_json(data, "test")
tasks = _task_list_from_meta(meta)
assert len(tasks) == 3
assert tasks[0]["completed"] is True
assert tasks[2]["completed"] is False
def test_empty_entities(self) -> None:
assert _task_list_from_meta({"entities": []}) == []
class TestDiffCheckpoints:
def test_diff_shows_status_change(self, capsys: pytest.CaptureFixture[str]) -> None:
with tempfile.TemporaryDirectory() as d:
_write_json_checkpoint(
d, name="20260101T000000_aaaa1111_p-none.json", tasks_completed=1
)
_write_json_checkpoint(
d, name="20260102T000000_bbbb2222_p-none.json", tasks_completed=3
)
diff_checkpoints(d, "aaaa1111", "bbbb2222")
out = capsys.readouterr().out
assert "---" in out
assert "+++" in out
assert "status:" in out or "pending -> done" in out
def test_diff_shows_output_change(self, capsys: pytest.CaptureFixture[str]) -> None:
with tempfile.TemporaryDirectory() as d:
data1 = _make_checkpoint_data(tasks_completed=2)
data2 = json.loads(data1)
data2["entities"][0]["tasks"][0]["output"]["raw"] = "Updated result"
_write_json_checkpoint(
d,
name="20260101T000000_aaaa1111_p-none.json",
data=json.dumps(json.loads(data1)),
)
_write_json_checkpoint(
d,
name="20260102T000000_bbbb2222_p-none.json",
data=json.dumps(data2),
)
diff_checkpoints(d, "aaaa1111", "bbbb2222")
out = capsys.readouterr().out
assert "output:" in out
def test_diff_not_found(self, capsys: pytest.CaptureFixture[str]) -> None:
with tempfile.TemporaryDirectory() as d:
_write_json_checkpoint(d, name="20260101T000000_aaaa1111_p-none.json")
diff_checkpoints(d, "aaaa1111", "nonexistent")
out = capsys.readouterr().out
assert "not found" in out
def test_diff_input_change(self, capsys: pytest.CaptureFixture[str]) -> None:
with tempfile.TemporaryDirectory() as d:
_write_json_checkpoint(
d,
name="20260101T000000_aaaa1111_p-none.json",
inputs={"topic": "AI"},
)
_write_json_checkpoint(
d,
name="20260102T000000_bbbb2222_p-none.json",
inputs={"topic": "ML"},
)
diff_checkpoints(d, "aaaa1111", "bbbb2222")
out = capsys.readouterr().out
assert "Inputs:" in out
assert "AI" in out
assert "ML" in out
class TestPruneJson:
def test_keep_n(self) -> None:
with tempfile.TemporaryDirectory() as d:
for i in range(5):
_write_json_checkpoint(
d, name=f"2026010{i + 1}T000000_aaa{i}1111_p-none.json"
)
time.sleep(0.01)
deleted = _prune_json(d, keep=2, older_than=None)
assert deleted == 3
remaining = []
for root, _, files in os.walk(d):
remaining.extend(files)
assert len(remaining) == 2
def test_older_than(self) -> None:
with tempfile.TemporaryDirectory() as d:
old_path = _write_json_checkpoint(
d, name="20250101T000000_old01111_p-none.json"
)
os.utime(old_path, (0, 0))
_write_json_checkpoint(d, name="20260417T000000_new01111_p-none.json")
deleted = _prune_json(d, keep=None, older_than=timedelta(days=1))
assert deleted == 1
def test_empty_dir(self) -> None:
with tempfile.TemporaryDirectory() as d:
assert _prune_json(d, keep=2, older_than=None) == 0
def test_removes_empty_branch_dirs(self) -> None:
with tempfile.TemporaryDirectory() as d:
path = _write_json_checkpoint(
d,
branch="feature",
name="20260101T000000_aaaa1111_p-none.json",
)
os.utime(path, (0, 0))
_prune_json(d, keep=None, older_than=timedelta(days=1))
assert not os.path.exists(os.path.join(d, "feature"))
class TestPruneSqlite:
def test_keep_n(self) -> None:
with tempfile.TemporaryDirectory() as d:
db_path = os.path.join(d, "test.db")
for i in range(5):
_create_sqlite_checkpoint(
db_path, f"2026010{i + 1}T000000_aaa{i}1111"
)
deleted = _prune_sqlite(db_path, keep=2, older_than=None)
assert deleted == 3
with sqlite3.connect(db_path) as conn:
count = conn.execute("SELECT COUNT(*) FROM checkpoints").fetchone()[0]
assert count == 2
def test_older_than(self) -> None:
with tempfile.TemporaryDirectory() as d:
db_path = os.path.join(d, "test.db")
_create_sqlite_checkpoint(db_path, "20200101T000000_old01111")
_create_sqlite_checkpoint(db_path, "20260417T000000_new01111")
deleted = _prune_sqlite(db_path, keep=None, older_than=timedelta(days=1))
assert deleted >= 1
with sqlite3.connect(db_path) as conn:
count = conn.execute("SELECT COUNT(*) FROM checkpoints").fetchone()[0]
assert count >= 1
class TestPruneCommand:
def test_no_options_shows_help(self, capsys: pytest.CaptureFixture[str]) -> None:
with tempfile.TemporaryDirectory() as d:
prune_checkpoints(d, keep=None, older_than=None)
out = capsys.readouterr().out
assert "Specify" in out
def test_dry_run_json(self, capsys: pytest.CaptureFixture[str]) -> None:
with tempfile.TemporaryDirectory() as d:
_write_json_checkpoint(d)
prune_checkpoints(d, keep=1, older_than=None, dry_run=True)
out = capsys.readouterr().out
assert "Would prune" in out
def test_not_found(self, capsys: pytest.CaptureFixture[str]) -> None:
prune_checkpoints("/nonexistent", keep=1, older_than=None)
out = capsys.readouterr().out
assert "Not a directory" in out
class TestResumeCheckpoint:
def test_not_found(self, capsys: pytest.CaptureFixture[str]) -> None:
with tempfile.TemporaryDirectory() as d:
resume_checkpoint(d, "nonexistent")
out = capsys.readouterr().out
assert "not found" in out
def test_no_checkpoints(self, capsys: pytest.CaptureFixture[str]) -> None:
with tempfile.TemporaryDirectory() as d:
resume_checkpoint(d, None)
out = capsys.readouterr().out
assert "No checkpoints" in out
class TestDiscoverabilityMessage:
def test_checkpoint_listener_logs_resume_hint(self) -> None:
from crewai.state.checkpoint_listener import _do_checkpoint
from crewai.state.runtime import RuntimeState
state = MagicMock(spec=RuntimeState)
state.root = []
state.model_dump.return_value = {"entities": [], "event_record": {"nodes": {}}}
state._parent_id = None
state._branch = "main"
cfg = MagicMock()
cfg.location = "/tmp/cp"
cfg.max_checkpoints = None
cfg.provider.checkpoint.return_value = "/tmp/cp/main/20260101T000000_test1234_p-none.json"
cfg.provider.extract_id.return_value = "20260101T000000_test1234"
with (
patch("crewai.state.checkpoint_listener._prepare_entities"),
patch("crewai.state.checkpoint_listener.logger") as mock_logger,
):
_do_checkpoint(state, cfg)
cfg.provider.extract_id.assert_called_once()
mock_logger.info.assert_called_once()
logged: str = mock_logger.info.call_args[0][0]
assert "crewai checkpoint resume" in logged
assert "20260101T000000_test1234" in logged