Files
crewAI/lib/cli/src/crewai_cli/checkpoint_cli.py
2026-05-06 20:46:46 +08:00

733 lines
23 KiB
Python

"""CLI commands for inspecting checkpoint files."""
from __future__ import annotations
from datetime import datetime, timedelta, timezone
import glob
import json
import os
import re
import sqlite3
from typing import Any
import click
_PLACEHOLDER_RE = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\-]*)}")
_SQLITE_MAGIC = b"SQLite format 3\x00"
_SELECT_ALL = """
SELECT id, created_at, json(data)
FROM checkpoints
ORDER BY rowid DESC
"""
_SELECT_ONE = """
SELECT id, created_at, json(data)
FROM checkpoints
WHERE id = ?
"""
_SELECT_LATEST = """
SELECT id, created_at, json(data)
FROM checkpoints
ORDER BY rowid DESC
LIMIT 1
"""
_DELETE_OLDER_THAN = """
DELETE FROM checkpoints
WHERE created_at < ?
"""
_DELETE_KEEP_N = """
DELETE FROM checkpoints WHERE rowid NOT IN (
SELECT rowid FROM checkpoints ORDER BY rowid DESC LIMIT ?
)
"""
_COUNT_CHECKPOINTS = "SELECT COUNT(*) FROM checkpoints"
_SELECT_LIKE = """
SELECT id, created_at, json(data)
FROM checkpoints
WHERE id LIKE ?
ORDER BY rowid DESC
"""
_DEFAULT_DIR = "./.checkpoints"
_DEFAULT_DB = "./.checkpoints.db"
def _detect_location(location: str) -> str:
"""Resolve the default checkpoint location.
When the caller passes the default directory path, check whether a
SQLite database exists at the conventional ``.db`` path and prefer it.
"""
if (
location == _DEFAULT_DIR
and not os.path.exists(_DEFAULT_DIR)
and os.path.exists(_DEFAULT_DB)
):
return _DEFAULT_DB
return location
def _is_sqlite(path: str) -> bool:
"""Check if a file is a SQLite database by reading its magic bytes."""
if not os.path.isfile(path):
return False
try:
with open(path, "rb") as f:
return f.read(16) == _SQLITE_MAGIC
except OSError:
return False
def _parse_checkpoint_json(raw: str, source: str) -> dict[str, Any]:
"""Parse checkpoint JSON into metadata dict."""
data = json.loads(raw)
entities = data.get("entities", [])
nodes = data.get("event_record", {}).get("nodes", {})
event_count = len(nodes)
trigger_event = data.get("trigger")
parsed_entities: list[dict[str, Any]] = []
for entity in entities:
tasks = entity.get("tasks", [])
completed = sum(1 for t in tasks if t.get("output") is not None)
info: dict[str, Any] = {
"type": entity.get("entity_type", "unknown"),
"name": entity.get("name"),
"id": entity.get("id"),
}
raw_agents = entity.get("agents", [])
agents_by_id: dict[str, dict[str, Any]] = {}
parsed_agents: list[dict[str, Any]] = []
for ag in raw_agents:
agent_info: dict[str, Any] = {
"id": ag.get("id", ""),
"role": ag.get("role", ""),
"goal": ag.get("goal", ""),
}
parsed_agents.append(agent_info)
if ag.get("id"):
agents_by_id[str(ag["id"])] = agent_info
if parsed_agents:
info["agents"] = parsed_agents
if tasks:
info["tasks_completed"] = completed
info["tasks_total"] = len(tasks)
parsed_tasks: list[dict[str, Any]] = []
for t in tasks:
task_info: dict[str, Any] = {
"description": t.get("description", ""),
"completed": t.get("output") is not None,
"output": (t.get("output") or {}).get("raw", ""),
}
task_agent = t.get("agent")
if isinstance(task_agent, dict):
task_info["agent_role"] = task_agent.get("role", "")
task_info["agent_id"] = task_agent.get("id", "")
elif isinstance(task_agent, str) and task_agent in agents_by_id:
task_info["agent_role"] = agents_by_id[task_agent].get("role", "")
task_info["agent_id"] = task_agent
parsed_tasks.append(task_info)
info["tasks"] = parsed_tasks
if entity.get("entity_type") == "flow":
completed_methods = entity.get("checkpoint_completed_methods")
if completed_methods:
info["completed_methods"] = sorted(completed_methods)
state = entity.get("checkpoint_state")
if isinstance(state, dict):
info["flow_state"] = state
parsed_entities.append(info)
inputs: dict[str, Any] = {}
for entity in entities:
cp_inputs = entity.get("checkpoint_inputs")
if isinstance(cp_inputs, dict) and cp_inputs:
inputs = dict(cp_inputs)
break
for entity in entities:
for task in entity.get("tasks", []):
for field in (
"checkpoint_original_description",
"checkpoint_original_expected_output",
):
text = task.get(field) or ""
for match in _PLACEHOLDER_RE.findall(text):
if match not in inputs:
inputs[match] = ""
for agent in entity.get("agents", []):
for field in ("role", "goal", "backstory"):
text = agent.get(field) or ""
for match in _PLACEHOLDER_RE.findall(text):
if match not in inputs:
inputs[match] = ""
branch = data.get("branch", "main")
parent_id = data.get("parent_id")
return {
"source": source,
"event_count": event_count,
"trigger": trigger_event,
"entities": parsed_entities,
"branch": branch,
"parent_id": parent_id,
"inputs": inputs,
}
def _format_size(size: int) -> str:
if size < 1024:
return f"{size}B"
if size < 1024 * 1024:
return f"{size / 1024:.1f}KB"
return f"{size / 1024 / 1024:.1f}MB"
def _ts_from_name(name: str) -> str | None:
"""Extract timestamp from checkpoint ID or filename."""
stem = os.path.basename(name).split("_")[0].removesuffix(".json")
try:
dt = datetime.strptime(stem, "%Y%m%dT%H%M%S")
except ValueError:
return None
return dt.strftime("%Y-%m-%d %H:%M:%S")
def _entity_summary(entities: list[dict[str, Any]]) -> str:
parts = []
for ent in entities:
etype = ent.get("type", "unknown")
ename = ent.get("name", "")
completed = ent.get("tasks_completed")
total = ent.get("tasks_total")
if completed is not None and total is not None:
parts.append(f"{etype}:{ename} [{completed}/{total} tasks]")
else:
parts.append(f"{etype}:{ename}")
return ", ".join(parts) if parts else "empty"
# --- JSON directory ---
def _list_json(location: str) -> list[dict[str, Any]]:
pattern = os.path.join(location, "**", "*.json")
results = []
for path in sorted(
glob.glob(pattern, recursive=True), key=os.path.getmtime, reverse=True
):
name = os.path.basename(path)
try:
with open(path) as f:
raw = f.read()
meta = _parse_checkpoint_json(raw, source=name)
meta["name"] = name
meta["ts"] = _ts_from_name(name)
meta["size"] = os.path.getsize(path)
meta["path"] = path
except Exception:
meta = {"name": name, "ts": None, "size": 0, "entities": [], "source": name}
results.append(meta)
return results
def _info_json_latest(location: str) -> dict[str, Any] | None:
pattern = os.path.join(location, "**", "*.json")
files = sorted(
glob.glob(pattern, recursive=True), key=os.path.getmtime, reverse=True
)
if not files:
return None
path = files[0]
with open(path) as f:
raw = f.read()
meta = _parse_checkpoint_json(raw, source=os.path.basename(path))
meta["name"] = os.path.basename(path)
meta["ts"] = _ts_from_name(path)
meta["size"] = os.path.getsize(path)
meta["path"] = path
return meta
def _info_json_file(path: str) -> dict[str, Any]:
with open(path) as f:
raw = f.read()
meta = _parse_checkpoint_json(raw, source=os.path.basename(path))
meta["name"] = os.path.basename(path)
meta["ts"] = _ts_from_name(path)
meta["size"] = os.path.getsize(path)
meta["path"] = path
return meta
# --- SQLite ---
def _list_sqlite(db_path: str) -> list[dict[str, Any]]:
results = []
with sqlite3.connect(db_path) as conn:
for row in conn.execute(_SELECT_ALL):
checkpoint_id, created_at, raw = row
try:
meta = _parse_checkpoint_json(raw, source=checkpoint_id)
meta["name"] = checkpoint_id
meta["ts"] = _ts_from_name(checkpoint_id) or created_at
except Exception:
meta = {
"name": checkpoint_id,
"ts": created_at,
"entities": [],
"source": checkpoint_id,
}
meta["db"] = db_path
results.append(meta)
return results
def _info_sqlite_latest(db_path: str) -> dict[str, Any] | None:
with sqlite3.connect(db_path) as conn:
row = conn.execute(_SELECT_LATEST).fetchone()
if not row:
return None
checkpoint_id, created_at, raw = row
meta = _parse_checkpoint_json(raw, source=checkpoint_id)
meta["name"] = checkpoint_id
meta["ts"] = _ts_from_name(checkpoint_id) or created_at
meta["db"] = db_path
return meta
def _info_sqlite_id(db_path: str, checkpoint_id: str) -> dict[str, Any] | None:
with sqlite3.connect(db_path) as conn:
row = conn.execute(_SELECT_ONE, (checkpoint_id,)).fetchone()
if not row:
row = conn.execute(_SELECT_LIKE, (f"%{checkpoint_id}%",)).fetchone()
if not row:
return None
cid, created_at, raw = row
meta = _parse_checkpoint_json(raw, source=cid)
meta["name"] = cid
meta["ts"] = _ts_from_name(cid) or created_at
meta["db"] = db_path
return meta
# --- Public API ---
def list_checkpoints(location: str) -> None:
"""List all checkpoints at a location."""
if _is_sqlite(location):
entries = _list_sqlite(location)
label = f"SQLite: {location}"
elif os.path.isdir(location):
entries = _list_json(location)
label = location
else:
click.echo(f"Not a directory or SQLite database: {location}")
return
if not entries:
click.echo(f"No checkpoints found in {label}")
return
click.echo(f"Found {len(entries)} checkpoint(s) in {label}\n")
for entry in entries:
ts = entry.get("ts") or "unknown"
name = entry.get("name", "")
size = _format_size(entry["size"]) if "size" in entry else ""
trigger = entry.get("trigger") or ""
summary = _entity_summary(entry.get("entities", []))
parts = [name, ts]
if size:
parts.append(size)
if trigger:
parts.append(trigger)
parts.append(summary)
click.echo(f" {' '.join(parts)}")
def info_checkpoint(path: str) -> None:
"""Show details of a single checkpoint."""
meta: dict[str, Any] | None = None
# db_path#checkpoint_id format
if "#" in path:
db_path, checkpoint_id = path.rsplit("#", 1)
if _is_sqlite(db_path):
meta = _info_sqlite_id(db_path, checkpoint_id)
if not meta:
click.echo(f"Checkpoint not found: {checkpoint_id}")
return
# SQLite file — show latest
if meta is None and _is_sqlite(path):
meta = _info_sqlite_latest(path)
if not meta:
click.echo(f"No checkpoints in database: {path}")
return
click.echo(f"Latest checkpoint: {meta['name']}\n")
# Directory — show latest JSON
if meta is None and os.path.isdir(path):
meta = _info_json_latest(path)
if not meta:
click.echo(f"No checkpoints found in {path}")
return
click.echo(f"Latest checkpoint: {meta['name']}\n")
# Specific JSON file
if meta is None and os.path.isfile(path):
try:
meta = _info_json_file(path)
except Exception as exc:
click.echo(f"Failed to read checkpoint: {exc}")
return
if meta is None:
click.echo(f"Not found: {path}")
return
_print_info(meta)
def _print_info(meta: dict[str, Any]) -> None:
ts = meta.get("ts") or "unknown"
source = meta.get("path") or meta.get("db") or meta.get("source", "")
click.echo(f"Source: {source}")
click.echo(f"Name: {meta.get('name', '')}")
click.echo(f"Time: {ts}")
if "size" in meta:
click.echo(f"Size: {_format_size(meta['size'])}")
click.echo(f"Events: {meta.get('event_count', 0)}")
trigger = meta.get("trigger")
if trigger:
click.echo(f"Trigger: {trigger}")
click.echo(f"Branch: {meta.get('branch', 'main')}")
parent_id = meta.get("parent_id")
if parent_id:
click.echo(f"Parent: {parent_id}")
for ent in meta.get("entities", []):
eid = str(ent.get("id", ""))[:8]
click.echo(f"\n {ent['type']}: {ent.get('name', 'unnamed')} ({eid}...)")
tasks = ent.get("tasks")
if isinstance(tasks, list):
click.echo(
f" Tasks: {ent['tasks_completed']}/{ent['tasks_total']} completed"
)
for i, task in enumerate(tasks):
status = "done" if task.get("completed") else "pending"
desc = str(task.get("description", ""))
if len(desc) > 70:
desc = desc[:67] + "..."
click.echo(f" {i + 1}. [{status}] {desc}")
def _resolve_checkpoint(
location: str, checkpoint_id: str | None
) -> dict[str, Any] | None:
if _is_sqlite(location):
if checkpoint_id:
return _info_sqlite_id(location, checkpoint_id)
return _info_sqlite_latest(location)
if os.path.isdir(location):
if checkpoint_id:
from crewai.state.provider.json_provider import JsonProvider
_json_provider: JsonProvider = JsonProvider()
pattern: str = os.path.join(location, "**", "*.json")
all_files: list[str] = glob.glob(pattern, recursive=True)
matches: list[str] = [
f for f in all_files if checkpoint_id in _json_provider.extract_id(f)
]
matches.sort(key=os.path.getmtime, reverse=True)
if matches:
return _info_json_file(matches[0])
return None
return _info_json_latest(location)
if os.path.isfile(location):
return _info_json_file(location)
return None
def _entity_type_from_meta(meta: dict[str, Any]) -> str:
for ent in meta.get("entities", []):
if ent.get("type") == "flow":
return "flow"
if ent.get("type") == "agent":
return "agent"
return "crew"
def resume_checkpoint(location: str, checkpoint_id: str | None) -> None:
import asyncio
meta: dict[str, Any] | None = _resolve_checkpoint(location, checkpoint_id)
if meta is None:
if checkpoint_id:
click.echo(f"Checkpoint not found: {checkpoint_id}")
else:
click.echo(f"No checkpoints found in {location}")
return
restore_path: str = meta.get("path") or meta.get("source", "")
if meta.get("db"):
restore_path = f"{meta['db']}#{meta['name']}"
click.echo(f"Resuming from: {meta.get('name', restore_path)}")
_print_info(meta)
click.echo()
from crewai.state.checkpoint_config import CheckpointConfig
config: CheckpointConfig = CheckpointConfig(restore_from=restore_path)
entity_type: str = _entity_type_from_meta(meta)
inputs: dict[str, Any] | None = meta.get("inputs") or None
if entity_type == "flow":
from crewai.flow.flow import Flow
flow = Flow.from_checkpoint(config)
result = asyncio.run(flow.kickoff_async(inputs=inputs))
elif entity_type == "agent":
from crewai.agent import Agent
agent = Agent.from_checkpoint(config)
result = asyncio.run(agent.akickoff(messages="Resume execution."))
else:
from crewai.crew import Crew
crew = Crew.from_checkpoint(config)
result = asyncio.run(crew.akickoff(inputs=inputs))
click.echo(f"\nResult: {getattr(result, 'raw', result)}")
def _task_list_from_meta(meta: dict[str, Any]) -> list[dict[str, Any]]:
tasks: list[dict[str, Any]] = []
for ent in meta.get("entities", []):
tasks.extend(
{
"entity": ent.get("name", "unnamed"),
"description": t.get("description", ""),
"completed": t.get("completed", False),
"output": t.get("output", ""),
}
for t in ent.get("tasks", [])
)
return tasks
def diff_checkpoints(location: str, id1: str, id2: str) -> None:
meta1: dict[str, Any] | None = _resolve_checkpoint(location, id1)
meta2: dict[str, Any] | None = _resolve_checkpoint(location, id2)
if meta1 is None:
click.echo(f"Checkpoint not found: {id1}")
return
if meta2 is None:
click.echo(f"Checkpoint not found: {id2}")
return
name1: str = meta1.get("name", id1)
name2: str = meta2.get("name", id2)
click.echo(f"--- {name1}")
click.echo(f"+++ {name2}")
click.echo()
fields: list[tuple[str, str]] = [
("Time", "ts"),
("Branch", "branch"),
("Trigger", "trigger"),
("Events", "event_count"),
]
for label, key in fields:
v1: str = str(meta1.get(key, ""))
v2: str = str(meta2.get(key, ""))
if v1 != v2:
click.echo(f" {label}:")
click.echo(f" - {v1}")
click.echo(f" + {v2}")
inputs1: dict[str, Any] = meta1.get("inputs", {})
inputs2: dict[str, Any] = meta2.get("inputs", {})
all_keys: list[str] = sorted(set(list(inputs1.keys()) + list(inputs2.keys())))
changed_inputs: list[tuple[str, Any, Any]] = [
(k, inputs1.get(k, ""), inputs2.get(k, ""))
for k in all_keys
if inputs1.get(k) != inputs2.get(k)
]
if changed_inputs:
click.echo("\n Inputs:")
for key, v1, v2 in changed_inputs:
click.echo(f" {key}:")
click.echo(f" - {v1}")
click.echo(f" + {v2}")
tasks1: list[dict[str, Any]] = _task_list_from_meta(meta1)
tasks2: list[dict[str, Any]] = _task_list_from_meta(meta2)
max_tasks: int = max(len(tasks1), len(tasks2))
if max_tasks == 0:
return
click.echo("\n Tasks:")
for i in range(max_tasks):
t1: dict[str, Any] | None = tasks1[i] if i < len(tasks1) else None
t2: dict[str, Any] | None = tasks2[i] if i < len(tasks2) else None
if t1 is None:
desc: str = t2["description"][:60] if t2 else ""
click.echo(f" + {i + 1}. [new] {desc}")
continue
if t2 is None:
desc = t1["description"][:60]
click.echo(f" - {i + 1}. [removed] {desc}")
continue
desc = str(t1["description"][:60])
s1: str = "done" if t1["completed"] else "pending"
s2: str = "done" if t2["completed"] else "pending"
if s1 != s2:
click.echo(f" {i + 1}. {desc}")
click.echo(f" status: {s1} -> {s2}")
out1: str = (t1.get("output") or "").strip()
out2: str = (t2.get("output") or "").strip()
if out1 != out2:
if s1 == s2:
click.echo(f" {i + 1}. {desc}")
preview1: str = (
out1[:80] + ("..." if len(out1) > 80 else "") if out1 else "(empty)"
)
preview2: str = (
out2[:80] + ("..." if len(out2) > 80 else "") if out2 else "(empty)"
)
click.echo(" output:")
click.echo(f" - {preview1}")
click.echo(f" + {preview2}")
def _parse_duration(value: str) -> timedelta:
match: re.Match[str] | None = re.match(r"^(\d+)([dhm])$", value.strip())
if not match:
raise click.BadParameter(
f"Invalid duration: {value!r}. Use format like '7d', '24h', or '30m'."
)
amount: int = int(match.group(1))
unit: str = match.group(2)
if unit == "d":
return timedelta(days=amount)
if unit == "h":
return timedelta(hours=amount)
return timedelta(minutes=amount)
def _prune_json(location: str, keep: int | None, older_than: timedelta | None) -> int:
pattern: str = os.path.join(location, "**", "*.json")
files: list[str] = sorted(
glob.glob(pattern, recursive=True), key=os.path.getmtime, reverse=True
)
if not files:
return 0
to_delete: set[str] = set()
if keep is not None and len(files) > keep:
to_delete.update(files[keep:])
if older_than is not None:
cutoff: datetime = datetime.now(timezone.utc) - older_than
for path in files:
mtime: datetime = datetime.fromtimestamp(
os.path.getmtime(path), tz=timezone.utc
)
if mtime < cutoff:
to_delete.add(path)
deleted: int = 0
for path in to_delete:
try:
os.remove(path)
deleted += 1
except OSError: # noqa: PERF203
pass
for dirpath, dirnames, filenames in os.walk(location, topdown=False):
if dirpath != location and not filenames and not dirnames:
try:
os.rmdir(dirpath)
except OSError:
pass
return deleted
def _prune_sqlite(db_path: str, keep: int | None, older_than: timedelta | None) -> int:
deleted: int = 0
with sqlite3.connect(db_path) as conn:
if older_than is not None:
cutoff: str = (datetime.now(timezone.utc) - older_than).strftime(
"%Y%m%dT%H%M%S"
)
cursor: sqlite3.Cursor = conn.execute(_DELETE_OLDER_THAN, (cutoff,))
deleted += cursor.rowcount
if keep is not None:
cursor = conn.execute(_DELETE_KEEP_N, (keep,))
deleted += cursor.rowcount
conn.commit()
return deleted
def prune_checkpoints(
location: str, keep: int | None, older_than: str | None, dry_run: bool = False
) -> None:
if keep is None and older_than is None:
click.echo("Specify --keep N and/or --older-than DURATION (e.g. 7d, 24h)")
return
duration: timedelta | None = _parse_duration(older_than) if older_than else None
deleted: int
if _is_sqlite(location):
if dry_run:
with sqlite3.connect(location) as conn:
total: int = conn.execute(_COUNT_CHECKPOINTS).fetchone()[0]
click.echo(f"Would prune from {total} checkpoint(s) in {location}")
return
deleted = _prune_sqlite(location, keep, duration)
elif os.path.isdir(location):
if dry_run:
files: list[str] = glob.glob(
os.path.join(location, "**", "*.json"), recursive=True
)
click.echo(f"Would prune from {len(files)} checkpoint(s) in {location}")
return
deleted = _prune_json(location, keep, duration)
else:
click.echo(f"Not a directory or SQLite database: {location}")
return
click.echo(f"Pruned {deleted} checkpoint(s) from {location}")