mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-03 22:19:27 +00:00
feat: add checkpoint resume, diff, prune commands and save discoverability
Add three new CLI subcommands to improve checkpoint UX: - `crewai checkpoint resume [id]` skips the TUI and resumes from the latest or specified checkpoint directly - `crewai checkpoint diff <id1> <id2>` compares two checkpoints showing changes in metadata, inputs, task status, and outputs - `crewai checkpoint prune --keep N --older-than Xd` removes old checkpoints from JSON dirs or SQLite databases Also writes a resume hint to stderr after every checkpoint save so users discover the command without needing to know it exists.
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
@@ -37,6 +37,26 @@ ORDER BY rowid DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
_DELETE_OLDER_THAN = """
|
||||
DELETE FROM checkpoints
|
||||
WHERE created_at < ?
|
||||
"""
|
||||
|
||||
_DELETE_KEEP_N = """
|
||||
DELETE FROM checkpoints WHERE rowid NOT IN (
|
||||
SELECT rowid FROM checkpoints ORDER BY rowid DESC LIMIT ?
|
||||
)
|
||||
"""
|
||||
|
||||
_COUNT_CHECKPOINTS = "SELECT COUNT(*) FROM checkpoints"
|
||||
|
||||
_SELECT_LIKE = """
|
||||
SELECT id, created_at, json(data)
|
||||
FROM checkpoints
|
||||
WHERE id LIKE ?
|
||||
ORDER BY rowid DESC
|
||||
"""
|
||||
|
||||
|
||||
_DEFAULT_DIR = "./.checkpoints"
|
||||
_DEFAULT_DB = "./.checkpoints.db"
|
||||
@@ -262,6 +282,8 @@ def _info_sqlite_latest(db_path: str) -> dict[str, Any] | None:
|
||||
def _info_sqlite_id(db_path: str, checkpoint_id: str) -> dict[str, Any] | None:
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
row = conn.execute(_SELECT_ONE, (checkpoint_id,)).fetchone()
|
||||
if not row:
|
||||
row = conn.execute(_SELECT_LIKE, (f"%{checkpoint_id}%",)).fetchone()
|
||||
if not row:
|
||||
return None
|
||||
cid, created_at, raw = row
|
||||
@@ -384,3 +406,287 @@ def _print_info(meta: dict[str, Any]) -> None:
|
||||
if len(desc) > 70:
|
||||
desc = desc[:67] + "..."
|
||||
click.echo(f" {i + 1}. [{status}] {desc}")
|
||||
|
||||
|
||||
def _resolve_checkpoint(
|
||||
location: str, checkpoint_id: str | None
|
||||
) -> dict[str, Any] | None:
|
||||
if _is_sqlite(location):
|
||||
if checkpoint_id:
|
||||
return _info_sqlite_id(location, checkpoint_id)
|
||||
return _info_sqlite_latest(location)
|
||||
if os.path.isdir(location):
|
||||
if checkpoint_id:
|
||||
from crewai.state.provider.json_provider import JsonProvider
|
||||
|
||||
_json_provider: JsonProvider = JsonProvider()
|
||||
pattern: str = os.path.join(location, "**", "*.json")
|
||||
all_files: list[str] = glob.glob(pattern, recursive=True)
|
||||
matches: list[str] = [
|
||||
f for f in all_files if checkpoint_id in _json_provider.extract_id(f)
|
||||
]
|
||||
matches.sort(key=os.path.getmtime, reverse=True)
|
||||
if matches:
|
||||
return _info_json_file(matches[0])
|
||||
return None
|
||||
return _info_json_latest(location)
|
||||
if os.path.isfile(location):
|
||||
return _info_json_file(location)
|
||||
return None
|
||||
|
||||
|
||||
def _entity_type_from_meta(meta: dict[str, Any]) -> str:
|
||||
for ent in meta.get("entities", []):
|
||||
if ent.get("type") == "flow":
|
||||
return "flow"
|
||||
return "crew"
|
||||
|
||||
|
||||
def resume_checkpoint(location: str, checkpoint_id: str | None) -> None:
|
||||
import asyncio
|
||||
|
||||
meta: dict[str, Any] | None = _resolve_checkpoint(location, checkpoint_id)
|
||||
if meta is None:
|
||||
if checkpoint_id:
|
||||
click.echo(f"Checkpoint not found: {checkpoint_id}")
|
||||
else:
|
||||
click.echo(f"No checkpoints found in {location}")
|
||||
return
|
||||
|
||||
restore_path: str = meta.get("path") or meta.get("source", "")
|
||||
if meta.get("db"):
|
||||
restore_path = f"{meta['db']}#{meta['name']}"
|
||||
|
||||
click.echo(f"Resuming from: {meta.get('name', restore_path)}")
|
||||
_print_info(meta)
|
||||
click.echo()
|
||||
|
||||
from crewai.state.checkpoint_config import CheckpointConfig
|
||||
|
||||
config: CheckpointConfig = CheckpointConfig(restore_from=restore_path)
|
||||
entity_type: str = _entity_type_from_meta(meta)
|
||||
inputs: dict[str, Any] | None = meta.get("inputs") or None
|
||||
|
||||
if entity_type == "flow":
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
flow = Flow.from_checkpoint(config)
|
||||
result = asyncio.run(flow.kickoff_async(inputs=inputs))
|
||||
else:
|
||||
from crewai.crew import Crew
|
||||
|
||||
crew = Crew.from_checkpoint(config)
|
||||
result = asyncio.run(crew.akickoff(inputs=inputs))
|
||||
|
||||
click.echo(f"\nResult: {getattr(result, 'raw', result)}")
|
||||
|
||||
|
||||
def _task_list_from_meta(meta: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
tasks: list[dict[str, Any]] = []
|
||||
for ent in meta.get("entities", []):
|
||||
tasks.extend(
|
||||
{
|
||||
"entity": ent.get("name", "unnamed"),
|
||||
"description": t.get("description", ""),
|
||||
"completed": t.get("completed", False),
|
||||
"output": t.get("output", ""),
|
||||
}
|
||||
for t in ent.get("tasks", [])
|
||||
)
|
||||
return tasks
|
||||
|
||||
|
||||
def diff_checkpoints(location: str, id1: str, id2: str) -> None:
|
||||
meta1: dict[str, Any] | None = _resolve_checkpoint(location, id1)
|
||||
meta2: dict[str, Any] | None = _resolve_checkpoint(location, id2)
|
||||
|
||||
if meta1 is None:
|
||||
click.echo(f"Checkpoint not found: {id1}")
|
||||
return
|
||||
if meta2 is None:
|
||||
click.echo(f"Checkpoint not found: {id2}")
|
||||
return
|
||||
|
||||
name1: str = meta1.get("name", id1)
|
||||
name2: str = meta2.get("name", id2)
|
||||
|
||||
click.echo(f"--- {name1}")
|
||||
click.echo(f"+++ {name2}")
|
||||
click.echo()
|
||||
|
||||
fields: list[tuple[str, str]] = [
|
||||
("Time", "ts"),
|
||||
("Branch", "branch"),
|
||||
("Trigger", "trigger"),
|
||||
("Events", "event_count"),
|
||||
]
|
||||
for label, key in fields:
|
||||
v1: str = str(meta1.get(key, ""))
|
||||
v2: str = str(meta2.get(key, ""))
|
||||
if v1 != v2:
|
||||
click.echo(f" {label}:")
|
||||
click.echo(f" - {v1}")
|
||||
click.echo(f" + {v2}")
|
||||
|
||||
inputs1: dict[str, Any] = meta1.get("inputs", {})
|
||||
inputs2: dict[str, Any] = meta2.get("inputs", {})
|
||||
all_keys: list[str] = sorted(set(list(inputs1.keys()) + list(inputs2.keys())))
|
||||
changed_inputs: list[tuple[str, Any, Any]] = [
|
||||
(k, inputs1.get(k, ""), inputs2.get(k, ""))
|
||||
for k in all_keys
|
||||
if inputs1.get(k) != inputs2.get(k)
|
||||
]
|
||||
if changed_inputs:
|
||||
click.echo("\n Inputs:")
|
||||
for key, v1, v2 in changed_inputs:
|
||||
click.echo(f" {key}:")
|
||||
click.echo(f" - {v1}")
|
||||
click.echo(f" + {v2}")
|
||||
|
||||
tasks1: list[dict[str, Any]] = _task_list_from_meta(meta1)
|
||||
tasks2: list[dict[str, Any]] = _task_list_from_meta(meta2)
|
||||
|
||||
max_tasks: int = max(len(tasks1), len(tasks2))
|
||||
if max_tasks == 0:
|
||||
return
|
||||
|
||||
click.echo("\n Tasks:")
|
||||
for i in range(max_tasks):
|
||||
t1: dict[str, Any] | None = tasks1[i] if i < len(tasks1) else None
|
||||
t2: dict[str, Any] | None = tasks2[i] if i < len(tasks2) else None
|
||||
|
||||
if t1 is None:
|
||||
desc: str = t2["description"][:60] if t2 else ""
|
||||
click.echo(f" + {i + 1}. [new] {desc}")
|
||||
continue
|
||||
if t2 is None:
|
||||
desc = t1["description"][:60]
|
||||
click.echo(f" - {i + 1}. [removed] {desc}")
|
||||
continue
|
||||
|
||||
desc = str(t1["description"][:60])
|
||||
s1: str = "done" if t1["completed"] else "pending"
|
||||
s2: str = "done" if t2["completed"] else "pending"
|
||||
|
||||
if s1 != s2:
|
||||
click.echo(f" {i + 1}. {desc}")
|
||||
click.echo(f" status: {s1} -> {s2}")
|
||||
|
||||
out1: str = (t1.get("output") or "").strip()
|
||||
out2: str = (t2.get("output") or "").strip()
|
||||
if out1 != out2:
|
||||
if s1 == s2:
|
||||
click.echo(f" {i + 1}. {desc}")
|
||||
preview1: str = (
|
||||
out1[:80] + ("..." if len(out1) > 80 else "") if out1 else "(empty)"
|
||||
)
|
||||
preview2: str = (
|
||||
out2[:80] + ("..." if len(out2) > 80 else "") if out2 else "(empty)"
|
||||
)
|
||||
click.echo(" output:")
|
||||
click.echo(f" - {preview1}")
|
||||
click.echo(f" + {preview2}")
|
||||
|
||||
|
||||
def _parse_duration(value: str) -> timedelta:
|
||||
match: re.Match[str] | None = re.match(r"^(\d+)([dhm])$", value.strip())
|
||||
if not match:
|
||||
raise click.BadParameter(
|
||||
f"Invalid duration: {value!r}. Use format like '7d', '24h', or '30m'."
|
||||
)
|
||||
amount: int = int(match.group(1))
|
||||
unit: str = match.group(2)
|
||||
if unit == "d":
|
||||
return timedelta(days=amount)
|
||||
if unit == "h":
|
||||
return timedelta(hours=amount)
|
||||
return timedelta(minutes=amount)
|
||||
|
||||
|
||||
def _prune_json(location: str, keep: int | None, older_than: timedelta | None) -> int:
|
||||
pattern: str = os.path.join(location, "**", "*.json")
|
||||
files: list[str] = sorted(
|
||||
glob.glob(pattern, recursive=True), key=os.path.getmtime, reverse=True
|
||||
)
|
||||
if not files:
|
||||
return 0
|
||||
|
||||
to_delete: set[str] = set()
|
||||
|
||||
if keep is not None and len(files) > keep:
|
||||
to_delete.update(files[keep:])
|
||||
|
||||
if older_than is not None:
|
||||
cutoff: datetime = datetime.now(timezone.utc) - older_than
|
||||
for path in files:
|
||||
mtime: datetime = datetime.fromtimestamp(
|
||||
os.path.getmtime(path), tz=timezone.utc
|
||||
)
|
||||
if mtime < cutoff:
|
||||
to_delete.add(path)
|
||||
|
||||
deleted: int = 0
|
||||
for path in to_delete:
|
||||
try:
|
||||
os.remove(path)
|
||||
deleted += 1
|
||||
except OSError: # noqa: PERF203
|
||||
pass
|
||||
|
||||
for dirpath, dirnames, filenames in os.walk(location, topdown=False):
|
||||
if dirpath != location and not filenames and not dirnames:
|
||||
try:
|
||||
os.rmdir(dirpath)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
return deleted
|
||||
|
||||
|
||||
def _prune_sqlite(db_path: str, keep: int | None, older_than: timedelta | None) -> int:
|
||||
deleted: int = 0
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
if older_than is not None:
|
||||
cutoff: str = (datetime.now(timezone.utc) - older_than).strftime(
|
||||
"%Y%m%dT%H%M%S"
|
||||
)
|
||||
cursor: sqlite3.Cursor = conn.execute(_DELETE_OLDER_THAN, (cutoff,))
|
||||
deleted += cursor.rowcount
|
||||
|
||||
if keep is not None:
|
||||
cursor = conn.execute(_DELETE_KEEP_N, (keep,))
|
||||
deleted += cursor.rowcount
|
||||
|
||||
conn.commit()
|
||||
return deleted
|
||||
|
||||
|
||||
def prune_checkpoints(
|
||||
location: str, keep: int | None, older_than: str | None, dry_run: bool = False
|
||||
) -> None:
|
||||
if keep is None and older_than is None:
|
||||
click.echo("Specify --keep N and/or --older-than DURATION (e.g. 7d, 24h)")
|
||||
return
|
||||
|
||||
duration: timedelta | None = _parse_duration(older_than) if older_than else None
|
||||
|
||||
deleted: int
|
||||
if _is_sqlite(location):
|
||||
if dry_run:
|
||||
with sqlite3.connect(location) as conn:
|
||||
total: int = conn.execute(_COUNT_CHECKPOINTS).fetchone()[0]
|
||||
click.echo(f"Would prune from {total} checkpoint(s) in {location}")
|
||||
return
|
||||
deleted = _prune_sqlite(location, keep, duration)
|
||||
elif os.path.isdir(location):
|
||||
if dry_run:
|
||||
files: list[str] = glob.glob(
|
||||
os.path.join(location, "**", "*.json"), recursive=True
|
||||
)
|
||||
click.echo(f"Would prune from {len(files)} checkpoint(s) in {location}")
|
||||
return
|
||||
deleted = _prune_json(location, keep, duration)
|
||||
else:
|
||||
click.echo(f"Not a directory or SQLite database: {location}")
|
||||
return
|
||||
click.echo(f"Pruned {deleted} checkpoint(s) from {location}")
|
||||
|
||||
@@ -873,5 +873,48 @@ def checkpoint_info(path: str) -> None:
|
||||
info_checkpoint(_detect_location(path))
|
||||
|
||||
|
||||
@checkpoint.command("resume")
|
||||
@click.argument("checkpoint_id", required=False, default=None)
|
||||
@click.pass_context
|
||||
def checkpoint_resume(ctx: click.Context, checkpoint_id: str | None) -> None:
|
||||
"""Resume from a checkpoint. Defaults to the most recent."""
|
||||
from crewai.cli.checkpoint_cli import resume_checkpoint
|
||||
|
||||
resume_checkpoint(ctx.obj["location"], checkpoint_id)
|
||||
|
||||
|
||||
@checkpoint.command("diff")
|
||||
@click.argument("id1")
|
||||
@click.argument("id2")
|
||||
@click.pass_context
|
||||
def checkpoint_diff(ctx: click.Context, id1: str, id2: str) -> None:
|
||||
"""Compare two checkpoints side-by-side."""
|
||||
from crewai.cli.checkpoint_cli import diff_checkpoints
|
||||
|
||||
diff_checkpoints(ctx.obj["location"], id1, id2)
|
||||
|
||||
|
||||
@checkpoint.command("prune")
|
||||
@click.option(
|
||||
"--keep", type=int, default=None, help="Keep the N most recent checkpoints."
|
||||
)
|
||||
@click.option(
|
||||
"--older-than",
|
||||
default=None,
|
||||
help="Remove checkpoints older than duration (e.g. 7d, 24h, 30m).",
|
||||
)
|
||||
@click.option(
|
||||
"--dry-run", is_flag=True, help="Show what would be pruned without deleting."
|
||||
)
|
||||
@click.pass_context
|
||||
def checkpoint_prune(
|
||||
ctx: click.Context, keep: int | None, older_than: str | None, dry_run: bool
|
||||
) -> None:
|
||||
"""Remove old checkpoints."""
|
||||
from crewai.cli.checkpoint_cli import prune_checkpoints
|
||||
|
||||
prune_checkpoints(ctx.obj["location"], keep, older_than, dry_run)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
crewai()
|
||||
|
||||
@@ -120,6 +120,12 @@ def _do_checkpoint(
|
||||
)
|
||||
state._chain_lineage(cfg.provider, location)
|
||||
|
||||
checkpoint_id: str = cfg.provider.extract_id(location)
|
||||
msg: str = (
|
||||
f"Checkpoint saved. Resume with: crewai checkpoint resume {checkpoint_id}"
|
||||
)
|
||||
logger.info(msg)
|
||||
|
||||
if cfg.max_checkpoints is not None:
|
||||
cfg.provider.prune(cfg.location, cfg.max_checkpoints, branch=state._branch)
|
||||
|
||||
|
||||
402
lib/crewai/tests/test_checkpoint_cli.py
Normal file
402
lib/crewai/tests/test_checkpoint_cli.py
Normal file
@@ -0,0 +1,402 @@
|
||||
"""Tests for checkpoint CLI commands."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import time
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from crewai.cli.checkpoint_cli import (
|
||||
_parse_checkpoint_json,
|
||||
_parse_duration,
|
||||
_prune_json,
|
||||
_prune_sqlite,
|
||||
_resolve_checkpoint,
|
||||
_task_list_from_meta,
|
||||
diff_checkpoints,
|
||||
prune_checkpoints,
|
||||
resume_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def _make_checkpoint_data(
|
||||
tasks_completed: int = 2,
|
||||
tasks_total: int = 4,
|
||||
trigger: str = "task_completed",
|
||||
branch: str = "main",
|
||||
parent_id: str | None = None,
|
||||
entity_type: str = "crew",
|
||||
name: str = "test_crew",
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
tasks: list[dict[str, Any]] = []
|
||||
for i in range(tasks_total):
|
||||
t: dict[str, Any] = {
|
||||
"description": f"Task {i + 1} description",
|
||||
"expected_output": f"Output {i + 1}",
|
||||
}
|
||||
if i < tasks_completed:
|
||||
t["output"] = {"raw": f"Result of task {i + 1}"}
|
||||
else:
|
||||
t["output"] = None
|
||||
tasks.append(t)
|
||||
|
||||
data: dict[str, Any] = {
|
||||
"entities": [
|
||||
{
|
||||
"entity_type": entity_type,
|
||||
"name": name,
|
||||
"id": "abc12345-1234-1234-1234-abcdef012345",
|
||||
"tasks": tasks,
|
||||
"agents": [],
|
||||
"checkpoint_inputs": inputs or {},
|
||||
}
|
||||
],
|
||||
"event_record": {"nodes": {f"node_{i}": {} for i in range(3)}},
|
||||
"trigger": trigger,
|
||||
"branch": branch,
|
||||
"parent_id": parent_id,
|
||||
}
|
||||
return json.dumps(data)
|
||||
|
||||
|
||||
def _write_json_checkpoint(
|
||||
base_dir: str,
|
||||
branch: str = "main",
|
||||
name: str | None = None,
|
||||
data: str | None = None,
|
||||
tasks_completed: int = 2,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
branch_dir = os.path.join(base_dir, branch)
|
||||
os.makedirs(branch_dir, exist_ok=True)
|
||||
if name is None:
|
||||
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S")
|
||||
name = f"{ts}_abcd1234_p-none.json"
|
||||
path = os.path.join(branch_dir, name)
|
||||
if data is None:
|
||||
data = _make_checkpoint_data(tasks_completed=tasks_completed, inputs=inputs)
|
||||
with open(path, "w") as f:
|
||||
f.write(data)
|
||||
return path
|
||||
|
||||
|
||||
def _create_sqlite_checkpoint(
|
||||
db_path: str,
|
||||
checkpoint_id: str | None = None,
|
||||
data: str | None = None,
|
||||
tasks_completed: int = 2,
|
||||
branch: str = "main",
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
if checkpoint_id is None:
|
||||
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S")
|
||||
checkpoint_id = f"{ts}_abcd1234"
|
||||
if data is None:
|
||||
data = _make_checkpoint_data(
|
||||
tasks_completed=tasks_completed, branch=branch, inputs=inputs
|
||||
)
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
conn.execute(
|
||||
"""CREATE TABLE IF NOT EXISTS checkpoints (
|
||||
id TEXT PRIMARY KEY,
|
||||
created_at TEXT NOT NULL,
|
||||
parent_id TEXT,
|
||||
branch TEXT NOT NULL DEFAULT 'main',
|
||||
data JSONB NOT NULL
|
||||
)"""
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO checkpoints (id, created_at, parent_id, branch, data) "
|
||||
"VALUES (?, ?, ?, ?, jsonb(?))",
|
||||
(checkpoint_id, checkpoint_id.split("_")[0], None, branch, data),
|
||||
)
|
||||
conn.commit()
|
||||
return checkpoint_id
|
||||
|
||||
|
||||
class TestParseDuration:
|
||||
def test_days(self) -> None:
|
||||
assert _parse_duration("7d") == timedelta(days=7)
|
||||
|
||||
def test_hours(self) -> None:
|
||||
assert _parse_duration("24h") == timedelta(hours=24)
|
||||
|
||||
def test_minutes(self) -> None:
|
||||
assert _parse_duration("30m") == timedelta(minutes=30)
|
||||
|
||||
def test_invalid_raises(self) -> None:
|
||||
with pytest.raises(Exception):
|
||||
_parse_duration("abc")
|
||||
|
||||
def test_no_unit_raises(self) -> None:
|
||||
with pytest.raises(Exception):
|
||||
_parse_duration("7")
|
||||
|
||||
|
||||
class TestResolveCheckpoint:
|
||||
def test_json_latest(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
_write_json_checkpoint(d, name="20260101T000000_aaaa1111_p-none.json")
|
||||
time.sleep(0.01)
|
||||
path2 = _write_json_checkpoint(
|
||||
d, name="20260102T000000_bbbb2222_p-none.json", tasks_completed=3
|
||||
)
|
||||
meta = _resolve_checkpoint(d, None)
|
||||
assert meta is not None
|
||||
assert meta["path"] == path2
|
||||
|
||||
def test_json_by_id(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
_write_json_checkpoint(d, name="20260101T000000_aaaa1111_p-none.json")
|
||||
_write_json_checkpoint(d, name="20260102T000000_bbbb2222_p-none.json")
|
||||
meta = _resolve_checkpoint(d, "aaaa1111")
|
||||
assert meta is not None
|
||||
assert "aaaa1111" in meta["name"]
|
||||
|
||||
def test_json_not_found(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
_write_json_checkpoint(d)
|
||||
assert _resolve_checkpoint(d, "nonexistent") is None
|
||||
|
||||
def test_sqlite_latest(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
db_path = os.path.join(d, "test.db")
|
||||
_create_sqlite_checkpoint(db_path, "20260101T000000_aaaa1111")
|
||||
_create_sqlite_checkpoint(
|
||||
db_path, "20260102T000000_bbbb2222", tasks_completed=3
|
||||
)
|
||||
meta = _resolve_checkpoint(db_path, None)
|
||||
assert meta is not None
|
||||
assert "bbbb2222" in meta["name"]
|
||||
|
||||
def test_sqlite_by_id(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
db_path = os.path.join(d, "test.db")
|
||||
_create_sqlite_checkpoint(db_path, "20260101T000000_aaaa1111")
|
||||
_create_sqlite_checkpoint(db_path, "20260102T000000_bbbb2222")
|
||||
meta = _resolve_checkpoint(db_path, "20260101T000000_aaaa1111")
|
||||
assert meta is not None
|
||||
assert "aaaa1111" in meta["name"]
|
||||
|
||||
def test_sqlite_partial_id(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
db_path = os.path.join(d, "test.db")
|
||||
_create_sqlite_checkpoint(db_path, "20260101T000000_aaaa1111")
|
||||
_create_sqlite_checkpoint(db_path, "20260102T000000_bbbb2222")
|
||||
meta = _resolve_checkpoint(db_path, "aaaa1111")
|
||||
assert meta is not None
|
||||
assert "aaaa1111" in meta["name"]
|
||||
|
||||
def test_nonexistent(self) -> None:
|
||||
assert _resolve_checkpoint("/nonexistent/path", None) is None
|
||||
|
||||
|
||||
class TestTaskListFromMeta:
|
||||
def test_flattens_tasks(self) -> None:
|
||||
data = _make_checkpoint_data(tasks_completed=2, tasks_total=3)
|
||||
meta = _parse_checkpoint_json(data, "test")
|
||||
tasks = _task_list_from_meta(meta)
|
||||
assert len(tasks) == 3
|
||||
assert tasks[0]["completed"] is True
|
||||
assert tasks[2]["completed"] is False
|
||||
|
||||
def test_empty_entities(self) -> None:
|
||||
assert _task_list_from_meta({"entities": []}) == []
|
||||
|
||||
|
||||
class TestDiffCheckpoints:
|
||||
def test_diff_shows_status_change(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
_write_json_checkpoint(
|
||||
d, name="20260101T000000_aaaa1111_p-none.json", tasks_completed=1
|
||||
)
|
||||
_write_json_checkpoint(
|
||||
d, name="20260102T000000_bbbb2222_p-none.json", tasks_completed=3
|
||||
)
|
||||
diff_checkpoints(d, "aaaa1111", "bbbb2222")
|
||||
out = capsys.readouterr().out
|
||||
assert "---" in out
|
||||
assert "+++" in out
|
||||
assert "status:" in out or "pending -> done" in out
|
||||
|
||||
def test_diff_shows_output_change(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
data1 = _make_checkpoint_data(tasks_completed=2)
|
||||
data2 = json.loads(data1)
|
||||
data2["entities"][0]["tasks"][0]["output"]["raw"] = "Updated result"
|
||||
_write_json_checkpoint(
|
||||
d,
|
||||
name="20260101T000000_aaaa1111_p-none.json",
|
||||
data=json.dumps(json.loads(data1)),
|
||||
)
|
||||
_write_json_checkpoint(
|
||||
d,
|
||||
name="20260102T000000_bbbb2222_p-none.json",
|
||||
data=json.dumps(data2),
|
||||
)
|
||||
diff_checkpoints(d, "aaaa1111", "bbbb2222")
|
||||
out = capsys.readouterr().out
|
||||
assert "output:" in out
|
||||
|
||||
def test_diff_not_found(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
_write_json_checkpoint(d, name="20260101T000000_aaaa1111_p-none.json")
|
||||
diff_checkpoints(d, "aaaa1111", "nonexistent")
|
||||
out = capsys.readouterr().out
|
||||
assert "not found" in out
|
||||
|
||||
def test_diff_input_change(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
_write_json_checkpoint(
|
||||
d,
|
||||
name="20260101T000000_aaaa1111_p-none.json",
|
||||
inputs={"topic": "AI"},
|
||||
)
|
||||
_write_json_checkpoint(
|
||||
d,
|
||||
name="20260102T000000_bbbb2222_p-none.json",
|
||||
inputs={"topic": "ML"},
|
||||
)
|
||||
diff_checkpoints(d, "aaaa1111", "bbbb2222")
|
||||
out = capsys.readouterr().out
|
||||
assert "Inputs:" in out
|
||||
assert "AI" in out
|
||||
assert "ML" in out
|
||||
|
||||
|
||||
class TestPruneJson:
|
||||
def test_keep_n(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
for i in range(5):
|
||||
_write_json_checkpoint(
|
||||
d, name=f"2026010{i + 1}T000000_aaa{i}1111_p-none.json"
|
||||
)
|
||||
time.sleep(0.01)
|
||||
deleted = _prune_json(d, keep=2, older_than=None)
|
||||
assert deleted == 3
|
||||
remaining = []
|
||||
for root, _, files in os.walk(d):
|
||||
remaining.extend(files)
|
||||
assert len(remaining) == 2
|
||||
|
||||
def test_older_than(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
old_path = _write_json_checkpoint(
|
||||
d, name="20250101T000000_old01111_p-none.json"
|
||||
)
|
||||
os.utime(old_path, (0, 0))
|
||||
_write_json_checkpoint(d, name="20260417T000000_new01111_p-none.json")
|
||||
deleted = _prune_json(d, keep=None, older_than=timedelta(days=1))
|
||||
assert deleted == 1
|
||||
|
||||
def test_empty_dir(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
assert _prune_json(d, keep=2, older_than=None) == 0
|
||||
|
||||
def test_removes_empty_branch_dirs(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
path = _write_json_checkpoint(
|
||||
d,
|
||||
branch="feature",
|
||||
name="20260101T000000_aaaa1111_p-none.json",
|
||||
)
|
||||
os.utime(path, (0, 0))
|
||||
_prune_json(d, keep=None, older_than=timedelta(days=1))
|
||||
assert not os.path.exists(os.path.join(d, "feature"))
|
||||
|
||||
|
||||
class TestPruneSqlite:
|
||||
def test_keep_n(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
db_path = os.path.join(d, "test.db")
|
||||
for i in range(5):
|
||||
_create_sqlite_checkpoint(
|
||||
db_path, f"2026010{i + 1}T000000_aaa{i}1111"
|
||||
)
|
||||
deleted = _prune_sqlite(db_path, keep=2, older_than=None)
|
||||
assert deleted == 3
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
count = conn.execute("SELECT COUNT(*) FROM checkpoints").fetchone()[0]
|
||||
assert count == 2
|
||||
|
||||
def test_older_than(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
db_path = os.path.join(d, "test.db")
|
||||
_create_sqlite_checkpoint(db_path, "20200101T000000_old01111")
|
||||
_create_sqlite_checkpoint(db_path, "20260417T000000_new01111")
|
||||
deleted = _prune_sqlite(db_path, keep=None, older_than=timedelta(days=1))
|
||||
assert deleted >= 1
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
count = conn.execute("SELECT COUNT(*) FROM checkpoints").fetchone()[0]
|
||||
assert count >= 1
|
||||
|
||||
|
||||
class TestPruneCommand:
|
||||
def test_no_options_shows_help(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
prune_checkpoints(d, keep=None, older_than=None)
|
||||
out = capsys.readouterr().out
|
||||
assert "Specify" in out
|
||||
|
||||
def test_dry_run_json(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
_write_json_checkpoint(d)
|
||||
prune_checkpoints(d, keep=1, older_than=None, dry_run=True)
|
||||
out = capsys.readouterr().out
|
||||
assert "Would prune" in out
|
||||
|
||||
def test_not_found(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
prune_checkpoints("/nonexistent", keep=1, older_than=None)
|
||||
out = capsys.readouterr().out
|
||||
assert "Not a directory" in out
|
||||
|
||||
|
||||
class TestResumeCheckpoint:
|
||||
def test_not_found(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
resume_checkpoint(d, "nonexistent")
|
||||
out = capsys.readouterr().out
|
||||
assert "not found" in out
|
||||
|
||||
def test_no_checkpoints(self, capsys: pytest.CaptureFixture[str]) -> None:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
resume_checkpoint(d, None)
|
||||
out = capsys.readouterr().out
|
||||
assert "No checkpoints" in out
|
||||
|
||||
|
||||
class TestDiscoverabilityMessage:
|
||||
def test_checkpoint_listener_logs_resume_hint(self) -> None:
|
||||
from crewai.state.checkpoint_listener import _do_checkpoint
|
||||
from crewai.state.runtime import RuntimeState
|
||||
|
||||
state = MagicMock(spec=RuntimeState)
|
||||
state.root = []
|
||||
state.model_dump.return_value = {"entities": [], "event_record": {"nodes": {}}}
|
||||
state._parent_id = None
|
||||
state._branch = "main"
|
||||
|
||||
cfg = MagicMock()
|
||||
cfg.location = "/tmp/cp"
|
||||
cfg.max_checkpoints = None
|
||||
cfg.provider.checkpoint.return_value = "/tmp/cp/main/20260101T000000_test1234_p-none.json"
|
||||
cfg.provider.extract_id.return_value = "20260101T000000_test1234"
|
||||
|
||||
with (
|
||||
patch("crewai.state.checkpoint_listener._prepare_entities"),
|
||||
patch("crewai.state.checkpoint_listener.logger") as mock_logger,
|
||||
):
|
||||
_do_checkpoint(state, cfg)
|
||||
|
||||
cfg.provider.extract_id.assert_called_once()
|
||||
mock_logger.info.assert_called_once()
|
||||
logged: str = mock_logger.info.call_args[0][0]
|
||||
assert "crewai checkpoint resume" in logged
|
||||
assert "20260101T000000_test1234" in logged
|
||||
Reference in New Issue
Block a user