mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-03 06:08:15 +00:00
feat: emit lifecycle events for checkpoint operations
Some checks failed
Some checks failed
This commit is contained in:
@@ -21,6 +21,7 @@ from crewai.events.depends import Depends
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.handler_graph import CircularDependencyError
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentEvaluationCompletedEvent,
|
||||
@@ -33,6 +34,20 @@ if TYPE_CHECKING:
|
||||
LiteAgentExecutionErrorEvent,
|
||||
LiteAgentExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.checkpoint_events import (
|
||||
CheckpointBaseEvent,
|
||||
CheckpointCompletedEvent,
|
||||
CheckpointFailedEvent,
|
||||
CheckpointForkBaseEvent,
|
||||
CheckpointForkCompletedEvent,
|
||||
CheckpointForkStartedEvent,
|
||||
CheckpointPrunedEvent,
|
||||
CheckpointRestoreBaseEvent,
|
||||
CheckpointRestoreCompletedEvent,
|
||||
CheckpointRestoreFailedEvent,
|
||||
CheckpointRestoreStartedEvent,
|
||||
CheckpointStartedEvent,
|
||||
)
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
@@ -141,6 +156,19 @@ _LAZY_EVENT_MAPPING: dict[str, str] = {
|
||||
"LiteAgentExecutionCompletedEvent": "crewai.events.types.agent_events",
|
||||
"LiteAgentExecutionErrorEvent": "crewai.events.types.agent_events",
|
||||
"LiteAgentExecutionStartedEvent": "crewai.events.types.agent_events",
|
||||
# checkpoint_events
|
||||
"CheckpointBaseEvent": "crewai.events.types.checkpoint_events",
|
||||
"CheckpointCompletedEvent": "crewai.events.types.checkpoint_events",
|
||||
"CheckpointFailedEvent": "crewai.events.types.checkpoint_events",
|
||||
"CheckpointForkBaseEvent": "crewai.events.types.checkpoint_events",
|
||||
"CheckpointForkCompletedEvent": "crewai.events.types.checkpoint_events",
|
||||
"CheckpointForkStartedEvent": "crewai.events.types.checkpoint_events",
|
||||
"CheckpointPrunedEvent": "crewai.events.types.checkpoint_events",
|
||||
"CheckpointRestoreBaseEvent": "crewai.events.types.checkpoint_events",
|
||||
"CheckpointRestoreCompletedEvent": "crewai.events.types.checkpoint_events",
|
||||
"CheckpointRestoreFailedEvent": "crewai.events.types.checkpoint_events",
|
||||
"CheckpointRestoreStartedEvent": "crewai.events.types.checkpoint_events",
|
||||
"CheckpointStartedEvent": "crewai.events.types.checkpoint_events",
|
||||
# crew_events
|
||||
"CrewKickoffCompletedEvent": "crewai.events.types.crew_events",
|
||||
"CrewKickoffFailedEvent": "crewai.events.types.crew_events",
|
||||
@@ -265,6 +293,18 @@ __all__ = [
|
||||
"AgentReasoningFailedEvent",
|
||||
"AgentReasoningStartedEvent",
|
||||
"BaseEventListener",
|
||||
"CheckpointBaseEvent",
|
||||
"CheckpointCompletedEvent",
|
||||
"CheckpointFailedEvent",
|
||||
"CheckpointForkBaseEvent",
|
||||
"CheckpointForkCompletedEvent",
|
||||
"CheckpointForkStartedEvent",
|
||||
"CheckpointPrunedEvent",
|
||||
"CheckpointRestoreBaseEvent",
|
||||
"CheckpointRestoreCompletedEvent",
|
||||
"CheckpointRestoreFailedEvent",
|
||||
"CheckpointRestoreStartedEvent",
|
||||
"CheckpointStartedEvent",
|
||||
"CircularDependencyError",
|
||||
"CrewKickoffCompletedEvent",
|
||||
"CrewKickoffFailedEvent",
|
||||
|
||||
@@ -30,6 +30,17 @@ from crewai.events.types.agent_events import (
|
||||
AgentExecutionStartedEvent,
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
)
|
||||
from crewai.events.types.checkpoint_events import (
|
||||
CheckpointCompletedEvent,
|
||||
CheckpointFailedEvent,
|
||||
CheckpointForkCompletedEvent,
|
||||
CheckpointForkStartedEvent,
|
||||
CheckpointPrunedEvent,
|
||||
CheckpointRestoreCompletedEvent,
|
||||
CheckpointRestoreFailedEvent,
|
||||
CheckpointRestoreStartedEvent,
|
||||
CheckpointStartedEvent,
|
||||
)
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
@@ -183,4 +194,13 @@ EventTypes = (
|
||||
| MCPToolExecutionCompletedEvent
|
||||
| MCPToolExecutionFailedEvent
|
||||
| MCPConfigFetchFailedEvent
|
||||
| CheckpointStartedEvent
|
||||
| CheckpointCompletedEvent
|
||||
| CheckpointFailedEvent
|
||||
| CheckpointForkStartedEvent
|
||||
| CheckpointForkCompletedEvent
|
||||
| CheckpointRestoreStartedEvent
|
||||
| CheckpointRestoreCompletedEvent
|
||||
| CheckpointRestoreFailedEvent
|
||||
| CheckpointPrunedEvent
|
||||
)
|
||||
|
||||
97
lib/crewai/src/crewai/events/types/checkpoint_events.py
Normal file
97
lib/crewai/src/crewai/events/types/checkpoint_events.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Event family for automatic state checkpointing and forking."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from crewai.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class CheckpointBaseEvent(BaseEvent):
|
||||
"""Base event for checkpoint lifecycle operations."""
|
||||
|
||||
type: str
|
||||
location: str
|
||||
provider: str
|
||||
trigger: str | None = None
|
||||
branch: str | None = None
|
||||
parent_id: str | None = None
|
||||
|
||||
|
||||
class CheckpointStartedEvent(CheckpointBaseEvent):
|
||||
"""Event emitted immediately before a checkpoint is written."""
|
||||
|
||||
type: Literal["checkpoint_started"] = "checkpoint_started"
|
||||
|
||||
|
||||
class CheckpointCompletedEvent(CheckpointBaseEvent):
|
||||
"""Event emitted when a checkpoint has been written successfully."""
|
||||
|
||||
type: Literal["checkpoint_completed"] = "checkpoint_completed"
|
||||
checkpoint_id: str
|
||||
duration_ms: float
|
||||
|
||||
|
||||
class CheckpointFailedEvent(CheckpointBaseEvent):
|
||||
"""Event emitted when a checkpoint write fails."""
|
||||
|
||||
type: Literal["checkpoint_failed"] = "checkpoint_failed"
|
||||
error: str
|
||||
|
||||
|
||||
class CheckpointPrunedEvent(CheckpointBaseEvent):
|
||||
"""Event emitted after pruning old checkpoints from a branch."""
|
||||
|
||||
type: Literal["checkpoint_pruned"] = "checkpoint_pruned"
|
||||
removed_count: int
|
||||
max_checkpoints: int
|
||||
|
||||
|
||||
class CheckpointForkBaseEvent(BaseEvent):
|
||||
"""Base event for fork lifecycle operations on a RuntimeState."""
|
||||
|
||||
type: str
|
||||
branch: str
|
||||
parent_branch: str | None = None
|
||||
parent_checkpoint_id: str | None = None
|
||||
|
||||
|
||||
class CheckpointForkStartedEvent(CheckpointForkBaseEvent):
|
||||
"""Event emitted immediately before a fork relabels the branch."""
|
||||
|
||||
type: Literal["checkpoint_fork_started"] = "checkpoint_fork_started"
|
||||
|
||||
|
||||
class CheckpointForkCompletedEvent(CheckpointForkBaseEvent):
|
||||
"""Event emitted after a fork has established the new branch."""
|
||||
|
||||
type: Literal["checkpoint_fork_completed"] = "checkpoint_fork_completed"
|
||||
|
||||
|
||||
class CheckpointRestoreBaseEvent(BaseEvent):
|
||||
"""Base event for checkpoint restore lifecycle operations."""
|
||||
|
||||
type: str
|
||||
location: str
|
||||
provider: str | None = None
|
||||
|
||||
|
||||
class CheckpointRestoreStartedEvent(CheckpointRestoreBaseEvent):
|
||||
"""Event emitted immediately before a checkpoint restore begins."""
|
||||
|
||||
type: Literal["checkpoint_restore_started"] = "checkpoint_restore_started"
|
||||
|
||||
|
||||
class CheckpointRestoreCompletedEvent(CheckpointRestoreBaseEvent):
|
||||
"""Event emitted when a checkpoint has been restored successfully."""
|
||||
|
||||
type: Literal["checkpoint_restore_completed"] = "checkpoint_restore_completed"
|
||||
checkpoint_id: str
|
||||
branch: str | None = None
|
||||
parent_id: str | None = None
|
||||
duration_ms: float
|
||||
|
||||
|
||||
class CheckpointRestoreFailedEvent(CheckpointRestoreBaseEvent):
|
||||
"""Event emitted when a checkpoint restore fails."""
|
||||
|
||||
type: Literal["checkpoint_restore_failed"] = "checkpoint_restore_failed"
|
||||
error: str
|
||||
@@ -10,12 +10,22 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.crew import Crew
|
||||
from crewai.events.base_events import BaseEvent
|
||||
from crewai.events.event_bus import CrewAIEventsBus, crewai_event_bus
|
||||
from crewai.events.types.checkpoint_events import (
|
||||
CheckpointBaseEvent,
|
||||
CheckpointCompletedEvent,
|
||||
CheckpointFailedEvent,
|
||||
CheckpointForkBaseEvent,
|
||||
CheckpointPrunedEvent,
|
||||
CheckpointRestoreBaseEvent,
|
||||
CheckpointStartedEvent,
|
||||
)
|
||||
from crewai.flow.flow import Flow
|
||||
from crewai.state.checkpoint_config import CheckpointConfig
|
||||
from crewai.state.runtime import RuntimeState, _prepare_entities
|
||||
@@ -107,27 +117,106 @@ def _do_checkpoint(
|
||||
state: RuntimeState, cfg: CheckpointConfig, event: BaseEvent | None = None
|
||||
) -> None:
|
||||
"""Write a checkpoint and prune old ones if configured."""
|
||||
_prepare_entities(state.root)
|
||||
payload = state.model_dump(mode="json")
|
||||
if event is not None:
|
||||
payload["trigger"] = event.type
|
||||
data = json.dumps(payload)
|
||||
location = cfg.provider.checkpoint(
|
||||
data,
|
||||
cfg.location,
|
||||
parent_id=state._parent_id,
|
||||
branch=state._branch,
|
||||
)
|
||||
state._chain_lineage(cfg.provider, location)
|
||||
provider_name: str = type(cfg.provider).__name__
|
||||
trigger: str | None = event.type if event is not None else None
|
||||
context: dict[str, Any] = {
|
||||
"task_id": event.task_id if event is not None else None,
|
||||
"task_name": event.task_name if event is not None else None,
|
||||
"agent_id": event.agent_id if event is not None else None,
|
||||
"agent_role": event.agent_role if event is not None else None,
|
||||
}
|
||||
|
||||
checkpoint_id: str = cfg.provider.extract_id(location)
|
||||
parent_id_snapshot: str | None = state._parent_id
|
||||
branch_snapshot: str = state._branch
|
||||
|
||||
crewai_event_bus.emit(
|
||||
cfg,
|
||||
CheckpointStartedEvent(
|
||||
location=cfg.location,
|
||||
provider=provider_name,
|
||||
trigger=trigger,
|
||||
branch=branch_snapshot,
|
||||
parent_id=parent_id_snapshot,
|
||||
**context,
|
||||
),
|
||||
)
|
||||
|
||||
start: float = time.perf_counter()
|
||||
try:
|
||||
_prepare_entities(state.root)
|
||||
payload = state.model_dump(mode="json")
|
||||
if event is not None:
|
||||
payload["trigger"] = event.type
|
||||
data = json.dumps(payload)
|
||||
location = cfg.provider.checkpoint(
|
||||
data,
|
||||
cfg.location,
|
||||
parent_id=parent_id_snapshot,
|
||||
branch=branch_snapshot,
|
||||
)
|
||||
state._chain_lineage(cfg.provider, location)
|
||||
checkpoint_id: str = cfg.provider.extract_id(location)
|
||||
except Exception as exc:
|
||||
crewai_event_bus.emit(
|
||||
cfg,
|
||||
CheckpointFailedEvent(
|
||||
location=cfg.location,
|
||||
provider=provider_name,
|
||||
trigger=trigger,
|
||||
branch=branch_snapshot,
|
||||
parent_id=parent_id_snapshot,
|
||||
error=str(exc),
|
||||
**context,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
duration_ms: float = (time.perf_counter() - start) * 1000.0
|
||||
msg: str = (
|
||||
f"Checkpoint saved. Resume with: crewai checkpoint resume {checkpoint_id}"
|
||||
)
|
||||
logger.info(msg)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
cfg,
|
||||
CheckpointCompletedEvent(
|
||||
location=location,
|
||||
provider=provider_name,
|
||||
trigger=trigger,
|
||||
branch=branch_snapshot,
|
||||
parent_id=parent_id_snapshot,
|
||||
checkpoint_id=checkpoint_id,
|
||||
duration_ms=duration_ms,
|
||||
**context,
|
||||
),
|
||||
)
|
||||
|
||||
if cfg.max_checkpoints is not None:
|
||||
cfg.provider.prune(cfg.location, cfg.max_checkpoints, branch=state._branch)
|
||||
try:
|
||||
removed_count: int = cfg.provider.prune(
|
||||
cfg.location, cfg.max_checkpoints, branch=branch_snapshot
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Checkpoint prune failed for %s (branch=%s)",
|
||||
cfg.location,
|
||||
branch_snapshot,
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
crewai_event_bus.emit(
|
||||
cfg,
|
||||
CheckpointPrunedEvent(
|
||||
location=cfg.location,
|
||||
provider=provider_name,
|
||||
trigger=trigger,
|
||||
branch=branch_snapshot,
|
||||
parent_id=parent_id_snapshot,
|
||||
removed_count=removed_count,
|
||||
max_checkpoints=cfg.max_checkpoints,
|
||||
**context,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _should_checkpoint(source: Any, event: BaseEvent) -> CheckpointConfig | None:
|
||||
@@ -142,6 +231,11 @@ def _should_checkpoint(source: Any, event: BaseEvent) -> CheckpointConfig | None
|
||||
|
||||
def _on_any_event(source: Any, event: BaseEvent, state: Any) -> None:
|
||||
"""Sync handler registered on every event class."""
|
||||
if isinstance(
|
||||
event,
|
||||
(CheckpointBaseEvent, CheckpointForkBaseEvent, CheckpointRestoreBaseEvent),
|
||||
):
|
||||
return
|
||||
cfg = _should_checkpoint(source, event)
|
||||
if cfg is None:
|
||||
return
|
||||
|
||||
@@ -61,13 +61,16 @@ class BaseProvider(BaseModel, ABC):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def prune(self, location: str, max_keep: int, *, branch: str = "main") -> None:
|
||||
def prune(self, location: str, max_keep: int, *, branch: str = "main") -> int:
|
||||
"""Remove old checkpoints, keeping at most *max_keep* per branch.
|
||||
|
||||
Args:
|
||||
location: The storage destination passed to ``checkpoint``.
|
||||
max_keep: Maximum number of checkpoints to retain.
|
||||
branch: Only prune checkpoints on this branch.
|
||||
|
||||
Returns:
|
||||
The number of checkpoints removed.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@@ -95,17 +95,20 @@ class JsonProvider(BaseProvider):
|
||||
await f.write(data)
|
||||
return str(file_path)
|
||||
|
||||
def prune(self, location: str, max_keep: int, *, branch: str = "main") -> None:
|
||||
def prune(self, location: str, max_keep: int, *, branch: str = "main") -> int:
|
||||
"""Remove oldest checkpoint files beyond *max_keep* on a branch."""
|
||||
_safe_branch(location, branch)
|
||||
branch_dir = os.path.join(location, branch)
|
||||
pattern = os.path.join(branch_dir, "*.json")
|
||||
files = sorted(glob.glob(pattern), key=os.path.getmtime)
|
||||
removed = 0
|
||||
for path in files if max_keep == 0 else files[:-max_keep]:
|
||||
try:
|
||||
os.remove(path)
|
||||
removed += 1
|
||||
except OSError: # noqa: PERF203
|
||||
logger.debug("Failed to remove %s", path, exc_info=True)
|
||||
return removed
|
||||
|
||||
def extract_id(self, location: str) -> str:
|
||||
"""Extract the checkpoint ID from a file path.
|
||||
|
||||
@@ -111,11 +111,13 @@ class SqliteProvider(BaseProvider):
|
||||
await db.commit()
|
||||
return f"{location}#{checkpoint_id}"
|
||||
|
||||
def prune(self, location: str, max_keep: int, *, branch: str = "main") -> None:
|
||||
def prune(self, location: str, max_keep: int, *, branch: str = "main") -> int:
|
||||
"""Remove oldest checkpoint rows beyond *max_keep* on a branch."""
|
||||
with sqlite3.connect(location) as conn:
|
||||
conn.execute(_PRUNE, (branch, branch, max_keep))
|
||||
cursor = conn.execute(_PRUNE, (branch, branch, max_keep))
|
||||
removed: int = cursor.rowcount
|
||||
conn.commit()
|
||||
return max(removed, 0)
|
||||
|
||||
def extract_id(self, location: str) -> str:
|
||||
"""Extract the checkpoint ID from a ``db_path#id`` string."""
|
||||
|
||||
@@ -10,6 +10,7 @@ via ``RuntimeState.model_rebuild()``.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import uuid
|
||||
|
||||
@@ -23,6 +24,17 @@ from pydantic import (
|
||||
)
|
||||
|
||||
from crewai.context import capture_execution_context
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.checkpoint_events import (
|
||||
CheckpointCompletedEvent,
|
||||
CheckpointFailedEvent,
|
||||
CheckpointForkCompletedEvent,
|
||||
CheckpointForkStartedEvent,
|
||||
CheckpointRestoreCompletedEvent,
|
||||
CheckpointRestoreFailedEvent,
|
||||
CheckpointRestoreStartedEvent,
|
||||
CheckpointStartedEvent,
|
||||
)
|
||||
from crewai.state.checkpoint_config import CheckpointConfig
|
||||
from crewai.state.event_record import EventRecord
|
||||
from crewai.state.provider.core import BaseProvider
|
||||
@@ -169,14 +181,52 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
Returns:
|
||||
A location identifier for the saved checkpoint.
|
||||
"""
|
||||
_prepare_entities(self.root)
|
||||
result = self._provider.checkpoint(
|
||||
self.model_dump_json(),
|
||||
location,
|
||||
parent_id=self._parent_id,
|
||||
branch=self._branch,
|
||||
provider_name: str = type(self._provider).__name__
|
||||
parent_id_snapshot: str | None = self._parent_id
|
||||
branch_snapshot: str = self._branch
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CheckpointStartedEvent(
|
||||
location=location,
|
||||
provider=provider_name,
|
||||
branch=branch_snapshot,
|
||||
parent_id=parent_id_snapshot,
|
||||
),
|
||||
)
|
||||
start: float = time.perf_counter()
|
||||
try:
|
||||
_prepare_entities(self.root)
|
||||
result = self._provider.checkpoint(
|
||||
self.model_dump_json(),
|
||||
location,
|
||||
parent_id=parent_id_snapshot,
|
||||
branch=branch_snapshot,
|
||||
)
|
||||
self._chain_lineage(self._provider, result)
|
||||
except Exception as exc:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CheckpointFailedEvent(
|
||||
location=location,
|
||||
provider=provider_name,
|
||||
branch=branch_snapshot,
|
||||
parent_id=parent_id_snapshot,
|
||||
error=str(exc),
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CheckpointCompletedEvent(
|
||||
location=result,
|
||||
provider=provider_name,
|
||||
branch=branch_snapshot,
|
||||
parent_id=parent_id_snapshot,
|
||||
checkpoint_id=self._provider.extract_id(result),
|
||||
duration_ms=(time.perf_counter() - start) * 1000.0,
|
||||
),
|
||||
)
|
||||
self._chain_lineage(self._provider, result)
|
||||
return result
|
||||
|
||||
async def acheckpoint(self, location: str) -> str:
|
||||
@@ -189,14 +239,52 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
Returns:
|
||||
A location identifier for the saved checkpoint.
|
||||
"""
|
||||
_prepare_entities(self.root)
|
||||
result = await self._provider.acheckpoint(
|
||||
self.model_dump_json(),
|
||||
location,
|
||||
parent_id=self._parent_id,
|
||||
branch=self._branch,
|
||||
provider_name: str = type(self._provider).__name__
|
||||
parent_id_snapshot: str | None = self._parent_id
|
||||
branch_snapshot: str = self._branch
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CheckpointStartedEvent(
|
||||
location=location,
|
||||
provider=provider_name,
|
||||
branch=branch_snapshot,
|
||||
parent_id=parent_id_snapshot,
|
||||
),
|
||||
)
|
||||
start: float = time.perf_counter()
|
||||
try:
|
||||
_prepare_entities(self.root)
|
||||
result = await self._provider.acheckpoint(
|
||||
self.model_dump_json(),
|
||||
location,
|
||||
parent_id=parent_id_snapshot,
|
||||
branch=branch_snapshot,
|
||||
)
|
||||
self._chain_lineage(self._provider, result)
|
||||
except Exception as exc:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CheckpointFailedEvent(
|
||||
location=location,
|
||||
provider=provider_name,
|
||||
branch=branch_snapshot,
|
||||
parent_id=parent_id_snapshot,
|
||||
error=str(exc),
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CheckpointCompletedEvent(
|
||||
location=result,
|
||||
provider=provider_name,
|
||||
branch=branch_snapshot,
|
||||
parent_id=parent_id_snapshot,
|
||||
checkpoint_id=self._provider.extract_id(result),
|
||||
duration_ms=(time.perf_counter() - start) * 1000.0,
|
||||
),
|
||||
)
|
||||
self._chain_lineage(self._provider, result)
|
||||
return result
|
||||
|
||||
def fork(self, branch: str | None = None) -> None:
|
||||
@@ -211,11 +299,32 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
times without collisions.
|
||||
"""
|
||||
if branch:
|
||||
self._branch = branch
|
||||
new_branch = branch
|
||||
elif self._checkpoint_id:
|
||||
self._branch = f"fork/{self._checkpoint_id}_{uuid.uuid4().hex[:6]}"
|
||||
new_branch = f"fork/{self._checkpoint_id}_{uuid.uuid4().hex[:6]}"
|
||||
else:
|
||||
self._branch = f"fork/{uuid.uuid4().hex[:8]}"
|
||||
new_branch = f"fork/{uuid.uuid4().hex[:8]}"
|
||||
|
||||
parent_branch: str | None = self._branch
|
||||
parent_checkpoint_id: str | None = self._checkpoint_id
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CheckpointForkStartedEvent(
|
||||
branch=new_branch,
|
||||
parent_branch=parent_branch,
|
||||
parent_checkpoint_id=parent_checkpoint_id,
|
||||
),
|
||||
)
|
||||
self._branch = new_branch
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CheckpointForkCompletedEvent(
|
||||
branch=new_branch,
|
||||
parent_branch=parent_branch,
|
||||
parent_checkpoint_id=parent_checkpoint_id,
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(cls, config: CheckpointConfig, **kwargs: Any) -> RuntimeState:
|
||||
@@ -233,13 +342,41 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
if config.restore_from is None:
|
||||
raise ValueError("CheckpointConfig.restore_from must be set")
|
||||
location = str(config.restore_from)
|
||||
provider = detect_provider(location)
|
||||
raw = provider.from_checkpoint(location)
|
||||
state = cls.model_validate_json(raw, **kwargs)
|
||||
state._provider = provider
|
||||
checkpoint_id = provider.extract_id(location)
|
||||
state._checkpoint_id = checkpoint_id
|
||||
state._parent_id = checkpoint_id
|
||||
|
||||
crewai_event_bus.emit(config, CheckpointRestoreStartedEvent(location=location))
|
||||
start: float = time.perf_counter()
|
||||
provider_name: str | None = None
|
||||
try:
|
||||
provider = detect_provider(location)
|
||||
provider_name = type(provider).__name__
|
||||
raw = provider.from_checkpoint(location)
|
||||
state = cls.model_validate_json(raw, **kwargs)
|
||||
state._provider = provider
|
||||
checkpoint_id = provider.extract_id(location)
|
||||
state._checkpoint_id = checkpoint_id
|
||||
state._parent_id = checkpoint_id
|
||||
except Exception as exc:
|
||||
crewai_event_bus.emit(
|
||||
config,
|
||||
CheckpointRestoreFailedEvent(
|
||||
location=location,
|
||||
provider=provider_name,
|
||||
error=str(exc),
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
crewai_event_bus.emit(
|
||||
config,
|
||||
CheckpointRestoreCompletedEvent(
|
||||
location=location,
|
||||
provider=provider_name,
|
||||
checkpoint_id=checkpoint_id,
|
||||
branch=state._branch,
|
||||
parent_id=state._parent_id,
|
||||
duration_ms=(time.perf_counter() - start) * 1000.0,
|
||||
),
|
||||
)
|
||||
return state
|
||||
|
||||
@classmethod
|
||||
@@ -260,13 +397,41 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
if config.restore_from is None:
|
||||
raise ValueError("CheckpointConfig.restore_from must be set")
|
||||
location = str(config.restore_from)
|
||||
provider = detect_provider(location)
|
||||
raw = await provider.afrom_checkpoint(location)
|
||||
state = cls.model_validate_json(raw, **kwargs)
|
||||
state._provider = provider
|
||||
checkpoint_id = provider.extract_id(location)
|
||||
state._checkpoint_id = checkpoint_id
|
||||
state._parent_id = checkpoint_id
|
||||
|
||||
crewai_event_bus.emit(config, CheckpointRestoreStartedEvent(location=location))
|
||||
start: float = time.perf_counter()
|
||||
provider_name: str | None = None
|
||||
try:
|
||||
provider = detect_provider(location)
|
||||
provider_name = type(provider).__name__
|
||||
raw = await provider.afrom_checkpoint(location)
|
||||
state = cls.model_validate_json(raw, **kwargs)
|
||||
state._provider = provider
|
||||
checkpoint_id = provider.extract_id(location)
|
||||
state._checkpoint_id = checkpoint_id
|
||||
state._parent_id = checkpoint_id
|
||||
except Exception as exc:
|
||||
crewai_event_bus.emit(
|
||||
config,
|
||||
CheckpointRestoreFailedEvent(
|
||||
location=location,
|
||||
provider=provider_name,
|
||||
error=str(exc),
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
crewai_event_bus.emit(
|
||||
config,
|
||||
CheckpointRestoreCompletedEvent(
|
||||
location=location,
|
||||
provider=provider_name,
|
||||
checkpoint_id=checkpoint_id,
|
||||
branch=state._branch,
|
||||
parent_id=state._parent_id,
|
||||
duration_ms=(time.perf_counter() - start) * 1000.0,
|
||||
),
|
||||
)
|
||||
return state
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user