feat: add SqliteProvider for checkpoint storage

This commit is contained in:
Greyson LaLonde
2026-04-07 05:54:05 +08:00
committed by GitHub
parent c4e2d7ea3b
commit 6b6e191532
2 changed files with 140 additions and 1 deletions

View File

@@ -1,4 +1,5 @@
from crewai.state.checkpoint_config import CheckpointConfig, CheckpointEventType
from crewai.state.provider.sqlite_provider import SqliteProvider
__all__ = ["CheckpointConfig", "CheckpointEventType"]
__all__ = ["CheckpointConfig", "CheckpointEventType", "SqliteProvider"]

View File

@@ -0,0 +1,138 @@
"""SQLite state provider for checkpointing."""
from __future__ import annotations
from datetime import datetime, timezone
from pathlib import Path
import sqlite3
import uuid
import aiosqlite
from crewai.state.provider.core import BaseProvider
_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS checkpoints (
id TEXT PRIMARY KEY,
created_at TEXT NOT NULL,
data TEXT NOT NULL
)
"""
_INSERT = "INSERT INTO checkpoints (id, created_at, data) VALUES (?, ?, ?)"
_SELECT = "SELECT data FROM checkpoints WHERE id = ?"
_PRUNE = """
DELETE FROM checkpoints WHERE rowid NOT IN (
SELECT rowid FROM checkpoints ORDER BY rowid DESC LIMIT ?
)
"""
def _make_id() -> tuple[str, str]:
"""Generate a checkpoint ID and ISO timestamp.
Returns:
A tuple of (checkpoint_id, timestamp).
"""
ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S")
checkpoint_id = f"{ts}_{uuid.uuid4().hex[:8]}"
return checkpoint_id, ts
class SqliteProvider(BaseProvider):
"""Persists runtime state checkpoints in a SQLite database.
The ``directory`` argument to ``checkpoint`` / ``acheckpoint`` is
used as the database path (e.g. ``"./.checkpoints.db"``).
Args:
max_checkpoints: Maximum number of checkpoints to retain.
Oldest rows are pruned after each write. None keeps all.
"""
def __init__(self, max_checkpoints: int | None = None) -> None:
self.max_checkpoints = max_checkpoints
def checkpoint(self, data: str, directory: str) -> str:
"""Write a checkpoint to the SQLite database.
Args:
data: The serialized JSON string to persist.
directory: Path to the SQLite database file.
Returns:
A location string in the format ``"db_path#checkpoint_id"``.
"""
checkpoint_id, ts = _make_id()
Path(directory).parent.mkdir(parents=True, exist_ok=True)
with sqlite3.connect(directory) as conn:
conn.execute("PRAGMA journal_mode=WAL")
conn.execute(_CREATE_TABLE)
conn.execute(_INSERT, (checkpoint_id, ts, data))
if self.max_checkpoints is not None:
conn.execute(_PRUNE, (self.max_checkpoints,))
conn.commit()
return f"{directory}#{checkpoint_id}"
async def acheckpoint(self, data: str, directory: str) -> str:
"""Write a checkpoint to the SQLite database asynchronously.
Args:
data: The serialized JSON string to persist.
directory: Path to the SQLite database file.
Returns:
A location string in the format ``"db_path#checkpoint_id"``.
"""
checkpoint_id, ts = _make_id()
Path(directory).parent.mkdir(parents=True, exist_ok=True)
async with aiosqlite.connect(directory) as db:
await db.execute("PRAGMA journal_mode=WAL")
await db.execute(_CREATE_TABLE)
await db.execute(_INSERT, (checkpoint_id, ts, data))
if self.max_checkpoints is not None:
await db.execute(_PRUNE, (self.max_checkpoints,))
await db.commit()
return f"{directory}#{checkpoint_id}"
def from_checkpoint(self, location: str) -> str:
"""Read a checkpoint from the SQLite database.
Args:
location: A location string returned by ``checkpoint()``.
Returns:
The raw JSON string.
Raises:
ValueError: If the checkpoint ID is not found.
"""
db_path, checkpoint_id = location.rsplit("#", 1)
with sqlite3.connect(db_path) as conn:
row = conn.execute(_SELECT, (checkpoint_id,)).fetchone()
if row is None:
raise ValueError(f"Checkpoint not found: {checkpoint_id}")
result: str = row[0]
return result
async def afrom_checkpoint(self, location: str) -> str:
"""Read a checkpoint from the SQLite database asynchronously.
Args:
location: A location string returned by ``acheckpoint()``.
Returns:
The raw JSON string.
Raises:
ValueError: If the checkpoint ID is not found.
"""
db_path, checkpoint_id = location.rsplit("#", 1)
async with aiosqlite.connect(db_path) as db:
cursor = await db.execute(_SELECT, (checkpoint_id,))
row = await cursor.fetchone()
if row is None:
raise ValueError(f"Checkpoint not found: {checkpoint_id}")
result: str = row[0]
return result