fix(flow): gate restore on a flag so live snapshots don't replay as resume

Checkpoint serialization stamps checkpoint_completed_methods onto every live
Flow in RuntimeState.root, including the agent executor reused across a crew's
tasks. kickoff_async read that stamp as a restore signal, so the second task
replayed the first task's completed methods and never reached a final answer.

Gate is_restoring on _restored_from_checkpoint, set only by
_restore_from_checkpoint, and consume it single-shot.
This commit is contained in:
Greyson LaLonde
2026-06-10 20:40:08 -07:00
committed by GitHub
parent 5267c059f5
commit fbafe1f0d3
2 changed files with 67 additions and 1 deletions

View File

@@ -862,6 +862,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
self._completed_methods = {
FlowMethodName(m) for m in self.checkpoint_completed_methods
}
self._restored_from_checkpoint = True
if self.checkpoint_method_outputs is not None:
self._method_outputs = list(self.checkpoint_method_outputs)
if self.checkpoint_method_counts is not None:
@@ -897,6 +898,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
_completed_methods: set[FlowMethodName] = PrivateAttr(default_factory=set)
_method_call_counts: dict[FlowMethodName, int] = PrivateAttr(default_factory=dict)
_is_execution_resuming: bool = PrivateAttr(default=False)
_restored_from_checkpoint: bool = PrivateAttr(default=False)
_event_futures: list[Future[None]] = PrivateAttr(default_factory=list)
_pending_feedback_context: PendingFeedbackContext | None = PrivateAttr(default=None)
_human_feedback_method_outputs: dict[str, Any] = PrivateAttr(default_factory=dict)
@@ -2058,7 +2060,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
# Reset flow state for fresh execution unless restoring from persistence
is_restoring = (
inputs and "id" in inputs and self.persistence is not None
) or self.checkpoint_completed_methods is not None
) or self._restored_from_checkpoint
if not is_restoring:
# Clear completed methods and outputs for a fresh start
self._completed_methods.clear()
@@ -2075,6 +2077,10 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
if self._completed_methods:
self._is_execution_resuming = True
# Restore is single-shot: a later kickoff on the same instance
# starts fresh.
self._restored_from_checkpoint = False
# Fork hydration: when restore_from_state_id is set and persistence is
# available, hydrate self._state from the source UUID's latest snapshot
# and reassign state.id to a fresh value so subsequent @persist writes

View File

@@ -16,6 +16,7 @@ from pydantic import BaseModel
from crewai.agent.core import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.crew import Crew
from crewai.llms.base_llm import BaseLLM
from crewai.flow.flow import _INITIAL_STATE_CLASS_MARKER, Flow, start
from crewai.state.checkpoint_config import CheckpointConfig
from crewai.state.checkpoint_listener import (
@@ -682,3 +683,62 @@ class TestAgentCheckpoint:
cfg = CheckpointConfig(restore_from=loc)
restored = Agent.from_checkpoint(cfg)
assert restored._kickoff_event_id == "evt-456"
class _FinalAnswerLLM(BaseLLM):
"""Stub LLM that always returns a final answer without any API calls."""
def __init__(self) -> None:
super().__init__(model="stub")
def call(
self,
messages,
tools=None,
callbacks=None,
available_functions=None,
from_task=None,
from_agent=None,
response_model=None,
):
return "Final Answer: done."
def supports_function_calling(self) -> bool:
return False
def supports_stop_words(self) -> bool:
return False
def get_context_window_size(self) -> int:
return 4096
async def acall(self, *args, **kwargs):
raise NotImplementedError
class TestCheckpointReusedExecutor:
"""Checkpoint serialization stamps every live Flow's completed methods.
The agent executor is a Flow reused across a crew's tasks, so the stamp
must not be read back as a restore signal on the next task — otherwise the
second task replays as a resume and never reaches a final answer.
"""
def test_second_task_runs_with_checkpointing_enabled(self) -> None:
agent = Agent(role="r", goal="g", backstory="b", llm=_FinalAnswerLLM())
task1 = Task(description="first", expected_output="x", agent=agent)
task2 = Task(description="second", expected_output="y", agent=agent)
with tempfile.TemporaryDirectory() as d:
crew = Crew(
agents=[agent],
tasks=[task1, task2],
verbose=False,
checkpoint=CheckpointConfig(
provider=JsonProvider(location=d),
on_events=["task_started", "task_completed"],
),
)
result = crew.kickoff()
assert len(result.tasks_output) == 2
assert result.tasks_output[1].raw