fix(flow): replay recorded method events on checkpoint resume

This commit is contained in:
Greyson LaLonde
2026-04-24 03:41:55 +08:00
committed by GitHub
parent 77b2835a1d
commit 69d777ca50
5 changed files with 335 additions and 2 deletions

View File

@@ -64,6 +64,22 @@ P = ParamSpec("P")
R = TypeVar("R")
_replaying: contextvars.ContextVar[bool] = contextvars.ContextVar(
"crewai_event_replaying", default=False
)
def is_replaying() -> bool:
"""Return True if the current context is dispatching a replayed event.
Listeners with side effects (checkpoint writes, external API calls that
should not be repeated) should early-return when this is true. Listeners
whose purpose is reconstructing timeline state (trace batch, console
formatter) should ignore the flag and process replayed events normally.
"""
return _replaying.get()
class CrewAIEventsBus:
"""Singleton event bus for handling events in CrewAI.
@@ -261,6 +277,11 @@ class CrewAIEventsBus:
self._runtime_state = state
self._registered_entity_ids = {id(e) for e in state.root}
@property
def runtime_state(self) -> RuntimeState | None:
"""The RuntimeState currently attached to the bus, if any."""
return self._runtime_state
def register_entity(self, entity: Any) -> None:
"""Add an entity to the RuntimeState, creating it if needed.
@@ -568,6 +589,87 @@ class CrewAIEventsBus:
return None
async def _acall_handlers_replaying(
self,
source: Any,
event: BaseEvent,
handlers: AsyncHandlerSet,
) -> None:
"""Call async handlers with the replaying flag set on the loop thread."""
token = _replaying.set(True)
try:
await self._acall_handlers(source, event, handlers)
finally:
_replaying.reset(token)
async def _emit_with_dependencies_replaying(
self, source: Any, event: BaseEvent
) -> None:
"""Dependency-aware dispatch with the replaying flag set."""
token = _replaying.set(True)
try:
await self._emit_with_dependencies(source, event)
finally:
_replaying.reset(token)
def replay(self, source: Any, event: BaseEvent) -> Future[None] | None:
"""Dispatch a previously-recorded event without mutating its fields.
Unlike :meth:`emit`, this does not run ``_prepare_event`` (so stored
event ids and ``emission_sequence`` are preserved) and does not
re-record the event. Listeners can call :func:`is_replaying` to
opt out of side-effectful processing.
Args:
source: The emitting object.
event: The previously-recorded event to dispatch.
Returns:
Future that completes when handlers finish, or None if no handlers.
"""
event_type = type(event)
with self._rwlock.r_locked():
if self._shutting_down:
return None
has_dependencies = event_type in self._handler_dependencies
sync_handlers = self._sync_handlers.get(event_type, frozenset())
async_handlers = self._async_handlers.get(event_type, frozenset())
if not sync_handlers and not async_handlers:
return None
self._ensure_executor_initialized()
self._has_pending_events = True
token = _replaying.set(True)
try:
if has_dependencies:
return self._track_future(
asyncio.run_coroutine_threadsafe(
self._emit_with_dependencies_replaying(source, event),
self._loop,
)
)
if sync_handlers:
ctx = contextvars.copy_context()
sync_future = self._sync_executor.submit(
ctx.run, self._call_handlers, source, event, sync_handlers
)
self._track_future(sync_future)
if not async_handlers:
return sync_future
return self._track_future(
asyncio.run_coroutine_threadsafe(
self._acall_handlers_replaying(source, event, async_handlers),
self._loop,
)
)
finally:
_replaying.reset(token)
def flush(self, timeout: float | None = 30.0) -> bool:
"""Block until all pending event handlers complete.

View File

@@ -59,6 +59,7 @@ from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_context import (
get_current_parent_id,
reset_last_event_id,
restore_event_scope,
triggered_by_scope,
)
from crewai.events.listeners.tracing.trace_listener import (
@@ -1016,13 +1017,18 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
A Flow instance on the new branch. Call kickoff() to run.
"""
flow = cls.from_checkpoint(config)
state = crewai_event_bus._runtime_state
state = crewai_event_bus.runtime_state
if state is None:
raise RuntimeError(
"Cannot fork: no runtime state on the event bus. "
"Ensure from_checkpoint() succeeded before calling fork()."
)
state.fork(branch)
new_id = str(uuid4())
if isinstance(flow._state, dict):
flow._state["id"] = new_id
else:
object.__setattr__(flow._state, "id", new_id)
return flow
checkpoint_completed_methods: set[str] | None = Field(default=None)
@@ -1044,6 +1050,8 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
}
if self.checkpoint_state is not None:
self._restore_state(self.checkpoint_state)
restore_event_scope(())
reset_last_event_id()
_methods: dict[FlowMethodName, FlowMethod[Any, Any]] = PrivateAttr(
default_factory=dict
@@ -2250,6 +2258,9 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
if inputs is not None and "id" not in inputs:
self._initialize_state(inputs)
if self._is_execution_resuming:
await self._replay_recorded_events()
try:
# Determine which start methods to execute at kickoff
# Conditional start methods (with __trigger_methods__) are only triggered by their conditions
@@ -2397,6 +2408,44 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
"""
return await self.kickoff_async(inputs, input_files, from_checkpoint)
async def _replay_recorded_events(self) -> None:
"""Dispatch recorded ``MethodExecution*`` events from the event record."""
state = crewai_event_bus.runtime_state
if state is None:
return
record = state.event_record
if len(record) == 0:
return
replayable = (
MethodExecutionStartedEvent,
MethodExecutionFinishedEvent,
MethodExecutionFailedEvent,
)
flow_name = self.name or self.__class__.__name__
nodes = sorted(
(
n
for n in record.all_nodes()
if isinstance(n.event, replayable)
and n.event.flow_name == flow_name
and n.event.method_name in self._completed_methods
),
key=lambda n: n.event.emission_sequence or 0,
)
for node in nodes:
future = crewai_event_bus.replay(self, node.event)
if future is not None:
try:
await asyncio.wrap_future(future)
except Exception:
logger.warning(
"Replayed event handler failed: %s",
node.event.type,
exc_info=True,
)
async def _execute_start_method(self, start_method_name: FlowMethodName) -> None:
"""Executes a flow's start method and its triggered listeners.

View File

@@ -16,7 +16,7 @@ from typing import Any
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.crew import Crew
from crewai.events.base_events import BaseEvent
from crewai.events.event_bus import CrewAIEventsBus, crewai_event_bus
from crewai.events.event_bus import CrewAIEventsBus, crewai_event_bus, is_replaying
from crewai.events.types.checkpoint_events import (
CheckpointBaseEvent,
CheckpointCompletedEvent,
@@ -228,6 +228,8 @@ def _should_checkpoint(source: Any, event: BaseEvent) -> CheckpointConfig | None
def _on_any_event(source: Any, event: BaseEvent, state: Any) -> None:
"""Sync handler registered on every event class."""
if is_replaying():
return
if isinstance(
event,
(CheckpointBaseEvent, CheckpointForkBaseEvent, CheckpointRestoreBaseEvent),

View File

@@ -197,6 +197,21 @@ class EventRecord(BaseModel):
node for node in self.nodes.values() if not node.neighbors("parent")
]
def all_nodes(self) -> list[EventNode]:
"""Return a snapshot of every node under the read lock.
Returns:
A list copy of the current nodes, safe to iterate without holding
the lock.
"""
with self._lock.r_locked():
return list(self.nodes.values())
def clear(self) -> None:
"""Remove all nodes from the record under the write lock."""
with self._lock.w_locked():
self.nodes.clear()
def __len__(self) -> int:
with self._lock.r_locked():
return len(self.nodes)

View File

@@ -0,0 +1,165 @@
"""Tests for event bus replay dispatch and is_replaying flag."""
from __future__ import annotations
from typing import Any
from unittest.mock import patch
from crewai.events.event_bus import _replaying, crewai_event_bus, is_replaying
from crewai.events.types.flow_events import (
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
def _make_started(method: str, event_id: str, sequence: int) -> MethodExecutionStartedEvent:
"""Build a MethodExecutionStartedEvent with explicit ids/sequence."""
ev = MethodExecutionStartedEvent(
method_name=method,
flow_name="F",
params={},
state={},
)
ev.event_id = event_id
ev.emission_sequence = sequence
return ev
class TestReplayPreservesFields:
"""replay() must not overwrite event_id, parent_event_id, or emission_sequence."""
def test_preserves_ids_and_sequence(self) -> None:
captured: list[MethodExecutionStartedEvent] = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(MethodExecutionStartedEvent)
def _capture(_: Any, event: MethodExecutionStartedEvent) -> None:
captured.append(event)
ev = _make_started("outline", "orig-id-1", 42)
ev.parent_event_id = "parent-abc"
future = crewai_event_bus.replay(object(), ev)
if future is not None:
future.result(timeout=5.0)
assert len(captured) == 1
assert captured[0].event_id == "orig-id-1"
assert captured[0].parent_event_id == "parent-abc"
assert captured[0].emission_sequence == 42
class TestIsReplayingFlag:
"""is_replaying() must be True inside handlers dispatched via replay()."""
def test_flag_true_during_replay(self) -> None:
seen: list[bool] = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(MethodExecutionStartedEvent)
def _capture(_: Any, __: MethodExecutionStartedEvent) -> None:
seen.append(is_replaying())
ev = _make_started("m", "id-1", 1)
future = crewai_event_bus.replay(object(), ev)
if future is not None:
future.result(timeout=5.0)
assert seen == [True]
assert is_replaying() is False
def test_flag_false_during_emit(self) -> None:
seen: list[bool] = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(MethodExecutionStartedEvent)
def _capture(_: Any, __: MethodExecutionStartedEvent) -> None:
seen.append(is_replaying())
ev = _make_started("m", "id-1", 1)
future = crewai_event_bus.emit(object(), ev)
if future is not None:
future.result(timeout=5.0)
assert seen == [False]
class TestCheckpointListenerOptsOut:
"""CheckpointListener must early-return during replay."""
def test_checkpoint_not_written_on_replay(self) -> None:
from crewai.state.checkpoint_config import CheckpointConfig
from crewai.state.checkpoint_listener import _on_any_event
class FlowLike:
entity_type = "flow"
checkpoint = CheckpointConfig(trigger_all=True)
ev = _make_started("m", "id-1", 1)
with patch("crewai.state.checkpoint_listener._do_checkpoint") as do_cp:
token = _replaying.set(True)
try:
_on_any_event(FlowLike(), ev, state=None)
finally:
_replaying.reset(token)
assert do_cp.call_count == 0
class TestFlowResumeReplaysEvents:
"""End-to-end: a resumed flow emits MethodExecution* events for completed methods."""
def test_resume_dispatches_completed_method_events(self, tmp_path) -> None:
from crewai.flow.flow import Flow, listen, start
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
db_path = tmp_path / "flows.db"
persistence = SQLiteFlowPersistence(str(db_path))
class ThreeStepFlow(Flow[dict]):
@start()
def step_a(self) -> str:
return "a"
@listen(step_a)
def step_b(self) -> str:
return "b"
@listen(step_b)
def step_c(self) -> str:
return "c"
if crewai_event_bus.runtime_state is not None:
crewai_event_bus.runtime_state.event_record.clear()
flow1 = ThreeStepFlow(persistence=persistence)
flow1.kickoff()
flow_id = flow1.state["id"]
captured_started: list[str] = []
captured_finished: list[str] = []
flow2 = ThreeStepFlow(persistence=persistence)
flow2._completed_methods = {"step_a", "step_b"}
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(MethodExecutionStartedEvent)
def _cs(_: Any, event: MethodExecutionStartedEvent) -> None:
captured_started.append(event.method_name)
@crewai_event_bus.on(MethodExecutionFinishedEvent)
def _cf(_: Any, event: MethodExecutionFinishedEvent) -> None:
captured_finished.append(event.method_name)
flow2.kickoff(inputs={"id": flow_id})
assert captured_started.count("step_a") == 1
assert captured_started.count("step_b") == 1
assert captured_started.count("step_c") == 1
assert captured_finished.count("step_a") == 1
assert captured_finished.count("step_b") == 1
assert captured_finished.count("step_c") == 1