mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-07 10:12:38 +00:00
Merge branch 'main' into luzk/propagate-is-litellm
This commit is contained in:
@@ -6,12 +6,16 @@ from datetime import datetime
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
|
||||
|
||||
_PLACEHOLDER_RE = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\-]*)}")
|
||||
|
||||
|
||||
_SQLITE_MAGIC = b"SQLite format 3\x00"
|
||||
|
||||
_SELECT_ALL = """
|
||||
@@ -34,6 +38,25 @@ LIMIT 1
|
||||
"""
|
||||
|
||||
|
||||
_DEFAULT_DIR = "./.checkpoints"
|
||||
_DEFAULT_DB = "./.checkpoints.db"
|
||||
|
||||
|
||||
def _detect_location(location: str) -> str:
|
||||
"""Resolve the default checkpoint location.
|
||||
|
||||
When the caller passes the default directory path, check whether a
|
||||
SQLite database exists at the conventional ``.db`` path and prefer it.
|
||||
"""
|
||||
if (
|
||||
location == _DEFAULT_DIR
|
||||
and not os.path.exists(_DEFAULT_DIR)
|
||||
and os.path.exists(_DEFAULT_DB)
|
||||
):
|
||||
return _DEFAULT_DB
|
||||
return location
|
||||
|
||||
|
||||
def _is_sqlite(path: str) -> bool:
|
||||
"""Check if a file is a SQLite database by reading its magic bytes."""
|
||||
if not os.path.isfile(path):
|
||||
@@ -52,13 +75,7 @@ def _parse_checkpoint_json(raw: str, source: str) -> dict[str, Any]:
|
||||
nodes = data.get("event_record", {}).get("nodes", {})
|
||||
event_count = len(nodes)
|
||||
|
||||
trigger_event = None
|
||||
if nodes:
|
||||
last_node = max(
|
||||
nodes.values(),
|
||||
key=lambda n: n.get("event", {}).get("emission_sequence") or 0,
|
||||
)
|
||||
trigger_event = last_node.get("event", {}).get("type")
|
||||
trigger_event = data.get("trigger")
|
||||
|
||||
parsed_entities: list[dict[str, Any]] = []
|
||||
for entity in entities:
|
||||
@@ -76,16 +93,47 @@ def _parse_checkpoint_json(raw: str, source: str) -> dict[str, Any]:
|
||||
{
|
||||
"description": t.get("description", ""),
|
||||
"completed": t.get("output") is not None,
|
||||
"output": (t.get("output") or {}).get("raw", ""),
|
||||
}
|
||||
for t in tasks
|
||||
]
|
||||
parsed_entities.append(info)
|
||||
|
||||
inputs: dict[str, Any] = {}
|
||||
for entity in entities:
|
||||
cp_inputs = entity.get("checkpoint_inputs")
|
||||
if isinstance(cp_inputs, dict) and cp_inputs:
|
||||
inputs = dict(cp_inputs)
|
||||
break
|
||||
|
||||
for entity in entities:
|
||||
for task in entity.get("tasks", []):
|
||||
for field in (
|
||||
"checkpoint_original_description",
|
||||
"checkpoint_original_expected_output",
|
||||
):
|
||||
text = task.get(field) or ""
|
||||
for match in _PLACEHOLDER_RE.findall(text):
|
||||
if match not in inputs:
|
||||
inputs[match] = ""
|
||||
for agent in entity.get("agents", []):
|
||||
for field in ("role", "goal", "backstory"):
|
||||
text = agent.get(field) or ""
|
||||
for match in _PLACEHOLDER_RE.findall(text):
|
||||
if match not in inputs:
|
||||
inputs[match] = ""
|
||||
|
||||
branch = data.get("branch", "main")
|
||||
parent_id = data.get("parent_id")
|
||||
|
||||
return {
|
||||
"source": source,
|
||||
"event_count": event_count,
|
||||
"trigger": trigger_event,
|
||||
"entities": parsed_entities,
|
||||
"branch": branch,
|
||||
"parent_id": parent_id,
|
||||
"inputs": inputs,
|
||||
}
|
||||
|
||||
|
||||
@@ -189,6 +237,7 @@ def _list_sqlite(db_path: str) -> list[dict[str, Any]]:
|
||||
"entities": [],
|
||||
"source": checkpoint_id,
|
||||
}
|
||||
meta["db"] = db_path
|
||||
results.append(meta)
|
||||
return results
|
||||
|
||||
@@ -311,6 +360,10 @@ def _print_info(meta: dict[str, Any]) -> None:
|
||||
trigger = meta.get("trigger")
|
||||
if trigger:
|
||||
click.echo(f"Trigger: {trigger}")
|
||||
click.echo(f"Branch: {meta.get('branch', 'main')}")
|
||||
parent_id = meta.get("parent_id")
|
||||
if parent_id:
|
||||
click.echo(f"Parent: {parent_id}")
|
||||
|
||||
for ent in meta.get("entities", []):
|
||||
eid = str(ent.get("id", ""))[:8]
|
||||
|
||||
@@ -2,17 +2,23 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any, ClassVar
|
||||
|
||||
from textual.app import App, ComposeResult
|
||||
from textual.binding import Binding
|
||||
from textual.containers import Horizontal, Vertical
|
||||
from textual.screen import ModalScreen
|
||||
from textual.widgets import Button, Footer, Header, OptionList, Static
|
||||
from textual.widgets.option_list import Option
|
||||
from textual.containers import Horizontal, Vertical, VerticalScroll
|
||||
from textual.widgets import (
|
||||
Button,
|
||||
Footer,
|
||||
Header,
|
||||
Input,
|
||||
Static,
|
||||
TextArea,
|
||||
Tree,
|
||||
)
|
||||
|
||||
from crewai.cli.checkpoint_cli import (
|
||||
_entity_summary,
|
||||
_format_size,
|
||||
_is_sqlite,
|
||||
_list_json,
|
||||
@@ -34,151 +40,54 @@ def _load_entries(location: str) -> list[dict[str, Any]]:
|
||||
return _list_json(location)
|
||||
|
||||
|
||||
def _format_list_label(entry: dict[str, Any]) -> str:
|
||||
"""Format a checkpoint entry for the list panel."""
|
||||
name = entry.get("name", "")
|
||||
ts = entry.get("ts") or ""
|
||||
trigger = entry.get("trigger") or ""
|
||||
summary = _entity_summary(entry.get("entities", []))
|
||||
|
||||
line1 = f"[bold]{name}[/]"
|
||||
parts = []
|
||||
if ts:
|
||||
parts.append(f"[dim]{ts}[/]")
|
||||
if "size" in entry:
|
||||
parts.append(f"[dim]{_format_size(entry['size'])}[/]")
|
||||
if trigger:
|
||||
parts.append(f"[{_PRIMARY}]{trigger}[/]")
|
||||
line2 = " ".join(parts)
|
||||
line3 = f" [{_DIM}]{summary}[/]"
|
||||
|
||||
return f"{line1}\n{line2}\n{line3}"
|
||||
def _short_id(name: str) -> str:
|
||||
"""Shorten a checkpoint name for tree display."""
|
||||
if len(name) > 30:
|
||||
return name[:27] + "..."
|
||||
return name
|
||||
|
||||
|
||||
def _format_detail(entry: dict[str, Any]) -> str:
|
||||
"""Format checkpoint details for the right panel."""
|
||||
def _entry_id(entry: dict[str, Any]) -> str:
|
||||
"""Normalize an entry's name into its checkpoint ID.
|
||||
|
||||
JSON filenames are ``{ts}_{uuid}_p-{parent}.json``; SQLite IDs
|
||||
are already ``{ts}_{uuid}``. This strips the JSON suffix so
|
||||
fork-parent lookups work in both providers.
|
||||
"""
|
||||
name = str(entry.get("name", ""))
|
||||
if name.endswith(".json"):
|
||||
name = name[: -len(".json")]
|
||||
idx = name.find("_p-")
|
||||
if idx != -1:
|
||||
name = name[:idx]
|
||||
return name
|
||||
|
||||
|
||||
def _build_entity_header(ent: dict[str, Any]) -> str:
|
||||
"""Build rich text header for an entity (progress bar only)."""
|
||||
lines: list[str] = []
|
||||
|
||||
# Header
|
||||
name = entry.get("name", "")
|
||||
lines.append(f"[bold {_PRIMARY}]{name}[/]")
|
||||
lines.append(f"[{_DIM}]{'─' * 50}[/]")
|
||||
lines.append("")
|
||||
|
||||
# Metadata table
|
||||
ts = entry.get("ts") or "unknown"
|
||||
trigger = entry.get("trigger") or ""
|
||||
lines.append(f" [bold]Time[/] {ts}")
|
||||
if "size" in entry:
|
||||
lines.append(f" [bold]Size[/] {_format_size(entry['size'])}")
|
||||
lines.append(f" [bold]Events[/] {entry.get('event_count', 0)}")
|
||||
if trigger:
|
||||
lines.append(f" [bold]Trigger[/] [{_PRIMARY}]{trigger}[/]")
|
||||
if "path" in entry:
|
||||
lines.append(f" [bold]Path[/] [{_DIM}]{entry['path']}[/]")
|
||||
if "db" in entry:
|
||||
lines.append(f" [bold]Database[/] [{_DIM}]{entry['db']}[/]")
|
||||
|
||||
# Entities
|
||||
for ent in entry.get("entities", []):
|
||||
eid = str(ent.get("id", ""))[:8]
|
||||
etype = ent.get("type", "unknown")
|
||||
ename = ent.get("name", "unnamed")
|
||||
|
||||
lines.append("")
|
||||
lines.append(f" [{_DIM}]{'─' * 50}[/]")
|
||||
lines.append(f" [bold {_SECONDARY}]{etype}[/]: {ename} [{_DIM}]{eid}[/]")
|
||||
|
||||
tasks = ent.get("tasks")
|
||||
if isinstance(tasks, list):
|
||||
completed = ent.get("tasks_completed", 0)
|
||||
total = ent.get("tasks_total", 0)
|
||||
pct = int(completed / total * 100) if total else 0
|
||||
bar_len = 20
|
||||
filled = int(bar_len * completed / total) if total else 0
|
||||
bar = f"[{_PRIMARY}]{'█' * filled}[/][{_DIM}]{'░' * (bar_len - filled)}[/]"
|
||||
|
||||
lines.append(f" {bar} {completed}/{total} tasks ({pct}%)")
|
||||
lines.append("")
|
||||
|
||||
for i, task in enumerate(tasks):
|
||||
if task.get("completed"):
|
||||
icon = "[green]✓[/]"
|
||||
else:
|
||||
icon = "[yellow]○[/]"
|
||||
desc = str(task.get("description", ""))
|
||||
if len(desc) > 55:
|
||||
desc = desc[:52] + "..."
|
||||
lines.append(f" {icon} {i + 1}. {desc}")
|
||||
|
||||
tasks = ent.get("tasks")
|
||||
if isinstance(tasks, list):
|
||||
completed = ent.get("tasks_completed", 0)
|
||||
total = ent.get("tasks_total", 0)
|
||||
pct = int(completed / total * 100) if total else 0
|
||||
bar_len = 20
|
||||
filled = int(bar_len * completed / total) if total else 0
|
||||
bar = f"[{_PRIMARY}]{'█' * filled}[/][{_DIM}]{'░' * (bar_len - filled)}[/]"
|
||||
lines.append(f"{bar} {completed}/{total} tasks ({pct}%)")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class ConfirmResumeScreen(ModalScreen[bool]):
|
||||
"""Modal confirmation before resuming from a checkpoint."""
|
||||
|
||||
CSS = f"""
|
||||
ConfirmResumeScreen {{
|
||||
align: center middle;
|
||||
}}
|
||||
#confirm-dialog {{
|
||||
width: 60;
|
||||
height: auto;
|
||||
padding: 1 2;
|
||||
background: {_BG_PANEL};
|
||||
border: round {_PRIMARY};
|
||||
}}
|
||||
#confirm-label {{
|
||||
width: 100%;
|
||||
content-align: center middle;
|
||||
margin-bottom: 1;
|
||||
}}
|
||||
#confirm-name {{
|
||||
width: 100%;
|
||||
content-align: center middle;
|
||||
color: {_PRIMARY};
|
||||
text-style: bold;
|
||||
margin-bottom: 1;
|
||||
}}
|
||||
#confirm-buttons {{
|
||||
width: 100%;
|
||||
height: 3;
|
||||
layout: horizontal;
|
||||
align: center middle;
|
||||
}}
|
||||
Button {{
|
||||
margin: 0 2;
|
||||
min-width: 12;
|
||||
}}
|
||||
"""
|
||||
|
||||
def __init__(self, checkpoint_name: str) -> None:
|
||||
super().__init__()
|
||||
self._checkpoint_name = checkpoint_name
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
with Vertical(id="confirm-dialog"):
|
||||
yield Static("Resume from this checkpoint?", id="confirm-label")
|
||||
yield Static(self._checkpoint_name, id="confirm-name")
|
||||
with Horizontal(id="confirm-buttons"):
|
||||
yield Button("Resume", variant="success", id="btn-yes")
|
||||
yield Button("Cancel", variant="default", id="btn-no")
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
self.dismiss(event.button.id == "btn-yes")
|
||||
|
||||
def on_key(self, event: Any) -> None:
|
||||
if event.key == "y":
|
||||
self.dismiss(True)
|
||||
elif event.key in ("n", "escape"):
|
||||
self.dismiss(False)
|
||||
# Return type: (location, action, inputs, task_output_overrides)
|
||||
_TuiResult = tuple[str, str, dict[str, Any] | None, dict[int, str] | None] | None
|
||||
|
||||
|
||||
class CheckpointTUI(App[str | None]):
|
||||
class CheckpointTUI(App[_TuiResult]):
|
||||
"""TUI to browse and inspect checkpoints.
|
||||
|
||||
Returns the checkpoint location string to resume from, or None if
|
||||
the user quit without selecting.
|
||||
Returns ``(location, action, inputs)`` where action is ``"resume"`` or
|
||||
``"fork"`` and inputs is a parsed dict or ``None``,
|
||||
or ``None`` if the user quit without selecting.
|
||||
"""
|
||||
|
||||
TITLE = "CrewAI Checkpoints"
|
||||
@@ -199,145 +108,431 @@ class CheckpointTUI(App[str | None]):
|
||||
background: {_PRIMARY};
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
Horizontal {{
|
||||
#main-layout {{
|
||||
height: 1fr;
|
||||
}}
|
||||
#cp-list {{
|
||||
width: 38%;
|
||||
#tree-panel {{
|
||||
width: 45%;
|
||||
background: {_BG_PANEL};
|
||||
border: round {_SECONDARY};
|
||||
padding: 0 1;
|
||||
scrollbar-color: {_PRIMARY};
|
||||
}}
|
||||
#cp-list:focus {{
|
||||
#tree-panel:focus-within {{
|
||||
border: round {_PRIMARY};
|
||||
}}
|
||||
#cp-list > .option-list--option-highlighted {{
|
||||
background: {_SECONDARY};
|
||||
color: {_TERTIARY};
|
||||
text-style: none;
|
||||
}}
|
||||
#cp-list > .option-list--option-highlighted * {{
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
#detail-container {{
|
||||
width: 62%;
|
||||
padding: 0 1;
|
||||
width: 55%;
|
||||
height: 1fr;
|
||||
}}
|
||||
#detail {{
|
||||
#detail-scroll {{
|
||||
height: 1fr;
|
||||
background: {_BG_PANEL};
|
||||
border: round {_SECONDARY};
|
||||
padding: 1 2;
|
||||
overflow-y: auto;
|
||||
scrollbar-color: {_PRIMARY};
|
||||
}}
|
||||
#detail:focus {{
|
||||
#detail-scroll:focus-within {{
|
||||
border: round {_PRIMARY};
|
||||
}}
|
||||
#detail-header {{
|
||||
margin-bottom: 1;
|
||||
}}
|
||||
#status {{
|
||||
height: 1;
|
||||
padding: 0 2;
|
||||
color: {_DIM};
|
||||
}}
|
||||
#inputs-section {{
|
||||
display: none;
|
||||
height: auto;
|
||||
max-height: 8;
|
||||
padding: 0 1;
|
||||
}}
|
||||
#inputs-section.visible {{
|
||||
display: block;
|
||||
}}
|
||||
#inputs-label {{
|
||||
height: 1;
|
||||
color: {_DIM};
|
||||
padding: 0 1;
|
||||
}}
|
||||
.input-row {{
|
||||
height: 3;
|
||||
padding: 0 1;
|
||||
}}
|
||||
.input-row Static {{
|
||||
width: auto;
|
||||
min-width: 12;
|
||||
padding: 1 1 0 0;
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
.input-row Input {{
|
||||
width: 1fr;
|
||||
}}
|
||||
#no-inputs-label {{
|
||||
height: 1;
|
||||
color: {_DIM};
|
||||
padding: 0 1;
|
||||
}}
|
||||
#action-buttons {{
|
||||
height: 3;
|
||||
align: right middle;
|
||||
padding: 0 1;
|
||||
display: none;
|
||||
}}
|
||||
#action-buttons.visible {{
|
||||
display: block;
|
||||
}}
|
||||
#action-buttons Button {{
|
||||
margin: 0 0 0 1;
|
||||
min-width: 10;
|
||||
}}
|
||||
#btn-resume {{
|
||||
background: {_SECONDARY};
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
#btn-resume:hover {{
|
||||
background: {_PRIMARY};
|
||||
}}
|
||||
#btn-fork {{
|
||||
background: {_PRIMARY};
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
#btn-fork:hover {{
|
||||
background: {_SECONDARY};
|
||||
}}
|
||||
.entity-title {{
|
||||
padding: 1 1 0 1;
|
||||
}}
|
||||
.entity-detail {{
|
||||
padding: 0 1;
|
||||
}}
|
||||
.task-output-editor {{
|
||||
height: auto;
|
||||
max-height: 10;
|
||||
margin: 0 1 1 1;
|
||||
border: round {_DIM};
|
||||
}}
|
||||
.task-output-editor:focus {{
|
||||
border: round {_PRIMARY};
|
||||
}}
|
||||
.task-label {{
|
||||
padding: 0 1;
|
||||
}}
|
||||
Tree {{
|
||||
background: {_BG_PANEL};
|
||||
}}
|
||||
Tree > .tree--cursor {{
|
||||
background: {_SECONDARY};
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
"""
|
||||
|
||||
BINDINGS: ClassVar[list[Binding | tuple[str, str] | tuple[str, str, str]]] = [
|
||||
("q", "quit", "Quit"),
|
||||
("r", "refresh", "Refresh"),
|
||||
("j", "cursor_down", "Down"),
|
||||
("k", "cursor_up", "Up"),
|
||||
]
|
||||
|
||||
def __init__(self, location: str = "./.checkpoints") -> None:
|
||||
super().__init__()
|
||||
self._location = location
|
||||
self._entries: list[dict[str, Any]] = []
|
||||
self._selected_idx: int = 0
|
||||
self._pending_location: str = ""
|
||||
self._selected_entry: dict[str, Any] | None = None
|
||||
self._input_keys: list[str] = []
|
||||
self._task_output_ids: list[tuple[int, str, str]] = []
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header(show_clock=False)
|
||||
with Horizontal():
|
||||
yield OptionList(id="cp-list")
|
||||
with Horizontal(id="main-layout"):
|
||||
tree: Tree[dict[str, Any]] = Tree("Checkpoints", id="tree-panel")
|
||||
tree.show_root = True
|
||||
tree.guide_depth = 3
|
||||
yield tree
|
||||
with Vertical(id="detail-container"):
|
||||
yield Static("", id="status")
|
||||
yield Static(
|
||||
f"\n [{_DIM}]Select a checkpoint from the list[/]", # noqa: S608
|
||||
id="detail",
|
||||
)
|
||||
with VerticalScroll(id="detail-scroll"):
|
||||
yield Static(
|
||||
f"[{_DIM}]Select a checkpoint from the tree[/]", # noqa: S608
|
||||
id="detail-header",
|
||||
)
|
||||
with Vertical(id="inputs-section"):
|
||||
yield Static("Inputs", id="inputs-label")
|
||||
with Horizontal(id="action-buttons"):
|
||||
yield Button("Resume", id="btn-resume")
|
||||
yield Button("Fork", id="btn-fork")
|
||||
yield Footer()
|
||||
|
||||
async def on_mount(self) -> None:
|
||||
self.query_one("#cp-list", OptionList).border_title = "Checkpoints"
|
||||
self.query_one("#detail", Static).border_title = "Detail"
|
||||
self._refresh_list()
|
||||
self._refresh_tree()
|
||||
self.query_one("#tree-panel", Tree).root.expand()
|
||||
|
||||
def _refresh_list(self) -> None:
|
||||
def _refresh_tree(self) -> None:
|
||||
self._entries = _load_entries(self._location)
|
||||
option_list = self.query_one("#cp-list", OptionList)
|
||||
option_list.clear_options()
|
||||
self._selected_entry = None
|
||||
|
||||
tree = self.query_one("#tree-panel", Tree)
|
||||
tree.clear()
|
||||
|
||||
if not self._entries:
|
||||
self.query_one("#detail", Static).update(
|
||||
f"\n [{_DIM}]No checkpoints in {self._location}[/]"
|
||||
self.query_one("#detail-header", Static).update(
|
||||
f"[{_DIM}]No checkpoints in {self._location}[/]"
|
||||
)
|
||||
self.query_one("#status", Static).update("")
|
||||
self.sub_title = self._location
|
||||
return
|
||||
|
||||
# Group by branch
|
||||
branches: dict[str, list[dict[str, Any]]] = defaultdict(list)
|
||||
for entry in self._entries:
|
||||
option_list.add_option(Option(_format_list_label(entry)))
|
||||
branch = entry.get("branch", "main")
|
||||
branches[branch].append(entry)
|
||||
|
||||
# Index checkpoint names to tree nodes so forks can attach
|
||||
node_by_name: dict[str, Any] = {}
|
||||
|
||||
def _make_label(e: dict[str, Any]) -> str:
|
||||
name = e.get("name", "")
|
||||
ts = e.get("ts") or ""
|
||||
trigger = e.get("trigger") or ""
|
||||
parts = [f"[bold]{_short_id(name)}[/]"]
|
||||
if ts:
|
||||
time_part = ts.split(" ")[-1] if " " in ts else ts
|
||||
parts.append(f"[{_DIM}]{time_part}[/]")
|
||||
if trigger:
|
||||
parts.append(f"[{_PRIMARY}]{trigger}[/]")
|
||||
return " ".join(parts)
|
||||
|
||||
fork_parents: set[str] = set()
|
||||
for branch_name, entries in branches.items():
|
||||
if branch_name == "main" or not entries:
|
||||
continue
|
||||
oldest = min(entries, key=lambda e: str(e.get("name", "")))
|
||||
first_parent = oldest.get("parent_id")
|
||||
if first_parent:
|
||||
fork_parents.add(str(first_parent))
|
||||
|
||||
def _add_checkpoint(parent_node: Any, e: dict[str, Any]) -> None:
|
||||
"""Add a checkpoint node — expandable only if a fork attaches to it."""
|
||||
cp_id = _entry_id(e)
|
||||
if cp_id in fork_parents:
|
||||
node = parent_node.add(
|
||||
_make_label(e), data=e, expand=False, allow_expand=True
|
||||
)
|
||||
else:
|
||||
node = parent_node.add_leaf(_make_label(e), data=e)
|
||||
node_by_name[cp_id] = node
|
||||
|
||||
if "main" in branches:
|
||||
for entry in reversed(branches["main"]):
|
||||
_add_checkpoint(tree.root, entry)
|
||||
|
||||
fork_branches = [
|
||||
(name, sorted(entries, key=lambda e: str(e.get("name", ""))))
|
||||
for name, entries in branches.items()
|
||||
if name != "main"
|
||||
]
|
||||
remaining = fork_branches
|
||||
max_passes = len(remaining) + 1
|
||||
while remaining and max_passes > 0:
|
||||
max_passes -= 1
|
||||
deferred = []
|
||||
made_progress = False
|
||||
for branch_name, entries in remaining:
|
||||
first_parent = entries[0].get("parent_id") if entries else None
|
||||
if first_parent and str(first_parent) not in node_by_name:
|
||||
deferred.append((branch_name, entries))
|
||||
continue
|
||||
attach_to: Any = tree.root
|
||||
if first_parent:
|
||||
attach_to = node_by_name.get(str(first_parent), tree.root)
|
||||
branch_label = (
|
||||
f"[bold {_SECONDARY}]{branch_name}[/] [{_DIM}]({len(entries)})[/]"
|
||||
)
|
||||
branch_node = attach_to.add(branch_label, expand=False)
|
||||
for entry in entries:
|
||||
_add_checkpoint(branch_node, entry)
|
||||
made_progress = True
|
||||
remaining = deferred
|
||||
if not made_progress:
|
||||
break
|
||||
|
||||
for branch_name, entries in remaining:
|
||||
branch_label = (
|
||||
f"[bold {_SECONDARY}]{branch_name}[/] "
|
||||
f"[{_DIM}]({len(entries)})[/] [{_DIM}](orphaned)[/]"
|
||||
)
|
||||
branch_node = tree.root.add(branch_label, expand=False)
|
||||
for entry in entries:
|
||||
_add_checkpoint(branch_node, entry)
|
||||
|
||||
count = len(self._entries)
|
||||
storage = "SQLite" if _is_sqlite(self._location) else "JSON"
|
||||
self.sub_title = f"{self._location}"
|
||||
self.sub_title = self._location
|
||||
self.query_one("#status", Static).update(f" {count} checkpoint(s) | {storage}")
|
||||
|
||||
async def on_option_list_option_highlighted(
|
||||
self,
|
||||
event: OptionList.OptionHighlighted,
|
||||
) -> None:
|
||||
idx = event.option_index
|
||||
if idx is None:
|
||||
return
|
||||
if idx < len(self._entries):
|
||||
self._selected_idx = idx
|
||||
entry = self._entries[idx]
|
||||
self.query_one("#detail", Static).update(_format_detail(entry))
|
||||
async def _show_detail(self, entry: dict[str, Any]) -> None:
|
||||
"""Update the detail panel for a checkpoint entry."""
|
||||
self._selected_entry = entry
|
||||
self.query_one("#action-buttons").add_class("visible")
|
||||
|
||||
def action_cursor_down(self) -> None:
|
||||
self.query_one("#cp-list", OptionList).action_cursor_down()
|
||||
detail_scroll = self.query_one("#detail-scroll", VerticalScroll)
|
||||
|
||||
def action_cursor_up(self) -> None:
|
||||
self.query_one("#cp-list", OptionList).action_cursor_up()
|
||||
# Remove all dynamic children except the header — await so IDs are freed
|
||||
to_remove = [c for c in detail_scroll.children if c.id != "detail-header"]
|
||||
for child in to_remove:
|
||||
await child.remove()
|
||||
|
||||
async def on_option_list_option_selected(
|
||||
self,
|
||||
event: OptionList.OptionSelected,
|
||||
) -> None:
|
||||
idx = event.option_index
|
||||
if idx is None or idx >= len(self._entries):
|
||||
return
|
||||
entry = self._entries[idx]
|
||||
# Header
|
||||
name = entry.get("name", "")
|
||||
ts = entry.get("ts") or "unknown"
|
||||
trigger = entry.get("trigger") or ""
|
||||
branch = entry.get("branch", "main")
|
||||
parent_id = entry.get("parent_id")
|
||||
|
||||
header_lines = [
|
||||
f"[bold {_PRIMARY}]{name}[/]",
|
||||
f"[{_DIM}]{'─' * 50}[/]",
|
||||
"",
|
||||
f" [bold]Time[/] {ts}",
|
||||
]
|
||||
if "size" in entry:
|
||||
header_lines.append(f" [bold]Size[/] {_format_size(entry['size'])}")
|
||||
header_lines.append(f" [bold]Events[/] {entry.get('event_count', 0)}")
|
||||
if trigger:
|
||||
header_lines.append(f" [bold]Trigger[/] [{_PRIMARY}]{trigger}[/]")
|
||||
header_lines.append(f" [bold]Branch[/] [{_SECONDARY}]{branch}[/]")
|
||||
if parent_id:
|
||||
header_lines.append(f" [bold]Parent[/] [{_DIM}]{parent_id}[/]")
|
||||
if "path" in entry:
|
||||
loc = entry["path"]
|
||||
elif _is_sqlite(self._location):
|
||||
loc = f"{self._location}#{entry['name']}"
|
||||
else:
|
||||
loc = entry.get("name", "")
|
||||
self._pending_location = loc
|
||||
name = entry.get("name", loc)
|
||||
self.push_screen(ConfirmResumeScreen(name), self._on_confirm)
|
||||
header_lines.append(f" [bold]Path[/] [{_DIM}]{entry['path']}[/]")
|
||||
if "db" in entry:
|
||||
header_lines.append(f" [bold]Database[/] [{_DIM}]{entry['db']}[/]")
|
||||
|
||||
def _on_confirm(self, confirmed: bool | None) -> None:
|
||||
if confirmed:
|
||||
self.exit(self._pending_location)
|
||||
else:
|
||||
self._pending_location = ""
|
||||
self.query_one("#detail-header", Static).update("\n".join(header_lines))
|
||||
|
||||
# Entity details and editable task outputs — mounted flat for scrolling
|
||||
self._task_output_ids = []
|
||||
flat_task_idx = 0
|
||||
for ent_idx, ent in enumerate(entry.get("entities", [])):
|
||||
etype = ent.get("type", "unknown")
|
||||
ename = ent.get("name", "unnamed")
|
||||
completed = ent.get("tasks_completed")
|
||||
total = ent.get("tasks_total")
|
||||
entity_title = f"[bold {_SECONDARY}]{etype}: {ename}[/]"
|
||||
if completed is not None and total is not None:
|
||||
entity_title += f" [{_DIM}]{completed}/{total} tasks[/]"
|
||||
await detail_scroll.mount(Static(entity_title, classes="entity-title"))
|
||||
await detail_scroll.mount(
|
||||
Static(_build_entity_header(ent), classes="entity-detail")
|
||||
)
|
||||
|
||||
tasks = ent.get("tasks", [])
|
||||
for i, task in enumerate(tasks):
|
||||
desc = str(task.get("description", ""))
|
||||
if len(desc) > 55:
|
||||
desc = desc[:52] + "..."
|
||||
if task.get("completed"):
|
||||
icon = "[green]✓[/]"
|
||||
await detail_scroll.mount(
|
||||
Static(f" {icon} {i + 1}. {desc}", classes="task-label")
|
||||
)
|
||||
output_text = task.get("output", "")
|
||||
editor_id = f"task-output-{ent_idx}-{i}"
|
||||
await detail_scroll.mount(
|
||||
TextArea(
|
||||
str(output_text),
|
||||
classes="task-output-editor",
|
||||
id=editor_id,
|
||||
)
|
||||
)
|
||||
self._task_output_ids.append(
|
||||
(flat_task_idx, editor_id, str(output_text))
|
||||
)
|
||||
else:
|
||||
icon = "[yellow]○[/]"
|
||||
await detail_scroll.mount(
|
||||
Static(f" {icon} {i + 1}. {desc}", classes="task-label")
|
||||
)
|
||||
flat_task_idx += 1
|
||||
|
||||
# Build input fields
|
||||
await self._build_input_fields(entry.get("inputs", {}))
|
||||
|
||||
async def _build_input_fields(self, inputs: dict[str, Any]) -> None:
|
||||
"""Rebuild the inputs section with one field per input key."""
|
||||
section = self.query_one("#inputs-section")
|
||||
|
||||
# Remove old dynamic children — await so IDs are freed
|
||||
for widget in list(section.query(".input-row, .no-inputs")):
|
||||
await widget.remove()
|
||||
|
||||
self._input_keys = []
|
||||
|
||||
if not inputs:
|
||||
await section.mount(Static(f"[{_DIM}]No inputs[/]", classes="no-inputs"))
|
||||
section.add_class("visible")
|
||||
return
|
||||
|
||||
for key, value in inputs.items():
|
||||
self._input_keys.append(key)
|
||||
row = Horizontal(classes="input-row")
|
||||
row.compose_add_child(Static(f"[bold]{key}[/]"))
|
||||
row.compose_add_child(
|
||||
Input(value=str(value), placeholder=key, id=f"input-{key}")
|
||||
)
|
||||
await section.mount(row)
|
||||
|
||||
section.add_class("visible")
|
||||
|
||||
def _collect_inputs(self) -> dict[str, Any] | None:
|
||||
"""Collect current values from input fields."""
|
||||
if not self._input_keys:
|
||||
return None
|
||||
result: dict[str, Any] = {}
|
||||
for key in self._input_keys:
|
||||
widget = self.query_one(f"#input-{key}", Input)
|
||||
result[key] = widget.value
|
||||
return result
|
||||
|
||||
def _collect_task_overrides(self) -> dict[int, str] | None:
|
||||
"""Collect edited task outputs. Returns only changed values."""
|
||||
if not self._task_output_ids or self._selected_entry is None:
|
||||
return None
|
||||
overrides: dict[int, str] = {}
|
||||
for task_idx, editor_id, original in self._task_output_ids:
|
||||
editor = self.query_one(f"#{editor_id}", TextArea)
|
||||
if editor.text != original:
|
||||
overrides[task_idx] = editor.text
|
||||
return overrides or None
|
||||
|
||||
def _resolve_location(self, entry: dict[str, Any]) -> str:
|
||||
"""Get the restore location string for a checkpoint entry."""
|
||||
if "path" in entry:
|
||||
return str(entry["path"])
|
||||
if _is_sqlite(self._location):
|
||||
return f"{self._location}#{entry['name']}"
|
||||
return str(entry.get("name", ""))
|
||||
|
||||
async def on_tree_node_highlighted(
|
||||
self, event: Tree.NodeHighlighted[dict[str, Any]]
|
||||
) -> None:
|
||||
if event.node.data is not None:
|
||||
await self._show_detail(event.node.data)
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
if self._selected_entry is None:
|
||||
return
|
||||
inputs = self._collect_inputs()
|
||||
overrides = self._collect_task_overrides()
|
||||
loc = self._resolve_location(self._selected_entry)
|
||||
if event.button.id == "btn-resume":
|
||||
self.exit((loc, "resume", inputs, overrides))
|
||||
elif event.button.id == "btn-fork":
|
||||
self.exit((loc, "fork", inputs, overrides))
|
||||
|
||||
def action_refresh(self) -> None:
|
||||
self._refresh_list()
|
||||
self._refresh_tree()
|
||||
|
||||
|
||||
async def _run_checkpoint_tui_async(location: str) -> None:
|
||||
@@ -345,18 +540,78 @@ async def _run_checkpoint_tui_async(location: str) -> None:
|
||||
import click
|
||||
|
||||
app = CheckpointTUI(location=location)
|
||||
selected = await app.run_async()
|
||||
selection = await app.run_async()
|
||||
|
||||
if selected is None:
|
||||
if selection is None:
|
||||
return
|
||||
|
||||
click.echo(f"\nResuming from: {selected}\n")
|
||||
selected, action, inputs, task_overrides = selection
|
||||
|
||||
from crewai.crew import Crew
|
||||
from crewai.state.checkpoint_config import CheckpointConfig
|
||||
|
||||
crew = Crew.from_checkpoint(CheckpointConfig(restore_from=selected))
|
||||
result = await crew.akickoff()
|
||||
config = CheckpointConfig(restore_from=selected)
|
||||
|
||||
if action == "fork":
|
||||
click.echo(f"\nForking from: {selected}\n")
|
||||
crew = Crew.fork(config)
|
||||
else:
|
||||
click.echo(f"\nResuming from: {selected}\n")
|
||||
crew = Crew.from_checkpoint(config)
|
||||
|
||||
if task_overrides:
|
||||
click.echo("Modifications:")
|
||||
overridden_agents: set[int] = set()
|
||||
for task_idx, new_output in task_overrides.items():
|
||||
if task_idx < len(crew.tasks) and crew.tasks[task_idx].output is not None:
|
||||
desc = crew.tasks[task_idx].description or f"Task {task_idx + 1}"
|
||||
if len(desc) > 60:
|
||||
desc = desc[:57] + "..."
|
||||
crew.tasks[task_idx].output.raw = new_output # type: ignore[union-attr]
|
||||
preview = new_output.replace("\n", " ")
|
||||
if len(preview) > 80:
|
||||
preview = preview[:77] + "..."
|
||||
click.echo(f" Task {task_idx + 1}: {desc}")
|
||||
click.echo(f" -> {preview}")
|
||||
agent = crew.tasks[task_idx].agent
|
||||
if agent and agent.agent_executor:
|
||||
nth = sum(1 for t in crew.tasks[:task_idx] if t.agent is agent)
|
||||
messages = agent.agent_executor.messages
|
||||
system_positions = [
|
||||
i for i, m in enumerate(messages) if m.get("role") == "system"
|
||||
]
|
||||
if nth < len(system_positions):
|
||||
seg_start = system_positions[nth]
|
||||
seg_end = (
|
||||
system_positions[nth + 1]
|
||||
if nth + 1 < len(system_positions)
|
||||
else len(messages)
|
||||
)
|
||||
for j in range(seg_end - 1, seg_start, -1):
|
||||
if messages[j].get("role") == "assistant":
|
||||
messages[j]["content"] = new_output
|
||||
break
|
||||
overridden_agents.add(id(agent))
|
||||
|
||||
earliest = min(task_overrides)
|
||||
for offset, subsequent in enumerate(
|
||||
crew.tasks[earliest + 1 :], start=earliest + 1
|
||||
):
|
||||
if subsequent.output and offset not in task_overrides:
|
||||
subsequent.output = None
|
||||
if subsequent.agent and subsequent.agent.agent_executor:
|
||||
subsequent.agent.agent_executor._resuming = False
|
||||
if id(subsequent.agent) not in overridden_agents:
|
||||
subsequent.agent.agent_executor.messages = []
|
||||
click.echo()
|
||||
|
||||
if inputs:
|
||||
click.echo("Inputs:")
|
||||
for k, v in inputs.items():
|
||||
click.echo(f" {k}: {v}")
|
||||
click.echo()
|
||||
|
||||
result = await crew.akickoff(inputs=inputs)
|
||||
click.echo(f"\nResult: {getattr(result, 'raw', result)}")
|
||||
|
||||
|
||||
|
||||
@@ -793,6 +793,9 @@ def traces_status() -> None:
|
||||
@click.pass_context
|
||||
def checkpoint(ctx: click.Context, location: str) -> None:
|
||||
"""Browse and inspect checkpoints. Launches a TUI when called without a subcommand."""
|
||||
from crewai.cli.checkpoint_cli import _detect_location
|
||||
|
||||
location = _detect_location(location)
|
||||
ctx.ensure_object(dict)
|
||||
ctx.obj["location"] = location
|
||||
if ctx.invoked_subcommand is None:
|
||||
@@ -805,18 +808,18 @@ def checkpoint(ctx: click.Context, location: str) -> None:
|
||||
@click.argument("location", default="./.checkpoints")
|
||||
def checkpoint_list(location: str) -> None:
|
||||
"""List checkpoints in a directory."""
|
||||
from crewai.cli.checkpoint_cli import list_checkpoints
|
||||
from crewai.cli.checkpoint_cli import _detect_location, list_checkpoints
|
||||
|
||||
list_checkpoints(location)
|
||||
list_checkpoints(_detect_location(location))
|
||||
|
||||
|
||||
@checkpoint.command("info")
|
||||
@click.argument("path", default="./.checkpoints")
|
||||
def checkpoint_info(path: str) -> None:
|
||||
"""Show details of a checkpoint. Pass a file or directory for latest."""
|
||||
from crewai.cli.checkpoint_cli import info_checkpoint
|
||||
from crewai.cli.checkpoint_cli import _detect_location, info_checkpoint
|
||||
|
||||
info_checkpoint(path)
|
||||
info_checkpoint(_detect_location(path))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -436,6 +436,13 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if agent.agent_executor is not None and task.output is None:
|
||||
agent.agent_executor.task = task
|
||||
break
|
||||
for task in self.tasks:
|
||||
if task.checkpoint_original_description is not None:
|
||||
task._original_description = task.checkpoint_original_description
|
||||
if task.checkpoint_original_expected_output is not None:
|
||||
task._original_expected_output = (
|
||||
task.checkpoint_original_expected_output
|
||||
)
|
||||
if self.checkpoint_inputs is not None:
|
||||
self._inputs = self.checkpoint_inputs
|
||||
if self.checkpoint_kickoff_event_id is not None:
|
||||
|
||||
@@ -213,6 +213,9 @@ class CheckpointConfig(BaseModel):
|
||||
def _register_handlers(self) -> CheckpointConfig:
|
||||
from crewai.state.checkpoint_listener import _ensure_handlers_registered
|
||||
|
||||
if isinstance(self.provider, SqliteProvider) and not Path(self.location).suffix:
|
||||
self.location = f"{self.location}.db"
|
||||
|
||||
_ensure_handlers_registered()
|
||||
return self
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ avoids per-event overhead when no entity uses checkpointing.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
@@ -102,10 +103,15 @@ def _find_checkpoint(source: Any) -> CheckpointConfig | None:
|
||||
return None
|
||||
|
||||
|
||||
def _do_checkpoint(state: RuntimeState, cfg: CheckpointConfig) -> None:
|
||||
def _do_checkpoint(
|
||||
state: RuntimeState, cfg: CheckpointConfig, event: BaseEvent | None = None
|
||||
) -> None:
|
||||
"""Write a checkpoint and prune old ones if configured."""
|
||||
_prepare_entities(state.root)
|
||||
data = state.model_dump_json()
|
||||
payload = state.model_dump(mode="json")
|
||||
if event is not None:
|
||||
payload["trigger"] = event.type
|
||||
data = json.dumps(payload)
|
||||
location = cfg.provider.checkpoint(
|
||||
data,
|
||||
cfg.location,
|
||||
@@ -134,7 +140,7 @@ def _on_any_event(source: Any, event: BaseEvent, state: Any) -> None:
|
||||
if cfg is None:
|
||||
return
|
||||
try:
|
||||
_do_checkpoint(state, cfg)
|
||||
_do_checkpoint(state, cfg, event)
|
||||
except Exception:
|
||||
logger.warning("Auto-checkpoint failed for event %s", event.type, exc_info=True)
|
||||
|
||||
|
||||
@@ -66,6 +66,9 @@ def _sync_checkpoint_fields(entity: object) -> None:
|
||||
entity.checkpoint_inputs = entity._inputs
|
||||
entity.checkpoint_train = entity._train
|
||||
entity.checkpoint_kickoff_event_id = entity._kickoff_event_id
|
||||
for task in entity.tasks:
|
||||
task.checkpoint_original_description = task._original_description
|
||||
task.checkpoint_original_expected_output = task._original_expected_output
|
||||
|
||||
|
||||
def _migrate(data: dict[str, Any]) -> dict[str, Any]:
|
||||
@@ -121,7 +124,7 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
"parent_id": self._parent_id,
|
||||
"branch": self._branch,
|
||||
"entities": [e.model_dump(mode="json") for e in self.root],
|
||||
"event_record": self._event_record.model_dump(),
|
||||
"event_record": self._event_record.model_dump(mode="json"),
|
||||
}
|
||||
|
||||
@model_validator(mode="wrap")
|
||||
@@ -194,7 +197,10 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
return result
|
||||
|
||||
def fork(self, branch: str | None = None) -> None:
|
||||
"""Mark this state as a fork for subsequent checkpoints.
|
||||
"""Create a new execution branch and write an initial checkpoint.
|
||||
|
||||
If this state was restored from a checkpoint, an initial checkpoint
|
||||
is written on the new branch so the fork point is recorded.
|
||||
|
||||
Args:
|
||||
branch: Branch label. Auto-generated from the current checkpoint
|
||||
@@ -204,7 +210,7 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
if branch:
|
||||
self._branch = branch
|
||||
elif self._checkpoint_id:
|
||||
self._branch = f"fork/{self._checkpoint_id}"
|
||||
self._branch = f"fork/{self._checkpoint_id}_{uuid.uuid4().hex[:6]}"
|
||||
else:
|
||||
self._branch = f"fork/{uuid.uuid4().hex[:8]}"
|
||||
|
||||
@@ -227,6 +233,7 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
provider = detect_provider(location)
|
||||
raw = provider.from_checkpoint(location)
|
||||
state = cls.model_validate_json(raw, **kwargs)
|
||||
state._provider = provider
|
||||
checkpoint_id = provider.extract_id(location)
|
||||
state._checkpoint_id = checkpoint_id
|
||||
state._parent_id = checkpoint_id
|
||||
@@ -253,6 +260,7 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
provider = detect_provider(location)
|
||||
raw = await provider.afrom_checkpoint(location)
|
||||
state = cls.model_validate_json(raw, **kwargs)
|
||||
state._provider = provider
|
||||
checkpoint_id = provider.extract_id(location)
|
||||
state._checkpoint_id = checkpoint_id
|
||||
state._parent_id = checkpoint_id
|
||||
|
||||
@@ -230,6 +230,8 @@ class Task(BaseModel):
|
||||
_original_description: str | None = PrivateAttr(default=None)
|
||||
_original_expected_output: str | None = PrivateAttr(default=None)
|
||||
_original_output_file: str | None = PrivateAttr(default=None)
|
||||
checkpoint_original_description: str | None = Field(default=None, exclude=False)
|
||||
checkpoint_original_expected_output: str | None = Field(default=None, exclude=False)
|
||||
_thread: threading.Thread | None = PrivateAttr(default=None)
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
|
||||
@@ -296,7 +296,8 @@ class TestRuntimeStateLineage:
|
||||
state = self._make_state()
|
||||
state._checkpoint_id = "20260409T120000_abc12345"
|
||||
state.fork()
|
||||
assert state._branch == "fork/20260409T120000_abc12345"
|
||||
assert state._branch.startswith("fork/20260409T120000_abc12345_")
|
||||
assert len(state._branch) == len("fork/20260409T120000_abc12345_") + 6
|
||||
|
||||
def test_fork_no_checkpoint_id_unique(self) -> None:
|
||||
state = self._make_state()
|
||||
|
||||
Reference in New Issue
Block a user