mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-03 14:09:24 +00:00
fix(flow): replay recorded method events on checkpoint resume
This commit is contained in:
@@ -64,6 +64,22 @@ P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
_replaying: contextvars.ContextVar[bool] = contextvars.ContextVar(
|
||||
"crewai_event_replaying", default=False
|
||||
)
|
||||
|
||||
|
||||
def is_replaying() -> bool:
|
||||
"""Return True if the current context is dispatching a replayed event.
|
||||
|
||||
Listeners with side effects (checkpoint writes, external API calls that
|
||||
should not be repeated) should early-return when this is true. Listeners
|
||||
whose purpose is reconstructing timeline state (trace batch, console
|
||||
formatter) should ignore the flag and process replayed events normally.
|
||||
"""
|
||||
return _replaying.get()
|
||||
|
||||
|
||||
class CrewAIEventsBus:
|
||||
"""Singleton event bus for handling events in CrewAI.
|
||||
|
||||
@@ -261,6 +277,11 @@ class CrewAIEventsBus:
|
||||
self._runtime_state = state
|
||||
self._registered_entity_ids = {id(e) for e in state.root}
|
||||
|
||||
@property
|
||||
def runtime_state(self) -> RuntimeState | None:
|
||||
"""The RuntimeState currently attached to the bus, if any."""
|
||||
return self._runtime_state
|
||||
|
||||
def register_entity(self, entity: Any) -> None:
|
||||
"""Add an entity to the RuntimeState, creating it if needed.
|
||||
|
||||
@@ -568,6 +589,87 @@ class CrewAIEventsBus:
|
||||
|
||||
return None
|
||||
|
||||
async def _acall_handlers_replaying(
|
||||
self,
|
||||
source: Any,
|
||||
event: BaseEvent,
|
||||
handlers: AsyncHandlerSet,
|
||||
) -> None:
|
||||
"""Call async handlers with the replaying flag set on the loop thread."""
|
||||
token = _replaying.set(True)
|
||||
try:
|
||||
await self._acall_handlers(source, event, handlers)
|
||||
finally:
|
||||
_replaying.reset(token)
|
||||
|
||||
async def _emit_with_dependencies_replaying(
|
||||
self, source: Any, event: BaseEvent
|
||||
) -> None:
|
||||
"""Dependency-aware dispatch with the replaying flag set."""
|
||||
token = _replaying.set(True)
|
||||
try:
|
||||
await self._emit_with_dependencies(source, event)
|
||||
finally:
|
||||
_replaying.reset(token)
|
||||
|
||||
def replay(self, source: Any, event: BaseEvent) -> Future[None] | None:
|
||||
"""Dispatch a previously-recorded event without mutating its fields.
|
||||
|
||||
Unlike :meth:`emit`, this does not run ``_prepare_event`` (so stored
|
||||
event ids and ``emission_sequence`` are preserved) and does not
|
||||
re-record the event. Listeners can call :func:`is_replaying` to
|
||||
opt out of side-effectful processing.
|
||||
|
||||
Args:
|
||||
source: The emitting object.
|
||||
event: The previously-recorded event to dispatch.
|
||||
|
||||
Returns:
|
||||
Future that completes when handlers finish, or None if no handlers.
|
||||
"""
|
||||
event_type = type(event)
|
||||
|
||||
with self._rwlock.r_locked():
|
||||
if self._shutting_down:
|
||||
return None
|
||||
has_dependencies = event_type in self._handler_dependencies
|
||||
sync_handlers = self._sync_handlers.get(event_type, frozenset())
|
||||
async_handlers = self._async_handlers.get(event_type, frozenset())
|
||||
|
||||
if not sync_handlers and not async_handlers:
|
||||
return None
|
||||
|
||||
self._ensure_executor_initialized()
|
||||
self._has_pending_events = True
|
||||
|
||||
token = _replaying.set(True)
|
||||
try:
|
||||
if has_dependencies:
|
||||
return self._track_future(
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._emit_with_dependencies_replaying(source, event),
|
||||
self._loop,
|
||||
)
|
||||
)
|
||||
|
||||
if sync_handlers:
|
||||
ctx = contextvars.copy_context()
|
||||
sync_future = self._sync_executor.submit(
|
||||
ctx.run, self._call_handlers, source, event, sync_handlers
|
||||
)
|
||||
self._track_future(sync_future)
|
||||
if not async_handlers:
|
||||
return sync_future
|
||||
|
||||
return self._track_future(
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._acall_handlers_replaying(source, event, async_handlers),
|
||||
self._loop,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
_replaying.reset(token)
|
||||
|
||||
def flush(self, timeout: float | None = 30.0) -> bool:
|
||||
"""Block until all pending event handlers complete.
|
||||
|
||||
|
||||
@@ -59,6 +59,7 @@ from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_context import (
|
||||
get_current_parent_id,
|
||||
reset_last_event_id,
|
||||
restore_event_scope,
|
||||
triggered_by_scope,
|
||||
)
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
@@ -1016,13 +1017,18 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
A Flow instance on the new branch. Call kickoff() to run.
|
||||
"""
|
||||
flow = cls.from_checkpoint(config)
|
||||
state = crewai_event_bus._runtime_state
|
||||
state = crewai_event_bus.runtime_state
|
||||
if state is None:
|
||||
raise RuntimeError(
|
||||
"Cannot fork: no runtime state on the event bus. "
|
||||
"Ensure from_checkpoint() succeeded before calling fork()."
|
||||
)
|
||||
state.fork(branch)
|
||||
new_id = str(uuid4())
|
||||
if isinstance(flow._state, dict):
|
||||
flow._state["id"] = new_id
|
||||
else:
|
||||
object.__setattr__(flow._state, "id", new_id)
|
||||
return flow
|
||||
|
||||
checkpoint_completed_methods: set[str] | None = Field(default=None)
|
||||
@@ -1044,6 +1050,8 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
}
|
||||
if self.checkpoint_state is not None:
|
||||
self._restore_state(self.checkpoint_state)
|
||||
restore_event_scope(())
|
||||
reset_last_event_id()
|
||||
|
||||
_methods: dict[FlowMethodName, FlowMethod[Any, Any]] = PrivateAttr(
|
||||
default_factory=dict
|
||||
@@ -2250,6 +2258,9 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
if inputs is not None and "id" not in inputs:
|
||||
self._initialize_state(inputs)
|
||||
|
||||
if self._is_execution_resuming:
|
||||
await self._replay_recorded_events()
|
||||
|
||||
try:
|
||||
# Determine which start methods to execute at kickoff
|
||||
# Conditional start methods (with __trigger_methods__) are only triggered by their conditions
|
||||
@@ -2397,6 +2408,44 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
"""
|
||||
return await self.kickoff_async(inputs, input_files, from_checkpoint)
|
||||
|
||||
async def _replay_recorded_events(self) -> None:
|
||||
"""Dispatch recorded ``MethodExecution*`` events from the event record."""
|
||||
state = crewai_event_bus.runtime_state
|
||||
if state is None:
|
||||
return
|
||||
record = state.event_record
|
||||
if len(record) == 0:
|
||||
return
|
||||
|
||||
replayable = (
|
||||
MethodExecutionStartedEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionFailedEvent,
|
||||
)
|
||||
flow_name = self.name or self.__class__.__name__
|
||||
nodes = sorted(
|
||||
(
|
||||
n
|
||||
for n in record.all_nodes()
|
||||
if isinstance(n.event, replayable)
|
||||
and n.event.flow_name == flow_name
|
||||
and n.event.method_name in self._completed_methods
|
||||
),
|
||||
key=lambda n: n.event.emission_sequence or 0,
|
||||
)
|
||||
|
||||
for node in nodes:
|
||||
future = crewai_event_bus.replay(self, node.event)
|
||||
if future is not None:
|
||||
try:
|
||||
await asyncio.wrap_future(future)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Replayed event handler failed: %s",
|
||||
node.event.type,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
async def _execute_start_method(self, start_method_name: FlowMethodName) -> None:
|
||||
"""Executes a flow's start method and its triggered listeners.
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from typing import Any
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.crew import Crew
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.events.event_bus import CrewAIEventsBus, crewai_event_bus
|
||||
from crewai.events.event_bus import CrewAIEventsBus, crewai_event_bus, is_replaying
|
||||
from crewai.events.types.checkpoint_events import (
|
||||
CheckpointBaseEvent,
|
||||
CheckpointCompletedEvent,
|
||||
@@ -228,6 +228,8 @@ def _should_checkpoint(source: Any, event: BaseEvent) -> CheckpointConfig | None
|
||||
|
||||
def _on_any_event(source: Any, event: BaseEvent, state: Any) -> None:
|
||||
"""Sync handler registered on every event class."""
|
||||
if is_replaying():
|
||||
return
|
||||
if isinstance(
|
||||
event,
|
||||
(CheckpointBaseEvent, CheckpointForkBaseEvent, CheckpointRestoreBaseEvent),
|
||||
|
||||
@@ -197,6 +197,21 @@ class EventRecord(BaseModel):
|
||||
node for node in self.nodes.values() if not node.neighbors("parent")
|
||||
]
|
||||
|
||||
def all_nodes(self) -> list[EventNode]:
|
||||
"""Return a snapshot of every node under the read lock.
|
||||
|
||||
Returns:
|
||||
A list copy of the current nodes, safe to iterate without holding
|
||||
the lock.
|
||||
"""
|
||||
with self._lock.r_locked():
|
||||
return list(self.nodes.values())
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all nodes from the record under the write lock."""
|
||||
with self._lock.w_locked():
|
||||
self.nodes.clear()
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self._lock.r_locked():
|
||||
return len(self.nodes)
|
||||
|
||||
165
lib/crewai/tests/events/test_event_replay.py
Normal file
165
lib/crewai/tests/events/test_event_replay.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Tests for event bus replay dispatch and is_replaying flag."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from crewai.events.event_bus import _replaying, crewai_event_bus, is_replaying
|
||||
from crewai.events.types.flow_events import (
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
|
||||
|
||||
def _make_started(method: str, event_id: str, sequence: int) -> MethodExecutionStartedEvent:
|
||||
"""Build a MethodExecutionStartedEvent with explicit ids/sequence."""
|
||||
ev = MethodExecutionStartedEvent(
|
||||
method_name=method,
|
||||
flow_name="F",
|
||||
params={},
|
||||
state={},
|
||||
)
|
||||
ev.event_id = event_id
|
||||
ev.emission_sequence = sequence
|
||||
return ev
|
||||
|
||||
|
||||
class TestReplayPreservesFields:
|
||||
"""replay() must not overwrite event_id, parent_event_id, or emission_sequence."""
|
||||
|
||||
def test_preserves_ids_and_sequence(self) -> None:
|
||||
captured: list[MethodExecutionStartedEvent] = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
def _capture(_: Any, event: MethodExecutionStartedEvent) -> None:
|
||||
captured.append(event)
|
||||
|
||||
ev = _make_started("outline", "orig-id-1", 42)
|
||||
ev.parent_event_id = "parent-abc"
|
||||
|
||||
future = crewai_event_bus.replay(object(), ev)
|
||||
if future is not None:
|
||||
future.result(timeout=5.0)
|
||||
|
||||
assert len(captured) == 1
|
||||
assert captured[0].event_id == "orig-id-1"
|
||||
assert captured[0].parent_event_id == "parent-abc"
|
||||
assert captured[0].emission_sequence == 42
|
||||
|
||||
|
||||
class TestIsReplayingFlag:
|
||||
"""is_replaying() must be True inside handlers dispatched via replay()."""
|
||||
|
||||
def test_flag_true_during_replay(self) -> None:
|
||||
seen: list[bool] = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
def _capture(_: Any, __: MethodExecutionStartedEvent) -> None:
|
||||
seen.append(is_replaying())
|
||||
|
||||
ev = _make_started("m", "id-1", 1)
|
||||
future = crewai_event_bus.replay(object(), ev)
|
||||
if future is not None:
|
||||
future.result(timeout=5.0)
|
||||
|
||||
assert seen == [True]
|
||||
assert is_replaying() is False
|
||||
|
||||
def test_flag_false_during_emit(self) -> None:
|
||||
seen: list[bool] = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
def _capture(_: Any, __: MethodExecutionStartedEvent) -> None:
|
||||
seen.append(is_replaying())
|
||||
|
||||
ev = _make_started("m", "id-1", 1)
|
||||
future = crewai_event_bus.emit(object(), ev)
|
||||
if future is not None:
|
||||
future.result(timeout=5.0)
|
||||
|
||||
assert seen == [False]
|
||||
|
||||
|
||||
class TestCheckpointListenerOptsOut:
|
||||
"""CheckpointListener must early-return during replay."""
|
||||
|
||||
def test_checkpoint_not_written_on_replay(self) -> None:
|
||||
from crewai.state.checkpoint_config import CheckpointConfig
|
||||
from crewai.state.checkpoint_listener import _on_any_event
|
||||
|
||||
class FlowLike:
|
||||
entity_type = "flow"
|
||||
checkpoint = CheckpointConfig(trigger_all=True)
|
||||
|
||||
ev = _make_started("m", "id-1", 1)
|
||||
|
||||
with patch("crewai.state.checkpoint_listener._do_checkpoint") as do_cp:
|
||||
token = _replaying.set(True)
|
||||
try:
|
||||
_on_any_event(FlowLike(), ev, state=None)
|
||||
finally:
|
||||
_replaying.reset(token)
|
||||
assert do_cp.call_count == 0
|
||||
|
||||
|
||||
class TestFlowResumeReplaysEvents:
|
||||
"""End-to-end: a resumed flow emits MethodExecution* events for completed methods."""
|
||||
|
||||
def test_resume_dispatches_completed_method_events(self, tmp_path) -> None:
|
||||
from crewai.flow.flow import Flow, listen, start
|
||||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||
|
||||
db_path = tmp_path / "flows.db"
|
||||
persistence = SQLiteFlowPersistence(str(db_path))
|
||||
|
||||
class ThreeStepFlow(Flow[dict]):
|
||||
@start()
|
||||
def step_a(self) -> str:
|
||||
return "a"
|
||||
|
||||
@listen(step_a)
|
||||
def step_b(self) -> str:
|
||||
return "b"
|
||||
|
||||
@listen(step_b)
|
||||
def step_c(self) -> str:
|
||||
return "c"
|
||||
|
||||
if crewai_event_bus.runtime_state is not None:
|
||||
crewai_event_bus.runtime_state.event_record.clear()
|
||||
|
||||
flow1 = ThreeStepFlow(persistence=persistence)
|
||||
flow1.kickoff()
|
||||
flow_id = flow1.state["id"]
|
||||
|
||||
captured_started: list[str] = []
|
||||
captured_finished: list[str] = []
|
||||
|
||||
flow2 = ThreeStepFlow(persistence=persistence)
|
||||
flow2._completed_methods = {"step_a", "step_b"}
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
def _cs(_: Any, event: MethodExecutionStartedEvent) -> None:
|
||||
captured_started.append(event.method_name)
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionFinishedEvent)
|
||||
def _cf(_: Any, event: MethodExecutionFinishedEvent) -> None:
|
||||
captured_finished.append(event.method_name)
|
||||
|
||||
flow2.kickoff(inputs={"id": flow_id})
|
||||
|
||||
assert captured_started.count("step_a") == 1
|
||||
assert captured_started.count("step_b") == 1
|
||||
assert captured_started.count("step_c") == 1
|
||||
assert captured_finished.count("step_a") == 1
|
||||
assert captured_finished.count("step_b") == 1
|
||||
assert captured_finished.count("step_c") == 1
|
||||
Reference in New Issue
Block a user