mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
fix: inject trigger locally in _do_checkpoint to avoid shared state race
This commit is contained in:
@@ -7,6 +7,7 @@ avoids per-event overhead when no entity uses checkpointing.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -106,23 +107,21 @@ def _do_checkpoint(
|
|||||||
state: RuntimeState, cfg: CheckpointConfig, event: BaseEvent | None = None
|
state: RuntimeState, cfg: CheckpointConfig, event: BaseEvent | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Write a checkpoint and prune old ones if configured."""
|
"""Write a checkpoint and prune old ones if configured."""
|
||||||
|
_prepare_entities(state.root)
|
||||||
|
payload = state.model_dump(mode="json")
|
||||||
if event is not None:
|
if event is not None:
|
||||||
state._trigger = event.type
|
payload["trigger"] = event.type
|
||||||
try:
|
data = json.dumps(payload)
|
||||||
_prepare_entities(state.root)
|
location = cfg.provider.checkpoint(
|
||||||
data = state.model_dump_json()
|
data,
|
||||||
location = cfg.provider.checkpoint(
|
cfg.location,
|
||||||
data,
|
parent_id=state._parent_id,
|
||||||
cfg.location,
|
branch=state._branch,
|
||||||
parent_id=state._parent_id,
|
)
|
||||||
branch=state._branch,
|
state._chain_lineage(cfg.provider, location)
|
||||||
)
|
|
||||||
state._chain_lineage(cfg.provider, location)
|
|
||||||
|
|
||||||
if cfg.max_checkpoints is not None:
|
if cfg.max_checkpoints is not None:
|
||||||
cfg.provider.prune(cfg.location, cfg.max_checkpoints, branch=state._branch)
|
cfg.provider.prune(cfg.location, cfg.max_checkpoints, branch=state._branch)
|
||||||
finally:
|
|
||||||
state._trigger = None
|
|
||||||
|
|
||||||
|
|
||||||
def _should_checkpoint(source: Any, event: BaseEvent) -> CheckpointConfig | None:
|
def _should_checkpoint(source: Any, event: BaseEvent) -> CheckpointConfig | None:
|
||||||
|
|||||||
@@ -111,7 +111,6 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
|||||||
_checkpoint_id: str | None = PrivateAttr(default=None)
|
_checkpoint_id: str | None = PrivateAttr(default=None)
|
||||||
_parent_id: str | None = PrivateAttr(default=None)
|
_parent_id: str | None = PrivateAttr(default=None)
|
||||||
_branch: str = PrivateAttr(default="main")
|
_branch: str = PrivateAttr(default="main")
|
||||||
_trigger: str | None = PrivateAttr(default=None)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def event_record(self) -> EventRecord:
|
def event_record(self) -> EventRecord:
|
||||||
@@ -120,16 +119,13 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
|||||||
|
|
||||||
@model_serializer(mode="plain")
|
@model_serializer(mode="plain")
|
||||||
def _serialize(self) -> dict[str, Any]:
|
def _serialize(self) -> dict[str, Any]:
|
||||||
d: dict[str, Any] = {
|
return {
|
||||||
"crewai_version": get_crewai_version(),
|
"crewai_version": get_crewai_version(),
|
||||||
"parent_id": self._parent_id,
|
"parent_id": self._parent_id,
|
||||||
"branch": self._branch,
|
"branch": self._branch,
|
||||||
"entities": [e.model_dump(mode="json") for e in self.root],
|
"entities": [e.model_dump(mode="json") for e in self.root],
|
||||||
"event_record": self._event_record.model_dump(),
|
"event_record": self._event_record.model_dump(),
|
||||||
}
|
}
|
||||||
if self._trigger:
|
|
||||||
d["trigger"] = self._trigger
|
|
||||||
return d
|
|
||||||
|
|
||||||
@model_validator(mode="wrap")
|
@model_validator(mode="wrap")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user