mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
feat: add checkpoint forking with lineage tracking
This commit is contained in:
@@ -400,6 +400,34 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return entity
|
||||
raise ValueError(f"No Crew found in checkpoint: {path}")
|
||||
|
||||
@classmethod
|
||||
def fork(
|
||||
cls,
|
||||
path: str,
|
||||
*,
|
||||
branch: str | None = None,
|
||||
provider: BaseProvider | None = None,
|
||||
) -> Crew:
|
||||
"""Fork a Crew from a checkpoint, creating a new execution branch.
|
||||
|
||||
Args:
|
||||
path: Path to a checkpoint file.
|
||||
branch: Branch label for the fork. Auto-generated if not provided.
|
||||
provider: Storage backend to read from. Defaults to auto-detect.
|
||||
|
||||
Returns:
|
||||
A Crew instance on the new branch. Call kickoff() to run.
|
||||
"""
|
||||
crew = cls.from_checkpoint(path, provider=provider)
|
||||
state = crewai_event_bus._runtime_state
|
||||
if state is None:
|
||||
raise RuntimeError(
|
||||
"Cannot fork: no runtime state on the event bus. "
|
||||
"Ensure from_checkpoint() succeeded before calling fork()."
|
||||
)
|
||||
state.fork(branch)
|
||||
return crew
|
||||
|
||||
def _restore_runtime(self) -> None:
|
||||
"""Re-create runtime objects after restoring from a checkpoint."""
|
||||
for agent in self.agents:
|
||||
|
||||
@@ -960,6 +960,34 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
return instance
|
||||
raise ValueError(f"No Flow found in checkpoint: {path}")
|
||||
|
||||
@classmethod
|
||||
def fork(
|
||||
cls,
|
||||
path: str,
|
||||
*,
|
||||
branch: str | None = None,
|
||||
provider: BaseProvider | None = None,
|
||||
) -> Flow: # type: ignore[type-arg]
|
||||
"""Fork a Flow from a checkpoint, creating a new execution branch.
|
||||
|
||||
Args:
|
||||
path: Path to a checkpoint file.
|
||||
branch: Branch label for the fork. Auto-generated if not provided.
|
||||
provider: Storage backend to read from. Defaults to auto-detect.
|
||||
|
||||
Returns:
|
||||
A Flow instance on the new branch. Call kickoff() to run.
|
||||
"""
|
||||
flow = cls.from_checkpoint(path, provider=provider)
|
||||
state = crewai_event_bus._runtime_state
|
||||
if state is None:
|
||||
raise RuntimeError(
|
||||
"Cannot fork: no runtime state on the event bus. "
|
||||
"Ensure from_checkpoint() succeeded before calling fork()."
|
||||
)
|
||||
state.fork(branch)
|
||||
return flow
|
||||
|
||||
checkpoint_completed_methods: set[str] | None = Field(default=None)
|
||||
checkpoint_method_outputs: list[Any] | None = Field(default=None)
|
||||
checkpoint_method_counts: dict[str, int] | None = Field(default=None)
|
||||
|
||||
@@ -106,10 +106,16 @@ def _do_checkpoint(state: RuntimeState, cfg: CheckpointConfig) -> None:
|
||||
"""Write a checkpoint and prune old ones if configured."""
|
||||
_prepare_entities(state.root)
|
||||
data = state.model_dump_json()
|
||||
cfg.provider.checkpoint(data, cfg.location)
|
||||
location = cfg.provider.checkpoint(
|
||||
data,
|
||||
cfg.location,
|
||||
parent_id=state._parent_id,
|
||||
branch=state._branch,
|
||||
)
|
||||
state._chain_lineage(cfg.provider, location)
|
||||
|
||||
if cfg.max_checkpoints is not None:
|
||||
cfg.provider.prune(cfg.location, cfg.max_checkpoints)
|
||||
cfg.provider.prune(cfg.location, cfg.max_checkpoints, branch=state._branch)
|
||||
|
||||
|
||||
def _should_checkpoint(source: Any, event: BaseEvent) -> CheckpointConfig | None:
|
||||
|
||||
@@ -17,12 +17,21 @@ class BaseProvider(BaseModel, ABC):
|
||||
provider_type: str = "base"
|
||||
|
||||
@abstractmethod
|
||||
def checkpoint(self, data: str, location: str) -> str:
|
||||
def checkpoint(
|
||||
self,
|
||||
data: str,
|
||||
location: str,
|
||||
*,
|
||||
parent_id: str | None = None,
|
||||
branch: str = "main",
|
||||
) -> str:
|
||||
"""Persist a snapshot synchronously.
|
||||
|
||||
Args:
|
||||
data: The serialized string to persist.
|
||||
location: Storage destination (directory, file path, URI, etc.).
|
||||
parent_id: ID of the parent checkpoint for lineage tracking.
|
||||
branch: Branch label for this checkpoint.
|
||||
|
||||
Returns:
|
||||
A location identifier for the saved checkpoint.
|
||||
@@ -30,12 +39,21 @@ class BaseProvider(BaseModel, ABC):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def acheckpoint(self, data: str, location: str) -> str:
|
||||
async def acheckpoint(
|
||||
self,
|
||||
data: str,
|
||||
location: str,
|
||||
*,
|
||||
parent_id: str | None = None,
|
||||
branch: str = "main",
|
||||
) -> str:
|
||||
"""Persist a snapshot asynchronously.
|
||||
|
||||
Args:
|
||||
data: The serialized string to persist.
|
||||
location: Storage destination (directory, file path, URI, etc.).
|
||||
parent_id: ID of the parent checkpoint for lineage tracking.
|
||||
branch: Branch label for this checkpoint.
|
||||
|
||||
Returns:
|
||||
A location identifier for the saved checkpoint.
|
||||
@@ -43,12 +61,25 @@ class BaseProvider(BaseModel, ABC):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def prune(self, location: str, max_keep: int) -> None:
|
||||
"""Remove old checkpoints, keeping at most *max_keep*.
|
||||
def prune(self, location: str, max_keep: int, *, branch: str = "main") -> None:
|
||||
"""Remove old checkpoints, keeping at most *max_keep* per branch.
|
||||
|
||||
Args:
|
||||
location: The storage destination passed to ``checkpoint``.
|
||||
max_keep: Maximum number of checkpoints to retain.
|
||||
branch: Only prune checkpoints on this branch.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def extract_id(self, location: str) -> str:
|
||||
"""Extract the checkpoint ID from a location string.
|
||||
|
||||
Args:
|
||||
location: The identifier returned by a previous ``checkpoint`` call.
|
||||
|
||||
Returns:
|
||||
The checkpoint ID.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@@ -19,48 +19,87 @@ from crewai.state.provider.core import BaseProvider
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _safe_branch(base: str, branch: str) -> None:
|
||||
"""Validate that a branch name doesn't escape the base directory.
|
||||
|
||||
Raises:
|
||||
ValueError: If the branch resolves outside the base directory.
|
||||
"""
|
||||
base_resolved = str(Path(base).resolve())
|
||||
target_resolved = str((Path(base) / branch).resolve())
|
||||
if (
|
||||
not target_resolved.startswith(base_resolved + os.sep)
|
||||
and target_resolved != base_resolved
|
||||
):
|
||||
raise ValueError(f"Branch name escapes checkpoint directory: {branch!r}")
|
||||
|
||||
|
||||
class JsonProvider(BaseProvider):
|
||||
"""Persists runtime state checkpoints as JSON files on the local filesystem."""
|
||||
|
||||
provider_type: Literal["json"] = "json"
|
||||
|
||||
def checkpoint(self, data: str, location: str) -> str:
|
||||
def checkpoint(
|
||||
self,
|
||||
data: str,
|
||||
location: str,
|
||||
*,
|
||||
parent_id: str | None = None,
|
||||
branch: str = "main",
|
||||
) -> str:
|
||||
"""Write a JSON checkpoint file.
|
||||
|
||||
Args:
|
||||
data: The serialized JSON string to persist.
|
||||
location: Directory where the checkpoint will be saved.
|
||||
location: Base directory where checkpoints are saved.
|
||||
parent_id: ID of the parent checkpoint for lineage tracking.
|
||||
Encoded in the filename for queryable lineage without
|
||||
parsing the blob.
|
||||
branch: Branch label. Files are stored under ``location/branch/``.
|
||||
|
||||
Returns:
|
||||
The path to the written checkpoint file.
|
||||
"""
|
||||
file_path = _build_path(location)
|
||||
file_path = _build_path(location, branch, parent_id)
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(file_path, "w") as f:
|
||||
f.write(data)
|
||||
return str(file_path)
|
||||
|
||||
async def acheckpoint(self, data: str, location: str) -> str:
|
||||
async def acheckpoint(
|
||||
self,
|
||||
data: str,
|
||||
location: str,
|
||||
*,
|
||||
parent_id: str | None = None,
|
||||
branch: str = "main",
|
||||
) -> str:
|
||||
"""Write a JSON checkpoint file asynchronously.
|
||||
|
||||
Args:
|
||||
data: The serialized JSON string to persist.
|
||||
location: Directory where the checkpoint will be saved.
|
||||
location: Base directory where checkpoints are saved.
|
||||
parent_id: ID of the parent checkpoint for lineage tracking.
|
||||
Encoded in the filename for queryable lineage without
|
||||
parsing the blob.
|
||||
branch: Branch label. Files are stored under ``location/branch/``.
|
||||
|
||||
Returns:
|
||||
The path to the written checkpoint file.
|
||||
"""
|
||||
file_path = _build_path(location)
|
||||
file_path = _build_path(location, branch, parent_id)
|
||||
await aiofiles.os.makedirs(str(file_path.parent), exist_ok=True)
|
||||
|
||||
async with aiofiles.open(file_path, "w") as f:
|
||||
await f.write(data)
|
||||
return str(file_path)
|
||||
|
||||
def prune(self, location: str, max_keep: int) -> None:
|
||||
"""Remove oldest checkpoint files beyond *max_keep*."""
|
||||
pattern = os.path.join(location, "*.json")
|
||||
def prune(self, location: str, max_keep: int, *, branch: str = "main") -> None:
|
||||
"""Remove oldest checkpoint files beyond *max_keep* on a branch."""
|
||||
_safe_branch(location, branch)
|
||||
branch_dir = os.path.join(location, branch)
|
||||
pattern = os.path.join(branch_dir, "*.json")
|
||||
files = sorted(glob.glob(pattern), key=os.path.getmtime)
|
||||
for path in files if max_keep == 0 else files[:-max_keep]:
|
||||
try:
|
||||
@@ -68,6 +107,16 @@ class JsonProvider(BaseProvider):
|
||||
except OSError: # noqa: PERF203
|
||||
logger.debug("Failed to remove %s", path, exc_info=True)
|
||||
|
||||
def extract_id(self, location: str) -> str:
|
||||
"""Extract the checkpoint ID from a file path.
|
||||
|
||||
The filename format is ``{ts}_{uuid8}_p-{parent}.json``.
|
||||
The checkpoint ID is the ``{ts}_{uuid8}`` prefix.
|
||||
"""
|
||||
stem = Path(location).stem
|
||||
idx = stem.find("_p-")
|
||||
return stem[:idx] if idx != -1 else stem
|
||||
|
||||
def from_checkpoint(self, location: str) -> str:
|
||||
"""Read a JSON checkpoint file.
|
||||
|
||||
@@ -92,15 +141,24 @@ class JsonProvider(BaseProvider):
|
||||
return await f.read()
|
||||
|
||||
|
||||
def _build_path(directory: str) -> Path:
|
||||
"""Build a timestamped checkpoint file path.
|
||||
def _build_path(
|
||||
directory: str, branch: str = "main", parent_id: str | None = None
|
||||
) -> Path:
|
||||
"""Build a timestamped checkpoint file path under a branch subdirectory.
|
||||
|
||||
Filename format: ``{ts}_{uuid8}_p-{parent_id}.json``
|
||||
|
||||
Args:
|
||||
directory: Parent directory for the checkpoint file.
|
||||
directory: Base directory for checkpoints.
|
||||
branch: Branch label used as a subdirectory name.
|
||||
parent_id: Parent checkpoint ID to encode in the filename.
|
||||
|
||||
Returns:
|
||||
The target file path.
|
||||
"""
|
||||
_safe_branch(directory, branch)
|
||||
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S")
|
||||
filename = f"{ts}_{uuid.uuid4().hex[:8]}.json"
|
||||
return Path(directory) / filename
|
||||
short_uuid = uuid.uuid4().hex[:8]
|
||||
parent_suffix = parent_id or "none"
|
||||
filename = f"{ts}_{short_uuid}_p-{parent_suffix}.json"
|
||||
return Path(directory) / branch / filename
|
||||
|
||||
@@ -17,15 +17,20 @@ _CREATE_TABLE = """
|
||||
CREATE TABLE IF NOT EXISTS checkpoints (
|
||||
id TEXT PRIMARY KEY,
|
||||
created_at TEXT NOT NULL,
|
||||
parent_id TEXT,
|
||||
branch TEXT NOT NULL DEFAULT 'main',
|
||||
data JSONB NOT NULL
|
||||
)
|
||||
"""
|
||||
|
||||
_INSERT = "INSERT INTO checkpoints (id, created_at, data) VALUES (?, ?, jsonb(?))"
|
||||
_INSERT = (
|
||||
"INSERT INTO checkpoints (id, created_at, parent_id, branch, data) "
|
||||
"VALUES (?, ?, ?, ?, jsonb(?))"
|
||||
)
|
||||
_SELECT = "SELECT json(data) FROM checkpoints WHERE id = ?"
|
||||
_PRUNE = """
|
||||
DELETE FROM checkpoints WHERE rowid NOT IN (
|
||||
SELECT rowid FROM checkpoints ORDER BY rowid DESC LIMIT ?
|
||||
DELETE FROM checkpoints WHERE branch = ? AND rowid NOT IN (
|
||||
SELECT rowid FROM checkpoints WHERE branch = ? ORDER BY rowid DESC LIMIT ?
|
||||
)
|
||||
"""
|
||||
|
||||
@@ -50,12 +55,21 @@ class SqliteProvider(BaseProvider):
|
||||
|
||||
provider_type: Literal["sqlite"] = "sqlite"
|
||||
|
||||
def checkpoint(self, data: str, location: str) -> str:
|
||||
def checkpoint(
|
||||
self,
|
||||
data: str,
|
||||
location: str,
|
||||
*,
|
||||
parent_id: str | None = None,
|
||||
branch: str = "main",
|
||||
) -> str:
|
||||
"""Write a checkpoint to the SQLite database.
|
||||
|
||||
Args:
|
||||
data: The serialized JSON string to persist.
|
||||
location: Path to the SQLite database file.
|
||||
parent_id: ID of the parent checkpoint for lineage tracking.
|
||||
branch: Branch label for this checkpoint.
|
||||
|
||||
Returns:
|
||||
A location string in the format ``"db_path#checkpoint_id"``.
|
||||
@@ -65,16 +79,25 @@ class SqliteProvider(BaseProvider):
|
||||
with sqlite3.connect(location) as conn:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute(_CREATE_TABLE)
|
||||
conn.execute(_INSERT, (checkpoint_id, ts, data))
|
||||
conn.execute(_INSERT, (checkpoint_id, ts, parent_id, branch, data))
|
||||
conn.commit()
|
||||
return f"{location}#{checkpoint_id}"
|
||||
|
||||
async def acheckpoint(self, data: str, location: str) -> str:
|
||||
async def acheckpoint(
|
||||
self,
|
||||
data: str,
|
||||
location: str,
|
||||
*,
|
||||
parent_id: str | None = None,
|
||||
branch: str = "main",
|
||||
) -> str:
|
||||
"""Write a checkpoint to the SQLite database asynchronously.
|
||||
|
||||
Args:
|
||||
data: The serialized JSON string to persist.
|
||||
location: Path to the SQLite database file.
|
||||
parent_id: ID of the parent checkpoint for lineage tracking.
|
||||
branch: Branch label for this checkpoint.
|
||||
|
||||
Returns:
|
||||
A location string in the format ``"db_path#checkpoint_id"``.
|
||||
@@ -84,16 +107,20 @@ class SqliteProvider(BaseProvider):
|
||||
async with aiosqlite.connect(location) as db:
|
||||
await db.execute("PRAGMA journal_mode=WAL")
|
||||
await db.execute(_CREATE_TABLE)
|
||||
await db.execute(_INSERT, (checkpoint_id, ts, data))
|
||||
await db.execute(_INSERT, (checkpoint_id, ts, parent_id, branch, data))
|
||||
await db.commit()
|
||||
return f"{location}#{checkpoint_id}"
|
||||
|
||||
def prune(self, location: str, max_keep: int) -> None:
|
||||
"""Remove oldest checkpoint rows beyond *max_keep*."""
|
||||
def prune(self, location: str, max_keep: int, *, branch: str = "main") -> None:
|
||||
"""Remove oldest checkpoint rows beyond *max_keep* on a branch."""
|
||||
with sqlite3.connect(location) as conn:
|
||||
conn.execute(_PRUNE, (max_keep,))
|
||||
conn.execute(_PRUNE, (branch, branch, max_keep))
|
||||
conn.commit()
|
||||
|
||||
def extract_id(self, location: str) -> str:
|
||||
"""Extract the checkpoint ID from a ``db_path#id`` string."""
|
||||
return location.rsplit("#", 1)[1]
|
||||
|
||||
def from_checkpoint(self, location: str) -> str:
|
||||
"""Read a checkpoint from the SQLite database.
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ via ``RuntimeState.model_rebuild()``.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
import uuid
|
||||
|
||||
from pydantic import (
|
||||
ModelWrapValidatorHandler,
|
||||
@@ -64,6 +65,9 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
root: list[Entity]
|
||||
_provider: BaseProvider = PrivateAttr(default_factory=JsonProvider)
|
||||
_event_record: EventRecord = PrivateAttr(default_factory=EventRecord)
|
||||
_checkpoint_id: str | None = PrivateAttr(default=None)
|
||||
_parent_id: str | None = PrivateAttr(default=None)
|
||||
_branch: str = PrivateAttr(default="main")
|
||||
|
||||
@property
|
||||
def event_record(self) -> EventRecord:
|
||||
@@ -73,6 +77,8 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
@model_serializer(mode="plain")
|
||||
def _serialize(self) -> dict[str, Any]:
|
||||
return {
|
||||
"parent_id": self._parent_id,
|
||||
"branch": self._branch,
|
||||
"entities": [e.model_dump(mode="json") for e in self.root],
|
||||
"event_record": self._event_record.model_dump(),
|
||||
}
|
||||
@@ -87,9 +93,24 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
state = handler(data["entities"])
|
||||
if record_data:
|
||||
state._event_record = EventRecord.model_validate(record_data)
|
||||
state._parent_id = data.get("parent_id")
|
||||
state._branch = data.get("branch", "main")
|
||||
return state
|
||||
return handler(data)
|
||||
|
||||
def _chain_lineage(self, provider: BaseProvider, location: str) -> None:
|
||||
"""Update lineage fields after a successful checkpoint write.
|
||||
|
||||
Sets ``_checkpoint_id`` and ``_parent_id`` so the next write
|
||||
records the correct parent in the lineage chain.
|
||||
|
||||
Args:
|
||||
provider: The provider that performed the write.
|
||||
location: The location string returned by the provider.
|
||||
"""
|
||||
self._checkpoint_id = provider.extract_id(location)
|
||||
self._parent_id = self._checkpoint_id
|
||||
|
||||
def checkpoint(self, location: str) -> str:
|
||||
"""Write a checkpoint.
|
||||
|
||||
@@ -101,7 +122,14 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
A location identifier for the saved checkpoint.
|
||||
"""
|
||||
_prepare_entities(self.root)
|
||||
return self._provider.checkpoint(self.model_dump_json(), location)
|
||||
result = self._provider.checkpoint(
|
||||
self.model_dump_json(),
|
||||
location,
|
||||
parent_id=self._parent_id,
|
||||
branch=self._branch,
|
||||
)
|
||||
self._chain_lineage(self._provider, result)
|
||||
return result
|
||||
|
||||
async def acheckpoint(self, location: str) -> str:
|
||||
"""Async version of :meth:`checkpoint`.
|
||||
@@ -114,7 +142,29 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
A location identifier for the saved checkpoint.
|
||||
"""
|
||||
_prepare_entities(self.root)
|
||||
return await self._provider.acheckpoint(self.model_dump_json(), location)
|
||||
result = await self._provider.acheckpoint(
|
||||
self.model_dump_json(),
|
||||
location,
|
||||
parent_id=self._parent_id,
|
||||
branch=self._branch,
|
||||
)
|
||||
self._chain_lineage(self._provider, result)
|
||||
return result
|
||||
|
||||
def fork(self, branch: str | None = None) -> None:
|
||||
"""Mark this state as a fork for subsequent checkpoints.
|
||||
|
||||
Args:
|
||||
branch: Branch label. Auto-generated from the current checkpoint
|
||||
ID if not provided. Always unique — safe to call multiple
|
||||
times without collisions.
|
||||
"""
|
||||
if branch:
|
||||
self._branch = branch
|
||||
elif self._checkpoint_id:
|
||||
self._branch = f"fork/{self._checkpoint_id}"
|
||||
else:
|
||||
self._branch = f"fork/{uuid.uuid4().hex[:8]}"
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
@@ -131,7 +181,11 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
A restored RuntimeState.
|
||||
"""
|
||||
raw = provider.from_checkpoint(location)
|
||||
return cls.model_validate_json(raw, **kwargs)
|
||||
state = cls.model_validate_json(raw, **kwargs)
|
||||
checkpoint_id = provider.extract_id(location)
|
||||
state._checkpoint_id = checkpoint_id
|
||||
state._parent_id = checkpoint_id
|
||||
return state
|
||||
|
||||
@classmethod
|
||||
async def afrom_checkpoint(
|
||||
@@ -148,7 +202,11 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
A restored RuntimeState.
|
||||
"""
|
||||
raw = await provider.afrom_checkpoint(location)
|
||||
return cls.model_validate_json(raw, **kwargs)
|
||||
state = cls.model_validate_json(raw, **kwargs)
|
||||
checkpoint_id = provider.extract_id(location)
|
||||
state._checkpoint_id = checkpoint_id
|
||||
state._parent_id = checkpoint_id
|
||||
return state
|
||||
|
||||
|
||||
def _prepare_entities(root: list[Entity]) -> None:
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
"""Tests for CheckpointConfig, checkpoint listener, and pruning."""
|
||||
"""Tests for CheckpointConfig, checkpoint listener, pruning, and forking."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any
|
||||
@@ -21,6 +23,8 @@ from crewai.state.checkpoint_listener import (
|
||||
_SENTINEL,
|
||||
)
|
||||
from crewai.state.provider.json_provider import JsonProvider
|
||||
from crewai.state.provider.sqlite_provider import SqliteProvider
|
||||
from crewai.state.runtime import RuntimeState
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@@ -116,35 +120,41 @@ class TestFindCheckpoint:
|
||||
class TestPrune:
|
||||
def test_prune_keeps_newest(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
branch_dir = os.path.join(d, "main")
|
||||
os.makedirs(branch_dir)
|
||||
for i in range(5):
|
||||
path = os.path.join(d, f"cp_{i}.json")
|
||||
path = os.path.join(branch_dir, f"cp_{i}.json")
|
||||
with open(path, "w") as f:
|
||||
f.write("{}")
|
||||
# Ensure distinct mtime
|
||||
time.sleep(0.01)
|
||||
|
||||
JsonProvider().prune(d, max_keep=2)
|
||||
remaining = os.listdir(d)
|
||||
JsonProvider().prune(d, max_keep=2, branch="main")
|
||||
remaining = os.listdir(branch_dir)
|
||||
assert len(remaining) == 2
|
||||
assert "cp_3.json" in remaining
|
||||
assert "cp_4.json" in remaining
|
||||
|
||||
def test_prune_zero_removes_all(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
branch_dir = os.path.join(d, "main")
|
||||
os.makedirs(branch_dir)
|
||||
for i in range(3):
|
||||
with open(os.path.join(d, f"cp_{i}.json"), "w") as f:
|
||||
with open(os.path.join(branch_dir, f"cp_{i}.json"), "w") as f:
|
||||
f.write("{}")
|
||||
|
||||
JsonProvider().prune(d, max_keep=0)
|
||||
assert os.listdir(d) == []
|
||||
JsonProvider().prune(d, max_keep=0, branch="main")
|
||||
assert os.listdir(branch_dir) == []
|
||||
|
||||
def test_prune_more_than_existing(self) -> None:
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
with open(os.path.join(d, "cp.json"), "w") as f:
|
||||
branch_dir = os.path.join(d, "main")
|
||||
os.makedirs(branch_dir)
|
||||
with open(os.path.join(branch_dir, "cp.json"), "w") as f:
|
||||
f.write("{}")
|
||||
|
||||
JsonProvider().prune(d, max_keep=10)
|
||||
assert len(os.listdir(d)) == 1
|
||||
JsonProvider().prune(d, max_keep=10, branch="main")
|
||||
assert len(os.listdir(branch_dir)) == 1
|
||||
|
||||
|
||||
# ---------- CheckpointConfig ----------
|
||||
@@ -167,3 +177,273 @@ class TestCheckpointConfig:
|
||||
on_events=["task_completed", "crew_kickoff_completed"]
|
||||
)
|
||||
assert cfg.trigger_events == {"task_completed", "crew_kickoff_completed"}
|
||||
|
||||
|
||||
# ---------- RuntimeState lineage ----------
|
||||
|
||||
|
||||
class TestRuntimeStateLineage:
|
||||
def _make_state(self) -> RuntimeState:
|
||||
from crewai import Agent, Crew
|
||||
|
||||
agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini")
|
||||
crew = Crew(agents=[agent], tasks=[], verbose=False)
|
||||
return RuntimeState(root=[crew])
|
||||
|
||||
def test_default_lineage_fields(self) -> None:
|
||||
state = self._make_state()
|
||||
assert state._checkpoint_id is None
|
||||
assert state._parent_id is None
|
||||
assert state._branch == "main"
|
||||
|
||||
def test_serialize_includes_lineage(self) -> None:
|
||||
state = self._make_state()
|
||||
state._parent_id = "parent456"
|
||||
state._branch = "experiment"
|
||||
dumped = json.loads(state.model_dump_json())
|
||||
assert dumped["parent_id"] == "parent456"
|
||||
assert dumped["branch"] == "experiment"
|
||||
assert "checkpoint_id" not in dumped
|
||||
|
||||
def test_deserialize_restores_lineage(self) -> None:
|
||||
state = self._make_state()
|
||||
state._parent_id = "parent456"
|
||||
state._branch = "experiment"
|
||||
raw = state.model_dump_json()
|
||||
restored = RuntimeState.model_validate_json(
|
||||
raw, context={"from_checkpoint": True}
|
||||
)
|
||||
assert restored._parent_id == "parent456"
|
||||
assert restored._branch == "experiment"
|
||||
|
||||
def test_deserialize_defaults_missing_lineage(self) -> None:
|
||||
state = self._make_state()
|
||||
raw = state.model_dump_json()
|
||||
data = json.loads(raw)
|
||||
data.pop("parent_id", None)
|
||||
data.pop("branch", None)
|
||||
restored = RuntimeState.model_validate_json(
|
||||
json.dumps(data), context={"from_checkpoint": True}
|
||||
)
|
||||
assert restored._parent_id is None
|
||||
assert restored._branch == "main"
|
||||
|
||||
def test_from_checkpoint_sets_checkpoint_id(self) -> None:
|
||||
"""from_checkpoint sets _checkpoint_id from the location, not the blob."""
|
||||
state = self._make_state()
|
||||
state._provider = JsonProvider()
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
loc = state.checkpoint(d)
|
||||
written_id = state._checkpoint_id
|
||||
|
||||
provider = JsonProvider()
|
||||
restored = RuntimeState.from_checkpoint(
|
||||
loc, provider, context={"from_checkpoint": True}
|
||||
)
|
||||
assert restored._checkpoint_id == written_id
|
||||
assert restored._parent_id == written_id
|
||||
|
||||
def test_fork_sets_branch(self) -> None:
|
||||
state = self._make_state()
|
||||
state._checkpoint_id = "abc12345"
|
||||
state._parent_id = "abc12345"
|
||||
state.fork("my-experiment")
|
||||
assert state._branch == "my-experiment"
|
||||
assert state._parent_id == "abc12345"
|
||||
|
||||
def test_fork_auto_branch(self) -> None:
|
||||
state = self._make_state()
|
||||
state._checkpoint_id = "20260409T120000_abc12345"
|
||||
state.fork()
|
||||
assert state._branch == "fork/20260409T120000_abc12345"
|
||||
|
||||
def test_fork_no_checkpoint_id_unique(self) -> None:
|
||||
state = self._make_state()
|
||||
state.fork()
|
||||
assert state._branch.startswith("fork/")
|
||||
assert len(state._branch) == len("fork/") + 8
|
||||
# Two forks without checkpoint_id produce different branches
|
||||
first = state._branch
|
||||
state.fork()
|
||||
assert state._branch != first
|
||||
|
||||
|
||||
# ---------- JsonProvider forking ----------
|
||||
|
||||
|
||||
class TestJsonProviderFork:
|
||||
def test_checkpoint_writes_to_branch_subdir(self) -> None:
|
||||
provider = JsonProvider()
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
path = provider.checkpoint("{}", d, branch="main")
|
||||
assert "/main/" in path
|
||||
assert path.endswith(".json")
|
||||
assert os.path.isfile(path)
|
||||
|
||||
def test_checkpoint_fork_branch_subdir(self) -> None:
|
||||
provider = JsonProvider()
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
path = provider.checkpoint("{}", d, branch="fork/exp1")
|
||||
assert "/fork/exp1/" in path
|
||||
assert os.path.isfile(path)
|
||||
|
||||
def test_prune_branch_aware(self) -> None:
|
||||
provider = JsonProvider()
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
# Write 3 checkpoints on main, 2 on fork
|
||||
for _ in range(3):
|
||||
provider.checkpoint("{}", d, branch="main")
|
||||
time.sleep(0.01)
|
||||
for _ in range(2):
|
||||
provider.checkpoint("{}", d, branch="fork/a")
|
||||
time.sleep(0.01)
|
||||
|
||||
# Prune main to 1
|
||||
provider.prune(d, max_keep=1, branch="main")
|
||||
|
||||
main_dir = os.path.join(d, "main")
|
||||
fork_dir = os.path.join(d, "fork", "a")
|
||||
assert len(os.listdir(main_dir)) == 1
|
||||
assert len(os.listdir(fork_dir)) == 2 # untouched
|
||||
|
||||
def test_extract_id(self) -> None:
|
||||
provider = JsonProvider()
|
||||
assert provider.extract_id("/dir/main/20260409T120000_abc12345_p-none.json") == "20260409T120000_abc12345"
|
||||
assert provider.extract_id("/dir/main/20260409T120000_abc12345_p-20260409T115900_def67890.json") == "20260409T120000_abc12345"
|
||||
|
||||
def test_branch_traversal_rejected(self) -> None:
|
||||
provider = JsonProvider()
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
with pytest.raises(ValueError, match="escapes checkpoint directory"):
|
||||
provider.checkpoint("{}", d, branch="../../etc")
|
||||
with pytest.raises(ValueError, match="escapes checkpoint directory"):
|
||||
provider.prune(d, max_keep=1, branch="../../etc")
|
||||
|
||||
def test_filename_encodes_parent_id(self) -> None:
|
||||
provider = JsonProvider()
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
# First checkpoint — no parent
|
||||
path1 = provider.checkpoint("{}", d, branch="main")
|
||||
assert "_p-none.json" in path1
|
||||
|
||||
# Second checkpoint — with parent
|
||||
id1 = provider.extract_id(path1)
|
||||
path2 = provider.checkpoint("{}", d, parent_id=id1, branch="main")
|
||||
assert f"_p-{id1}.json" in path2
|
||||
|
||||
def test_checkpoint_chaining(self) -> None:
|
||||
"""RuntimeState.checkpoint() chains parent_id after each write."""
|
||||
state = self._make_state()
|
||||
state._provider = JsonProvider()
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
state.checkpoint(d)
|
||||
id1 = state._checkpoint_id
|
||||
assert id1 is not None
|
||||
assert state._parent_id == id1
|
||||
|
||||
loc2 = state.checkpoint(d)
|
||||
id2 = state._checkpoint_id
|
||||
assert id2 is not None
|
||||
assert id2 != id1
|
||||
assert state._parent_id == id2
|
||||
|
||||
# Verify the second checkpoint blob has parent_id == id1
|
||||
with open(loc2) as f:
|
||||
data2 = json.loads(f.read())
|
||||
assert data2["parent_id"] == id1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acheckpoint_chaining(self) -> None:
|
||||
"""Async checkpoint path chains lineage identically to sync."""
|
||||
state = self._make_state()
|
||||
state._provider = JsonProvider()
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
await state.acheckpoint(d)
|
||||
id1 = state._checkpoint_id
|
||||
assert id1 is not None
|
||||
|
||||
loc2 = await state.acheckpoint(d)
|
||||
id2 = state._checkpoint_id
|
||||
assert id2 != id1
|
||||
assert state._parent_id == id2
|
||||
|
||||
with open(loc2) as f:
|
||||
data2 = json.loads(f.read())
|
||||
assert data2["parent_id"] == id1
|
||||
|
||||
def _make_state(self) -> RuntimeState:
|
||||
from crewai import Agent, Crew
|
||||
|
||||
agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini")
|
||||
crew = Crew(agents=[agent], tasks=[], verbose=False)
|
||||
return RuntimeState(root=[crew])
|
||||
|
||||
|
||||
# ---------- SqliteProvider forking ----------
|
||||
|
||||
|
||||
class TestSqliteProviderFork:
|
||||
def test_checkpoint_stores_branch_and_parent(self) -> None:
|
||||
provider = SqliteProvider()
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
db = os.path.join(d, "cp.db")
|
||||
loc = provider.checkpoint("{}", db, parent_id="p1", branch="exp")
|
||||
cid = provider.extract_id(loc)
|
||||
|
||||
with sqlite3.connect(db) as conn:
|
||||
row = conn.execute(
|
||||
"SELECT parent_id, branch FROM checkpoints WHERE id = ?",
|
||||
(cid,),
|
||||
).fetchone()
|
||||
assert row == ("p1", "exp")
|
||||
|
||||
def test_prune_branch_aware(self) -> None:
|
||||
provider = SqliteProvider()
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
db = os.path.join(d, "cp.db")
|
||||
for _ in range(3):
|
||||
provider.checkpoint("{}", db, branch="main")
|
||||
for _ in range(2):
|
||||
provider.checkpoint("{}", db, branch="fork/a")
|
||||
|
||||
provider.prune(db, max_keep=1, branch="main")
|
||||
|
||||
with sqlite3.connect(db) as conn:
|
||||
main_count = conn.execute(
|
||||
"SELECT COUNT(*) FROM checkpoints WHERE branch = 'main'"
|
||||
).fetchone()[0]
|
||||
fork_count = conn.execute(
|
||||
"SELECT COUNT(*) FROM checkpoints WHERE branch = 'fork/a'"
|
||||
).fetchone()[0]
|
||||
assert main_count == 1
|
||||
assert fork_count == 2
|
||||
|
||||
def test_extract_id(self) -> None:
|
||||
provider = SqliteProvider()
|
||||
assert provider.extract_id("/path/to/db#abc123") == "abc123"
|
||||
|
||||
def test_checkpoint_chaining_sqlite(self) -> None:
|
||||
state = self._make_state()
|
||||
state._provider = SqliteProvider()
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
db = os.path.join(d, "cp.db")
|
||||
state.checkpoint(db)
|
||||
id1 = state._checkpoint_id
|
||||
|
||||
state.checkpoint(db)
|
||||
id2 = state._checkpoint_id
|
||||
assert id2 != id1
|
||||
|
||||
# Second row should have parent_id == id1
|
||||
with sqlite3.connect(db) as conn:
|
||||
row = conn.execute(
|
||||
"SELECT parent_id FROM checkpoints WHERE id = ?", (id2,)
|
||||
).fetchone()
|
||||
assert row[0] == id1
|
||||
|
||||
def _make_state(self) -> RuntimeState:
|
||||
from crewai import Agent, Crew
|
||||
|
||||
agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini")
|
||||
crew = Crew(agents=[agent], tasks=[], verbose=False)
|
||||
return RuntimeState(root=[crew])
|
||||
|
||||
Reference in New Issue
Block a user