mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-11 21:42:36 +00:00
fix: inject trigger locally in _do_checkpoint to avoid shared state race
This commit is contained in:
@@ -7,6 +7,7 @@ avoids per-event overhead when no entity uses checkpointing.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
@@ -106,23 +107,21 @@ def _do_checkpoint(
|
||||
state: RuntimeState, cfg: CheckpointConfig, event: BaseEvent | None = None
|
||||
) -> None:
|
||||
"""Write a checkpoint and prune old ones if configured."""
|
||||
_prepare_entities(state.root)
|
||||
payload = state.model_dump(mode="json")
|
||||
if event is not None:
|
||||
state._trigger = event.type
|
||||
try:
|
||||
_prepare_entities(state.root)
|
||||
data = state.model_dump_json()
|
||||
location = cfg.provider.checkpoint(
|
||||
data,
|
||||
cfg.location,
|
||||
parent_id=state._parent_id,
|
||||
branch=state._branch,
|
||||
)
|
||||
state._chain_lineage(cfg.provider, location)
|
||||
payload["trigger"] = event.type
|
||||
data = json.dumps(payload)
|
||||
location = cfg.provider.checkpoint(
|
||||
data,
|
||||
cfg.location,
|
||||
parent_id=state._parent_id,
|
||||
branch=state._branch,
|
||||
)
|
||||
state._chain_lineage(cfg.provider, location)
|
||||
|
||||
if cfg.max_checkpoints is not None:
|
||||
cfg.provider.prune(cfg.location, cfg.max_checkpoints, branch=state._branch)
|
||||
finally:
|
||||
state._trigger = None
|
||||
if cfg.max_checkpoints is not None:
|
||||
cfg.provider.prune(cfg.location, cfg.max_checkpoints, branch=state._branch)
|
||||
|
||||
|
||||
def _should_checkpoint(source: Any, event: BaseEvent) -> CheckpointConfig | None:
|
||||
|
||||
@@ -111,7 +111,6 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
_checkpoint_id: str | None = PrivateAttr(default=None)
|
||||
_parent_id: str | None = PrivateAttr(default=None)
|
||||
_branch: str = PrivateAttr(default="main")
|
||||
_trigger: str | None = PrivateAttr(default=None)
|
||||
|
||||
@property
|
||||
def event_record(self) -> EventRecord:
|
||||
@@ -120,16 +119,13 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
|
||||
@model_serializer(mode="plain")
|
||||
def _serialize(self) -> dict[str, Any]:
|
||||
d: dict[str, Any] = {
|
||||
return {
|
||||
"crewai_version": get_crewai_version(),
|
||||
"parent_id": self._parent_id,
|
||||
"branch": self._branch,
|
||||
"entities": [e.model_dump(mode="json") for e in self.root],
|
||||
"event_record": self._event_record.model_dump(),
|
||||
}
|
||||
if self._trigger:
|
||||
d["trigger"] = self._trigger
|
||||
return d
|
||||
|
||||
@model_validator(mode="wrap")
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user