fix: inject trigger locally in _do_checkpoint to avoid shared state race

This commit is contained in:
Greyson LaLonde
2026-04-10 19:43:26 +08:00
parent 1c9c78823d
commit 6fa9904e6b
2 changed files with 15 additions and 20 deletions

View File

@@ -7,6 +7,7 @@ avoids per-event overhead when no entity uses checkpointing.
from __future__ import annotations from __future__ import annotations
import json
import logging import logging
import threading import threading
from typing import Any from typing import Any
@@ -106,23 +107,21 @@ def _do_checkpoint(
state: RuntimeState, cfg: CheckpointConfig, event: BaseEvent | None = None state: RuntimeState, cfg: CheckpointConfig, event: BaseEvent | None = None
) -> None: ) -> None:
"""Write a checkpoint and prune old ones if configured.""" """Write a checkpoint and prune old ones if configured."""
_prepare_entities(state.root)
payload = state.model_dump(mode="json")
if event is not None: if event is not None:
state._trigger = event.type payload["trigger"] = event.type
try: data = json.dumps(payload)
_prepare_entities(state.root) location = cfg.provider.checkpoint(
data = state.model_dump_json() data,
location = cfg.provider.checkpoint( cfg.location,
data, parent_id=state._parent_id,
cfg.location, branch=state._branch,
parent_id=state._parent_id, )
branch=state._branch, state._chain_lineage(cfg.provider, location)
)
state._chain_lineage(cfg.provider, location)
if cfg.max_checkpoints is not None: if cfg.max_checkpoints is not None:
cfg.provider.prune(cfg.location, cfg.max_checkpoints, branch=state._branch) cfg.provider.prune(cfg.location, cfg.max_checkpoints, branch=state._branch)
finally:
state._trigger = None
def _should_checkpoint(source: Any, event: BaseEvent) -> CheckpointConfig | None: def _should_checkpoint(source: Any, event: BaseEvent) -> CheckpointConfig | None:

View File

@@ -111,7 +111,6 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
_checkpoint_id: str | None = PrivateAttr(default=None) _checkpoint_id: str | None = PrivateAttr(default=None)
_parent_id: str | None = PrivateAttr(default=None) _parent_id: str | None = PrivateAttr(default=None)
_branch: str = PrivateAttr(default="main") _branch: str = PrivateAttr(default="main")
_trigger: str | None = PrivateAttr(default=None)
@property @property
def event_record(self) -> EventRecord: def event_record(self) -> EventRecord:
@@ -120,16 +119,13 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
@model_serializer(mode="plain") @model_serializer(mode="plain")
def _serialize(self) -> dict[str, Any]: def _serialize(self) -> dict[str, Any]:
d: dict[str, Any] = { return {
"crewai_version": get_crewai_version(), "crewai_version": get_crewai_version(),
"parent_id": self._parent_id, "parent_id": self._parent_id,
"branch": self._branch, "branch": self._branch,
"entities": [e.model_dump(mode="json") for e in self.root], "entities": [e.model_dump(mode="json") for e in self.root],
"event_record": self._event_record.model_dump(), "event_record": self._event_record.model_dump(),
} }
if self._trigger:
d["trigger"] = self._trigger
return d
@model_validator(mode="wrap") @model_validator(mode="wrap")
@classmethod @classmethod