"""Textual TUI for browsing checkpoint files.""" from __future__ import annotations from collections import defaultdict from datetime import datetime from typing import Any, ClassVar, Literal from textual.app import App, ComposeResult from textual.binding import Binding from textual.containers import Horizontal, Vertical, VerticalScroll from textual.widgets import ( Collapsible, Footer, Header, Input, Static, TabPane, TabbedContent, TextArea, Tree, ) from crewai_cli.checkpoint_cli import ( _format_size, _is_sqlite, _list_json, _list_sqlite, ) _PRIMARY = "#eb6658" _SECONDARY = "#1F7982" _TERTIARY = "#ffffff" _DIM = "#888888" _BG_DARK = "#0d1117" _BG_PANEL = "#161b22" _ACCENT = "#c9a227" _SUCCESS = "#3fb950" _PENDING = "#e3b341" _ENTITY_ICONS: dict[str, str] = { "flow": "◆", "crew": "●", "agent": "◈", "unknown": "○", } _ENTITY_COLORS: dict[str, str] = { "flow": _ACCENT, "crew": _SECONDARY, "agent": _PRIMARY, "unknown": _DIM, } def _load_entries(location: str) -> list[dict[str, Any]]: if _is_sqlite(location): return _list_sqlite(location) return _list_json(location) def _human_ts(ts: str) -> str: """Turn '2026-04-17 17:05:00' into a short relative label.""" try: dt = datetime.strptime(ts, "%Y-%m-%d %H:%M:%S") except ValueError: return ts now = datetime.now() delta = now.date() - dt.date() hour = dt.hour % 12 or 12 ampm = "am" if dt.hour < 12 else "pm" time_str = f"{hour}:{dt.minute:02d}{ampm}" if delta.days == 0: return time_str if delta.days == 1: return f"yest {time_str}" if delta.days < 7: return f"{dt.strftime('%a').lower()} {time_str}" return f"{dt.strftime('%b')} {dt.day}" def _short_id(name: str) -> str: if len(name) > 30: return name[:27] + "..." return name def _entry_id(entry: dict[str, Any]) -> str: """Normalize an entry's name into its checkpoint ID. JSON filenames are ``{ts}_{uuid}_p-{parent}.json``; SQLite IDs are already ``{ts}_{uuid}``. This strips the JSON suffix so fork-parent lookups work in both providers. """ name = str(entry.get("name", "")) if name.endswith(".json"): name = name[: -len(".json")] idx = name.find("_p-") if idx != -1: name = name[:idx] return name def _build_progress_bar(completed: int, total: int, width: int = 20) -> str: if total == 0: return f"[{_DIM}]{'░' * width}[/] 0/0" pct = int(completed / total * 100) filled = int(width * completed / total) color = _SUCCESS if completed == total else _PRIMARY bar = f"[{color}]{'█' * filled}[/][{_DIM}]{'░' * (width - filled)}[/]" return f"{bar} {completed}/{total} ({pct}%)" def _entity_icon(etype: str) -> str: icon = _ENTITY_ICONS.get(etype, _ENTITY_ICONS["unknown"]) color = _ENTITY_COLORS.get(etype, _DIM) return f"[{color}]{icon}[/]" _TuiResult = ( tuple[ str, str, dict[str, Any] | None, dict[int, str] | None, Literal["crew", "flow", "agent"], ] | None ) class CheckpointTUI(App[_TuiResult]): """TUI to browse and inspect checkpoints. Returns ``(location, action, inputs, task_overrides, entity_type)`` where action is ``"resume"`` or ``"fork"``, inputs is a parsed dict or ``None``, and entity_type is ``"crew"`` or ``"flow"``; or ``None`` if the user quit without selecting. """ TITLE = "CrewAI Checkpoints" CSS = f""" Screen {{ background: {_BG_DARK}; }} Header {{ background: {_PRIMARY}; color: {_TERTIARY}; }} Footer {{ background: {_SECONDARY}; color: {_TERTIARY}; }} Footer > .footer-key--key {{ background: {_PRIMARY}; color: {_TERTIARY}; }} #main-layout {{ height: 1fr; }} #tree-panel {{ width: 40%; background: {_BG_PANEL}; border: round {_SECONDARY}; padding: 0 1; scrollbar-color: {_PRIMARY}; }} #tree-panel:focus-within {{ border: round {_PRIMARY}; }} #detail-container {{ width: 60%; height: 1fr; }} #status {{ height: 1; padding: 0 2; color: {_DIM}; }} #detail-tabs {{ height: 1fr; }} TabbedContent > ContentSwitcher {{ background: {_BG_PANEL}; height: 1fr; }} TabPane {{ padding: 0; }} Tabs {{ background: {_BG_DARK}; }} Tab {{ background: {_BG_DARK}; color: {_DIM}; padding: 0 2; }} Tab.-active {{ background: {_BG_PANEL}; color: {_PRIMARY}; }} Tab:hover {{ color: {_TERTIARY}; }} Underline > .underline--bar {{ color: {_SECONDARY}; background: {_BG_DARK}; }} .tab-scroll {{ background: {_BG_PANEL}; height: 1fr; padding: 1 2; scrollbar-color: {_PRIMARY}; }} .section-header {{ padding: 0 0 0 1; margin: 1 0 0 0; }} .detail-line {{ padding: 0 0 0 1; }} .task-label {{ padding: 0 1; }} .task-output-editor {{ height: auto; max-height: 10; margin: 0 1 1 3; border: round {_DIM}; }} .task-output-editor:focus {{ border: round {_PRIMARY}; }} Collapsible {{ background: {_BG_PANEL}; padding: 0; margin: 0 0 1 1; }} CollapsibleTitle {{ background: {_BG_DARK}; color: {_TERTIARY}; padding: 0 1; }} CollapsibleTitle:hover {{ background: {_SECONDARY}; }} .input-row {{ height: 3; padding: 0 1; }} .input-row Static {{ width: auto; min-width: 12; padding: 1 1 0 0; color: {_TERTIARY}; }} .input-row Input {{ width: 1fr; }} .empty-state {{ color: {_DIM}; padding: 1; }} Tree {{ background: {_BG_PANEL}; }} Tree > .tree--cursor {{ background: {_SECONDARY}; color: {_TERTIARY}; }} """ BINDINGS: ClassVar[list[Binding | tuple[str, str] | tuple[str, str, str]]] = [ ("q", "quit", "Quit"), ("r", "refresh", "Refresh"), ("e", "resume", "Resume"), ("f", "fork", "Fork"), ] def __init__(self, location: str = "./.checkpoints") -> None: super().__init__() self._location = location self._entries: list[dict[str, Any]] = [] self._selected_entry: dict[str, Any] | None = None self._input_keys: list[str] = [] self._task_output_ids: list[tuple[int, str, str]] = [] def compose(self) -> ComposeResult: yield Header(show_clock=False) with Horizontal(id="main-layout"): tree: Tree[dict[str, Any]] = Tree("Checkpoints", id="tree-panel") tree.show_root = False tree.guide_depth = 3 yield tree with Vertical(id="detail-container"): yield Static("", id="status") with TabbedContent(id="detail-tabs"): with TabPane("Overview", id="tab-overview"): with VerticalScroll(classes="tab-scroll"): yield Static( f"[{_DIM}]Select a checkpoint from the tree[/]", # noqa: S608 id="overview-empty", ) with TabPane("Tasks", id="tab-tasks"): with VerticalScroll(classes="tab-scroll"): yield Static( f"[{_DIM}]Select a checkpoint to view tasks[/]", id="tasks-empty", ) with TabPane("Inputs", id="tab-inputs"): with VerticalScroll(classes="tab-scroll"): yield Static( f"[{_DIM}]Select a checkpoint to view inputs[/]", id="inputs-empty", ) yield Footer() async def on_mount(self) -> None: self._refresh_tree() self.query_one("#tree-panel", Tree).root.expand() # ── Tree building ────────────────────────────────────────────── @staticmethod def _top_level_entity(entry: dict[str, Any]) -> tuple[str, str]: etype, ename = "unknown", "" for ent in entry.get("entities", []): t = ent.get("type", "unknown") if t == "flow": return "flow", ent.get("name") or "" if t == "crew" and etype != "crew": etype, ename = "crew", ent.get("name") or "" return etype, ename def _refresh_tree(self) -> None: self._entries = _load_entries(self._location) self._selected_entry = None tree = self.query_one("#tree-panel", Tree) tree.clear() if not self._entries: self.sub_title = self._location self.query_one("#status", Static).update("") return grouped: dict[tuple[str, str], dict[str, list[dict[str, Any]]]] = defaultdict( lambda: defaultdict(list) ) for entry in self._entries: key = self._top_level_entity(entry) branch = entry.get("branch", "main") grouped[key][branch].append(entry) def _make_label(e: dict[str, Any]) -> str: ts = e.get("ts") or "" trigger = e.get("trigger") or "" time_part = ts.split(" ")[-1] if " " in ts else ts total_c, total_t = 0, 0 for ent in e.get("entities", []): c = ent.get("tasks_completed") t = ent.get("tasks_total") if c is not None and t is not None: total_c += c total_t += t parts: list[str] = [] if time_part: parts.append(f"[{_DIM}]{time_part}[/]") if trigger: parts.append(f"[{_PRIMARY}]{trigger}[/]") if total_t: display_c = total_c if trigger == "task_started" and total_c < total_t: display_c = total_c + 1 color = _SUCCESS if total_c == total_t else _DIM parts.append(f"[{color}]{display_c}/{total_t}[/]") return " ".join(parts) if parts else _short_id(e.get("name", "")) fork_parents: set[str] = set() for branches in grouped.values(): for branch_name, entries in branches.items(): if branch_name == "main" or not entries: continue oldest = min(entries, key=lambda e: str(e.get("name", ""))) first_parent = oldest.get("parent_id") if first_parent: fork_parents.add(str(first_parent)) node_by_name: dict[str, Any] = {} def _add_checkpoint(parent_node: Any, e: dict[str, Any]) -> None: cp_id = _entry_id(e) if cp_id in fork_parents: node = parent_node.add( _make_label(e), data=e, expand=False, allow_expand=True ) else: node = parent_node.add_leaf(_make_label(e), data=e) node_by_name[cp_id] = node type_order = {"flow": 0, "crew": 1} sorted_keys = sorted( grouped.keys(), key=lambda k: (type_order.get(k[0], 9), k[1]) ) for etype, ename in sorted_keys: branches = grouped[(etype, ename)] icon = _entity_icon(etype) color = _ENTITY_COLORS.get(etype, _DIM) total = sum(len(v) for v in branches.values()) label_parts = [f"{icon} [bold {color}]{etype.upper()}[/]"] if ename: label_parts.append(f"[bold]{ename}[/]") label_parts.append(f"[{_DIM}]({total})[/]") all_entries = [e for bl in branches.values() for e in bl] timestamps = [str(e.get("ts", "")) for e in all_entries if e.get("ts")] if timestamps: latest = max(timestamps) label_parts.append(f"[{_DIM}]{_human_ts(latest)}[/]") entity_label = " ".join(label_parts) entity_node = tree.root.add(entity_label, expand=True) if "main" in branches: for entry in reversed(branches["main"]): _add_checkpoint(entity_node, entry) fork_branches = [ (name, sorted(entries, key=lambda e: str(e.get("name", "")))) for name, entries in branches.items() if name != "main" ] remaining = fork_branches max_passes = len(remaining) + 1 while remaining and max_passes > 0: max_passes -= 1 deferred = [] made_progress = False for branch_name, entries in remaining: first_parent = entries[0].get("parent_id") if entries else None if first_parent and str(first_parent) not in node_by_name: deferred.append((branch_name, entries)) continue attach_to: Any = entity_node if first_parent: attach_to = node_by_name.get(str(first_parent), entity_node) branch_label = ( f"[bold {_SECONDARY}]{branch_name}[/] " f"[{_DIM}]({len(entries)})[/]" ) branch_node = attach_to.add(branch_label, expand=False) for entry in entries: _add_checkpoint(branch_node, entry) made_progress = True remaining = deferred if not made_progress: break for branch_name, entries in remaining: branch_label = ( f"[bold {_SECONDARY}]{branch_name}[/] " f"[{_DIM}]({len(entries)})[/] [{_DIM}](orphaned)[/]" ) branch_node = entity_node.add(branch_label, expand=False) for entry in entries: _add_checkpoint(branch_node, entry) count = len(self._entries) storage = "SQLite" if _is_sqlite(self._location) else "JSON" self.sub_title = self._location self.query_one("#status", Static).update(f" {count} checkpoint(s) | {storage}") # ── Detail panel ─────────────────────────────────────────────── async def _clear_scroll(self, tab_id: str) -> VerticalScroll: tab = self.query_one(f"#{tab_id}", TabPane) scroll = tab.query_one(VerticalScroll) for child in list(scroll.children): await child.remove() return scroll async def _show_detail(self, entry: dict[str, Any]) -> None: self._selected_entry = entry await self._render_overview(entry) await self._render_tasks(entry) await self._render_inputs(entry.get("inputs", {})) async def _render_overview(self, entry: dict[str, Any]) -> None: scroll = await self._clear_scroll("tab-overview") name = entry.get("name", "") ts = entry.get("ts") or "unknown" trigger = entry.get("trigger") or "" branch = entry.get("branch", "main") parent_id = entry.get("parent_id") header_lines = [ f"[bold {_PRIMARY}]{name}[/]", f"[{_DIM}]{'─' * 50}[/]", "", f" [bold]Time[/] {ts}", ] if "size" in entry: header_lines.append(f" [bold]Size[/] {_format_size(entry['size'])}") header_lines.append(f" [bold]Events[/] {entry.get('event_count', 0)}") if trigger: header_lines.append(f" [bold]Trigger[/] [{_PRIMARY}]{trigger}[/]") header_lines.append(f" [bold]Branch[/] [{_SECONDARY}]{branch}[/]") if parent_id: header_lines.append(f" [bold]Parent[/] [{_DIM}]{parent_id}[/]") await scroll.mount(Static("\n".join(header_lines))) for ent in entry.get("entities", []): etype = ent.get("type", "unknown") ename = ent.get("name", "unnamed") icon = _entity_icon(etype) color = _ENTITY_COLORS.get(etype, _DIM) eid = str(ent.get("id", ""))[:8] entity_title = ( f"\n{icon} [bold {color}]{etype.upper()}[/] [bold]{ename}[/]" ) if eid: entity_title += f" [{_DIM}]{eid}…[/]" await scroll.mount(Static(entity_title, classes="section-header")) await scroll.mount(Static(f"[{_DIM}]{'─' * 46}[/]", classes="detail-line")) if etype == "flow": methods = ent.get("completed_methods", []) if methods: method_list = ", ".join(f"[{_SUCCESS}]{m}[/]" for m in methods) await scroll.mount( Static( f" [bold]Methods[/] {method_list}", classes="detail-line", ) ) flow_state = ent.get("flow_state") if isinstance(flow_state, dict) and flow_state: state_parts: list[str] = [] for k, v in list(flow_state.items())[:5]: sv = str(v) if len(sv) > 40: sv = sv[:37] + "..." state_parts.append(f"[{_DIM}]{k}[/]={sv}") await scroll.mount( Static( f" [bold]State[/] {', '.join(state_parts)}", classes="detail-line", ) ) agents = ent.get("agents", []) if agents: agent_lines: list[Static] = [] for ag in agents: role = ag.get("role", "unnamed") goal = ag.get("goal", "") if len(goal) > 60: goal = goal[:57] + "..." agent_line = f" {_entity_icon('agent')} [bold]{role}[/]" if goal: agent_line += f"\n [{_DIM}]{goal}[/]" agent_lines.append(Static(agent_line)) collapsible = Collapsible( *agent_lines, title=f"Agents ({len(agents)})", collapsed=len(agents) > 3, ) await scroll.mount(collapsible) async def _render_tasks(self, entry: dict[str, Any]) -> None: scroll = await self._clear_scroll("tab-tasks") self._task_output_ids = [] flat_task_idx = 0 has_tasks = False for ent_idx, ent in enumerate(entry.get("entities", [])): etype = ent.get("type", "unknown") ename = ent.get("name", "unnamed") icon = _entity_icon(etype) color = _ENTITY_COLORS.get(etype, _DIM) tasks = ent.get("tasks", []) if not tasks: continue has_tasks = True completed = ent.get("tasks_completed", 0) total = ent.get("tasks_total", 0) await scroll.mount( Static( f"{icon} [bold {color}]{ename}[/] " f"{_build_progress_bar(completed, total, width=16)}", classes="section-header", ) ) for i, task in enumerate(tasks): desc = str(task.get("description", "")) if len(desc) > 50: desc = desc[:47] + "..." agent_role = task.get("agent_role", "") if task.get("completed"): status_icon = f"[{_SUCCESS}]✓[/]" task_line = f" {status_icon} {i + 1}. {desc}" if agent_role: task_line += ( f" [{_DIM}]→ {_entity_icon('agent')} {agent_role}[/]" ) await scroll.mount(Static(task_line, classes="task-label")) output_text = task.get("output", "") editor_id = f"task-output-{ent_idx}-{i}" await scroll.mount( TextArea( str(output_text), classes="task-output-editor", id=editor_id, ) ) self._task_output_ids.append( (flat_task_idx, editor_id, str(output_text)) ) else: status_icon = f"[{_PENDING}]○[/]" task_line = f" {status_icon} {i + 1}. {desc}" if agent_role: task_line += ( f" [{_DIM}]→ {_entity_icon('agent')} {agent_role}[/]" ) await scroll.mount(Static(task_line, classes="task-label")) flat_task_idx += 1 if not has_tasks: await scroll.mount(Static(f"[{_DIM}]No tasks[/]", classes="empty-state")) async def _render_inputs(self, inputs: dict[str, Any]) -> None: scroll = await self._clear_scroll("tab-inputs") self._input_keys = [] if not inputs: await scroll.mount(Static(f"[{_DIM}]No inputs[/]", classes="empty-state")) return for key, value in inputs.items(): self._input_keys.append(key) row = Horizontal(classes="input-row") row.compose_add_child(Static(f"[bold]{key}[/]")) row.compose_add_child( Input(value=str(value), placeholder=key, id=f"input-{key}") ) await scroll.mount(row) # ── Data collection ──────────────────────────────────────────── def _collect_inputs(self) -> dict[str, Any] | None: if not self._input_keys: return None result: dict[str, Any] = {} for key in self._input_keys: widget = self.query_one(f"#input-{key}", Input) result[key] = widget.value return result def _collect_task_overrides(self) -> dict[int, str] | None: if not self._task_output_ids or self._selected_entry is None: return None overrides: dict[int, str] = {} for task_idx, editor_id, original in self._task_output_ids: editor = self.query_one(f"#{editor_id}", TextArea) if editor.text != original: overrides[task_idx] = editor.text return overrides or None def _detect_entity_type( self, entry: dict[str, Any] ) -> Literal["crew", "flow", "agent"]: for ent in entry.get("entities", []): if ent.get("type") == "flow": return "flow" if ent.get("type") == "agent": return "agent" return "crew" def _resolve_location(self, entry: dict[str, Any]) -> str: if "path" in entry: return str(entry["path"]) if _is_sqlite(self._location): return f"{self._location}#{entry['name']}" return str(entry.get("name", "")) # ── Events ───────────────────────────────────────────────────── async def on_tree_node_highlighted( self, event: Tree.NodeHighlighted[dict[str, Any]] ) -> None: if event.node.data is not None: await self._show_detail(event.node.data) def _exit_with_action(self, action: str) -> None: if self._selected_entry is None: self.notify("No checkpoint selected", severity="warning") return inputs = self._collect_inputs() overrides = self._collect_task_overrides() loc = self._resolve_location(self._selected_entry) etype = self._detect_entity_type(self._selected_entry) name = self._selected_entry.get("name", "")[:30] self.notify(f"{action.title()}: {name}") self.exit((loc, action, inputs, overrides, etype)) def action_resume(self) -> None: self._exit_with_action("resume") def action_fork(self) -> None: self._exit_with_action("fork") def action_refresh(self) -> None: self._refresh_tree() def _apply_task_overrides(crew: Any, task_overrides: dict[int, str]) -> None: """Apply task output overrides to a restored Crew and print modifications.""" import click click.echo("Modifications:") overridden_agents: set[int] = set() for task_idx, new_output in task_overrides.items(): if task_idx < len(crew.tasks) and crew.tasks[task_idx].output is not None: desc = crew.tasks[task_idx].description or f"Task {task_idx + 1}" if len(desc) > 60: desc = desc[:57] + "..." crew.tasks[task_idx].output.raw = new_output preview = new_output.replace("\n", " ") if len(preview) > 80: preview = preview[:77] + "..." click.echo(f" Task {task_idx + 1}: {desc}") click.echo(f" -> {preview}") agent = crew.tasks[task_idx].agent if agent and agent.agent_executor: nth = sum(1 for t in crew.tasks[:task_idx] if t.agent is agent) messages = agent.agent_executor.messages system_positions = [ i for i, m in enumerate(messages) if m.get("role") == "system" ] if nth < len(system_positions): seg_start = system_positions[nth] seg_end = ( system_positions[nth + 1] if nth + 1 < len(system_positions) else len(messages) ) for j in range(seg_end - 1, seg_start, -1): if messages[j].get("role") == "assistant": messages[j]["content"] = new_output break overridden_agents.add(id(agent)) earliest = min(task_overrides) for offset, subsequent in enumerate(crew.tasks[earliest + 1 :], start=earliest + 1): if subsequent.output and offset not in task_overrides: subsequent.output = None if subsequent.agent and subsequent.agent.agent_executor: subsequent.agent.agent_executor._resuming = False if id(subsequent.agent) not in overridden_agents: subsequent.agent.agent_executor.messages = [] click.echo() async def _run_checkpoint_tui_async(location: str) -> None: """Async implementation of the checkpoint TUI flow.""" import click app = CheckpointTUI(location=location) selection = await app.run_async() if selection is None: return selected, action, inputs, task_overrides, entity_type = selection from crewai.state.checkpoint_config import CheckpointConfig config = CheckpointConfig(restore_from=selected) if entity_type == "flow": from crewai.events.event_bus import crewai_event_bus from crewai.flow.flow import Flow if action == "fork": click.echo(f"\nForking flow from: {selected}\n") flow = Flow.fork(config) else: click.echo(f"\nResuming flow from: {selected}\n") flow = Flow.from_checkpoint(config) if task_overrides: from crewai.crew import Crew as CrewCls state = crewai_event_bus._runtime_state if state is not None: flat_offset = 0 for entity in state.root: if not isinstance(entity, CrewCls) or not entity.tasks: continue n = len(entity.tasks) local = { idx - flat_offset: out for idx, out in task_overrides.items() if flat_offset <= idx < flat_offset + n } if local: _apply_task_overrides(entity, local) flat_offset += n if inputs: click.echo("Inputs:") for k, v in inputs.items(): click.echo(f" {k}: {v}") click.echo() result = await flow.kickoff_async(inputs=inputs) click.echo(f"\nResult: {getattr(result, 'raw', result)}") return if entity_type == "agent": from crewai.agent import Agent if action == "fork": click.echo(f"\nForking agent from: {selected}\n") agent = Agent.fork(config) else: click.echo(f"\nResuming agent from: {selected}\n") agent = Agent.from_checkpoint(config) click.echo() result = await agent.akickoff(messages="Resume execution.") click.echo(f"\nResult: {getattr(result, 'raw', result)}") return from crewai.crew import Crew if action == "fork": click.echo(f"\nForking from: {selected}\n") crew = Crew.fork(config) else: click.echo(f"\nResuming from: {selected}\n") crew = Crew.from_checkpoint(config) if task_overrides: _apply_task_overrides(crew, task_overrides) if inputs: click.echo("Inputs:") for k, v in inputs.items(): click.echo(f" {k}: {v}") click.echo() result = await crew.akickoff(inputs=inputs) click.echo(f"\nResult: {getattr(result, 'raw', result)}") def run_checkpoint_tui(location: str = "./.checkpoints") -> None: """Launch the checkpoint browser TUI.""" import asyncio asyncio.run(_run_checkpoint_tui_async(location))