mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
feat: store trigger in checkpoint data, editable inputs/outputs in TUI
Write the actual triggering event type into checkpoint JSON at write time instead of inferring from the event record. Adds editable input fields and task output overrides to the TUI for what-if exploration. Unique fork branch names prevent collisions on repeated forks.
This commit is contained in:
@@ -6,12 +6,16 @@ from datetime import datetime
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
|
||||
|
||||
_PLACEHOLDER_RE = re.compile(r"\{([A-Za-z_][A-Za-z0-9_\-]*)}")
|
||||
|
||||
|
||||
_SQLITE_MAGIC = b"SQLite format 3\x00"
|
||||
|
||||
_SELECT_ALL = """
|
||||
@@ -71,13 +75,7 @@ def _parse_checkpoint_json(raw: str, source: str) -> dict[str, Any]:
|
||||
nodes = data.get("event_record", {}).get("nodes", {})
|
||||
event_count = len(nodes)
|
||||
|
||||
trigger_event = None
|
||||
if nodes:
|
||||
last_node = max(
|
||||
nodes.values(),
|
||||
key=lambda n: n.get("event", {}).get("emission_sequence") or 0,
|
||||
)
|
||||
trigger_event = last_node.get("event", {}).get("type")
|
||||
trigger_event = data.get("trigger")
|
||||
|
||||
parsed_entities: list[dict[str, Any]] = []
|
||||
for entity in entities:
|
||||
@@ -95,21 +93,51 @@ def _parse_checkpoint_json(raw: str, source: str) -> dict[str, Any]:
|
||||
{
|
||||
"description": t.get("description", ""),
|
||||
"completed": t.get("output") is not None,
|
||||
"output": (t.get("output") or {}).get("raw", ""),
|
||||
}
|
||||
for t in tasks
|
||||
]
|
||||
parsed_entities.append(info)
|
||||
|
||||
inputs: dict[str, Any] = {}
|
||||
for entity in entities:
|
||||
cp_inputs = entity.get("checkpoint_inputs")
|
||||
if isinstance(cp_inputs, dict) and cp_inputs:
|
||||
inputs = dict(cp_inputs)
|
||||
break
|
||||
|
||||
for entity in entities:
|
||||
for task in entity.get("tasks", []):
|
||||
for field in (
|
||||
"checkpoint_original_description",
|
||||
"checkpoint_original_expected_output",
|
||||
):
|
||||
text = task.get(field) or ""
|
||||
for match in _PLACEHOLDER_RE.findall(text):
|
||||
if match not in inputs:
|
||||
inputs[match] = ""
|
||||
for agent in entity.get("agents", []):
|
||||
for field in ("role", "goal", "backstory"):
|
||||
text = agent.get(field) or ""
|
||||
for match in _PLACEHOLDER_RE.findall(text):
|
||||
if match not in inputs:
|
||||
inputs[match] = ""
|
||||
|
||||
branch = data.get("branch", "main")
|
||||
parent_id = data.get("parent_id")
|
||||
|
||||
return {
|
||||
"source": source,
|
||||
"event_count": event_count,
|
||||
"trigger": trigger_event,
|
||||
"entities": parsed_entities,
|
||||
"branch": data.get("branch", "main"),
|
||||
"parent_id": data.get("parent_id"),
|
||||
"branch": branch,
|
||||
"parent_id": parent_id,
|
||||
"inputs": inputs,
|
||||
}
|
||||
|
||||
|
||||
|
||||
def _format_size(size: int) -> str:
|
||||
if size < 1024:
|
||||
return f"{size}B"
|
||||
|
||||
@@ -10,10 +10,11 @@ from textual.binding import Binding
|
||||
from textual.containers import Horizontal, Vertical, VerticalScroll
|
||||
from textual.widgets import (
|
||||
Button,
|
||||
Collapsible,
|
||||
Footer,
|
||||
Header,
|
||||
Input,
|
||||
Static,
|
||||
TextArea,
|
||||
Tree,
|
||||
)
|
||||
|
||||
@@ -39,6 +40,7 @@ def _load_entries(location: str) -> list[dict[str, Any]]:
|
||||
return _list_json(location)
|
||||
|
||||
|
||||
|
||||
def _short_id(name: str) -> str:
|
||||
"""Shorten a checkpoint name for tree display."""
|
||||
if len(name) > 30:
|
||||
@@ -46,14 +48,9 @@ def _short_id(name: str) -> str:
|
||||
return name
|
||||
|
||||
|
||||
def _build_entity_detail(ent: dict[str, Any]) -> str:
|
||||
"""Build rich text for a single entity."""
|
||||
def _build_entity_header(ent: dict[str, Any]) -> str:
|
||||
"""Build rich text header for an entity (progress bar only)."""
|
||||
lines: list[str] = []
|
||||
eid = str(ent.get("id", ""))[:8]
|
||||
etype = ent.get("type", "unknown")
|
||||
ename = ent.get("name", "unnamed")
|
||||
lines.append(f"[bold {_SECONDARY}]{etype}[/]: {ename} [{_DIM}]{eid}[/]")
|
||||
|
||||
tasks = ent.get("tasks")
|
||||
if isinstance(tasks, list):
|
||||
completed = ent.get("tasks_completed", 0)
|
||||
@@ -63,22 +60,19 @@ def _build_entity_detail(ent: dict[str, Any]) -> str:
|
||||
filled = int(bar_len * completed / total) if total else 0
|
||||
bar = f"[{_PRIMARY}]{'█' * filled}[/][{_DIM}]{'░' * (bar_len - filled)}[/]"
|
||||
lines.append(f"{bar} {completed}/{total} tasks ({pct}%)")
|
||||
lines.append("")
|
||||
for i, task in enumerate(tasks):
|
||||
icon = "[green]✓[/]" if task.get("completed") else "[yellow]○[/]"
|
||||
desc = str(task.get("description", ""))
|
||||
if len(desc) > 55:
|
||||
desc = desc[:52] + "..."
|
||||
lines.append(f" {icon} {i + 1}. {desc}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class CheckpointTUI(App[tuple[str, str] | None]):
|
||||
# Return type: (location, action, inputs, task_output_overrides)
|
||||
_TuiResult = tuple[str, str, dict[str, Any] | None, dict[int, str] | None] | None
|
||||
|
||||
|
||||
class CheckpointTUI(App[_TuiResult]):
|
||||
"""TUI to browse and inspect checkpoints.
|
||||
|
||||
Returns ``(location, action)`` where action is ``"resume"`` or
|
||||
``"fork"``, or ``None`` if the user quit without selecting.
|
||||
Returns ``(location, action, inputs)`` where action is ``"resume"`` or
|
||||
``"fork"`` and inputs is a parsed dict or ``None``,
|
||||
or ``None`` if the user quit without selecting.
|
||||
"""
|
||||
|
||||
TITLE = "CrewAI Checkpoints"
|
||||
@@ -114,6 +108,7 @@ class CheckpointTUI(App[tuple[str, str] | None]):
|
||||
}}
|
||||
#detail-container {{
|
||||
width: 55%;
|
||||
height: 1fr;
|
||||
}}
|
||||
#detail-scroll {{
|
||||
height: 1fr;
|
||||
@@ -133,9 +128,41 @@ class CheckpointTUI(App[tuple[str, str] | None]):
|
||||
padding: 0 2;
|
||||
color: {_DIM};
|
||||
}}
|
||||
#inputs-section {{
|
||||
display: none;
|
||||
height: auto;
|
||||
max-height: 8;
|
||||
padding: 0 1;
|
||||
}}
|
||||
#inputs-section.visible {{
|
||||
display: block;
|
||||
}}
|
||||
#inputs-label {{
|
||||
height: 1;
|
||||
color: {_DIM};
|
||||
padding: 0 1;
|
||||
}}
|
||||
.input-row {{
|
||||
height: 3;
|
||||
padding: 0 1;
|
||||
}}
|
||||
.input-row Static {{
|
||||
width: auto;
|
||||
min-width: 12;
|
||||
padding: 1 1 0 0;
|
||||
color: {_TERTIARY};
|
||||
}}
|
||||
.input-row Input {{
|
||||
width: 1fr;
|
||||
}}
|
||||
#no-inputs-label {{
|
||||
height: 1;
|
||||
color: {_DIM};
|
||||
padding: 0 1;
|
||||
}}
|
||||
#action-buttons {{
|
||||
height: 3;
|
||||
align: center middle;
|
||||
align: right middle;
|
||||
padding: 0 1;
|
||||
display: none;
|
||||
}}
|
||||
@@ -143,8 +170,8 @@ class CheckpointTUI(App[tuple[str, str] | None]):
|
||||
display: block;
|
||||
}}
|
||||
#action-buttons Button {{
|
||||
margin: 0 1;
|
||||
min-width: 14;
|
||||
margin: 0 0 0 1;
|
||||
min-width: 10;
|
||||
}}
|
||||
#btn-resume {{
|
||||
background: {_SECONDARY};
|
||||
@@ -160,13 +187,24 @@ class CheckpointTUI(App[tuple[str, str] | None]):
|
||||
#btn-fork:hover {{
|
||||
background: {_SECONDARY};
|
||||
}}
|
||||
Collapsible {{
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
.entity-title {{
|
||||
padding: 1 1 0 1;
|
||||
}}
|
||||
.entity-detail {{
|
||||
padding: 0 1;
|
||||
}}
|
||||
.task-output-editor {{
|
||||
height: auto;
|
||||
max-height: 10;
|
||||
margin: 0 1 1 1;
|
||||
border: round {_DIM};
|
||||
}}
|
||||
.task-output-editor:focus {{
|
||||
border: round {_PRIMARY};
|
||||
}}
|
||||
.task-label {{
|
||||
padding: 0 1;
|
||||
}}
|
||||
Tree {{
|
||||
background: {_BG_PANEL};
|
||||
}}
|
||||
@@ -186,6 +224,7 @@ class CheckpointTUI(App[tuple[str, str] | None]):
|
||||
self._location = location
|
||||
self._entries: list[dict[str, Any]] = []
|
||||
self._selected_entry: dict[str, Any] | None = None
|
||||
self._input_keys: list[str] = []
|
||||
|
||||
def compose(self) -> ComposeResult:
|
||||
yield Header(show_clock=False)
|
||||
@@ -201,6 +240,8 @@ class CheckpointTUI(App[tuple[str, str] | None]):
|
||||
f"[{_DIM}]Select a checkpoint from the tree[/]", # noqa: S608
|
||||
id="detail-header",
|
||||
)
|
||||
with Vertical(id="inputs-section"):
|
||||
yield Static("Inputs", id="inputs-label")
|
||||
with Horizontal(id="action-buttons"):
|
||||
yield Button("Resume", id="btn-resume")
|
||||
yield Button("Fork", id="btn-fork")
|
||||
@@ -295,6 +336,7 @@ class CheckpointTUI(App[tuple[str, str] | None]):
|
||||
if first_parent
|
||||
else tree.root
|
||||
)
|
||||
# Drop fork-initial checkpoint — it's just a lineage marker
|
||||
branch_label = (
|
||||
f"[bold {_SECONDARY}]{branch_name}[/] [{_DIM}]({len(entries)})[/]"
|
||||
)
|
||||
@@ -308,16 +350,17 @@ class CheckpointTUI(App[tuple[str, str] | None]):
|
||||
self.sub_title = self._location
|
||||
self.query_one("#status", Static).update(f" {count} checkpoint(s) | {storage}")
|
||||
|
||||
def _show_detail(self, entry: dict[str, Any]) -> None:
|
||||
async def _show_detail(self, entry: dict[str, Any]) -> None:
|
||||
"""Update the detail panel for a checkpoint entry."""
|
||||
self._selected_entry = entry
|
||||
self.query_one("#action-buttons").add_class("visible")
|
||||
|
||||
detail_scroll = self.query_one("#detail-scroll", VerticalScroll)
|
||||
|
||||
# Remove old collapsibles
|
||||
for widget in list(detail_scroll.query("Collapsible")):
|
||||
widget.remove()
|
||||
# Remove all dynamic children except the header — await so IDs are freed
|
||||
to_remove = [c for c in detail_scroll.children if c.id != "detail-header"]
|
||||
for child in to_remove:
|
||||
await child.remove()
|
||||
|
||||
# Header
|
||||
name = entry.get("name", "")
|
||||
@@ -347,19 +390,102 @@ class CheckpointTUI(App[tuple[str, str] | None]):
|
||||
|
||||
self.query_one("#detail-header", Static).update("\n".join(header_lines))
|
||||
|
||||
# Entity collapsibles
|
||||
for ent in entry.get("entities", []):
|
||||
# Entity details and editable task outputs — mounted flat for scrolling
|
||||
self._task_output_ids = []
|
||||
for ent_idx, ent in enumerate(entry.get("entities", [])):
|
||||
etype = ent.get("type", "unknown")
|
||||
ename = ent.get("name", "unnamed")
|
||||
completed = ent.get("tasks_completed")
|
||||
total = ent.get("tasks_total")
|
||||
title = f"{etype}: {ename}"
|
||||
entity_title = f"[bold {_SECONDARY}]{etype}: {ename}[/]"
|
||||
if completed is not None and total is not None:
|
||||
title += f" [{completed}/{total} tasks]"
|
||||
entity_title += f" [{_DIM}]{completed}/{total} tasks[/]"
|
||||
detail_scroll.mount(Static(entity_title, classes="entity-title"))
|
||||
detail_scroll.mount(
|
||||
Static(_build_entity_header(ent), classes="entity-detail")
|
||||
)
|
||||
|
||||
content = Static(_build_entity_detail(ent), classes="entity-detail")
|
||||
collapsible = Collapsible(content, title=title, collapsed=False)
|
||||
detail_scroll.mount(collapsible)
|
||||
tasks = ent.get("tasks", [])
|
||||
for i, task in enumerate(tasks):
|
||||
desc = str(task.get("description", ""))
|
||||
if len(desc) > 55:
|
||||
desc = desc[:52] + "..."
|
||||
if task.get("completed"):
|
||||
icon = "[green]✓[/]"
|
||||
detail_scroll.mount(
|
||||
Static(f" {icon} {i + 1}. {desc}", classes="task-label")
|
||||
)
|
||||
output_text = task.get("output", "")
|
||||
editor_id = f"task-output-{ent_idx}-{i}"
|
||||
detail_scroll.mount(
|
||||
TextArea(
|
||||
str(output_text),
|
||||
classes="task-output-editor",
|
||||
id=editor_id,
|
||||
)
|
||||
)
|
||||
self._task_output_ids.append((i, editor_id))
|
||||
else:
|
||||
icon = "[yellow]○[/]"
|
||||
detail_scroll.mount(
|
||||
Static(f" {icon} {i + 1}. {desc}", classes="task-label")
|
||||
)
|
||||
|
||||
# Build input fields
|
||||
await self._build_input_fields(entry.get("inputs", {}))
|
||||
|
||||
async def _build_input_fields(self, inputs: dict[str, Any]) -> None:
|
||||
"""Rebuild the inputs section with one field per input key."""
|
||||
section = self.query_one("#inputs-section")
|
||||
|
||||
# Remove old dynamic children — await so IDs are freed
|
||||
for widget in list(section.query(".input-row, .no-inputs")):
|
||||
await widget.remove()
|
||||
|
||||
self._input_keys = []
|
||||
|
||||
if not inputs:
|
||||
section.mount(Static(f"[{_DIM}]No inputs[/]", classes="no-inputs"))
|
||||
section.add_class("visible")
|
||||
return
|
||||
|
||||
for key, value in inputs.items():
|
||||
self._input_keys.append(key)
|
||||
row = Horizontal(classes="input-row")
|
||||
row.compose_add_child(Static(f"[bold]{key}[/]"))
|
||||
row.compose_add_child(
|
||||
Input(value=str(value), placeholder=key, id=f"input-{key}")
|
||||
)
|
||||
section.mount(row)
|
||||
|
||||
section.add_class("visible")
|
||||
|
||||
def _collect_inputs(self) -> dict[str, Any] | None:
|
||||
"""Collect current values from input fields."""
|
||||
if not self._input_keys:
|
||||
return None
|
||||
result: dict[str, Any] = {}
|
||||
for key in self._input_keys:
|
||||
widget = self.query_one(f"#input-{key}", Input)
|
||||
result[key] = widget.value
|
||||
return result
|
||||
|
||||
def _collect_task_overrides(self) -> dict[int, str] | None:
|
||||
"""Collect edited task outputs. Returns only changed values."""
|
||||
if not self._task_output_ids or self._selected_entry is None:
|
||||
return None
|
||||
overrides: dict[int, str] = {}
|
||||
tasks = []
|
||||
for ent in self._selected_entry.get("entities", []):
|
||||
tasks = ent.get("tasks", [])
|
||||
if tasks:
|
||||
break
|
||||
for task_idx, editor_id in self._task_output_ids:
|
||||
editor = self.query_one(f"#{editor_id}", TextArea)
|
||||
original = str(tasks[task_idx].get("output", "")) if task_idx < len(tasks) else ""
|
||||
if editor.text != original:
|
||||
overrides[task_idx] = editor.text
|
||||
return overrides or None
|
||||
|
||||
def _resolve_location(self, entry: dict[str, Any]) -> str:
|
||||
"""Get the restore location string for a checkpoint entry."""
|
||||
@@ -373,16 +499,18 @@ class CheckpointTUI(App[tuple[str, str] | None]):
|
||||
self, event: Tree.NodeHighlighted[dict[str, Any]]
|
||||
) -> None:
|
||||
if event.node.data is not None:
|
||||
self._show_detail(event.node.data)
|
||||
await self._show_detail(event.node.data)
|
||||
|
||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
||||
if self._selected_entry is None:
|
||||
return
|
||||
inputs = self._collect_inputs()
|
||||
overrides = self._collect_task_overrides()
|
||||
loc = self._resolve_location(self._selected_entry)
|
||||
if event.button.id == "btn-resume":
|
||||
self.exit((loc, "resume"))
|
||||
self.exit((loc, "resume", inputs, overrides))
|
||||
elif event.button.id == "btn-fork":
|
||||
self.exit((loc, "fork"))
|
||||
self.exit((loc, "fork", inputs, overrides))
|
||||
|
||||
def action_refresh(self) -> None:
|
||||
self._refresh_tree()
|
||||
@@ -398,7 +526,7 @@ async def _run_checkpoint_tui_async(location: str) -> None:
|
||||
if selection is None:
|
||||
return
|
||||
|
||||
selected, action = selection
|
||||
selected, action, inputs, task_overrides = selection
|
||||
|
||||
from crewai.crew import Crew
|
||||
from crewai.state.checkpoint_config import CheckpointConfig
|
||||
@@ -412,7 +540,44 @@ async def _run_checkpoint_tui_async(location: str) -> None:
|
||||
click.echo(f"\nResuming from: {selected}\n")
|
||||
crew = Crew.from_checkpoint(config)
|
||||
|
||||
result = await crew.akickoff()
|
||||
# Apply task output overrides before kickoff
|
||||
if task_overrides:
|
||||
click.echo("Modifications:")
|
||||
for task_idx, new_output in task_overrides.items():
|
||||
if task_idx < len(crew.tasks) and crew.tasks[task_idx].output is not None:
|
||||
desc = crew.tasks[task_idx].description or f"Task {task_idx + 1}"
|
||||
if len(desc) > 60:
|
||||
desc = desc[:57] + "..."
|
||||
crew.tasks[task_idx].output.raw = new_output # type: ignore[union-attr]
|
||||
preview = new_output.replace("\n", " ")
|
||||
if len(preview) > 80:
|
||||
preview = preview[:77] + "..."
|
||||
click.echo(f" Task {task_idx + 1}: {desc}")
|
||||
click.echo(f" -> {preview}")
|
||||
# Update the assistant message in the executor's history
|
||||
# so the LLM sees the override in its conversation
|
||||
agent = crew.tasks[task_idx].agent
|
||||
if agent and agent.agent_executor:
|
||||
for msg in reversed(agent.agent_executor.messages):
|
||||
if msg.get("role") == "assistant":
|
||||
msg["content"] = new_output
|
||||
break
|
||||
# Invalidate all subsequent tasks so they re-run with
|
||||
# the modified context instead of using cached results
|
||||
for subsequent in crew.tasks[task_idx + 1 :]:
|
||||
subsequent.output = None
|
||||
if subsequent.agent and subsequent.agent.agent_executor:
|
||||
subsequent.agent.agent_executor._resuming = False
|
||||
subsequent.agent.agent_executor.messages = []
|
||||
click.echo()
|
||||
|
||||
if inputs:
|
||||
click.echo("Inputs:")
|
||||
for k, v in inputs.items():
|
||||
click.echo(f" {k}: {v}")
|
||||
click.echo()
|
||||
|
||||
result = await crew.akickoff(inputs=inputs)
|
||||
click.echo(f"\nResult: {getattr(result, 'raw', result)}")
|
||||
|
||||
|
||||
|
||||
@@ -436,6 +436,11 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if agent.agent_executor is not None and task.output is None:
|
||||
agent.agent_executor.task = task
|
||||
break
|
||||
for task in self.tasks:
|
||||
if task.checkpoint_original_description is not None:
|
||||
task._original_description = task.checkpoint_original_description
|
||||
if task.checkpoint_original_expected_output is not None:
|
||||
task._original_expected_output = task.checkpoint_original_expected_output
|
||||
if self.checkpoint_inputs is not None:
|
||||
self._inputs = self.checkpoint_inputs
|
||||
if self.checkpoint_kickoff_event_id is not None:
|
||||
|
||||
@@ -102,8 +102,12 @@ def _find_checkpoint(source: Any) -> CheckpointConfig | None:
|
||||
return None
|
||||
|
||||
|
||||
def _do_checkpoint(state: RuntimeState, cfg: CheckpointConfig) -> None:
|
||||
def _do_checkpoint(
|
||||
state: RuntimeState, cfg: CheckpointConfig, event: BaseEvent | None = None
|
||||
) -> None:
|
||||
"""Write a checkpoint and prune old ones if configured."""
|
||||
if event is not None:
|
||||
state._trigger = event.type
|
||||
_prepare_entities(state.root)
|
||||
data = state.model_dump_json()
|
||||
location = cfg.provider.checkpoint(
|
||||
@@ -134,7 +138,7 @@ def _on_any_event(source: Any, event: BaseEvent, state: Any) -> None:
|
||||
if cfg is None:
|
||||
return
|
||||
try:
|
||||
_do_checkpoint(state, cfg)
|
||||
_do_checkpoint(state, cfg, event)
|
||||
except Exception:
|
||||
logger.warning("Auto-checkpoint failed for event %s", event.type, exc_info=True)
|
||||
|
||||
|
||||
@@ -80,6 +80,9 @@ def _sync_checkpoint_fields(entity: object) -> None:
|
||||
entity.checkpoint_inputs = entity._inputs
|
||||
entity.checkpoint_train = entity._train
|
||||
entity.checkpoint_kickoff_event_id = entity._kickoff_event_id
|
||||
for task in entity.tasks:
|
||||
task.checkpoint_original_description = task._original_description
|
||||
task.checkpoint_original_expected_output = task._original_expected_output
|
||||
|
||||
|
||||
def _migrate(data: dict[str, Any]) -> dict[str, Any]:
|
||||
@@ -123,6 +126,7 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
_parent_id: str | None = PrivateAttr(default=None)
|
||||
_branch: str = PrivateAttr(default="main")
|
||||
_location: str | None = PrivateAttr(default=None)
|
||||
_trigger: str | None = PrivateAttr(default=None)
|
||||
|
||||
@property
|
||||
def event_record(self) -> EventRecord:
|
||||
@@ -131,13 +135,16 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
|
||||
@model_serializer(mode="plain")
|
||||
def _serialize(self) -> dict[str, Any]:
|
||||
return {
|
||||
d: dict[str, Any] = {
|
||||
"crewai_version": get_crewai_version(),
|
||||
"parent_id": self._parent_id,
|
||||
"branch": self._branch,
|
||||
"entities": [e.model_dump(mode="json") for e in self.root],
|
||||
"event_record": self._event_record.model_dump(),
|
||||
}
|
||||
if self._trigger:
|
||||
d["trigger"] = self._trigger
|
||||
return d
|
||||
|
||||
@model_validator(mode="wrap")
|
||||
@classmethod
|
||||
@@ -222,13 +229,10 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
if branch:
|
||||
self._branch = branch
|
||||
elif self._checkpoint_id:
|
||||
self._branch = f"fork/{self._checkpoint_id}"
|
||||
self._branch = f"fork/{self._checkpoint_id}_{uuid.uuid4().hex[:6]}"
|
||||
else:
|
||||
self._branch = f"fork/{uuid.uuid4().hex[:8]}"
|
||||
|
||||
if self._location is not None:
|
||||
self.checkpoint(self._location)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(cls, config: CheckpointConfig, **kwargs: Any) -> RuntimeState:
|
||||
"""Restore a RuntimeState from a checkpoint.
|
||||
|
||||
@@ -230,6 +230,8 @@ class Task(BaseModel):
|
||||
_original_description: str | None = PrivateAttr(default=None)
|
||||
_original_expected_output: str | None = PrivateAttr(default=None)
|
||||
_original_output_file: str | None = PrivateAttr(default=None)
|
||||
checkpoint_original_description: str | None = Field(default=None, exclude=False)
|
||||
checkpoint_original_expected_output: str | None = Field(default=None, exclude=False)
|
||||
_thread: threading.Thread | None = PrivateAttr(default=None)
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user