diff --git a/docs/en/concepts/checkpointing.mdx b/docs/en/concepts/checkpointing.mdx index d1f5fecda..d6430eb6f 100644 --- a/docs/en/concepts/checkpointing.mdx +++ b/docs/en/concepts/checkpointing.mdx @@ -282,6 +282,15 @@ crewai checkpoint --location ./.checkpoints.db The left panel is a tree view. Checkpoints are grouped by branch, and forks nest under the checkpoint they diverged from. Select a checkpoint to see its metadata, entity state, and task progress in the detail panel. Hit **Resume** to pick up where it left off, or **Fork** to start a new branch from that point. +### Editing inputs and task outputs + +When a checkpoint is selected, the detail panel shows: + +- **Inputs** — if the original kickoff had inputs (e.g. `{topic}`), they appear as editable fields pre-filled with the original values. Change them before resuming or forking. +- **Task outputs** — completed tasks show their output in editable text areas. Edit a task's output to change the context that downstream tasks receive. When you modify a task output and hit Fork, all subsequent tasks are invalidated and re-run with the new context. + +This is useful for "what if" exploration — fork from a checkpoint, tweak a task's result, and see how it changes downstream behavior. + ### Subcommands ```bash diff --git a/lib/crewai/src/crewai/cli/checkpoint_cli.py b/lib/crewai/src/crewai/cli/checkpoint_cli.py index 9a6fcd7ea..8469483ad 100644 --- a/lib/crewai/src/crewai/cli/checkpoint_cli.py +++ b/lib/crewai/src/crewai/cli/checkpoint_cli.py @@ -6,12 +6,16 @@ from datetime import datetime import glob import json import os +import re import sqlite3 from typing import Any import click +_PLACEHOLDER_RE = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\-]*)}") + + _SQLITE_MAGIC = b"SQLite format 3\x00" _SELECT_ALL = """ @@ -71,13 +75,7 @@ def _parse_checkpoint_json(raw: str, source: str) -> dict[str, Any]: nodes = data.get("event_record", {}).get("nodes", {}) event_count = len(nodes) - trigger_event = None - if nodes: - last_node = max( - nodes.values(), - key=lambda n: n.get("event", {}).get("emission_sequence") or 0, - ) - trigger_event = last_node.get("event", {}).get("type") + trigger_event = data.get("trigger") parsed_entities: list[dict[str, Any]] = [] for entity in entities: @@ -95,21 +93,51 @@ def _parse_checkpoint_json(raw: str, source: str) -> dict[str, Any]: { "description": t.get("description", ""), "completed": t.get("output") is not None, + "output": (t.get("output") or {}).get("raw", ""), } for t in tasks ] parsed_entities.append(info) + inputs: dict[str, Any] = {} + for entity in entities: + cp_inputs = entity.get("checkpoint_inputs") + if isinstance(cp_inputs, dict) and cp_inputs: + inputs = dict(cp_inputs) + break + + for entity in entities: + for task in entity.get("tasks", []): + for field in ( + "checkpoint_original_description", + "checkpoint_original_expected_output", + ): + text = task.get(field) or "" + for match in _PLACEHOLDER_RE.findall(text): + if match not in inputs: + inputs[match] = "" + for agent in entity.get("agents", []): + for field in ("role", "goal", "backstory"): + text = agent.get(field) or "" + for match in _PLACEHOLDER_RE.findall(text): + if match not in inputs: + inputs[match] = "" + + branch = data.get("branch", "main") + parent_id = data.get("parent_id") + return { "source": source, "event_count": event_count, "trigger": trigger_event, "entities": parsed_entities, - "branch": data.get("branch", "main"), - "parent_id": data.get("parent_id"), + "branch": branch, + "parent_id": parent_id, + "inputs": inputs, } + def _format_size(size: int) -> str: if size < 1024: return f"{size}B" diff --git a/lib/crewai/src/crewai/cli/checkpoint_tui.py b/lib/crewai/src/crewai/cli/checkpoint_tui.py index 50ac4195f..cf2cdb61e 100644 --- a/lib/crewai/src/crewai/cli/checkpoint_tui.py +++ b/lib/crewai/src/crewai/cli/checkpoint_tui.py @@ -10,10 +10,11 @@ from textual.binding import Binding from textual.containers import Horizontal, Vertical, VerticalScroll from textual.widgets import ( Button, - Collapsible, Footer, Header, + Input, Static, + TextArea, Tree, ) @@ -39,6 +40,7 @@ def _load_entries(location: str) -> list[dict[str, Any]]: return _list_json(location) + def _short_id(name: str) -> str: """Shorten a checkpoint name for tree display.""" if len(name) > 30: @@ -46,14 +48,9 @@ def _short_id(name: str) -> str: return name -def _build_entity_detail(ent: dict[str, Any]) -> str: - """Build rich text for a single entity.""" +def _build_entity_header(ent: dict[str, Any]) -> str: + """Build rich text header for an entity (progress bar only).""" lines: list[str] = [] - eid = str(ent.get("id", ""))[:8] - etype = ent.get("type", "unknown") - ename = ent.get("name", "unnamed") - lines.append(f"[bold {_SECONDARY}]{etype}[/]: {ename} [{_DIM}]{eid}[/]") - tasks = ent.get("tasks") if isinstance(tasks, list): completed = ent.get("tasks_completed", 0) @@ -63,22 +60,19 @@ def _build_entity_detail(ent: dict[str, Any]) -> str: filled = int(bar_len * completed / total) if total else 0 bar = f"[{_PRIMARY}]{'█' * filled}[/][{_DIM}]{'░' * (bar_len - filled)}[/]" lines.append(f"{bar} {completed}/{total} tasks ({pct}%)") - lines.append("") - for i, task in enumerate(tasks): - icon = "[green]✓[/]" if task.get("completed") else "[yellow]○[/]" - desc = str(task.get("description", "")) - if len(desc) > 55: - desc = desc[:52] + "..." - lines.append(f" {icon} {i + 1}. {desc}") - return "\n".join(lines) -class CheckpointTUI(App[tuple[str, str] | None]): +# Return type: (location, action, inputs, task_output_overrides) +_TuiResult = tuple[str, str, dict[str, Any] | None, dict[int, str] | None] | None + + +class CheckpointTUI(App[_TuiResult]): """TUI to browse and inspect checkpoints. - Returns ``(location, action)`` where action is ``"resume"`` or - ``"fork"``, or ``None`` if the user quit without selecting. + Returns ``(location, action, inputs)`` where action is ``"resume"`` or + ``"fork"`` and inputs is a parsed dict or ``None``, + or ``None`` if the user quit without selecting. """ TITLE = "CrewAI Checkpoints" @@ -114,6 +108,7 @@ class CheckpointTUI(App[tuple[str, str] | None]): }} #detail-container {{ width: 55%; + height: 1fr; }} #detail-scroll {{ height: 1fr; @@ -133,9 +128,41 @@ class CheckpointTUI(App[tuple[str, str] | None]): padding: 0 2; color: {_DIM}; }} + #inputs-section {{ + display: none; + height: auto; + max-height: 8; + padding: 0 1; + }} + #inputs-section.visible {{ + display: block; + }} + #inputs-label {{ + height: 1; + color: {_DIM}; + padding: 0 1; + }} + .input-row {{ + height: 3; + padding: 0 1; + }} + .input-row Static {{ + width: auto; + min-width: 12; + padding: 1 1 0 0; + color: {_TERTIARY}; + }} + .input-row Input {{ + width: 1fr; + }} + #no-inputs-label {{ + height: 1; + color: {_DIM}; + padding: 0 1; + }} #action-buttons {{ height: 3; - align: center middle; + align: right middle; padding: 0 1; display: none; }} @@ -143,8 +170,8 @@ class CheckpointTUI(App[tuple[str, str] | None]): display: block; }} #action-buttons Button {{ - margin: 0 1; - min-width: 14; + margin: 0 0 0 1; + min-width: 10; }} #btn-resume {{ background: {_SECONDARY}; @@ -160,13 +187,24 @@ class CheckpointTUI(App[tuple[str, str] | None]): #btn-fork:hover {{ background: {_SECONDARY}; }} - Collapsible {{ - margin: 0; - padding: 0; + .entity-title {{ + padding: 1 1 0 1; }} .entity-detail {{ padding: 0 1; }} + .task-output-editor {{ + height: auto; + max-height: 10; + margin: 0 1 1 1; + border: round {_DIM}; + }} + .task-output-editor:focus {{ + border: round {_PRIMARY}; + }} + .task-label {{ + padding: 0 1; + }} Tree {{ background: {_BG_PANEL}; }} @@ -186,6 +224,7 @@ class CheckpointTUI(App[tuple[str, str] | None]): self._location = location self._entries: list[dict[str, Any]] = [] self._selected_entry: dict[str, Any] | None = None + self._input_keys: list[str] = [] def compose(self) -> ComposeResult: yield Header(show_clock=False) @@ -201,6 +240,8 @@ class CheckpointTUI(App[tuple[str, str] | None]): f"[{_DIM}]Select a checkpoint from the tree[/]", # noqa: S608 id="detail-header", ) + with Vertical(id="inputs-section"): + yield Static("Inputs", id="inputs-label") with Horizontal(id="action-buttons"): yield Button("Resume", id="btn-resume") yield Button("Fork", id="btn-fork") @@ -295,6 +336,7 @@ class CheckpointTUI(App[tuple[str, str] | None]): if first_parent else tree.root ) + # Drop fork-initial checkpoint — it's just a lineage marker branch_label = ( f"[bold {_SECONDARY}]{branch_name}[/] [{_DIM}]({len(entries)})[/]" ) @@ -308,16 +350,17 @@ class CheckpointTUI(App[tuple[str, str] | None]): self.sub_title = self._location self.query_one("#status", Static).update(f" {count} checkpoint(s) | {storage}") - def _show_detail(self, entry: dict[str, Any]) -> None: + async def _show_detail(self, entry: dict[str, Any]) -> None: """Update the detail panel for a checkpoint entry.""" self._selected_entry = entry self.query_one("#action-buttons").add_class("visible") detail_scroll = self.query_one("#detail-scroll", VerticalScroll) - # Remove old collapsibles - for widget in list(detail_scroll.query("Collapsible")): - widget.remove() + # Remove all dynamic children except the header — await so IDs are freed + to_remove = [c for c in detail_scroll.children if c.id != "detail-header"] + for child in to_remove: + await child.remove() # Header name = entry.get("name", "") @@ -347,19 +390,102 @@ class CheckpointTUI(App[tuple[str, str] | None]): self.query_one("#detail-header", Static).update("\n".join(header_lines)) - # Entity collapsibles - for ent in entry.get("entities", []): + # Entity details and editable task outputs — mounted flat for scrolling + self._task_output_ids = [] + for ent_idx, ent in enumerate(entry.get("entities", [])): etype = ent.get("type", "unknown") ename = ent.get("name", "unnamed") completed = ent.get("tasks_completed") total = ent.get("tasks_total") - title = f"{etype}: {ename}" + entity_title = f"[bold {_SECONDARY}]{etype}: {ename}[/]" if completed is not None and total is not None: - title += f" [{completed}/{total} tasks]" + entity_title += f" [{_DIM}]{completed}/{total} tasks[/]" + detail_scroll.mount(Static(entity_title, classes="entity-title")) + detail_scroll.mount( + Static(_build_entity_header(ent), classes="entity-detail") + ) - content = Static(_build_entity_detail(ent), classes="entity-detail") - collapsible = Collapsible(content, title=title, collapsed=False) - detail_scroll.mount(collapsible) + tasks = ent.get("tasks", []) + for i, task in enumerate(tasks): + desc = str(task.get("description", "")) + if len(desc) > 55: + desc = desc[:52] + "..." + if task.get("completed"): + icon = "[green]✓[/]" + detail_scroll.mount( + Static(f" {icon} {i + 1}. {desc}", classes="task-label") + ) + output_text = task.get("output", "") + editor_id = f"task-output-{ent_idx}-{i}" + detail_scroll.mount( + TextArea( + str(output_text), + classes="task-output-editor", + id=editor_id, + ) + ) + self._task_output_ids.append((i, editor_id)) + else: + icon = "[yellow]○[/]" + detail_scroll.mount( + Static(f" {icon} {i + 1}. {desc}", classes="task-label") + ) + + # Build input fields + await self._build_input_fields(entry.get("inputs", {})) + + async def _build_input_fields(self, inputs: dict[str, Any]) -> None: + """Rebuild the inputs section with one field per input key.""" + section = self.query_one("#inputs-section") + + # Remove old dynamic children — await so IDs are freed + for widget in list(section.query(".input-row, .no-inputs")): + await widget.remove() + + self._input_keys = [] + + if not inputs: + section.mount(Static(f"[{_DIM}]No inputs[/]", classes="no-inputs")) + section.add_class("visible") + return + + for key, value in inputs.items(): + self._input_keys.append(key) + row = Horizontal(classes="input-row") + row.compose_add_child(Static(f"[bold]{key}[/]")) + row.compose_add_child( + Input(value=str(value), placeholder=key, id=f"input-{key}") + ) + section.mount(row) + + section.add_class("visible") + + def _collect_inputs(self) -> dict[str, Any] | None: + """Collect current values from input fields.""" + if not self._input_keys: + return None + result: dict[str, Any] = {} + for key in self._input_keys: + widget = self.query_one(f"#input-{key}", Input) + result[key] = widget.value + return result + + def _collect_task_overrides(self) -> dict[int, str] | None: + """Collect edited task outputs. Returns only changed values.""" + if not self._task_output_ids or self._selected_entry is None: + return None + overrides: dict[int, str] = {} + tasks = [] + for ent in self._selected_entry.get("entities", []): + tasks = ent.get("tasks", []) + if tasks: + break + for task_idx, editor_id in self._task_output_ids: + editor = self.query_one(f"#{editor_id}", TextArea) + original = str(tasks[task_idx].get("output", "")) if task_idx < len(tasks) else "" + if editor.text != original: + overrides[task_idx] = editor.text + return overrides or None def _resolve_location(self, entry: dict[str, Any]) -> str: """Get the restore location string for a checkpoint entry.""" @@ -373,16 +499,18 @@ class CheckpointTUI(App[tuple[str, str] | None]): self, event: Tree.NodeHighlighted[dict[str, Any]] ) -> None: if event.node.data is not None: - self._show_detail(event.node.data) + await self._show_detail(event.node.data) def on_button_pressed(self, event: Button.Pressed) -> None: if self._selected_entry is None: return + inputs = self._collect_inputs() + overrides = self._collect_task_overrides() loc = self._resolve_location(self._selected_entry) if event.button.id == "btn-resume": - self.exit((loc, "resume")) + self.exit((loc, "resume", inputs, overrides)) elif event.button.id == "btn-fork": - self.exit((loc, "fork")) + self.exit((loc, "fork", inputs, overrides)) def action_refresh(self) -> None: self._refresh_tree() @@ -398,7 +526,7 @@ async def _run_checkpoint_tui_async(location: str) -> None: if selection is None: return - selected, action = selection + selected, action, inputs, task_overrides = selection from crewai.crew import Crew from crewai.state.checkpoint_config import CheckpointConfig @@ -412,7 +540,44 @@ async def _run_checkpoint_tui_async(location: str) -> None: click.echo(f"\nResuming from: {selected}\n") crew = Crew.from_checkpoint(config) - result = await crew.akickoff() + # Apply task output overrides before kickoff + if task_overrides: + click.echo("Modifications:") + for task_idx, new_output in task_overrides.items(): + if task_idx < len(crew.tasks) and crew.tasks[task_idx].output is not None: + desc = crew.tasks[task_idx].description or f"Task {task_idx + 1}" + if len(desc) > 60: + desc = desc[:57] + "..." + crew.tasks[task_idx].output.raw = new_output # type: ignore[union-attr] + preview = new_output.replace("\n", " ") + if len(preview) > 80: + preview = preview[:77] + "..." + click.echo(f" Task {task_idx + 1}: {desc}") + click.echo(f" -> {preview}") + # Update the assistant message in the executor's history + # so the LLM sees the override in its conversation + agent = crew.tasks[task_idx].agent + if agent and agent.agent_executor: + for msg in reversed(agent.agent_executor.messages): + if msg.get("role") == "assistant": + msg["content"] = new_output + break + # Invalidate all subsequent tasks so they re-run with + # the modified context instead of using cached results + for subsequent in crew.tasks[task_idx + 1 :]: + subsequent.output = None + if subsequent.agent and subsequent.agent.agent_executor: + subsequent.agent.agent_executor._resuming = False + subsequent.agent.agent_executor.messages = [] + click.echo() + + if inputs: + click.echo("Inputs:") + for k, v in inputs.items(): + click.echo(f" {k}: {v}") + click.echo() + + result = await crew.akickoff(inputs=inputs) click.echo(f"\nResult: {getattr(result, 'raw', result)}") diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index 111babbf3..3b7f7a7d5 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -436,6 +436,11 @@ class Crew(FlowTrackable, BaseModel): if agent.agent_executor is not None and task.output is None: agent.agent_executor.task = task break + for task in self.tasks: + if task.checkpoint_original_description is not None: + task._original_description = task.checkpoint_original_description + if task.checkpoint_original_expected_output is not None: + task._original_expected_output = task.checkpoint_original_expected_output if self.checkpoint_inputs is not None: self._inputs = self.checkpoint_inputs if self.checkpoint_kickoff_event_id is not None: diff --git a/lib/crewai/src/crewai/state/checkpoint_listener.py b/lib/crewai/src/crewai/state/checkpoint_listener.py index c2ac728a8..e956a4395 100644 --- a/lib/crewai/src/crewai/state/checkpoint_listener.py +++ b/lib/crewai/src/crewai/state/checkpoint_listener.py @@ -102,8 +102,12 @@ def _find_checkpoint(source: Any) -> CheckpointConfig | None: return None -def _do_checkpoint(state: RuntimeState, cfg: CheckpointConfig) -> None: +def _do_checkpoint( + state: RuntimeState, cfg: CheckpointConfig, event: BaseEvent | None = None +) -> None: """Write a checkpoint and prune old ones if configured.""" + if event is not None: + state._trigger = event.type _prepare_entities(state.root) data = state.model_dump_json() location = cfg.provider.checkpoint( @@ -134,7 +138,7 @@ def _on_any_event(source: Any, event: BaseEvent, state: Any) -> None: if cfg is None: return try: - _do_checkpoint(state, cfg) + _do_checkpoint(state, cfg, event) except Exception: logger.warning("Auto-checkpoint failed for event %s", event.type, exc_info=True) diff --git a/lib/crewai/src/crewai/state/runtime.py b/lib/crewai/src/crewai/state/runtime.py index 4dd1a4555..3e8e4f4bd 100644 --- a/lib/crewai/src/crewai/state/runtime.py +++ b/lib/crewai/src/crewai/state/runtime.py @@ -80,6 +80,9 @@ def _sync_checkpoint_fields(entity: object) -> None: entity.checkpoint_inputs = entity._inputs entity.checkpoint_train = entity._train entity.checkpoint_kickoff_event_id = entity._kickoff_event_id + for task in entity.tasks: + task.checkpoint_original_description = task._original_description + task.checkpoint_original_expected_output = task._original_expected_output def _migrate(data: dict[str, Any]) -> dict[str, Any]: @@ -123,6 +126,7 @@ class RuntimeState(RootModel): # type: ignore[type-arg] _parent_id: str | None = PrivateAttr(default=None) _branch: str = PrivateAttr(default="main") _location: str | None = PrivateAttr(default=None) + _trigger: str | None = PrivateAttr(default=None) @property def event_record(self) -> EventRecord: @@ -131,13 +135,16 @@ class RuntimeState(RootModel): # type: ignore[type-arg] @model_serializer(mode="plain") def _serialize(self) -> dict[str, Any]: - return { + d: dict[str, Any] = { "crewai_version": get_crewai_version(), "parent_id": self._parent_id, "branch": self._branch, "entities": [e.model_dump(mode="json") for e in self.root], "event_record": self._event_record.model_dump(), } + if self._trigger: + d["trigger"] = self._trigger + return d @model_validator(mode="wrap") @classmethod @@ -222,13 +229,10 @@ class RuntimeState(RootModel): # type: ignore[type-arg] if branch: self._branch = branch elif self._checkpoint_id: - self._branch = f"fork/{self._checkpoint_id}" + self._branch = f"fork/{self._checkpoint_id}_{uuid.uuid4().hex[:6]}" else: self._branch = f"fork/{uuid.uuid4().hex[:8]}" - if self._location is not None: - self.checkpoint(self._location) - @classmethod def from_checkpoint(cls, config: CheckpointConfig, **kwargs: Any) -> RuntimeState: """Restore a RuntimeState from a checkpoint. diff --git a/lib/crewai/src/crewai/task.py b/lib/crewai/src/crewai/task.py index 4224828e3..0a159cb0e 100644 --- a/lib/crewai/src/crewai/task.py +++ b/lib/crewai/src/crewai/task.py @@ -230,6 +230,8 @@ class Task(BaseModel): _original_description: str | None = PrivateAttr(default=None) _original_expected_output: str | None = PrivateAttr(default=None) _original_output_file: str | None = PrivateAttr(default=None) + checkpoint_original_description: str | None = Field(default=None, exclude=False) + checkpoint_original_expected_output: str | None = Field(default=None, exclude=False) _thread: threading.Thread | None = PrivateAttr(default=None) model_config = {"arbitrary_types_allowed": True}