From 7e01c5a03048596dc6dd43b51b3ca2c41e3bc8fe Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Fri, 17 Apr 2026 01:34:06 +0800 Subject: [PATCH] fix: dispatch Flow checkpoints through Flow APIs in TUI --- lib/crewai/src/crewai/cli/checkpoint_tui.py | 170 ++++++++++++++------ lib/crewai/src/crewai/flow/flow.py | 4 +- 2 files changed, 120 insertions(+), 54 deletions(-) diff --git a/lib/crewai/src/crewai/cli/checkpoint_tui.py b/lib/crewai/src/crewai/cli/checkpoint_tui.py index e0d10f813..26791af23 100644 --- a/lib/crewai/src/crewai/cli/checkpoint_tui.py +++ b/lib/crewai/src/crewai/cli/checkpoint_tui.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections import defaultdict -from typing import Any, ClassVar +from typing import Any, ClassVar, Literal from textual.app import App, ComposeResult from textual.binding import Binding @@ -78,15 +78,25 @@ def _build_entity_header(ent: dict[str, Any]) -> str: return "\n".join(lines) -# Return type: (location, action, inputs, task_output_overrides) -_TuiResult = tuple[str, str, dict[str, Any] | None, dict[int, str] | None] | None +# Return type: (location, action, inputs, task_output_overrides, entity_type) +_TuiResult = ( + tuple[ + str, + str, + dict[str, Any] | None, + dict[int, str] | None, + Literal["crew", "flow"], + ] + | None +) class CheckpointTUI(App[_TuiResult]): """TUI to browse and inspect checkpoints. - Returns ``(location, action, inputs)`` where action is ``"resume"`` or - ``"fork"`` and inputs is a parsed dict or ``None``, + Returns ``(location, action, inputs, task_overrides, entity_type)`` + where action is ``"resume"`` or ``"fork"``, inputs is a parsed dict + or ``None``, and entity_type is ``"crew"`` or ``"flow"``; or ``None`` if the user quit without selecting. """ @@ -506,6 +516,13 @@ class CheckpointTUI(App[_TuiResult]): overrides[task_idx] = editor.text return overrides or None + def _detect_entity_type(self, entry: dict[str, Any]) -> Literal["crew", "flow"]: + """Infer the top-level entity type from checkpoint entities.""" + for ent in entry.get("entities", []): + if ent.get("type") == "flow": + return "flow" + return "crew" + def _resolve_location(self, entry: dict[str, Any]) -> str: """Get the restore location string for a checkpoint entry.""" if "path" in entry: @@ -526,15 +543,64 @@ class CheckpointTUI(App[_TuiResult]): inputs = self._collect_inputs() overrides = self._collect_task_overrides() loc = self._resolve_location(self._selected_entry) + etype = self._detect_entity_type(self._selected_entry) if event.button.id == "btn-resume": - self.exit((loc, "resume", inputs, overrides)) + self.exit((loc, "resume", inputs, overrides, etype)) elif event.button.id == "btn-fork": - self.exit((loc, "fork", inputs, overrides)) + self.exit((loc, "fork", inputs, overrides, etype)) def action_refresh(self) -> None: self._refresh_tree() +def _apply_task_overrides(crew: Any, task_overrides: dict[int, str]) -> None: + """Apply task output overrides to a restored Crew and print modifications.""" + import click + + click.echo("Modifications:") + overridden_agents: set[int] = set() + for task_idx, new_output in task_overrides.items(): + if task_idx < len(crew.tasks) and crew.tasks[task_idx].output is not None: + desc = crew.tasks[task_idx].description or f"Task {task_idx + 1}" + if len(desc) > 60: + desc = desc[:57] + "..." + crew.tasks[task_idx].output.raw = new_output + preview = new_output.replace("\n", " ") + if len(preview) > 80: + preview = preview[:77] + "..." + click.echo(f" Task {task_idx + 1}: {desc}") + click.echo(f" -> {preview}") + agent = crew.tasks[task_idx].agent + if agent and agent.agent_executor: + nth = sum(1 for t in crew.tasks[:task_idx] if t.agent is agent) + messages = agent.agent_executor.messages + system_positions = [ + i for i, m in enumerate(messages) if m.get("role") == "system" + ] + if nth < len(system_positions): + seg_start = system_positions[nth] + seg_end = ( + system_positions[nth + 1] + if nth + 1 < len(system_positions) + else len(messages) + ) + for j in range(seg_end - 1, seg_start, -1): + if messages[j].get("role") == "assistant": + messages[j]["content"] = new_output + break + overridden_agents.add(id(agent)) + + earliest = min(task_overrides) + for offset, subsequent in enumerate(crew.tasks[earliest + 1 :], start=earliest + 1): + if subsequent.output and offset not in task_overrides: + subsequent.output = None + if subsequent.agent and subsequent.agent.agent_executor: + subsequent.agent.agent_executor._resuming = False + if id(subsequent.agent) not in overridden_agents: + subsequent.agent.agent_executor.messages = [] + click.echo() + + async def _run_checkpoint_tui_async(location: str) -> None: """Async implementation of the checkpoint TUI flow.""" import click @@ -545,13 +611,54 @@ async def _run_checkpoint_tui_async(location: str) -> None: if selection is None: return - selected, action, inputs, task_overrides = selection + selected, action, inputs, task_overrides, entity_type = selection - from crewai.crew import Crew from crewai.state.checkpoint_config import CheckpointConfig config = CheckpointConfig(restore_from=selected) + if entity_type == "flow": + from crewai.events.event_bus import crewai_event_bus + from crewai.flow.flow import Flow + + if action == "fork": + click.echo(f"\nForking flow from: {selected}\n") + flow = Flow.fork(config) + else: + click.echo(f"\nResuming flow from: {selected}\n") + flow = Flow.from_checkpoint(config) + + if task_overrides: + from crewai.crew import Crew as CrewCls + + state = crewai_event_bus._runtime_state + if state is not None: + flat_offset = 0 + for entity in state.root: + if not isinstance(entity, CrewCls) or not entity.tasks: + continue + n = len(entity.tasks) + local = { + idx - flat_offset: out + for idx, out in task_overrides.items() + if flat_offset <= idx < flat_offset + n + } + if local: + _apply_task_overrides(entity, local) + flat_offset += n + + if inputs: + click.echo("Inputs:") + for k, v in inputs.items(): + click.echo(f" {k}: {v}") + click.echo() + + result = await flow.kickoff_async(inputs=inputs) + click.echo(f"\nResult: {getattr(result, 'raw', result)}") + return + + from crewai.crew import Crew + if action == "fork": click.echo(f"\nForking from: {selected}\n") crew = Crew.fork(config) @@ -560,50 +667,7 @@ async def _run_checkpoint_tui_async(location: str) -> None: crew = Crew.from_checkpoint(config) if task_overrides: - click.echo("Modifications:") - overridden_agents: set[int] = set() - for task_idx, new_output in task_overrides.items(): - if task_idx < len(crew.tasks) and crew.tasks[task_idx].output is not None: - desc = crew.tasks[task_idx].description or f"Task {task_idx + 1}" - if len(desc) > 60: - desc = desc[:57] + "..." - crew.tasks[task_idx].output.raw = new_output # type: ignore[union-attr] - preview = new_output.replace("\n", " ") - if len(preview) > 80: - preview = preview[:77] + "..." - click.echo(f" Task {task_idx + 1}: {desc}") - click.echo(f" -> {preview}") - agent = crew.tasks[task_idx].agent - if agent and agent.agent_executor: - nth = sum(1 for t in crew.tasks[:task_idx] if t.agent is agent) - messages = agent.agent_executor.messages - system_positions = [ - i for i, m in enumerate(messages) if m.get("role") == "system" - ] - if nth < len(system_positions): - seg_start = system_positions[nth] - seg_end = ( - system_positions[nth + 1] - if nth + 1 < len(system_positions) - else len(messages) - ) - for j in range(seg_end - 1, seg_start, -1): - if messages[j].get("role") == "assistant": - messages[j]["content"] = new_output - break - overridden_agents.add(id(agent)) - - earliest = min(task_overrides) - for offset, subsequent in enumerate( - crew.tasks[earliest + 1 :], start=earliest + 1 - ): - if subsequent.output and offset not in task_overrides: - subsequent.output = None - if subsequent.agent and subsequent.agent.agent_executor: - subsequent.agent.agent_executor._resuming = False - if id(subsequent.agent) not in overridden_agents: - subsequent.agent.agent_executor.messages = [] - click.echo() + _apply_task_overrides(crew, task_overrides) if inputs: click.echo("Inputs:") diff --git a/lib/crewai/src/crewai/flow/flow.py b/lib/crewai/src/crewai/flow/flow.py index 057f60ffb..88457f7aa 100644 --- a/lib/crewai/src/crewai/flow/flow.py +++ b/lib/crewai/src/crewai/flow/flow.py @@ -2138,7 +2138,9 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): try: # Reset flow state for fresh execution unless restoring from persistence - is_restoring = inputs and "id" in inputs and self.persistence is not None + is_restoring = ( + inputs and "id" in inputs and self.persistence is not None + ) or self.checkpoint_completed_methods is not None if not is_restoring: # Clear completed methods and outputs for a fresh start self._completed_methods.clear()