refactor: checkpoint API cleanup

This commit is contained in:
Greyson LaLonde
2026-04-08 01:13:23 +08:00
committed by GitHub
parent 9325e2f6a4
commit 5958a16ade
11 changed files with 119 additions and 126 deletions

View File

@@ -39,7 +39,7 @@ crew = Crew(
agents=[...],
tasks=[...],
checkpoint=CheckpointConfig(
directory="./my_checkpoints",
location="./my_checkpoints",
on_events=["task_completed", "crew_kickoff_completed"],
max_checkpoints=5,
),
@@ -50,7 +50,7 @@ crew = Crew(
| الحقل | النوع | الافتراضي | الوصف |
|:------|:------|:----------|:------|
| `directory` | `str` | `"./.checkpoints"` | مسار ملفات نقاط الحفظ |
| `location` | `str` | `"./.checkpoints"` | مسار ملفات نقاط الحفظ |
| `on_events` | `list[str]` | `["task_completed"]` | انواع الاحداث التي تطلق نقطة حفظ |
| `provider` | `BaseProvider` | `JsonProvider()` | واجهة التخزين |
| `max_checkpoints` | `int \| None` | `None` | الحد الاقصى للملفات؛ يتم حذف الاقدم اولا |
@@ -95,7 +95,7 @@ result = crew.kickoff() # يستأنف من اخر مهمة مكتملة
crew = Crew(
agents=[researcher, writer],
tasks=[research_task, write_task, review_task],
checkpoint=CheckpointConfig(directory="./crew_cp"),
checkpoint=CheckpointConfig(location="./crew_cp"),
)
```
@@ -118,7 +118,7 @@ class MyFlow(Flow):
flow = MyFlow(
checkpoint=CheckpointConfig(
directory="./flow_cp",
location="./flow_cp",
on_events=["method_execution_finished"],
),
)
@@ -137,7 +137,7 @@ agent = Agent(
goal="Research topics",
backstory="Expert researcher",
checkpoint=CheckpointConfig(
directory="./agent_cp",
location="./agent_cp",
on_events=["lite_agent_execution_completed"],
),
)
@@ -160,7 +160,7 @@ crew = Crew(
agents=[...],
tasks=[...],
checkpoint=CheckpointConfig(
directory="./my_checkpoints",
location="./my_checkpoints",
provider=JsonProvider(),
max_checkpoints=5,
),
@@ -179,15 +179,12 @@ crew = Crew(
agents=[...],
tasks=[...],
checkpoint=CheckpointConfig(
directory="./.checkpoints.db",
provider=SqliteProvider(max_checkpoints=50),
location="./.checkpoints.db",
provider=SqliteProvider(),
),
)
```
<Note>
عند استخدام `SqliteProvider`، حقل `directory` هو مسار ملف قاعدة البيانات، وليس مجلدا.
</Note>
## انواع الاحداث

View File

@@ -39,7 +39,7 @@ crew = Crew(
agents=[...],
tasks=[...],
checkpoint=CheckpointConfig(
directory="./my_checkpoints",
location="./my_checkpoints",
on_events=["task_completed", "crew_kickoff_completed"],
max_checkpoints=5,
),
@@ -50,10 +50,10 @@ crew = Crew(
| Field | Type | Default | Description |
|:------|:-----|:--------|:------------|
| `directory` | `str` | `"./.checkpoints"` | Filesystem path for checkpoint files |
| `location` | `str` | `"./.checkpoints"` | Storage destination — a directory for `JsonProvider`, a database file path for `SqliteProvider` |
| `on_events` | `list[str]` | `["task_completed"]` | Event types that trigger a checkpoint |
| `provider` | `BaseProvider` | `JsonProvider()` | Storage backend |
| `max_checkpoints` | `int \| None` | `None` | Max files to keep; oldest pruned first |
| `max_checkpoints` | `int \| None` | `None` | Max checkpoints to keep. Oldest are pruned after each write. Pruning is handled by the provider. |
### Inheritance and Opt-Out
@@ -95,7 +95,7 @@ The restored crew skips already-completed tasks and resumes from the first incom
crew = Crew(
agents=[researcher, writer],
tasks=[research_task, write_task, review_task],
checkpoint=CheckpointConfig(directory="./crew_cp"),
checkpoint=CheckpointConfig(location="./crew_cp"),
)
```
@@ -118,7 +118,7 @@ class MyFlow(Flow):
flow = MyFlow(
checkpoint=CheckpointConfig(
directory="./flow_cp",
location="./flow_cp",
on_events=["method_execution_finished"],
),
)
@@ -137,7 +137,7 @@ agent = Agent(
goal="Research topics",
backstory="Expert researcher",
checkpoint=CheckpointConfig(
directory="./agent_cp",
location="./agent_cp",
on_events=["lite_agent_execution_completed"],
),
)
@@ -160,14 +160,14 @@ crew = Crew(
agents=[...],
tasks=[...],
checkpoint=CheckpointConfig(
directory="./my_checkpoints",
location="./my_checkpoints",
provider=JsonProvider(), # this is the default
max_checkpoints=5, # prunes oldest files
),
)
```
Files are named `<timestamp>_<uuid>.json` inside the directory.
Files are named `<timestamp>_<uuid>.json` inside the location directory.
### SqliteProvider
@@ -181,17 +181,14 @@ crew = Crew(
agents=[...],
tasks=[...],
checkpoint=CheckpointConfig(
directory="./.checkpoints.db",
provider=SqliteProvider(max_checkpoints=50),
location="./.checkpoints.db",
provider=SqliteProvider(),
max_checkpoints=50,
),
)
```
`SqliteProvider` accepts its own `max_checkpoints` parameter that prunes old rows via SQL. WAL journal mode is enabled for concurrent read access.
<Note>
When using `SqliteProvider`, the `directory` field is the database file path, not a directory. The `max_checkpoints` on `CheckpointConfig` controls filesystem pruning (for `JsonProvider`), while `SqliteProvider.max_checkpoints` controls row pruning in the database.
</Note>
WAL journal mode is enabled for concurrent read access.
## Event Types

View File

@@ -39,7 +39,7 @@ crew = Crew(
agents=[...],
tasks=[...],
checkpoint=CheckpointConfig(
directory="./my_checkpoints",
location="./my_checkpoints",
on_events=["task_completed", "crew_kickoff_completed"],
max_checkpoints=5,
),
@@ -50,7 +50,7 @@ crew = Crew(
| 필드 | 타입 | 기본값 | 설명 |
|:-----|:-----|:-------|:-----|
| `directory` | `str` | `"./.checkpoints"` | 체크포인트 파일 경로 |
| `location` | `str` | `"./.checkpoints"` | 체크포인트 파일 경로 |
| `on_events` | `list[str]` | `["task_completed"]` | 체크포인트를 트리거하는 이벤트 타입 |
| `provider` | `BaseProvider` | `JsonProvider()` | 스토리지 백엔드 |
| `max_checkpoints` | `int \| None` | `None` | 보관할 최대 파일 수; 오래된 것부터 삭제 |
@@ -95,7 +95,7 @@ result = crew.kickoff() # 마지막으로 완료된 태스크부터 재개
crew = Crew(
agents=[researcher, writer],
tasks=[research_task, write_task, review_task],
checkpoint=CheckpointConfig(directory="./crew_cp"),
checkpoint=CheckpointConfig(location="./crew_cp"),
)
```
@@ -118,7 +118,7 @@ class MyFlow(Flow):
flow = MyFlow(
checkpoint=CheckpointConfig(
directory="./flow_cp",
location="./flow_cp",
on_events=["method_execution_finished"],
),
)
@@ -137,7 +137,7 @@ agent = Agent(
goal="Research topics",
backstory="Expert researcher",
checkpoint=CheckpointConfig(
directory="./agent_cp",
location="./agent_cp",
on_events=["lite_agent_execution_completed"],
),
)
@@ -160,7 +160,7 @@ crew = Crew(
agents=[...],
tasks=[...],
checkpoint=CheckpointConfig(
directory="./my_checkpoints",
location="./my_checkpoints",
provider=JsonProvider(),
max_checkpoints=5,
),
@@ -179,15 +179,12 @@ crew = Crew(
agents=[...],
tasks=[...],
checkpoint=CheckpointConfig(
directory="./.checkpoints.db",
provider=SqliteProvider(max_checkpoints=50),
location="./.checkpoints.db",
provider=SqliteProvider(),
),
)
```
<Note>
`SqliteProvider`를 사용할 때 `directory` 필드는 디렉토리가 아닌 데이터베이스 파일 경로입니다.
</Note>
## 이벤트 타입

View File

@@ -39,7 +39,7 @@ crew = Crew(
agents=[...],
tasks=[...],
checkpoint=CheckpointConfig(
directory="./my_checkpoints",
location="./my_checkpoints",
on_events=["task_completed", "crew_kickoff_completed"],
max_checkpoints=5,
),
@@ -50,7 +50,7 @@ crew = Crew(
| Campo | Tipo | Padrao | Descricao |
|:------|:-----|:-------|:----------|
| `directory` | `str` | `"./.checkpoints"` | Caminho para os arquivos de checkpoint |
| `location` | `str` | `"./.checkpoints"` | Caminho para os arquivos de checkpoint |
| `on_events` | `list[str]` | `["task_completed"]` | Tipos de evento que acionam um checkpoint |
| `provider` | `BaseProvider` | `JsonProvider()` | Backend de armazenamento |
| `max_checkpoints` | `int \| None` | `None` | Maximo de arquivos a manter; os mais antigos sao removidos primeiro |
@@ -95,7 +95,7 @@ A crew restaurada pula tarefas ja concluidas e retoma a partir da primeira incom
crew = Crew(
agents=[researcher, writer],
tasks=[research_task, write_task, review_task],
checkpoint=CheckpointConfig(directory="./crew_cp"),
checkpoint=CheckpointConfig(location="./crew_cp"),
)
```
@@ -118,7 +118,7 @@ class MyFlow(Flow):
flow = MyFlow(
checkpoint=CheckpointConfig(
directory="./flow_cp",
location="./flow_cp",
on_events=["method_execution_finished"],
),
)
@@ -137,7 +137,7 @@ agent = Agent(
goal="Research topics",
backstory="Expert researcher",
checkpoint=CheckpointConfig(
directory="./agent_cp",
location="./agent_cp",
on_events=["lite_agent_execution_completed"],
),
)
@@ -160,7 +160,7 @@ crew = Crew(
agents=[...],
tasks=[...],
checkpoint=CheckpointConfig(
directory="./my_checkpoints",
location="./my_checkpoints",
provider=JsonProvider(),
max_checkpoints=5,
),
@@ -179,15 +179,12 @@ crew = Crew(
agents=[...],
tasks=[...],
checkpoint=CheckpointConfig(
directory="./.checkpoints.db",
provider=SqliteProvider(max_checkpoints=50),
location="./.checkpoints.db",
provider=SqliteProvider(),
),
)
```
<Note>
Ao usar `SqliteProvider`, o campo `directory` e o caminho do arquivo de banco de dados, nao um diretorio.
</Note>
## Tipos de Evento

View File

@@ -165,9 +165,10 @@ class CheckpointConfig(BaseModel):
automatically whenever the specified event(s) fire.
"""
directory: str = Field(
location: str = Field(
default="./.checkpoints",
description="Filesystem path where checkpoint JSON files are written.",
description="Storage destination. For JsonProvider this is a directory "
"path; for SqliteProvider it is a database file path.",
)
on_events: list[CheckpointEventType | Literal["*"]] = Field(
default=["task_completed"],
@@ -180,8 +181,8 @@ class CheckpointConfig(BaseModel):
)
max_checkpoints: int | None = Field(
default=None,
description="Maximum checkpoint files to keep. Oldest are pruned first. "
"None means keep all.",
description="Maximum checkpoints to keep. Oldest are pruned after "
"each write. None means keep all.",
)
@property

View File

@@ -7,9 +7,7 @@ avoids per-event overhead when no entity uses checkpointing.
from __future__ import annotations
import glob
import logging
import os
import threading
from typing import Any
@@ -105,29 +103,13 @@ def _find_checkpoint(source: Any) -> CheckpointConfig | None:
def _do_checkpoint(state: RuntimeState, cfg: CheckpointConfig) -> None:
"""Write a checkpoint synchronously and optionally prune old files."""
"""Write a checkpoint and prune old ones if configured."""
_prepare_entities(state.root)
data = state.model_dump_json()
cfg.provider.checkpoint(data, cfg.directory)
cfg.provider.checkpoint(data, cfg.location)
if cfg.max_checkpoints is not None:
_prune(cfg.directory, cfg.max_checkpoints)
def _safe_remove(path: str) -> None:
try:
os.remove(path)
except OSError:
logger.debug("Failed to remove checkpoint file %s", path, exc_info=True)
def _prune(directory: str, max_keep: int) -> None:
"""Remove oldest checkpoint files beyond *max_keep*."""
pattern = os.path.join(directory, "*.json")
files = sorted(glob.glob(pattern), key=os.path.getmtime)
to_remove = files if max_keep == 0 else files[:-max_keep]
for path in to_remove:
_safe_remove(path)
cfg.provider.prune(cfg.location, cfg.max_checkpoints)
def _should_checkpoint(source: Any, event: BaseEvent) -> CheckpointConfig | None:

View File

@@ -34,27 +34,36 @@ class BaseProvider(Protocol):
),
)
def checkpoint(self, data: str, directory: str) -> str:
def checkpoint(self, data: str, location: str) -> str:
"""Persist a snapshot synchronously.
Args:
data: The serialized string to persist.
directory: Logical destination: path, bucket prefix, etc.
location: Storage destination (directory, file path, URI, etc.).
Returns:
A location identifier for the saved checkpoint, such as a file path or URI.
A location identifier for the saved checkpoint.
"""
...
async def acheckpoint(self, data: str, directory: str) -> str:
async def acheckpoint(self, data: str, location: str) -> str:
"""Persist a snapshot asynchronously.
Args:
data: The serialized string to persist.
directory: Logical destination: path, bucket prefix, etc.
location: Storage destination (directory, file path, URI, etc.).
Returns:
A location identifier for the saved checkpoint, such as a file path or URI.
A location identifier for the saved checkpoint.
"""
...
def prune(self, location: str, max_keep: int) -> None:
"""Remove old checkpoints, keeping at most *max_keep*.
Args:
location: The storage destination passed to ``checkpoint``.
max_keep: Maximum number of checkpoints to retain.
"""
...

View File

@@ -3,6 +3,9 @@
from __future__ import annotations
from datetime import datetime, timezone
import glob
import logging
import os
from pathlib import Path
import uuid
@@ -12,43 +15,56 @@ import aiofiles.os
from crewai.state.provider.core import BaseProvider
logger = logging.getLogger(__name__)
class JsonProvider(BaseProvider):
"""Persists runtime state checkpoints as JSON files on the local filesystem."""
def checkpoint(self, data: str, directory: str) -> str:
"""Write a JSON checkpoint file to the directory.
def checkpoint(self, data: str, location: str) -> str:
"""Write a JSON checkpoint file.
Args:
data: The serialized JSON string to persist.
directory: Filesystem path where the checkpoint will be saved.
location: Directory where the checkpoint will be saved.
Returns:
The path to the written checkpoint file.
"""
file_path = _build_path(directory)
file_path = _build_path(location)
file_path.parent.mkdir(parents=True, exist_ok=True)
with open(file_path, "w") as f:
f.write(data)
return str(file_path)
async def acheckpoint(self, data: str, directory: str) -> str:
"""Write a JSON checkpoint file to the directory asynchronously.
async def acheckpoint(self, data: str, location: str) -> str:
"""Write a JSON checkpoint file asynchronously.
Args:
data: The serialized JSON string to persist.
directory: Filesystem path where the checkpoint will be saved.
location: Directory where the checkpoint will be saved.
Returns:
The path to the written checkpoint file.
"""
file_path = _build_path(directory)
file_path = _build_path(location)
await aiofiles.os.makedirs(str(file_path.parent), exist_ok=True)
async with aiofiles.open(file_path, "w") as f:
await f.write(data)
return str(file_path)
def prune(self, location: str, max_keep: int) -> None:
"""Remove oldest checkpoint files beyond *max_keep*."""
pattern = os.path.join(location, "*.json")
files = sorted(glob.glob(pattern), key=os.path.getmtime)
for path in files if max_keep == 0 else files[:-max_keep]:
try:
os.remove(path)
except OSError: # noqa: PERF203
logger.debug("Failed to remove %s", path, exc_info=True)
def from_checkpoint(self, location: str) -> str:
"""Read a JSON checkpoint file.

View File

@@ -43,58 +43,53 @@ def _make_id() -> tuple[str, str]:
class SqliteProvider(BaseProvider):
"""Persists runtime state checkpoints in a SQLite database.
The ``directory`` argument to ``checkpoint`` / ``acheckpoint`` is
used as the database path (e.g. ``"./.checkpoints.db"``).
Args:
max_checkpoints: Maximum number of checkpoints to retain.
Oldest rows are pruned after each write. None keeps all.
The ``location`` argument to ``checkpoint`` / ``acheckpoint`` is
used as the database file path.
"""
def __init__(self, max_checkpoints: int | None = None) -> None:
self.max_checkpoints = max_checkpoints
def checkpoint(self, data: str, directory: str) -> str:
def checkpoint(self, data: str, location: str) -> str:
"""Write a checkpoint to the SQLite database.
Args:
data: The serialized JSON string to persist.
directory: Path to the SQLite database file.
location: Path to the SQLite database file.
Returns:
A location string in the format ``"db_path#checkpoint_id"``.
"""
checkpoint_id, ts = _make_id()
Path(directory).parent.mkdir(parents=True, exist_ok=True)
with sqlite3.connect(directory) as conn:
Path(location).parent.mkdir(parents=True, exist_ok=True)
with sqlite3.connect(location) as conn:
conn.execute("PRAGMA journal_mode=WAL")
conn.execute(_CREATE_TABLE)
conn.execute(_INSERT, (checkpoint_id, ts, data))
if self.max_checkpoints is not None:
conn.execute(_PRUNE, (self.max_checkpoints,))
conn.commit()
return f"{directory}#{checkpoint_id}"
return f"{location}#{checkpoint_id}"
async def acheckpoint(self, data: str, directory: str) -> str:
async def acheckpoint(self, data: str, location: str) -> str:
"""Write a checkpoint to the SQLite database asynchronously.
Args:
data: The serialized JSON string to persist.
directory: Path to the SQLite database file.
location: Path to the SQLite database file.
Returns:
A location string in the format ``"db_path#checkpoint_id"``.
"""
checkpoint_id, ts = _make_id()
Path(directory).parent.mkdir(parents=True, exist_ok=True)
async with aiosqlite.connect(directory) as db:
Path(location).parent.mkdir(parents=True, exist_ok=True)
async with aiosqlite.connect(location) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(_CREATE_TABLE)
await db.execute(_INSERT, (checkpoint_id, ts, data))
if self.max_checkpoints is not None:
await db.execute(_PRUNE, (self.max_checkpoints,))
await db.commit()
return f"{directory}#{checkpoint_id}"
return f"{location}#{checkpoint_id}"
def prune(self, location: str, max_keep: int) -> None:
"""Remove oldest checkpoint rows beyond *max_keep*."""
with sqlite3.connect(location) as conn:
conn.execute(_PRUNE, (max_keep,))
conn.commit()
def from_checkpoint(self, location: str) -> str:
"""Read a checkpoint from the SQLite database.

View File

@@ -90,29 +90,31 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
return state
return handler(data)
def checkpoint(self, directory: str) -> str:
"""Write a checkpoint file to the directory.
def checkpoint(self, location: str) -> str:
"""Write a checkpoint.
Args:
directory: Filesystem path where the checkpoint JSON will be saved.
location: Storage destination. For JsonProvider this is a directory
path; for SqliteProvider it is a database file path.
Returns:
A location identifier for the saved checkpoint.
"""
_prepare_entities(self.root)
return self._provider.checkpoint(self.model_dump_json(), directory)
return self._provider.checkpoint(self.model_dump_json(), location)
async def acheckpoint(self, directory: str) -> str:
async def acheckpoint(self, location: str) -> str:
"""Async version of :meth:`checkpoint`.
Args:
directory: Filesystem path where the checkpoint JSON will be saved.
location: Storage destination. For JsonProvider this is a directory
path; for SqliteProvider it is a database file path.
Returns:
A location identifier for the saved checkpoint.
"""
_prepare_entities(self.root)
return await self._provider.acheckpoint(self.model_dump_json(), directory)
return await self._provider.acheckpoint(self.model_dump_json(), location)
@classmethod
def from_checkpoint(

View File

@@ -17,10 +17,10 @@ from crewai.flow.flow import Flow, start
from crewai.state.checkpoint_config import CheckpointConfig
from crewai.state.checkpoint_listener import (
_find_checkpoint,
_prune,
_resolve,
_SENTINEL,
)
from crewai.state.provider.json_provider import JsonProvider
from crewai.task import Task
@@ -37,10 +37,10 @@ class TestResolve:
def test_true_returns_config(self) -> None:
result = _resolve(True)
assert isinstance(result, CheckpointConfig)
assert result.directory == "./.checkpoints"
assert result.location == "./.checkpoints"
def test_config_returns_config(self) -> None:
cfg = CheckpointConfig(directory="/tmp/cp")
cfg = CheckpointConfig(location="/tmp/cp")
assert _resolve(cfg) is cfg
@@ -77,12 +77,12 @@ class TestFindCheckpoint:
def test_agent_config_overrides_crew(self) -> None:
a = self._make_agent(
checkpoint=CheckpointConfig(directory="/agent_cp")
checkpoint=CheckpointConfig(location="/agent_cp")
)
self._make_crew([a], checkpoint=True)
cfg = _find_checkpoint(a)
assert isinstance(cfg, CheckpointConfig)
assert cfg.directory == "/agent_cp"
assert cfg.location == "/agent_cp"
def test_task_inherits_from_crew(self) -> None:
a = self._make_agent()
@@ -123,7 +123,7 @@ class TestPrune:
# Ensure distinct mtime
time.sleep(0.01)
_prune(d, max_keep=2)
JsonProvider().prune(d, max_keep=2)
remaining = os.listdir(d)
assert len(remaining) == 2
assert "cp_3.json" in remaining
@@ -135,7 +135,7 @@ class TestPrune:
with open(os.path.join(d, f"cp_{i}.json"), "w") as f:
f.write("{}")
_prune(d, max_keep=0)
JsonProvider().prune(d, max_keep=0)
assert os.listdir(d) == []
def test_prune_more_than_existing(self) -> None:
@@ -143,7 +143,7 @@ class TestPrune:
with open(os.path.join(d, "cp.json"), "w") as f:
f.write("{}")
_prune(d, max_keep=10)
JsonProvider().prune(d, max_keep=10)
assert len(os.listdir(d)) == 1
@@ -153,7 +153,7 @@ class TestPrune:
class TestCheckpointConfig:
def test_defaults(self) -> None:
cfg = CheckpointConfig()
assert cfg.directory == "./.checkpoints"
assert cfg.location == "./.checkpoints"
assert cfg.on_events == ["task_completed"]
assert cfg.max_checkpoints is None
assert not cfg.trigger_all