"""Tests for checkpoint CLI commands.""" from __future__ import annotations import json import os import sqlite3 import tempfile import time from datetime import datetime, timedelta, timezone from typing import Any from unittest.mock import MagicMock, patch import pytest from crewai.cli.checkpoint_cli import ( _parse_checkpoint_json, _parse_duration, _prune_json, _prune_sqlite, _resolve_checkpoint, _task_list_from_meta, diff_checkpoints, prune_checkpoints, resume_checkpoint, ) def _make_checkpoint_data( tasks_completed: int = 2, tasks_total: int = 4, trigger: str = "task_completed", branch: str = "main", parent_id: str | None = None, entity_type: str = "crew", name: str = "test_crew", inputs: dict[str, Any] | None = None, ) -> str: tasks: list[dict[str, Any]] = [] for i in range(tasks_total): t: dict[str, Any] = { "description": f"Task {i + 1} description", "expected_output": f"Output {i + 1}", } if i < tasks_completed: t["output"] = {"raw": f"Result of task {i + 1}"} else: t["output"] = None tasks.append(t) data: dict[str, Any] = { "entities": [ { "entity_type": entity_type, "name": name, "id": "abc12345-1234-1234-1234-abcdef012345", "tasks": tasks, "agents": [], "checkpoint_inputs": inputs or {}, } ], "event_record": {"nodes": {f"node_{i}": {} for i in range(3)}}, "trigger": trigger, "branch": branch, "parent_id": parent_id, } return json.dumps(data) def _write_json_checkpoint( base_dir: str, branch: str = "main", name: str | None = None, data: str | None = None, tasks_completed: int = 2, inputs: dict[str, Any] | None = None, ) -> str: branch_dir = os.path.join(base_dir, branch) os.makedirs(branch_dir, exist_ok=True) if name is None: ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S") name = f"{ts}_abcd1234_p-none.json" path = os.path.join(branch_dir, name) if data is None: data = _make_checkpoint_data(tasks_completed=tasks_completed, inputs=inputs) with open(path, "w") as f: f.write(data) return path def _create_sqlite_checkpoint( db_path: str, checkpoint_id: str | None = None, data: str | None = None, tasks_completed: int = 2, branch: str = "main", inputs: dict[str, Any] | None = None, ) -> str: if checkpoint_id is None: ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S") checkpoint_id = f"{ts}_abcd1234" if data is None: data = _make_checkpoint_data( tasks_completed=tasks_completed, branch=branch, inputs=inputs ) with sqlite3.connect(db_path) as conn: conn.execute( """CREATE TABLE IF NOT EXISTS checkpoints ( id TEXT PRIMARY KEY, created_at TEXT NOT NULL, parent_id TEXT, branch TEXT NOT NULL DEFAULT 'main', data JSONB NOT NULL )""" ) conn.execute( "INSERT INTO checkpoints (id, created_at, parent_id, branch, data) " "VALUES (?, ?, ?, ?, jsonb(?))", (checkpoint_id, checkpoint_id.split("_")[0], None, branch, data), ) conn.commit() return checkpoint_id class TestParseDuration: def test_days(self) -> None: assert _parse_duration("7d") == timedelta(days=7) def test_hours(self) -> None: assert _parse_duration("24h") == timedelta(hours=24) def test_minutes(self) -> None: assert _parse_duration("30m") == timedelta(minutes=30) def test_invalid_raises(self) -> None: with pytest.raises(Exception): _parse_duration("abc") def test_no_unit_raises(self) -> None: with pytest.raises(Exception): _parse_duration("7") class TestResolveCheckpoint: def test_json_latest(self) -> None: with tempfile.TemporaryDirectory() as d: _write_json_checkpoint(d, name="20260101T000000_aaaa1111_p-none.json") time.sleep(0.01) path2 = _write_json_checkpoint( d, name="20260102T000000_bbbb2222_p-none.json", tasks_completed=3 ) meta = _resolve_checkpoint(d, None) assert meta is not None assert meta["path"] == path2 def test_json_by_id(self) -> None: with tempfile.TemporaryDirectory() as d: _write_json_checkpoint(d, name="20260101T000000_aaaa1111_p-none.json") _write_json_checkpoint(d, name="20260102T000000_bbbb2222_p-none.json") meta = _resolve_checkpoint(d, "aaaa1111") assert meta is not None assert "aaaa1111" in meta["name"] def test_json_not_found(self) -> None: with tempfile.TemporaryDirectory() as d: _write_json_checkpoint(d) assert _resolve_checkpoint(d, "nonexistent") is None def test_sqlite_latest(self) -> None: with tempfile.TemporaryDirectory() as d: db_path = os.path.join(d, "test.db") _create_sqlite_checkpoint(db_path, "20260101T000000_aaaa1111") _create_sqlite_checkpoint( db_path, "20260102T000000_bbbb2222", tasks_completed=3 ) meta = _resolve_checkpoint(db_path, None) assert meta is not None assert "bbbb2222" in meta["name"] def test_sqlite_by_id(self) -> None: with tempfile.TemporaryDirectory() as d: db_path = os.path.join(d, "test.db") _create_sqlite_checkpoint(db_path, "20260101T000000_aaaa1111") _create_sqlite_checkpoint(db_path, "20260102T000000_bbbb2222") meta = _resolve_checkpoint(db_path, "20260101T000000_aaaa1111") assert meta is not None assert "aaaa1111" in meta["name"] def test_sqlite_partial_id(self) -> None: with tempfile.TemporaryDirectory() as d: db_path = os.path.join(d, "test.db") _create_sqlite_checkpoint(db_path, "20260101T000000_aaaa1111") _create_sqlite_checkpoint(db_path, "20260102T000000_bbbb2222") meta = _resolve_checkpoint(db_path, "aaaa1111") assert meta is not None assert "aaaa1111" in meta["name"] def test_nonexistent(self) -> None: assert _resolve_checkpoint("/nonexistent/path", None) is None class TestTaskListFromMeta: def test_flattens_tasks(self) -> None: data = _make_checkpoint_data(tasks_completed=2, tasks_total=3) meta = _parse_checkpoint_json(data, "test") tasks = _task_list_from_meta(meta) assert len(tasks) == 3 assert tasks[0]["completed"] is True assert tasks[2]["completed"] is False def test_empty_entities(self) -> None: assert _task_list_from_meta({"entities": []}) == [] class TestDiffCheckpoints: def test_diff_shows_status_change(self, capsys: pytest.CaptureFixture[str]) -> None: with tempfile.TemporaryDirectory() as d: _write_json_checkpoint( d, name="20260101T000000_aaaa1111_p-none.json", tasks_completed=1 ) _write_json_checkpoint( d, name="20260102T000000_bbbb2222_p-none.json", tasks_completed=3 ) diff_checkpoints(d, "aaaa1111", "bbbb2222") out = capsys.readouterr().out assert "---" in out assert "+++" in out assert "status:" in out or "pending -> done" in out def test_diff_shows_output_change(self, capsys: pytest.CaptureFixture[str]) -> None: with tempfile.TemporaryDirectory() as d: data1 = _make_checkpoint_data(tasks_completed=2) data2 = json.loads(data1) data2["entities"][0]["tasks"][0]["output"]["raw"] = "Updated result" _write_json_checkpoint( d, name="20260101T000000_aaaa1111_p-none.json", data=json.dumps(json.loads(data1)), ) _write_json_checkpoint( d, name="20260102T000000_bbbb2222_p-none.json", data=json.dumps(data2), ) diff_checkpoints(d, "aaaa1111", "bbbb2222") out = capsys.readouterr().out assert "output:" in out def test_diff_not_found(self, capsys: pytest.CaptureFixture[str]) -> None: with tempfile.TemporaryDirectory() as d: _write_json_checkpoint(d, name="20260101T000000_aaaa1111_p-none.json") diff_checkpoints(d, "aaaa1111", "nonexistent") out = capsys.readouterr().out assert "not found" in out def test_diff_input_change(self, capsys: pytest.CaptureFixture[str]) -> None: with tempfile.TemporaryDirectory() as d: _write_json_checkpoint( d, name="20260101T000000_aaaa1111_p-none.json", inputs={"topic": "AI"}, ) _write_json_checkpoint( d, name="20260102T000000_bbbb2222_p-none.json", inputs={"topic": "ML"}, ) diff_checkpoints(d, "aaaa1111", "bbbb2222") out = capsys.readouterr().out assert "Inputs:" in out assert "AI" in out assert "ML" in out class TestPruneJson: def test_keep_n(self) -> None: with tempfile.TemporaryDirectory() as d: for i in range(5): _write_json_checkpoint( d, name=f"2026010{i + 1}T000000_aaa{i}1111_p-none.json" ) time.sleep(0.01) deleted = _prune_json(d, keep=2, older_than=None) assert deleted == 3 remaining = [] for root, _, files in os.walk(d): remaining.extend(files) assert len(remaining) == 2 def test_older_than(self) -> None: with tempfile.TemporaryDirectory() as d: old_path = _write_json_checkpoint( d, name="20250101T000000_old01111_p-none.json" ) os.utime(old_path, (0, 0)) _write_json_checkpoint(d, name="20990101T000000_new01111_p-none.json") deleted = _prune_json(d, keep=None, older_than=timedelta(days=1)) assert deleted == 1 def test_empty_dir(self) -> None: with tempfile.TemporaryDirectory() as d: assert _prune_json(d, keep=2, older_than=None) == 0 def test_removes_empty_branch_dirs(self) -> None: with tempfile.TemporaryDirectory() as d: path = _write_json_checkpoint( d, branch="feature", name="20260101T000000_aaaa1111_p-none.json", ) os.utime(path, (0, 0)) _prune_json(d, keep=None, older_than=timedelta(days=1)) assert not os.path.exists(os.path.join(d, "feature")) class TestPruneSqlite: def test_keep_n(self) -> None: with tempfile.TemporaryDirectory() as d: db_path = os.path.join(d, "test.db") for i in range(5): _create_sqlite_checkpoint( db_path, f"2026010{i + 1}T000000_aaa{i}1111" ) deleted = _prune_sqlite(db_path, keep=2, older_than=None) assert deleted == 3 with sqlite3.connect(db_path) as conn: count = conn.execute("SELECT COUNT(*) FROM checkpoints").fetchone()[0] assert count == 2 def test_older_than(self) -> None: with tempfile.TemporaryDirectory() as d: db_path = os.path.join(d, "test.db") _create_sqlite_checkpoint(db_path, "20200101T000000_old01111") _create_sqlite_checkpoint(db_path, "20990101T000000_new01111") deleted = _prune_sqlite(db_path, keep=None, older_than=timedelta(days=1)) assert deleted >= 1 with sqlite3.connect(db_path) as conn: count = conn.execute("SELECT COUNT(*) FROM checkpoints").fetchone()[0] assert count >= 1 class TestPruneCommand: def test_no_options_shows_help(self, capsys: pytest.CaptureFixture[str]) -> None: with tempfile.TemporaryDirectory() as d: prune_checkpoints(d, keep=None, older_than=None) out = capsys.readouterr().out assert "Specify" in out def test_dry_run_json(self, capsys: pytest.CaptureFixture[str]) -> None: with tempfile.TemporaryDirectory() as d: _write_json_checkpoint(d) prune_checkpoints(d, keep=1, older_than=None, dry_run=True) out = capsys.readouterr().out assert "Would prune" in out def test_not_found(self, capsys: pytest.CaptureFixture[str]) -> None: prune_checkpoints("/nonexistent", keep=1, older_than=None) out = capsys.readouterr().out assert "Not a directory" in out class TestResumeCheckpoint: def test_not_found(self, capsys: pytest.CaptureFixture[str]) -> None: with tempfile.TemporaryDirectory() as d: resume_checkpoint(d, "nonexistent") out = capsys.readouterr().out assert "not found" in out def test_no_checkpoints(self, capsys: pytest.CaptureFixture[str]) -> None: with tempfile.TemporaryDirectory() as d: resume_checkpoint(d, None) out = capsys.readouterr().out assert "No checkpoints" in out class TestDiscoverabilityMessage: def test_checkpoint_listener_logs_resume_hint(self) -> None: from crewai.state.checkpoint_listener import _do_checkpoint from crewai.state.runtime import RuntimeState state = MagicMock(spec=RuntimeState) state.root = [] state.model_dump.return_value = {"entities": [], "event_record": {"nodes": {}}} state._parent_id = None state._branch = "main" cfg = MagicMock() cfg.location = "/tmp/cp" cfg.max_checkpoints = None cfg.provider.checkpoint.return_value = "/tmp/cp/main/20260101T000000_test1234_p-none.json" cfg.provider.extract_id.return_value = "20260101T000000_test1234" with ( patch("crewai.state.checkpoint_listener._prepare_entities"), patch("crewai.state.checkpoint_listener.logger") as mock_logger, ): _do_checkpoint(state, cfg) cfg.provider.extract_id.assert_called_once() mock_logger.info.assert_called_once() logged: str = mock_logger.info.call_args[0][0] assert "crewai checkpoint resume" in logged assert "20260101T000000_test1234" in logged