feat: store trigger in checkpoint data, editable inputs/outputs in TUI

Write the actual triggering event type into checkpoint JSON at write
time instead of inferring from the event record. Adds editable input
fields and task output overrides to the TUI for what-if exploration.
Unique fork branch names prevent collisions on repeated forks.
This commit is contained in:
Greyson LaLonde
2026-04-10 07:22:17 +08:00
parent 2647f73150
commit fc041354b1
7 changed files with 274 additions and 57 deletions

View File

@@ -6,12 +6,16 @@ from datetime import datetime
import glob
import json
import os
import re
import sqlite3
from typing import Any
import click
_PLACEHOLDER_RE = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\-]*)}")
_SQLITE_MAGIC = b"SQLite format 3\x00"
_SELECT_ALL = """
@@ -71,13 +75,7 @@ def _parse_checkpoint_json(raw: str, source: str) -> dict[str, Any]:
nodes = data.get("event_record", {}).get("nodes", {})
event_count = len(nodes)
trigger_event = None
if nodes:
last_node = max(
nodes.values(),
key=lambda n: n.get("event", {}).get("emission_sequence") or 0,
)
trigger_event = last_node.get("event", {}).get("type")
trigger_event = data.get("trigger")
parsed_entities: list[dict[str, Any]] = []
for entity in entities:
@@ -95,21 +93,51 @@ def _parse_checkpoint_json(raw: str, source: str) -> dict[str, Any]:
{
"description": t.get("description", ""),
"completed": t.get("output") is not None,
"output": (t.get("output") or {}).get("raw", ""),
}
for t in tasks
]
parsed_entities.append(info)
inputs: dict[str, Any] = {}
for entity in entities:
cp_inputs = entity.get("checkpoint_inputs")
if isinstance(cp_inputs, dict) and cp_inputs:
inputs = dict(cp_inputs)
break
for entity in entities:
for task in entity.get("tasks", []):
for field in (
"checkpoint_original_description",
"checkpoint_original_expected_output",
):
text = task.get(field) or ""
for match in _PLACEHOLDER_RE.findall(text):
if match not in inputs:
inputs[match] = ""
for agent in entity.get("agents", []):
for field in ("role", "goal", "backstory"):
text = agent.get(field) or ""
for match in _PLACEHOLDER_RE.findall(text):
if match not in inputs:
inputs[match] = ""
branch = data.get("branch", "main")
parent_id = data.get("parent_id")
return {
"source": source,
"event_count": event_count,
"trigger": trigger_event,
"entities": parsed_entities,
"branch": data.get("branch", "main"),
"parent_id": data.get("parent_id"),
"branch": branch,
"parent_id": parent_id,
"inputs": inputs,
}
def _format_size(size: int) -> str:
if size < 1024:
return f"{size}B"

View File

@@ -10,10 +10,11 @@ from textual.binding import Binding
from textual.containers import Horizontal, Vertical, VerticalScroll
from textual.widgets import (
Button,
Collapsible,
Footer,
Header,
Input,
Static,
TextArea,
Tree,
)
@@ -39,6 +40,7 @@ def _load_entries(location: str) -> list[dict[str, Any]]:
return _list_json(location)
def _short_id(name: str) -> str:
"""Shorten a checkpoint name for tree display."""
if len(name) > 30:
@@ -46,14 +48,9 @@ def _short_id(name: str) -> str:
return name
def _build_entity_detail(ent: dict[str, Any]) -> str:
"""Build rich text for a single entity."""
def _build_entity_header(ent: dict[str, Any]) -> str:
"""Build rich text header for an entity (progress bar only)."""
lines: list[str] = []
eid = str(ent.get("id", ""))[:8]
etype = ent.get("type", "unknown")
ename = ent.get("name", "unnamed")
lines.append(f"[bold {_SECONDARY}]{etype}[/]: {ename} [{_DIM}]{eid}[/]")
tasks = ent.get("tasks")
if isinstance(tasks, list):
completed = ent.get("tasks_completed", 0)
@@ -63,22 +60,19 @@ def _build_entity_detail(ent: dict[str, Any]) -> str:
filled = int(bar_len * completed / total) if total else 0
bar = f"[{_PRIMARY}]{'' * filled}[/][{_DIM}]{'' * (bar_len - filled)}[/]"
lines.append(f"{bar} {completed}/{total} tasks ({pct}%)")
lines.append("")
for i, task in enumerate(tasks):
icon = "[green]✓[/]" if task.get("completed") else "[yellow]○[/]"
desc = str(task.get("description", ""))
if len(desc) > 55:
desc = desc[:52] + "..."
lines.append(f" {icon} {i + 1}. {desc}")
return "\n".join(lines)
class CheckpointTUI(App[tuple[str, str] | None]):
# Return type: (location, action, inputs, task_output_overrides)
_TuiResult = tuple[str, str, dict[str, Any] | None, dict[int, str] | None] | None
class CheckpointTUI(App[_TuiResult]):
"""TUI to browse and inspect checkpoints.
Returns ``(location, action)`` where action is ``"resume"`` or
``"fork"``, or ``None`` if the user quit without selecting.
Returns ``(location, action, inputs)`` where action is ``"resume"`` or
``"fork"`` and inputs is a parsed dict or ``None``,
or ``None`` if the user quit without selecting.
"""
TITLE = "CrewAI Checkpoints"
@@ -114,6 +108,7 @@ class CheckpointTUI(App[tuple[str, str] | None]):
}}
#detail-container {{
width: 55%;
height: 1fr;
}}
#detail-scroll {{
height: 1fr;
@@ -133,9 +128,41 @@ class CheckpointTUI(App[tuple[str, str] | None]):
padding: 0 2;
color: {_DIM};
}}
#inputs-section {{
display: none;
height: auto;
max-height: 8;
padding: 0 1;
}}
#inputs-section.visible {{
display: block;
}}
#inputs-label {{
height: 1;
color: {_DIM};
padding: 0 1;
}}
.input-row {{
height: 3;
padding: 0 1;
}}
.input-row Static {{
width: auto;
min-width: 12;
padding: 1 1 0 0;
color: {_TERTIARY};
}}
.input-row Input {{
width: 1fr;
}}
#no-inputs-label {{
height: 1;
color: {_DIM};
padding: 0 1;
}}
#action-buttons {{
height: 3;
align: center middle;
align: right middle;
padding: 0 1;
display: none;
}}
@@ -143,8 +170,8 @@ class CheckpointTUI(App[tuple[str, str] | None]):
display: block;
}}
#action-buttons Button {{
margin: 0 1;
min-width: 14;
margin: 0 0 0 1;
min-width: 10;
}}
#btn-resume {{
background: {_SECONDARY};
@@ -160,13 +187,24 @@ class CheckpointTUI(App[tuple[str, str] | None]):
#btn-fork:hover {{
background: {_SECONDARY};
}}
Collapsible {{
margin: 0;
padding: 0;
.entity-title {{
padding: 1 1 0 1;
}}
.entity-detail {{
padding: 0 1;
}}
.task-output-editor {{
height: auto;
max-height: 10;
margin: 0 1 1 1;
border: round {_DIM};
}}
.task-output-editor:focus {{
border: round {_PRIMARY};
}}
.task-label {{
padding: 0 1;
}}
Tree {{
background: {_BG_PANEL};
}}
@@ -186,6 +224,7 @@ class CheckpointTUI(App[tuple[str, str] | None]):
self._location = location
self._entries: list[dict[str, Any]] = []
self._selected_entry: dict[str, Any] | None = None
self._input_keys: list[str] = []
def compose(self) -> ComposeResult:
yield Header(show_clock=False)
@@ -201,6 +240,8 @@ class CheckpointTUI(App[tuple[str, str] | None]):
f"[{_DIM}]Select a checkpoint from the tree[/]", # noqa: S608
id="detail-header",
)
with Vertical(id="inputs-section"):
yield Static("Inputs", id="inputs-label")
with Horizontal(id="action-buttons"):
yield Button("Resume", id="btn-resume")
yield Button("Fork", id="btn-fork")
@@ -295,6 +336,7 @@ class CheckpointTUI(App[tuple[str, str] | None]):
if first_parent
else tree.root
)
# Drop fork-initial checkpoint — it's just a lineage marker
branch_label = (
f"[bold {_SECONDARY}]{branch_name}[/] [{_DIM}]({len(entries)})[/]"
)
@@ -308,16 +350,17 @@ class CheckpointTUI(App[tuple[str, str] | None]):
self.sub_title = self._location
self.query_one("#status", Static).update(f" {count} checkpoint(s) | {storage}")
def _show_detail(self, entry: dict[str, Any]) -> None:
async def _show_detail(self, entry: dict[str, Any]) -> None:
"""Update the detail panel for a checkpoint entry."""
self._selected_entry = entry
self.query_one("#action-buttons").add_class("visible")
detail_scroll = self.query_one("#detail-scroll", VerticalScroll)
# Remove old collapsibles
for widget in list(detail_scroll.query("Collapsible")):
widget.remove()
# Remove all dynamic children except the header — await so IDs are freed
to_remove = [c for c in detail_scroll.children if c.id != "detail-header"]
for child in to_remove:
await child.remove()
# Header
name = entry.get("name", "")
@@ -347,19 +390,102 @@ class CheckpointTUI(App[tuple[str, str] | None]):
self.query_one("#detail-header", Static).update("\n".join(header_lines))
# Entity collapsibles
for ent in entry.get("entities", []):
# Entity details and editable task outputs — mounted flat for scrolling
self._task_output_ids = []
for ent_idx, ent in enumerate(entry.get("entities", [])):
etype = ent.get("type", "unknown")
ename = ent.get("name", "unnamed")
completed = ent.get("tasks_completed")
total = ent.get("tasks_total")
title = f"{etype}: {ename}"
entity_title = f"[bold {_SECONDARY}]{etype}: {ename}[/]"
if completed is not None and total is not None:
title += f" [{completed}/{total} tasks]"
entity_title += f" [{_DIM}]{completed}/{total} tasks[/]"
detail_scroll.mount(Static(entity_title, classes="entity-title"))
detail_scroll.mount(
Static(_build_entity_header(ent), classes="entity-detail")
)
content = Static(_build_entity_detail(ent), classes="entity-detail")
collapsible = Collapsible(content, title=title, collapsed=False)
detail_scroll.mount(collapsible)
tasks = ent.get("tasks", [])
for i, task in enumerate(tasks):
desc = str(task.get("description", ""))
if len(desc) > 55:
desc = desc[:52] + "..."
if task.get("completed"):
icon = "[green]✓[/]"
detail_scroll.mount(
Static(f" {icon} {i + 1}. {desc}", classes="task-label")
)
output_text = task.get("output", "")
editor_id = f"task-output-{ent_idx}-{i}"
detail_scroll.mount(
TextArea(
str(output_text),
classes="task-output-editor",
id=editor_id,
)
)
self._task_output_ids.append((i, editor_id))
else:
icon = "[yellow]○[/]"
detail_scroll.mount(
Static(f" {icon} {i + 1}. {desc}", classes="task-label")
)
# Build input fields
await self._build_input_fields(entry.get("inputs", {}))
async def _build_input_fields(self, inputs: dict[str, Any]) -> None:
"""Rebuild the inputs section with one field per input key."""
section = self.query_one("#inputs-section")
# Remove old dynamic children — await so IDs are freed
for widget in list(section.query(".input-row, .no-inputs")):
await widget.remove()
self._input_keys = []
if not inputs:
section.mount(Static(f"[{_DIM}]No inputs[/]", classes="no-inputs"))
section.add_class("visible")
return
for key, value in inputs.items():
self._input_keys.append(key)
row = Horizontal(classes="input-row")
row.compose_add_child(Static(f"[bold]{key}[/]"))
row.compose_add_child(
Input(value=str(value), placeholder=key, id=f"input-{key}")
)
section.mount(row)
section.add_class("visible")
def _collect_inputs(self) -> dict[str, Any] | None:
"""Collect current values from input fields."""
if not self._input_keys:
return None
result: dict[str, Any] = {}
for key in self._input_keys:
widget = self.query_one(f"#input-{key}", Input)
result[key] = widget.value
return result
def _collect_task_overrides(self) -> dict[int, str] | None:
"""Collect edited task outputs. Returns only changed values."""
if not self._task_output_ids or self._selected_entry is None:
return None
overrides: dict[int, str] = {}
tasks = []
for ent in self._selected_entry.get("entities", []):
tasks = ent.get("tasks", [])
if tasks:
break
for task_idx, editor_id in self._task_output_ids:
editor = self.query_one(f"#{editor_id}", TextArea)
original = str(tasks[task_idx].get("output", "")) if task_idx < len(tasks) else ""
if editor.text != original:
overrides[task_idx] = editor.text
return overrides or None
def _resolve_location(self, entry: dict[str, Any]) -> str:
"""Get the restore location string for a checkpoint entry."""
@@ -373,16 +499,18 @@ class CheckpointTUI(App[tuple[str, str] | None]):
self, event: Tree.NodeHighlighted[dict[str, Any]]
) -> None:
if event.node.data is not None:
self._show_detail(event.node.data)
await self._show_detail(event.node.data)
def on_button_pressed(self, event: Button.Pressed) -> None:
if self._selected_entry is None:
return
inputs = self._collect_inputs()
overrides = self._collect_task_overrides()
loc = self._resolve_location(self._selected_entry)
if event.button.id == "btn-resume":
self.exit((loc, "resume"))
self.exit((loc, "resume", inputs, overrides))
elif event.button.id == "btn-fork":
self.exit((loc, "fork"))
self.exit((loc, "fork", inputs, overrides))
def action_refresh(self) -> None:
self._refresh_tree()
@@ -398,7 +526,7 @@ async def _run_checkpoint_tui_async(location: str) -> None:
if selection is None:
return
selected, action = selection
selected, action, inputs, task_overrides = selection
from crewai.crew import Crew
from crewai.state.checkpoint_config import CheckpointConfig
@@ -412,7 +540,44 @@ async def _run_checkpoint_tui_async(location: str) -> None:
click.echo(f"\nResuming from: {selected}\n")
crew = Crew.from_checkpoint(config)
result = await crew.akickoff()
# Apply task output overrides before kickoff
if task_overrides:
click.echo("Modifications:")
for task_idx, new_output in task_overrides.items():
if task_idx < len(crew.tasks) and crew.tasks[task_idx].output is not None:
desc = crew.tasks[task_idx].description or f"Task {task_idx + 1}"
if len(desc) > 60:
desc = desc[:57] + "..."
crew.tasks[task_idx].output.raw = new_output # type: ignore[union-attr]
preview = new_output.replace("\n", " ")
if len(preview) > 80:
preview = preview[:77] + "..."
click.echo(f" Task {task_idx + 1}: {desc}")
click.echo(f" -> {preview}")
# Update the assistant message in the executor's history
# so the LLM sees the override in its conversation
agent = crew.tasks[task_idx].agent
if agent and agent.agent_executor:
for msg in reversed(agent.agent_executor.messages):
if msg.get("role") == "assistant":
msg["content"] = new_output
break
# Invalidate all subsequent tasks so they re-run with
# the modified context instead of using cached results
for subsequent in crew.tasks[task_idx + 1 :]:
subsequent.output = None
if subsequent.agent and subsequent.agent.agent_executor:
subsequent.agent.agent_executor._resuming = False
subsequent.agent.agent_executor.messages = []
click.echo()
if inputs:
click.echo("Inputs:")
for k, v in inputs.items():
click.echo(f" {k}: {v}")
click.echo()
result = await crew.akickoff(inputs=inputs)
click.echo(f"\nResult: {getattr(result, 'raw', result)}")

View File

@@ -436,6 +436,11 @@ class Crew(FlowTrackable, BaseModel):
if agent.agent_executor is not None and task.output is None:
agent.agent_executor.task = task
break
for task in self.tasks:
if task.checkpoint_original_description is not None:
task._original_description = task.checkpoint_original_description
if task.checkpoint_original_expected_output is not None:
task._original_expected_output = task.checkpoint_original_expected_output
if self.checkpoint_inputs is not None:
self._inputs = self.checkpoint_inputs
if self.checkpoint_kickoff_event_id is not None:

View File

@@ -102,8 +102,12 @@ def _find_checkpoint(source: Any) -> CheckpointConfig | None:
return None
def _do_checkpoint(state: RuntimeState, cfg: CheckpointConfig) -> None:
def _do_checkpoint(
state: RuntimeState, cfg: CheckpointConfig, event: BaseEvent | None = None
) -> None:
"""Write a checkpoint and prune old ones if configured."""
if event is not None:
state._trigger = event.type
_prepare_entities(state.root)
data = state.model_dump_json()
location = cfg.provider.checkpoint(
@@ -134,7 +138,7 @@ def _on_any_event(source: Any, event: BaseEvent, state: Any) -> None:
if cfg is None:
return
try:
_do_checkpoint(state, cfg)
_do_checkpoint(state, cfg, event)
except Exception:
logger.warning("Auto-checkpoint failed for event %s", event.type, exc_info=True)

View File

@@ -80,6 +80,9 @@ def _sync_checkpoint_fields(entity: object) -> None:
entity.checkpoint_inputs = entity._inputs
entity.checkpoint_train = entity._train
entity.checkpoint_kickoff_event_id = entity._kickoff_event_id
for task in entity.tasks:
task.checkpoint_original_description = task._original_description
task.checkpoint_original_expected_output = task._original_expected_output
def _migrate(data: dict[str, Any]) -> dict[str, Any]:
@@ -123,6 +126,7 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
_parent_id: str | None = PrivateAttr(default=None)
_branch: str = PrivateAttr(default="main")
_location: str | None = PrivateAttr(default=None)
_trigger: str | None = PrivateAttr(default=None)
@property
def event_record(self) -> EventRecord:
@@ -131,13 +135,16 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
@model_serializer(mode="plain")
def _serialize(self) -> dict[str, Any]:
return {
d: dict[str, Any] = {
"crewai_version": get_crewai_version(),
"parent_id": self._parent_id,
"branch": self._branch,
"entities": [e.model_dump(mode="json") for e in self.root],
"event_record": self._event_record.model_dump(),
}
if self._trigger:
d["trigger"] = self._trigger
return d
@model_validator(mode="wrap")
@classmethod
@@ -222,13 +229,10 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
if branch:
self._branch = branch
elif self._checkpoint_id:
self._branch = f"fork/{self._checkpoint_id}"
self._branch = f"fork/{self._checkpoint_id}_{uuid.uuid4().hex[:6]}"
else:
self._branch = f"fork/{uuid.uuid4().hex[:8]}"
if self._location is not None:
self.checkpoint(self._location)
@classmethod
def from_checkpoint(cls, config: CheckpointConfig, **kwargs: Any) -> RuntimeState:
"""Restore a RuntimeState from a checkpoint.

View File

@@ -230,6 +230,8 @@ class Task(BaseModel):
_original_description: str | None = PrivateAttr(default=None)
_original_expected_output: str | None = PrivateAttr(default=None)
_original_output_file: str | None = PrivateAttr(default=None)
checkpoint_original_description: str | None = Field(default=None, exclude=False)
checkpoint_original_expected_output: str | None = Field(default=None, exclude=False)
_thread: threading.Thread | None = PrivateAttr(default=None)
model_config = {"arbitrary_types_allowed": True}