fix(events): scope runtime state per run to bound growth and isolate concurrent runs

This commit is contained in:
Greyson LaLonde
2026-06-10 18:39:05 -07:00
committed by GitHub
parent 036b032ab6
commit a1f44eb272
6 changed files with 181 additions and 43 deletions

View File

@@ -117,8 +117,10 @@ def capture_execution_context(
)
def apply_execution_context(ctx: ExecutionContext) -> None:
def apply_execution_context(ctx: ExecutionContext | dict[str, Any]) -> None:
"""Write an ExecutionContext back into the ContextVars."""
if isinstance(ctx, dict):
ctx = ExecutionContext.model_validate(ctx)
_current_task_id.set(ctx.current_task_id)
current_flow_request_id.set(ctx.flow_request_id)
current_flow_id.set(ctx.flow_id)

View File

@@ -1013,6 +1013,7 @@ class Crew(FlowTrackable, BaseModel):
)
token = attach(baggage_ctx)
runtime_scope = crewai_event_bus._enter_runtime_scope()
try:
inputs = prepare_kickoff(self, inputs, input_files)
@@ -1048,6 +1049,7 @@ class Crew(FlowTrackable, BaseModel):
self._memory.drain_writes()
clear_files(self.id)
detach(token)
crewai_event_bus._exit_runtime_scope(runtime_scope)
def _post_kickoff(self, result: CrewOutput) -> CrewOutput:
return result
@@ -1223,6 +1225,7 @@ class Crew(FlowTrackable, BaseModel):
)
token = attach(baggage_ctx)
runtime_scope = crewai_event_bus._enter_runtime_scope()
try:
inputs = prepare_kickoff(self, inputs, input_files)
@@ -1256,6 +1259,7 @@ class Crew(FlowTrackable, BaseModel):
finally:
clear_files(self.id)
detach(token)
crewai_event_bus._exit_runtime_scope(runtime_scope)
async def akickoff_for_each(
self,

View File

@@ -80,6 +80,17 @@ def is_replaying() -> bool:
return _replaying.get()
_runtime_state_var: contextvars.ContextVar[RuntimeState | None] = (
contextvars.ContextVar("crewai_runtime_state", default=None)
)
_registered_entity_ids_var: contextvars.ContextVar[set[int] | None] = (
contextvars.ContextVar("crewai_registered_entity_ids", default=None)
)
_runtime_scope_depth: contextvars.ContextVar[int] = contextvars.ContextVar(
"crewai_runtime_scope_depth", default=0
)
class CrewAIEventsBus:
"""Singleton event bus for handling events in CrewAI.
@@ -116,7 +127,6 @@ class CrewAIEventsBus:
_futures_lock: threading.Lock
_executor_initialized: bool
_has_pending_events: bool
_runtime_state: RuntimeState | None
def __new__(cls) -> Self:
"""Create or return the singleton instance.
@@ -151,8 +161,6 @@ class CrewAIEventsBus:
self._console = ConsoleFormatter()
self._executor_initialized = False
self._has_pending_events = False
self._runtime_state: RuntimeState | None = None
self._registered_entity_ids: set[int] = set()
def _ensure_executor_initialized(self) -> None:
"""Lazily initialize the thread pool executor and event loop.
@@ -281,11 +289,50 @@ class CrewAIEventsBus:
"""The RuntimeState currently attached to the bus, if any."""
return self._runtime_state
@property
def _runtime_state(self) -> RuntimeState | None:
return _runtime_state_var.get()
@_runtime_state.setter
def _runtime_state(self, value: RuntimeState | None) -> None:
_runtime_state_var.set(value)
@property
def _registered_entity_ids(self) -> set[int]:
ids = _registered_entity_ids_var.get()
if ids is None:
ids = set()
_registered_entity_ids_var.set(ids)
return ids
@_registered_entity_ids.setter
def _registered_entity_ids(self, value: set[int]) -> None:
_registered_entity_ids_var.set(value)
def reset_runtime_state(self) -> None:
"""Detach the RuntimeState and clear the entity registry."""
with self._instance_lock:
self._runtime_state = None
self._registered_entity_ids = set()
self._runtime_state = None
self._registered_entity_ids = set()
def _enter_runtime_scope(self) -> bool:
depth = _runtime_scope_depth.get()
_runtime_scope_depth.set(depth + 1)
if depth != 0:
return False
if _runtime_state_var.get() is None:
from crewai import RuntimeState
if RuntimeState is not None:
_runtime_state_var.set(RuntimeState(root=[]))
_registered_entity_ids_var.set(set())
return True
def _exit_runtime_scope(self, outermost: bool) -> None:
depth = _runtime_scope_depth.get()
_runtime_scope_depth.set(depth - 1 if depth > 0 else 0)
if outermost:
_runtime_state_var.set(None)
_registered_entity_ids_var.set(None)
def register_entity(self, entity: Any) -> None:
"""Add an entity to the RuntimeState, creating it if needed.
@@ -355,6 +402,7 @@ class CrewAIEventsBus:
source: Any,
event: BaseEvent,
handlers: SyncHandlerSet,
state: RuntimeState | None,
) -> None:
"""Call provided synchronous handlers.
@@ -362,8 +410,8 @@ class CrewAIEventsBus:
source: The emitting object
event: The event instance
handlers: Frozenset of sync handlers to call
state: The RuntimeState captured on the emitting context
"""
state = self._runtime_state
errors: list[tuple[SyncHandler, Exception]] = [
(handler, error)
for handler in handlers
@@ -382,6 +430,7 @@ class CrewAIEventsBus:
source: Any,
event: BaseEvent,
handlers: AsyncHandlerSet,
state: RuntimeState | None,
) -> None:
"""Asynchronously call provided async handlers.
@@ -389,8 +438,8 @@ class CrewAIEventsBus:
source: The object that emitted the event
event: The event instance
handlers: Frozenset of async handlers to call
state: The RuntimeState captured on the emitting context
"""
state = self._runtime_state
async def _call(handler: AsyncHandler) -> Any:
if _get_param_count(handler) >= 3:
@@ -405,7 +454,9 @@ class CrewAIEventsBus:
f"[CrewAIEventsBus] Async handler error in {getattr(handler, '__name__', handler)}: {result}"
)
async def _emit_with_dependencies(self, source: Any, event: BaseEvent) -> None:
async def _emit_with_dependencies(
self, source: Any, event: BaseEvent, state: RuntimeState | None
) -> None:
"""Emit an event with dependency-aware handler execution.
Handlers are grouped into execution levels based on their dependencies.
@@ -456,18 +507,18 @@ class CrewAIEventsBus:
if level_sync:
if event_type is LLMStreamChunkEvent:
self._call_handlers(source, event, level_sync)
self._call_handlers(source, event, level_sync, state)
else:
ctx = contextvars.copy_context()
future = self._sync_executor.submit(
ctx.run, self._call_handlers, source, event, level_sync
ctx.run, self._call_handlers, source, event, level_sync, state
)
await asyncio.get_running_loop().run_in_executor(
None, future.result
)
if level_async:
await self._acall_handlers(source, event, level_async)
await self._acall_handlers(source, event, level_async, state)
def _register_source(self, source: Any) -> None:
"""Register the source entity in RuntimeState if applicable."""
@@ -562,21 +613,23 @@ class CrewAIEventsBus:
self._ensure_executor_initialized()
self._has_pending_events = True
state = self._runtime_state
if has_dependencies:
return self._track_future(
asyncio.run_coroutine_threadsafe(
self._emit_with_dependencies(source, event),
self._emit_with_dependencies(source, event, state),
self._loop,
)
)
if sync_handlers:
if event_type is LLMStreamChunkEvent:
self._call_handlers(source, event, sync_handlers)
self._call_handlers(source, event, sync_handlers, state)
else:
ctx = contextvars.copy_context()
sync_future = self._sync_executor.submit(
ctx.run, self._call_handlers, source, event, sync_handlers
ctx.run, self._call_handlers, source, event, sync_handlers, state
)
if not async_handlers:
return self._track_future(sync_future)
@@ -584,7 +637,7 @@ class CrewAIEventsBus:
if async_handlers:
return self._track_future(
asyncio.run_coroutine_threadsafe(
self._acall_handlers(source, event, async_handlers),
self._acall_handlers(source, event, async_handlers, state),
self._loop,
)
)
@@ -596,21 +649,22 @@ class CrewAIEventsBus:
source: Any,
event: BaseEvent,
handlers: AsyncHandlerSet,
state: RuntimeState | None,
) -> None:
"""Call async handlers with the replaying flag set on the loop thread."""
token = _replaying.set(True)
try:
await self._acall_handlers(source, event, handlers)
await self._acall_handlers(source, event, handlers, state)
finally:
_replaying.reset(token)
async def _emit_with_dependencies_replaying(
self, source: Any, event: BaseEvent
self, source: Any, event: BaseEvent, state: RuntimeState | None
) -> None:
"""Dependency-aware dispatch with the replaying flag set."""
token = _replaying.set(True)
try:
await self._emit_with_dependencies(source, event)
await self._emit_with_dependencies(source, event, state)
finally:
_replaying.reset(token)
@@ -644,12 +698,13 @@ class CrewAIEventsBus:
self._ensure_executor_initialized()
self._has_pending_events = True
state = self._runtime_state
token = _replaying.set(True)
try:
if has_dependencies:
return self._track_future(
asyncio.run_coroutine_threadsafe(
self._emit_with_dependencies_replaying(source, event),
self._emit_with_dependencies_replaying(source, event, state),
self._loop,
)
)
@@ -657,7 +712,7 @@ class CrewAIEventsBus:
if sync_handlers:
ctx = contextvars.copy_context()
sync_future = self._sync_executor.submit(
ctx.run, self._call_handlers, source, event, sync_handlers
ctx.run, self._call_handlers, source, event, sync_handlers, state
)
self._track_future(sync_future)
if not async_handlers:
@@ -665,7 +720,9 @@ class CrewAIEventsBus:
return self._track_future(
asyncio.run_coroutine_threadsafe(
self._acall_handlers_replaying(source, event, async_handlers),
self._acall_handlers_replaying(
source, event, async_handlers, state
),
self._loop,
)
)
@@ -733,7 +790,9 @@ class CrewAIEventsBus:
async_handlers = self._async_handlers.get(event_type, frozenset())
if async_handlers:
await self._acall_handlers(source, event, async_handlers)
await self._acall_handlers(
source, event, async_handlers, self._runtime_state
)
def register_handler(
self,

View File

@@ -1,6 +1,6 @@
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, field_serializer
from crewai.events.base_events import BaseEvent
@@ -57,6 +57,10 @@ class MethodExecutionFailedEvent(FlowEvent):
model_config = ConfigDict(arbitrary_types_allowed=True)
@field_serializer("error")
def _serialize_error(self, error: Exception) -> str:
return str(error)
class MethodExecutionPausedEvent(FlowEvent):
"""Event emitted when a flow method is paused waiting for human feedback.

