diff --git a/docs/en/concepts/checkpointing.mdx b/docs/en/concepts/checkpointing.mdx index 21ed13905..d1f5fecda 100644 --- a/docs/en/concepts/checkpointing.mdx +++ b/docs/en/concepts/checkpointing.mdx @@ -54,6 +54,7 @@ crew = Crew( | `on_events` | `list[str]` | `["task_completed"]` | Event types that trigger a checkpoint | | `provider` | `BaseProvider` | `JsonProvider()` | Storage backend | | `max_checkpoints` | `int \| None` | `None` | Max checkpoints to keep. Oldest are pruned after each write. Pruning is handled by the provider. | +| `restore_from` | `Path \| str \| None` | `None` | Path to a checkpoint to restore from. Used when passing config via a kickoff method's `from_checkpoint` parameter. | ### Inheritance and Opt-Out @@ -79,13 +80,42 @@ crew = Crew( ## Resuming from a Checkpoint +Pass a `CheckpointConfig` with `restore_from` to any kickoff method. The crew restores from that checkpoint, skips completed tasks, and resumes. + ```python -# Restore and resume -crew = Crew.from_checkpoint("./my_checkpoints/20260407T120000_abc123.json") -result = crew.kickoff() # picks up from last completed task +from crewai import Crew, CheckpointConfig + +crew = Crew(agents=[...], tasks=[...]) +result = crew.kickoff( + from_checkpoint=CheckpointConfig( + restore_from="./my_checkpoints/20260407T120000_abc123.json", + ), +) ``` -The restored crew skips already-completed tasks and resumes from the first incomplete one. +Remaining `CheckpointConfig` fields apply to the new run, so checkpointing continues after the restore. + +You can also use the classmethod directly: + +```python +config = CheckpointConfig(restore_from="./my_checkpoints/20260407T120000_abc123.json") +crew = Crew.from_checkpoint(config) +result = crew.kickoff() +``` + +## Forking from a Checkpoint + +`fork()` restores a checkpoint and starts a new execution branch. Useful for exploring alternative paths from the same point. + +```python +from crewai import Crew, CheckpointConfig + +config = CheckpointConfig(restore_from="./my_checkpoints/20260407T120000_abc123.json") +crew = Crew.fork(config, branch="experiment-a") +result = crew.kickoff(inputs={"strategy": "aggressive"}) +``` + +Each fork gets a unique lineage ID so checkpoints from different branches don't collide. The `branch` label is optional and auto-generated if omitted. ## Works on Crew, Flow, and Agent @@ -125,7 +155,8 @@ flow = MyFlow( result = flow.kickoff() # Resume -flow = MyFlow.from_checkpoint("./flow_cp/20260407T120000_abc123.json") +config = CheckpointConfig(restore_from="./flow_cp/20260407T120000_abc123.json") +flow = MyFlow.from_checkpoint(config) result = flow.kickoff() ``` @@ -231,3 +262,35 @@ async def on_llm_done_async(source, event, state): The `state` argument is the `RuntimeState` passed automatically by the event bus when your handler accepts 3 parameters. You can register handlers on any event type listed in the [Event Listeners](/en/concepts/event-listener) documentation. Checkpointing is best-effort: if a checkpoint write fails, the error is logged but execution continues uninterrupted. + +## CLI + +The `crewai checkpoint` command gives you a TUI for browsing, inspecting, resuming, and forking checkpoints. It auto-detects whether your checkpoints are JSON files or a SQLite database. + +```bash +# Launch the TUI — auto-detects .checkpoints/ or .checkpoints.db +crewai checkpoint + +# Point at a specific location +crewai checkpoint --location ./my_checkpoints +crewai checkpoint --location ./.checkpoints.db +``` + + + Checkpoint TUI + + +The left panel is a tree view. Checkpoints are grouped by branch, and forks nest under the checkpoint they diverged from. Select a checkpoint to see its metadata, entity state, and task progress in the detail panel. Hit **Resume** to pick up where it left off, or **Fork** to start a new branch from that point. + +### Subcommands + +```bash +# List all checkpoints +crewai checkpoint list ./my_checkpoints + +# Inspect a specific checkpoint +crewai checkpoint info ./my_checkpoints/20260407T120000_abc123.json + +# Inspect latest in a SQLite database +crewai checkpoint info ./.checkpoints.db +``` diff --git a/docs/images/checkpointing.png b/docs/images/checkpointing.png new file mode 100644 index 000000000..de1f4776a Binary files /dev/null and b/docs/images/checkpointing.png differ diff --git a/lib/crewai/src/crewai/cli/checkpoint_cli.py b/lib/crewai/src/crewai/cli/checkpoint_cli.py index c61500b20..9a6fcd7ea 100644 --- a/lib/crewai/src/crewai/cli/checkpoint_cli.py +++ b/lib/crewai/src/crewai/cli/checkpoint_cli.py @@ -34,6 +34,25 @@ LIMIT 1 """ +_DEFAULT_DIR = "./.checkpoints" +_DEFAULT_DB = "./.checkpoints.db" + + +def _detect_location(location: str) -> str: + """Resolve the default checkpoint location. + + When the caller passes the default directory path, check whether a + SQLite database exists at the conventional ``.db`` path and prefer it. + """ + if ( + location == _DEFAULT_DIR + and not os.path.exists(_DEFAULT_DIR) + and os.path.exists(_DEFAULT_DB) + ): + return _DEFAULT_DB + return location + + def _is_sqlite(path: str) -> bool: """Check if a file is a SQLite database by reading its magic bytes.""" if not os.path.isfile(path): @@ -86,6 +105,8 @@ def _parse_checkpoint_json(raw: str, source: str) -> dict[str, Any]: "event_count": event_count, "trigger": trigger_event, "entities": parsed_entities, + "branch": data.get("branch", "main"), + "parent_id": data.get("parent_id"), } @@ -311,6 +332,10 @@ def _print_info(meta: dict[str, Any]) -> None: trigger = meta.get("trigger") if trigger: click.echo(f"Trigger: {trigger}") + click.echo(f"Branch: {meta.get('branch', 'main')}") + parent_id = meta.get("parent_id") + if parent_id: + click.echo(f"Parent: {parent_id}") for ent in meta.get("entities", []): eid = str(ent.get("id", ""))[:8] diff --git a/lib/crewai/src/crewai/cli/checkpoint_tui.py b/lib/crewai/src/crewai/cli/checkpoint_tui.py index c58ed2ea6..50ac4195f 100644 --- a/lib/crewai/src/crewai/cli/checkpoint_tui.py +++ b/lib/crewai/src/crewai/cli/checkpoint_tui.py @@ -2,17 +2,22 @@ from __future__ import annotations +from collections import defaultdict from typing import Any, ClassVar from textual.app import App, ComposeResult from textual.binding import Binding -from textual.containers import Horizontal, Vertical -from textual.screen import ModalScreen -from textual.widgets import Button, Footer, Header, OptionList, Static -from textual.widgets.option_list import Option +from textual.containers import Horizontal, Vertical, VerticalScroll +from textual.widgets import ( + Button, + Collapsible, + Footer, + Header, + Static, + Tree, +) from crewai.cli.checkpoint_cli import ( - _entity_summary, _format_size, _is_sqlite, _list_json, @@ -34,151 +39,46 @@ def _load_entries(location: str) -> list[dict[str, Any]]: return _list_json(location) -def _format_list_label(entry: dict[str, Any]) -> str: - """Format a checkpoint entry for the list panel.""" - name = entry.get("name", "") - ts = entry.get("ts") or "" - trigger = entry.get("trigger") or "" - summary = _entity_summary(entry.get("entities", [])) - - line1 = f"[bold]{name}[/]" - parts = [] - if ts: - parts.append(f"[dim]{ts}[/]") - if "size" in entry: - parts.append(f"[dim]{_format_size(entry['size'])}[/]") - if trigger: - parts.append(f"[{_PRIMARY}]{trigger}[/]") - line2 = " ".join(parts) - line3 = f" [{_DIM}]{summary}[/]" - - return f"{line1}\n{line2}\n{line3}" +def _short_id(name: str) -> str: + """Shorten a checkpoint name for tree display.""" + if len(name) > 30: + return name[:27] + "..." + return name -def _format_detail(entry: dict[str, Any]) -> str: - """Format checkpoint details for the right panel.""" +def _build_entity_detail(ent: dict[str, Any]) -> str: + """Build rich text for a single entity.""" lines: list[str] = [] + eid = str(ent.get("id", ""))[:8] + etype = ent.get("type", "unknown") + ename = ent.get("name", "unnamed") + lines.append(f"[bold {_SECONDARY}]{etype}[/]: {ename} [{_DIM}]{eid}[/]") - # Header - name = entry.get("name", "") - lines.append(f"[bold {_PRIMARY}]{name}[/]") - lines.append(f"[{_DIM}]{'─' * 50}[/]") - lines.append("") - - # Metadata table - ts = entry.get("ts") or "unknown" - trigger = entry.get("trigger") or "" - lines.append(f" [bold]Time[/] {ts}") - if "size" in entry: - lines.append(f" [bold]Size[/] {_format_size(entry['size'])}") - lines.append(f" [bold]Events[/] {entry.get('event_count', 0)}") - if trigger: - lines.append(f" [bold]Trigger[/] [{_PRIMARY}]{trigger}[/]") - if "path" in entry: - lines.append(f" [bold]Path[/] [{_DIM}]{entry['path']}[/]") - if "db" in entry: - lines.append(f" [bold]Database[/] [{_DIM}]{entry['db']}[/]") - - # Entities - for ent in entry.get("entities", []): - eid = str(ent.get("id", ""))[:8] - etype = ent.get("type", "unknown") - ename = ent.get("name", "unnamed") - + tasks = ent.get("tasks") + if isinstance(tasks, list): + completed = ent.get("tasks_completed", 0) + total = ent.get("tasks_total", 0) + pct = int(completed / total * 100) if total else 0 + bar_len = 20 + filled = int(bar_len * completed / total) if total else 0 + bar = f"[{_PRIMARY}]{'█' * filled}[/][{_DIM}]{'░' * (bar_len - filled)}[/]" + lines.append(f"{bar} {completed}/{total} tasks ({pct}%)") lines.append("") - lines.append(f" [{_DIM}]{'─' * 50}[/]") - lines.append(f" [bold {_SECONDARY}]{etype}[/]: {ename} [{_DIM}]{eid}[/]") - - tasks = ent.get("tasks") - if isinstance(tasks, list): - completed = ent.get("tasks_completed", 0) - total = ent.get("tasks_total", 0) - pct = int(completed / total * 100) if total else 0 - bar_len = 20 - filled = int(bar_len * completed / total) if total else 0 - bar = f"[{_PRIMARY}]{'█' * filled}[/][{_DIM}]{'░' * (bar_len - filled)}[/]" - - lines.append(f" {bar} {completed}/{total} tasks ({pct}%)") - lines.append("") - - for i, task in enumerate(tasks): - if task.get("completed"): - icon = "[green]✓[/]" - else: - icon = "[yellow]○[/]" - desc = str(task.get("description", "")) - if len(desc) > 55: - desc = desc[:52] + "..." - lines.append(f" {icon} {i + 1}. {desc}") + for i, task in enumerate(tasks): + icon = "[green]✓[/]" if task.get("completed") else "[yellow]○[/]" + desc = str(task.get("description", "")) + if len(desc) > 55: + desc = desc[:52] + "..." + lines.append(f" {icon} {i + 1}. {desc}") return "\n".join(lines) -class ConfirmResumeScreen(ModalScreen[bool]): - """Modal confirmation before resuming from a checkpoint.""" - - CSS = f""" - ConfirmResumeScreen {{ - align: center middle; - }} - #confirm-dialog {{ - width: 60; - height: auto; - padding: 1 2; - background: {_BG_PANEL}; - border: round {_PRIMARY}; - }} - #confirm-label {{ - width: 100%; - content-align: center middle; - margin-bottom: 1; - }} - #confirm-name {{ - width: 100%; - content-align: center middle; - color: {_PRIMARY}; - text-style: bold; - margin-bottom: 1; - }} - #confirm-buttons {{ - width: 100%; - height: 3; - layout: horizontal; - align: center middle; - }} - Button {{ - margin: 0 2; - min-width: 12; - }} - """ - - def __init__(self, checkpoint_name: str) -> None: - super().__init__() - self._checkpoint_name = checkpoint_name - - def compose(self) -> ComposeResult: - with Vertical(id="confirm-dialog"): - yield Static("Resume from this checkpoint?", id="confirm-label") - yield Static(self._checkpoint_name, id="confirm-name") - with Horizontal(id="confirm-buttons"): - yield Button("Resume", variant="success", id="btn-yes") - yield Button("Cancel", variant="default", id="btn-no") - - def on_button_pressed(self, event: Button.Pressed) -> None: - self.dismiss(event.button.id == "btn-yes") - - def on_key(self, event: Any) -> None: - if event.key == "y": - self.dismiss(True) - elif event.key in ("n", "escape"): - self.dismiss(False) - - -class CheckpointTUI(App[str | None]): +class CheckpointTUI(App[tuple[str, str] | None]): """TUI to browse and inspect checkpoints. - Returns the checkpoint location string to resume from, or None if - the user quit without selecting. + Returns ``(location, action)`` where action is ``"resume"`` or + ``"fork"``, or ``None`` if the user quit without selecting. """ TITLE = "CrewAI Checkpoints" @@ -199,145 +99,293 @@ class CheckpointTUI(App[str | None]): background: {_PRIMARY}; color: {_TERTIARY}; }} - Horizontal {{ + #main-layout {{ height: 1fr; }} - #cp-list {{ - width: 38%; + #tree-panel {{ + width: 45%; background: {_BG_PANEL}; border: round {_SECONDARY}; padding: 0 1; scrollbar-color: {_PRIMARY}; }} - #cp-list:focus {{ + #tree-panel:focus-within {{ border: round {_PRIMARY}; }} - #cp-list > .option-list--option-highlighted {{ - background: {_SECONDARY}; - color: {_TERTIARY}; - text-style: none; - }} - #cp-list > .option-list--option-highlighted * {{ - color: {_TERTIARY}; - }} #detail-container {{ - width: 62%; - padding: 0 1; + width: 55%; }} - #detail {{ + #detail-scroll {{ height: 1fr; background: {_BG_PANEL}; border: round {_SECONDARY}; padding: 1 2; - overflow-y: auto; scrollbar-color: {_PRIMARY}; }} - #detail:focus {{ + #detail-scroll:focus-within {{ border: round {_PRIMARY}; }} + #detail-header {{ + margin-bottom: 1; + }} #status {{ height: 1; padding: 0 2; color: {_DIM}; }} + #action-buttons {{ + height: 3; + align: center middle; + padding: 0 1; + display: none; + }} + #action-buttons.visible {{ + display: block; + }} + #action-buttons Button {{ + margin: 0 1; + min-width: 14; + }} + #btn-resume {{ + background: {_SECONDARY}; + color: {_TERTIARY}; + }} + #btn-resume:hover {{ + background: {_PRIMARY}; + }} + #btn-fork {{ + background: {_PRIMARY}; + color: {_TERTIARY}; + }} + #btn-fork:hover {{ + background: {_SECONDARY}; + }} + Collapsible {{ + margin: 0; + padding: 0; + }} + .entity-detail {{ + padding: 0 1; + }} + Tree {{ + background: {_BG_PANEL}; + }} + Tree > .tree--cursor {{ + background: {_SECONDARY}; + color: {_TERTIARY}; + }} """ BINDINGS: ClassVar[list[Binding | tuple[str, str] | tuple[str, str, str]]] = [ ("q", "quit", "Quit"), ("r", "refresh", "Refresh"), - ("j", "cursor_down", "Down"), - ("k", "cursor_up", "Up"), ] def __init__(self, location: str = "./.checkpoints") -> None: super().__init__() self._location = location self._entries: list[dict[str, Any]] = [] - self._selected_idx: int = 0 - self._pending_location: str = "" + self._selected_entry: dict[str, Any] | None = None def compose(self) -> ComposeResult: yield Header(show_clock=False) - with Horizontal(): - yield OptionList(id="cp-list") + with Horizontal(id="main-layout"): + tree: Tree[dict[str, Any]] = Tree("Checkpoints", id="tree-panel") + tree.show_root = True + tree.guide_depth = 3 + yield tree with Vertical(id="detail-container"): yield Static("", id="status") - yield Static( - f"\n [{_DIM}]Select a checkpoint from the list[/]", # noqa: S608 - id="detail", - ) + with VerticalScroll(id="detail-scroll"): + yield Static( + f"[{_DIM}]Select a checkpoint from the tree[/]", # noqa: S608 + id="detail-header", + ) + with Horizontal(id="action-buttons"): + yield Button("Resume", id="btn-resume") + yield Button("Fork", id="btn-fork") yield Footer() async def on_mount(self) -> None: - self.query_one("#cp-list", OptionList).border_title = "Checkpoints" - self.query_one("#detail", Static).border_title = "Detail" - self._refresh_list() + self._refresh_tree() + self.query_one("#tree-panel", Tree).root.expand() - def _refresh_list(self) -> None: + def _refresh_tree(self) -> None: self._entries = _load_entries(self._location) - option_list = self.query_one("#cp-list", OptionList) - option_list.clear_options() + self._selected_entry = None + + tree = self.query_one("#tree-panel", Tree) + tree.clear() if not self._entries: - self.query_one("#detail", Static).update( - f"\n [{_DIM}]No checkpoints in {self._location}[/]" + self.query_one("#detail-header", Static).update( + f"[{_DIM}]No checkpoints in {self._location}[/]" ) self.query_one("#status", Static).update("") self.sub_title = self._location return + # Group by branch + branches: dict[str, list[dict[str, Any]]] = defaultdict(list) for entry in self._entries: - option_list.add_option(Option(_format_list_label(entry))) + branch = entry.get("branch", "main") + branches[branch].append(entry) + + # Index checkpoint names to tree nodes so forks can attach + node_by_name: dict[str, Any] = {} + + def _make_label(entry: dict[str, Any]) -> str: + name = entry.get("name", "") + ts = entry.get("ts") or "" + trigger = entry.get("trigger") or "" + parts = [f"[bold]{_short_id(name)}[/]"] + if ts: + time_part = ts.split(" ")[-1] if " " in ts else ts + parts.append(f"[{_DIM}]{time_part}[/]") + if trigger: + parts.append(f"[{_PRIMARY}]{trigger}[/]") + return " ".join(parts) + + # Find which checkpoints are fork parents so they get expandable nodes + fork_parents: set[str] = set() + for branch_name, entries in branches.items(): + if branch_name == "main": + continue + first_parent = ( + entries[-1].get("parent_id") if entries else None + ) # reversed later; -1 is oldest + if first_parent: + fork_parents.add(str(first_parent)) + + def _add_checkpoint(parent_node: Any, entry: dict[str, Any]) -> None: + """Add a checkpoint node — expandable only if a fork attaches to it.""" + name = entry.get("name", "") + if name in fork_parents: + node = parent_node.add( + _make_label(entry), data=entry, expand=False, allow_expand=True + ) + else: + node = parent_node.add_leaf(_make_label(entry), data=entry) + node_by_name[name] = node + + # Build main branch directly under root (oldest to newest) + if "main" in branches: + for entry in reversed(branches["main"]): + _add_checkpoint(tree.root, entry) + + # Build fork branches — sort so parent forks are built before child forks + fork_branches = [ + (name, list(reversed(entries))) + for name, entries in branches.items() + if name != "main" + ] + # Process forks whose parent is already indexed first + remaining = fork_branches + max_passes = len(remaining) + 1 + while remaining and max_passes > 0: + max_passes -= 1 + deferred = [] + for branch_name, entries in remaining: + first_parent = entries[0].get("parent_id") if entries else None + if first_parent and str(first_parent) not in node_by_name: + deferred.append((branch_name, entries)) + continue + attach_to = ( + node_by_name.get(str(first_parent), tree.root) + if first_parent + else tree.root + ) + branch_label = ( + f"[bold {_SECONDARY}]{branch_name}[/] [{_DIM}]({len(entries)})[/]" + ) + branch_node = attach_to.add(branch_label, expand=False) + for entry in entries: + _add_checkpoint(branch_node, entry) + remaining = deferred count = len(self._entries) storage = "SQLite" if _is_sqlite(self._location) else "JSON" - self.sub_title = f"{self._location}" + self.sub_title = self._location self.query_one("#status", Static).update(f" {count} checkpoint(s) | {storage}") - async def on_option_list_option_highlighted( - self, - event: OptionList.OptionHighlighted, - ) -> None: - idx = event.option_index - if idx is None: - return - if idx < len(self._entries): - self._selected_idx = idx - entry = self._entries[idx] - self.query_one("#detail", Static).update(_format_detail(entry)) + def _show_detail(self, entry: dict[str, Any]) -> None: + """Update the detail panel for a checkpoint entry.""" + self._selected_entry = entry + self.query_one("#action-buttons").add_class("visible") - def action_cursor_down(self) -> None: - self.query_one("#cp-list", OptionList).action_cursor_down() + detail_scroll = self.query_one("#detail-scroll", VerticalScroll) - def action_cursor_up(self) -> None: - self.query_one("#cp-list", OptionList).action_cursor_up() + # Remove old collapsibles + for widget in list(detail_scroll.query("Collapsible")): + widget.remove() - async def on_option_list_option_selected( - self, - event: OptionList.OptionSelected, - ) -> None: - idx = event.option_index - if idx is None or idx >= len(self._entries): - return - entry = self._entries[idx] + # Header + name = entry.get("name", "") + ts = entry.get("ts") or "unknown" + trigger = entry.get("trigger") or "" + branch = entry.get("branch", "main") + parent_id = entry.get("parent_id") + + header_lines = [ + f"[bold {_PRIMARY}]{name}[/]", + f"[{_DIM}]{'─' * 50}[/]", + "", + f" [bold]Time[/] {ts}", + ] + if "size" in entry: + header_lines.append(f" [bold]Size[/] {_format_size(entry['size'])}") + header_lines.append(f" [bold]Events[/] {entry.get('event_count', 0)}") + if trigger: + header_lines.append(f" [bold]Trigger[/] [{_PRIMARY}]{trigger}[/]") + header_lines.append(f" [bold]Branch[/] [{_SECONDARY}]{branch}[/]") + if parent_id: + header_lines.append(f" [bold]Parent[/] [{_DIM}]{parent_id}[/]") if "path" in entry: - loc = entry["path"] - elif _is_sqlite(self._location): - loc = f"{self._location}#{entry['name']}" - else: - loc = entry.get("name", "") - self._pending_location = loc - name = entry.get("name", loc) - self.push_screen(ConfirmResumeScreen(name), self._on_confirm) + header_lines.append(f" [bold]Path[/] [{_DIM}]{entry['path']}[/]") + if "db" in entry: + header_lines.append(f" [bold]Database[/] [{_DIM}]{entry['db']}[/]") - def _on_confirm(self, confirmed: bool | None) -> None: - if confirmed: - self.exit(self._pending_location) - else: - self._pending_location = "" + self.query_one("#detail-header", Static).update("\n".join(header_lines)) + + # Entity collapsibles + for ent in entry.get("entities", []): + etype = ent.get("type", "unknown") + ename = ent.get("name", "unnamed") + completed = ent.get("tasks_completed") + total = ent.get("tasks_total") + title = f"{etype}: {ename}" + if completed is not None and total is not None: + title += f" [{completed}/{total} tasks]" + + content = Static(_build_entity_detail(ent), classes="entity-detail") + collapsible = Collapsible(content, title=title, collapsed=False) + detail_scroll.mount(collapsible) + + def _resolve_location(self, entry: dict[str, Any]) -> str: + """Get the restore location string for a checkpoint entry.""" + if "path" in entry: + return str(entry["path"]) + if _is_sqlite(self._location): + return f"{self._location}#{entry['name']}" + return str(entry.get("name", "")) + + async def on_tree_node_highlighted( + self, event: Tree.NodeHighlighted[dict[str, Any]] + ) -> None: + if event.node.data is not None: + self._show_detail(event.node.data) + + def on_button_pressed(self, event: Button.Pressed) -> None: + if self._selected_entry is None: + return + loc = self._resolve_location(self._selected_entry) + if event.button.id == "btn-resume": + self.exit((loc, "resume")) + elif event.button.id == "btn-fork": + self.exit((loc, "fork")) def action_refresh(self) -> None: - self._refresh_list() + self._refresh_tree() async def _run_checkpoint_tui_async(location: str) -> None: @@ -345,17 +393,25 @@ async def _run_checkpoint_tui_async(location: str) -> None: import click app = CheckpointTUI(location=location) - selected = await app.run_async() + selection = await app.run_async() - if selected is None: + if selection is None: return - click.echo(f"\nResuming from: {selected}\n") + selected, action = selection from crewai.crew import Crew from crewai.state.checkpoint_config import CheckpointConfig - crew = Crew.from_checkpoint(CheckpointConfig(restore_from=selected)) + config = CheckpointConfig(restore_from=selected) + + if action == "fork": + click.echo(f"\nForking from: {selected}\n") + crew = Crew.fork(config) + else: + click.echo(f"\nResuming from: {selected}\n") + crew = Crew.from_checkpoint(config) + result = await crew.akickoff() click.echo(f"\nResult: {getattr(result, 'raw', result)}") diff --git a/lib/crewai/src/crewai/cli/cli.py b/lib/crewai/src/crewai/cli/cli.py index 20a65dbe1..03bdcd502 100644 --- a/lib/crewai/src/crewai/cli/cli.py +++ b/lib/crewai/src/crewai/cli/cli.py @@ -793,6 +793,9 @@ def traces_status() -> None: @click.pass_context def checkpoint(ctx: click.Context, location: str) -> None: """Browse and inspect checkpoints. Launches a TUI when called without a subcommand.""" + from crewai.cli.checkpoint_cli import _detect_location + + location = _detect_location(location) ctx.ensure_object(dict) ctx.obj["location"] = location if ctx.invoked_subcommand is None: @@ -805,18 +808,18 @@ def checkpoint(ctx: click.Context, location: str) -> None: @click.argument("location", default="./.checkpoints") def checkpoint_list(location: str) -> None: """List checkpoints in a directory.""" - from crewai.cli.checkpoint_cli import list_checkpoints + from crewai.cli.checkpoint_cli import _detect_location, list_checkpoints - list_checkpoints(location) + list_checkpoints(_detect_location(location)) @checkpoint.command("info") @click.argument("path", default="./.checkpoints") def checkpoint_info(path: str) -> None: """Show details of a checkpoint. Pass a file or directory for latest.""" - from crewai.cli.checkpoint_cli import info_checkpoint + from crewai.cli.checkpoint_cli import _detect_location, info_checkpoint - info_checkpoint(path) + info_checkpoint(_detect_location(path)) if __name__ == "__main__": diff --git a/lib/crewai/src/crewai/state/checkpoint_config.py b/lib/crewai/src/crewai/state/checkpoint_config.py index bd24ef8eb..e03964c05 100644 --- a/lib/crewai/src/crewai/state/checkpoint_config.py +++ b/lib/crewai/src/crewai/state/checkpoint_config.py @@ -213,6 +213,9 @@ class CheckpointConfig(BaseModel): def _register_handlers(self) -> CheckpointConfig: from crewai.state.checkpoint_listener import _ensure_handlers_registered + if isinstance(self.provider, SqliteProvider) and not Path(self.location).suffix: + self.location = f"{self.location}.db" + _ensure_handlers_registered() return self diff --git a/lib/crewai/src/crewai/state/runtime.py b/lib/crewai/src/crewai/state/runtime.py index 3f32457bb..4dd1a4555 100644 --- a/lib/crewai/src/crewai/state/runtime.py +++ b/lib/crewai/src/crewai/state/runtime.py @@ -10,6 +10,7 @@ via ``RuntimeState.model_rebuild()``. from __future__ import annotations import logging +from pathlib import Path from typing import TYPE_CHECKING, Any import uuid @@ -37,6 +38,19 @@ if TYPE_CHECKING: from crewai import Entity +def _base_location(location: str, provider: BaseProvider) -> str: + """Extract the base storage location from a restore path. + + For SQLite (``db_path#id``), returns ``db_path``. + For JSON (a file path), returns the parent directory. + """ + from crewai.state.provider.sqlite_provider import SqliteProvider + + if isinstance(provider, SqliteProvider): + return location.rsplit("#", 1)[0] + return str(Path(location).parent) + + def _sync_checkpoint_fields(entity: object) -> None: """Copy private runtime attrs into checkpoint fields before serializing. @@ -108,6 +122,7 @@ class RuntimeState(RootModel): # type: ignore[type-arg] _checkpoint_id: str | None = PrivateAttr(default=None) _parent_id: str | None = PrivateAttr(default=None) _branch: str = PrivateAttr(default="main") + _location: str | None = PrivateAttr(default=None) @property def event_record(self) -> EventRecord: @@ -194,7 +209,10 @@ class RuntimeState(RootModel): # type: ignore[type-arg] return result def fork(self, branch: str | None = None) -> None: - """Mark this state as a fork for subsequent checkpoints. + """Create a new execution branch and write an initial checkpoint. + + If this state was restored from a checkpoint, an initial checkpoint + is written on the new branch so the fork point is recorded. Args: branch: Branch label. Auto-generated from the current checkpoint @@ -208,6 +226,9 @@ class RuntimeState(RootModel): # type: ignore[type-arg] else: self._branch = f"fork/{uuid.uuid4().hex[:8]}" + if self._location is not None: + self.checkpoint(self._location) + @classmethod def from_checkpoint(cls, config: CheckpointConfig, **kwargs: Any) -> RuntimeState: """Restore a RuntimeState from a checkpoint. @@ -227,6 +248,8 @@ class RuntimeState(RootModel): # type: ignore[type-arg] provider = detect_provider(location) raw = provider.from_checkpoint(location) state = cls.model_validate_json(raw, **kwargs) + state._provider = provider + state._location = _base_location(location, provider) checkpoint_id = provider.extract_id(location) state._checkpoint_id = checkpoint_id state._parent_id = checkpoint_id @@ -253,6 +276,8 @@ class RuntimeState(RootModel): # type: ignore[type-arg] provider = detect_provider(location) raw = await provider.afrom_checkpoint(location) state = cls.model_validate_json(raw, **kwargs) + state._provider = provider + state._location = _base_location(location, provider) checkpoint_id = provider.extract_id(location) state._checkpoint_id = checkpoint_id state._parent_id = checkpoint_id