Merge branch 'main' into luzk/propagate-is-litellm

This commit is contained in:
Lucas Gomide
2026-04-10 10:33:39 -03:00
committed by GitHub
11 changed files with 652 additions and 242 deletions

View File

@@ -6,12 +6,16 @@ from datetime import datetime
import glob
import json
import os
import re
import sqlite3
from typing import Any
import click
_PLACEHOLDER_RE = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\-]*)}")
_SQLITE_MAGIC = b"SQLite format 3\x00"
_SELECT_ALL = """
@@ -34,6 +38,25 @@ LIMIT 1
"""
_DEFAULT_DIR = "./.checkpoints"
_DEFAULT_DB = "./.checkpoints.db"
def _detect_location(location: str) -> str:
"""Resolve the default checkpoint location.
When the caller passes the default directory path, check whether a
SQLite database exists at the conventional ``.db`` path and prefer it.
"""
if (
location == _DEFAULT_DIR
and not os.path.exists(_DEFAULT_DIR)
and os.path.exists(_DEFAULT_DB)
):
return _DEFAULT_DB
return location
def _is_sqlite(path: str) -> bool:
"""Check if a file is a SQLite database by reading its magic bytes."""
if not os.path.isfile(path):
@@ -52,13 +75,7 @@ def _parse_checkpoint_json(raw: str, source: str) -> dict[str, Any]:
nodes = data.get("event_record", {}).get("nodes", {})
event_count = len(nodes)
trigger_event = None
if nodes:
last_node = max(
nodes.values(),
key=lambda n: n.get("event", {}).get("emission_sequence") or 0,
)
trigger_event = last_node.get("event", {}).get("type")
trigger_event = data.get("trigger")
parsed_entities: list[dict[str, Any]] = []
for entity in entities:
@@ -76,16 +93,47 @@ def _parse_checkpoint_json(raw: str, source: str) -> dict[str, Any]:
{
"description": t.get("description", ""),
"completed": t.get("output") is not None,
"output": (t.get("output") or {}).get("raw", ""),
}
for t in tasks
]
parsed_entities.append(info)
inputs: dict[str, Any] = {}
for entity in entities:
cp_inputs = entity.get("checkpoint_inputs")
if isinstance(cp_inputs, dict) and cp_inputs:
inputs = dict(cp_inputs)
break
for entity in entities:
for task in entity.get("tasks", []):
for field in (
"checkpoint_original_description",
"checkpoint_original_expected_output",
):
text = task.get(field) or ""
for match in _PLACEHOLDER_RE.findall(text):
if match not in inputs:
inputs[match] = ""
for agent in entity.get("agents", []):
for field in ("role", "goal", "backstory"):
text = agent.get(field) or ""
for match in _PLACEHOLDER_RE.findall(text):
if match not in inputs:
inputs[match] = ""
branch = data.get("branch", "main")
parent_id = data.get("parent_id")
return {
"source": source,
"event_count": event_count,
"trigger": trigger_event,
"entities": parsed_entities,
"branch": branch,
"parent_id": parent_id,
"inputs": inputs,
}
@@ -189,6 +237,7 @@ def _list_sqlite(db_path: str) -> list[dict[str, Any]]:
"entities": [],
"source": checkpoint_id,
}
meta["db"] = db_path
results.append(meta)
return results
@@ -311,6 +360,10 @@ def _print_info(meta: dict[str, Any]) -> None:
trigger = meta.get("trigger")
if trigger:
click.echo(f"Trigger: {trigger}")
click.echo(f"Branch: {meta.get('branch', 'main')}")
parent_id = meta.get("parent_id")
if parent_id:
click.echo(f"Parent: {parent_id}")
for ent in meta.get("entities", []):
eid = str(ent.get("id", ""))[:8]

View File

