feat: add checkpoint forking with lineage tracking

This commit is contained in:
Greyson LaLonde
2026-04-10 00:03:28 +08:00
committed by GitHub
parent ce56472fc3
commit 68c754883d
8 changed files with 560 additions and 44 deletions

View File

@@ -400,6 +400,34 @@ class Crew(FlowTrackable, BaseModel):
return entity
raise ValueError(f"No Crew found in checkpoint: {path}")
@classmethod
def fork(
cls,
path: str,
*,
branch: str | None = None,
provider: BaseProvider | None = None,
) -> Crew:
"""Fork a Crew from a checkpoint, creating a new execution branch.
Args:
path: Path to a checkpoint file.
branch: Branch label for the fork. Auto-generated if not provided.
provider: Storage backend to read from. Defaults to auto-detect.
Returns:
A Crew instance on the new branch. Call kickoff() to run.
"""
crew = cls.from_checkpoint(path, provider=provider)
state = crewai_event_bus._runtime_state
if state is None:
raise RuntimeError(
"Cannot fork: no runtime state on the event bus. "
"Ensure from_checkpoint() succeeded before calling fork()."
)
state.fork(branch)
return crew
def _restore_runtime(self) -> None:
"""Re-create runtime objects after restoring from a checkpoint."""
for agent in self.agents:

View File

@@ -960,6 +960,34 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
return instance
raise ValueError(f"No Flow found in checkpoint: {path}")
@classmethod
def fork(
cls,
path: str,
*,
branch: str | None = None,
provider: BaseProvider | None = None,
) -> Flow: # type: ignore[type-arg]
"""Fork a Flow from a checkpoint, creating a new execution branch.
Args:
path: Path to a checkpoint file.
branch: Branch label for the fork. Auto-generated if not provided.
provider: Storage backend to read from. Defaults to auto-detect.
Returns:
A Flow instance on the new branch. Call kickoff() to run.
"""
flow = cls.from_checkpoint(path, provider=provider)
state = crewai_event_bus._runtime_state
if state is None:
raise RuntimeError(
"Cannot fork: no runtime state on the event bus. "
"Ensure from_checkpoint() succeeded before calling fork()."
)
state.fork(branch)
return flow
checkpoint_completed_methods: set[str] | None = Field(default=None)
checkpoint_method_outputs: list[Any] | None = Field(default=None)
checkpoint_method_counts: dict[str, int] | None = Field(default=None)

View File

@@ -106,10 +106,16 @@ def _do_checkpoint(state: RuntimeState, cfg: CheckpointConfig) -> None:
"""Write a checkpoint and prune old ones if configured."""
_prepare_entities(state.root)
data = state.model_dump_json()
cfg.provider.checkpoint(data, cfg.location)
location = cfg.provider.checkpoint(
data,
cfg.location,
parent_id=state._parent_id,
branch=state._branch,
)
state._chain_lineage(cfg.provider, location)
if cfg.max_checkpoints is not None:
cfg.provider.prune(cfg.location, cfg.max_checkpoints)
cfg.provider.prune(cfg.location, cfg.max_checkpoints, branch=state._branch)
def _should_checkpoint(source: Any, event: BaseEvent) -> CheckpointConfig | None:

View File

