diff --git a/docs/en/concepts/checkpointing.mdx b/docs/en/concepts/checkpointing.mdx
index 21ed13905..d6430eb6f 100644
--- a/docs/en/concepts/checkpointing.mdx
+++ b/docs/en/concepts/checkpointing.mdx
@@ -54,6 +54,7 @@ crew = Crew(
| `on_events` | `list[str]` | `["task_completed"]` | Event types that trigger a checkpoint |
| `provider` | `BaseProvider` | `JsonProvider()` | Storage backend |
| `max_checkpoints` | `int \| None` | `None` | Max checkpoints to keep. Oldest are pruned after each write. Pruning is handled by the provider. |
+| `restore_from` | `Path \| str \| None` | `None` | Path to a checkpoint to restore from. Used when passing config via a kickoff method's `from_checkpoint` parameter. |
### Inheritance and Opt-Out
@@ -79,13 +80,42 @@ crew = Crew(
## Resuming from a Checkpoint
+Pass a `CheckpointConfig` with `restore_from` to any kickoff method. The crew restores from that checkpoint, skips completed tasks, and resumes.
+
```python
-# Restore and resume
-crew = Crew.from_checkpoint("./my_checkpoints/20260407T120000_abc123.json")
-result = crew.kickoff() # picks up from last completed task
+from crewai import Crew, CheckpointConfig
+
+crew = Crew(agents=[...], tasks=[...])
+result = crew.kickoff(
+ from_checkpoint=CheckpointConfig(
+ restore_from="./my_checkpoints/20260407T120000_abc123.json",
+ ),
+)
```
-The restored crew skips already-completed tasks and resumes from the first incomplete one.
+Remaining `CheckpointConfig` fields apply to the new run, so checkpointing continues after the restore.
+
+You can also use the classmethod directly:
+
+```python
+config = CheckpointConfig(restore_from="./my_checkpoints/20260407T120000_abc123.json")
+crew = Crew.from_checkpoint(config)
+result = crew.kickoff()
+```
+
+## Forking from a Checkpoint
+
+`fork()` restores a checkpoint and starts a new execution branch. Useful for exploring alternative paths from the same point.
+
+```python
+from crewai import Crew, CheckpointConfig
+
+config = CheckpointConfig(restore_from="./my_checkpoints/20260407T120000_abc123.json")
+crew = Crew.fork(config, branch="experiment-a")
+result = crew.kickoff(inputs={"strategy": "aggressive"})
+```
+
+Each fork gets a unique lineage ID so checkpoints from different branches don't collide. The `branch` label is optional and auto-generated if omitted.
## Works on Crew, Flow, and Agent
@@ -125,7 +155,8 @@ flow = MyFlow(
result = flow.kickoff()
# Resume
-flow = MyFlow.from_checkpoint("./flow_cp/20260407T120000_abc123.json")
+config = CheckpointConfig(restore_from="./flow_cp/20260407T120000_abc123.json")
+flow = MyFlow.from_checkpoint(config)
result = flow.kickoff()
```
@@ -231,3 +262,44 @@ async def on_llm_done_async(source, event, state):
The `state` argument is the `RuntimeState` passed automatically by the event bus when your handler accepts 3 parameters. You can register handlers on any event type listed in the [Event Listeners](/en/concepts/event-listener) documentation.
Checkpointing is best-effort: if a checkpoint write fails, the error is logged but execution continues uninterrupted.
+
+## CLI
+
+The `crewai checkpoint` command gives you a TUI for browsing, inspecting, resuming, and forking checkpoints. It auto-detects whether your checkpoints are JSON files or a SQLite database.
+
+```bash
+# Launch the TUI — auto-detects .checkpoints/ or .checkpoints.db
+crewai checkpoint
+
+# Point at a specific location
+crewai checkpoint --location ./my_checkpoints
+crewai checkpoint --location ./.checkpoints.db
+```
+
+
+
+
+
+The left panel is a tree view. Checkpoints are grouped by branch, and forks nest under the checkpoint they diverged from. Select a checkpoint to see its metadata, entity state, and task progress in the detail panel. Hit **Resume** to pick up where it left off, or **Fork** to start a new branch from that point.
+
+### Editing inputs and task outputs
+
+When a checkpoint is selected, the detail panel shows:
+
+- **Inputs** — if the original kickoff had inputs (e.g. `{topic}`), they appear as editable fields pre-filled with the original values. Change them before resuming or forking.
+- **Task outputs** — completed tasks show their output in editable text areas. Edit a task's output to change the context that downstream tasks receive. When you modify a task output and hit Fork, all subsequent tasks are invalidated and re-run with the new context.
+
+This is useful for "what if" exploration — fork from a checkpoint, tweak a task's result, and see how it changes downstream behavior.
+
+### Subcommands
+
+```bash
+# List all checkpoints
+crewai checkpoint list ./my_checkpoints
+
+# Inspect a specific checkpoint
+crewai checkpoint info ./my_checkpoints/20260407T120000_abc123.json
+
+# Inspect latest in a SQLite database
+crewai checkpoint info ./.checkpoints.db
+```
diff --git a/docs/images/checkpointing.png b/docs/images/checkpointing.png
new file mode 100644
index 000000000..de1f4776a
Binary files /dev/null and b/docs/images/checkpointing.png differ
diff --git a/lib/crewai/src/crewai/cli/checkpoint_cli.py b/lib/crewai/src/crewai/cli/checkpoint_cli.py
index c61500b20..fa6e003aa 100644
--- a/lib/crewai/src/crewai/cli/checkpoint_cli.py
+++ b/lib/crewai/src/crewai/cli/checkpoint_cli.py
@@ -6,12 +6,16 @@ from datetime import datetime
import glob
import json
import os
+import re
import sqlite3
from typing import Any
import click
+_PLACEHOLDER_RE = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\-]*)}")
+
+
_SQLITE_MAGIC = b"SQLite format 3\x00"
_SELECT_ALL = """
@@ -34,6 +38,25 @@ LIMIT 1
"""
+_DEFAULT_DIR = "./.checkpoints"
+_DEFAULT_DB = "./.checkpoints.db"
+
+
+def _detect_location(location: str) -> str:
+ """Resolve the default checkpoint location.
+
+ When the caller passes the default directory path, check whether a
+ SQLite database exists at the conventional ``.db`` path and prefer it.
+ """
+ if (
+ location == _DEFAULT_DIR
+ and not os.path.exists(_DEFAULT_DIR)
+ and os.path.exists(_DEFAULT_DB)
+ ):
+ return _DEFAULT_DB
+ return location
+
+
def _is_sqlite(path: str) -> bool:
"""Check if a file is a SQLite database by reading its magic bytes."""
if not os.path.isfile(path):
@@ -52,13 +75,7 @@ def _parse_checkpoint_json(raw: str, source: str) -> dict[str, Any]:
nodes = data.get("event_record", {}).get("nodes", {})
event_count = len(nodes)
- trigger_event = None
- if nodes:
- last_node = max(
- nodes.values(),
- key=lambda n: n.get("event", {}).get("emission_sequence") or 0,
- )
- trigger_event = last_node.get("event", {}).get("type")
+ trigger_event = data.get("trigger")
parsed_entities: list[dict[str, Any]] = []
for entity in entities:
@@ -76,16 +93,47 @@ def _parse_checkpoint_json(raw: str, source: str) -> dict[str, Any]:
{
"description": t.get("description", ""),
"completed": t.get("output") is not None,
+ "output": (t.get("output") or {}).get("raw", ""),
}
for t in tasks
]
parsed_entities.append(info)
+ inputs: dict[str, Any] = {}
+ for entity in entities:
+ cp_inputs = entity.get("checkpoint_inputs")
+ if isinstance(cp_inputs, dict) and cp_inputs:
+ inputs = dict(cp_inputs)
+ break
+
+ for entity in entities:
+ for task in entity.get("tasks", []):
+ for field in (
+ "checkpoint_original_description",
+ "checkpoint_original_expected_output",
+ ):
+ text = task.get(field) or ""
+ for match in _PLACEHOLDER_RE.findall(text):
+ if match not in inputs:
+ inputs[match] = ""
+ for agent in entity.get("agents", []):
+ for field in ("role", "goal", "backstory"):
+ text = agent.get(field) or ""
+ for match in _PLACEHOLDER_RE.findall(text):
+ if match not in inputs:
+ inputs[match] = ""
+
+ branch = data.get("branch", "main")
+ parent_id = data.get("parent_id")
+
return {
"source": source,
"event_count": event_count,
"trigger": trigger_event,
"entities": parsed_entities,
+ "branch": branch,
+ "parent_id": parent_id,
+ "inputs": inputs,
}
@@ -189,6 +237,7 @@ def _list_sqlite(db_path: str) -> list[dict[str, Any]]:
"entities": [],
"source": checkpoint_id,
}
+ meta["db"] = db_path
results.append(meta)
return results
@@ -311,6 +360,10 @@ def _print_info(meta: dict[str, Any]) -> None:
trigger = meta.get("trigger")
if trigger:
click.echo(f"Trigger: {trigger}")
+ click.echo(f"Branch: {meta.get('branch', 'main')}")
+ parent_id = meta.get("parent_id")
+ if parent_id:
+ click.echo(f"Parent: {parent_id}")
for ent in meta.get("entities", []):
eid = str(ent.get("id", ""))[:8]
diff --git a/lib/crewai/src/crewai/cli/checkpoint_tui.py b/lib/crewai/src/crewai/cli/checkpoint_tui.py
index c58ed2ea6..e0d10f813 100644
--- a/lib/crewai/src/crewai/cli/checkpoint_tui.py
+++ b/lib/crewai/src/crewai/cli/checkpoint_tui.py
@@ -2,17 +2,23 @@
from __future__ import annotations
+from collections import defaultdict
from typing import Any, ClassVar
from textual.app import App, ComposeResult
from textual.binding import Binding
-from textual.containers import Horizontal, Vertical
-from textual.screen import ModalScreen
-from textual.widgets import Button, Footer, Header, OptionList, Static
-from textual.widgets.option_list import Option
+from textual.containers import Horizontal, Vertical, VerticalScroll
+from textual.widgets import (
+ Button,
+ Footer,
+ Header,
+ Input,
+ Static,
+ TextArea,
+ Tree,
+)
from crewai.cli.checkpoint_cli import (
- _entity_summary,
_format_size,
_is_sqlite,
_list_json,
@@ -34,151 +40,54 @@ def _load_entries(location: str) -> list[dict[str, Any]]:
return _list_json(location)
-def _format_list_label(entry: dict[str, Any]) -> str:
- """Format a checkpoint entry for the list panel."""
- name = entry.get("name", "")
- ts = entry.get("ts") or ""
- trigger = entry.get("trigger") or ""
- summary = _entity_summary(entry.get("entities", []))
-
- line1 = f"[bold]{name}[/]"
- parts = []
- if ts:
- parts.append(f"[dim]{ts}[/]")
- if "size" in entry:
- parts.append(f"[dim]{_format_size(entry['size'])}[/]")
- if trigger:
- parts.append(f"[{_PRIMARY}]{trigger}[/]")
- line2 = " ".join(parts)
- line3 = f" [{_DIM}]{summary}[/]"
-
- return f"{line1}\n{line2}\n{line3}"
+def _short_id(name: str) -> str:
+ """Shorten a checkpoint name for tree display."""
+ if len(name) > 30:
+ return name[:27] + "..."
+ return name
-def _format_detail(entry: dict[str, Any]) -> str:
- """Format checkpoint details for the right panel."""
+def _entry_id(entry: dict[str, Any]) -> str:
+ """Normalize an entry's name into its checkpoint ID.
+
+ JSON filenames are ``{ts}_{uuid}_p-{parent}.json``; SQLite IDs
+ are already ``{ts}_{uuid}``. This strips the JSON suffix so
+ fork-parent lookups work in both providers.
+ """
+ name = str(entry.get("name", ""))
+ if name.endswith(".json"):
+ name = name[: -len(".json")]
+ idx = name.find("_p-")
+ if idx != -1:
+ name = name[:idx]
+ return name
+
+
+def _build_entity_header(ent: dict[str, Any]) -> str:
+ """Build rich text header for an entity (progress bar only)."""
lines: list[str] = []
-
- # Header
- name = entry.get("name", "")
- lines.append(f"[bold {_PRIMARY}]{name}[/]")
- lines.append(f"[{_DIM}]{'─' * 50}[/]")
- lines.append("")
-
- # Metadata table
- ts = entry.get("ts") or "unknown"
- trigger = entry.get("trigger") or ""
- lines.append(f" [bold]Time[/] {ts}")
- if "size" in entry:
- lines.append(f" [bold]Size[/] {_format_size(entry['size'])}")
- lines.append(f" [bold]Events[/] {entry.get('event_count', 0)}")
- if trigger:
- lines.append(f" [bold]Trigger[/] [{_PRIMARY}]{trigger}[/]")
- if "path" in entry:
- lines.append(f" [bold]Path[/] [{_DIM}]{entry['path']}[/]")
- if "db" in entry:
- lines.append(f" [bold]Database[/] [{_DIM}]{entry['db']}[/]")
-
- # Entities
- for ent in entry.get("entities", []):
- eid = str(ent.get("id", ""))[:8]
- etype = ent.get("type", "unknown")
- ename = ent.get("name", "unnamed")
-
- lines.append("")
- lines.append(f" [{_DIM}]{'─' * 50}[/]")
- lines.append(f" [bold {_SECONDARY}]{etype}[/]: {ename} [{_DIM}]{eid}[/]")
-
- tasks = ent.get("tasks")
- if isinstance(tasks, list):
- completed = ent.get("tasks_completed", 0)
- total = ent.get("tasks_total", 0)
- pct = int(completed / total * 100) if total else 0
- bar_len = 20
- filled = int(bar_len * completed / total) if total else 0
- bar = f"[{_PRIMARY}]{'█' * filled}[/][{_DIM}]{'░' * (bar_len - filled)}[/]"
-
- lines.append(f" {bar} {completed}/{total} tasks ({pct}%)")
- lines.append("")
-
- for i, task in enumerate(tasks):
- if task.get("completed"):
- icon = "[green]✓[/]"
- else:
- icon = "[yellow]○[/]"
- desc = str(task.get("description", ""))
- if len(desc) > 55:
- desc = desc[:52] + "..."
- lines.append(f" {icon} {i + 1}. {desc}")
-
+ tasks = ent.get("tasks")
+ if isinstance(tasks, list):
+ completed = ent.get("tasks_completed", 0)
+ total = ent.get("tasks_total", 0)
+ pct = int(completed / total * 100) if total else 0
+ bar_len = 20
+ filled = int(bar_len * completed / total) if total else 0
+ bar = f"[{_PRIMARY}]{'█' * filled}[/][{_DIM}]{'░' * (bar_len - filled)}[/]"
+ lines.append(f"{bar} {completed}/{total} tasks ({pct}%)")
return "\n".join(lines)
-class ConfirmResumeScreen(ModalScreen[bool]):
- """Modal confirmation before resuming from a checkpoint."""
-
- CSS = f"""
- ConfirmResumeScreen {{
- align: center middle;
- }}
- #confirm-dialog {{
- width: 60;
- height: auto;
- padding: 1 2;
- background: {_BG_PANEL};
- border: round {_PRIMARY};
- }}
- #confirm-label {{
- width: 100%;
- content-align: center middle;
- margin-bottom: 1;
- }}
- #confirm-name {{
- width: 100%;
- content-align: center middle;
- color: {_PRIMARY};
- text-style: bold;
- margin-bottom: 1;
- }}
- #confirm-buttons {{
- width: 100%;
- height: 3;
- layout: horizontal;
- align: center middle;
- }}
- Button {{
- margin: 0 2;
- min-width: 12;
- }}
- """
-
- def __init__(self, checkpoint_name: str) -> None:
- super().__init__()
- self._checkpoint_name = checkpoint_name
-
- def compose(self) -> ComposeResult:
- with Vertical(id="confirm-dialog"):
- yield Static("Resume from this checkpoint?", id="confirm-label")
- yield Static(self._checkpoint_name, id="confirm-name")
- with Horizontal(id="confirm-buttons"):
- yield Button("Resume", variant="success", id="btn-yes")
- yield Button("Cancel", variant="default", id="btn-no")
-
- def on_button_pressed(self, event: Button.Pressed) -> None:
- self.dismiss(event.button.id == "btn-yes")
-
- def on_key(self, event: Any) -> None:
- if event.key == "y":
- self.dismiss(True)
- elif event.key in ("n", "escape"):
- self.dismiss(False)
+# Return type: (location, action, inputs, task_output_overrides)
+_TuiResult = tuple[str, str, dict[str, Any] | None, dict[int, str] | None] | None
-class CheckpointTUI(App[str | None]):
+class CheckpointTUI(App[_TuiResult]):
"""TUI to browse and inspect checkpoints.
- Returns the checkpoint location string to resume from, or None if
- the user quit without selecting.
+ Returns ``(location, action, inputs)`` where action is ``"resume"`` or
+ ``"fork"`` and inputs is a parsed dict or ``None``,
+ or ``None`` if the user quit without selecting.
"""
TITLE = "CrewAI Checkpoints"
@@ -199,145 +108,431 @@ class CheckpointTUI(App[str | None]):
background: {_PRIMARY};
color: {_TERTIARY};
}}
- Horizontal {{
+ #main-layout {{
height: 1fr;
}}
- #cp-list {{
- width: 38%;
+ #tree-panel {{
+ width: 45%;
background: {_BG_PANEL};
border: round {_SECONDARY};
padding: 0 1;
scrollbar-color: {_PRIMARY};
}}
- #cp-list:focus {{
+ #tree-panel:focus-within {{
border: round {_PRIMARY};
}}
- #cp-list > .option-list--option-highlighted {{
- background: {_SECONDARY};
- color: {_TERTIARY};
- text-style: none;
- }}
- #cp-list > .option-list--option-highlighted * {{
- color: {_TERTIARY};
- }}
#detail-container {{
- width: 62%;
- padding: 0 1;
+ width: 55%;
+ height: 1fr;
}}
- #detail {{
+ #detail-scroll {{
height: 1fr;
background: {_BG_PANEL};
border: round {_SECONDARY};
padding: 1 2;
- overflow-y: auto;
scrollbar-color: {_PRIMARY};
}}
- #detail:focus {{
+ #detail-scroll:focus-within {{
border: round {_PRIMARY};
}}
+ #detail-header {{
+ margin-bottom: 1;
+ }}
#status {{
height: 1;
padding: 0 2;
color: {_DIM};
}}
+ #inputs-section {{
+ display: none;
+ height: auto;
+ max-height: 8;
+ padding: 0 1;
+ }}
+ #inputs-section.visible {{
+ display: block;
+ }}
+ #inputs-label {{
+ height: 1;
+ color: {_DIM};
+ padding: 0 1;
+ }}
+ .input-row {{
+ height: 3;
+ padding: 0 1;
+ }}
+ .input-row Static {{
+ width: auto;
+ min-width: 12;
+ padding: 1 1 0 0;
+ color: {_TERTIARY};
+ }}
+ .input-row Input {{
+ width: 1fr;
+ }}
+ #no-inputs-label {{
+ height: 1;
+ color: {_DIM};
+ padding: 0 1;
+ }}
+ #action-buttons {{
+ height: 3;
+ align: right middle;
+ padding: 0 1;
+ display: none;
+ }}
+ #action-buttons.visible {{
+ display: block;
+ }}
+ #action-buttons Button {{
+ margin: 0 0 0 1;
+ min-width: 10;
+ }}
+ #btn-resume {{
+ background: {_SECONDARY};
+ color: {_TERTIARY};
+ }}
+ #btn-resume:hover {{
+ background: {_PRIMARY};
+ }}
+ #btn-fork {{
+ background: {_PRIMARY};
+ color: {_TERTIARY};
+ }}
+ #btn-fork:hover {{
+ background: {_SECONDARY};
+ }}
+ .entity-title {{
+ padding: 1 1 0 1;
+ }}
+ .entity-detail {{
+ padding: 0 1;
+ }}
+ .task-output-editor {{
+ height: auto;
+ max-height: 10;
+ margin: 0 1 1 1;
+ border: round {_DIM};
+ }}
+ .task-output-editor:focus {{
+ border: round {_PRIMARY};
+ }}
+ .task-label {{
+ padding: 0 1;
+ }}
+ Tree {{
+ background: {_BG_PANEL};
+ }}
+ Tree > .tree--cursor {{
+ background: {_SECONDARY};
+ color: {_TERTIARY};
+ }}
"""
BINDINGS: ClassVar[list[Binding | tuple[str, str] | tuple[str, str, str]]] = [
("q", "quit", "Quit"),
("r", "refresh", "Refresh"),
- ("j", "cursor_down", "Down"),
- ("k", "cursor_up", "Up"),
]
def __init__(self, location: str = "./.checkpoints") -> None:
super().__init__()
self._location = location
self._entries: list[dict[str, Any]] = []
- self._selected_idx: int = 0
- self._pending_location: str = ""
+ self._selected_entry: dict[str, Any] | None = None
+ self._input_keys: list[str] = []
+ self._task_output_ids: list[tuple[int, str, str]] = []
def compose(self) -> ComposeResult:
yield Header(show_clock=False)
- with Horizontal():
- yield OptionList(id="cp-list")
+ with Horizontal(id="main-layout"):
+ tree: Tree[dict[str, Any]] = Tree("Checkpoints", id="tree-panel")
+ tree.show_root = True
+ tree.guide_depth = 3
+ yield tree
with Vertical(id="detail-container"):
yield Static("", id="status")
- yield Static(
- f"\n [{_DIM}]Select a checkpoint from the list[/]", # noqa: S608
- id="detail",
- )
+ with VerticalScroll(id="detail-scroll"):
+ yield Static(
+ f"[{_DIM}]Select a checkpoint from the tree[/]", # noqa: S608
+ id="detail-header",
+ )
+ with Vertical(id="inputs-section"):
+ yield Static("Inputs", id="inputs-label")
+ with Horizontal(id="action-buttons"):
+ yield Button("Resume", id="btn-resume")
+ yield Button("Fork", id="btn-fork")
yield Footer()
async def on_mount(self) -> None:
- self.query_one("#cp-list", OptionList).border_title = "Checkpoints"
- self.query_one("#detail", Static).border_title = "Detail"
- self._refresh_list()
+ self._refresh_tree()
+ self.query_one("#tree-panel", Tree).root.expand()
- def _refresh_list(self) -> None:
+ def _refresh_tree(self) -> None:
self._entries = _load_entries(self._location)
- option_list = self.query_one("#cp-list", OptionList)
- option_list.clear_options()
+ self._selected_entry = None
+
+ tree = self.query_one("#tree-panel", Tree)
+ tree.clear()
if not self._entries:
- self.query_one("#detail", Static).update(
- f"\n [{_DIM}]No checkpoints in {self._location}[/]"
+ self.query_one("#detail-header", Static).update(
+ f"[{_DIM}]No checkpoints in {self._location}[/]"
)
self.query_one("#status", Static).update("")
self.sub_title = self._location
return
+ # Group by branch
+ branches: dict[str, list[dict[str, Any]]] = defaultdict(list)
for entry in self._entries:
- option_list.add_option(Option(_format_list_label(entry)))
+ branch = entry.get("branch", "main")
+ branches[branch].append(entry)
+
+ # Index checkpoint names to tree nodes so forks can attach
+ node_by_name: dict[str, Any] = {}
+
+ def _make_label(e: dict[str, Any]) -> str:
+ name = e.get("name", "")
+ ts = e.get("ts") or ""
+ trigger = e.get("trigger") or ""
+ parts = [f"[bold]{_short_id(name)}[/]"]
+ if ts:
+ time_part = ts.split(" ")[-1] if " " in ts else ts
+ parts.append(f"[{_DIM}]{time_part}[/]")
+ if trigger:
+ parts.append(f"[{_PRIMARY}]{trigger}[/]")
+ return " ".join(parts)
+
+ fork_parents: set[str] = set()
+ for branch_name, entries in branches.items():
+ if branch_name == "main" or not entries:
+ continue
+ oldest = min(entries, key=lambda e: str(e.get("name", "")))
+ first_parent = oldest.get("parent_id")
+ if first_parent:
+ fork_parents.add(str(first_parent))
+
+ def _add_checkpoint(parent_node: Any, e: dict[str, Any]) -> None:
+ """Add a checkpoint node — expandable only if a fork attaches to it."""
+ cp_id = _entry_id(e)
+ if cp_id in fork_parents:
+ node = parent_node.add(
+ _make_label(e), data=e, expand=False, allow_expand=True
+ )
+ else:
+ node = parent_node.add_leaf(_make_label(e), data=e)
+ node_by_name[cp_id] = node
+
+ if "main" in branches:
+ for entry in reversed(branches["main"]):
+ _add_checkpoint(tree.root, entry)
+
+ fork_branches = [
+ (name, sorted(entries, key=lambda e: str(e.get("name", ""))))
+ for name, entries in branches.items()
+ if name != "main"
+ ]
+ remaining = fork_branches
+ max_passes = len(remaining) + 1
+ while remaining and max_passes > 0:
+ max_passes -= 1
+ deferred = []
+ made_progress = False
+ for branch_name, entries in remaining:
+ first_parent = entries[0].get("parent_id") if entries else None
+ if first_parent and str(first_parent) not in node_by_name:
+ deferred.append((branch_name, entries))
+ continue
+ attach_to: Any = tree.root
+ if first_parent:
+ attach_to = node_by_name.get(str(first_parent), tree.root)
+ branch_label = (
+ f"[bold {_SECONDARY}]{branch_name}[/] [{_DIM}]({len(entries)})[/]"
+ )
+ branch_node = attach_to.add(branch_label, expand=False)
+ for entry in entries:
+ _add_checkpoint(branch_node, entry)
+ made_progress = True
+ remaining = deferred
+ if not made_progress:
+ break
+
+ for branch_name, entries in remaining:
+ branch_label = (
+ f"[bold {_SECONDARY}]{branch_name}[/] "
+ f"[{_DIM}]({len(entries)})[/] [{_DIM}](orphaned)[/]"
+ )
+ branch_node = tree.root.add(branch_label, expand=False)
+ for entry in entries:
+ _add_checkpoint(branch_node, entry)
count = len(self._entries)
storage = "SQLite" if _is_sqlite(self._location) else "JSON"
- self.sub_title = f"{self._location}"
+ self.sub_title = self._location
self.query_one("#status", Static).update(f" {count} checkpoint(s) | {storage}")
- async def on_option_list_option_highlighted(
- self,
- event: OptionList.OptionHighlighted,
- ) -> None:
- idx = event.option_index
- if idx is None:
- return
- if idx < len(self._entries):
- self._selected_idx = idx
- entry = self._entries[idx]
- self.query_one("#detail", Static).update(_format_detail(entry))
+ async def _show_detail(self, entry: dict[str, Any]) -> None:
+ """Update the detail panel for a checkpoint entry."""
+ self._selected_entry = entry
+ self.query_one("#action-buttons").add_class("visible")
- def action_cursor_down(self) -> None:
- self.query_one("#cp-list", OptionList).action_cursor_down()
+ detail_scroll = self.query_one("#detail-scroll", VerticalScroll)
- def action_cursor_up(self) -> None:
- self.query_one("#cp-list", OptionList).action_cursor_up()
+ # Remove all dynamic children except the header — await so IDs are freed
+ to_remove = [c for c in detail_scroll.children if c.id != "detail-header"]
+ for child in to_remove:
+ await child.remove()
- async def on_option_list_option_selected(
- self,
- event: OptionList.OptionSelected,
- ) -> None:
- idx = event.option_index
- if idx is None or idx >= len(self._entries):
- return
- entry = self._entries[idx]
+ # Header
+ name = entry.get("name", "")
+ ts = entry.get("ts") or "unknown"
+ trigger = entry.get("trigger") or ""
+ branch = entry.get("branch", "main")
+ parent_id = entry.get("parent_id")
+
+ header_lines = [
+ f"[bold {_PRIMARY}]{name}[/]",
+ f"[{_DIM}]{'─' * 50}[/]",
+ "",
+ f" [bold]Time[/] {ts}",
+ ]
+ if "size" in entry:
+ header_lines.append(f" [bold]Size[/] {_format_size(entry['size'])}")
+ header_lines.append(f" [bold]Events[/] {entry.get('event_count', 0)}")
+ if trigger:
+ header_lines.append(f" [bold]Trigger[/] [{_PRIMARY}]{trigger}[/]")
+ header_lines.append(f" [bold]Branch[/] [{_SECONDARY}]{branch}[/]")
+ if parent_id:
+ header_lines.append(f" [bold]Parent[/] [{_DIM}]{parent_id}[/]")
if "path" in entry:
- loc = entry["path"]
- elif _is_sqlite(self._location):
- loc = f"{self._location}#{entry['name']}"
- else:
- loc = entry.get("name", "")
- self._pending_location = loc
- name = entry.get("name", loc)
- self.push_screen(ConfirmResumeScreen(name), self._on_confirm)
+ header_lines.append(f" [bold]Path[/] [{_DIM}]{entry['path']}[/]")
+ if "db" in entry:
+ header_lines.append(f" [bold]Database[/] [{_DIM}]{entry['db']}[/]")
- def _on_confirm(self, confirmed: bool | None) -> None:
- if confirmed:
- self.exit(self._pending_location)
- else:
- self._pending_location = ""
+ self.query_one("#detail-header", Static).update("\n".join(header_lines))
+
+ # Entity details and editable task outputs — mounted flat for scrolling
+ self._task_output_ids = []
+ flat_task_idx = 0
+ for ent_idx, ent in enumerate(entry.get("entities", [])):
+ etype = ent.get("type", "unknown")
+ ename = ent.get("name", "unnamed")
+ completed = ent.get("tasks_completed")
+ total = ent.get("tasks_total")
+ entity_title = f"[bold {_SECONDARY}]{etype}: {ename}[/]"
+ if completed is not None and total is not None:
+ entity_title += f" [{_DIM}]{completed}/{total} tasks[/]"
+ await detail_scroll.mount(Static(entity_title, classes="entity-title"))
+ await detail_scroll.mount(
+ Static(_build_entity_header(ent), classes="entity-detail")
+ )
+
+ tasks = ent.get("tasks", [])
+ for i, task in enumerate(tasks):
+ desc = str(task.get("description", ""))
+ if len(desc) > 55:
+ desc = desc[:52] + "..."
+ if task.get("completed"):
+ icon = "[green]✓[/]"
+ await detail_scroll.mount(
+ Static(f" {icon} {i + 1}. {desc}", classes="task-label")
+ )
+ output_text = task.get("output", "")
+ editor_id = f"task-output-{ent_idx}-{i}"
+ await detail_scroll.mount(
+ TextArea(
+ str(output_text),
+ classes="task-output-editor",
+ id=editor_id,
+ )
+ )
+ self._task_output_ids.append(
+ (flat_task_idx, editor_id, str(output_text))
+ )
+ else:
+ icon = "[yellow]○[/]"
+ await detail_scroll.mount(
+ Static(f" {icon} {i + 1}. {desc}", classes="task-label")
+ )
+ flat_task_idx += 1
+
+ # Build input fields
+ await self._build_input_fields(entry.get("inputs", {}))
+
+ async def _build_input_fields(self, inputs: dict[str, Any]) -> None:
+ """Rebuild the inputs section with one field per input key."""
+ section = self.query_one("#inputs-section")
+
+ # Remove old dynamic children — await so IDs are freed
+ for widget in list(section.query(".input-row, .no-inputs")):
+ await widget.remove()
+
+ self._input_keys = []
+
+ if not inputs:
+ await section.mount(Static(f"[{_DIM}]No inputs[/]", classes="no-inputs"))
+ section.add_class("visible")
+ return
+
+ for key, value in inputs.items():
+ self._input_keys.append(key)
+ row = Horizontal(classes="input-row")
+ row.compose_add_child(Static(f"[bold]{key}[/]"))
+ row.compose_add_child(
+ Input(value=str(value), placeholder=key, id=f"input-{key}")
+ )
+ await section.mount(row)
+
+ section.add_class("visible")
+
+ def _collect_inputs(self) -> dict[str, Any] | None:
+ """Collect current values from input fields."""
+ if not self._input_keys:
+ return None
+ result: dict[str, Any] = {}
+ for key in self._input_keys:
+ widget = self.query_one(f"#input-{key}", Input)
+ result[key] = widget.value
+ return result
+
+ def _collect_task_overrides(self) -> dict[int, str] | None:
+ """Collect edited task outputs. Returns only changed values."""
+ if not self._task_output_ids or self._selected_entry is None:
+ return None
+ overrides: dict[int, str] = {}
+ for task_idx, editor_id, original in self._task_output_ids:
+ editor = self.query_one(f"#{editor_id}", TextArea)
+ if editor.text != original:
+ overrides[task_idx] = editor.text
+ return overrides or None
+
+ def _resolve_location(self, entry: dict[str, Any]) -> str:
+ """Get the restore location string for a checkpoint entry."""
+ if "path" in entry:
+ return str(entry["path"])
+ if _is_sqlite(self._location):
+ return f"{self._location}#{entry['name']}"
+ return str(entry.get("name", ""))
+
+ async def on_tree_node_highlighted(
+ self, event: Tree.NodeHighlighted[dict[str, Any]]
+ ) -> None:
+ if event.node.data is not None:
+ await self._show_detail(event.node.data)
+
+ def on_button_pressed(self, event: Button.Pressed) -> None:
+ if self._selected_entry is None:
+ return
+ inputs = self._collect_inputs()
+ overrides = self._collect_task_overrides()
+ loc = self._resolve_location(self._selected_entry)
+ if event.button.id == "btn-resume":
+ self.exit((loc, "resume", inputs, overrides))
+ elif event.button.id == "btn-fork":
+ self.exit((loc, "fork", inputs, overrides))
def action_refresh(self) -> None:
- self._refresh_list()
+ self._refresh_tree()
async def _run_checkpoint_tui_async(location: str) -> None:
@@ -345,18 +540,78 @@ async def _run_checkpoint_tui_async(location: str) -> None:
import click
app = CheckpointTUI(location=location)
- selected = await app.run_async()
+ selection = await app.run_async()
- if selected is None:
+ if selection is None:
return
- click.echo(f"\nResuming from: {selected}\n")
+ selected, action, inputs, task_overrides = selection
from crewai.crew import Crew
from crewai.state.checkpoint_config import CheckpointConfig
- crew = Crew.from_checkpoint(CheckpointConfig(restore_from=selected))
- result = await crew.akickoff()
+ config = CheckpointConfig(restore_from=selected)
+
+ if action == "fork":
+ click.echo(f"\nForking from: {selected}\n")
+ crew = Crew.fork(config)
+ else:
+ click.echo(f"\nResuming from: {selected}\n")
+ crew = Crew.from_checkpoint(config)
+
+ if task_overrides:
+ click.echo("Modifications:")
+ overridden_agents: set[int] = set()
+ for task_idx, new_output in task_overrides.items():
+ if task_idx < len(crew.tasks) and crew.tasks[task_idx].output is not None:
+ desc = crew.tasks[task_idx].description or f"Task {task_idx + 1}"
+ if len(desc) > 60:
+ desc = desc[:57] + "..."
+ crew.tasks[task_idx].output.raw = new_output # type: ignore[union-attr]
+ preview = new_output.replace("\n", " ")
+ if len(preview) > 80:
+ preview = preview[:77] + "..."
+ click.echo(f" Task {task_idx + 1}: {desc}")
+ click.echo(f" -> {preview}")
+ agent = crew.tasks[task_idx].agent
+ if agent and agent.agent_executor:
+ nth = sum(1 for t in crew.tasks[:task_idx] if t.agent is agent)
+ messages = agent.agent_executor.messages
+ system_positions = [
+ i for i, m in enumerate(messages) if m.get("role") == "system"
+ ]
+ if nth < len(system_positions):
+ seg_start = system_positions[nth]
+ seg_end = (
+ system_positions[nth + 1]
+ if nth + 1 < len(system_positions)
+ else len(messages)
+ )
+ for j in range(seg_end - 1, seg_start, -1):
+ if messages[j].get("role") == "assistant":
+ messages[j]["content"] = new_output
+ break
+ overridden_agents.add(id(agent))
+
+ earliest = min(task_overrides)
+ for offset, subsequent in enumerate(
+ crew.tasks[earliest + 1 :], start=earliest + 1
+ ):
+ if subsequent.output and offset not in task_overrides:
+ subsequent.output = None
+ if subsequent.agent and subsequent.agent.agent_executor:
+ subsequent.agent.agent_executor._resuming = False
+ if id(subsequent.agent) not in overridden_agents:
+ subsequent.agent.agent_executor.messages = []
+ click.echo()
+
+ if inputs:
+ click.echo("Inputs:")
+ for k, v in inputs.items():
+ click.echo(f" {k}: {v}")
+ click.echo()
+
+ result = await crew.akickoff(inputs=inputs)
click.echo(f"\nResult: {getattr(result, 'raw', result)}")
diff --git a/lib/crewai/src/crewai/cli/cli.py b/lib/crewai/src/crewai/cli/cli.py
index 20a65dbe1..03bdcd502 100644
--- a/lib/crewai/src/crewai/cli/cli.py
+++ b/lib/crewai/src/crewai/cli/cli.py
@@ -793,6 +793,9 @@ def traces_status() -> None:
@click.pass_context
def checkpoint(ctx: click.Context, location: str) -> None:
"""Browse and inspect checkpoints. Launches a TUI when called without a subcommand."""
+ from crewai.cli.checkpoint_cli import _detect_location
+
+ location = _detect_location(location)
ctx.ensure_object(dict)
ctx.obj["location"] = location
if ctx.invoked_subcommand is None:
@@ -805,18 +808,18 @@ def checkpoint(ctx: click.Context, location: str) -> None:
@click.argument("location", default="./.checkpoints")
def checkpoint_list(location: str) -> None:
"""List checkpoints in a directory."""
- from crewai.cli.checkpoint_cli import list_checkpoints
+ from crewai.cli.checkpoint_cli import _detect_location, list_checkpoints
- list_checkpoints(location)
+ list_checkpoints(_detect_location(location))
@checkpoint.command("info")
@click.argument("path", default="./.checkpoints")
def checkpoint_info(path: str) -> None:
"""Show details of a checkpoint. Pass a file or directory for latest."""
- from crewai.cli.checkpoint_cli import info_checkpoint
+ from crewai.cli.checkpoint_cli import _detect_location, info_checkpoint
- info_checkpoint(path)
+ info_checkpoint(_detect_location(path))
if __name__ == "__main__":
diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py
index 111babbf3..de9a8f73d 100644
--- a/lib/crewai/src/crewai/crew.py
+++ b/lib/crewai/src/crewai/crew.py
@@ -436,6 +436,13 @@ class Crew(FlowTrackable, BaseModel):
if agent.agent_executor is not None and task.output is None:
agent.agent_executor.task = task
break
+ for task in self.tasks:
+ if task.checkpoint_original_description is not None:
+ task._original_description = task.checkpoint_original_description
+ if task.checkpoint_original_expected_output is not None:
+ task._original_expected_output = (
+ task.checkpoint_original_expected_output
+ )
if self.checkpoint_inputs is not None:
self._inputs = self.checkpoint_inputs
if self.checkpoint_kickoff_event_id is not None:
diff --git a/lib/crewai/src/crewai/state/checkpoint_config.py b/lib/crewai/src/crewai/state/checkpoint_config.py
index bd24ef8eb..e03964c05 100644
--- a/lib/crewai/src/crewai/state/checkpoint_config.py
+++ b/lib/crewai/src/crewai/state/checkpoint_config.py
@@ -213,6 +213,9 @@ class CheckpointConfig(BaseModel):
def _register_handlers(self) -> CheckpointConfig:
from crewai.state.checkpoint_listener import _ensure_handlers_registered
+ if isinstance(self.provider, SqliteProvider) and not Path(self.location).suffix:
+ self.location = f"{self.location}.db"
+
_ensure_handlers_registered()
return self
diff --git a/lib/crewai/src/crewai/state/checkpoint_listener.py b/lib/crewai/src/crewai/state/checkpoint_listener.py
index c2ac728a8..2408e88e3 100644
--- a/lib/crewai/src/crewai/state/checkpoint_listener.py
+++ b/lib/crewai/src/crewai/state/checkpoint_listener.py
@@ -7,6 +7,7 @@ avoids per-event overhead when no entity uses checkpointing.
from __future__ import annotations
+import json
import logging
import threading
from typing import Any
@@ -102,10 +103,15 @@ def _find_checkpoint(source: Any) -> CheckpointConfig | None:
return None
-def _do_checkpoint(state: RuntimeState, cfg: CheckpointConfig) -> None:
+def _do_checkpoint(
+ state: RuntimeState, cfg: CheckpointConfig, event: BaseEvent | None = None
+) -> None:
"""Write a checkpoint and prune old ones if configured."""
_prepare_entities(state.root)
- data = state.model_dump_json()
+ payload = state.model_dump(mode="json")
+ if event is not None:
+ payload["trigger"] = event.type
+ data = json.dumps(payload)
location = cfg.provider.checkpoint(
data,
cfg.location,
@@ -134,7 +140,7 @@ def _on_any_event(source: Any, event: BaseEvent, state: Any) -> None:
if cfg is None:
return
try:
- _do_checkpoint(state, cfg)
+ _do_checkpoint(state, cfg, event)
except Exception:
logger.warning("Auto-checkpoint failed for event %s", event.type, exc_info=True)
diff --git a/lib/crewai/src/crewai/state/runtime.py b/lib/crewai/src/crewai/state/runtime.py
index 3f32457bb..daae0620e 100644
--- a/lib/crewai/src/crewai/state/runtime.py
+++ b/lib/crewai/src/crewai/state/runtime.py
@@ -66,6 +66,9 @@ def _sync_checkpoint_fields(entity: object) -> None:
entity.checkpoint_inputs = entity._inputs
entity.checkpoint_train = entity._train
entity.checkpoint_kickoff_event_id = entity._kickoff_event_id
+ for task in entity.tasks:
+ task.checkpoint_original_description = task._original_description
+ task.checkpoint_original_expected_output = task._original_expected_output
def _migrate(data: dict[str, Any]) -> dict[str, Any]:
@@ -121,7 +124,7 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
"parent_id": self._parent_id,
"branch": self._branch,
"entities": [e.model_dump(mode="json") for e in self.root],
- "event_record": self._event_record.model_dump(),
+ "event_record": self._event_record.model_dump(mode="json"),
}
@model_validator(mode="wrap")
@@ -194,7 +197,10 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
return result
def fork(self, branch: str | None = None) -> None:
- """Mark this state as a fork for subsequent checkpoints.
+ """Create a new execution branch and write an initial checkpoint.
+
+ If this state was restored from a checkpoint, an initial checkpoint
+ is written on the new branch so the fork point is recorded.
Args:
branch: Branch label. Auto-generated from the current checkpoint
@@ -204,7 +210,7 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
if branch:
self._branch = branch
elif self._checkpoint_id:
- self._branch = f"fork/{self._checkpoint_id}"
+ self._branch = f"fork/{self._checkpoint_id}_{uuid.uuid4().hex[:6]}"
else:
self._branch = f"fork/{uuid.uuid4().hex[:8]}"
@@ -227,6 +233,7 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
provider = detect_provider(location)
raw = provider.from_checkpoint(location)
state = cls.model_validate_json(raw, **kwargs)
+ state._provider = provider
checkpoint_id = provider.extract_id(location)
state._checkpoint_id = checkpoint_id
state._parent_id = checkpoint_id
@@ -253,6 +260,7 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
provider = detect_provider(location)
raw = await provider.afrom_checkpoint(location)
state = cls.model_validate_json(raw, **kwargs)
+ state._provider = provider
checkpoint_id = provider.extract_id(location)
state._checkpoint_id = checkpoint_id
state._parent_id = checkpoint_id
diff --git a/lib/crewai/src/crewai/task.py b/lib/crewai/src/crewai/task.py
index 4224828e3..0a159cb0e 100644
--- a/lib/crewai/src/crewai/task.py
+++ b/lib/crewai/src/crewai/task.py
@@ -230,6 +230,8 @@ class Task(BaseModel):
_original_description: str | None = PrivateAttr(default=None)
_original_expected_output: str | None = PrivateAttr(default=None)
_original_output_file: str | None = PrivateAttr(default=None)
+ checkpoint_original_description: str | None = Field(default=None, exclude=False)
+ checkpoint_original_expected_output: str | None = Field(default=None, exclude=False)
_thread: threading.Thread | None = PrivateAttr(default=None)
model_config = {"arbitrary_types_allowed": True}
diff --git a/lib/crewai/tests/test_checkpoint.py b/lib/crewai/tests/test_checkpoint.py
index 1f1541790..f645541a4 100644
--- a/lib/crewai/tests/test_checkpoint.py
+++ b/lib/crewai/tests/test_checkpoint.py
@@ -296,7 +296,8 @@ class TestRuntimeStateLineage:
state = self._make_state()
state._checkpoint_id = "20260409T120000_abc12345"
state.fork()
- assert state._branch == "fork/20260409T120000_abc12345"
+ assert state._branch.startswith("fork/20260409T120000_abc12345_")
+ assert len(state._branch) == len("fork/20260409T120000_abc12345_") + 6
def test_fork_no_checkpoint_id_unique(self) -> None:
state = self._make_state()