mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-19 09:22:39 +00:00
Write the crewAI package version into every checkpoint blob. On restore, run version-based migrations so older checkpoints can be transformed forward to the current format. Adds crewai.utilities.version module.
483 lines
17 KiB
Python
483 lines
17 KiB
Python
"""Tests for CheckpointConfig, checkpoint listener, pruning, and forking."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import sqlite3
|
|
import tempfile
|
|
import time
|
|
from typing import Any
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from crewai.agent.core import Agent
|
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
|
from crewai.crew import Crew
|
|
from crewai.flow.flow import Flow, start
|
|
from crewai.state.checkpoint_config import CheckpointConfig
|
|
from crewai.state.checkpoint_listener import (
|
|
_find_checkpoint,
|
|
_resolve,
|
|
_SENTINEL,
|
|
)
|
|
from crewai.state.provider.json_provider import JsonProvider
|
|
from crewai.state.provider.sqlite_provider import SqliteProvider
|
|
from crewai.state.runtime import RuntimeState
|
|
from crewai.task import Task
|
|
|
|
|
|
# ---------- _resolve ----------
|
|
|
|
|
|
class TestResolve:
|
|
def test_none_returns_none(self) -> None:
|
|
assert _resolve(None) is None
|
|
|
|
def test_false_returns_sentinel(self) -> None:
|
|
assert _resolve(False) is _SENTINEL
|
|
|
|
def test_true_returns_config(self) -> None:
|
|
result = _resolve(True)
|
|
assert isinstance(result, CheckpointConfig)
|
|
assert result.location == "./.checkpoints"
|
|
|
|
def test_config_returns_config(self) -> None:
|
|
cfg = CheckpointConfig(location="/tmp/cp")
|
|
assert _resolve(cfg) is cfg
|
|
|
|
|
|
# ---------- _find_checkpoint inheritance ----------
|
|
|
|
|
|
class TestFindCheckpoint:
|
|
def _make_agent(self, checkpoint: Any = None) -> Agent:
|
|
return Agent(role="r", goal="g", backstory="b", checkpoint=checkpoint)
|
|
|
|
def _make_crew(
|
|
self, agents: list[Agent], checkpoint: Any = None
|
|
) -> Crew:
|
|
crew = Crew(agents=agents, tasks=[], checkpoint=checkpoint)
|
|
for a in agents:
|
|
a.crew = crew
|
|
return crew
|
|
|
|
def test_crew_true(self) -> None:
|
|
a = self._make_agent()
|
|
self._make_crew([a], checkpoint=True)
|
|
cfg = _find_checkpoint(a)
|
|
assert isinstance(cfg, CheckpointConfig)
|
|
|
|
def test_crew_true_agent_false_opts_out(self) -> None:
|
|
a = self._make_agent(checkpoint=False)
|
|
self._make_crew([a], checkpoint=True)
|
|
assert _find_checkpoint(a) is None
|
|
|
|
def test_crew_none_agent_none(self) -> None:
|
|
a = self._make_agent()
|
|
self._make_crew([a])
|
|
assert _find_checkpoint(a) is None
|
|
|
|
def test_agent_config_overrides_crew(self) -> None:
|
|
a = self._make_agent(
|
|
checkpoint=CheckpointConfig(location="/agent_cp")
|
|
)
|
|
self._make_crew([a], checkpoint=True)
|
|
cfg = _find_checkpoint(a)
|
|
assert isinstance(cfg, CheckpointConfig)
|
|
assert cfg.location == "/agent_cp"
|
|
|
|
def test_task_inherits_from_crew(self) -> None:
|
|
a = self._make_agent()
|
|
self._make_crew([a], checkpoint=True)
|
|
task = Task(description="d", expected_output="e", agent=a)
|
|
cfg = _find_checkpoint(task)
|
|
assert isinstance(cfg, CheckpointConfig)
|
|
|
|
def test_task_agent_false_blocks(self) -> None:
|
|
a = self._make_agent(checkpoint=False)
|
|
self._make_crew([a], checkpoint=True)
|
|
task = Task(description="d", expected_output="e", agent=a)
|
|
assert _find_checkpoint(task) is None
|
|
|
|
def test_flow_direct(self) -> None:
|
|
flow = Flow(checkpoint=True)
|
|
cfg = _find_checkpoint(flow)
|
|
assert isinstance(cfg, CheckpointConfig)
|
|
|
|
def test_flow_none(self) -> None:
|
|
flow = Flow()
|
|
assert _find_checkpoint(flow) is None
|
|
|
|
def test_unknown_source(self) -> None:
|
|
assert _find_checkpoint("random") is None
|
|
|
|
|
|
# ---------- _prune ----------
|
|
|
|
|
|
class TestPrune:
|
|
def test_prune_keeps_newest(self) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
branch_dir = os.path.join(d, "main")
|
|
os.makedirs(branch_dir)
|
|
for i in range(5):
|
|
path = os.path.join(branch_dir, f"cp_{i}.json")
|
|
with open(path, "w") as f:
|
|
f.write("{}")
|
|
# Ensure distinct mtime
|
|
time.sleep(0.01)
|
|
|
|
JsonProvider().prune(d, max_keep=2, branch="main")
|
|
remaining = os.listdir(branch_dir)
|
|
assert len(remaining) == 2
|
|
assert "cp_3.json" in remaining
|
|
assert "cp_4.json" in remaining
|
|
|
|
def test_prune_zero_removes_all(self) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
branch_dir = os.path.join(d, "main")
|
|
os.makedirs(branch_dir)
|
|
for i in range(3):
|
|
with open(os.path.join(branch_dir, f"cp_{i}.json"), "w") as f:
|
|
f.write("{}")
|
|
|
|
JsonProvider().prune(d, max_keep=0, branch="main")
|
|
assert os.listdir(branch_dir) == []
|
|
|
|
def test_prune_more_than_existing(self) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
branch_dir = os.path.join(d, "main")
|
|
os.makedirs(branch_dir)
|
|
with open(os.path.join(branch_dir, "cp.json"), "w") as f:
|
|
f.write("{}")
|
|
|
|
JsonProvider().prune(d, max_keep=10, branch="main")
|
|
assert len(os.listdir(branch_dir)) == 1
|
|
|
|
|
|
# ---------- CheckpointConfig ----------
|
|
|
|
|
|
class TestCheckpointConfig:
|
|
def test_defaults(self) -> None:
|
|
cfg = CheckpointConfig()
|
|
assert cfg.location == "./.checkpoints"
|
|
assert cfg.on_events == ["task_completed"]
|
|
assert cfg.max_checkpoints is None
|
|
assert not cfg.trigger_all
|
|
|
|
def test_trigger_all(self) -> None:
|
|
cfg = CheckpointConfig(on_events=["*"])
|
|
assert cfg.trigger_all
|
|
|
|
def test_trigger_events(self) -> None:
|
|
cfg = CheckpointConfig(
|
|
on_events=["task_completed", "crew_kickoff_completed"]
|
|
)
|
|
assert cfg.trigger_events == {"task_completed", "crew_kickoff_completed"}
|
|
|
|
|
|
# ---------- RuntimeState lineage ----------
|
|
|
|
|
|
class TestRuntimeStateLineage:
|
|
def _make_state(self) -> RuntimeState:
|
|
from crewai import Agent, Crew
|
|
|
|
agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini")
|
|
crew = Crew(agents=[agent], tasks=[], verbose=False)
|
|
return RuntimeState(root=[crew])
|
|
|
|
def test_default_lineage_fields(self) -> None:
|
|
state = self._make_state()
|
|
assert state._checkpoint_id is None
|
|
assert state._parent_id is None
|
|
assert state._branch == "main"
|
|
|
|
def test_serialize_includes_version(self) -> None:
|
|
from crewai.utilities.version import get_crewai_version
|
|
|
|
state = self._make_state()
|
|
dumped = json.loads(state.model_dump_json())
|
|
assert dumped["crewai_version"] == get_crewai_version()
|
|
|
|
def test_deserialize_migrates_on_version_mismatch(self, caplog: Any) -> None:
|
|
import logging
|
|
|
|
state = self._make_state()
|
|
raw = state.model_dump_json()
|
|
data = json.loads(raw)
|
|
data["crewai_version"] = "0.1.0"
|
|
with caplog.at_level(logging.DEBUG):
|
|
RuntimeState.model_validate_json(
|
|
json.dumps(data), context={"from_checkpoint": True}
|
|
)
|
|
assert "Migrating checkpoint from crewAI 0.1.0" in caplog.text
|
|
|
|
def test_deserialize_warns_on_missing_version(self, caplog: Any) -> None:
|
|
import logging
|
|
|
|
state = self._make_state()
|
|
raw = state.model_dump_json()
|
|
data = json.loads(raw)
|
|
data.pop("crewai_version", None)
|
|
with caplog.at_level(logging.WARNING):
|
|
RuntimeState.model_validate_json(
|
|
json.dumps(data), context={"from_checkpoint": True}
|
|
)
|
|
assert "treating as 0.0.0" in caplog.text
|
|
|
|
def test_serialize_includes_lineage(self) -> None:
|
|
state = self._make_state()
|
|
state._parent_id = "parent456"
|
|
state._branch = "experiment"
|
|
dumped = json.loads(state.model_dump_json())
|
|
assert dumped["parent_id"] == "parent456"
|
|
assert dumped["branch"] == "experiment"
|
|
assert "checkpoint_id" not in dumped
|
|
|
|
def test_deserialize_restores_lineage(self) -> None:
|
|
state = self._make_state()
|
|
state._parent_id = "parent456"
|
|
state._branch = "experiment"
|
|
raw = state.model_dump_json()
|
|
restored = RuntimeState.model_validate_json(
|
|
raw, context={"from_checkpoint": True}
|
|
)
|
|
assert restored._parent_id == "parent456"
|
|
assert restored._branch == "experiment"
|
|
|
|
def test_deserialize_defaults_missing_lineage(self) -> None:
|
|
state = self._make_state()
|
|
raw = state.model_dump_json()
|
|
data = json.loads(raw)
|
|
data.pop("parent_id", None)
|
|
data.pop("branch", None)
|
|
restored = RuntimeState.model_validate_json(
|
|
json.dumps(data), context={"from_checkpoint": True}
|
|
)
|
|
assert restored._parent_id is None
|
|
assert restored._branch == "main"
|
|
|
|
def test_from_checkpoint_sets_checkpoint_id(self) -> None:
|
|
"""from_checkpoint sets _checkpoint_id from the location, not the blob."""
|
|
state = self._make_state()
|
|
state._provider = JsonProvider()
|
|
with tempfile.TemporaryDirectory() as d:
|
|
loc = state.checkpoint(d)
|
|
written_id = state._checkpoint_id
|
|
|
|
provider = JsonProvider()
|
|
restored = RuntimeState.from_checkpoint(
|
|
loc, provider, context={"from_checkpoint": True}
|
|
)
|
|
assert restored._checkpoint_id == written_id
|
|
assert restored._parent_id == written_id
|
|
|
|
def test_fork_sets_branch(self) -> None:
|
|
state = self._make_state()
|
|
state._checkpoint_id = "abc12345"
|
|
state._parent_id = "abc12345"
|
|
state.fork("my-experiment")
|
|
assert state._branch == "my-experiment"
|
|
assert state._parent_id == "abc12345"
|
|
|
|
def test_fork_auto_branch(self) -> None:
|
|
state = self._make_state()
|
|
state._checkpoint_id = "20260409T120000_abc12345"
|
|
state.fork()
|
|
assert state._branch == "fork/20260409T120000_abc12345"
|
|
|
|
def test_fork_no_checkpoint_id_unique(self) -> None:
|
|
state = self._make_state()
|
|
state.fork()
|
|
assert state._branch.startswith("fork/")
|
|
assert len(state._branch) == len("fork/") + 8
|
|
# Two forks without checkpoint_id produce different branches
|
|
first = state._branch
|
|
state.fork()
|
|
assert state._branch != first
|
|
|
|
|
|
# ---------- JsonProvider forking ----------
|
|
|
|
|
|
class TestJsonProviderFork:
|
|
def test_checkpoint_writes_to_branch_subdir(self) -> None:
|
|
provider = JsonProvider()
|
|
with tempfile.TemporaryDirectory() as d:
|
|
path = provider.checkpoint("{}", d, branch="main")
|
|
assert "/main/" in path
|
|
assert path.endswith(".json")
|
|
assert os.path.isfile(path)
|
|
|
|
def test_checkpoint_fork_branch_subdir(self) -> None:
|
|
provider = JsonProvider()
|
|
with tempfile.TemporaryDirectory() as d:
|
|
path = provider.checkpoint("{}", d, branch="fork/exp1")
|
|
assert "/fork/exp1/" in path
|
|
assert os.path.isfile(path)
|
|
|
|
def test_prune_branch_aware(self) -> None:
|
|
provider = JsonProvider()
|
|
with tempfile.TemporaryDirectory() as d:
|
|
# Write 3 checkpoints on main, 2 on fork
|
|
for _ in range(3):
|
|
provider.checkpoint("{}", d, branch="main")
|
|
time.sleep(0.01)
|
|
for _ in range(2):
|
|
provider.checkpoint("{}", d, branch="fork/a")
|
|
time.sleep(0.01)
|
|
|
|
# Prune main to 1
|
|
provider.prune(d, max_keep=1, branch="main")
|
|
|
|
main_dir = os.path.join(d, "main")
|
|
fork_dir = os.path.join(d, "fork", "a")
|
|
assert len(os.listdir(main_dir)) == 1
|
|
assert len(os.listdir(fork_dir)) == 2 # untouched
|
|
|
|
def test_extract_id(self) -> None:
|
|
provider = JsonProvider()
|
|
assert provider.extract_id("/dir/main/20260409T120000_abc12345_p-none.json") == "20260409T120000_abc12345"
|
|
assert provider.extract_id("/dir/main/20260409T120000_abc12345_p-20260409T115900_def67890.json") == "20260409T120000_abc12345"
|
|
|
|
def test_branch_traversal_rejected(self) -> None:
|
|
provider = JsonProvider()
|
|
with tempfile.TemporaryDirectory() as d:
|
|
with pytest.raises(ValueError, match="escapes checkpoint directory"):
|
|
provider.checkpoint("{}", d, branch="../../etc")
|
|
with pytest.raises(ValueError, match="escapes checkpoint directory"):
|
|
provider.prune(d, max_keep=1, branch="../../etc")
|
|
|
|
def test_filename_encodes_parent_id(self) -> None:
|
|
provider = JsonProvider()
|
|
with tempfile.TemporaryDirectory() as d:
|
|
# First checkpoint — no parent
|
|
path1 = provider.checkpoint("{}", d, branch="main")
|
|
assert "_p-none.json" in path1
|
|
|
|
# Second checkpoint — with parent
|
|
id1 = provider.extract_id(path1)
|
|
path2 = provider.checkpoint("{}", d, parent_id=id1, branch="main")
|
|
assert f"_p-{id1}.json" in path2
|
|
|
|
def test_checkpoint_chaining(self) -> None:
|
|
"""RuntimeState.checkpoint() chains parent_id after each write."""
|
|
state = self._make_state()
|
|
state._provider = JsonProvider()
|
|
with tempfile.TemporaryDirectory() as d:
|
|
state.checkpoint(d)
|
|
id1 = state._checkpoint_id
|
|
assert id1 is not None
|
|
assert state._parent_id == id1
|
|
|
|
loc2 = state.checkpoint(d)
|
|
id2 = state._checkpoint_id
|
|
assert id2 is not None
|
|
assert id2 != id1
|
|
assert state._parent_id == id2
|
|
|
|
# Verify the second checkpoint blob has parent_id == id1
|
|
with open(loc2) as f:
|
|
data2 = json.loads(f.read())
|
|
assert data2["parent_id"] == id1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_acheckpoint_chaining(self) -> None:
|
|
"""Async checkpoint path chains lineage identically to sync."""
|
|
state = self._make_state()
|
|
state._provider = JsonProvider()
|
|
with tempfile.TemporaryDirectory() as d:
|
|
await state.acheckpoint(d)
|
|
id1 = state._checkpoint_id
|
|
assert id1 is not None
|
|
|
|
loc2 = await state.acheckpoint(d)
|
|
id2 = state._checkpoint_id
|
|
assert id2 != id1
|
|
assert state._parent_id == id2
|
|
|
|
with open(loc2) as f:
|
|
data2 = json.loads(f.read())
|
|
assert data2["parent_id"] == id1
|
|
|
|
def _make_state(self) -> RuntimeState:
|
|
from crewai import Agent, Crew
|
|
|
|
agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini")
|
|
crew = Crew(agents=[agent], tasks=[], verbose=False)
|
|
return RuntimeState(root=[crew])
|
|
|
|
|
|
# ---------- SqliteProvider forking ----------
|
|
|
|
|
|
class TestSqliteProviderFork:
|
|
def test_checkpoint_stores_branch_and_parent(self) -> None:
|
|
provider = SqliteProvider()
|
|
with tempfile.TemporaryDirectory() as d:
|
|
db = os.path.join(d, "cp.db")
|
|
loc = provider.checkpoint("{}", db, parent_id="p1", branch="exp")
|
|
cid = provider.extract_id(loc)
|
|
|
|
with sqlite3.connect(db) as conn:
|
|
row = conn.execute(
|
|
"SELECT parent_id, branch FROM checkpoints WHERE id = ?",
|
|
(cid,),
|
|
).fetchone()
|
|
assert row == ("p1", "exp")
|
|
|
|
def test_prune_branch_aware(self) -> None:
|
|
provider = SqliteProvider()
|
|
with tempfile.TemporaryDirectory() as d:
|
|
db = os.path.join(d, "cp.db")
|
|
for _ in range(3):
|
|
provider.checkpoint("{}", db, branch="main")
|
|
for _ in range(2):
|
|
provider.checkpoint("{}", db, branch="fork/a")
|
|
|
|
provider.prune(db, max_keep=1, branch="main")
|
|
|
|
with sqlite3.connect(db) as conn:
|
|
main_count = conn.execute(
|
|
"SELECT COUNT(*) FROM checkpoints WHERE branch = 'main'"
|
|
).fetchone()[0]
|
|
fork_count = conn.execute(
|
|
"SELECT COUNT(*) FROM checkpoints WHERE branch = 'fork/a'"
|
|
).fetchone()[0]
|
|
assert main_count == 1
|
|
assert fork_count == 2
|
|
|
|
def test_extract_id(self) -> None:
|
|
provider = SqliteProvider()
|
|
assert provider.extract_id("/path/to/db#abc123") == "abc123"
|
|
|
|
def test_checkpoint_chaining_sqlite(self) -> None:
|
|
state = self._make_state()
|
|
state._provider = SqliteProvider()
|
|
with tempfile.TemporaryDirectory() as d:
|
|
db = os.path.join(d, "cp.db")
|
|
state.checkpoint(db)
|
|
id1 = state._checkpoint_id
|
|
|
|
state.checkpoint(db)
|
|
id2 = state._checkpoint_id
|
|
assert id2 != id1
|
|
|
|
# Second row should have parent_id == id1
|
|
with sqlite3.connect(db) as conn:
|
|
row = conn.execute(
|
|
"SELECT parent_id FROM checkpoints WHERE id = ?", (id2,)
|
|
).fetchone()
|
|
assert row[0] == id1
|
|
|
|
def _make_state(self) -> RuntimeState:
|
|
from crewai import Agent, Crew
|
|
|
|
agent = Agent(role="r", goal="g", backstory="b", llm="gpt-4o-mini")
|
|
crew = Crew(agents=[agent], tasks=[], verbose=False)
|
|
return RuntimeState(root=[crew])
|