View File

@@ -1935,13 +1935,17 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
restore_from_state_id=restore_from_state_id,
)
runtime_scope = crewai_event_bus._enter_runtime_scope()
try:
asyncio.get_running_loop()
ctx = contextvars.copy_context()
with ThreadPoolExecutor(max_workers=1) as pool:
return pool.submit(ctx.run, asyncio.run, _run_flow()).result()
except RuntimeError:
return asyncio.run(_run_flow())
try:
asyncio.get_running_loop()
ctx = contextvars.copy_context()
with ThreadPoolExecutor(max_workers=1) as pool:
return pool.submit(ctx.run, asyncio.run, _run_flow()).result()
except RuntimeError:
return asyncio.run(_run_flow())
finally:
crewai_event_bus._exit_runtime_scope(runtime_scope)
async def kickoff_async(
self,
@@ -2049,6 +2053,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
if current_flow_request_id.get() is None:
request_id_token = current_flow_request_id.set(self.flow_id)
runtime_scope = crewai_event_bus._enter_runtime_scope()
try:
# Reset flow state for fresh execution unless restoring from persistence
is_restoring = (
@@ -2345,6 +2350,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
if flow_id_token is not None:
current_flow_id.reset(flow_id_token)
detach(flow_token)
crewai_event_bus._exit_runtime_scope(runtime_scope)
async def akickoff(
self,

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import threading
from typing import Any
from unittest.mock import patch
@@ -109,10 +110,79 @@ class TestCheckpointListenerOptsOut:
assert do_cp.call_count == 0
class TestFlowResumeReplaysEvents:
"""End-to-end: a resumed flow emits MethodExecution* events for completed methods."""
class TestCheckpointResumeReplaysEvents:
"""A flow resumed from a checkpoint replays MethodExecution* events for
completed methods and executes the pending ones. The checkpoint persists
the event record, which is reloaded into the per-run runtime state.
def test_resume_dispatches_completed_method_events(self, tmp_path) -> None:
``step_c`` is gated on a threading.Event so the flow is frozen with exactly
``step_a`` and ``step_b`` completed when the checkpoint is written — the
mid-run snapshot is deterministic rather than dependent on write timing.
"""
def test_resume_replays_completed_and_executes_pending(self, tmp_path) -> None:
from crewai.flow.flow import Flow, listen, start
from crewai.state.checkpoint_config import CheckpointConfig
at_step_c = threading.Event()
release = threading.Event()
captured: list[Any] = []
class ThreeStepFlow(Flow[dict]):
@start()
def step_a(self) -> str:
return "a"
@listen(step_a)
def step_b(self) -> str:
return "b"
@listen(step_b)
def step_c(self) -> str:
captured.append(crewai_event_bus.runtime_state)
at_step_c.set()
release.wait(timeout=10)
return "c"
runner = threading.Thread(target=ThreeStepFlow().kickoff)
runner.start()
try:
assert at_step_c.wait(timeout=10)
location = captured[0].checkpoint(str(tmp_path / "cp"))
finally:
release.set()
runner.join(timeout=10)
captured_started: list[str] = []
captured_finished: list[str] = []
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(MethodExecutionStartedEvent)
def _cs(_: Any, event: MethodExecutionStartedEvent) -> None:
captured_started.append(event.method_name)
@crewai_event_bus.on(MethodExecutionFinishedEvent)
def _cf(_: Any, event: MethodExecutionFinishedEvent) -> None:
captured_finished.append(event.method_name)
ThreeStepFlow().kickoff(
from_checkpoint=CheckpointConfig(restore_from=location)
)
assert captured_started == ["step_a", "step_b", "step_c"]
assert captured_finished == ["step_a", "step_b", "step_c"]
class TestPersistResumeDoesNotReplayCompletedEvents:
"""A @persist resume continues from pending methods only.
@persist stores flow state, not the event record, so completed-method
events have no persisted source to replay from. Runtime state is scoped
per run, so flow1's events are not visible to flow2.
"""
def test_persist_resume_executes_only_pending_methods(self, tmp_path) -> None:
from crewai.flow.flow import Flow, listen, start
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
@@ -132,9 +202,6 @@ class TestFlowResumeReplaysEvents:
def step_c(self) -> str:
return "c"
if crewai_event_bus.runtime_state is not None:
crewai_event_bus.runtime_state.event_record.clear()
flow1 = ThreeStepFlow(persistence=persistence)
flow1.kickoff()
flow_id = flow1.state["id"]
@@ -157,9 +224,5 @@ class TestFlowResumeReplaysEvents:
flow2.kickoff(inputs={"id": flow_id})
assert captured_started.count("step_a") == 1
assert captured_started.count("step_b") == 1
assert captured_started.count("step_c") == 1
assert captured_finished.count("step_a") == 1
assert captured_finished.count("step_b") == 1
assert captured_finished.count("step_c") == 1
assert captured_started == ["step_c"]
assert captured_finished == ["step_c"]