From 6fa9904e6bb76f424b222194988ffa473eb4d90b Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Fri, 10 Apr 2026 19:43:26 +0800 Subject: [PATCH] fix: inject trigger locally in _do_checkpoint to avoid shared state race --- .../src/crewai/state/checkpoint_listener.py | 29 +++++++++---------- lib/crewai/src/crewai/state/runtime.py | 6 +--- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/lib/crewai/src/crewai/state/checkpoint_listener.py b/lib/crewai/src/crewai/state/checkpoint_listener.py index df8d45ea2..2408e88e3 100644 --- a/lib/crewai/src/crewai/state/checkpoint_listener.py +++ b/lib/crewai/src/crewai/state/checkpoint_listener.py @@ -7,6 +7,7 @@ avoids per-event overhead when no entity uses checkpointing. from __future__ import annotations +import json import logging import threading from typing import Any @@ -106,23 +107,21 @@ def _do_checkpoint( state: RuntimeState, cfg: CheckpointConfig, event: BaseEvent | None = None ) -> None: """Write a checkpoint and prune old ones if configured.""" + _prepare_entities(state.root) + payload = state.model_dump(mode="json") if event is not None: - state._trigger = event.type - try: - _prepare_entities(state.root) - data = state.model_dump_json() - location = cfg.provider.checkpoint( - data, - cfg.location, - parent_id=state._parent_id, - branch=state._branch, - ) - state._chain_lineage(cfg.provider, location) + payload["trigger"] = event.type + data = json.dumps(payload) + location = cfg.provider.checkpoint( + data, + cfg.location, + parent_id=state._parent_id, + branch=state._branch, + ) + state._chain_lineage(cfg.provider, location) - if cfg.max_checkpoints is not None: - cfg.provider.prune(cfg.location, cfg.max_checkpoints, branch=state._branch) - finally: - state._trigger = None + if cfg.max_checkpoints is not None: + cfg.provider.prune(cfg.location, cfg.max_checkpoints, branch=state._branch) def _should_checkpoint(source: Any, event: BaseEvent) -> CheckpointConfig | None: diff --git a/lib/crewai/src/crewai/state/runtime.py b/lib/crewai/src/crewai/state/runtime.py index 56790005a..9151f0cfd 100644 --- a/lib/crewai/src/crewai/state/runtime.py +++ b/lib/crewai/src/crewai/state/runtime.py @@ -111,7 +111,6 @@ class RuntimeState(RootModel): # type: ignore[type-arg] _checkpoint_id: str | None = PrivateAttr(default=None) _parent_id: str | None = PrivateAttr(default=None) _branch: str = PrivateAttr(default="main") - _trigger: str | None = PrivateAttr(default=None) @property def event_record(self) -> EventRecord: @@ -120,16 +119,13 @@ class RuntimeState(RootModel): # type: ignore[type-arg] @model_serializer(mode="plain") def _serialize(self) -> dict[str, Any]: - d: dict[str, Any] = { + return { "crewai_version": get_crewai_version(), "parent_id": self._parent_id, "branch": self._branch, "entities": [e.model_dump(mode="json") for e in self.root], "event_record": self._event_record.model_dump(), } - if self._trigger: - d["trigger"] = self._trigger - return d @model_validator(mode="wrap") @classmethod