From fc041354b143f8a557a2721baf6bb546112f8496 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Fri, 10 Apr 2026 07:22:17 +0800 Subject: [PATCH] feat: store trigger in checkpoint data, editable inputs/outputs in TUI Write the actual triggering event type into checkpoint JSON at write time instead of inferring from the event record. Adds editable input fields and task output overrides to the TUI for what-if exploration. Unique fork branch names prevent collisions on repeated forks. --- docs/en/concepts/checkpointing.mdx | 9 + lib/crewai/src/crewai/cli/checkpoint_cli.py | 46 +++- lib/crewai/src/crewai/cli/checkpoint_tui.py | 247 +++++++++++++++--- lib/crewai/src/crewai/crew.py | 5 + .../src/crewai/state/checkpoint_listener.py | 8 +- lib/crewai/src/crewai/state/runtime.py | 14 +- lib/crewai/src/crewai/task.py | 2 + 7 files changed, 274 insertions(+), 57 deletions(-) diff --git a/docs/en/concepts/checkpointing.mdx b/docs/en/concepts/checkpointing.mdx index d1f5fecda..d6430eb6f 100644 --- a/docs/en/concepts/checkpointing.mdx +++ b/docs/en/concepts/checkpointing.mdx @@ -282,6 +282,15 @@ crewai checkpoint --location ./.checkpoints.db The left panel is a tree view. Checkpoints are grouped by branch, and forks nest under the checkpoint they diverged from. Select a checkpoint to see its metadata, entity state, and task progress in the detail panel. Hit **Resume** to pick up where it left off, or **Fork** to start a new branch from that point. +### Editing inputs and task outputs + +When a checkpoint is selected, the detail panel shows: + +- **Inputs** — if the original kickoff had inputs (e.g. `{topic}`), they appear as editable fields pre-filled with the original values. Change them before resuming or forking. +- **Task outputs** — completed tasks show their output in editable text areas. Edit a task's output to change the context that downstream tasks receive. When you modify a task output and hit Fork, all subsequent tasks are invalidated and re-run with the new context. + +This is useful for "what if" exploration — fork from a checkpoint, tweak a task's result, and see how it changes downstream behavior. + ### Subcommands ```bash diff --git a/lib/crewai/src/crewai/cli/checkpoint_cli.py b/lib/crewai/src/crewai/cli/checkpoint_cli.py index 9a6fcd7ea..8469483ad 100644 --- a/lib/crewai/src/crewai/cli/checkpoint_cli.py +++ b/lib/crewai/src/crewai/cli/checkpoint_cli.py @@ -6,12 +6,16 @@ from datetime import datetime import glob import json import os +import re import sqlite3 from typing import Any import click +_PLACEHOLDER_RE = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\-]*)}") + + _SQLITE_MAGIC = b"SQLite format 3\x00" _SELECT_ALL = """ @@ -71,13 +75,7 @@ def _parse_checkpoint_json(raw: str, source: str) -> dict[str, Any]: nodes = data.get("event_record", {}).get("nodes", {}) event_count = len(nodes) - trigger_event = None - if nodes: - last_node = max( - nodes.values(), - key=lambda n: n.get("event", {}).get("emission_sequence") or 0, - ) - trigger_event = last_node.get("event", {}).get("type") + trigger_event = data.get("trigger") parsed_entities: list[dict[str, Any]] = [] for entity in entities: @@ -95,21 +93,51 @@ def _parse_checkpoint_json(raw: str, source: str) -> dict[str, Any]: { "description": t.get("description", ""), "completed": t.get("output") is not None, + "output": (t.get("output") or {}).get("raw", ""), } for t in tasks ] parsed_entities.append(info) + inputs: dict[str, Any] = {} + for entity in entities: + cp_inputs = entity.get("checkpoint_inputs") + if isinstance(cp_inputs, dict) and cp_inputs: + inputs = dict(cp_inputs) + break + + for entity in entities: + for task in entity.get("tasks", []): + for field in ( + "checkpoint_original_description", + "checkpoint_original_expected_output", + ): + text = task.get(field) or "" + for match in _PLACEHOLDER_RE.findall(text): + if match not in inputs: + inputs[match] = "" + for agent in entity.get("agents", []): + for field in ("role", "goal", "backstory"): + text = agent.get(field) or "" + for match in _PLACEHOLDER_RE.findall(text): + if match not in inputs: + inputs[match] = "" + + branch = data.get("branch", "main") + parent_id = data.get("parent_id") + return { "source": source, "event_count": event_count, "trigger": trigger_event, "entities": parsed_entities, - "branch": data.get("branch", "main"), - "parent_id": data.get("parent_id"), + "branch": branch, + "parent_id": parent_id, + "inputs": inputs, } + def _format_size(size: int) -> str: if size < 1024: return f"{size}B" diff --git a/lib/crewai/src/crewai/cli/checkpoint_tui.py b/lib/crewai/src/crewai/cli/checkpoint_tui.py index 50ac4195f..cf2cdb61e 100644 --- a/lib/crewai/src/crewai/cli/checkpoint_tui.py +++ b/lib/crewai/src/crewai/cli/checkpoint_tui.py @@ -10,10 +10,11 @@ from textual.binding import Binding from textual.containers import Horizontal, Vertical, VerticalScroll from textual.widgets import ( Button, - Collapsible, Footer, Header, + Input, Static, + TextArea, Tree, ) @@ -39,6 +40,7 @@ def _load_entries(location: str) -> list[dict[str, Any]]: return _list_json(location) + def _short_id(name: str) -> str: """Shorten a checkpoint name for tree display.""" if len(name) > 30: @@ -46,14 +48,9 @@ def _short_id(name: str) -> str: return name -def _build_entity_detail(ent: dict[str, Any]) -> str: - """Build rich text for a single entity.""" +def _build_entity_header(ent: dict[str, Any]) -> str: + """Build rich text header for an entity (progress bar only).""" lines: list[str] = [] - eid = str(ent.get("id", ""))[:8] - etype = ent.get("type", "unknown") - ename = ent.get("name", "unnamed") - lines.append(f"[bold {_SECONDARY}]{etype}[/]: {ename} [{_DIM}]{eid}[/]") - tasks = ent.get("tasks") if isinstance(tasks, list): completed = ent.get("tasks_completed", 0) @@ -63,22 +60,19 @@ def _build_entity_detail(ent: dict[str, Any]) -> str: filled = int(bar_len * completed / total) if total else 0 bar = f"[{_PRIMARY}]{'█' * filled}[/][{_DIM}]{'░' * (bar_len - filled)}[/]" lines.append(f"{bar} {completed}/{total} tasks ({pct}%)") - lines.append("") - for i, task in enumerate(tasks): - icon = "[green]✓[/]" if task.get("completed") else "[yellow]○[/]" - desc = str(task.get("description", "")) - if len(desc) > 55: - desc = desc[:52] + "..." - lines.append(f" {icon} {i + 1}. {desc}") - return "\n".join(lines) -class CheckpointTUI(App[tuple[str, str] | None]): +# Return type: (location, action, inputs, task_output_overrides) +_TuiResult = tuple[str, str, dict[str, Any] | None, dict[int, str] | None] | None + + +class CheckpointTUI(App[_TuiResult]): """TUI to browse and inspect checkpoints. - Returns ``(location, action)`` where action is ``"resume"`` or - ``"fork"``, or ``None`` if the user quit without selecting. + Returns ``(location, action, inputs)`` where action is ``"resume"`` or + ``"fork"`` and inputs is a parsed dict or ``None``, + or ``None`` if the user quit without selecting. """ TITLE = "CrewAI Checkpoints" @@ -114,6 +108,7 @@ class CheckpointTUI(App[tuple[str, str] | None]): }} #detail-container {{ width: 55%; + height: 1fr; }} #detail-scroll {{ height: 1fr; @@ -133,9 +128,41 @@ class CheckpointTUI(App[tuple[str, str] | None]): padding: 0 2; color: {_DIM}; }} + #inputs-section {{ + display: none; + height: auto; + max-height: 8; + padding: 0 1; + }} + #inputs-section.visible {{ + display: block; + }} + #inputs-label {{ + height: 1; + color: {_DIM}; + padding: 0 1; + }} + .input-row {{ + height: 3; + padding: 0 1; + }} + .input-row Static {{ + width: auto; + min-width: 12; + padding: 1 1 0 0; + color: {_TERTIARY}; + }} + .input-row Input {{ + width: 1fr; + }} + #no-inputs-label {{ + height: 1; + color: {_DIM}; + padding: 0 1; + }} #action-buttons {{ height: 3; - align: center middle; + align: right middle; padding: 0 1; display: none; }} @@ -143,8 +170,8 @@ class CheckpointTUI(App[tuple[str, str] | None]): display: block; }} #action-buttons Button {{ - margin: 0 1; - min-width: 14; + margin: 0 0 0 1; + min-width: 10; }} #btn-resume {{ background: {_SECONDARY}; @@ -160,13 +187,24 @@ class CheckpointTUI(App[tuple[str, str] | None]): #btn-fork:hover {{ background: {_SECONDARY}; }} - Collapsible {{ - margin: 0; - padding: 0; + .entity-title {{ + padding: 1 1 0 1; }} .entity-detail {{ padding: 0 1; }} + .task-output-editor {{ + height: auto; + max-height: 10; + margin: 0 1 1 1; + border: round {_DIM}; + }} + .task-output-editor:focus {{ + border: round {_PRIMARY}; + }} + .task-label {{ + padding: 0 1; + }} Tree {{ background: {_BG_PANEL}; }} @@ -186,6 +224,7 @@ class CheckpointTUI(App[tuple[str, str] | None]): self._location = location self._entries: list[dict[str, Any]] = [] self._selected_entry: dict[str, Any] | None = None + self._input_keys: list[str] = [] def compose(self) -> ComposeResult: yield Header(show_clock=False) @@ -201,6 +240,8 @@ class CheckpointTUI(App[tuple[str, str] | None]): f"[{_DIM}]Select a checkpoint from the tree[/]", # noqa: S608 id="detail-header", ) + with Vertical(id="inputs-section"): + yield Static("Inputs", id="inputs-label") with Horizontal(id="action-buttons"): yield Button("Resume", id="btn-resume") yield Button("Fork", id="btn-fork") @@ -295,6 +336,7 @@ class CheckpointTUI(App[tuple[str, str] | None]): if first_parent else tree.root ) + # Drop fork-initial checkpoint — it's just a lineage marker branch_label = ( f"[bold {_SECONDARY}]{branch_name}[/] [{_DIM}]({len(entries)})[/]" ) @@ -308,16 +350,17 @@ class CheckpointTUI(App[tuple[str, str] | None]): self.sub_title = self._location self.query_one("#status", Static).update(f" {count} checkpoint(s) | {storage}") - def _show_detail(self, entry: dict[str, Any]) -> None: + async def _show_detail(self, entry: dict[str, Any]) -> None: """Update the detail panel for a checkpoint entry.""" self._selected_entry = entry self.query_one("#action-buttons").add_class("visible") detail_scroll = self.query_one("#detail-scroll", VerticalScroll) - # Remove old collapsibles - for widget in list(detail_scroll.query("Collapsible")): - widget.remove() + # Remove all dynamic children except the header — await so IDs are freed + to_remove = [c for c in detail_scroll.children if c.id != "detail-header"] + for child in to_remove: + await child.remove() # Header name = entry.get("name", "") @@ -347,19 +390,102 @@ class CheckpointTUI(App[tuple[str, str] | None]): self.query_one("#detail-header", Static).update("\n".join(header_lines)) - # Entity collapsibles - for ent in entry.get("entities", []): + # Entity details and editable task outputs — mounted flat for scrolling + self._task_output_ids = [] + for ent_idx, ent in enumerate(entry.get("entities", [])): etype = ent.get("type", "unknown") ename = ent.get("name", "unnamed") completed = ent.get("tasks_completed") total = ent.get("tasks_total") - title = f"{etype}: {ename}" + entity_title = f"[bold {_SECONDARY}]{etype}: {ename}[/]" if completed is not None and total is not None: - title += f" [{completed}/{total} tasks]" + entity_title += f" [{_DIM}]{completed}/{total} tasks[/]" + detail_scroll.mount(Static(entity_title, classes="entity-title")) + detail_scroll.mount( + Static(_build_entity_header(ent), classes="entity-detail") + ) - content = Static(_build_entity_detail(ent), classes="entity-detail") - collapsible = Collapsible(content, title=title, collapsed=False) - detail_scroll.mount(collapsible) + tasks = ent.get("tasks", []) + for i, task in enumerate(tasks): + desc = str(task.get("description", "")) + if len(desc) > 55: + desc = desc[:52] + "..." + if task.get("completed"): + icon = "[green]✓[/]" + detail_scroll.mount( + Static(f" {icon} {i + 1}. {desc}", classes="task-label") + ) + output_text = task.get("output", "") + editor_id = f"task-output-{ent_idx}-{i}" + detail_scroll.mount( + TextArea( + str(output_text), + classes="task-output-editor", + id=editor_id, + ) + ) + self._task_output_ids.append((i, editor_id)) + else: + icon = "[yellow]○[/]" + detail_scroll.mount( + Static(f" {icon} {i + 1}. {desc}", classes="task-label") + ) + + # Build input fields + await self._build_input_fields(entry.get("inputs", {})) + + async def _build_input_fields(self, inputs: dict[str, Any]) -> None: + """Rebuild the inputs section with one field per input key.""" + section = self.query_one("#inputs-section") + + # Remove old dynamic children — await so IDs are freed + for widget in list(section.query(".input-row, .no-inputs")): + await widget.remove() + + self._input_keys = [] + + if not inputs: + section.mount(Static(f"[{_DIM}]No inputs[/]", classes="no-inputs")) + section.add_class("visible") + return + + for key, value in inputs.items(): + self._input_keys.append(key) + row = Horizontal(classes="input-row") + row.compose_add_child(Static(f"[bold]{key}[/]")) + row.compose_add_child( + Input(value=str(value), placeholder=key, id=f"input-{key}") + ) + section.mount(row) + + section.add_class("visible") + + def _collect_inputs(self) -> dict[str, Any] | None: + """Collect current values from input fields.""" + if not self._input_keys: + return None + result: dict[str, Any] = {} + for key in self._input_keys: + widget = self.query_one(f"#input-{key}", Input) + result[key] = widget.value + return result + + def _collect_task_overrides(self) -> dict[int, str] | None: + """Collect edited task outputs. Returns only changed values.""" + if not self._task_output_ids or self._selected_entry is None: + return None + overrides: dict[int, str] = {} + tasks = [] + for ent in self._selected_entry.get("entities", []): + tasks = ent.get("tasks", []) + if tasks: + break + for task_idx, editor_id in self._task_output_ids: + editor = self.query_one(f"#{editor_id}", TextArea) + original = str(tasks[task_idx].get("output", "")) if task_idx < len(tasks) else "" + if editor.text != original: + overrides[task_idx] = editor.text + return overrides or None def _resolve_location(self, entry: dict[str, Any]) -> str: """Get the restore location string for a checkpoint entry.""" @@ -373,16 +499,18 @@ class CheckpointTUI(App[tuple[str, str] | None]): self, event: Tree.NodeHighlighted[dict[str, Any]] ) -> None: if event.node.data is not None: - self._show_detail(event.node.data) + await self._show_detail(event.node.data) def on_button_pressed(self, event: Button.Pressed) -> None: if self._selected_entry is None: return + inputs = self._collect_inputs() + overrides = self._collect_task_overrides() loc = self._resolve_location(self._selected_entry) if event.button.id == "btn-resume": - self.exit((loc, "resume")) + self.exit((loc, "resume", inputs, overrides)) elif event.button.id == "btn-fork": - self.exit((loc, "fork")) + self.exit((loc, "fork", inputs, overrides)) def action_refresh(self) -> None: self._refresh_tree() @@ -398,7 +526,7 @@ async def _run_checkpoint_tui_async(location: str) -> None: if selection is None: return - selected, action = selection + selected, action, inputs, task_overrides = selection from crewai.crew import Crew from crewai.state.checkpoint_config import CheckpointConfig @@ -412,7 +540,44 @@ async def _run_checkpoint_tui_async(location: str) -> None: click.echo(f"\nResuming from: {selected}\n") crew = Crew.from_checkpoint(config) - result = await crew.akickoff() + # Apply task output overrides before kickoff + if task_overrides: + click.echo("Modifications:") + for task_idx, new_output in task_overrides.items(): + if task_idx < len(crew.tasks) and crew.tasks[task_idx].output is not None: + desc = crew.tasks[task_idx].description or f"Task {task_idx + 1}" + if len(desc) > 60: + desc = desc[:57] + "..." + crew.tasks[task_idx].output.raw = new_output # type: ignore[union-attr] + preview = new_output.replace("\n", " ") + if len(preview) > 80: + preview = preview[:77] + "..." + click.echo(f" Task {task_idx + 1}: {desc}") + click.echo(f" -> {preview}") + # Update the assistant message in the executor's history + # so the LLM sees the override in its conversation + agent = crew.tasks[task_idx].agent + if agent and agent.agent_executor: + for msg in reversed(agent.agent_executor.messages): + if msg.get("role") == "assistant": + msg["content"] = new_output + break + # Invalidate all subsequent tasks so they re-run with + # the modified context instead of using cached results + for subsequent in crew.tasks[task_idx + 1 :]: + subsequent.output = None + if subsequent.agent and subsequent.agent.agent_executor: + subsequent.agent.agent_executor._resuming = False + subsequent.agent.agent_executor.messages = [] + click.echo() + + if inputs: + click.echo("Inputs:") + for k, v in inputs.items(): + click.echo(f" {k}: {v}") + click.echo() + + result = await crew.akickoff(inputs=inputs) click.echo(f"\nResult: {getattr(result, 'raw', result)}") diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index 111babbf3..3b7f7a7d5 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -436,6 +436,11 @@ class Crew(FlowTrackable, BaseModel): if agent.agent_executor is not None and task.output is None: agent.agent_executor.task = task break + for task in self.tasks: + if task.checkpoint_original_description is not None: + task._original_description = task.checkpoint_original_description + if task.checkpoint_original_expected_output is not None: + task._original_expected_output = task.checkpoint_original_expected_output if self.checkpoint_inputs is not None: self._inputs = self.checkpoint_inputs if self.checkpoint_kickoff_event_id is not None: diff --git a/lib/crewai/src/crewai/state/checkpoint_listener.py b/lib/crewai/src/crewai/state/checkpoint_listener.py index c2ac728a8..e956a4395 100644 --- a/lib/crewai/src/crewai/state/checkpoint_listener.py +++ b/lib/crewai/src/crewai/state/checkpoint_listener.py @@ -102,8 +102,12 @@ def _find_checkpoint(source: Any) -> CheckpointConfig | None: return None -def _do_checkpoint(state: RuntimeState, cfg: CheckpointConfig) -> None: +def _do_checkpoint( + state: RuntimeState, cfg: CheckpointConfig, event: BaseEvent | None = None +) -> None: """Write a checkpoint and prune old ones if configured.""" + if event is not None: + state._trigger = event.type _prepare_entities(state.root) data = state.model_dump_json() location = cfg.provider.checkpoint( @@ -134,7 +138,7 @@ def _on_any_event(source: Any, event: BaseEvent, state: Any) -> None: if cfg is None: return try: - _do_checkpoint(state, cfg) + _do_checkpoint(state, cfg, event) except Exception: logger.warning("Auto-checkpoint failed for event %s", event.type, exc_info=True) diff --git a/lib/crewai/src/crewai/state/runtime.py b/lib/crewai/src/crewai/state/runtime.py index 4dd1a4555..3e8e4f4bd 100644 --- a/lib/crewai/src/crewai/state/runtime.py +++ b/lib/crewai/src/crewai/state/runtime.py @@ -80,6 +80,9 @@ def _sync_checkpoint_fields(entity: object) -> None: entity.checkpoint_inputs = entity._inputs entity.checkpoint_train = entity._train entity.checkpoint_kickoff_event_id = entity._kickoff_event_id + for task in entity.tasks: + task.checkpoint_original_description = task._original_description + task.checkpoint_original_expected_output = task._original_expected_output def _migrate(data: dict[str, Any]) -> dict[str, Any]: @@ -123,6 +126,7 @@ class RuntimeState(RootModel): # type: ignore[type-arg] _parent_id: str | None = PrivateAttr(default=None) _branch: str = PrivateAttr(default="main") _location: str | None = PrivateAttr(default=None) + _trigger: str | None = PrivateAttr(default=None) @property def event_record(self) -> EventRecord: @@ -131,13 +135,16 @@ class RuntimeState(RootModel): # type: ignore[type-arg] @model_serializer(mode="plain") def _serialize(self) -> dict[str, Any]: - return { + d: dict[str, Any] = { "crewai_version": get_crewai_version(), "parent_id": self._parent_id, "branch": self._branch, "entities": [e.model_dump(mode="json") for e in self.root], "event_record": self._event_record.model_dump(), } + if self._trigger: + d["trigger"] = self._trigger + return d @model_validator(mode="wrap") @classmethod @@ -222,13 +229,10 @@ class RuntimeState(RootModel): # type: ignore[type-arg] if branch: self._branch = branch elif self._checkpoint_id: - self._branch = f"fork/{self._checkpoint_id}" + self._branch = f"fork/{self._checkpoint_id}_{uuid.uuid4().hex[:6]}" else: self._branch = f"fork/{uuid.uuid4().hex[:8]}" - if self._location is not None: - self.checkpoint(self._location) - @classmethod def from_checkpoint(cls, config: CheckpointConfig, **kwargs: Any) -> RuntimeState: """Restore a RuntimeState from a checkpoint. diff --git a/lib/crewai/src/crewai/task.py b/lib/crewai/src/crewai/task.py index 4224828e3..0a159cb0e 100644 --- a/lib/crewai/src/crewai/task.py +++ b/lib/crewai/src/crewai/task.py @@ -230,6 +230,8 @@ class Task(BaseModel): _original_description: str | None = PrivateAttr(default=None) _original_expected_output: str | None = PrivateAttr(default=None) _original_output_file: str | None = PrivateAttr(default=None) + checkpoint_original_description: str | None = Field(default=None, exclude=False) + checkpoint_original_expected_output: str | None = Field(default=None, exclude=False) _thread: threading.Thread | None = PrivateAttr(default=None) model_config = {"arbitrary_types_allowed": True}