diff --git a/lib/crewai/src/crewai/state/__init__.py b/lib/crewai/src/crewai/state/__init__.py index d8f3419c7..c08d19dcd 100644 --- a/lib/crewai/src/crewai/state/__init__.py +++ b/lib/crewai/src/crewai/state/__init__.py @@ -1,4 +1,5 @@ from crewai.state.checkpoint_config import CheckpointConfig, CheckpointEventType +from crewai.state.provider.sqlite_provider import SqliteProvider -__all__ = ["CheckpointConfig", "CheckpointEventType"] +__all__ = ["CheckpointConfig", "CheckpointEventType", "SqliteProvider"] diff --git a/lib/crewai/src/crewai/state/provider/sqlite_provider.py b/lib/crewai/src/crewai/state/provider/sqlite_provider.py new file mode 100644 index 000000000..51f7096d2 --- /dev/null +++ b/lib/crewai/src/crewai/state/provider/sqlite_provider.py @@ -0,0 +1,138 @@ +"""SQLite state provider for checkpointing.""" + +from __future__ import annotations + +from datetime import datetime, timezone +from pathlib import Path +import sqlite3 +import uuid + +import aiosqlite + +from crewai.state.provider.core import BaseProvider + + +_CREATE_TABLE = """ +CREATE TABLE IF NOT EXISTS checkpoints ( + id TEXT PRIMARY KEY, + created_at TEXT NOT NULL, + data TEXT NOT NULL +) +""" + +_INSERT = "INSERT INTO checkpoints (id, created_at, data) VALUES (?, ?, ?)" +_SELECT = "SELECT data FROM checkpoints WHERE id = ?" +_PRUNE = """ +DELETE FROM checkpoints WHERE rowid NOT IN ( + SELECT rowid FROM checkpoints ORDER BY rowid DESC LIMIT ? +) +""" + + +def _make_id() -> tuple[str, str]: + """Generate a checkpoint ID and ISO timestamp. + + Returns: + A tuple of (checkpoint_id, timestamp). + """ + ts = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%S") + checkpoint_id = f"{ts}_{uuid.uuid4().hex[:8]}" + return checkpoint_id, ts + + +class SqliteProvider(BaseProvider): + """Persists runtime state checkpoints in a SQLite database. + + The ``directory`` argument to ``checkpoint`` / ``acheckpoint`` is + used as the database path (e.g. ``"./.checkpoints.db"``). + + Args: + max_checkpoints: Maximum number of checkpoints to retain. + Oldest rows are pruned after each write. None keeps all. + """ + + def __init__(self, max_checkpoints: int | None = None) -> None: + self.max_checkpoints = max_checkpoints + + def checkpoint(self, data: str, directory: str) -> str: + """Write a checkpoint to the SQLite database. + + Args: + data: The serialized JSON string to persist. + directory: Path to the SQLite database file. + + Returns: + A location string in the format ``"db_path#checkpoint_id"``. + """ + checkpoint_id, ts = _make_id() + Path(directory).parent.mkdir(parents=True, exist_ok=True) + with sqlite3.connect(directory) as conn: + conn.execute("PRAGMA journal_mode=WAL") + conn.execute(_CREATE_TABLE) + conn.execute(_INSERT, (checkpoint_id, ts, data)) + if self.max_checkpoints is not None: + conn.execute(_PRUNE, (self.max_checkpoints,)) + conn.commit() + return f"{directory}#{checkpoint_id}" + + async def acheckpoint(self, data: str, directory: str) -> str: + """Write a checkpoint to the SQLite database asynchronously. + + Args: + data: The serialized JSON string to persist. + directory: Path to the SQLite database file. + + Returns: + A location string in the format ``"db_path#checkpoint_id"``. + """ + checkpoint_id, ts = _make_id() + Path(directory).parent.mkdir(parents=True, exist_ok=True) + async with aiosqlite.connect(directory) as db: + await db.execute("PRAGMA journal_mode=WAL") + await db.execute(_CREATE_TABLE) + await db.execute(_INSERT, (checkpoint_id, ts, data)) + if self.max_checkpoints is not None: + await db.execute(_PRUNE, (self.max_checkpoints,)) + await db.commit() + return f"{directory}#{checkpoint_id}" + + def from_checkpoint(self, location: str) -> str: + """Read a checkpoint from the SQLite database. + + Args: + location: A location string returned by ``checkpoint()``. + + Returns: + The raw JSON string. + + Raises: + ValueError: If the checkpoint ID is not found. + """ + db_path, checkpoint_id = location.rsplit("#", 1) + with sqlite3.connect(db_path) as conn: + row = conn.execute(_SELECT, (checkpoint_id,)).fetchone() + if row is None: + raise ValueError(f"Checkpoint not found: {checkpoint_id}") + result: str = row[0] + return result + + async def afrom_checkpoint(self, location: str) -> str: + """Read a checkpoint from the SQLite database asynchronously. + + Args: + location: A location string returned by ``acheckpoint()``. + + Returns: + The raw JSON string. + + Raises: + ValueError: If the checkpoint ID is not found. + """ + db_path, checkpoint_id = location.rsplit("#", 1) + async with aiosqlite.connect(db_path) as db: + cursor = await db.execute(_SELECT, (checkpoint_id,)) + row = await cursor.fetchone() + if row is None: + raise ValueError(f"Checkpoint not found: {checkpoint_id}") + result: str = row[0] + return result