feat: emit lifecycle events for checkpoint operations
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Vulnerability Scan / pip-audit (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

This commit is contained in:
Greyson LaLonde
2026-04-23 18:47:50 +08:00
committed by GitHub
parent bc2fb71560
commit 55937d7523
8 changed files with 473 additions and 49 deletions

View File

@@ -21,6 +21,7 @@ from crewai.events.depends import Depends
from crewai.events.event_bus import crewai_event_bus
from crewai.events.handler_graph import CircularDependencyError
if TYPE_CHECKING:
from crewai.events.types.agent_events import (
AgentEvaluationCompletedEvent,
@@ -33,6 +34,20 @@ if TYPE_CHECKING:
LiteAgentExecutionErrorEvent,
LiteAgentExecutionStartedEvent,
)
from crewai.events.types.checkpoint_events import (
CheckpointBaseEvent,
CheckpointCompletedEvent,
CheckpointFailedEvent,
CheckpointForkBaseEvent,
CheckpointForkCompletedEvent,
CheckpointForkStartedEvent,
CheckpointPrunedEvent,
CheckpointRestoreBaseEvent,
CheckpointRestoreCompletedEvent,
CheckpointRestoreFailedEvent,
CheckpointRestoreStartedEvent,
CheckpointStartedEvent,
)
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
@@ -141,6 +156,19 @@ _LAZY_EVENT_MAPPING: dict[str, str] = {
"LiteAgentExecutionCompletedEvent": "crewai.events.types.agent_events",
"LiteAgentExecutionErrorEvent": "crewai.events.types.agent_events",
"LiteAgentExecutionStartedEvent": "crewai.events.types.agent_events",
# checkpoint_events
"CheckpointBaseEvent": "crewai.events.types.checkpoint_events",
"CheckpointCompletedEvent": "crewai.events.types.checkpoint_events",
"CheckpointFailedEvent": "crewai.events.types.checkpoint_events",
"CheckpointForkBaseEvent": "crewai.events.types.checkpoint_events",
"CheckpointForkCompletedEvent": "crewai.events.types.checkpoint_events",
"CheckpointForkStartedEvent": "crewai.events.types.checkpoint_events",
"CheckpointPrunedEvent": "crewai.events.types.checkpoint_events",
"CheckpointRestoreBaseEvent": "crewai.events.types.checkpoint_events",
"CheckpointRestoreCompletedEvent": "crewai.events.types.checkpoint_events",
"CheckpointRestoreFailedEvent": "crewai.events.types.checkpoint_events",
"CheckpointRestoreStartedEvent": "crewai.events.types.checkpoint_events",
"CheckpointStartedEvent": "crewai.events.types.checkpoint_events",
# crew_events
"CrewKickoffCompletedEvent": "crewai.events.types.crew_events",
"CrewKickoffFailedEvent": "crewai.events.types.crew_events",
@@ -265,6 +293,18 @@ __all__ = [
"AgentReasoningFailedEvent",
"AgentReasoningStartedEvent",
"BaseEventListener",
"CheckpointBaseEvent",
"CheckpointCompletedEvent",
"CheckpointFailedEvent",
"CheckpointForkBaseEvent",
"CheckpointForkCompletedEvent",
"CheckpointForkStartedEvent",
"CheckpointPrunedEvent",
"CheckpointRestoreBaseEvent",
"CheckpointRestoreCompletedEvent",
"CheckpointRestoreFailedEvent",
"CheckpointRestoreStartedEvent",
"CheckpointStartedEvent",
"CircularDependencyError",
"CrewKickoffCompletedEvent",
"CrewKickoffFailedEvent",

View File

@@ -30,6 +30,17 @@ from crewai.events.types.agent_events import (
AgentExecutionStartedEvent,
LiteAgentExecutionCompletedEvent,
)
from crewai.events.types.checkpoint_events import (
CheckpointCompletedEvent,
CheckpointFailedEvent,
CheckpointForkCompletedEvent,
CheckpointForkStartedEvent,
CheckpointPrunedEvent,
CheckpointRestoreCompletedEvent,
CheckpointRestoreFailedEvent,
CheckpointRestoreStartedEvent,
CheckpointStartedEvent,
)
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
@@ -183,4 +194,13 @@ EventTypes = (
| MCPToolExecutionCompletedEvent
| MCPToolExecutionFailedEvent
| MCPConfigFetchFailedEvent
| CheckpointStartedEvent
| CheckpointCompletedEvent
| CheckpointFailedEvent
| CheckpointForkStartedEvent
| CheckpointForkCompletedEvent
| CheckpointRestoreStartedEvent
| CheckpointRestoreCompletedEvent
| CheckpointRestoreFailedEvent
| CheckpointPrunedEvent
)

View File

@@ -0,0 +1,97 @@
"""Event family for automatic state checkpointing and forking."""
from typing import Literal
from crewai.events.base_events import BaseEvent
class CheckpointBaseEvent(BaseEvent):
"""Base event for checkpoint lifecycle operations."""
type: str
location: str
provider: str
trigger: str | None = None
branch: str | None = None
parent_id: str | None = None
class CheckpointStartedEvent(CheckpointBaseEvent):
"""Event emitted immediately before a checkpoint is written."""
type: Literal["checkpoint_started"] = "checkpoint_started"
class CheckpointCompletedEvent(CheckpointBaseEvent):
"""Event emitted when a checkpoint has been written successfully."""
type: Literal["checkpoint_completed"] = "checkpoint_completed"
checkpoint_id: str
duration_ms: float
class CheckpointFailedEvent(CheckpointBaseEvent):
"""Event emitted when a checkpoint write fails."""
type: Literal["checkpoint_failed"] = "checkpoint_failed"
error: str
class CheckpointPrunedEvent(CheckpointBaseEvent):
"""Event emitted after pruning old checkpoints from a branch."""
type: Literal["checkpoint_pruned"] = "checkpoint_pruned"
removed_count: int
max_checkpoints: int
class CheckpointForkBaseEvent(BaseEvent):
"""Base event for fork lifecycle operations on a RuntimeState."""
type: str
branch: str
parent_branch: str | None = None
parent_checkpoint_id: str | None = None
class CheckpointForkStartedEvent(CheckpointForkBaseEvent):
"""Event emitted immediately before a fork relabels the branch."""
type: Literal["checkpoint_fork_started"] = "checkpoint_fork_started"
class CheckpointForkCompletedEvent(CheckpointForkBaseEvent):
"""Event emitted after a fork has established the new branch."""
type: Literal["checkpoint_fork_completed"] = "checkpoint_fork_completed"
class CheckpointRestoreBaseEvent(BaseEvent):
"""Base event for checkpoint restore lifecycle operations."""
type: str
location: str
provider: str | None = None
class CheckpointRestoreStartedEvent(CheckpointRestoreBaseEvent):
"""Event emitted immediately before a checkpoint restore begins."""
type: Literal["checkpoint_restore_started"] = "checkpoint_restore_started"
class CheckpointRestoreCompletedEvent(CheckpointRestoreBaseEvent):
"""Event emitted when a checkpoint has been restored successfully."""
type: Literal["checkpoint_restore_completed"] = "checkpoint_restore_completed"
checkpoint_id: str
branch: str | None = None
parent_id: str | None = None
duration_ms: float
class CheckpointRestoreFailedEvent(CheckpointRestoreBaseEvent):
"""Event emitted when a checkpoint restore fails."""
type: Literal["checkpoint_restore_failed"] = "checkpoint_restore_failed"
error: str

View File

@@ -10,12 +10,22 @@ from __future__ import annotations
import json
import logging
import threading
import time
from typing import Any
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.crew import Crew
from crewai.events.base_events import BaseEvent
from crewai.events.event_bus import CrewAIEventsBus, crewai_event_bus
from crewai.events.types.checkpoint_events import (
CheckpointBaseEvent,
CheckpointCompletedEvent,
CheckpointFailedEvent,
CheckpointForkBaseEvent,
CheckpointPrunedEvent,
CheckpointRestoreBaseEvent,
CheckpointStartedEvent,
)
from crewai.flow.flow import Flow
from crewai.state.checkpoint_config import CheckpointConfig
from crewai.state.runtime import RuntimeState, _prepare_entities
@@ -107,27 +117,106 @@ def _do_checkpoint(
state: RuntimeState, cfg: CheckpointConfig, event: BaseEvent | None = None
) -> None:
"""Write a checkpoint and prune old ones if configured."""
_prepare_entities(state.root)
payload = state.model_dump(mode="json")
if event is not None:
payload["trigger"] = event.type
data = json.dumps(payload)
location = cfg.provider.checkpoint(
data,
cfg.location,
parent_id=state._parent_id,
branch=state._branch,
)
state._chain_lineage(cfg.provider, location)
provider_name: str = type(cfg.provider).__name__
trigger: str | None = event.type if event is not None else None
context: dict[str, Any] = {
"task_id": event.task_id if event is not None else None,
"task_name": event.task_name if event is not None else None,
"agent_id": event.agent_id if event is not None else None,
"agent_role": event.agent_role if event is not None else None,
}
checkpoint_id: str = cfg.provider.extract_id(location)
parent_id_snapshot: str | None = state._parent_id
branch_snapshot: str = state._branch
crewai_event_bus.emit(
cfg,
CheckpointStartedEvent(
location=cfg.location,
provider=provider_name,
trigger=trigger,
branch=branch_snapshot,
parent_id=parent_id_snapshot,
**context,
),
)
start: float = time.perf_counter()
try:
_prepare_entities(state.root)
payload = state.model_dump(mode="json")
if event is not None:
payload["trigger"] = event.type
data = json.dumps(payload)
location = cfg.provider.checkpoint(
data,
cfg.location,
parent_id=parent_id_snapshot,
branch=branch_snapshot,
)
state._chain_lineage(cfg.provider, location)
checkpoint_id: str = cfg.provider.extract_id(location)
except Exception as exc:
crewai_event_bus.emit(
cfg,
CheckpointFailedEvent(
location=cfg.location,
provider=provider_name,
trigger=trigger,
branch=branch_snapshot,
parent_id=parent_id_snapshot,
error=str(exc),
**context,
),
)
raise
duration_ms: float = (time.perf_counter() - start) * 1000.0
msg: str = (
f"Checkpoint saved. Resume with: crewai checkpoint resume {checkpoint_id}"
)
logger.info(msg)
crewai_event_bus.emit(
cfg,
CheckpointCompletedEvent(
location=location,
provider=provider_name,
trigger=trigger,
branch=branch_snapshot,
parent_id=parent_id_snapshot,
checkpoint_id=checkpoint_id,
duration_ms=duration_ms,
**context,
),
)
if cfg.max_checkpoints is not None:
cfg.provider.prune(cfg.location, cfg.max_checkpoints, branch=state._branch)
try:
removed_count: int = cfg.provider.prune(
cfg.location, cfg.max_checkpoints, branch=branch_snapshot
)
except Exception:
logger.warning(
"Checkpoint prune failed for %s (branch=%s)",
cfg.location,
branch_snapshot,
exc_info=True,
)
return
crewai_event_bus.emit(
cfg,
CheckpointPrunedEvent(
location=cfg.location,
provider=provider_name,
trigger=trigger,
branch=branch_snapshot,
parent_id=parent_id_snapshot,
removed_count=removed_count,
max_checkpoints=cfg.max_checkpoints,
**context,
),
)
def _should_checkpoint(source: Any, event: BaseEvent) -> CheckpointConfig | None:
@@ -142,6 +231,11 @@ def _should_checkpoint(source: Any, event: BaseEvent) -> CheckpointConfig | None
def _on_any_event(source: Any, event: BaseEvent, state: Any) -> None:
"""Sync handler registered on every event class."""
if isinstance(
event,
(CheckpointBaseEvent, CheckpointForkBaseEvent, CheckpointRestoreBaseEvent),
):
return
cfg = _should_checkpoint(source, event)
if cfg is None:
return

View File

@@ -61,13 +61,16 @@ class BaseProvider(BaseModel, ABC):
...
@abstractmethod
def prune(self, location: str, max_keep: int, *, branch: str = "main") -> None:
def prune(self, location: str, max_keep: int, *, branch: str = "main") -> int:
"""Remove old checkpoints, keeping at most *max_keep* per branch.
Args:
location: The storage destination passed to ``checkpoint``.
max_keep: Maximum number of checkpoints to retain.
branch: Only prune checkpoints on this branch.
Returns:
The number of checkpoints removed.
"""
...

View File

@@ -95,17 +95,20 @@ class JsonProvider(BaseProvider):
await f.write(data)
return str(file_path)
def prune(self, location: str, max_keep: int, *, branch: str = "main") -> None:
def prune(self, location: str, max_keep: int, *, branch: str = "main") -> int:
"""Remove oldest checkpoint files beyond *max_keep* on a branch."""
_safe_branch(location, branch)
branch_dir = os.path.join(location, branch)
pattern = os.path.join(branch_dir, "*.json")
files = sorted(glob.glob(pattern), key=os.path.getmtime)
removed = 0
for path in files if max_keep == 0 else files[:-max_keep]:
try:
os.remove(path)
removed += 1
except OSError: # noqa: PERF203
logger.debug("Failed to remove %s", path, exc_info=True)
return removed
def extract_id(self, location: str) -> str:
"""Extract the checkpoint ID from a file path.

View File

@@ -111,11 +111,13 @@ class SqliteProvider(BaseProvider):
await db.commit()
return f"{location}#{checkpoint_id}"
def prune(self, location: str, max_keep: int, *, branch: str = "main") -> None:
def prune(self, location: str, max_keep: int, *, branch: str = "main") -> int:
"""Remove oldest checkpoint rows beyond *max_keep* on a branch."""
with sqlite3.connect(location) as conn:
conn.execute(_PRUNE, (branch, branch, max_keep))
cursor = conn.execute(_PRUNE, (branch, branch, max_keep))
removed: int = cursor.rowcount
conn.commit()
return max(removed, 0)
def extract_id(self, location: str) -> str:
"""Extract the checkpoint ID from a ``db_path#id`` string."""

View File

@@ -10,6 +10,7 @@ via ``RuntimeState.model_rebuild()``.
from __future__ import annotations
import logging
import time
from typing import TYPE_CHECKING, Any
import uuid
@@ -23,6 +24,17 @@ from pydantic import (
)
from crewai.context import capture_execution_context
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.checkpoint_events import (
CheckpointCompletedEvent,
CheckpointFailedEvent,
CheckpointForkCompletedEvent,
CheckpointForkStartedEvent,
CheckpointRestoreCompletedEvent,
CheckpointRestoreFailedEvent,
CheckpointRestoreStartedEvent,
CheckpointStartedEvent,
)
from crewai.state.checkpoint_config import CheckpointConfig
from crewai.state.event_record import EventRecord
from crewai.state.provider.core import BaseProvider
@@ -169,14 +181,52 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
Returns:
A location identifier for the saved checkpoint.
"""
_prepare_entities(self.root)
result = self._provider.checkpoint(
self.model_dump_json(),
location,
parent_id=self._parent_id,
branch=self._branch,
provider_name: str = type(self._provider).__name__
parent_id_snapshot: str | None = self._parent_id
branch_snapshot: str = self._branch
crewai_event_bus.emit(
self,
CheckpointStartedEvent(
location=location,
provider=provider_name,
branch=branch_snapshot,
parent_id=parent_id_snapshot,
),
)
start: float = time.perf_counter()
try:
_prepare_entities(self.root)
result = self._provider.checkpoint(
self.model_dump_json(),
location,
parent_id=parent_id_snapshot,
branch=branch_snapshot,
)
self._chain_lineage(self._provider, result)
except Exception as exc:
crewai_event_bus.emit(
self,
CheckpointFailedEvent(
location=location,
provider=provider_name,
branch=branch_snapshot,
parent_id=parent_id_snapshot,
error=str(exc),
),
)
raise
crewai_event_bus.emit(
self,
CheckpointCompletedEvent(
location=result,
provider=provider_name,
branch=branch_snapshot,
parent_id=parent_id_snapshot,
checkpoint_id=self._provider.extract_id(result),
duration_ms=(time.perf_counter() - start) * 1000.0,
),
)
self._chain_lineage(self._provider, result)
return result
async def acheckpoint(self, location: str) -> str:
@@ -189,14 +239,52 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
Returns:
A location identifier for the saved checkpoint.
"""
_prepare_entities(self.root)
result = await self._provider.acheckpoint(
self.model_dump_json(),
location,
parent_id=self._parent_id,
branch=self._branch,
provider_name: str = type(self._provider).__name__
parent_id_snapshot: str | None = self._parent_id
branch_snapshot: str = self._branch
crewai_event_bus.emit(
self,
CheckpointStartedEvent(
location=location,
provider=provider_name,
branch=branch_snapshot,
parent_id=parent_id_snapshot,
),
)
start: float = time.perf_counter()
try:
_prepare_entities(self.root)
result = await self._provider.acheckpoint(
self.model_dump_json(),
location,
parent_id=parent_id_snapshot,
branch=branch_snapshot,
)
self._chain_lineage(self._provider, result)
except Exception as exc:
crewai_event_bus.emit(
self,
CheckpointFailedEvent(
location=location,
provider=provider_name,
branch=branch_snapshot,
parent_id=parent_id_snapshot,
error=str(exc),
),
)
raise
crewai_event_bus.emit(
self,
CheckpointCompletedEvent(
location=result,
provider=provider_name,
branch=branch_snapshot,
parent_id=parent_id_snapshot,
checkpoint_id=self._provider.extract_id(result),
duration_ms=(time.perf_counter() - start) * 1000.0,
),
)
self._chain_lineage(self._provider, result)
return result
def fork(self, branch: str | None = None) -> None:
@@ -211,11 +299,32 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
times without collisions.
"""
if branch:
self._branch = branch
new_branch = branch
elif self._checkpoint_id:
self._branch = f"fork/{self._checkpoint_id}_{uuid.uuid4().hex[:6]}"
new_branch = f"fork/{self._checkpoint_id}_{uuid.uuid4().hex[:6]}"
else:
self._branch = f"fork/{uuid.uuid4().hex[:8]}"
new_branch = f"fork/{uuid.uuid4().hex[:8]}"
parent_branch: str | None = self._branch
parent_checkpoint_id: str | None = self._checkpoint_id
crewai_event_bus.emit(
self,
CheckpointForkStartedEvent(
branch=new_branch,
parent_branch=parent_branch,
parent_checkpoint_id=parent_checkpoint_id,
),
)
self._branch = new_branch
crewai_event_bus.emit(
self,
CheckpointForkCompletedEvent(
branch=new_branch,
parent_branch=parent_branch,
parent_checkpoint_id=parent_checkpoint_id,
),
)
@classmethod
def from_checkpoint(cls, config: CheckpointConfig, **kwargs: Any) -> RuntimeState:
@@ -233,13 +342,41 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
if config.restore_from is None:
raise ValueError("CheckpointConfig.restore_from must be set")
location = str(config.restore_from)
provider = detect_provider(location)
raw = provider.from_checkpoint(location)
state = cls.model_validate_json(raw, **kwargs)
state._provider = provider
checkpoint_id = provider.extract_id(location)
state._checkpoint_id = checkpoint_id
state._parent_id = checkpoint_id
crewai_event_bus.emit(config, CheckpointRestoreStartedEvent(location=location))
start: float = time.perf_counter()
provider_name: str | None = None
try:
provider = detect_provider(location)
provider_name = type(provider).__name__
raw = provider.from_checkpoint(location)
state = cls.model_validate_json(raw, **kwargs)
state._provider = provider
checkpoint_id = provider.extract_id(location)
state._checkpoint_id = checkpoint_id
state._parent_id = checkpoint_id
except Exception as exc:
crewai_event_bus.emit(
config,
CheckpointRestoreFailedEvent(
location=location,
provider=provider_name,
error=str(exc),
),
)
raise
crewai_event_bus.emit(
config,
CheckpointRestoreCompletedEvent(
location=location,
provider=provider_name,
checkpoint_id=checkpoint_id,
branch=state._branch,
parent_id=state._parent_id,
duration_ms=(time.perf_counter() - start) * 1000.0,
),
)
return state
@classmethod
@@ -260,13 +397,41 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
if config.restore_from is None:
raise ValueError("CheckpointConfig.restore_from must be set")
location = str(config.restore_from)
provider = detect_provider(location)
raw = await provider.afrom_checkpoint(location)
state = cls.model_validate_json(raw, **kwargs)
state._provider = provider
checkpoint_id = provider.extract_id(location)
state._checkpoint_id = checkpoint_id
state._parent_id = checkpoint_id
crewai_event_bus.emit(config, CheckpointRestoreStartedEvent(location=location))
start: float = time.perf_counter()
provider_name: str | None = None
try:
provider = detect_provider(location)
provider_name = type(provider).__name__
raw = await provider.afrom_checkpoint(location)
state = cls.model_validate_json(raw, **kwargs)
state._provider = provider
checkpoint_id = provider.extract_id(location)
state._checkpoint_id = checkpoint_id
state._parent_id = checkpoint_id
except Exception as exc:
crewai_event_bus.emit(
config,
CheckpointRestoreFailedEvent(
location=location,
provider=provider_name,
error=str(exc),
),
)
raise
crewai_event_bus.emit(
config,
CheckpointRestoreCompletedEvent(
location=location,
provider=provider_name,
checkpoint_id=checkpoint_id,
branch=state._branch,
parent_id=state._parent_id,
duration_ms=(time.perf_counter() - start) * 1000.0,
),
)
return state