From c5192b970c498010d9008263249ca5598c2b2459 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Fri, 17 Apr 2026 04:50:15 +0800 Subject: [PATCH] feat: add checkpoint resume, diff, prune commands and save discoverability Add three new CLI subcommands to improve checkpoint UX: - `crewai checkpoint resume [id]` skips the TUI and resumes from the latest or specified checkpoint directly - `crewai checkpoint diff ` compares two checkpoints showing changes in metadata, inputs, task status, and outputs - `crewai checkpoint prune --keep N --older-than Xd` removes old checkpoints from JSON dirs or SQLite databases Also writes a resume hint to stderr after every checkpoint save so users discover the command without needing to know it exists. --- lib/crewai/src/crewai/cli/checkpoint_cli.py | 308 +++++++++++++- lib/crewai/src/crewai/cli/cli.py | 43 ++ .../src/crewai/state/checkpoint_listener.py | 6 + lib/crewai/tests/test_checkpoint_cli.py | 402 ++++++++++++++++++ 4 files changed, 758 insertions(+), 1 deletion(-) create mode 100644 lib/crewai/tests/test_checkpoint_cli.py diff --git a/lib/crewai/src/crewai/cli/checkpoint_cli.py b/lib/crewai/src/crewai/cli/checkpoint_cli.py index 1745c224f..9db783e0e 100644 --- a/lib/crewai/src/crewai/cli/checkpoint_cli.py +++ b/lib/crewai/src/crewai/cli/checkpoint_cli.py @@ -2,7 +2,7 @@ from __future__ import annotations -from datetime import datetime +from datetime import datetime, timedelta, timezone import glob import json import os @@ -37,6 +37,26 @@ ORDER BY rowid DESC LIMIT 1 """ +_DELETE_OLDER_THAN = """ +DELETE FROM checkpoints +WHERE created_at < ? +""" + +_DELETE_KEEP_N = """ +DELETE FROM checkpoints WHERE rowid NOT IN ( + SELECT rowid FROM checkpoints ORDER BY rowid DESC LIMIT ? +) +""" + +_COUNT_CHECKPOINTS = "SELECT COUNT(*) FROM checkpoints" + +_SELECT_LIKE = """ +SELECT id, created_at, json(data) +FROM checkpoints +WHERE id LIKE ? +ORDER BY rowid DESC +""" + _DEFAULT_DIR = "./.checkpoints" _DEFAULT_DB = "./.checkpoints.db" @@ -262,6 +282,8 @@ def _info_sqlite_latest(db_path: str) -> dict[str, Any] | None: def _info_sqlite_id(db_path: str, checkpoint_id: str) -> dict[str, Any] | None: with sqlite3.connect(db_path) as conn: row = conn.execute(_SELECT_ONE, (checkpoint_id,)).fetchone() + if not row: + row = conn.execute(_SELECT_LIKE, (f"%{checkpoint_id}%",)).fetchone() if not row: return None cid, created_at, raw = row @@ -384,3 +406,287 @@ def _print_info(meta: dict[str, Any]) -> None: if len(desc) > 70: desc = desc[:67] + "..." click.echo(f" {i + 1}. [{status}] {desc}") + + +def _resolve_checkpoint( + location: str, checkpoint_id: str | None +) -> dict[str, Any] | None: + if _is_sqlite(location): + if checkpoint_id: + return _info_sqlite_id(location, checkpoint_id) + return _info_sqlite_latest(location) + if os.path.isdir(location): + if checkpoint_id: + from crewai.state.provider.json_provider import JsonProvider + + _json_provider: JsonProvider = JsonProvider() + pattern: str = os.path.join(location, "**", "*.json") + all_files: list[str] = glob.glob(pattern, recursive=True) + matches: list[str] = [ + f for f in all_files if checkpoint_id in _json_provider.extract_id(f) + ] + matches.sort(key=os.path.getmtime, reverse=True) + if matches: + return _info_json_file(matches[0]) + return None + return _info_json_latest(location) + if os.path.isfile(location): + return _info_json_file(location) + return None + + +def _entity_type_from_meta(meta: dict[str, Any]) -> str: + for ent in meta.get("entities", []): + if ent.get("type") == "flow": + return "flow" + return "crew" + + +def resume_checkpoint(location: str, checkpoint_id: str | None) -> None: + import asyncio + + meta: dict[str, Any] | None = _resolve_checkpoint(location, checkpoint_id) + if meta is None: + if checkpoint_id: + click.echo(f"Checkpoint not found: {checkpoint_id}") + else: + click.echo(f"No checkpoints found in {location}") + return + + restore_path: str = meta.get("path") or meta.get("source", "") + if meta.get("db"): + restore_path = f"{meta['db']}#{meta['name']}" + + click.echo(f"Resuming from: {meta.get('name', restore_path)}") + _print_info(meta) + click.echo() + + from crewai.state.checkpoint_config import CheckpointConfig + + config: CheckpointConfig = CheckpointConfig(restore_from=restore_path) + entity_type: str = _entity_type_from_meta(meta) + inputs: dict[str, Any] | None = meta.get("inputs") or None + + if entity_type == "flow": + from crewai.flow.flow import Flow + + flow = Flow.from_checkpoint(config) + result = asyncio.run(flow.kickoff_async(inputs=inputs)) + else: + from crewai.crew import Crew + + crew = Crew.from_checkpoint(config) + result = asyncio.run(crew.akickoff(inputs=inputs)) + + click.echo(f"\nResult: {getattr(result, 'raw', result)}") + + +def _task_list_from_meta(meta: dict[str, Any]) -> list[dict[str, Any]]: + tasks: list[dict[str, Any]] = [] + for ent in meta.get("entities", []): + tasks.extend( + { + "entity": ent.get("name", "unnamed"), + "description": t.get("description", ""), + "completed": t.get("completed", False), + "output": t.get("output", ""), + } + for t in ent.get("tasks", []) + ) + return tasks + + +def diff_checkpoints(location: str, id1: str, id2: str) -> None: + meta1: dict[str, Any] | None = _resolve_checkpoint(location, id1) + meta2: dict[str, Any] | None = _resolve_checkpoint(location, id2) + + if meta1 is None: + click.echo(f"Checkpoint not found: {id1}") + return + if meta2 is None: + click.echo(f"Checkpoint not found: {id2}") + return + + name1: str = meta1.get("name", id1) + name2: str = meta2.get("name", id2) + + click.echo(f"--- {name1}") + click.echo(f"+++ {name2}") + click.echo() + + fields: list[tuple[str, str]] = [ + ("Time", "ts"), + ("Branch", "branch"), + ("Trigger", "trigger"), + ("Events", "event_count"), + ] + for label, key in fields: + v1: str = str(meta1.get(key, "")) + v2: str = str(meta2.get(key, "")) + if v1 != v2: + click.echo(f" {label}:") + click.echo(f" - {v1}") + click.echo(f" + {v2}") + + inputs1: dict[str, Any] = meta1.get("inputs", {}) + inputs2: dict[str, Any] = meta2.get("inputs", {}) + all_keys: list[str] = sorted(set(list(inputs1.keys()) + list(inputs2.keys()))) + changed_inputs: list[tuple[str, Any, Any]] = [ + (k, inputs1.get(k, ""), inputs2.get(k, "")) + for k in all_keys + if inputs1.get(k) != inputs2.get(k) + ] + if changed_inputs: + click.echo("\n Inputs:") + for key, v1, v2 in changed_inputs: + click.echo(f" {key}:") + click.echo(f" - {v1}") + click.echo(f" + {v2}") + + tasks1: list[dict[str, Any]] = _task_list_from_meta(meta1) + tasks2: list[dict[str, Any]] = _task_list_from_meta(meta2) + + max_tasks: int = max(len(tasks1), len(tasks2)) + if max_tasks == 0: + return + + click.echo("\n Tasks:") + for i in range(max_tasks): + t1: dict[str, Any] | None = tasks1[i] if i < len(tasks1) else None + t2: dict[str, Any] | None = tasks2[i] if i < len(tasks2) else None + + if t1 is None: + desc: str = t2["description"][:60] if t2 else "" + click.echo(f" + {i + 1}. [new] {desc}") + continue + if t2 is None: + desc = t1["description"][:60] + click.echo(f" - {i + 1}. [removed] {desc}") + continue + + desc = str(t1["description"][:60]) + s1: str = "done" if t1["completed"] else "pending" + s2: str = "done" if t2["completed"] else "pending" + + if s1 != s2: + click.echo(f" {i + 1}. {desc}") + click.echo(f" status: {s1} -> {s2}") + + out1: str = (t1.get("output") or "").strip() + out2: str = (t2.get("output") or "").strip() + if out1 != out2: + if s1 == s2: + click.echo(f" {i + 1}. {desc}") + preview1: str = ( + out1[:80] + ("..." if len(out1) > 80 else "") if out1 else "(empty)" + ) + preview2: str = ( + out2[:80] + ("..." if len(out2) > 80 else "") if out2 else "(empty)" + ) + click.echo(" output:") + click.echo(f" - {preview1}") + click.echo(f" + {preview2}") + + +def _parse_duration(value: str) -> timedelta: + match: re.Match[str] | None = re.match(r"^(\d+)([dhm])$", value.strip()) + if not match: + raise click.BadParameter( + f"Invalid duration: {value!r}. Use format like '7d', '24h', or '30m'." + ) + amount: int = int(match.group(1)) + unit: str = match.group(2) + if unit == "d": + return timedelta(days=amount) + if unit == "h": + return timedelta(hours=amount) + return timedelta(minutes=amount) + + +def _prune_json(location: str, keep: int | None, older_than: timedelta | None) -> int: + pattern: str = os.path.join(location, "**", "*.json") + files: list[str] = sorted( + glob.glob(pattern, recursive=True), key=os.path.getmtime, reverse=True + ) + if not files: + return 0 + + to_delete: set[str] = set() + + if keep is not None and len(files) > keep: + to_delete.update(files[keep:]) + + if older_than is not None: + cutoff: datetime = datetime.now(timezone.utc) - older_than + for path in files: + mtime: datetime = datetime.fromtimestamp( + os.path.getmtime(path), tz=timezone.utc + ) + if mtime < cutoff: + to_delete.add(path) + + deleted: int = 0 + for path in to_delete: + try: + os.remove(path) + deleted += 1 + except OSError: # noqa: PERF203 + pass + + for dirpath, dirnames, filenames in os.walk(location, topdown=False): + if dirpath != location and not filenames and not dirnames: + try: + os.rmdir(dirpath) + except OSError: + pass + + return deleted + + +def _prune_sqlite(db_path: str, keep: int | None, older_than: timedelta | None) -> int: + deleted: int = 0 + with sqlite3.connect(db_path) as conn: + if older_than is not None: + cutoff: str = (datetime.now(timezone.utc) - older_than).strftime( + "%Y%m%dT%H%M%S" + ) + cursor: sqlite3.Cursor = conn.execute(_DELETE_OLDER_THAN, (cutoff,)) + deleted += cursor.rowcount + + if keep is not None: + cursor = conn.execute(_DELETE_KEEP_N, (keep,)) + deleted += cursor.rowcount + + conn.commit() + return deleted + + +def prune_checkpoints( + location: str, keep: int | None, older_than: str | None, dry_run: bool = False +) -> None: + if keep is None and older_than is None: + click.echo("Specify --keep N and/or --older-than DURATION (e.g. 7d, 24h)") + return + + duration: timedelta | None = _parse_duration(older_than) if older_than else None + + deleted: int + if _is_sqlite(location): + if dry_run: + with sqlite3.connect(location) as conn: + total: int = conn.execute(_COUNT_CHECKPOINTS).fetchone()[0] + click.echo(f"Would prune from {total} checkpoint(s) in {location}") + return + deleted = _prune_sqlite(location, keep, duration) + elif os.path.isdir(location): + if dry_run: + files: list[str] = glob.glob( + os.path.join(location, "**", "*.json"), recursive=True + ) + click.echo(f"Would prune from {len(files)} checkpoint(s) in {location}") + return + deleted = _prune_json(location, keep, duration) + else: + click.echo(f"Not a directory or SQLite database: {location}") + return + click.echo(f"Pruned {deleted} checkpoint(s) from {location}") diff --git a/lib/crewai/src/crewai/cli/cli.py b/lib/crewai/src/crewai/cli/cli.py index bc2a9ee26..dc4284677 100644 --- a/lib/crewai/src/crewai/cli/cli.py +++ b/lib/crewai/src/crewai/cli/cli.py @@ -873,5 +873,48 @@ def checkpoint_info(path: str) -> None: info_checkpoint(_detect_location(path)) +@checkpoint.command("resume") +@click.argument("checkpoint_id", required=False, default=None) +@click.pass_context +def checkpoint_resume(ctx: click.Context, checkpoint_id: str | None) -> None: + """Resume from a checkpoint. Defaults to the most recent.""" + from crewai.cli.checkpoint_cli import resume_checkpoint + + resume_checkpoint(ctx.obj["location"], checkpoint_id) + + +@checkpoint.command("diff") +@click.argument("id1") +@click.argument("id2") +@click.pass_context +def checkpoint_diff(ctx: click.Context, id1: str, id2: str) -> None: + """Compare two checkpoints side-by-side.""" + from crewai.cli.checkpoint_cli import diff_checkpoints + + diff_checkpoints(ctx.obj["location"], id1, id2) + + +@checkpoint.command("prune") +@click.option( + "--keep", type=int, default=None, help="Keep the N most recent checkpoints." +) +@click.option( + "--older-than", + default=None, + help="Remove checkpoints older than duration (e.g. 7d, 24h, 30m).", +) +@click.option( + "--dry-run", is_flag=True, help="Show what would be pruned without deleting." +) +@click.pass_context +def checkpoint_prune( + ctx: click.Context, keep: int | None, older_than: str | None, dry_run: bool +) -> None: + """Remove old checkpoints.""" + from crewai.cli.checkpoint_cli import prune_checkpoints + + prune_checkpoints(ctx.obj["location"], keep, older_than, dry_run) + + if __name__ == "__main__": crewai() diff --git a/lib/crewai/src/crewai/state/checkpoint_listener.py b/lib/crewai/src/crewai/state/checkpoint_listener.py index 2408e88e3..674a8436a 100644 --- a/lib/crewai/src/crewai/state/checkpoint_listener.py +++ b/lib/crewai/src/crewai/state/checkpoint_listener.py @@ -120,6 +120,12 @@ def _do_checkpoint( ) state._chain_lineage(cfg.provider, location) + checkpoint_id: str = cfg.provider.extract_id(location) + msg: str = ( + f"Checkpoint saved. Resume with: crewai checkpoint resume {checkpoint_id}" + ) + logger.info(msg) + if cfg.max_checkpoints is not None: cfg.provider.prune(cfg.location, cfg.max_checkpoints, branch=state._branch) diff --git a/lib/crewai/tests/test_checkpoint_cli.py b/lib/crewai/tests/test_checkpoint_cli.py new file mode 100644 index 000000000..38e105cce --- /dev/null +++ b/lib/crewai/tests/test_checkpoint_cli.py @@ -0,0 +1,402 @@ +"""Tests for checkpoint CLI commands.""" + +from __future__ import annotations + +import json +import os +import sqlite3 +import tempfile +import time +from datetime import datetime, timedelta, timezone +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest +from crewai.cli.checkpoint_cli import ( + _parse_checkpoint_json, + _parse_duration, + _prune_json, + _prune_sqlite, + _resolve_checkpoint, + _task_list_from_meta, + diff_checkpoints, + prune_checkpoints, + resume_checkpoint, +) + + +def _make_checkpoint_data( + tasks_completed: int = 2, + tasks_total: int = 4, + trigger: str = "task_completed", + branch: str = "main", + parent_id: str | None = None, + entity_type: str = "crew", + name: str = "test_crew", + inputs: dict[str, Any] | None = None, +) -> str: + tasks: list[dict[str, Any]] = [] + for i in range(tasks_total): + t: dict[str, Any] = { + "description": f"Task {i + 1} description", + "expected_output": f"Output {i + 1}", + } + if i < tasks_completed: + t["output"] = {"raw": f"Result of task {i + 1}"} + else: + t["output"] = None + tasks.append(t) + + data: dict[str, Any] = { + "entities": [ + { + "entity_type": entity_type, + "name": name, + "id": "abc12345-1234-1234-1234-abcdef012345", + "tasks": tasks, + "agents": [], + "checkpoint_inputs": inputs or {}, + } + ], + "event_record": {"nodes": {f"node_{i}": {} for i in range(3)}}, + "trigger": trigger, + "branch": branch, + "parent_id": parent_id, + } + return json.dumps(data) + + +def _write_json_checkpoint( + base_dir: str, + branch: str = "main", + name: str | None = None, + data: str | None = None, + tasks_completed: int = 2, + inputs: dict[str, Any] | None = None, +) -> str: + branch_dir = os.path.join(base_dir, branch) + os.makedirs(branch_dir, exist_ok=True) + if name is None: + ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S") + name = f"{ts}_abcd1234_p-none.json" + path = os.path.join(branch_dir, name) + if data is None: + data = _make_checkpoint_data(tasks_completed=tasks_completed, inputs=inputs) + with open(path, "w") as f: + f.write(data) + return path + + +def _create_sqlite_checkpoint( + db_path: str, + checkpoint_id: str | None = None, + data: str | None = None, + tasks_completed: int = 2, + branch: str = "main", + inputs: dict[str, Any] | None = None, +) -> str: + if checkpoint_id is None: + ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S") + checkpoint_id = f"{ts}_abcd1234" + if data is None: + data = _make_checkpoint_data( + tasks_completed=tasks_completed, branch=branch, inputs=inputs + ) + with sqlite3.connect(db_path) as conn: + conn.execute( + """CREATE TABLE IF NOT EXISTS checkpoints ( + id TEXT PRIMARY KEY, + created_at TEXT NOT NULL, + parent_id TEXT, + branch TEXT NOT NULL DEFAULT 'main', + data JSONB NOT NULL + )""" + ) + conn.execute( + "INSERT INTO checkpoints (id, created_at, parent_id, branch, data) " + "VALUES (?, ?, ?, ?, jsonb(?))", + (checkpoint_id, checkpoint_id.split("_")[0], None, branch, data), + ) + conn.commit() + return checkpoint_id + + +class TestParseDuration: + def test_days(self) -> None: + assert _parse_duration("7d") == timedelta(days=7) + + def test_hours(self) -> None: + assert _parse_duration("24h") == timedelta(hours=24) + + def test_minutes(self) -> None: + assert _parse_duration("30m") == timedelta(minutes=30) + + def test_invalid_raises(self) -> None: + with pytest.raises(Exception): + _parse_duration("abc") + + def test_no_unit_raises(self) -> None: + with pytest.raises(Exception): + _parse_duration("7") + + +class TestResolveCheckpoint: + def test_json_latest(self) -> None: + with tempfile.TemporaryDirectory() as d: + _write_json_checkpoint(d, name="20260101T000000_aaaa1111_p-none.json") + time.sleep(0.01) + path2 = _write_json_checkpoint( + d, name="20260102T000000_bbbb2222_p-none.json", tasks_completed=3 + ) + meta = _resolve_checkpoint(d, None) + assert meta is not None + assert meta["path"] == path2 + + def test_json_by_id(self) -> None: + with tempfile.TemporaryDirectory() as d: + _write_json_checkpoint(d, name="20260101T000000_aaaa1111_p-none.json") + _write_json_checkpoint(d, name="20260102T000000_bbbb2222_p-none.json") + meta = _resolve_checkpoint(d, "aaaa1111") + assert meta is not None + assert "aaaa1111" in meta["name"] + + def test_json_not_found(self) -> None: + with tempfile.TemporaryDirectory() as d: + _write_json_checkpoint(d) + assert _resolve_checkpoint(d, "nonexistent") is None + + def test_sqlite_latest(self) -> None: + with tempfile.TemporaryDirectory() as d: + db_path = os.path.join(d, "test.db") + _create_sqlite_checkpoint(db_path, "20260101T000000_aaaa1111") + _create_sqlite_checkpoint( + db_path, "20260102T000000_bbbb2222", tasks_completed=3 + ) + meta = _resolve_checkpoint(db_path, None) + assert meta is not None + assert "bbbb2222" in meta["name"] + + def test_sqlite_by_id(self) -> None: + with tempfile.TemporaryDirectory() as d: + db_path = os.path.join(d, "test.db") + _create_sqlite_checkpoint(db_path, "20260101T000000_aaaa1111") + _create_sqlite_checkpoint(db_path, "20260102T000000_bbbb2222") + meta = _resolve_checkpoint(db_path, "20260101T000000_aaaa1111") + assert meta is not None + assert "aaaa1111" in meta["name"] + + def test_sqlite_partial_id(self) -> None: + with tempfile.TemporaryDirectory() as d: + db_path = os.path.join(d, "test.db") + _create_sqlite_checkpoint(db_path, "20260101T000000_aaaa1111") + _create_sqlite_checkpoint(db_path, "20260102T000000_bbbb2222") + meta = _resolve_checkpoint(db_path, "aaaa1111") + assert meta is not None + assert "aaaa1111" in meta["name"] + + def test_nonexistent(self) -> None: + assert _resolve_checkpoint("/nonexistent/path", None) is None + + +class TestTaskListFromMeta: + def test_flattens_tasks(self) -> None: + data = _make_checkpoint_data(tasks_completed=2, tasks_total=3) + meta = _parse_checkpoint_json(data, "test") + tasks = _task_list_from_meta(meta) + assert len(tasks) == 3 + assert tasks[0]["completed"] is True + assert tasks[2]["completed"] is False + + def test_empty_entities(self) -> None: + assert _task_list_from_meta({"entities": []}) == [] + + +class TestDiffCheckpoints: + def test_diff_shows_status_change(self, capsys: pytest.CaptureFixture[str]) -> None: + with tempfile.TemporaryDirectory() as d: + _write_json_checkpoint( + d, name="20260101T000000_aaaa1111_p-none.json", tasks_completed=1 + ) + _write_json_checkpoint( + d, name="20260102T000000_bbbb2222_p-none.json", tasks_completed=3 + ) + diff_checkpoints(d, "aaaa1111", "bbbb2222") + out = capsys.readouterr().out + assert "---" in out + assert "+++" in out + assert "status:" in out or "pending -> done" in out + + def test_diff_shows_output_change(self, capsys: pytest.CaptureFixture[str]) -> None: + with tempfile.TemporaryDirectory() as d: + data1 = _make_checkpoint_data(tasks_completed=2) + data2 = json.loads(data1) + data2["entities"][0]["tasks"][0]["output"]["raw"] = "Updated result" + _write_json_checkpoint( + d, + name="20260101T000000_aaaa1111_p-none.json", + data=json.dumps(json.loads(data1)), + ) + _write_json_checkpoint( + d, + name="20260102T000000_bbbb2222_p-none.json", + data=json.dumps(data2), + ) + diff_checkpoints(d, "aaaa1111", "bbbb2222") + out = capsys.readouterr().out + assert "output:" in out + + def test_diff_not_found(self, capsys: pytest.CaptureFixture[str]) -> None: + with tempfile.TemporaryDirectory() as d: + _write_json_checkpoint(d, name="20260101T000000_aaaa1111_p-none.json") + diff_checkpoints(d, "aaaa1111", "nonexistent") + out = capsys.readouterr().out + assert "not found" in out + + def test_diff_input_change(self, capsys: pytest.CaptureFixture[str]) -> None: + with tempfile.TemporaryDirectory() as d: + _write_json_checkpoint( + d, + name="20260101T000000_aaaa1111_p-none.json", + inputs={"topic": "AI"}, + ) + _write_json_checkpoint( + d, + name="20260102T000000_bbbb2222_p-none.json", + inputs={"topic": "ML"}, + ) + diff_checkpoints(d, "aaaa1111", "bbbb2222") + out = capsys.readouterr().out + assert "Inputs:" in out + assert "AI" in out + assert "ML" in out + + +class TestPruneJson: + def test_keep_n(self) -> None: + with tempfile.TemporaryDirectory() as d: + for i in range(5): + _write_json_checkpoint( + d, name=f"2026010{i + 1}T000000_aaa{i}1111_p-none.json" + ) + time.sleep(0.01) + deleted = _prune_json(d, keep=2, older_than=None) + assert deleted == 3 + remaining = [] + for root, _, files in os.walk(d): + remaining.extend(files) + assert len(remaining) == 2 + + def test_older_than(self) -> None: + with tempfile.TemporaryDirectory() as d: + old_path = _write_json_checkpoint( + d, name="20250101T000000_old01111_p-none.json" + ) + os.utime(old_path, (0, 0)) + _write_json_checkpoint(d, name="20260417T000000_new01111_p-none.json") + deleted = _prune_json(d, keep=None, older_than=timedelta(days=1)) + assert deleted == 1 + + def test_empty_dir(self) -> None: + with tempfile.TemporaryDirectory() as d: + assert _prune_json(d, keep=2, older_than=None) == 0 + + def test_removes_empty_branch_dirs(self) -> None: + with tempfile.TemporaryDirectory() as d: + path = _write_json_checkpoint( + d, + branch="feature", + name="20260101T000000_aaaa1111_p-none.json", + ) + os.utime(path, (0, 0)) + _prune_json(d, keep=None, older_than=timedelta(days=1)) + assert not os.path.exists(os.path.join(d, "feature")) + + +class TestPruneSqlite: + def test_keep_n(self) -> None: + with tempfile.TemporaryDirectory() as d: + db_path = os.path.join(d, "test.db") + for i in range(5): + _create_sqlite_checkpoint( + db_path, f"2026010{i + 1}T000000_aaa{i}1111" + ) + deleted = _prune_sqlite(db_path, keep=2, older_than=None) + assert deleted == 3 + with sqlite3.connect(db_path) as conn: + count = conn.execute("SELECT COUNT(*) FROM checkpoints").fetchone()[0] + assert count == 2 + + def test_older_than(self) -> None: + with tempfile.TemporaryDirectory() as d: + db_path = os.path.join(d, "test.db") + _create_sqlite_checkpoint(db_path, "20200101T000000_old01111") + _create_sqlite_checkpoint(db_path, "20260417T000000_new01111") + deleted = _prune_sqlite(db_path, keep=None, older_than=timedelta(days=1)) + assert deleted >= 1 + with sqlite3.connect(db_path) as conn: + count = conn.execute("SELECT COUNT(*) FROM checkpoints").fetchone()[0] + assert count >= 1 + + +class TestPruneCommand: + def test_no_options_shows_help(self, capsys: pytest.CaptureFixture[str]) -> None: + with tempfile.TemporaryDirectory() as d: + prune_checkpoints(d, keep=None, older_than=None) + out = capsys.readouterr().out + assert "Specify" in out + + def test_dry_run_json(self, capsys: pytest.CaptureFixture[str]) -> None: + with tempfile.TemporaryDirectory() as d: + _write_json_checkpoint(d) + prune_checkpoints(d, keep=1, older_than=None, dry_run=True) + out = capsys.readouterr().out + assert "Would prune" in out + + def test_not_found(self, capsys: pytest.CaptureFixture[str]) -> None: + prune_checkpoints("/nonexistent", keep=1, older_than=None) + out = capsys.readouterr().out + assert "Not a directory" in out + + +class TestResumeCheckpoint: + def test_not_found(self, capsys: pytest.CaptureFixture[str]) -> None: + with tempfile.TemporaryDirectory() as d: + resume_checkpoint(d, "nonexistent") + out = capsys.readouterr().out + assert "not found" in out + + def test_no_checkpoints(self, capsys: pytest.CaptureFixture[str]) -> None: + with tempfile.TemporaryDirectory() as d: + resume_checkpoint(d, None) + out = capsys.readouterr().out + assert "No checkpoints" in out + + +class TestDiscoverabilityMessage: + def test_checkpoint_listener_logs_resume_hint(self) -> None: + from crewai.state.checkpoint_listener import _do_checkpoint + from crewai.state.runtime import RuntimeState + + state = MagicMock(spec=RuntimeState) + state.root = [] + state.model_dump.return_value = {"entities": [], "event_record": {"nodes": {}}} + state._parent_id = None + state._branch = "main" + + cfg = MagicMock() + cfg.location = "/tmp/cp" + cfg.max_checkpoints = None + cfg.provider.checkpoint.return_value = "/tmp/cp/main/20260101T000000_test1234_p-none.json" + cfg.provider.extract_id.return_value = "20260101T000000_test1234" + + with ( + patch("crewai.state.checkpoint_listener._prepare_entities"), + patch("crewai.state.checkpoint_listener.logger") as mock_logger, + ): + _do_checkpoint(state, cfg) + + cfg.provider.extract_id.assert_called_once() + mock_logger.info.assert_called_once() + logged: str = mock_logger.info.call_args[0][0] + assert "crewai checkpoint resume" in logged + assert "20260101T000000_test1234" in logged