mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-10 04:52:40 +00:00
Accept CheckpointConfig on Crew and Flow kickoff/kickoff_async/akickoff. When restore_from is set, the entity resumes from that checkpoint. When only config fields are set, checkpointing is enabled for the run. Adds restore_from field (Path | str | None) to CheckpointConfig.
539 lines
19 KiB
Python
539 lines
19 KiB
Python
"""Tests for CheckpointConfig, checkpoint listener, pruning, and forking."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import sqlite3
|
|
import tempfile
|
|
import time
|
|
from typing import Any
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from crewai.agent.core import Agent
|
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
|
from crewai.crew import Crew
|
|
from crewai.flow.flow import Flow, start
|
|
from crewai.state.checkpoint_config import CheckpointConfig
|
|
from crewai.state.checkpoint_listener import (
|
|
_find_checkpoint,
|
|
_resolve,
|
|
_SENTINEL,
|
|
)
|
|
from crewai.state.provider.json_provider import JsonProvider
|
|
from crewai.state.provider.sqlite_provider import SqliteProvider
|
|
from crewai.state.runtime import RuntimeState
|
|
from crewai.task import Task
|
|
|
|
|
|
# ---------- _resolve ----------
|
|
|
|
|
|
class TestResolve:
|
|
def test_none_returns_none(self) -> None:
|
|
assert _resolve(None) is None
|
|
|
|
def test_false_returns_sentinel(self) -> None:
|
|
assert _resolve(False) is _SENTINEL
|
|
|
|
def test_true_returns_config(self) -> None:
|
|
result = _resolve(True)
|
|
assert isinstance(result, CheckpointConfig)
|
|
assert result.location == "./.checkpoints"
|
|
|
|
def test_config_returns_config(self) -> None:
|
|
cfg = CheckpointConfig(location="/tmp/cp")
|
|
assert _resolve(cfg) is cfg
|
|
|
|
|
|
# ---------- _find_checkpoint inheritance ----------
|
|
|
|
|
|
class TestFindCheckpoint:
|
|
def _make_agent(self, checkpoint: Any = None) -> Agent:
|
|
return Agent(role="r", goal="g", backstory="b", checkpoint=checkpoint)
|
|
|
|
def _make_crew(
|
|
self, agents: list[Agent], checkpoint: Any = None
|
|
) -> Crew:
|
|
crew = Crew(agents=agents, tasks=[], checkpoint=checkpoint)
|
|
for a in agents:
|
|
a.crew = crew
|
|
return crew
|
|
|
|
def test_crew_true(self) -> None:
|
|
a = self._make_agent()
|
|
self._make_crew([a], checkpoint=True)
|
|
cfg = _find_checkpoint(a)
|
|
assert isinstance(cfg, CheckpointConfig)
|
|
|
|
def test_crew_true_agent_false_opts_out(self) -> None:
|
|
a = self._make_agent(checkpoint=False)
|
|
self._make_crew([a], checkpoint=True)
|
|
assert _find_checkpoint(a) is None
|
|
|
|
def test_crew_none_agent_none(self) -> None:
|
|
a = self._make_agent()
|
|
self._make_crew([a])
|
|
assert _find_checkpoint(a) is None
|
|
|
|
def test_agent_config_overrides_crew(self) -> None:
|
|
a = self._make_agent(
|
|
checkpoint=CheckpointConfig(location="/agent_cp")
|
|
)
|
|
self._make_crew([a], checkpoint=True)
|
|
cfg = _find_checkpoint(a)
|
|
assert isinstance(cfg, CheckpointConfig)
|
|
assert cfg.location == "/agent_cp"
|
|
|
|
def test_task_inherits_from_crew(self) -> None:
|
|
a = self._make_agent()
|
|
self._make_crew([a], checkpoint=True)
|
|
task = Task(description="d", expected_output="e", agent=a)
|
|
cfg = _find_checkpoint(task)
|
|
assert isinstance(cfg, CheckpointConfig)
|
|
|
|
def test_task_agent_false_blocks(self) -> None:
|
|
a = self._make_agent(checkpoint=False)
|
|
self._make_crew([a], checkpoint=True)
|
|
task = Task(description="d", expected_output="e", agent=a)
|
|
assert _find_checkpoint(task) is None
|
|
|
|
def test_flow_direct(self) -> None:
|
|
flow = Flow(checkpoint=True)
|
|
cfg = _find_checkpoint(flow)
|
|
assert isinstance(cfg, CheckpointConfig)
|
|
|
|
def test_flow_none(self) -> None:
|
|
flow = Flow()
|
|
assert _find_checkpoint(flow) is None
|
|
|
|
def test_unknown_source(self) -> None:
|
|
assert _find_checkpoint("random") is None
|
|
|
|
|
|
# ---------- _prune ----------
|
|
|
|
|
|
class TestPrune:
|
|
def test_prune_keeps_newest(self) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
branch_dir = os.path.join(d, "main")
|
|
os.makedirs(branch_dir)
|
|
for i in range(5):
|
|
path = os.path.join(branch_dir, f"cp_{i}.json")
|
|
with open(path, "w") as f:
|
|
f.write("{}")
|
|
# Ensure distinct mtime
|
|
time.sleep(0.01)
|
|
|
|
JsonProvider().prune(d, max_keep=2, branch="main")
|
|
remaining = os.listdir(branch_dir)
|
|
assert len(remaining) == 2
|
|
assert "cp_3.json" in remaining
|
|
assert "cp_4.json" in remaining
|
|
|
|
def test_prune_zero_removes_all(self) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
branch_dir = os.path.join(d, "main")
|
|
os.makedirs(branch_dir)
|
|
for i in range(3):
|
|
with open(os.path.join(branch_dir, f"cp_{i}.json"), "w") as f:
|
|
f.write("{}")
|
|
|
|
JsonProvider().prune(d, max_keep=0, branch="main")
|
|
assert os.listdir(branch_dir) == []
|
|
|
|
def test_prune_more_than_existing(self) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
branch_dir = os.path.join(d, "main")
|
|
os.makedirs(branch_dir)
|
|
with open(os.path.join(branch_dir, "cp.json"), "w") as f:
|
|
f.write("{}")
|
|
|
|
JsonProvider().prune(d, max_keep=10, branch="main")
|
|
assert len(os.listdir(branch_dir)) == 1
|
|
|
|
|
|
# ---------- CheckpointConfig ----------
|
|
|
|
|
|
class TestCheckpointConfig:
|
|
def test_defaults(self) -> None:
|
|
cfg = CheckpointConfig()
|
|
assert cfg.location == "./.checkpoints"
|
|
assert cfg.on_events == ["task_completed"]
|
|
assert cfg.max_checkpoints is None
|
|
assert not cfg.trigger_all
|
|
|
|
def test_trigger_all(self) -> None:
|
|
cfg = CheckpointConfig(on_events=["*"])
|
|
assert cfg.trigger_all
|
|
|
|
def test_restore_from_field(self) -> None:
|
|
cfg = CheckpointConfig(restore_from="/path/to/checkpoint.json")
|
|
assert cfg.restore_from == "/path/to/checkpoint.json"
|
|
|
|
def test_restore_from_default_none(self) -> None:
|
|
cfg = CheckpointConfig()
|
|
assert cfg.restore_from is None
|
|
|
|
def test_trigger_events(self) -> None:
|
|
cfg = CheckpointConfig(
|
|
on_events=["task_completed", "crew_kickoff_completed"]
|
|
)
|
|
assert cfg.trigger_events == {"task_completed", "crew_kickoff_completed"}
|
|
|
|
|
|
# ---------- RuntimeState lineage ----------
|
|
|
|
|
|
class TestRuntimeStateLineage:
|
|
def _make_state(self) -> RuntimeState:
|
|
from crewai import Agent, Crew
|
|
|
|
agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini")
|
|
crew = Crew(agents=[agent], tasks=[], verbose=False)
|
|
return RuntimeState(root=[crew])
|
|
|
|
def test_default_lineage_fields(self) -> None:
|
|
state = self._make_state()
|
|
assert state._checkpoint_id is None
|
|
assert state._parent_id is None
|
|
assert state._branch == "main"
|
|
|
|
def test_serialize_includes_version(self) -> None:
|
|
from crewai.utilities.version import get_crewai_version
|
|
|
|
state = self._make_state()
|
|
dumped = json.loads(state.model_dump_json())
|
|
assert dumped["crewai_version"] == get_crewai_version()
|
|
|
|
def test_deserialize_migrates_on_version_mismatch(self, caplog: Any) -> None:
|
|
import logging
|
|
|
|
state = self._make_state()
|
|
raw = state.model_dump_json()
|
|
data = json.loads(raw)
|
|
data["crewai_version"] = "0.1.0"
|
|
with caplog.at_level(logging.DEBUG):
|
|
RuntimeState.model_validate_json(
|
|
json.dumps(data), context={"from_checkpoint": True}
|
|
)
|
|
assert "Migrating checkpoint from crewAI 0.1.0" in caplog.text
|
|
|
|
def test_deserialize_warns_on_missing_version(self, caplog: Any) -> None:
|
|
import logging
|
|
|
|
state = self._make_state()
|
|
raw = state.model_dump_json()
|
|
data = json.loads(raw)
|
|
data.pop("crewai_version", None)
|
|
with caplog.at_level(logging.WARNING):
|
|
RuntimeState.model_validate_json(
|
|
json.dumps(data), context={"from_checkpoint": True}
|
|
)
|
|
assert "treating as 0.0.0" in caplog.text
|
|
|
|
def test_serialize_includes_lineage(self) -> None:
|
|
state = self._make_state()
|
|
state._parent_id = "parent456"
|
|
state._branch = "experiment"
|
|
dumped = json.loads(state.model_dump_json())
|
|
assert dumped["parent_id"] == "parent456"
|
|
assert dumped["branch"] == "experiment"
|
|
assert "checkpoint_id" not in dumped
|
|
|
|
def test_deserialize_restores_lineage(self) -> None:
|
|
state = self._make_state()
|
|
state._parent_id = "parent456"
|
|
state._branch = "experiment"
|
|
raw = state.model_dump_json()
|
|
restored = RuntimeState.model_validate_json(
|
|
raw, context={"from_checkpoint": True}
|
|
)
|
|
assert restored._parent_id == "parent456"
|
|
assert restored._branch == "experiment"
|
|
|
|
def test_deserialize_defaults_missing_lineage(self) -> None:
|
|
state = self._make_state()
|
|
raw = state.model_dump_json()
|
|
data = json.loads(raw)
|
|
data.pop("parent_id", None)
|
|
data.pop("branch", None)
|
|
restored = RuntimeState.model_validate_json(
|
|
json.dumps(data), context={"from_checkpoint": True}
|
|
)
|
|
assert restored._parent_id is None
|
|
assert restored._branch == "main"
|
|
|
|
def test_from_checkpoint_sets_checkpoint_id(self) -> None:
|
|
"""from_checkpoint sets _checkpoint_id from the location, not the blob."""
|
|
state = self._make_state()
|
|
state._provider = JsonProvider()
|
|
with tempfile.TemporaryDirectory() as d:
|
|
loc = state.checkpoint(d)
|
|
written_id = state._checkpoint_id
|
|
|
|
cfg = CheckpointConfig(restore_from=loc)
|
|
restored = RuntimeState.from_checkpoint(
|
|
cfg, context={"from_checkpoint": True}
|
|
)
|
|
assert restored._checkpoint_id == written_id
|
|
assert restored._parent_id == written_id
|
|
|
|
def test_fork_sets_branch(self) -> None:
|
|
state = self._make_state()
|
|
state._checkpoint_id = "abc12345"
|
|
state._parent_id = "abc12345"
|
|
state.fork("my-experiment")
|
|
assert state._branch == "my-experiment"
|
|
assert state._parent_id == "abc12345"
|
|
|
|
def test_fork_auto_branch(self) -> None:
|
|
state = self._make_state()
|
|
state._checkpoint_id = "20260409T120000_abc12345"
|
|
state.fork()
|
|
assert state._branch == "fork/20260409T120000_abc12345"
|
|
|
|
def test_fork_no_checkpoint_id_unique(self) -> None:
|
|
state = self._make_state()
|
|
state.fork()
|
|
assert state._branch.startswith("fork/")
|
|
assert len(state._branch) == len("fork/") + 8
|
|
# Two forks without checkpoint_id produce different branches
|
|
first = state._branch
|
|
state.fork()
|
|
assert state._branch != first
|
|
|
|
|
|
# ---------- JsonProvider forking ----------
|
|
|
|
|
|
class TestJsonProviderFork:
|
|
def test_checkpoint_writes_to_branch_subdir(self) -> None:
|
|
provider = JsonProvider()
|
|
with tempfile.TemporaryDirectory() as d:
|
|
path = provider.checkpoint("{}", d, branch="main")
|
|
assert "/main/" in path
|
|
assert path.endswith(".json")
|
|
assert os.path.isfile(path)
|
|
|
|
def test_checkpoint_fork_branch_subdir(self) -> None:
|
|
provider = JsonProvider()
|
|
with tempfile.TemporaryDirectory() as d:
|
|
path = provider.checkpoint("{}", d, branch="fork/exp1")
|
|
assert "/fork/exp1/" in path
|
|
assert os.path.isfile(path)
|
|
|
|
def test_prune_branch_aware(self) -> None:
|
|
provider = JsonProvider()
|
|
with tempfile.TemporaryDirectory() as d:
|
|
# Write 3 checkpoints on main, 2 on fork
|
|
for _ in range(3):
|
|
provider.checkpoint("{}", d, branch="main")
|
|
time.sleep(0.01)
|
|
for _ in range(2):
|
|
provider.checkpoint("{}", d, branch="fork/a")
|
|
time.sleep(0.01)
|
|
|
|
# Prune main to 1
|
|
provider.prune(d, max_keep=1, branch="main")
|
|
|
|
main_dir = os.path.join(d, "main")
|
|
fork_dir = os.path.join(d, "fork", "a")
|
|
assert len(os.listdir(main_dir)) == 1
|
|
assert len(os.listdir(fork_dir)) == 2 # untouched
|
|
|
|
def test_extract_id(self) -> None:
|
|
provider = JsonProvider()
|
|
assert provider.extract_id("/dir/main/20260409T120000_abc12345_p-none.json") == "20260409T120000_abc12345"
|
|
assert provider.extract_id("/dir/main/20260409T120000_abc12345_p-20260409T115900_def67890.json") == "20260409T120000_abc12345"
|
|
|
|
def test_branch_traversal_rejected(self) -> None:
|
|
provider = JsonProvider()
|
|
with tempfile.TemporaryDirectory() as d:
|
|
with pytest.raises(ValueError, match="escapes checkpoint directory"):
|
|
provider.checkpoint("{}", d, branch="../../etc")
|
|
with pytest.raises(ValueError, match="escapes checkpoint directory"):
|
|
provider.prune(d, max_keep=1, branch="../../etc")
|
|
|
|
def test_filename_encodes_parent_id(self) -> None:
|
|
provider = JsonProvider()
|
|
with tempfile.TemporaryDirectory() as d:
|
|
# First checkpoint — no parent
|
|
path1 = provider.checkpoint("{}", d, branch="main")
|
|
assert "_p-none.json" in path1
|
|
|
|
# Second checkpoint — with parent
|
|
id1 = provider.extract_id(path1)
|
|
path2 = provider.checkpoint("{}", d, parent_id=id1, branch="main")
|
|
assert f"_p-{id1}.json" in path2
|
|
|
|
def test_checkpoint_chaining(self) -> None:
|
|
"""RuntimeState.checkpoint() chains parent_id after each write."""
|
|
state = self._make_state()
|
|
state._provider = JsonProvider()
|
|
with tempfile.TemporaryDirectory() as d:
|
|
state.checkpoint(d)
|
|
id1 = state._checkpoint_id
|
|
assert id1 is not None
|
|
assert state._parent_id == id1
|
|
|
|
loc2 = state.checkpoint(d)
|
|
id2 = state._checkpoint_id
|
|
assert id2 is not None
|
|
assert id2 != id1
|
|
assert state._parent_id == id2
|
|
|
|
# Verify the second checkpoint blob has parent_id == id1
|
|
with open(loc2) as f:
|
|
data2 = json.loads(f.read())
|
|
assert data2["parent_id"] == id1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_acheckpoint_chaining(self) -> None:
|
|
"""Async checkpoint path chains lineage identically to sync."""
|
|
state = self._make_state()
|
|
state._provider = JsonProvider()
|
|
with tempfile.TemporaryDirectory() as d:
|
|
await state.acheckpoint(d)
|
|
id1 = state._checkpoint_id
|
|
assert id1 is not None
|
|
|
|
loc2 = await state.acheckpoint(d)
|
|
id2 = state._checkpoint_id
|
|
assert id2 != id1
|
|
assert state._parent_id == id2
|
|
|
|
with open(loc2) as f:
|
|
data2 = json.loads(f.read())
|
|
assert data2["parent_id"] == id1
|
|
|
|
def _make_state(self) -> RuntimeState:
|
|
from crewai import Agent, Crew
|
|
|
|
agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini")
|
|
crew = Crew(agents=[agent], tasks=[], verbose=False)
|
|
return RuntimeState(root=[crew])
|
|
|
|
|
|
# ---------- SqliteProvider forking ----------
|
|
|
|
|
|
class TestSqliteProviderFork:
|
|
def test_checkpoint_stores_branch_and_parent(self) -> None:
|
|
provider = SqliteProvider()
|
|
with tempfile.TemporaryDirectory() as d:
|
|
db = os.path.join(d, "cp.db")
|
|
loc = provider.checkpoint("{}", db, parent_id="p1", branch="exp")
|
|
cid = provider.extract_id(loc)
|
|
|
|
with sqlite3.connect(db) as conn:
|
|
row = conn.execute(
|
|
"SELECT parent_id, branch FROM checkpoints WHERE id = ?",
|
|
(cid,),
|
|
).fetchone()
|
|
assert row == ("p1", "exp")
|
|
|
|
def test_prune_branch_aware(self) -> None:
|
|
provider = SqliteProvider()
|
|
with tempfile.TemporaryDirectory() as d:
|
|
db = os.path.join(d, "cp.db")
|
|
for _ in range(3):
|
|
provider.checkpoint("{}", db, branch="main")
|
|
for _ in range(2):
|
|
provider.checkpoint("{}", db, branch="fork/a")
|
|
|
|
provider.prune(db, max_keep=1, branch="main")
|
|
|
|
with sqlite3.connect(db) as conn:
|
|
main_count = conn.execute(
|
|
"SELECT COUNT(*) FROM checkpoints WHERE branch = 'main'"
|
|
).fetchone()[0]
|
|
fork_count = conn.execute(
|
|
"SELECT COUNT(*) FROM checkpoints WHERE branch = 'fork/a'"
|
|
).fetchone()[0]
|
|
assert main_count == 1
|
|
assert fork_count == 2
|
|
|
|
def test_extract_id(self) -> None:
|
|
provider = SqliteProvider()
|
|
assert provider.extract_id("/path/to/db#abc123") == "abc123"
|
|
|
|
def test_checkpoint_chaining_sqlite(self) -> None:
|
|
state = self._make_state()
|
|
state._provider = SqliteProvider()
|
|
with tempfile.TemporaryDirectory() as d:
|
|
db = os.path.join(d, "cp.db")
|
|
state.checkpoint(db)
|
|
id1 = state._checkpoint_id
|
|
|
|
state.checkpoint(db)
|
|
id2 = state._checkpoint_id
|
|
assert id2 != id1
|
|
|
|
# Second row should have parent_id == id1
|
|
with sqlite3.connect(db) as conn:
|
|
row = conn.execute(
|
|
"SELECT parent_id FROM checkpoints WHERE id = ?", (id2,)
|
|
).fetchone()
|
|
assert row[0] == id1
|
|
|
|
def _make_state(self) -> RuntimeState:
|
|
from crewai import Agent, Crew
|
|
|
|
agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini")
|
|
crew = Crew(agents=[agent], tasks=[], verbose=False)
|
|
return RuntimeState(root=[crew])
|
|
|
|
|
|
# ---------- Kickoff from_checkpoint parameter ----------
|
|
|
|
|
|
class TestKickoffFromCheckpoint:
|
|
def test_crew_kickoff_delegates_to_from_checkpoint(self) -> None:
|
|
mock_restored = MagicMock(spec=Crew)
|
|
mock_restored.kickoff.return_value = "result"
|
|
|
|
cfg = CheckpointConfig(restore_from="/path/to/cp.json")
|
|
with patch.object(Crew, "from_checkpoint", return_value=mock_restored):
|
|
agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini")
|
|
crew = Crew(agents=[agent], tasks=[], verbose=False)
|
|
result = crew.kickoff(inputs={"k": "v"}, from_checkpoint=cfg)
|
|
|
|
mock_restored.kickoff.assert_called_once_with(
|
|
inputs={"k": "v"}, input_files=None
|
|
)
|
|
assert mock_restored.checkpoint.restore_from is None
|
|
assert result == "result"
|
|
|
|
def test_crew_kickoff_config_only_sets_checkpoint(self) -> None:
|
|
cfg = CheckpointConfig(on_events=["task_completed"])
|
|
agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini")
|
|
crew = Crew(agents=[agent], tasks=[], verbose=False)
|
|
assert crew.checkpoint is None
|
|
with patch("crewai.crew.get_env_context"), \
|
|
patch("crewai.crew.prepare_kickoff", side_effect=RuntimeError("stop")):
|
|
with pytest.raises(RuntimeError, match="stop"):
|
|
crew.kickoff(from_checkpoint=cfg)
|
|
assert isinstance(crew.checkpoint, CheckpointConfig)
|
|
assert crew.checkpoint.on_events == ["task_completed"]
|
|
|
|
def test_flow_kickoff_delegates_to_from_checkpoint(self) -> None:
|
|
mock_restored = MagicMock(spec=Flow)
|
|
mock_restored.kickoff.return_value = "flow_result"
|
|
|
|
cfg = CheckpointConfig(restore_from="/path/to/flow_cp.json")
|
|
with patch.object(Flow, "from_checkpoint", return_value=mock_restored):
|
|
flow = Flow()
|
|
result = flow.kickoff(from_checkpoint=cfg)
|
|
|
|
mock_restored.kickoff.assert_called_once_with(
|
|
inputs=None, input_files=None
|
|
)
|
|
assert mock_restored.checkpoint.restore_from is None
|
|
assert result == "flow_result"
|