@@ -17,12 +17,21 @@ class BaseProvider(BaseModel, ABC):
provider_type: str = "base"
@abstractmethod
def checkpoint(self, data: str, location: str) -> str:
def checkpoint(
self,
data: str,
location: str,
*,
parent_id: str | None = None,
branch: str = "main",
) -> str:
"""Persist a snapshot synchronously.
Args:
data: The serialized string to persist.
location: Storage destination (directory, file path, URI, etc.).
parent_id: ID of the parent checkpoint for lineage tracking.
branch: Branch label for this checkpoint.
Returns:
A location identifier for the saved checkpoint.
@@ -30,12 +39,21 @@ class BaseProvider(BaseModel, ABC):
...
@abstractmethod
async def acheckpoint(self, data: str, location: str) -> str:
async def acheckpoint(
self,
data: str,
location: str,
*,
parent_id: str | None = None,
branch: str = "main",
) -> str:
"""Persist a snapshot asynchronously.
Args:
data: The serialized string to persist.
location: Storage destination (directory, file path, URI, etc.).
parent_id: ID of the parent checkpoint for lineage tracking.
branch: Branch label for this checkpoint.
Returns:
A location identifier for the saved checkpoint.
@@ -43,12 +61,25 @@ class BaseProvider(BaseModel, ABC):
...
@abstractmethod
def prune(self, location: str, max_keep: int) -> None:
"""Remove old checkpoints, keeping at most *max_keep*.
def prune(self, location: str, max_keep: int, *, branch: str = "main") -> None:
"""Remove old checkpoints, keeping at most *max_keep* per branch.
Args:
location: The storage destination passed to ``checkpoint``.
max_keep: Maximum number of checkpoints to retain.
branch: Only prune checkpoints on this branch.
"""
...
@abstractmethod
def extract_id(self, location: str) -> str:
"""Extract the checkpoint ID from a location string.
Args:
location: The identifier returned by a previous ``checkpoint`` call.
Returns:
The checkpoint ID.
"""
...

View File

@@ -19,48 +19,87 @@ from crewai.state.provider.core import BaseProvider
logger = logging.getLogger(__name__)
def _safe_branch(base: str, branch: str) -> None:
"""Validate that a branch name doesn't escape the base directory.
Raises:
ValueError: If the branch resolves outside the base directory.
"""
base_resolved = str(Path(base).resolve())
target_resolved = str((Path(base) / branch).resolve())
if (
not target_resolved.startswith(base_resolved + os.sep)
and target_resolved != base_resolved
):
raise ValueError(f"Branch name escapes checkpoint directory: {branch!r}")
class JsonProvider(BaseProvider):
"""Persists runtime state checkpoints as JSON files on the local filesystem."""
provider_type: Literal["json"] = "json"
def checkpoint(self, data: str, location: str) -> str:
def checkpoint(
self,
data: str,
location: str,
*,
parent_id: str | None = None,
branch: str = "main",
) -> str:
"""Write a JSON checkpoint file.
Args:
data: The serialized JSON string to persist.
location: Directory where the checkpoint will be saved.
location: Base directory where checkpoints are saved.
parent_id: ID of the parent checkpoint for lineage tracking.
Encoded in the filename for queryable lineage without
parsing the blob.
branch: Branch label. Files are stored under ``location/branch/``.
Returns:
The path to the written checkpoint file.
"""
file_path = _build_path(location)
file_path = _build_path(location, branch, parent_id)
file_path.parent.mkdir(parents=True, exist_ok=True)
with open(file_path, "w") as f:
f.write(data)
return str(file_path)
async def acheckpoint(self, data: str, location: str) -> str:
async def acheckpoint(
self,
data: str,
location: str,
*,
parent_id: str | None = None,
branch: str = "main",
) -> str:
"""Write a JSON checkpoint file asynchronously.
Args:
data: The serialized JSON string to persist.
location: Directory where the checkpoint will be saved.
location: Base directory where checkpoints are saved.
parent_id: ID of the parent checkpoint for lineage tracking.
Encoded in the filename for queryable lineage without
parsing the blob.
branch: Branch label. Files are stored under ``location/branch/``.
Returns:
The path to the written checkpoint file.
"""
file_path = _build_path(location)
file_path = _build_path(location, branch, parent_id)
await aiofiles.os.makedirs(str(file_path.parent), exist_ok=True)
async with aiofiles.open(file_path, "w") as f:
await f.write(data)
return str(file_path)
def prune(self, location: str, max_keep: int) -> None:
"""Remove oldest checkpoint files beyond *max_keep*."""
pattern = os.path.join(location, "*.json")
def prune(self, location: str, max_keep: int, *, branch: str = "main") -> None:
"""Remove oldest checkpoint files beyond *max_keep* on a branch."""
_safe_branch(location, branch)
branch_dir = os.path.join(location, branch)
pattern = os.path.join(branch_dir, "*.json")
files = sorted(glob.glob(pattern), key=os.path.getmtime)
for path in files if max_keep == 0 else files[:-max_keep]:
try:
@@ -68,6 +107,16 @@ class JsonProvider(BaseProvider):
except OSError: # noqa: PERF203
logger.debug("Failed to remove %s", path, exc_info=True)
def extract_id(self, location: str) -> str:
"""Extract the checkpoint ID from a file path.
The filename format is ``{ts}_{uuid8}_p-{parent}.json``.
The checkpoint ID is the ``{ts}_{uuid8}`` prefix.
"""
stem = Path(location).stem
idx = stem.find("_p-")
return stem[:idx] if idx != -1 else stem
def from_checkpoint(self, location: str) -> str:
"""Read a JSON checkpoint file.
@@ -92,15 +141,24 @@ class JsonProvider(BaseProvider):
return await f.read()
def _build_path(directory: str) -> Path:
"""Build a timestamped checkpoint file path.
def _build_path(
directory: str, branch: str = "main", parent_id: str | None = None
) -> Path:
"""Build a timestamped checkpoint file path under a branch subdirectory.
Filename format: ``{ts}_{uuid8}_p-{parent_id}.json``
Args:
directory: Parent directory for the checkpoint file.
directory: Base directory for checkpoints.
branch: Branch label used as a subdirectory name.
parent_id: Parent checkpoint ID to encode in the filename.
Returns:
The target file path.
"""
_safe_branch(directory, branch)
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S")
filename = f"{ts}_{uuid.uuid4().hex[:8]}.json"
return Path(directory) / filename
short_uuid = uuid.uuid4().hex[:8]
parent_suffix = parent_id or "none"
filename = f"{ts}_{short_uuid}_p-{parent_suffix}.json"
return Path(directory) / branch / filename

View File

@@ -17,15 +17,20 @@ _CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS checkpoints (
id TEXT PRIMARY KEY,
created_at TEXT NOT NULL,
parent_id TEXT,
branch TEXT NOT NULL DEFAULT 'main',
data JSONB NOT NULL
)
"""
_INSERT = "INSERT INTO checkpoints (id, created_at, data) VALUES (?, ?, jsonb(?))"
_INSERT = (
"INSERT INTO checkpoints (id, created_at, parent_id, branch, data) "
"VALUES (?, ?, ?, ?, jsonb(?))"
)
_SELECT = "SELECT json(data) FROM checkpoints WHERE id = ?"
_PRUNE = """
DELETE FROM checkpoints WHERE rowid NOT IN (
SELECT rowid FROM checkpoints ORDER BY rowid DESC LIMIT ?
DELETE FROM checkpoints WHERE branch = ? AND rowid NOT IN (
SELECT rowid FROM checkpoints WHERE branch = ? ORDER BY rowid DESC LIMIT ?
)
"""
@@ -50,12 +55,21 @@ class SqliteProvider(BaseProvider):
provider_type: Literal["sqlite"] = "sqlite"
def checkpoint(self, data: str, location: str) -> str:
def checkpoint(
self,
data: str,
location: str,
*,
parent_id: str | None = None,
branch: str = "main",
) -> str:
"""Write a checkpoint to the SQLite database.
Args:
data: The serialized JSON string to persist.
location: Path to the SQLite database file.
parent_id: ID of the parent checkpoint for lineage tracking.
branch: Branch label for this checkpoint.
Returns:
A location string in the format ``"db_path#checkpoint_id"``.
@@ -65,16 +79,25 @@ class SqliteProvider(BaseProvider):
with sqlite3.connect(location) as conn:
conn.execute("PRAGMA journal_mode=WAL")
conn.execute(_CREATE_TABLE)
conn.execute(_INSERT, (checkpoint_id, ts, data))
conn.execute(_INSERT, (checkpoint_id, ts, parent_id, branch, data))
conn.commit()
return f"{location}#{checkpoint_id}"
async def acheckpoint(self, data: str, location: str) -> str:
async def acheckpoint(
self,
data: str,
location: str,
*,
parent_id: str | None = None,
branch: str = "main",
) -> str:
"""Write a checkpoint to the SQLite database asynchronously.
Args:
data: The serialized JSON string to persist.
location: Path to the SQLite database file.
parent_id: ID of the parent checkpoint for lineage tracking.
branch: Branch label for this checkpoint.
Returns:
A location string in the format ``"db_path#checkpoint_id"``.
@@ -84,16 +107,20 @@ class SqliteProvider(BaseProvider):
async with aiosqlite.connect(location) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(_CREATE_TABLE)
await db.execute(_INSERT, (checkpoint_id, ts, data))
await db.execute(_INSERT, (checkpoint_id, ts, parent_id, branch, data))
await db.commit()
return f"{location}#{checkpoint_id}"
def prune(self, location: str, max_keep: int) -> None:
"""Remove oldest checkpoint rows beyond *max_keep*."""
def prune(self, location: str, max_keep: int, *, branch: str = "main") -> None:
"""Remove oldest checkpoint rows beyond *max_keep* on a branch."""
with sqlite3.connect(location) as conn:
conn.execute(_PRUNE, (max_keep,))
conn.execute(_PRUNE, (branch, branch, max_keep))
conn.commit()
def extract_id(self, location: str) -> str:
"""Extract the checkpoint ID from a ``db_path#id`` string."""
return location.rsplit("#", 1)[1]
def from_checkpoint(self, location: str) -> str:
"""Read a checkpoint from the SQLite database.

View File

@@ -10,6 +10,7 @@ via ``RuntimeState.model_rebuild()``.
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import uuid
from pydantic import (
ModelWrapValidatorHandler,
@@ -64,6 +65,9 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
root: list[Entity]
_provider: BaseProvider = PrivateAttr(default_factory=JsonProvider)
_event_record: EventRecord = PrivateAttr(default_factory=EventRecord)
_checkpoint_id: str | None = PrivateAttr(default=None)
_parent_id: str | None = PrivateAttr(default=None)
_branch: str = PrivateAttr(default="main")
@property
def event_record(self) -> EventRecord:
@@ -73,6 +77,8 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
@model_serializer(mode="plain")
def _serialize(self) -> dict[str, Any]:
return {
"parent_id": self._parent_id,
"branch": self._branch,
"entities": [e.model_dump(mode="json") for e in self.root],
"event_record": self._event_record.model_dump(),
}
@@ -87,9 +93,24 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
state = handler(data["entities"])
if record_data:
state._event_record = EventRecord.model_validate(record_data)
state._parent_id = data.get("parent_id")
state._branch = data.get("branch", "main")
return state
return handler(data)
def _chain_lineage(self, provider: BaseProvider, location: str) -> None:
"""Update lineage fields after a successful checkpoint write.
Sets ``_checkpoint_id`` and ``_parent_id`` so the next write
records the correct parent in the lineage chain.
Args:
provider: The provider that performed the write.
location: The location string returned by the provider.
"""
self._checkpoint_id = provider.extract_id(location)
self._parent_id = self._checkpoint_id
def checkpoint(self, location: str) -> str:
"""Write a checkpoint.
@@ -101,7 +122,14 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
A location identifier for the saved checkpoint.
"""
_prepare_entities(self.root)
return self._provider.checkpoint(self.model_dump_json(), location)
result = self._provider.checkpoint(
self.model_dump_json(),
location,
parent_id=self._parent_id,
branch=self._branch,
)
self._chain_lineage(self._provider, result)
return result
async def acheckpoint(self, location: str) -> str:
"""Async version of :meth:`checkpoint`.
@@ -114,7 +142,29 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
A location identifier for the saved checkpoint.
"""
_prepare_entities(self.root)
return await self._provider.acheckpoint(self.model_dump_json(), location)
result = await self._provider.acheckpoint(
self.model_dump_json(),
location,
parent_id=self._parent_id,
branch=self._branch,
)
self._chain_lineage(self._provider, result)
return result
def fork(self, branch: str | None = None) -> None:
"""Mark this state as a fork for subsequent checkpoints.
Args:
branch: Branch label. Auto-generated from the current checkpoint
ID if not provided. Always unique — safe to call multiple
times without collisions.
"""
if branch:
self._branch = branch
elif self._checkpoint_id:
self._branch = f"fork/{self._checkpoint_id}"
else:
self._branch = f"fork/{uuid.uuid4().hex[:8]}"
@classmethod
def from_checkpoint(
@@ -131,7 +181,11 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
A restored RuntimeState.
"""
raw = provider.from_checkpoint(location)
return cls.model_validate_json(raw, **kwargs)
state = cls.model_validate_json(raw, **kwargs)
checkpoint_id = provider.extract_id(location)
state._checkpoint_id = checkpoint_id
state._parent_id = checkpoint_id
return state
@classmethod
async def afrom_checkpoint(
@@ -148,7 +202,11 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
A restored RuntimeState.
"""
raw = await provider.afrom_checkpoint(location)
return cls.model_validate_json(raw, **kwargs)
state = cls.model_validate_json(raw, **kwargs)
checkpoint_id = provider.extract_id(location)
state._checkpoint_id = checkpoint_id
state._parent_id = checkpoint_id
return state
def _prepare_entities(root: list[Entity]) -> None:

View File

@@ -1,8 +1,10 @@
"""Tests for CheckpointConfig, checkpoint listener, and pruning."""
"""Tests for CheckpointConfig, checkpoint listener, pruning, and forking."""
from __future__ import annotations
import json
import os
import sqlite3
import tempfile
import time
from typing import Any
@@ -21,6 +23,8 @@ from crewai.state.checkpoint_listener import (
_SENTINEL,
)
from crewai.state.provider.json_provider import JsonProvider
from crewai.state.provider.sqlite_provider import SqliteProvider
from crewai.state.runtime import RuntimeState
from crewai.task import Task
@@ -116,35 +120,41 @@ class TestFindCheckpoint:
class TestPrune:
def test_prune_keeps_newest(self) -> None:
with tempfile.TemporaryDirectory() as d:
branch_dir = os.path.join(d, "main")
os.makedirs(branch_dir)
for i in range(5):
path = os.path.join(d, f"cp_{i}.json")
path = os.path.join(branch_dir, f"cp_{i}.json")
with open(path, "w") as f:
f.write("{}")
# Ensure distinct mtime
time.sleep(0.01)
JsonProvider().prune(d, max_keep=2)
remaining = os.listdir(d)
JsonProvider().prune(d, max_keep=2, branch="main")
remaining = os.listdir(branch_dir)
assert len(remaining) == 2
assert "cp_3.json" in remaining
assert "cp_4.json" in remaining
def test_prune_zero_removes_all(self) -> None:
with tempfile.TemporaryDirectory() as d:
branch_dir = os.path.join(d, "main")
os.makedirs(branch_dir)
for i in range(3):
with open(os.path.join(d, f"cp_{i}.json"), "w") as f:
with open(os.path.join(branch_dir, f"cp_{i}.json"), "w") as f:
f.write("{}")
JsonProvider().prune(d, max_keep=0)
assert os.listdir(d) == []
JsonProvider().prune(d, max_keep=0, branch="main")
assert os.listdir(branch_dir) == []
def test_prune_more_than_existing(self) -> None:
with tempfile.TemporaryDirectory() as d:
with open(os.path.join(d, "cp.json"), "w") as f:
branch_dir = os.path.join(d, "main")
os.makedirs(branch_dir)
with open(os.path.join(branch_dir, "cp.json"), "w") as f:
f.write("{}")
JsonProvider().prune(d, max_keep=10)
assert len(os.listdir(d)) == 1
JsonProvider().prune(d, max_keep=10, branch="main")
assert len(os.listdir(branch_dir)) == 1
# ---------- CheckpointConfig ----------
@@ -167,3 +177,273 @@ class TestCheckpointConfig:
on_events=["task_completed", "crew_kickoff_completed"]
)
assert cfg.trigger_events == {"task_completed", "crew_kickoff_completed"}
# ---------- RuntimeState lineage ----------
class TestRuntimeStateLineage:
def _make_state(self) -> RuntimeState:
from crewai import Agent, Crew
agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini")
crew = Crew(agents=[agent], tasks=[], verbose=False)
return RuntimeState(root=[crew])
def test_default_lineage_fields(self) -> None:
state = self._make_state()
assert state._checkpoint_id is None
assert state._parent_id is None
assert state._branch == "main"
def test_serialize_includes_lineage(self) -> None:
state = self._make_state()
state._parent_id = "parent456"
state._branch = "experiment"
dumped = json.loads(state.model_dump_json())
assert dumped["parent_id"] == "parent456"
assert dumped["branch"] == "experiment"
assert "checkpoint_id" not in dumped
def test_deserialize_restores_lineage(self) -> None:
state = self._make_state()
state._parent_id = "parent456"
state._branch = "experiment"
raw = state.model_dump_json()
restored = RuntimeState.model_validate_json(
raw, context={"from_checkpoint": True}
)
assert restored._parent_id == "parent456"
assert restored._branch == "experiment"
def test_deserialize_defaults_missing_lineage(self) -> None:
state = self._make_state()
raw = state.model_dump_json()
data = json.loads(raw)
data.pop("parent_id", None)
data.pop("branch", None)
restored = RuntimeState.model_validate_json(
json.dumps(data), context={"from_checkpoint": True}
)
assert restored._parent_id is None
assert restored._branch == "main"
def test_from_checkpoint_sets_checkpoint_id(self) -> None:
"""from_checkpoint sets _checkpoint_id from the location, not the blob."""
state = self._make_state()
state._provider = JsonProvider()
with tempfile.TemporaryDirectory() as d:
loc = state.checkpoint(d)
written_id = state._checkpoint_id
provider = JsonProvider()
restored = RuntimeState.from_checkpoint(
loc, provider, context={"from_checkpoint": True}
)
assert restored._checkpoint_id == written_id
assert restored._parent_id == written_id
def test_fork_sets_branch(self) -> None:
state = self._make_state()
state._checkpoint_id = "abc12345"
state._parent_id = "abc12345"
state.fork("my-experiment")
assert state._branch == "my-experiment"
assert state._parent_id == "abc12345"
def test_fork_auto_branch(self) -> None:
state = self._make_state()
state._checkpoint_id = "20260409T120000_abc12345"
state.fork()
assert state._branch == "fork/20260409T120000_abc12345"
def test_fork_no_checkpoint_id_unique(self) -> None:
state = self._make_state()
state.fork()
assert state._branch.startswith("fork/")
assert len(state._branch) == len("fork/") + 8
# Two forks without checkpoint_id produce different branches
first = state._branch
state.fork()
assert state._branch != first
# ---------- JsonProvider forking ----------
class TestJsonProviderFork:
def test_checkpoint_writes_to_branch_subdir(self) -> None:
provider = JsonProvider()
with tempfile.TemporaryDirectory() as d:
path = provider.checkpoint("{}", d, branch="main")
assert "/main/" in path
assert path.endswith(".json")
assert os.path.isfile(path)
def test_checkpoint_fork_branch_subdir(self) -> None:
provider = JsonProvider()
with tempfile.TemporaryDirectory() as d:
path = provider.checkpoint("{}", d, branch="fork/exp1")
assert "/fork/exp1/" in path
assert os.path.isfile(path)
def test_prune_branch_aware(self) -> None:
provider = JsonProvider()
with tempfile.TemporaryDirectory() as d:
# Write 3 checkpoints on main, 2 on fork
for _ in range(3):
provider.checkpoint("{}", d, branch="main")
time.sleep(0.01)
for _ in range(2):
provider.checkpoint("{}", d, branch="fork/a")
time.sleep(0.01)
# Prune main to 1
provider.prune(d, max_keep=1, branch="main")
main_dir = os.path.join(d, "main")
fork_dir = os.path.join(d, "fork", "a")
assert len(os.listdir(main_dir)) == 1
assert len(os.listdir(fork_dir)) == 2 # untouched
def test_extract_id(self) -> None:
provider = JsonProvider()
assert provider.extract_id("/dir/main/20260409T120000_abc12345_p-none.json") == "20260409T120000_abc12345"
assert provider.extract_id("/dir/main/20260409T120000_abc12345_p-20260409T115900_def67890.json") == "20260409T120000_abc12345"
def test_branch_traversal_rejected(self) -> None:
provider = JsonProvider()
with tempfile.TemporaryDirectory() as d:
with pytest.raises(ValueError, match="escapes checkpoint directory"):
provider.checkpoint("{}", d, branch="../../etc")
with pytest.raises(ValueError, match="escapes checkpoint directory"):
provider.prune(d, max_keep=1, branch="../../etc")
def test_filename_encodes_parent_id(self) -> None:
provider = JsonProvider()
with tempfile.TemporaryDirectory() as d:
# First checkpoint — no parent
path1 = provider.checkpoint("{}", d, branch="main")
assert "_p-none.json" in path1
# Second checkpoint — with parent
id1 = provider.extract_id(path1)
path2 = provider.checkpoint("{}", d, parent_id=id1, branch="main")
assert f"_p-{id1}.json" in path2
def test_checkpoint_chaining(self) -> None:
"""RuntimeState.checkpoint() chains parent_id after each write."""
state = self._make_state()
state._provider = JsonProvider()
with tempfile.TemporaryDirectory() as d:
state.checkpoint(d)
id1 = state._checkpoint_id
assert id1 is not None
assert state._parent_id == id1
loc2 = state.checkpoint(d)
id2 = state._checkpoint_id
assert id2 is not None
assert id2 != id1
assert state._parent_id == id2
# Verify the second checkpoint blob has parent_id == id1
with open(loc2) as f:
data2 = json.loads(f.read())
assert data2["parent_id"] == id1
@pytest.mark.asyncio
async def test_acheckpoint_chaining(self) -> None:
"""Async checkpoint path chains lineage identically to sync."""
state = self._make_state()
state._provider = JsonProvider()
with tempfile.TemporaryDirectory() as d:
await state.acheckpoint(d)
id1 = state._checkpoint_id
assert id1 is not None
loc2 = await state.acheckpoint(d)
id2 = state._checkpoint_id
assert id2 != id1
assert state._parent_id == id2
with open(loc2) as f:
data2 = json.loads(f.read())
assert data2["parent_id"] == id1
def _make_state(self) -> RuntimeState:
from crewai import Agent, Crew
agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini")
crew = Crew(agents=[agent], tasks=[], verbose=False)
return RuntimeState(root=[crew])
# ---------- SqliteProvider forking ----------
class TestSqliteProviderFork:
def test_checkpoint_stores_branch_and_parent(self) -> None:
provider = SqliteProvider()
with tempfile.TemporaryDirectory() as d:
db = os.path.join(d, "cp.db")
loc = provider.checkpoint("{}", db, parent_id="p1", branch="exp")
cid = provider.extract_id(loc)
with sqlite3.connect(db) as conn:
row = conn.execute(
"SELECT parent_id, branch FROM checkpoints WHERE id = ?",
(cid,),
).fetchone()
assert row == ("p1", "exp")
def test_prune_branch_aware(self) -> None:
provider = SqliteProvider()
with tempfile.TemporaryDirectory() as d:
db = os.path.join(d, "cp.db")
for _ in range(3):
provider.checkpoint("{}", db, branch="main")
for _ in range(2):
provider.checkpoint("{}", db, branch="fork/a")
provider.prune(db, max_keep=1, branch="main")
with sqlite3.connect(db) as conn:
main_count = conn.execute(
"SELECT COUNT(*) FROM checkpoints WHERE branch = 'main'"
).fetchone()[0]
fork_count = conn.execute(
"SELECT COUNT(*) FROM checkpoints WHERE branch = 'fork/a'"
).fetchone()[0]
assert main_count == 1
assert fork_count == 2
def test_extract_id(self) -> None:
provider = SqliteProvider()
assert provider.extract_id("/path/to/db#abc123") == "abc123"
def test_checkpoint_chaining_sqlite(self) -> None:
state = self._make_state()
state._provider = SqliteProvider()
with tempfile.TemporaryDirectory() as d:
db = os.path.join(d, "cp.db")
state.checkpoint(db)
id1 = state._checkpoint_id
state.checkpoint(db)
id2 = state._checkpoint_id
assert id2 != id1
# Second row should have parent_id == id1
with sqlite3.connect(db) as conn:
row = conn.execute(
"SELECT parent_id FROM checkpoints WHERE id = ?", (id2,)
).fetchone()
assert row[0] == id1
def _make_state(self) -> RuntimeState:
from crewai import Agent, Crew
agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini")
crew = Crew(agents=[agent], tasks=[], verbose=False)
return RuntimeState(root=[crew])