mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-05 01:02:37 +00:00
Some checks failed
The test_older_than tests in both JSON and SQLite prune suites used hardcoded 2026-04-17 timestamps for the 'new' checkpoint. Once that date passes, the checkpoint is older than 1 day and gets pruned along with the 'old' one, causing assert count >= 1 to fail (count=0). Use 2099-01-01 for the 'new' checkpoint so tests remain stable. Co-authored-by: Joao Moura <joaomdmoura@gmail.com>
403 lines
14 KiB
Python
403 lines
14 KiB
Python
"""Tests for checkpoint CLI commands."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import sqlite3
|
|
import tempfile
|
|
import time
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Any
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
from crewai.cli.checkpoint_cli import (
|
|
_parse_checkpoint_json,
|
|
_parse_duration,
|
|
_prune_json,
|
|
_prune_sqlite,
|
|
_resolve_checkpoint,
|
|
_task_list_from_meta,
|
|
diff_checkpoints,
|
|
prune_checkpoints,
|
|
resume_checkpoint,
|
|
)
|
|
|
|
|
|
def _make_checkpoint_data(
|
|
tasks_completed: int = 2,
|
|
tasks_total: int = 4,
|
|
trigger: str = "task_completed",
|
|
branch: str = "main",
|
|
parent_id: str | None = None,
|
|
entity_type: str = "crew",
|
|
name: str = "test_crew",
|
|
inputs: dict[str, Any] | None = None,
|
|
) -> str:
|
|
tasks: list[dict[str, Any]] = []
|
|
for i in range(tasks_total):
|
|
t: dict[str, Any] = {
|
|
"description": f"Task {i + 1} description",
|
|
"expected_output": f"Output {i + 1}",
|
|
}
|
|
if i < tasks_completed:
|
|
t["output"] = {"raw": f"Result of task {i + 1}"}
|
|
else:
|
|
t["output"] = None
|
|
tasks.append(t)
|
|
|
|
data: dict[str, Any] = {
|
|
"entities": [
|
|
{
|
|
"entity_type": entity_type,
|
|
"name": name,
|
|
"id": "abc12345-1234-1234-1234-abcdef012345",
|
|
"tasks": tasks,
|
|
"agents": [],
|
|
"checkpoint_inputs": inputs or {},
|
|
}
|
|
],
|
|
"event_record": {"nodes": {f"node_{i}": {} for i in range(3)}},
|
|
"trigger": trigger,
|
|
"branch": branch,
|
|
"parent_id": parent_id,
|
|
}
|
|
return json.dumps(data)
|
|
|
|
|
|
def _write_json_checkpoint(
|
|
base_dir: str,
|
|
branch: str = "main",
|
|
name: str | None = None,
|
|
data: str | None = None,
|
|
tasks_completed: int = 2,
|
|
inputs: dict[str, Any] | None = None,
|
|
) -> str:
|
|
branch_dir = os.path.join(base_dir, branch)
|
|
os.makedirs(branch_dir, exist_ok=True)
|
|
if name is None:
|
|
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S")
|
|
name = f"{ts}_abcd1234_p-none.json"
|
|
path = os.path.join(branch_dir, name)
|
|
if data is None:
|
|
data = _make_checkpoint_data(tasks_completed=tasks_completed, inputs=inputs)
|
|
with open(path, "w") as f:
|
|
f.write(data)
|
|
return path
|
|
|
|
|
|
def _create_sqlite_checkpoint(
|
|
db_path: str,
|
|
checkpoint_id: str | None = None,
|
|
data: str | None = None,
|
|
tasks_completed: int = 2,
|
|
branch: str = "main",
|
|
inputs: dict[str, Any] | None = None,
|
|
) -> str:
|
|
if checkpoint_id is None:
|
|
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S")
|
|
checkpoint_id = f"{ts}_abcd1234"
|
|
if data is None:
|
|
data = _make_checkpoint_data(
|
|
tasks_completed=tasks_completed, branch=branch, inputs=inputs
|
|
)
|
|
with sqlite3.connect(db_path) as conn:
|
|
conn.execute(
|
|
"""CREATE TABLE IF NOT EXISTS checkpoints (
|
|
id TEXT PRIMARY KEY,
|
|
created_at TEXT NOT NULL,
|
|
parent_id TEXT,
|
|
branch TEXT NOT NULL DEFAULT 'main',
|
|
data JSONB NOT NULL
|
|
)"""
|
|
)
|
|
conn.execute(
|
|
"INSERT INTO checkpoints (id, created_at, parent_id, branch, data) "
|
|
"VALUES (?, ?, ?, ?, jsonb(?))",
|
|
(checkpoint_id, checkpoint_id.split("_")[0], None, branch, data),
|
|
)
|
|
conn.commit()
|
|
return checkpoint_id
|
|
|
|
|
|
class TestParseDuration:
|
|
def test_days(self) -> None:
|
|
assert _parse_duration("7d") == timedelta(days=7)
|
|
|
|
def test_hours(self) -> None:
|
|
assert _parse_duration("24h") == timedelta(hours=24)
|
|
|
|
def test_minutes(self) -> None:
|
|
assert _parse_duration("30m") == timedelta(minutes=30)
|
|
|
|
def test_invalid_raises(self) -> None:
|
|
with pytest.raises(Exception):
|
|
_parse_duration("abc")
|
|
|
|
def test_no_unit_raises(self) -> None:
|
|
with pytest.raises(Exception):
|
|
_parse_duration("7")
|
|
|
|
|
|
class TestResolveCheckpoint:
|
|
def test_json_latest(self) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
_write_json_checkpoint(d, name="20260101T000000_aaaa1111_p-none.json")
|
|
time.sleep(0.01)
|
|
path2 = _write_json_checkpoint(
|
|
d, name="20260102T000000_bbbb2222_p-none.json", tasks_completed=3
|
|
)
|
|
meta = _resolve_checkpoint(d, None)
|
|
assert meta is not None
|
|
assert meta["path"] == path2
|
|
|
|
def test_json_by_id(self) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
_write_json_checkpoint(d, name="20260101T000000_aaaa1111_p-none.json")
|
|
_write_json_checkpoint(d, name="20260102T000000_bbbb2222_p-none.json")
|
|
meta = _resolve_checkpoint(d, "aaaa1111")
|
|
assert meta is not None
|
|
assert "aaaa1111" in meta["name"]
|
|
|
|
def test_json_not_found(self) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
_write_json_checkpoint(d)
|
|
assert _resolve_checkpoint(d, "nonexistent") is None
|
|
|
|
def test_sqlite_latest(self) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
db_path = os.path.join(d, "test.db")
|
|
_create_sqlite_checkpoint(db_path, "20260101T000000_aaaa1111")
|
|
_create_sqlite_checkpoint(
|
|
db_path, "20260102T000000_bbbb2222", tasks_completed=3
|
|
)
|
|
meta = _resolve_checkpoint(db_path, None)
|
|
assert meta is not None
|
|
assert "bbbb2222" in meta["name"]
|
|
|
|
def test_sqlite_by_id(self) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
db_path = os.path.join(d, "test.db")
|
|
_create_sqlite_checkpoint(db_path, "20260101T000000_aaaa1111")
|
|
_create_sqlite_checkpoint(db_path, "20260102T000000_bbbb2222")
|
|
meta = _resolve_checkpoint(db_path, "20260101T000000_aaaa1111")
|
|
assert meta is not None
|
|
assert "aaaa1111" in meta["name"]
|
|
|
|
def test_sqlite_partial_id(self) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
db_path = os.path.join(d, "test.db")
|
|
_create_sqlite_checkpoint(db_path, "20260101T000000_aaaa1111")
|
|
_create_sqlite_checkpoint(db_path, "20260102T000000_bbbb2222")
|
|
meta = _resolve_checkpoint(db_path, "aaaa1111")
|
|
assert meta is not None
|
|
assert "aaaa1111" in meta["name"]
|
|
|
|
def test_nonexistent(self) -> None:
|
|
assert _resolve_checkpoint("/nonexistent/path", None) is None
|
|
|
|
|
|
class TestTaskListFromMeta:
|
|
def test_flattens_tasks(self) -> None:
|
|
data = _make_checkpoint_data(tasks_completed=2, tasks_total=3)
|
|
meta = _parse_checkpoint_json(data, "test")
|
|
tasks = _task_list_from_meta(meta)
|
|
assert len(tasks) == 3
|
|
assert tasks[0]["completed"] is True
|
|
assert tasks[2]["completed"] is False
|
|
|
|
def test_empty_entities(self) -> None:
|
|
assert _task_list_from_meta({"entities": []}) == []
|
|
|
|
|
|
class TestDiffCheckpoints:
|
|
def test_diff_shows_status_change(self, capsys: pytest.CaptureFixture[str]) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
_write_json_checkpoint(
|
|
d, name="20260101T000000_aaaa1111_p-none.json", tasks_completed=1
|
|
)
|
|
_write_json_checkpoint(
|
|
d, name="20260102T000000_bbbb2222_p-none.json", tasks_completed=3
|
|
)
|
|
diff_checkpoints(d, "aaaa1111", "bbbb2222")
|
|
out = capsys.readouterr().out
|
|
assert "---" in out
|
|
assert "+++" in out
|
|
assert "status:" in out or "pending -> done" in out
|
|
|
|
def test_diff_shows_output_change(self, capsys: pytest.CaptureFixture[str]) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
data1 = _make_checkpoint_data(tasks_completed=2)
|
|
data2 = json.loads(data1)
|
|
data2["entities"][0]["tasks"][0]["output"]["raw"] = "Updated result"
|
|
_write_json_checkpoint(
|
|
d,
|
|
name="20260101T000000_aaaa1111_p-none.json",
|
|
data=json.dumps(json.loads(data1)),
|
|
)
|
|
_write_json_checkpoint(
|
|
d,
|
|
name="20260102T000000_bbbb2222_p-none.json",
|
|
data=json.dumps(data2),
|
|
)
|
|
diff_checkpoints(d, "aaaa1111", "bbbb2222")
|
|
out = capsys.readouterr().out
|
|
assert "output:" in out
|
|
|
|
def test_diff_not_found(self, capsys: pytest.CaptureFixture[str]) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
_write_json_checkpoint(d, name="20260101T000000_aaaa1111_p-none.json")
|
|
diff_checkpoints(d, "aaaa1111", "nonexistent")
|
|
out = capsys.readouterr().out
|
|
assert "not found" in out
|
|
|
|
def test_diff_input_change(self, capsys: pytest.CaptureFixture[str]) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
_write_json_checkpoint(
|
|
d,
|
|
name="20260101T000000_aaaa1111_p-none.json",
|
|
inputs={"topic": "AI"},
|
|
)
|
|
_write_json_checkpoint(
|
|
d,
|
|
name="20260102T000000_bbbb2222_p-none.json",
|
|
inputs={"topic": "ML"},
|
|
)
|
|
diff_checkpoints(d, "aaaa1111", "bbbb2222")
|
|
out = capsys.readouterr().out
|
|
assert "Inputs:" in out
|
|
assert "AI" in out
|
|
assert "ML" in out
|
|
|
|
|
|
class TestPruneJson:
|
|
def test_keep_n(self) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
for i in range(5):
|
|
_write_json_checkpoint(
|
|
d, name=f"2026010{i + 1}T000000_aaa{i}1111_p-none.json"
|
|
)
|
|
time.sleep(0.01)
|
|
deleted = _prune_json(d, keep=2, older_than=None)
|
|
assert deleted == 3
|
|
remaining = []
|
|
for root, _, files in os.walk(d):
|
|
remaining.extend(files)
|
|
assert len(remaining) == 2
|
|
|
|
def test_older_than(self) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
old_path = _write_json_checkpoint(
|
|
d, name="20250101T000000_old01111_p-none.json"
|
|
)
|
|
os.utime(old_path, (0, 0))
|
|
_write_json_checkpoint(d, name="20990101T000000_new01111_p-none.json")
|
|
deleted = _prune_json(d, keep=None, older_than=timedelta(days=1))
|
|
assert deleted == 1
|
|
|
|
def test_empty_dir(self) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
assert _prune_json(d, keep=2, older_than=None) == 0
|
|
|
|
def test_removes_empty_branch_dirs(self) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
path = _write_json_checkpoint(
|
|
d,
|
|
branch="feature",
|
|
name="20260101T000000_aaaa1111_p-none.json",
|
|
)
|
|
os.utime(path, (0, 0))
|
|
_prune_json(d, keep=None, older_than=timedelta(days=1))
|
|
assert not os.path.exists(os.path.join(d, "feature"))
|
|
|
|
|
|
class TestPruneSqlite:
|
|
def test_keep_n(self) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
db_path = os.path.join(d, "test.db")
|
|
for i in range(5):
|
|
_create_sqlite_checkpoint(
|
|
db_path, f"2026010{i + 1}T000000_aaa{i}1111"
|
|
)
|
|
deleted = _prune_sqlite(db_path, keep=2, older_than=None)
|
|
assert deleted == 3
|
|
with sqlite3.connect(db_path) as conn:
|
|
count = conn.execute("SELECT COUNT(*) FROM checkpoints").fetchone()[0]
|
|
assert count == 2
|
|
|
|
def test_older_than(self) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
db_path = os.path.join(d, "test.db")
|
|
_create_sqlite_checkpoint(db_path, "20200101T000000_old01111")
|
|
_create_sqlite_checkpoint(db_path, "20990101T000000_new01111")
|
|
deleted = _prune_sqlite(db_path, keep=None, older_than=timedelta(days=1))
|
|
assert deleted >= 1
|
|
with sqlite3.connect(db_path) as conn:
|
|
count = conn.execute("SELECT COUNT(*) FROM checkpoints").fetchone()[0]
|
|
assert count >= 1
|
|
|
|
|
|
class TestPruneCommand:
|
|
def test_no_options_shows_help(self, capsys: pytest.CaptureFixture[str]) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
prune_checkpoints(d, keep=None, older_than=None)
|
|
out = capsys.readouterr().out
|
|
assert "Specify" in out
|
|
|
|
def test_dry_run_json(self, capsys: pytest.CaptureFixture[str]) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
_write_json_checkpoint(d)
|
|
prune_checkpoints(d, keep=1, older_than=None, dry_run=True)
|
|
out = capsys.readouterr().out
|
|
assert "Would prune" in out
|
|
|
|
def test_not_found(self, capsys: pytest.CaptureFixture[str]) -> None:
|
|
prune_checkpoints("/nonexistent", keep=1, older_than=None)
|
|
out = capsys.readouterr().out
|
|
assert "Not a directory" in out
|
|
|
|
|
|
class TestResumeCheckpoint:
|
|
def test_not_found(self, capsys: pytest.CaptureFixture[str]) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
resume_checkpoint(d, "nonexistent")
|
|
out = capsys.readouterr().out
|
|
assert "not found" in out
|
|
|
|
def test_no_checkpoints(self, capsys: pytest.CaptureFixture[str]) -> None:
|
|
with tempfile.TemporaryDirectory() as d:
|
|
resume_checkpoint(d, None)
|
|
out = capsys.readouterr().out
|
|
assert "No checkpoints" in out
|
|
|
|
|
|
class TestDiscoverabilityMessage:
|
|
def test_checkpoint_listener_logs_resume_hint(self) -> None:
|
|
from crewai.state.checkpoint_listener import _do_checkpoint
|
|
from crewai.state.runtime import RuntimeState
|
|
|
|
state = MagicMock(spec=RuntimeState)
|
|
state.root = []
|
|
state.model_dump.return_value = {"entities": [], "event_record": {"nodes": {}}}
|
|
state._parent_id = None
|
|
state._branch = "main"
|
|
|
|
cfg = MagicMock()
|
|
cfg.location = "/tmp/cp"
|
|
cfg.max_checkpoints = None
|
|
cfg.provider.checkpoint.return_value = "/tmp/cp/main/20260101T000000_test1234_p-none.json"
|
|
cfg.provider.extract_id.return_value = "20260101T000000_test1234"
|
|
|
|
with (
|
|
patch("crewai.state.checkpoint_listener._prepare_entities"),
|
|
patch("crewai.state.checkpoint_listener.logger") as mock_logger,
|
|
):
|
|
_do_checkpoint(state, cfg)
|
|
|
|
cfg.provider.extract_id.assert_called_once()
|
|
mock_logger.info.assert_called_once()
|
|
logged: str = mock_logger.info.call_args[0][0]
|
|
assert "crewai checkpoint resume" in logged
|
|
assert "20260101T000000_test1234" in logged
|