From 69d777ca50a2a9f241b971b7de7f757f631e8bd8 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Fri, 24 Apr 2026 03:41:55 +0800 Subject: [PATCH] fix(flow): replay recorded method events on checkpoint resume --- lib/crewai/src/crewai/events/event_bus.py | 102 +++++++++++ lib/crewai/src/crewai/flow/flow.py | 51 +++++- .../src/crewai/state/checkpoint_listener.py | 4 +- lib/crewai/src/crewai/state/event_record.py | 15 ++ lib/crewai/tests/events/test_event_replay.py | 165 ++++++++++++++++++ 5 files changed, 335 insertions(+), 2 deletions(-) create mode 100644 lib/crewai/tests/events/test_event_replay.py diff --git a/lib/crewai/src/crewai/events/event_bus.py b/lib/crewai/src/crewai/events/event_bus.py index c2a2956a7..821f97768 100644 --- a/lib/crewai/src/crewai/events/event_bus.py +++ b/lib/crewai/src/crewai/events/event_bus.py @@ -64,6 +64,22 @@ P = ParamSpec("P") R = TypeVar("R") +_replaying: contextvars.ContextVar[bool] = contextvars.ContextVar( + "crewai_event_replaying", default=False +) + + +def is_replaying() -> bool: + """Return True if the current context is dispatching a replayed event. + + Listeners with side effects (checkpoint writes, external API calls that + should not be repeated) should early-return when this is true. Listeners + whose purpose is reconstructing timeline state (trace batch, console + formatter) should ignore the flag and process replayed events normally. + """ + return _replaying.get() + + class CrewAIEventsBus: """Singleton event bus for handling events in CrewAI. @@ -261,6 +277,11 @@ class CrewAIEventsBus: self._runtime_state = state self._registered_entity_ids = {id(e) for e in state.root} + @property + def runtime_state(self) -> RuntimeState | None: + """The RuntimeState currently attached to the bus, if any.""" + return self._runtime_state + def register_entity(self, entity: Any) -> None: """Add an entity to the RuntimeState, creating it if needed. @@ -568,6 +589,87 @@ class CrewAIEventsBus: return None + async def _acall_handlers_replaying( + self, + source: Any, + event: BaseEvent, + handlers: AsyncHandlerSet, + ) -> None: + """Call async handlers with the replaying flag set on the loop thread.""" + token = _replaying.set(True) + try: + await self._acall_handlers(source, event, handlers) + finally: + _replaying.reset(token) + + async def _emit_with_dependencies_replaying( + self, source: Any, event: BaseEvent + ) -> None: + """Dependency-aware dispatch with the replaying flag set.""" + token = _replaying.set(True) + try: + await self._emit_with_dependencies(source, event) + finally: + _replaying.reset(token) + + def replay(self, source: Any, event: BaseEvent) -> Future[None] | None: + """Dispatch a previously-recorded event without mutating its fields. + + Unlike :meth:`emit`, this does not run ``_prepare_event`` (so stored + event ids and ``emission_sequence`` are preserved) and does not + re-record the event. Listeners can call :func:`is_replaying` to + opt out of side-effectful processing. + + Args: + source: The emitting object. + event: The previously-recorded event to dispatch. + + Returns: + Future that completes when handlers finish, or None if no handlers. + """ + event_type = type(event) + + with self._rwlock.r_locked(): + if self._shutting_down: + return None + has_dependencies = event_type in self._handler_dependencies + sync_handlers = self._sync_handlers.get(event_type, frozenset()) + async_handlers = self._async_handlers.get(event_type, frozenset()) + + if not sync_handlers and not async_handlers: + return None + + self._ensure_executor_initialized() + self._has_pending_events = True + + token = _replaying.set(True) + try: + if has_dependencies: + return self._track_future( + asyncio.run_coroutine_threadsafe( + self._emit_with_dependencies_replaying(source, event), + self._loop, + ) + ) + + if sync_handlers: + ctx = contextvars.copy_context() + sync_future = self._sync_executor.submit( + ctx.run, self._call_handlers, source, event, sync_handlers + ) + self._track_future(sync_future) + if not async_handlers: + return sync_future + + return self._track_future( + asyncio.run_coroutine_threadsafe( + self._acall_handlers_replaying(source, event, async_handlers), + self._loop, + ) + ) + finally: + _replaying.reset(token) + def flush(self, timeout: float | None = 30.0) -> bool: """Block until all pending event handlers complete. diff --git a/lib/crewai/src/crewai/flow/flow.py b/lib/crewai/src/crewai/flow/flow.py index 439a40524..8172e7a70 100644 --- a/lib/crewai/src/crewai/flow/flow.py +++ b/lib/crewai/src/crewai/flow/flow.py @@ -59,6 +59,7 @@ from crewai.events.event_bus import crewai_event_bus from crewai.events.event_context import ( get_current_parent_id, reset_last_event_id, + restore_event_scope, triggered_by_scope, ) from crewai.events.listeners.tracing.trace_listener import ( @@ -1016,13 +1017,18 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): A Flow instance on the new branch. Call kickoff() to run. """ flow = cls.from_checkpoint(config) - state = crewai_event_bus._runtime_state + state = crewai_event_bus.runtime_state if state is None: raise RuntimeError( "Cannot fork: no runtime state on the event bus. " "Ensure from_checkpoint() succeeded before calling fork()." ) state.fork(branch) + new_id = str(uuid4()) + if isinstance(flow._state, dict): + flow._state["id"] = new_id + else: + object.__setattr__(flow._state, "id", new_id) return flow checkpoint_completed_methods: set[str] | None = Field(default=None) @@ -1044,6 +1050,8 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): } if self.checkpoint_state is not None: self._restore_state(self.checkpoint_state) + restore_event_scope(()) + reset_last_event_id() _methods: dict[FlowMethodName, FlowMethod[Any, Any]] = PrivateAttr( default_factory=dict @@ -2250,6 +2258,9 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): if inputs is not None and "id" not in inputs: self._initialize_state(inputs) + if self._is_execution_resuming: + await self._replay_recorded_events() + try: # Determine which start methods to execute at kickoff # Conditional start methods (with __trigger_methods__) are only triggered by their conditions @@ -2397,6 +2408,44 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): """ return await self.kickoff_async(inputs, input_files, from_checkpoint) + async def _replay_recorded_events(self) -> None: + """Dispatch recorded ``MethodExecution*`` events from the event record.""" + state = crewai_event_bus.runtime_state + if state is None: + return + record = state.event_record + if len(record) == 0: + return + + replayable = ( + MethodExecutionStartedEvent, + MethodExecutionFinishedEvent, + MethodExecutionFailedEvent, + ) + flow_name = self.name or self.__class__.__name__ + nodes = sorted( + ( + n + for n in record.all_nodes() + if isinstance(n.event, replayable) + and n.event.flow_name == flow_name + and n.event.method_name in self._completed_methods + ), + key=lambda n: n.event.emission_sequence or 0, + ) + + for node in nodes: + future = crewai_event_bus.replay(self, node.event) + if future is not None: + try: + await asyncio.wrap_future(future) + except Exception: + logger.warning( + "Replayed event handler failed: %s", + node.event.type, + exc_info=True, + ) + async def _execute_start_method(self, start_method_name: FlowMethodName) -> None: """Executes a flow's start method and its triggered listeners. diff --git a/lib/crewai/src/crewai/state/checkpoint_listener.py b/lib/crewai/src/crewai/state/checkpoint_listener.py index 0c2adc127..53ae0b494 100644 --- a/lib/crewai/src/crewai/state/checkpoint_listener.py +++ b/lib/crewai/src/crewai/state/checkpoint_listener.py @@ -16,7 +16,7 @@ from typing import Any from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.crew import Crew from crewai.events.base_events import BaseEvent -from crewai.events.event_bus import CrewAIEventsBus, crewai_event_bus +from crewai.events.event_bus import CrewAIEventsBus, crewai_event_bus, is_replaying from crewai.events.types.checkpoint_events import ( CheckpointBaseEvent, CheckpointCompletedEvent, @@ -228,6 +228,8 @@ def _should_checkpoint(source: Any, event: BaseEvent) -> CheckpointConfig | None def _on_any_event(source: Any, event: BaseEvent, state: Any) -> None: """Sync handler registered on every event class.""" + if is_replaying(): + return if isinstance( event, (CheckpointBaseEvent, CheckpointForkBaseEvent, CheckpointRestoreBaseEvent), diff --git a/lib/crewai/src/crewai/state/event_record.py b/lib/crewai/src/crewai/state/event_record.py index 866398e0a..f0b15b48f 100644 --- a/lib/crewai/src/crewai/state/event_record.py +++ b/lib/crewai/src/crewai/state/event_record.py @@ -197,6 +197,21 @@ class EventRecord(BaseModel): node for node in self.nodes.values() if not node.neighbors("parent") ] + def all_nodes(self) -> list[EventNode]: + """Return a snapshot of every node under the read lock. + + Returns: + A list copy of the current nodes, safe to iterate without holding + the lock. + """ + with self._lock.r_locked(): + return list(self.nodes.values()) + + def clear(self) -> None: + """Remove all nodes from the record under the write lock.""" + with self._lock.w_locked(): + self.nodes.clear() + def __len__(self) -> int: with self._lock.r_locked(): return len(self.nodes) diff --git a/lib/crewai/tests/events/test_event_replay.py b/lib/crewai/tests/events/test_event_replay.py new file mode 100644 index 000000000..d141385ca --- /dev/null +++ b/lib/crewai/tests/events/test_event_replay.py @@ -0,0 +1,165 @@ +"""Tests for event bus replay dispatch and is_replaying flag.""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import patch + +from crewai.events.event_bus import _replaying, crewai_event_bus, is_replaying +from crewai.events.types.flow_events import ( + MethodExecutionFinishedEvent, + MethodExecutionStartedEvent, +) + + +def _make_started(method: str, event_id: str, sequence: int) -> MethodExecutionStartedEvent: + """Build a MethodExecutionStartedEvent with explicit ids/sequence.""" + ev = MethodExecutionStartedEvent( + method_name=method, + flow_name="F", + params={}, + state={}, + ) + ev.event_id = event_id + ev.emission_sequence = sequence + return ev + + +class TestReplayPreservesFields: + """replay() must not overwrite event_id, parent_event_id, or emission_sequence.""" + + def test_preserves_ids_and_sequence(self) -> None: + captured: list[MethodExecutionStartedEvent] = [] + + with crewai_event_bus.scoped_handlers(): + + @crewai_event_bus.on(MethodExecutionStartedEvent) + def _capture(_: Any, event: MethodExecutionStartedEvent) -> None: + captured.append(event) + + ev = _make_started("outline", "orig-id-1", 42) + ev.parent_event_id = "parent-abc" + + future = crewai_event_bus.replay(object(), ev) + if future is not None: + future.result(timeout=5.0) + + assert len(captured) == 1 + assert captured[0].event_id == "orig-id-1" + assert captured[0].parent_event_id == "parent-abc" + assert captured[0].emission_sequence == 42 + + +class TestIsReplayingFlag: + """is_replaying() must be True inside handlers dispatched via replay().""" + + def test_flag_true_during_replay(self) -> None: + seen: list[bool] = [] + + with crewai_event_bus.scoped_handlers(): + + @crewai_event_bus.on(MethodExecutionStartedEvent) + def _capture(_: Any, __: MethodExecutionStartedEvent) -> None: + seen.append(is_replaying()) + + ev = _make_started("m", "id-1", 1) + future = crewai_event_bus.replay(object(), ev) + if future is not None: + future.result(timeout=5.0) + + assert seen == [True] + assert is_replaying() is False + + def test_flag_false_during_emit(self) -> None: + seen: list[bool] = [] + + with crewai_event_bus.scoped_handlers(): + + @crewai_event_bus.on(MethodExecutionStartedEvent) + def _capture(_: Any, __: MethodExecutionStartedEvent) -> None: + seen.append(is_replaying()) + + ev = _make_started("m", "id-1", 1) + future = crewai_event_bus.emit(object(), ev) + if future is not None: + future.result(timeout=5.0) + + assert seen == [False] + + +class TestCheckpointListenerOptsOut: + """CheckpointListener must early-return during replay.""" + + def test_checkpoint_not_written_on_replay(self) -> None: + from crewai.state.checkpoint_config import CheckpointConfig + from crewai.state.checkpoint_listener import _on_any_event + + class FlowLike: + entity_type = "flow" + checkpoint = CheckpointConfig(trigger_all=True) + + ev = _make_started("m", "id-1", 1) + + with patch("crewai.state.checkpoint_listener._do_checkpoint") as do_cp: + token = _replaying.set(True) + try: + _on_any_event(FlowLike(), ev, state=None) + finally: + _replaying.reset(token) + assert do_cp.call_count == 0 + + +class TestFlowResumeReplaysEvents: + """End-to-end: a resumed flow emits MethodExecution* events for completed methods.""" + + def test_resume_dispatches_completed_method_events(self, tmp_path) -> None: + from crewai.flow.flow import Flow, listen, start + from crewai.flow.persistence.sqlite import SQLiteFlowPersistence + + db_path = tmp_path / "flows.db" + persistence = SQLiteFlowPersistence(str(db_path)) + + class ThreeStepFlow(Flow[dict]): + @start() + def step_a(self) -> str: + return "a" + + @listen(step_a) + def step_b(self) -> str: + return "b" + + @listen(step_b) + def step_c(self) -> str: + return "c" + + if crewai_event_bus.runtime_state is not None: + crewai_event_bus.runtime_state.event_record.clear() + + flow1 = ThreeStepFlow(persistence=persistence) + flow1.kickoff() + flow_id = flow1.state["id"] + + captured_started: list[str] = [] + captured_finished: list[str] = [] + + flow2 = ThreeStepFlow(persistence=persistence) + flow2._completed_methods = {"step_a", "step_b"} + + with crewai_event_bus.scoped_handlers(): + + @crewai_event_bus.on(MethodExecutionStartedEvent) + def _cs(_: Any, event: MethodExecutionStartedEvent) -> None: + captured_started.append(event.method_name) + + @crewai_event_bus.on(MethodExecutionFinishedEvent) + def _cf(_: Any, event: MethodExecutionFinishedEvent) -> None: + captured_finished.append(event.method_name) + + flow2.kickoff(inputs={"id": flow_id}) + + assert captured_started.count("step_a") == 1 + assert captured_started.count("step_b") == 1 + assert captured_started.count("step_c") == 1 + assert captured_finished.count("step_a") == 1 + assert captured_finished.count("step_b") == 1 + assert captured_finished.count("step_c") == 1