diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index 4090e706b..d279a265d 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -400,6 +400,34 @@ class Crew(FlowTrackable, BaseModel): return entity raise ValueError(f"No Crew found in checkpoint: {path}") + @classmethod + def fork( + cls, + path: str, + *, + branch: str | None = None, + provider: BaseProvider | None = None, + ) -> Crew: + """Fork a Crew from a checkpoint, creating a new execution branch. + + Args: + path: Path to a checkpoint file. + branch: Branch label for the fork. Auto-generated if not provided. + provider: Storage backend to read from. Defaults to auto-detect. + + Returns: + A Crew instance on the new branch. Call kickoff() to run. + """ + crew = cls.from_checkpoint(path, provider=provider) + state = crewai_event_bus._runtime_state + if state is None: + raise RuntimeError( + "Cannot fork: no runtime state on the event bus. " + "Ensure from_checkpoint() succeeded before calling fork()." + ) + state.fork(branch) + return crew + def _restore_runtime(self) -> None: """Re-create runtime objects after restoring from a checkpoint.""" for agent in self.agents: diff --git a/lib/crewai/src/crewai/flow/flow.py b/lib/crewai/src/crewai/flow/flow.py index a057da581..a85885bd9 100644 --- a/lib/crewai/src/crewai/flow/flow.py +++ b/lib/crewai/src/crewai/flow/flow.py @@ -960,6 +960,34 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): return instance raise ValueError(f"No Flow found in checkpoint: {path}") + @classmethod + def fork( + cls, + path: str, + *, + branch: str | None = None, + provider: BaseProvider | None = None, + ) -> Flow: # type: ignore[type-arg] + """Fork a Flow from a checkpoint, creating a new execution branch. + + Args: + path: Path to a checkpoint file. + branch: Branch label for the fork. Auto-generated if not provided. + provider: Storage backend to read from. Defaults to auto-detect. + + Returns: + A Flow instance on the new branch. Call kickoff() to run. + """ + flow = cls.from_checkpoint(path, provider=provider) + state = crewai_event_bus._runtime_state + if state is None: + raise RuntimeError( + "Cannot fork: no runtime state on the event bus. " + "Ensure from_checkpoint() succeeded before calling fork()." + ) + state.fork(branch) + return flow + checkpoint_completed_methods: set[str] | None = Field(default=None) checkpoint_method_outputs: list[Any] | None = Field(default=None) checkpoint_method_counts: dict[str, int] | None = Field(default=None) diff --git a/lib/crewai/src/crewai/state/checkpoint_listener.py b/lib/crewai/src/crewai/state/checkpoint_listener.py index 6471b9bde..c2ac728a8 100644 --- a/lib/crewai/src/crewai/state/checkpoint_listener.py +++ b/lib/crewai/src/crewai/state/checkpoint_listener.py @@ -106,10 +106,16 @@ def _do_checkpoint(state: RuntimeState, cfg: CheckpointConfig) -> None: """Write a checkpoint and prune old ones if configured.""" _prepare_entities(state.root) data = state.model_dump_json() - cfg.provider.checkpoint(data, cfg.location) + location = cfg.provider.checkpoint( + data, + cfg.location, + parent_id=state._parent_id, + branch=state._branch, + ) + state._chain_lineage(cfg.provider, location) if cfg.max_checkpoints is not None: - cfg.provider.prune(cfg.location, cfg.max_checkpoints) + cfg.provider.prune(cfg.location, cfg.max_checkpoints, branch=state._branch) def _should_checkpoint(source: Any, event: BaseEvent) -> CheckpointConfig | None: diff --git a/lib/crewai/src/crewai/state/provider/core.py b/lib/crewai/src/crewai/state/provider/core.py index 0b12364c0..c386d519f 100644 --- a/lib/crewai/src/crewai/state/provider/core.py +++ b/lib/crewai/src/crewai/state/provider/core.py @@ -17,12 +17,21 @@ class BaseProvider(BaseModel, ABC): provider_type: str = "base" @abstractmethod - def checkpoint(self, data: str, location: str) -> str: + def checkpoint( + self, + data: str, + location: str, + *, + parent_id: str | None = None, + branch: str = "main", + ) -> str: """Persist a snapshot synchronously. Args: data: The serialized string to persist. location: Storage destination (directory, file path, URI, etc.). + parent_id: ID of the parent checkpoint for lineage tracking. + branch: Branch label for this checkpoint. Returns: A location identifier for the saved checkpoint. @@ -30,12 +39,21 @@ class BaseProvider(BaseModel, ABC): ... @abstractmethod - async def acheckpoint(self, data: str, location: str) -> str: + async def acheckpoint( + self, + data: str, + location: str, + *, + parent_id: str | None = None, + branch: str = "main", + ) -> str: """Persist a snapshot asynchronously. Args: data: The serialized string to persist. location: Storage destination (directory, file path, URI, etc.). + parent_id: ID of the parent checkpoint for lineage tracking. + branch: Branch label for this checkpoint. Returns: A location identifier for the saved checkpoint. @@ -43,12 +61,25 @@ class BaseProvider(BaseModel, ABC): ... @abstractmethod - def prune(self, location: str, max_keep: int) -> None: - """Remove old checkpoints, keeping at most *max_keep*. + def prune(self, location: str, max_keep: int, *, branch: str = "main") -> None: + """Remove old checkpoints, keeping at most *max_keep* per branch. Args: location: The storage destination passed to ``checkpoint``. max_keep: Maximum number of checkpoints to retain. + branch: Only prune checkpoints on this branch. + """ + ... + + @abstractmethod + def extract_id(self, location: str) -> str: + """Extract the checkpoint ID from a location string. + + Args: + location: The identifier returned by a previous ``checkpoint`` call. + + Returns: + The checkpoint ID. """ ... diff --git a/lib/crewai/src/crewai/state/provider/json_provider.py b/lib/crewai/src/crewai/state/provider/json_provider.py index f9763e6f3..0f18a5901 100644 --- a/lib/crewai/src/crewai/state/provider/json_provider.py +++ b/lib/crewai/src/crewai/state/provider/json_provider.py @@ -19,48 +19,87 @@ from crewai.state.provider.core import BaseProvider logger = logging.getLogger(__name__) +def _safe_branch(base: str, branch: str) -> None: + """Validate that a branch name doesn't escape the base directory. + + Raises: + ValueError: If the branch resolves outside the base directory. + """ + base_resolved = str(Path(base).resolve()) + target_resolved = str((Path(base) / branch).resolve()) + if ( + not target_resolved.startswith(base_resolved + os.sep) + and target_resolved != base_resolved + ): + raise ValueError(f"Branch name escapes checkpoint directory: {branch!r}") + + class JsonProvider(BaseProvider): """Persists runtime state checkpoints as JSON files on the local filesystem.""" provider_type: Literal["json"] = "json" - def checkpoint(self, data: str, location: str) -> str: + def checkpoint( + self, + data: str, + location: str, + *, + parent_id: str | None = None, + branch: str = "main", + ) -> str: """Write a JSON checkpoint file. Args: data: The serialized JSON string to persist. - location: Directory where the checkpoint will be saved. + location: Base directory where checkpoints are saved. + parent_id: ID of the parent checkpoint for lineage tracking. + Encoded in the filename for queryable lineage without + parsing the blob. + branch: Branch label. Files are stored under ``location/branch/``. Returns: The path to the written checkpoint file. """ - file_path = _build_path(location) + file_path = _build_path(location, branch, parent_id) file_path.parent.mkdir(parents=True, exist_ok=True) with open(file_path, "w") as f: f.write(data) return str(file_path) - async def acheckpoint(self, data: str, location: str) -> str: + async def acheckpoint( + self, + data: str, + location: str, + *, + parent_id: str | None = None, + branch: str = "main", + ) -> str: """Write a JSON checkpoint file asynchronously. Args: data: The serialized JSON string to persist. - location: Directory where the checkpoint will be saved. + location: Base directory where checkpoints are saved. + parent_id: ID of the parent checkpoint for lineage tracking. + Encoded in the filename for queryable lineage without + parsing the blob. + branch: Branch label. Files are stored under ``location/branch/``. Returns: The path to the written checkpoint file. """ - file_path = _build_path(location) + file_path = _build_path(location, branch, parent_id) await aiofiles.os.makedirs(str(file_path.parent), exist_ok=True) async with aiofiles.open(file_path, "w") as f: await f.write(data) return str(file_path) - def prune(self, location: str, max_keep: int) -> None: - """Remove oldest checkpoint files beyond *max_keep*.""" - pattern = os.path.join(location, "*.json") + def prune(self, location: str, max_keep: int, *, branch: str = "main") -> None: + """Remove oldest checkpoint files beyond *max_keep* on a branch.""" + _safe_branch(location, branch) + branch_dir = os.path.join(location, branch) + pattern = os.path.join(branch_dir, "*.json") files = sorted(glob.glob(pattern), key=os.path.getmtime) for path in files if max_keep == 0 else files[:-max_keep]: try: @@ -68,6 +107,16 @@ class JsonProvider(BaseProvider): except OSError: # noqa: PERF203 logger.debug("Failed to remove %s", path, exc_info=True) + def extract_id(self, location: str) -> str: + """Extract the checkpoint ID from a file path. + + The filename format is ``{ts}_{uuid8}_p-{parent}.json``. + The checkpoint ID is the ``{ts}_{uuid8}`` prefix. + """ + stem = Path(location).stem + idx = stem.find("_p-") + return stem[:idx] if idx != -1 else stem + def from_checkpoint(self, location: str) -> str: """Read a JSON checkpoint file. @@ -92,15 +141,24 @@ class JsonProvider(BaseProvider): return await f.read() -def _build_path(directory: str) -> Path: - """Build a timestamped checkpoint file path. +def _build_path( + directory: str, branch: str = "main", parent_id: str | None = None +) -> Path: + """Build a timestamped checkpoint file path under a branch subdirectory. + + Filename format: ``{ts}_{uuid8}_p-{parent_id}.json`` Args: - directory: Parent directory for the checkpoint file. + directory: Base directory for checkpoints. + branch: Branch label used as a subdirectory name. + parent_id: Parent checkpoint ID to encode in the filename. Returns: The target file path. """ + _safe_branch(directory, branch) ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S") - filename = f"{ts}_{uuid.uuid4().hex[:8]}.json" - return Path(directory) / filename + short_uuid = uuid.uuid4().hex[:8] + parent_suffix = parent_id or "none" + filename = f"{ts}_{short_uuid}_p-{parent_suffix}.json" + return Path(directory) / branch / filename diff --git a/lib/crewai/src/crewai/state/provider/sqlite_provider.py b/lib/crewai/src/crewai/state/provider/sqlite_provider.py index e54f56180..5ee4dca26 100644 --- a/lib/crewai/src/crewai/state/provider/sqlite_provider.py +++ b/lib/crewai/src/crewai/state/provider/sqlite_provider.py @@ -17,15 +17,20 @@ _CREATE_TABLE = """ CREATE TABLE IF NOT EXISTS checkpoints ( id TEXT PRIMARY KEY, created_at TEXT NOT NULL, + parent_id TEXT, + branch TEXT NOT NULL DEFAULT 'main', data JSONB NOT NULL ) """ -_INSERT = "INSERT INTO checkpoints (id, created_at, data) VALUES (?, ?, jsonb(?))" +_INSERT = ( + "INSERT INTO checkpoints (id, created_at, parent_id, branch, data) " + "VALUES (?, ?, ?, ?, jsonb(?))" +) _SELECT = "SELECT json(data) FROM checkpoints WHERE id = ?" _PRUNE = """ -DELETE FROM checkpoints WHERE rowid NOT IN ( - SELECT rowid FROM checkpoints ORDER BY rowid DESC LIMIT ? +DELETE FROM checkpoints WHERE branch = ? AND rowid NOT IN ( + SELECT rowid FROM checkpoints WHERE branch = ? ORDER BY rowid DESC LIMIT ? ) """ @@ -50,12 +55,21 @@ class SqliteProvider(BaseProvider): provider_type: Literal["sqlite"] = "sqlite" - def checkpoint(self, data: str, location: str) -> str: + def checkpoint( + self, + data: str, + location: str, + *, + parent_id: str | None = None, + branch: str = "main", + ) -> str: """Write a checkpoint to the SQLite database. Args: data: The serialized JSON string to persist. location: Path to the SQLite database file. + parent_id: ID of the parent checkpoint for lineage tracking. + branch: Branch label for this checkpoint. Returns: A location string in the format ``"db_path#checkpoint_id"``. @@ -65,16 +79,25 @@ class SqliteProvider(BaseProvider): with sqlite3.connect(location) as conn: conn.execute("PRAGMA journal_mode=WAL") conn.execute(_CREATE_TABLE) - conn.execute(_INSERT, (checkpoint_id, ts, data)) + conn.execute(_INSERT, (checkpoint_id, ts, parent_id, branch, data)) conn.commit() return f"{location}#{checkpoint_id}" - async def acheckpoint(self, data: str, location: str) -> str: + async def acheckpoint( + self, + data: str, + location: str, + *, + parent_id: str | None = None, + branch: str = "main", + ) -> str: """Write a checkpoint to the SQLite database asynchronously. Args: data: The serialized JSON string to persist. location: Path to the SQLite database file. + parent_id: ID of the parent checkpoint for lineage tracking. + branch: Branch label for this checkpoint. Returns: A location string in the format ``"db_path#checkpoint_id"``. @@ -84,16 +107,20 @@ class SqliteProvider(BaseProvider): async with aiosqlite.connect(location) as db: await db.execute("PRAGMA journal_mode=WAL") await db.execute(_CREATE_TABLE) - await db.execute(_INSERT, (checkpoint_id, ts, data)) + await db.execute(_INSERT, (checkpoint_id, ts, parent_id, branch, data)) await db.commit() return f"{location}#{checkpoint_id}" - def prune(self, location: str, max_keep: int) -> None: - """Remove oldest checkpoint rows beyond *max_keep*.""" + def prune(self, location: str, max_keep: int, *, branch: str = "main") -> None: + """Remove oldest checkpoint rows beyond *max_keep* on a branch.""" with sqlite3.connect(location) as conn: - conn.execute(_PRUNE, (max_keep,)) + conn.execute(_PRUNE, (branch, branch, max_keep)) conn.commit() + def extract_id(self, location: str) -> str: + """Extract the checkpoint ID from a ``db_path#id`` string.""" + return location.rsplit("#", 1)[1] + def from_checkpoint(self, location: str) -> str: """Read a checkpoint from the SQLite database. diff --git a/lib/crewai/src/crewai/state/runtime.py b/lib/crewai/src/crewai/state/runtime.py index 6f1c5de80..b4293ad39 100644 --- a/lib/crewai/src/crewai/state/runtime.py +++ b/lib/crewai/src/crewai/state/runtime.py @@ -10,6 +10,7 @@ via ``RuntimeState.model_rebuild()``. from __future__ import annotations from typing import TYPE_CHECKING, Any +import uuid from pydantic import ( ModelWrapValidatorHandler, @@ -64,6 +65,9 @@ class RuntimeState(RootModel): # type: ignore[type-arg] root: list[Entity] _provider: BaseProvider = PrivateAttr(default_factory=JsonProvider) _event_record: EventRecord = PrivateAttr(default_factory=EventRecord) + _checkpoint_id: str | None = PrivateAttr(default=None) + _parent_id: str | None = PrivateAttr(default=None) + _branch: str = PrivateAttr(default="main") @property def event_record(self) -> EventRecord: @@ -73,6 +77,8 @@ class RuntimeState(RootModel): # type: ignore[type-arg] @model_serializer(mode="plain") def _serialize(self) -> dict[str, Any]: return { + "parent_id": self._parent_id, + "branch": self._branch, "entities": [e.model_dump(mode="json") for e in self.root], "event_record": self._event_record.model_dump(), } @@ -87,9 +93,24 @@ class RuntimeState(RootModel): # type: ignore[type-arg] state = handler(data["entities"]) if record_data: state._event_record = EventRecord.model_validate(record_data) + state._parent_id = data.get("parent_id") + state._branch = data.get("branch", "main") return state return handler(data) + def _chain_lineage(self, provider: BaseProvider, location: str) -> None: + """Update lineage fields after a successful checkpoint write. + + Sets ``_checkpoint_id`` and ``_parent_id`` so the next write + records the correct parent in the lineage chain. + + Args: + provider: The provider that performed the write. + location: The location string returned by the provider. + """ + self._checkpoint_id = provider.extract_id(location) + self._parent_id = self._checkpoint_id + def checkpoint(self, location: str) -> str: """Write a checkpoint. @@ -101,7 +122,14 @@ class RuntimeState(RootModel): # type: ignore[type-arg] A location identifier for the saved checkpoint. """ _prepare_entities(self.root) - return self._provider.checkpoint(self.model_dump_json(), location) + result = self._provider.checkpoint( + self.model_dump_json(), + location, + parent_id=self._parent_id, + branch=self._branch, + ) + self._chain_lineage(self._provider, result) + return result async def acheckpoint(self, location: str) -> str: """Async version of :meth:`checkpoint`. @@ -114,7 +142,29 @@ class RuntimeState(RootModel): # type: ignore[type-arg] A location identifier for the saved checkpoint. """ _prepare_entities(self.root) - return await self._provider.acheckpoint(self.model_dump_json(), location) + result = await self._provider.acheckpoint( + self.model_dump_json(), + location, + parent_id=self._parent_id, + branch=self._branch, + ) + self._chain_lineage(self._provider, result) + return result + + def fork(self, branch: str | None = None) -> None: + """Mark this state as a fork for subsequent checkpoints. + + Args: + branch: Branch label. Auto-generated from the current checkpoint + ID if not provided. Always unique — safe to call multiple + times without collisions. + """ + if branch: + self._branch = branch + elif self._checkpoint_id: + self._branch = f"fork/{self._checkpoint_id}" + else: + self._branch = f"fork/{uuid.uuid4().hex[:8]}" @classmethod def from_checkpoint( @@ -131,7 +181,11 @@ class RuntimeState(RootModel): # type: ignore[type-arg] A restored RuntimeState. """ raw = provider.from_checkpoint(location) - return cls.model_validate_json(raw, **kwargs) + state = cls.model_validate_json(raw, **kwargs) + checkpoint_id = provider.extract_id(location) + state._checkpoint_id = checkpoint_id + state._parent_id = checkpoint_id + return state @classmethod async def afrom_checkpoint( @@ -148,7 +202,11 @@ class RuntimeState(RootModel): # type: ignore[type-arg] A restored RuntimeState. """ raw = await provider.afrom_checkpoint(location) - return cls.model_validate_json(raw, **kwargs) + state = cls.model_validate_json(raw, **kwargs) + checkpoint_id = provider.extract_id(location) + state._checkpoint_id = checkpoint_id + state._parent_id = checkpoint_id + return state def _prepare_entities(root: list[Entity]) -> None: diff --git a/lib/crewai/tests/test_checkpoint.py b/lib/crewai/tests/test_checkpoint.py index 29dc289b4..cbea4b562 100644 --- a/lib/crewai/tests/test_checkpoint.py +++ b/lib/crewai/tests/test_checkpoint.py @@ -1,8 +1,10 @@ -"""Tests for CheckpointConfig, checkpoint listener, and pruning.""" +"""Tests for CheckpointConfig, checkpoint listener, pruning, and forking.""" from __future__ import annotations +import json import os +import sqlite3 import tempfile import time from typing import Any @@ -21,6 +23,8 @@ from crewai.state.checkpoint_listener import ( _SENTINEL, ) from crewai.state.provider.json_provider import JsonProvider +from crewai.state.provider.sqlite_provider import SqliteProvider +from crewai.state.runtime import RuntimeState from crewai.task import Task @@ -116,35 +120,41 @@ class TestFindCheckpoint: class TestPrune: def test_prune_keeps_newest(self) -> None: with tempfile.TemporaryDirectory() as d: + branch_dir = os.path.join(d, "main") + os.makedirs(branch_dir) for i in range(5): - path = os.path.join(d, f"cp_{i}.json") + path = os.path.join(branch_dir, f"cp_{i}.json") with open(path, "w") as f: f.write("{}") # Ensure distinct mtime time.sleep(0.01) - JsonProvider().prune(d, max_keep=2) - remaining = os.listdir(d) + JsonProvider().prune(d, max_keep=2, branch="main") + remaining = os.listdir(branch_dir) assert len(remaining) == 2 assert "cp_3.json" in remaining assert "cp_4.json" in remaining def test_prune_zero_removes_all(self) -> None: with tempfile.TemporaryDirectory() as d: + branch_dir = os.path.join(d, "main") + os.makedirs(branch_dir) for i in range(3): - with open(os.path.join(d, f"cp_{i}.json"), "w") as f: + with open(os.path.join(branch_dir, f"cp_{i}.json"), "w") as f: f.write("{}") - JsonProvider().prune(d, max_keep=0) - assert os.listdir(d) == [] + JsonProvider().prune(d, max_keep=0, branch="main") + assert os.listdir(branch_dir) == [] def test_prune_more_than_existing(self) -> None: with tempfile.TemporaryDirectory() as d: - with open(os.path.join(d, "cp.json"), "w") as f: + branch_dir = os.path.join(d, "main") + os.makedirs(branch_dir) + with open(os.path.join(branch_dir, "cp.json"), "w") as f: f.write("{}") - JsonProvider().prune(d, max_keep=10) - assert len(os.listdir(d)) == 1 + JsonProvider().prune(d, max_keep=10, branch="main") + assert len(os.listdir(branch_dir)) == 1 # ---------- CheckpointConfig ---------- @@ -167,3 +177,273 @@ class TestCheckpointConfig: on_events=["task_completed", "crew_kickoff_completed"] ) assert cfg.trigger_events == {"task_completed", "crew_kickoff_completed"} + + +# ---------- RuntimeState lineage ---------- + + +class TestRuntimeStateLineage: + def _make_state(self) -> RuntimeState: + from crewai import Agent, Crew + + agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini") + crew = Crew(agents=[agent], tasks=[], verbose=False) + return RuntimeState(root=[crew]) + + def test_default_lineage_fields(self) -> None: + state = self._make_state() + assert state._checkpoint_id is None + assert state._parent_id is None + assert state._branch == "main" + + def test_serialize_includes_lineage(self) -> None: + state = self._make_state() + state._parent_id = "parent456" + state._branch = "experiment" + dumped = json.loads(state.model_dump_json()) + assert dumped["parent_id"] == "parent456" + assert dumped["branch"] == "experiment" + assert "checkpoint_id" not in dumped + + def test_deserialize_restores_lineage(self) -> None: + state = self._make_state() + state._parent_id = "parent456" + state._branch = "experiment" + raw = state.model_dump_json() + restored = RuntimeState.model_validate_json( + raw, context={"from_checkpoint": True} + ) + assert restored._parent_id == "parent456" + assert restored._branch == "experiment" + + def test_deserialize_defaults_missing_lineage(self) -> None: + state = self._make_state() + raw = state.model_dump_json() + data = json.loads(raw) + data.pop("parent_id", None) + data.pop("branch", None) + restored = RuntimeState.model_validate_json( + json.dumps(data), context={"from_checkpoint": True} + ) + assert restored._parent_id is None + assert restored._branch == "main" + + def test_from_checkpoint_sets_checkpoint_id(self) -> None: + """from_checkpoint sets _checkpoint_id from the location, not the blob.""" + state = self._make_state() + state._provider = JsonProvider() + with tempfile.TemporaryDirectory() as d: + loc = state.checkpoint(d) + written_id = state._checkpoint_id + + provider = JsonProvider() + restored = RuntimeState.from_checkpoint( + loc, provider, context={"from_checkpoint": True} + ) + assert restored._checkpoint_id == written_id + assert restored._parent_id == written_id + + def test_fork_sets_branch(self) -> None: + state = self._make_state() + state._checkpoint_id = "abc12345" + state._parent_id = "abc12345" + state.fork("my-experiment") + assert state._branch == "my-experiment" + assert state._parent_id == "abc12345" + + def test_fork_auto_branch(self) -> None: + state = self._make_state() + state._checkpoint_id = "20260409T120000_abc12345" + state.fork() + assert state._branch == "fork/20260409T120000_abc12345" + + def test_fork_no_checkpoint_id_unique(self) -> None: + state = self._make_state() + state.fork() + assert state._branch.startswith("fork/") + assert len(state._branch) == len("fork/") + 8 + # Two forks without checkpoint_id produce different branches + first = state._branch + state.fork() + assert state._branch != first + + +# ---------- JsonProvider forking ---------- + + +class TestJsonProviderFork: + def test_checkpoint_writes_to_branch_subdir(self) -> None: + provider = JsonProvider() + with tempfile.TemporaryDirectory() as d: + path = provider.checkpoint("{}", d, branch="main") + assert "/main/" in path + assert path.endswith(".json") + assert os.path.isfile(path) + + def test_checkpoint_fork_branch_subdir(self) -> None: + provider = JsonProvider() + with tempfile.TemporaryDirectory() as d: + path = provider.checkpoint("{}", d, branch="fork/exp1") + assert "/fork/exp1/" in path + assert os.path.isfile(path) + + def test_prune_branch_aware(self) -> None: + provider = JsonProvider() + with tempfile.TemporaryDirectory() as d: + # Write 3 checkpoints on main, 2 on fork + for _ in range(3): + provider.checkpoint("{}", d, branch="main") + time.sleep(0.01) + for _ in range(2): + provider.checkpoint("{}", d, branch="fork/a") + time.sleep(0.01) + + # Prune main to 1 + provider.prune(d, max_keep=1, branch="main") + + main_dir = os.path.join(d, "main") + fork_dir = os.path.join(d, "fork", "a") + assert len(os.listdir(main_dir)) == 1 + assert len(os.listdir(fork_dir)) == 2 # untouched + + def test_extract_id(self) -> None: + provider = JsonProvider() + assert provider.extract_id("/dir/main/20260409T120000_abc12345_p-none.json") == "20260409T120000_abc12345" + assert provider.extract_id("/dir/main/20260409T120000_abc12345_p-20260409T115900_def67890.json") == "20260409T120000_abc12345" + + def test_branch_traversal_rejected(self) -> None: + provider = JsonProvider() + with tempfile.TemporaryDirectory() as d: + with pytest.raises(ValueError, match="escapes checkpoint directory"): + provider.checkpoint("{}", d, branch="../../etc") + with pytest.raises(ValueError, match="escapes checkpoint directory"): + provider.prune(d, max_keep=1, branch="../../etc") + + def test_filename_encodes_parent_id(self) -> None: + provider = JsonProvider() + with tempfile.TemporaryDirectory() as d: + # First checkpoint — no parent + path1 = provider.checkpoint("{}", d, branch="main") + assert "_p-none.json" in path1 + + # Second checkpoint — with parent + id1 = provider.extract_id(path1) + path2 = provider.checkpoint("{}", d, parent_id=id1, branch="main") + assert f"_p-{id1}.json" in path2 + + def test_checkpoint_chaining(self) -> None: + """RuntimeState.checkpoint() chains parent_id after each write.""" + state = self._make_state() + state._provider = JsonProvider() + with tempfile.TemporaryDirectory() as d: + state.checkpoint(d) + id1 = state._checkpoint_id + assert id1 is not None + assert state._parent_id == id1 + + loc2 = state.checkpoint(d) + id2 = state._checkpoint_id + assert id2 is not None + assert id2 != id1 + assert state._parent_id == id2 + + # Verify the second checkpoint blob has parent_id == id1 + with open(loc2) as f: + data2 = json.loads(f.read()) + assert data2["parent_id"] == id1 + + @pytest.mark.asyncio + async def test_acheckpoint_chaining(self) -> None: + """Async checkpoint path chains lineage identically to sync.""" + state = self._make_state() + state._provider = JsonProvider() + with tempfile.TemporaryDirectory() as d: + await state.acheckpoint(d) + id1 = state._checkpoint_id + assert id1 is not None + + loc2 = await state.acheckpoint(d) + id2 = state._checkpoint_id + assert id2 != id1 + assert state._parent_id == id2 + + with open(loc2) as f: + data2 = json.loads(f.read()) + assert data2["parent_id"] == id1 + + def _make_state(self) -> RuntimeState: + from crewai import Agent, Crew + + agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini") + crew = Crew(agents=[agent], tasks=[], verbose=False) + return RuntimeState(root=[crew]) + + +# ---------- SqliteProvider forking ---------- + + +class TestSqliteProviderFork: + def test_checkpoint_stores_branch_and_parent(self) -> None: + provider = SqliteProvider() + with tempfile.TemporaryDirectory() as d: + db = os.path.join(d, "cp.db") + loc = provider.checkpoint("{}", db, parent_id="p1", branch="exp") + cid = provider.extract_id(loc) + + with sqlite3.connect(db) as conn: + row = conn.execute( + "SELECT parent_id, branch FROM checkpoints WHERE id = ?", + (cid,), + ).fetchone() + assert row == ("p1", "exp") + + def test_prune_branch_aware(self) -> None: + provider = SqliteProvider() + with tempfile.TemporaryDirectory() as d: + db = os.path.join(d, "cp.db") + for _ in range(3): + provider.checkpoint("{}", db, branch="main") + for _ in range(2): + provider.checkpoint("{}", db, branch="fork/a") + + provider.prune(db, max_keep=1, branch="main") + + with sqlite3.connect(db) as conn: + main_count = conn.execute( + "SELECT COUNT(*) FROM checkpoints WHERE branch = 'main'" + ).fetchone()[0] + fork_count = conn.execute( + "SELECT COUNT(*) FROM checkpoints WHERE branch = 'fork/a'" + ).fetchone()[0] + assert main_count == 1 + assert fork_count == 2 + + def test_extract_id(self) -> None: + provider = SqliteProvider() + assert provider.extract_id("/path/to/db#abc123") == "abc123" + + def test_checkpoint_chaining_sqlite(self) -> None: + state = self._make_state() + state._provider = SqliteProvider() + with tempfile.TemporaryDirectory() as d: + db = os.path.join(d, "cp.db") + state.checkpoint(db) + id1 = state._checkpoint_id + + state.checkpoint(db) + id2 = state._checkpoint_id + assert id2 != id1 + + # Second row should have parent_id == id1 + with sqlite3.connect(db) as conn: + row = conn.execute( + "SELECT parent_id FROM checkpoints WHERE id = ?", (id2,) + ).fetchone() + assert row[0] == id1 + + def _make_state(self) -> RuntimeState: + from crewai import Agent, Crew + + agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini") + crew = Crew(agents=[agent], tasks=[], verbose=False) + return RuntimeState(root=[crew])