diff --git a/docs/en/concepts/checkpointing.mdx b/docs/en/concepts/checkpointing.mdx index 21ed13905..d6430eb6f 100644 --- a/docs/en/concepts/checkpointing.mdx +++ b/docs/en/concepts/checkpointing.mdx @@ -54,6 +54,7 @@ crew = Crew( | `on_events` | `list[str]` | `["task_completed"]` | Event types that trigger a checkpoint | | `provider` | `BaseProvider` | `JsonProvider()` | Storage backend | | `max_checkpoints` | `int \| None` | `None` | Max checkpoints to keep. Oldest are pruned after each write. Pruning is handled by the provider. | +| `restore_from` | `Path \| str \| None` | `None` | Path to a checkpoint to restore from. Used when passing config via a kickoff method's `from_checkpoint` parameter. | ### Inheritance and Opt-Out @@ -79,13 +80,42 @@ crew = Crew( ## Resuming from a Checkpoint +Pass a `CheckpointConfig` with `restore_from` to any kickoff method. The crew restores from that checkpoint, skips completed tasks, and resumes. + ```python -# Restore and resume -crew = Crew.from_checkpoint("./my_checkpoints/20260407T120000_abc123.json") -result = crew.kickoff() # picks up from last completed task +from crewai import Crew, CheckpointConfig + +crew = Crew(agents=[...], tasks=[...]) +result = crew.kickoff( + from_checkpoint=CheckpointConfig( + restore_from="./my_checkpoints/20260407T120000_abc123.json", + ), +) ``` -The restored crew skips already-completed tasks and resumes from the first incomplete one. +Remaining `CheckpointConfig` fields apply to the new run, so checkpointing continues after the restore. + +You can also use the classmethod directly: + +```python +config = CheckpointConfig(restore_from="./my_checkpoints/20260407T120000_abc123.json") +crew = Crew.from_checkpoint(config) +result = crew.kickoff() +``` + +## Forking from a Checkpoint + +`fork()` restores a checkpoint and starts a new execution branch. Useful for exploring alternative paths from the same point. + +```python +from crewai import Crew, CheckpointConfig + +config = CheckpointConfig(restore_from="./my_checkpoints/20260407T120000_abc123.json") +crew = Crew.fork(config, branch="experiment-a") +result = crew.kickoff(inputs={"strategy": "aggressive"}) +``` + +Each fork gets a unique lineage ID so checkpoints from different branches don't collide. The `branch` label is optional and auto-generated if omitted. ## Works on Crew, Flow, and Agent @@ -125,7 +155,8 @@ flow = MyFlow( result = flow.kickoff() # Resume -flow = MyFlow.from_checkpoint("./flow_cp/20260407T120000_abc123.json") +config = CheckpointConfig(restore_from="./flow_cp/20260407T120000_abc123.json") +flow = MyFlow.from_checkpoint(config) result = flow.kickoff() ``` @@ -231,3 +262,44 @@ async def on_llm_done_async(source, event, state): The `state` argument is the `RuntimeState` passed automatically by the event bus when your handler accepts 3 parameters. You can register handlers on any event type listed in the [Event Listeners](/en/concepts/event-listener) documentation. Checkpointing is best-effort: if a checkpoint write fails, the error is logged but execution continues uninterrupted. + +## CLI + +The `crewai checkpoint` command gives you a TUI for browsing, inspecting, resuming, and forking checkpoints. It auto-detects whether your checkpoints are JSON files or a SQLite database. + +```bash +# Launch the TUI — auto-detects .checkpoints/ or .checkpoints.db +crewai checkpoint + +# Point at a specific location +crewai checkpoint --location ./my_checkpoints +crewai checkpoint --location ./.checkpoints.db +``` + + + Checkpoint TUI + + +The left panel is a tree view. Checkpoints are grouped by branch, and forks nest under the checkpoint they diverged from. Select a checkpoint to see its metadata, entity state, and task progress in the detail panel. Hit **Resume** to pick up where it left off, or **Fork** to start a new branch from that point. + +### Editing inputs and task outputs + +When a checkpoint is selected, the detail panel shows: + +- **Inputs** — if the original kickoff had inputs (e.g. `{topic}`), they appear as editable fields pre-filled with the original values. Change them before resuming or forking. +- **Task outputs** — completed tasks show their output in editable text areas. Edit a task's output to change the context that downstream tasks receive. When you modify a task output and hit Fork, all subsequent tasks are invalidated and re-run with the new context. + +This is useful for "what if" exploration — fork from a checkpoint, tweak a task's result, and see how it changes downstream behavior. + +### Subcommands + +```bash +# List all checkpoints +crewai checkpoint list ./my_checkpoints + +# Inspect a specific checkpoint +crewai checkpoint info ./my_checkpoints/20260407T120000_abc123.json + +# Inspect latest in a SQLite database +crewai checkpoint info ./.checkpoints.db +``` diff --git a/docs/images/checkpointing.png b/docs/images/checkpointing.png new file mode 100644 index 000000000..de1f4776a Binary files /dev/null and b/docs/images/checkpointing.png differ diff --git a/lib/crewai/src/crewai/cli/checkpoint_cli.py b/lib/crewai/src/crewai/cli/checkpoint_cli.py index c61500b20..fa6e003aa 100644 --- a/lib/crewai/src/crewai/cli/checkpoint_cli.py +++ b/lib/crewai/src/crewai/cli/checkpoint_cli.py @@ -6,12 +6,16 @@ from datetime import datetime import glob import json import os +import re import sqlite3 from typing import Any import click +_PLACEHOLDER_RE = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\-]*)}") + + _SQLITE_MAGIC = b"SQLite format 3\x00" _SELECT_ALL = """ @@ -34,6 +38,25 @@ LIMIT 1 """ +_DEFAULT_DIR = "./.checkpoints" +_DEFAULT_DB = "./.checkpoints.db" + + +def _detect_location(location: str) -> str: + """Resolve the default checkpoint location. + + When the caller passes the default directory path, check whether a + SQLite database exists at the conventional ``.db`` path and prefer it. + """ + if ( + location == _DEFAULT_DIR + and not os.path.exists(_DEFAULT_DIR) + and os.path.exists(_DEFAULT_DB) + ): + return _DEFAULT_DB + return location + + def _is_sqlite(path: str) -> bool: """Check if a file is a SQLite database by reading its magic bytes.""" if not os.path.isfile(path): @@ -52,13 +75,7 @@ def _parse_checkpoint_json(raw: str, source: str) -> dict[str, Any]: nodes = data.get("event_record", {}).get("nodes", {}) event_count = len(nodes) - trigger_event = None - if nodes: - last_node = max( - nodes.values(), - key=lambda n: n.get("event", {}).get("emission_sequence") or 0, - ) - trigger_event = last_node.get("event", {}).get("type") + trigger_event = data.get("trigger") parsed_entities: list[dict[str, Any]] = [] for entity in entities: @@ -76,16 +93,47 @@ def _parse_checkpoint_json(raw: str, source: str) -> dict[str, Any]: { "description": t.get("description", ""), "completed": t.get("output") is not None, + "output": (t.get("output") or {}).get("raw", ""), } for t in tasks ] parsed_entities.append(info) + inputs: dict[str, Any] = {} + for entity in entities: + cp_inputs = entity.get("checkpoint_inputs") + if isinstance(cp_inputs, dict) and cp_inputs: + inputs = dict(cp_inputs) + break + + for entity in entities: + for task in entity.get("tasks", []): + for field in ( + "checkpoint_original_description", + "checkpoint_original_expected_output", + ): + text = task.get(field) or "" + for match in _PLACEHOLDER_RE.findall(text): + if match not in inputs: + inputs[match] = "" + for agent in entity.get("agents", []): + for field in ("role", "goal", "backstory"): + text = agent.get(field) or "" + for match in _PLACEHOLDER_RE.findall(text): + if match not in inputs: + inputs[match] = "" + + branch = data.get("branch", "main") + parent_id = data.get("parent_id") + return { "source": source, "event_count": event_count, "trigger": trigger_event, "entities": parsed_entities, + "branch": branch, + "parent_id": parent_id, + "inputs": inputs, } @@ -189,6 +237,7 @@ def _list_sqlite(db_path: str) -> list[dict[str, Any]]: "entities": [], "source": checkpoint_id, } + meta["db"] = db_path results.append(meta) return results @@ -311,6 +360,10 @@ def _print_info(meta: dict[str, Any]) -> None: trigger = meta.get("trigger") if trigger: click.echo(f"Trigger: {trigger}") + click.echo(f"Branch: {meta.get('branch', 'main')}") + parent_id = meta.get("parent_id") + if parent_id: + click.echo(f"Parent: {parent_id}") for ent in meta.get("entities", []): eid = str(ent.get("id", ""))[:8] diff --git a/lib/crewai/src/crewai/cli/checkpoint_tui.py b/lib/crewai/src/crewai/cli/checkpoint_tui.py index c58ed2ea6..e0d10f813 100644 --- a/lib/crewai/src/crewai/cli/checkpoint_tui.py +++ b/lib/crewai/src/crewai/cli/checkpoint_tui.py @@ -2,17 +2,23 @@ from __future__ import annotations +from collections import defaultdict from typing import Any, ClassVar from textual.app import App, ComposeResult from textual.binding import Binding -from textual.containers import Horizontal, Vertical -from textual.screen import ModalScreen -from textual.widgets import Button, Footer, Header, OptionList, Static -from textual.widgets.option_list import Option +from textual.containers import Horizontal, Vertical, VerticalScroll +from textual.widgets import ( + Button, + Footer, + Header, + Input, + Static, + TextArea, + Tree, +) from crewai.cli.checkpoint_cli import ( - _entity_summary, _format_size, _is_sqlite, _list_json, @@ -34,151 +40,54 @@ def _load_entries(location: str) -> list[dict[str, Any]]: return _list_json(location) -def _format_list_label(entry: dict[str, Any]) -> str: - """Format a checkpoint entry for the list panel.""" - name = entry.get("name", "") - ts = entry.get("ts") or "" - trigger = entry.get("trigger") or "" - summary = _entity_summary(entry.get("entities", [])) - - line1 = f"[bold]{name}[/]" - parts = [] - if ts: - parts.append(f"[dim]{ts}[/]") - if "size" in entry: - parts.append(f"[dim]{_format_size(entry['size'])}[/]") - if trigger: - parts.append(f"[{_PRIMARY}]{trigger}[/]") - line2 = " ".join(parts) - line3 = f" [{_DIM}]{summary}[/]" - - return f"{line1}\n{line2}\n{line3}" +def _short_id(name: str) -> str: + """Shorten a checkpoint name for tree display.""" + if len(name) > 30: + return name[:27] + "..." + return name -def _format_detail(entry: dict[str, Any]) -> str: - """Format checkpoint details for the right panel.""" +def _entry_id(entry: dict[str, Any]) -> str: + """Normalize an entry's name into its checkpoint ID. + + JSON filenames are ``{ts}_{uuid}_p-{parent}.json``; SQLite IDs + are already ``{ts}_{uuid}``. This strips the JSON suffix so + fork-parent lookups work in both providers. + """ + name = str(entry.get("name", "")) + if name.endswith(".json"): + name = name[: -len(".json")] + idx = name.find("_p-") + if idx != -1: + name = name[:idx] + return name + + +def _build_entity_header(ent: dict[str, Any]) -> str: + """Build rich text header for an entity (progress bar only).""" lines: list[str] = [] - - # Header - name = entry.get("name", "") - lines.append(f"[bold {_PRIMARY}]{name}[/]") - lines.append(f"[{_DIM}]{'─' * 50}[/]") - lines.append("") - - # Metadata table - ts = entry.get("ts") or "unknown" - trigger = entry.get("trigger") or "" - lines.append(f" [bold]Time[/] {ts}") - if "size" in entry: - lines.append(f" [bold]Size[/] {_format_size(entry['size'])}") - lines.append(f" [bold]Events[/] {entry.get('event_count', 0)}") - if trigger: - lines.append(f" [bold]Trigger[/] [{_PRIMARY}]{trigger}[/]") - if "path" in entry: - lines.append(f" [bold]Path[/] [{_DIM}]{entry['path']}[/]") - if "db" in entry: - lines.append(f" [bold]Database[/] [{_DIM}]{entry['db']}[/]") - - # Entities - for ent in entry.get("entities", []): - eid = str(ent.get("id", ""))[:8] - etype = ent.get("type", "unknown") - ename = ent.get("name", "unnamed") - - lines.append("") - lines.append(f" [{_DIM}]{'─' * 50}[/]") - lines.append(f" [bold {_SECONDARY}]{etype}[/]: {ename} [{_DIM}]{eid}[/]") - - tasks = ent.get("tasks") - if isinstance(tasks, list): - completed = ent.get("tasks_completed", 0) - total = ent.get("tasks_total", 0) - pct = int(completed / total * 100) if total else 0 - bar_len = 20 - filled = int(bar_len * completed / total) if total else 0 - bar = f"[{_PRIMARY}]{'█' * filled}[/][{_DIM}]{'░' * (bar_len - filled)}[/]" - - lines.append(f" {bar} {completed}/{total} tasks ({pct}%)") - lines.append("") - - for i, task in enumerate(tasks): - if task.get("completed"): - icon = "[green]✓[/]" - else: - icon = "[yellow]○[/]" - desc = str(task.get("description", "")) - if len(desc) > 55: - desc = desc[:52] + "..." - lines.append(f" {icon} {i + 1}. {desc}") - + tasks = ent.get("tasks") + if isinstance(tasks, list): + completed = ent.get("tasks_completed", 0) + total = ent.get("tasks_total", 0) + pct = int(completed / total * 100) if total else 0 + bar_len = 20 + filled = int(bar_len * completed / total) if total else 0 + bar = f"[{_PRIMARY}]{'█' * filled}[/][{_DIM}]{'░' * (bar_len - filled)}[/]" + lines.append(f"{bar} {completed}/{total} tasks ({pct}%)") return "\n".join(lines) -class ConfirmResumeScreen(ModalScreen[bool]): - """Modal confirmation before resuming from a checkpoint.""" - - CSS = f""" - ConfirmResumeScreen {{ - align: center middle; - }} - #confirm-dialog {{ - width: 60; - height: auto; - padding: 1 2; - background: {_BG_PANEL}; - border: round {_PRIMARY}; - }} - #confirm-label {{ - width: 100%; - content-align: center middle; - margin-bottom: 1; - }} - #confirm-name {{ - width: 100%; - content-align: center middle; - color: {_PRIMARY}; - text-style: bold; - margin-bottom: 1; - }} - #confirm-buttons {{ - width: 100%; - height: 3; - layout: horizontal; - align: center middle; - }} - Button {{ - margin: 0 2; - min-width: 12; - }} - """ - - def __init__(self, checkpoint_name: str) -> None: - super().__init__() - self._checkpoint_name = checkpoint_name - - def compose(self) -> ComposeResult: - with Vertical(id="confirm-dialog"): - yield Static("Resume from this checkpoint?", id="confirm-label") - yield Static(self._checkpoint_name, id="confirm-name") - with Horizontal(id="confirm-buttons"): - yield Button("Resume", variant="success", id="btn-yes") - yield Button("Cancel", variant="default", id="btn-no") - - def on_button_pressed(self, event: Button.Pressed) -> None: - self.dismiss(event.button.id == "btn-yes") - - def on_key(self, event: Any) -> None: - if event.key == "y": - self.dismiss(True) - elif event.key in ("n", "escape"): - self.dismiss(False) +# Return type: (location, action, inputs, task_output_overrides) +_TuiResult = tuple[str, str, dict[str, Any] | None, dict[int, str] | None] | None -class CheckpointTUI(App[str | None]): +class CheckpointTUI(App[_TuiResult]): """TUI to browse and inspect checkpoints. - Returns the checkpoint location string to resume from, or None if - the user quit without selecting. + Returns ``(location, action, inputs)`` where action is ``"resume"`` or + ``"fork"`` and inputs is a parsed dict or ``None``, + or ``None`` if the user quit without selecting. """ TITLE = "CrewAI Checkpoints" @@ -199,145 +108,431 @@ class CheckpointTUI(App[str | None]): background: {_PRIMARY}; color: {_TERTIARY}; }} - Horizontal {{ + #main-layout {{ height: 1fr; }} - #cp-list {{ - width: 38%; + #tree-panel {{ + width: 45%; background: {_BG_PANEL}; border: round {_SECONDARY}; padding: 0 1; scrollbar-color: {_PRIMARY}; }} - #cp-list:focus {{ + #tree-panel:focus-within {{ border: round {_PRIMARY}; }} - #cp-list > .option-list--option-highlighted {{ - background: {_SECONDARY}; - color: {_TERTIARY}; - text-style: none; - }} - #cp-list > .option-list--option-highlighted * {{ - color: {_TERTIARY}; - }} #detail-container {{ - width: 62%; - padding: 0 1; + width: 55%; + height: 1fr; }} - #detail {{ + #detail-scroll {{ height: 1fr; background: {_BG_PANEL}; border: round {_SECONDARY}; padding: 1 2; - overflow-y: auto; scrollbar-color: {_PRIMARY}; }} - #detail:focus {{ + #detail-scroll:focus-within {{ border: round {_PRIMARY}; }} + #detail-header {{ + margin-bottom: 1; + }} #status {{ height: 1; padding: 0 2; color: {_DIM}; }} + #inputs-section {{ + display: none; + height: auto; + max-height: 8; + padding: 0 1; + }} + #inputs-section.visible {{ + display: block; + }} + #inputs-label {{ + height: 1; + color: {_DIM}; + padding: 0 1; + }} + .input-row {{ + height: 3; + padding: 0 1; + }} + .input-row Static {{ + width: auto; + min-width: 12; + padding: 1 1 0 0; + color: {_TERTIARY}; + }} + .input-row Input {{ + width: 1fr; + }} + #no-inputs-label {{ + height: 1; + color: {_DIM}; + padding: 0 1; + }} + #action-buttons {{ + height: 3; + align: right middle; + padding: 0 1; + display: none; + }} + #action-buttons.visible {{ + display: block; + }} + #action-buttons Button {{ + margin: 0 0 0 1; + min-width: 10; + }} + #btn-resume {{ + background: {_SECONDARY}; + color: {_TERTIARY}; + }} + #btn-resume:hover {{ + background: {_PRIMARY}; + }} + #btn-fork {{ + background: {_PRIMARY}; + color: {_TERTIARY}; + }} + #btn-fork:hover {{ + background: {_SECONDARY}; + }} + .entity-title {{ + padding: 1 1 0 1; + }} + .entity-detail {{ + padding: 0 1; + }} + .task-output-editor {{ + height: auto; + max-height: 10; + margin: 0 1 1 1; + border: round {_DIM}; + }} + .task-output-editor:focus {{ + border: round {_PRIMARY}; + }} + .task-label {{ + padding: 0 1; + }} + Tree {{ + background: {_BG_PANEL}; + }} + Tree > .tree--cursor {{ + background: {_SECONDARY}; + color: {_TERTIARY}; + }} """ BINDINGS: ClassVar[list[Binding | tuple[str, str] | tuple[str, str, str]]] = [ ("q", "quit", "Quit"), ("r", "refresh", "Refresh"), - ("j", "cursor_down", "Down"), - ("k", "cursor_up", "Up"), ] def __init__(self, location: str = "./.checkpoints") -> None: super().__init__() self._location = location self._entries: list[dict[str, Any]] = [] - self._selected_idx: int = 0 - self._pending_location: str = "" + self._selected_entry: dict[str, Any] | None = None + self._input_keys: list[str] = [] + self._task_output_ids: list[tuple[int, str, str]] = [] def compose(self) -> ComposeResult: yield Header(show_clock=False) - with Horizontal(): - yield OptionList(id="cp-list") + with Horizontal(id="main-layout"): + tree: Tree[dict[str, Any]] = Tree("Checkpoints", id="tree-panel") + tree.show_root = True + tree.guide_depth = 3 + yield tree with Vertical(id="detail-container"): yield Static("", id="status") - yield Static( - f"\n [{_DIM}]Select a checkpoint from the list[/]", # noqa: S608 - id="detail", - ) + with VerticalScroll(id="detail-scroll"): + yield Static( + f"[{_DIM}]Select a checkpoint from the tree[/]", # noqa: S608 + id="detail-header", + ) + with Vertical(id="inputs-section"): + yield Static("Inputs", id="inputs-label") + with Horizontal(id="action-buttons"): + yield Button("Resume", id="btn-resume") + yield Button("Fork", id="btn-fork") yield Footer() async def on_mount(self) -> None: - self.query_one("#cp-list", OptionList).border_title = "Checkpoints" - self.query_one("#detail", Static).border_title = "Detail" - self._refresh_list() + self._refresh_tree() + self.query_one("#tree-panel", Tree).root.expand() - def _refresh_list(self) -> None: + def _refresh_tree(self) -> None: self._entries = _load_entries(self._location) - option_list = self.query_one("#cp-list", OptionList) - option_list.clear_options() + self._selected_entry = None + + tree = self.query_one("#tree-panel", Tree) + tree.clear() if not self._entries: - self.query_one("#detail", Static).update( - f"\n [{_DIM}]No checkpoints in {self._location}[/]" + self.query_one("#detail-header", Static).update( + f"[{_DIM}]No checkpoints in {self._location}[/]" ) self.query_one("#status", Static).update("") self.sub_title = self._location return + # Group by branch + branches: dict[str, list[dict[str, Any]]] = defaultdict(list) for entry in self._entries: - option_list.add_option(Option(_format_list_label(entry))) + branch = entry.get("branch", "main") + branches[branch].append(entry) + + # Index checkpoint names to tree nodes so forks can attach + node_by_name: dict[str, Any] = {} + + def _make_label(e: dict[str, Any]) -> str: + name = e.get("name", "") + ts = e.get("ts") or "" + trigger = e.get("trigger") or "" + parts = [f"[bold]{_short_id(name)}[/]"] + if ts: + time_part = ts.split(" ")[-1] if " " in ts else ts + parts.append(f"[{_DIM}]{time_part}[/]") + if trigger: + parts.append(f"[{_PRIMARY}]{trigger}[/]") + return " ".join(parts) + + fork_parents: set[str] = set() + for branch_name, entries in branches.items(): + if branch_name == "main" or not entries: + continue + oldest = min(entries, key=lambda e: str(e.get("name", ""))) + first_parent = oldest.get("parent_id") + if first_parent: + fork_parents.add(str(first_parent)) + + def _add_checkpoint(parent_node: Any, e: dict[str, Any]) -> None: + """Add a checkpoint node — expandable only if a fork attaches to it.""" + cp_id = _entry_id(e) + if cp_id in fork_parents: + node = parent_node.add( + _make_label(e), data=e, expand=False, allow_expand=True + ) + else: + node = parent_node.add_leaf(_make_label(e), data=e) + node_by_name[cp_id] = node + + if "main" in branches: + for entry in reversed(branches["main"]): + _add_checkpoint(tree.root, entry) + + fork_branches = [ + (name, sorted(entries, key=lambda e: str(e.get("name", "")))) + for name, entries in branches.items() + if name != "main" + ] + remaining = fork_branches + max_passes = len(remaining) + 1 + while remaining and max_passes > 0: + max_passes -= 1 + deferred = [] + made_progress = False + for branch_name, entries in remaining: + first_parent = entries[0].get("parent_id") if entries else None + if first_parent and str(first_parent) not in node_by_name: + deferred.append((branch_name, entries)) + continue + attach_to: Any = tree.root + if first_parent: + attach_to = node_by_name.get(str(first_parent), tree.root) + branch_label = ( + f"[bold {_SECONDARY}]{branch_name}[/] [{_DIM}]({len(entries)})[/]" + ) + branch_node = attach_to.add(branch_label, expand=False) + for entry in entries: + _add_checkpoint(branch_node, entry) + made_progress = True + remaining = deferred + if not made_progress: + break + + for branch_name, entries in remaining: + branch_label = ( + f"[bold {_SECONDARY}]{branch_name}[/] " + f"[{_DIM}]({len(entries)})[/] [{_DIM}](orphaned)[/]" + ) + branch_node = tree.root.add(branch_label, expand=False) + for entry in entries: + _add_checkpoint(branch_node, entry) count = len(self._entries) storage = "SQLite" if _is_sqlite(self._location) else "JSON" - self.sub_title = f"{self._location}" + self.sub_title = self._location self.query_one("#status", Static).update(f" {count} checkpoint(s) | {storage}") - async def on_option_list_option_highlighted( - self, - event: OptionList.OptionHighlighted, - ) -> None: - idx = event.option_index - if idx is None: - return - if idx < len(self._entries): - self._selected_idx = idx - entry = self._entries[idx] - self.query_one("#detail", Static).update(_format_detail(entry)) + async def _show_detail(self, entry: dict[str, Any]) -> None: + """Update the detail panel for a checkpoint entry.""" + self._selected_entry = entry + self.query_one("#action-buttons").add_class("visible") - def action_cursor_down(self) -> None: - self.query_one("#cp-list", OptionList).action_cursor_down() + detail_scroll = self.query_one("#detail-scroll", VerticalScroll) - def action_cursor_up(self) -> None: - self.query_one("#cp-list", OptionList).action_cursor_up() + # Remove all dynamic children except the header — await so IDs are freed + to_remove = [c for c in detail_scroll.children if c.id != "detail-header"] + for child in to_remove: + await child.remove() - async def on_option_list_option_selected( - self, - event: OptionList.OptionSelected, - ) -> None: - idx = event.option_index - if idx is None or idx >= len(self._entries): - return - entry = self._entries[idx] + # Header + name = entry.get("name", "") + ts = entry.get("ts") or "unknown" + trigger = entry.get("trigger") or "" + branch = entry.get("branch", "main") + parent_id = entry.get("parent_id") + + header_lines = [ + f"[bold {_PRIMARY}]{name}[/]", + f"[{_DIM}]{'─' * 50}[/]", + "", + f" [bold]Time[/] {ts}", + ] + if "size" in entry: + header_lines.append(f" [bold]Size[/] {_format_size(entry['size'])}") + header_lines.append(f" [bold]Events[/] {entry.get('event_count', 0)}") + if trigger: + header_lines.append(f" [bold]Trigger[/] [{_PRIMARY}]{trigger}[/]") + header_lines.append(f" [bold]Branch[/] [{_SECONDARY}]{branch}[/]") + if parent_id: + header_lines.append(f" [bold]Parent[/] [{_DIM}]{parent_id}[/]") if "path" in entry: - loc = entry["path"] - elif _is_sqlite(self._location): - loc = f"{self._location}#{entry['name']}" - else: - loc = entry.get("name", "") - self._pending_location = loc - name = entry.get("name", loc) - self.push_screen(ConfirmResumeScreen(name), self._on_confirm) + header_lines.append(f" [bold]Path[/] [{_DIM}]{entry['path']}[/]") + if "db" in entry: + header_lines.append(f" [bold]Database[/] [{_DIM}]{entry['db']}[/]") - def _on_confirm(self, confirmed: bool | None) -> None: - if confirmed: - self.exit(self._pending_location) - else: - self._pending_location = "" + self.query_one("#detail-header", Static).update("\n".join(header_lines)) + + # Entity details and editable task outputs — mounted flat for scrolling + self._task_output_ids = [] + flat_task_idx = 0 + for ent_idx, ent in enumerate(entry.get("entities", [])): + etype = ent.get("type", "unknown") + ename = ent.get("name", "unnamed") + completed = ent.get("tasks_completed") + total = ent.get("tasks_total") + entity_title = f"[bold {_SECONDARY}]{etype}: {ename}[/]" + if completed is not None and total is not None: + entity_title += f" [{_DIM}]{completed}/{total} tasks[/]" + await detail_scroll.mount(Static(entity_title, classes="entity-title")) + await detail_scroll.mount( + Static(_build_entity_header(ent), classes="entity-detail") + ) + + tasks = ent.get("tasks", []) + for i, task in enumerate(tasks): + desc = str(task.get("description", "")) + if len(desc) > 55: + desc = desc[:52] + "..." + if task.get("completed"): + icon = "[green]✓[/]" + await detail_scroll.mount( + Static(f" {icon} {i + 1}. {desc}", classes="task-label") + ) + output_text = task.get("output", "") + editor_id = f"task-output-{ent_idx}-{i}" + await detail_scroll.mount( + TextArea( + str(output_text), + classes="task-output-editor", + id=editor_id, + ) + ) + self._task_output_ids.append( + (flat_task_idx, editor_id, str(output_text)) + ) + else: + icon = "[yellow]○[/]" + await detail_scroll.mount( + Static(f" {icon} {i + 1}. {desc}", classes="task-label") + ) + flat_task_idx += 1 + + # Build input fields + await self._build_input_fields(entry.get("inputs", {})) + + async def _build_input_fields(self, inputs: dict[str, Any]) -> None: + """Rebuild the inputs section with one field per input key.""" + section = self.query_one("#inputs-section") + + # Remove old dynamic children — await so IDs are freed + for widget in list(section.query(".input-row, .no-inputs")): + await widget.remove() + + self._input_keys = [] + + if not inputs: + await section.mount(Static(f"[{_DIM}]No inputs[/]", classes="no-inputs")) + section.add_class("visible") + return + + for key, value in inputs.items(): + self._input_keys.append(key) + row = Horizontal(classes="input-row") + row.compose_add_child(Static(f"[bold]{key}[/]")) + row.compose_add_child( + Input(value=str(value), placeholder=key, id=f"input-{key}") + ) + await section.mount(row) + + section.add_class("visible") + + def _collect_inputs(self) -> dict[str, Any] | None: + """Collect current values from input fields.""" + if not self._input_keys: + return None + result: dict[str, Any] = {} + for key in self._input_keys: + widget = self.query_one(f"#input-{key}", Input) + result[key] = widget.value + return result + + def _collect_task_overrides(self) -> dict[int, str] | None: + """Collect edited task outputs. Returns only changed values.""" + if not self._task_output_ids or self._selected_entry is None: + return None + overrides: dict[int, str] = {} + for task_idx, editor_id, original in self._task_output_ids: + editor = self.query_one(f"#{editor_id}", TextArea) + if editor.text != original: + overrides[task_idx] = editor.text + return overrides or None + + def _resolve_location(self, entry: dict[str, Any]) -> str: + """Get the restore location string for a checkpoint entry.""" + if "path" in entry: + return str(entry["path"]) + if _is_sqlite(self._location): + return f"{self._location}#{entry['name']}" + return str(entry.get("name", "")) + + async def on_tree_node_highlighted( + self, event: Tree.NodeHighlighted[dict[str, Any]] + ) -> None: + if event.node.data is not None: + await self._show_detail(event.node.data) + + def on_button_pressed(self, event: Button.Pressed) -> None: + if self._selected_entry is None: + return + inputs = self._collect_inputs() + overrides = self._collect_task_overrides() + loc = self._resolve_location(self._selected_entry) + if event.button.id == "btn-resume": + self.exit((loc, "resume", inputs, overrides)) + elif event.button.id == "btn-fork": + self.exit((loc, "fork", inputs, overrides)) def action_refresh(self) -> None: - self._refresh_list() + self._refresh_tree() async def _run_checkpoint_tui_async(location: str) -> None: @@ -345,18 +540,78 @@ async def _run_checkpoint_tui_async(location: str) -> None: import click app = CheckpointTUI(location=location) - selected = await app.run_async() + selection = await app.run_async() - if selected is None: + if selection is None: return - click.echo(f"\nResuming from: {selected}\n") + selected, action, inputs, task_overrides = selection from crewai.crew import Crew from crewai.state.checkpoint_config import CheckpointConfig - crew = Crew.from_checkpoint(CheckpointConfig(restore_from=selected)) - result = await crew.akickoff() + config = CheckpointConfig(restore_from=selected) + + if action == "fork": + click.echo(f"\nForking from: {selected}\n") + crew = Crew.fork(config) + else: + click.echo(f"\nResuming from: {selected}\n") + crew = Crew.from_checkpoint(config) + + if task_overrides: + click.echo("Modifications:") + overridden_agents: set[int] = set() + for task_idx, new_output in task_overrides.items(): + if task_idx < len(crew.tasks) and crew.tasks[task_idx].output is not None: + desc = crew.tasks[task_idx].description or f"Task {task_idx + 1}" + if len(desc) > 60: + desc = desc[:57] + "..." + crew.tasks[task_idx].output.raw = new_output # type: ignore[union-attr] + preview = new_output.replace("\n", " ") + if len(preview) > 80: + preview = preview[:77] + "..." + click.echo(f" Task {task_idx + 1}: {desc}") + click.echo(f" -> {preview}") + agent = crew.tasks[task_idx].agent + if agent and agent.agent_executor: + nth = sum(1 for t in crew.tasks[:task_idx] if t.agent is agent) + messages = agent.agent_executor.messages + system_positions = [ + i for i, m in enumerate(messages) if m.get("role") == "system" + ] + if nth < len(system_positions): + seg_start = system_positions[nth] + seg_end = ( + system_positions[nth + 1] + if nth + 1 < len(system_positions) + else len(messages) + ) + for j in range(seg_end - 1, seg_start, -1): + if messages[j].get("role") == "assistant": + messages[j]["content"] = new_output + break + overridden_agents.add(id(agent)) + + earliest = min(task_overrides) + for offset, subsequent in enumerate( + crew.tasks[earliest + 1 :], start=earliest + 1 + ): + if subsequent.output and offset not in task_overrides: + subsequent.output = None + if subsequent.agent and subsequent.agent.agent_executor: + subsequent.agent.agent_executor._resuming = False + if id(subsequent.agent) not in overridden_agents: + subsequent.agent.agent_executor.messages = [] + click.echo() + + if inputs: + click.echo("Inputs:") + for k, v in inputs.items(): + click.echo(f" {k}: {v}") + click.echo() + + result = await crew.akickoff(inputs=inputs) click.echo(f"\nResult: {getattr(result, 'raw', result)}") diff --git a/lib/crewai/src/crewai/cli/cli.py b/lib/crewai/src/crewai/cli/cli.py index 20a65dbe1..03bdcd502 100644 --- a/lib/crewai/src/crewai/cli/cli.py +++ b/lib/crewai/src/crewai/cli/cli.py @@ -793,6 +793,9 @@ def traces_status() -> None: @click.pass_context def checkpoint(ctx: click.Context, location: str) -> None: """Browse and inspect checkpoints. Launches a TUI when called without a subcommand.""" + from crewai.cli.checkpoint_cli import _detect_location + + location = _detect_location(location) ctx.ensure_object(dict) ctx.obj["location"] = location if ctx.invoked_subcommand is None: @@ -805,18 +808,18 @@ def checkpoint(ctx: click.Context, location: str) -> None: @click.argument("location", default="./.checkpoints") def checkpoint_list(location: str) -> None: """List checkpoints in a directory.""" - from crewai.cli.checkpoint_cli import list_checkpoints + from crewai.cli.checkpoint_cli import _detect_location, list_checkpoints - list_checkpoints(location) + list_checkpoints(_detect_location(location)) @checkpoint.command("info") @click.argument("path", default="./.checkpoints") def checkpoint_info(path: str) -> None: """Show details of a checkpoint. Pass a file or directory for latest.""" - from crewai.cli.checkpoint_cli import info_checkpoint + from crewai.cli.checkpoint_cli import _detect_location, info_checkpoint - info_checkpoint(path) + info_checkpoint(_detect_location(path)) if __name__ == "__main__": diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index 111babbf3..de9a8f73d 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -436,6 +436,13 @@ class Crew(FlowTrackable, BaseModel): if agent.agent_executor is not None and task.output is None: agent.agent_executor.task = task break + for task in self.tasks: + if task.checkpoint_original_description is not None: + task._original_description = task.checkpoint_original_description + if task.checkpoint_original_expected_output is not None: + task._original_expected_output = ( + task.checkpoint_original_expected_output + ) if self.checkpoint_inputs is not None: self._inputs = self.checkpoint_inputs if self.checkpoint_kickoff_event_id is not None: diff --git a/lib/crewai/src/crewai/state/checkpoint_config.py b/lib/crewai/src/crewai/state/checkpoint_config.py index bd24ef8eb..e03964c05 100644 --- a/lib/crewai/src/crewai/state/checkpoint_config.py +++ b/lib/crewai/src/crewai/state/checkpoint_config.py @@ -213,6 +213,9 @@ class CheckpointConfig(BaseModel): def _register_handlers(self) -> CheckpointConfig: from crewai.state.checkpoint_listener import _ensure_handlers_registered + if isinstance(self.provider, SqliteProvider) and not Path(self.location).suffix: + self.location = f"{self.location}.db" + _ensure_handlers_registered() return self diff --git a/lib/crewai/src/crewai/state/checkpoint_listener.py b/lib/crewai/src/crewai/state/checkpoint_listener.py index c2ac728a8..2408e88e3 100644 --- a/lib/crewai/src/crewai/state/checkpoint_listener.py +++ b/lib/crewai/src/crewai/state/checkpoint_listener.py @@ -7,6 +7,7 @@ avoids per-event overhead when no entity uses checkpointing. from __future__ import annotations +import json import logging import threading from typing import Any @@ -102,10 +103,15 @@ def _find_checkpoint(source: Any) -> CheckpointConfig | None: return None -def _do_checkpoint(state: RuntimeState, cfg: CheckpointConfig) -> None: +def _do_checkpoint( + state: RuntimeState, cfg: CheckpointConfig, event: BaseEvent | None = None +) -> None: """Write a checkpoint and prune old ones if configured.""" _prepare_entities(state.root) - data = state.model_dump_json() + payload = state.model_dump(mode="json") + if event is not None: + payload["trigger"] = event.type + data = json.dumps(payload) location = cfg.provider.checkpoint( data, cfg.location, @@ -134,7 +140,7 @@ def _on_any_event(source: Any, event: BaseEvent, state: Any) -> None: if cfg is None: return try: - _do_checkpoint(state, cfg) + _do_checkpoint(state, cfg, event) except Exception: logger.warning("Auto-checkpoint failed for event %s", event.type, exc_info=True) diff --git a/lib/crewai/src/crewai/state/runtime.py b/lib/crewai/src/crewai/state/runtime.py index 3f32457bb..daae0620e 100644 --- a/lib/crewai/src/crewai/state/runtime.py +++ b/lib/crewai/src/crewai/state/runtime.py @@ -66,6 +66,9 @@ def _sync_checkpoint_fields(entity: object) -> None: entity.checkpoint_inputs = entity._inputs entity.checkpoint_train = entity._train entity.checkpoint_kickoff_event_id = entity._kickoff_event_id + for task in entity.tasks: + task.checkpoint_original_description = task._original_description + task.checkpoint_original_expected_output = task._original_expected_output def _migrate(data: dict[str, Any]) -> dict[str, Any]: @@ -121,7 +124,7 @@ class RuntimeState(RootModel): # type: ignore[type-arg] "parent_id": self._parent_id, "branch": self._branch, "entities": [e.model_dump(mode="json") for e in self.root], - "event_record": self._event_record.model_dump(), + "event_record": self._event_record.model_dump(mode="json"), } @model_validator(mode="wrap") @@ -194,7 +197,10 @@ class RuntimeState(RootModel): # type: ignore[type-arg] return result def fork(self, branch: str | None = None) -> None: - """Mark this state as a fork for subsequent checkpoints. + """Create a new execution branch and write an initial checkpoint. + + If this state was restored from a checkpoint, an initial checkpoint + is written on the new branch so the fork point is recorded. Args: branch: Branch label. Auto-generated from the current checkpoint @@ -204,7 +210,7 @@ class RuntimeState(RootModel): # type: ignore[type-arg] if branch: self._branch = branch elif self._checkpoint_id: - self._branch = f"fork/{self._checkpoint_id}" + self._branch = f"fork/{self._checkpoint_id}_{uuid.uuid4().hex[:6]}" else: self._branch = f"fork/{uuid.uuid4().hex[:8]}" @@ -227,6 +233,7 @@ class RuntimeState(RootModel): # type: ignore[type-arg] provider = detect_provider(location) raw = provider.from_checkpoint(location) state = cls.model_validate_json(raw, **kwargs) + state._provider = provider checkpoint_id = provider.extract_id(location) state._checkpoint_id = checkpoint_id state._parent_id = checkpoint_id @@ -253,6 +260,7 @@ class RuntimeState(RootModel): # type: ignore[type-arg] provider = detect_provider(location) raw = await provider.afrom_checkpoint(location) state = cls.model_validate_json(raw, **kwargs) + state._provider = provider checkpoint_id = provider.extract_id(location) state._checkpoint_id = checkpoint_id state._parent_id = checkpoint_id diff --git a/lib/crewai/src/crewai/task.py b/lib/crewai/src/crewai/task.py index 4224828e3..0a159cb0e 100644 --- a/lib/crewai/src/crewai/task.py +++ b/lib/crewai/src/crewai/task.py @@ -230,6 +230,8 @@ class Task(BaseModel): _original_description: str | None = PrivateAttr(default=None) _original_expected_output: str | None = PrivateAttr(default=None) _original_output_file: str | None = PrivateAttr(default=None) + checkpoint_original_description: str | None = Field(default=None, exclude=False) + checkpoint_original_expected_output: str | None = Field(default=None, exclude=False) _thread: threading.Thread | None = PrivateAttr(default=None) model_config = {"arbitrary_types_allowed": True} diff --git a/lib/crewai/tests/test_checkpoint.py b/lib/crewai/tests/test_checkpoint.py index 1f1541790..f645541a4 100644 --- a/lib/crewai/tests/test_checkpoint.py +++ b/lib/crewai/tests/test_checkpoint.py @@ -296,7 +296,8 @@ class TestRuntimeStateLineage: state = self._make_state() state._checkpoint_id = "20260409T120000_abc12345" state.fork() - assert state._branch == "fork/20260409T120000_abc12345" + assert state._branch.startswith("fork/20260409T120000_abc12345_") + assert len(state._branch) == len("fork/20260409T120000_abc12345_") + 6 def test_fork_no_checkpoint_id_unique(self) -> None: state = self._make_state()