@@ -2,17 +2,23 @@
from __future__ import annotations
from collections import defaultdict
from typing import Any, ClassVar
from textual.app import App, ComposeResult
from textual.binding import Binding
from textual.containers import Horizontal, Vertical
from textual.screen import ModalScreen
from textual.widgets import Button, Footer, Header, OptionList, Static
from textual.widgets.option_list import Option
from textual.containers import Horizontal, Vertical, VerticalScroll
from textual.widgets import (
Button,
Footer,
Header,
Input,
Static,
TextArea,
Tree,
)
from crewai.cli.checkpoint_cli import (
_entity_summary,
_format_size,
_is_sqlite,
_list_json,
@@ -34,151 +40,54 @@ def _load_entries(location: str) -> list[dict[str, Any]]:
return _list_json(location)
def _format_list_label(entry: dict[str, Any]) -> str:
"""Format a checkpoint entry for the list panel."""
name = entry.get("name", "")
ts = entry.get("ts") or ""
trigger = entry.get("trigger") or ""
summary = _entity_summary(entry.get("entities", []))
line1 = f"[bold]{name}[/]"
parts = []
if ts:
parts.append(f"[dim]{ts}[/]")
if "size" in entry:
parts.append(f"[dim]{_format_size(entry['size'])}[/]")
if trigger:
parts.append(f"[{_PRIMARY}]{trigger}[/]")
line2 = " ".join(parts)
line3 = f" [{_DIM}]{summary}[/]"
return f"{line1}\n{line2}\n{line3}"
def _short_id(name: str) -> str:
"""Shorten a checkpoint name for tree display."""
if len(name) > 30:
return name[:27] + "..."
return name
def _format_detail(entry: dict[str, Any]) -> str:
"""Format checkpoint details for the right panel."""
def _entry_id(entry: dict[str, Any]) -> str:
"""Normalize an entry's name into its checkpoint ID.
JSON filenames are ``{ts}_{uuid}_p-{parent}.json``; SQLite IDs
are already ``{ts}_{uuid}``. This strips the JSON suffix so
fork-parent lookups work in both providers.
"""
name = str(entry.get("name", ""))
if name.endswith(".json"):
name = name[: -len(".json")]
idx = name.find("_p-")
if idx != -1:
name = name[:idx]
return name
def _build_entity_header(ent: dict[str, Any]) -> str:
"""Build rich text header for an entity (progress bar only)."""
lines: list[str] = []
# Header
name = entry.get("name", "")
lines.append(f"[bold {_PRIMARY}]{name}[/]")
lines.append(f"[{_DIM}]{'' * 50}[/]")
lines.append("")
# Metadata table
ts = entry.get("ts") or "unknown"
trigger = entry.get("trigger") or ""
lines.append(f" [bold]Time[/] {ts}")
if "size" in entry:
lines.append(f" [bold]Size[/] {_format_size(entry['size'])}")
lines.append(f" [bold]Events[/] {entry.get('event_count', 0)}")
if trigger:
lines.append(f" [bold]Trigger[/] [{_PRIMARY}]{trigger}[/]")
if "path" in entry:
lines.append(f" [bold]Path[/] [{_DIM}]{entry['path']}[/]")
if "db" in entry:
lines.append(f" [bold]Database[/] [{_DIM}]{entry['db']}[/]")
# Entities
for ent in entry.get("entities", []):
eid = str(ent.get("id", ""))[:8]
etype = ent.get("type", "unknown")
ename = ent.get("name", "unnamed")
lines.append("")
lines.append(f" [{_DIM}]{'' * 50}[/]")
lines.append(f" [bold {_SECONDARY}]{etype}[/]: {ename} [{_DIM}]{eid}[/]")
tasks = ent.get("tasks")
if isinstance(tasks, list):
completed = ent.get("tasks_completed", 0)
total = ent.get("tasks_total", 0)
pct = int(completed / total * 100) if total else 0
bar_len = 20
filled = int(bar_len * completed / total) if total else 0
bar = f"[{_PRIMARY}]{'' * filled}[/][{_DIM}]{'' * (bar_len - filled)}[/]"
lines.append(f" {bar} {completed}/{total} tasks ({pct}%)")
lines.append("")
for i, task in enumerate(tasks):
if task.get("completed"):
icon = "[green]✓[/]"
else:
icon = "[yellow]○[/]"
desc = str(task.get("description", ""))
if len(desc) > 55:
desc = desc[:52] + "..."
lines.append(f" {icon} {i + 1}. {desc}")
tasks = ent.get("tasks")
if isinstance(tasks, list):
completed = ent.get("tasks_completed", 0)
total = ent.get("tasks_total", 0)
pct = int(completed / total * 100) if total else 0
bar_len = 20
filled = int(bar_len * completed / total) if total else 0
bar = f"[{_PRIMARY}]{'' * filled}[/][{_DIM}]{'' * (bar_len - filled)}[/]"
lines.append(f"{bar} {completed}/{total} tasks ({pct}%)")
return "\n".join(lines)
class ConfirmResumeScreen(ModalScreen[bool]):
"""Modal confirmation before resuming from a checkpoint."""
CSS = f"""
ConfirmResumeScreen {{
align: center middle;
}}
#confirm-dialog {{
width: 60;
height: auto;
padding: 1 2;
background: {_BG_PANEL};
border: round {_PRIMARY};
}}
#confirm-label {{
width: 100%;
content-align: center middle;
margin-bottom: 1;
}}
#confirm-name {{
width: 100%;
content-align: center middle;
color: {_PRIMARY};
text-style: bold;
margin-bottom: 1;
}}
#confirm-buttons {{
width: 100%;
height: 3;
layout: horizontal;
align: center middle;
}}
Button {{
margin: 0 2;
min-width: 12;
}}
"""
def __init__(self, checkpoint_name: str) -> None:
super().__init__()
self._checkpoint_name = checkpoint_name
def compose(self) -> ComposeResult:
with Vertical(id="confirm-dialog"):
yield Static("Resume from this checkpoint?", id="confirm-label")
yield Static(self._checkpoint_name, id="confirm-name")
with Horizontal(id="confirm-buttons"):
yield Button("Resume", variant="success", id="btn-yes")
yield Button("Cancel", variant="default", id="btn-no")
def on_button_pressed(self, event: Button.Pressed) -> None:
self.dismiss(event.button.id == "btn-yes")
def on_key(self, event: Any) -> None:
if event.key == "y":
self.dismiss(True)
elif event.key in ("n", "escape"):
self.dismiss(False)
# Return type: (location, action, inputs, task_output_overrides)
_TuiResult = tuple[str, str, dict[str, Any] | None, dict[int, str] | None] | None
class CheckpointTUI(App[str | None]):
class CheckpointTUI(App[_TuiResult]):
"""TUI to browse and inspect checkpoints.
Returns the checkpoint location string to resume from, or None if
the user quit without selecting.
Returns ``(location, action, inputs)`` where action is ``"resume"`` or
``"fork"`` and inputs is a parsed dict or ``None``,
or ``None`` if the user quit without selecting.
"""
TITLE = "CrewAI Checkpoints"
@@ -199,145 +108,431 @@ class CheckpointTUI(App[str | None]):
background: {_PRIMARY};
color: {_TERTIARY};
}}
Horizontal {{
#main-layout {{
height: 1fr;
}}
#cp-list {{
width: 38%;
#tree-panel {{
width: 45%;
background: {_BG_PANEL};
border: round {_SECONDARY};
padding: 0 1;
scrollbar-color: {_PRIMARY};
}}
#cp-list:focus {{
#tree-panel:focus-within {{
border: round {_PRIMARY};
}}
#cp-list > .option-list--option-highlighted {{
background: {_SECONDARY};
color: {_TERTIARY};
text-style: none;
}}
#cp-list > .option-list--option-highlighted * {{
color: {_TERTIARY};
}}
#detail-container {{
width: 62%;
padding: 0 1;
width: 55%;
height: 1fr;
}}
#detail {{
#detail-scroll {{
height: 1fr;
background: {_BG_PANEL};
border: round {_SECONDARY};
padding: 1 2;
overflow-y: auto;
scrollbar-color: {_PRIMARY};
}}
#detail:focus {{
#detail-scroll:focus-within {{
border: round {_PRIMARY};
}}
#detail-header {{
margin-bottom: 1;
}}
#status {{
height: 1;
padding: 0 2;
color: {_DIM};
}}
#inputs-section {{
display: none;
height: auto;
max-height: 8;
padding: 0 1;
}}
#inputs-section.visible {{
display: block;
}}
#inputs-label {{
height: 1;
color: {_DIM};
padding: 0 1;
}}
.input-row {{
height: 3;
padding: 0 1;
}}
.input-row Static {{
width: auto;
min-width: 12;
padding: 1 1 0 0;
color: {_TERTIARY};
}}
.input-row Input {{
width: 1fr;
}}
#no-inputs-label {{
height: 1;
color: {_DIM};
padding: 0 1;
}}
#action-buttons {{
height: 3;
align: right middle;
padding: 0 1;
display: none;
}}
#action-buttons.visible {{
display: block;
}}
#action-buttons Button {{
margin: 0 0 0 1;
min-width: 10;
}}
#btn-resume {{
background: {_SECONDARY};
color: {_TERTIARY};
}}
#btn-resume:hover {{
background: {_PRIMARY};
}}
#btn-fork {{
background: {_PRIMARY};
color: {_TERTIARY};
}}
#btn-fork:hover {{
background: {_SECONDARY};
}}
.entity-title {{
padding: 1 1 0 1;
}}
.entity-detail {{
padding: 0 1;
}}
.task-output-editor {{
height: auto;
max-height: 10;
margin: 0 1 1 1;
border: round {_DIM};
}}
.task-output-editor:focus {{
border: round {_PRIMARY};
}}
.task-label {{
padding: 0 1;
}}
Tree {{
background: {_BG_PANEL};
}}
Tree > .tree--cursor {{
background: {_SECONDARY};
color: {_TERTIARY};
}}
"""
BINDINGS: ClassVar[list[Binding | tuple[str, str] | tuple[str, str, str]]] = [
("q", "quit", "Quit"),
("r", "refresh", "Refresh"),
("j", "cursor_down", "Down"),
("k", "cursor_up", "Up"),
]
def __init__(self, location: str = "./.checkpoints") -> None:
super().__init__()
self._location = location
self._entries: list[dict[str, Any]] = []
self._selected_idx: int = 0
self._pending_location: str = ""
self._selected_entry: dict[str, Any] | None = None
self._input_keys: list[str] = []
self._task_output_ids: list[tuple[int, str, str]] = []
def compose(self) -> ComposeResult:
yield Header(show_clock=False)
with Horizontal():
yield OptionList(id="cp-list")
with Horizontal(id="main-layout"):
tree: Tree[dict[str, Any]] = Tree("Checkpoints", id="tree-panel")
tree.show_root = True
tree.guide_depth = 3
yield tree
with Vertical(id="detail-container"):
yield Static("", id="status")
yield Static(
f"\n [{_DIM}]Select a checkpoint from the list[/]", # noqa: S608
id="detail",
)
with VerticalScroll(id="detail-scroll"):
yield Static(
f"[{_DIM}]Select a checkpoint from the tree[/]", # noqa: S608
id="detail-header",
)
with Vertical(id="inputs-section"):
yield Static("Inputs", id="inputs-label")
with Horizontal(id="action-buttons"):
yield Button("Resume", id="btn-resume")
yield Button("Fork", id="btn-fork")
yield Footer()
async def on_mount(self) -> None:
self.query_one("#cp-list", OptionList).border_title = "Checkpoints"
self.query_one("#detail", Static).border_title = "Detail"
self._refresh_list()
self._refresh_tree()
self.query_one("#tree-panel", Tree).root.expand()
def _refresh_list(self) -> None:
def _refresh_tree(self) -> None:
self._entries = _load_entries(self._location)
option_list = self.query_one("#cp-list", OptionList)
option_list.clear_options()
self._selected_entry = None
tree = self.query_one("#tree-panel", Tree)
tree.clear()
if not self._entries:
self.query_one("#detail", Static).update(
f"\n [{_DIM}]No checkpoints in {self._location}[/]"
self.query_one("#detail-header", Static).update(
f"[{_DIM}]No checkpoints in {self._location}[/]"
)
self.query_one("#status", Static).update("")
self.sub_title = self._location
return
# Group by branch
branches: dict[str, list[dict[str, Any]]] = defaultdict(list)
for entry in self._entries:
option_list.add_option(Option(_format_list_label(entry)))
branch = entry.get("branch", "main")
branches[branch].append(entry)
# Index checkpoint names to tree nodes so forks can attach
node_by_name: dict[str, Any] = {}
def _make_label(e: dict[str, Any]) -> str:
name = e.get("name", "")
ts = e.get("ts") or ""
trigger = e.get("trigger") or ""
parts = [f"[bold]{_short_id(name)}[/]"]
if ts:
time_part = ts.split(" ")[-1] if " " in ts else ts
parts.append(f"[{_DIM}]{time_part}[/]")
if trigger:
parts.append(f"[{_PRIMARY}]{trigger}[/]")
return " ".join(parts)
fork_parents: set[str] = set()
for branch_name, entries in branches.items():
if branch_name == "main" or not entries:
continue
oldest = min(entries, key=lambda e: str(e.get("name", "")))
first_parent = oldest.get("parent_id")
if first_parent:
fork_parents.add(str(first_parent))
def _add_checkpoint(parent_node: Any, e: dict[str, Any]) -> None:
"""Add a checkpoint node — expandable only if a fork attaches to it."""
cp_id = _entry_id(e)
if cp_id in fork_parents:
node = parent_node.add(
_make_label(e), data=e, expand=False, allow_expand=True
)
else:
node = parent_node.add_leaf(_make_label(e), data=e)
node_by_name[cp_id] = node
if "main" in branches:
for entry in reversed(branches["main"]):
_add_checkpoint(tree.root, entry)
fork_branches = [
(name, sorted(entries, key=lambda e: str(e.get("name", ""))))
for name, entries in branches.items()
if name != "main"
]
remaining = fork_branches
max_passes = len(remaining) + 1
while remaining and max_passes > 0:
max_passes -= 1
deferred = []
made_progress = False
for branch_name, entries in remaining:
first_parent = entries[0].get("parent_id") if entries else None
if first_parent and str(first_parent) not in node_by_name:
deferred.append((branch_name, entries))
continue
attach_to: Any = tree.root
if first_parent:
attach_to = node_by_name.get(str(first_parent), tree.root)
branch_label = (
f"[bold {_SECONDARY}]{branch_name}[/] [{_DIM}]({len(entries)})[/]"
)
branch_node = attach_to.add(branch_label, expand=False)
for entry in entries:
_add_checkpoint(branch_node, entry)
made_progress = True
remaining = deferred
if not made_progress:
break
for branch_name, entries in remaining:
branch_label = (
f"[bold {_SECONDARY}]{branch_name}[/] "
f"[{_DIM}]({len(entries)})[/] [{_DIM}](orphaned)[/]"
)
branch_node = tree.root.add(branch_label, expand=False)
for entry in entries:
_add_checkpoint(branch_node, entry)
count = len(self._entries)
storage = "SQLite" if _is_sqlite(self._location) else "JSON"
self.sub_title = f"{self._location}"
self.sub_title = self._location
self.query_one("#status", Static).update(f" {count} checkpoint(s) | {storage}")
async def on_option_list_option_highlighted(
self,
event: OptionList.OptionHighlighted,
) -> None:
idx = event.option_index
if idx is None:
return
if idx < len(self._entries):
self._selected_idx = idx
entry = self._entries[idx]
self.query_one("#detail", Static).update(_format_detail(entry))
async def _show_detail(self, entry: dict[str, Any]) -> None:
"""Update the detail panel for a checkpoint entry."""
self._selected_entry = entry
self.query_one("#action-buttons").add_class("visible")
def action_cursor_down(self) -> None:
self.query_one("#cp-list", OptionList).action_cursor_down()
detail_scroll = self.query_one("#detail-scroll", VerticalScroll)
def action_cursor_up(self) -> None:
self.query_one("#cp-list", OptionList).action_cursor_up()
# Remove all dynamic children except the header — await so IDs are freed
to_remove = [c for c in detail_scroll.children if c.id != "detail-header"]
for child in to_remove:
await child.remove()
async def on_option_list_option_selected(
self,
event: OptionList.OptionSelected,
) -> None:
idx = event.option_index
if idx is None or idx >= len(self._entries):
return
entry = self._entries[idx]
# Header
name = entry.get("name", "")
ts = entry.get("ts") or "unknown"
trigger = entry.get("trigger") or ""
branch = entry.get("branch", "main")
parent_id = entry.get("parent_id")
header_lines = [
f"[bold {_PRIMARY}]{name}[/]",
f"[{_DIM}]{'' * 50}[/]",
"",
f" [bold]Time[/] {ts}",
]
if "size" in entry:
header_lines.append(f" [bold]Size[/] {_format_size(entry['size'])}")
header_lines.append(f" [bold]Events[/] {entry.get('event_count', 0)}")
if trigger:
header_lines.append(f" [bold]Trigger[/] [{_PRIMARY}]{trigger}[/]")
header_lines.append(f" [bold]Branch[/] [{_SECONDARY}]{branch}[/]")
if parent_id:
header_lines.append(f" [bold]Parent[/] [{_DIM}]{parent_id}[/]")
if "path" in entry:
loc = entry["path"]
elif _is_sqlite(self._location):
loc = f"{self._location}#{entry['name']}"
else:
loc = entry.get("name", "")
self._pending_location = loc
name = entry.get("name", loc)
self.push_screen(ConfirmResumeScreen(name), self._on_confirm)
header_lines.append(f" [bold]Path[/] [{_DIM}]{entry['path']}[/]")
if "db" in entry:
header_lines.append(f" [bold]Database[/] [{_DIM}]{entry['db']}[/]")
def _on_confirm(self, confirmed: bool | None) -> None:
if confirmed:
self.exit(self._pending_location)
else:
self._pending_location = ""
self.query_one("#detail-header", Static).update("\n".join(header_lines))
# Entity details and editable task outputs — mounted flat for scrolling
self._task_output_ids = []
flat_task_idx = 0
for ent_idx, ent in enumerate(entry.get("entities", [])):
etype = ent.get("type", "unknown")
ename = ent.get("name", "unnamed")
completed = ent.get("tasks_completed")
total = ent.get("tasks_total")
entity_title = f"[bold {_SECONDARY}]{etype}: {ename}[/]"
if completed is not None and total is not None:
entity_title += f" [{_DIM}]{completed}/{total} tasks[/]"
await detail_scroll.mount(Static(entity_title, classes="entity-title"))
await detail_scroll.mount(
Static(_build_entity_header(ent), classes="entity-detail")
)
tasks = ent.get("tasks", [])
for i, task in enumerate(tasks):
desc = str(task.get("description", ""))
if len(desc) > 55:
desc = desc[:52] + "..."
if task.get("completed"):
icon = "[green]✓[/]"
await detail_scroll.mount(
Static(f" {icon} {i + 1}. {desc}", classes="task-label")
)
output_text = task.get("output", "")
editor_id = f"task-output-{ent_idx}-{i}"
await detail_scroll.mount(
TextArea(
str(output_text),
classes="task-output-editor",
id=editor_id,
)
)
self._task_output_ids.append(
(flat_task_idx, editor_id, str(output_text))
)
else:
icon = "[yellow]○[/]"
await detail_scroll.mount(
Static(f" {icon} {i + 1}. {desc}", classes="task-label")
)
flat_task_idx += 1
# Build input fields
await self._build_input_fields(entry.get("inputs", {}))
async def _build_input_fields(self, inputs: dict[str, Any]) -> None:
"""Rebuild the inputs section with one field per input key."""
section = self.query_one("#inputs-section")
# Remove old dynamic children — await so IDs are freed
for widget in list(section.query(".input-row, .no-inputs")):
await widget.remove()
self._input_keys = []
if not inputs:
await section.mount(Static(f"[{_DIM}]No inputs[/]", classes="no-inputs"))
section.add_class("visible")
return
for key, value in inputs.items():
self._input_keys.append(key)
row = Horizontal(classes="input-row")
row.compose_add_child(Static(f"[bold]{key}[/]"))
row.compose_add_child(
Input(value=str(value), placeholder=key, id=f"input-{key}")
)
await section.mount(row)
section.add_class("visible")
def _collect_inputs(self) -> dict[str, Any] | None:
"""Collect current values from input fields."""
if not self._input_keys:
return None
result: dict[str, Any] = {}
for key in self._input_keys:
widget = self.query_one(f"#input-{key}", Input)
result[key] = widget.value
return result
def _collect_task_overrides(self) -> dict[int, str] | None:
"""Collect edited task outputs. Returns only changed values."""
if not self._task_output_ids or self._selected_entry is None:
return None
overrides: dict[int, str] = {}
for task_idx, editor_id, original in self._task_output_ids:
editor = self.query_one(f"#{editor_id}", TextArea)
if editor.text != original:
overrides[task_idx] = editor.text
return overrides or None
def _resolve_location(self, entry: dict[str, Any]) -> str:
"""Get the restore location string for a checkpoint entry."""
if "path" in entry:
return str(entry["path"])
if _is_sqlite(self._location):
return f"{self._location}#{entry['name']}"
return str(entry.get("name", ""))
async def on_tree_node_highlighted(
self, event: Tree.NodeHighlighted[dict[str, Any]]
) -> None:
if event.node.data is not None:
await self._show_detail(event.node.data)
def on_button_pressed(self, event: Button.Pressed) -> None:
if self._selected_entry is None:
return
inputs = self._collect_inputs()
overrides = self._collect_task_overrides()
loc = self._resolve_location(self._selected_entry)
if event.button.id == "btn-resume":
self.exit((loc, "resume", inputs, overrides))
elif event.button.id == "btn-fork":
self.exit((loc, "fork", inputs, overrides))
def action_refresh(self) -> None:
self._refresh_list()
self._refresh_tree()
async def _run_checkpoint_tui_async(location: str) -> None:
@@ -345,18 +540,78 @@ async def _run_checkpoint_tui_async(location: str) -> None:
import click
app = CheckpointTUI(location=location)
selected = await app.run_async()
selection = await app.run_async()
if selected is None:
if selection is None:
return
click.echo(f"\nResuming from: {selected}\n")
selected, action, inputs, task_overrides = selection
from crewai.crew import Crew
from crewai.state.checkpoint_config import CheckpointConfig
crew = Crew.from_checkpoint(CheckpointConfig(restore_from=selected))
result = await crew.akickoff()
config = CheckpointConfig(restore_from=selected)
if action == "fork":
click.echo(f"\nForking from: {selected}\n")
crew = Crew.fork(config)
else:
click.echo(f"\nResuming from: {selected}\n")
crew = Crew.from_checkpoint(config)
if task_overrides:
click.echo("Modifications:")
overridden_agents: set[int] = set()
for task_idx, new_output in task_overrides.items():
if task_idx < len(crew.tasks) and crew.tasks[task_idx].output is not None:
desc = crew.tasks[task_idx].description or f"Task {task_idx + 1}"
if len(desc) > 60:
desc = desc[:57] + "..."
crew.tasks[task_idx].output.raw = new_output # type: ignore[union-attr]
preview = new_output.replace("\n", " ")
if len(preview) > 80:
preview = preview[:77] + "..."
click.echo(f" Task {task_idx + 1}: {desc}")
click.echo(f" -> {preview}")
agent = crew.tasks[task_idx].agent
if agent and agent.agent_executor:
nth = sum(1 for t in crew.tasks[:task_idx] if t.agent is agent)
messages = agent.agent_executor.messages
system_positions = [
i for i, m in enumerate(messages) if m.get("role") == "system"
]
if nth < len(system_positions):
seg_start = system_positions[nth]
seg_end = (
system_positions[nth + 1]
if nth + 1 < len(system_positions)
else len(messages)
)
for j in range(seg_end - 1, seg_start, -1):
if messages[j].get("role") == "assistant":
messages[j]["content"] = new_output
break
overridden_agents.add(id(agent))
earliest = min(task_overrides)
for offset, subsequent in enumerate(
crew.tasks[earliest + 1 :], start=earliest + 1
):
if subsequent.output and offset not in task_overrides:
subsequent.output = None
if subsequent.agent and subsequent.agent.agent_executor:
subsequent.agent.agent_executor._resuming = False
if id(subsequent.agent) not in overridden_agents:
subsequent.agent.agent_executor.messages = []
click.echo()
if inputs:
click.echo("Inputs:")
for k, v in inputs.items():
click.echo(f" {k}: {v}")
click.echo()
result = await crew.akickoff(inputs=inputs)
click.echo(f"\nResult: {getattr(result, 'raw', result)}")

View File

@@ -793,6 +793,9 @@ def traces_status() -> None:
@click.pass_context
def checkpoint(ctx: click.Context, location: str) -> None:
"""Browse and inspect checkpoints. Launches a TUI when called without a subcommand."""
from crewai.cli.checkpoint_cli import _detect_location
location = _detect_location(location)
ctx.ensure_object(dict)
ctx.obj["location"] = location
if ctx.invoked_subcommand is None:
@@ -805,18 +808,18 @@ def checkpoint(ctx: click.Context, location: str) -> None:
@click.argument("location", default="./.checkpoints")
def checkpoint_list(location: str) -> None:
"""List checkpoints in a directory."""
from crewai.cli.checkpoint_cli import list_checkpoints
from crewai.cli.checkpoint_cli import _detect_location, list_checkpoints
list_checkpoints(location)
list_checkpoints(_detect_location(location))
@checkpoint.command("info")
@click.argument("path", default="./.checkpoints")
def checkpoint_info(path: str) -> None:
"""Show details of a checkpoint. Pass a file or directory for latest."""
from crewai.cli.checkpoint_cli import info_checkpoint
from crewai.cli.checkpoint_cli import _detect_location, info_checkpoint
info_checkpoint(path)
info_checkpoint(_detect_location(path))
if __name__ == "__main__":

View File

@@ -436,6 +436,13 @@ class Crew(FlowTrackable, BaseModel):
if agent.agent_executor is not None and task.output is None:
agent.agent_executor.task = task
break
for task in self.tasks:
if task.checkpoint_original_description is not None:
task._original_description = task.checkpoint_original_description
if task.checkpoint_original_expected_output is not None:
task._original_expected_output = (
task.checkpoint_original_expected_output
)
if self.checkpoint_inputs is not None:
self._inputs = self.checkpoint_inputs
if self.checkpoint_kickoff_event_id is not None:

View File

@@ -213,6 +213,9 @@ class CheckpointConfig(BaseModel):
def _register_handlers(self) -> CheckpointConfig:
from crewai.state.checkpoint_listener import _ensure_handlers_registered
if isinstance(self.provider, SqliteProvider) and not Path(self.location).suffix:
self.location = f"{self.location}.db"
_ensure_handlers_registered()
return self

View File

@@ -7,6 +7,7 @@ avoids per-event overhead when no entity uses checkpointing.
from __future__ import annotations
import json
import logging
import threading
from typing import Any
@@ -102,10 +103,15 @@ def _find_checkpoint(source: Any) -> CheckpointConfig | None:
return None
def _do_checkpoint(state: RuntimeState, cfg: CheckpointConfig) -> None:
def _do_checkpoint(
state: RuntimeState, cfg: CheckpointConfig, event: BaseEvent | None = None
) -> None:
"""Write a checkpoint and prune old ones if configured."""
_prepare_entities(state.root)
data = state.model_dump_json()
payload = state.model_dump(mode="json")
if event is not None:
payload["trigger"] = event.type
data = json.dumps(payload)
location = cfg.provider.checkpoint(
data,
cfg.location,
@@ -134,7 +140,7 @@ def _on_any_event(source: Any, event: BaseEvent, state: Any) -> None:
if cfg is None:
return
try:
_do_checkpoint(state, cfg)
_do_checkpoint(state, cfg, event)
except Exception:
logger.warning("Auto-checkpoint failed for event %s", event.type, exc_info=True)

View File

@@ -66,6 +66,9 @@ def _sync_checkpoint_fields(entity: object) -> None:
entity.checkpoint_inputs = entity._inputs
entity.checkpoint_train = entity._train
entity.checkpoint_kickoff_event_id = entity._kickoff_event_id
for task in entity.tasks:
task.checkpoint_original_description = task._original_description
task.checkpoint_original_expected_output = task._original_expected_output
def _migrate(data: dict[str, Any]) -> dict[str, Any]:
@@ -121,7 +124,7 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
"parent_id": self._parent_id,
"branch": self._branch,
"entities": [e.model_dump(mode="json") for e in self.root],
"event_record": self._event_record.model_dump(),
"event_record": self._event_record.model_dump(mode="json"),
}
@model_validator(mode="wrap")
@@ -194,7 +197,10 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
return result
def fork(self, branch: str | None = None) -> None:
"""Mark this state as a fork for subsequent checkpoints.
"""Create a new execution branch and write an initial checkpoint.
If this state was restored from a checkpoint, an initial checkpoint
is written on the new branch so the fork point is recorded.
Args:
branch: Branch label. Auto-generated from the current checkpoint
@@ -204,7 +210,7 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
if branch:
self._branch = branch
elif self._checkpoint_id:
self._branch = f"fork/{self._checkpoint_id}"
self._branch = f"fork/{self._checkpoint_id}_{uuid.uuid4().hex[:6]}"
else:
self._branch = f"fork/{uuid.uuid4().hex[:8]}"
@@ -227,6 +233,7 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
provider = detect_provider(location)
raw = provider.from_checkpoint(location)
state = cls.model_validate_json(raw, **kwargs)
state._provider = provider
checkpoint_id = provider.extract_id(location)
state._checkpoint_id = checkpoint_id
state._parent_id = checkpoint_id
@@ -253,6 +260,7 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
provider = detect_provider(location)
raw = await provider.afrom_checkpoint(location)
state = cls.model_validate_json(raw, **kwargs)
state._provider = provider
checkpoint_id = provider.extract_id(location)
state._checkpoint_id = checkpoint_id
state._parent_id = checkpoint_id

View File

@@ -230,6 +230,8 @@ class Task(BaseModel):
_original_description: str | None = PrivateAttr(default=None)
_original_expected_output: str | None = PrivateAttr(default=None)
_original_output_file: str | None = PrivateAttr(default=None)
checkpoint_original_description: str | None = Field(default=None, exclude=False)
checkpoint_original_expected_output: str | None = Field(default=None, exclude=False)
_thread: threading.Thread | None = PrivateAttr(default=None)
model_config = {"arbitrary_types_allowed": True}

View File

@@ -296,7 +296,8 @@ class TestRuntimeStateLineage:
state = self._make_state()
state._checkpoint_id = "20260409T120000_abc12345"
state.fork()
assert state._branch == "fork/20260409T120000_abc12345"
assert state._branch.startswith("fork/20260409T120000_abc12345_")
assert len(state._branch) == len("fork/20260409T120000_abc12345_") + 6
def test_fork_no_checkpoint_id_unique(self) -> None:
state = self._make_state()