mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-02 21:58:11 +00:00
refactor: checkpoint API cleanup
This commit is contained in:
@@ -39,7 +39,7 @@ crew = Crew(
|
||||
agents=[...],
|
||||
tasks=[...],
|
||||
checkpoint=CheckpointConfig(
|
||||
directory="./my_checkpoints",
|
||||
location="./my_checkpoints",
|
||||
on_events=["task_completed", "crew_kickoff_completed"],
|
||||
max_checkpoints=5,
|
||||
),
|
||||
@@ -50,7 +50,7 @@ crew = Crew(
|
||||
|
||||
| الحقل | النوع | الافتراضي | الوصف |
|
||||
|:------|:------|:----------|:------|
|
||||
| `directory` | `str` | `"./.checkpoints"` | مسار ملفات نقاط الحفظ |
|
||||
| `location` | `str` | `"./.checkpoints"` | مسار ملفات نقاط الحفظ |
|
||||
| `on_events` | `list[str]` | `["task_completed"]` | انواع الاحداث التي تطلق نقطة حفظ |
|
||||
| `provider` | `BaseProvider` | `JsonProvider()` | واجهة التخزين |
|
||||
| `max_checkpoints` | `int \| None` | `None` | الحد الاقصى للملفات؛ يتم حذف الاقدم اولا |
|
||||
@@ -95,7 +95,7 @@ result = crew.kickoff() # يستأنف من اخر مهمة مكتملة
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
tasks=[research_task, write_task, review_task],
|
||||
checkpoint=CheckpointConfig(directory="./crew_cp"),
|
||||
checkpoint=CheckpointConfig(location="./crew_cp"),
|
||||
)
|
||||
```
|
||||
|
||||
@@ -118,7 +118,7 @@ class MyFlow(Flow):
|
||||
|
||||
flow = MyFlow(
|
||||
checkpoint=CheckpointConfig(
|
||||
directory="./flow_cp",
|
||||
location="./flow_cp",
|
||||
on_events=["method_execution_finished"],
|
||||
),
|
||||
)
|
||||
@@ -137,7 +137,7 @@ agent = Agent(
|
||||
goal="Research topics",
|
||||
backstory="Expert researcher",
|
||||
checkpoint=CheckpointConfig(
|
||||
directory="./agent_cp",
|
||||
location="./agent_cp",
|
||||
on_events=["lite_agent_execution_completed"],
|
||||
),
|
||||
)
|
||||
@@ -160,7 +160,7 @@ crew = Crew(
|
||||
agents=[...],
|
||||
tasks=[...],
|
||||
checkpoint=CheckpointConfig(
|
||||
directory="./my_checkpoints",
|
||||
location="./my_checkpoints",
|
||||
provider=JsonProvider(),
|
||||
max_checkpoints=5,
|
||||
),
|
||||
@@ -179,15 +179,12 @@ crew = Crew(
|
||||
agents=[...],
|
||||
tasks=[...],
|
||||
checkpoint=CheckpointConfig(
|
||||
directory="./.checkpoints.db",
|
||||
provider=SqliteProvider(max_checkpoints=50),
|
||||
location="./.checkpoints.db",
|
||||
provider=SqliteProvider(),
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
<Note>
|
||||
عند استخدام `SqliteProvider`، حقل `directory` هو مسار ملف قاعدة البيانات، وليس مجلدا.
|
||||
</Note>
|
||||
|
||||
## انواع الاحداث
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ crew = Crew(
|
||||
agents=[...],
|
||||
tasks=[...],
|
||||
checkpoint=CheckpointConfig(
|
||||
directory="./my_checkpoints",
|
||||
location="./my_checkpoints",
|
||||
on_events=["task_completed", "crew_kickoff_completed"],
|
||||
max_checkpoints=5,
|
||||
),
|
||||
@@ -50,10 +50,10 @@ crew = Crew(
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|:------|:-----|:--------|:------------|
|
||||
| `directory` | `str` | `"./.checkpoints"` | Filesystem path for checkpoint files |
|
||||
| `location` | `str` | `"./.checkpoints"` | Storage destination — a directory for `JsonProvider`, a database file path for `SqliteProvider` |
|
||||
| `on_events` | `list[str]` | `["task_completed"]` | Event types that trigger a checkpoint |
|
||||
| `provider` | `BaseProvider` | `JsonProvider()` | Storage backend |
|
||||
| `max_checkpoints` | `int \| None` | `None` | Max files to keep; oldest pruned first |
|
||||
| `max_checkpoints` | `int \| None` | `None` | Max checkpoints to keep. Oldest are pruned after each write. Pruning is handled by the provider. |
|
||||
|
||||
### Inheritance and Opt-Out
|
||||
|
||||
@@ -95,7 +95,7 @@ The restored crew skips already-completed tasks and resumes from the first incom
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
tasks=[research_task, write_task, review_task],
|
||||
checkpoint=CheckpointConfig(directory="./crew_cp"),
|
||||
checkpoint=CheckpointConfig(location="./crew_cp"),
|
||||
)
|
||||
```
|
||||
|
||||
@@ -118,7 +118,7 @@ class MyFlow(Flow):
|
||||
|
||||
flow = MyFlow(
|
||||
checkpoint=CheckpointConfig(
|
||||
directory="./flow_cp",
|
||||
location="./flow_cp",
|
||||
on_events=["method_execution_finished"],
|
||||
),
|
||||
)
|
||||
@@ -137,7 +137,7 @@ agent = Agent(
|
||||
goal="Research topics",
|
||||
backstory="Expert researcher",
|
||||
checkpoint=CheckpointConfig(
|
||||
directory="./agent_cp",
|
||||
location="./agent_cp",
|
||||
on_events=["lite_agent_execution_completed"],
|
||||
),
|
||||
)
|
||||
@@ -160,14 +160,14 @@ crew = Crew(
|
||||
agents=[...],
|
||||
tasks=[...],
|
||||
checkpoint=CheckpointConfig(
|
||||
directory="./my_checkpoints",
|
||||
location="./my_checkpoints",
|
||||
provider=JsonProvider(), # this is the default
|
||||
max_checkpoints=5, # prunes oldest files
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
Files are named `<timestamp>_<uuid>.json` inside the directory.
|
||||
Files are named `<timestamp>_<uuid>.json` inside the location directory.
|
||||
|
||||
### SqliteProvider
|
||||
|
||||
@@ -181,17 +181,14 @@ crew = Crew(
|
||||
agents=[...],
|
||||
tasks=[...],
|
||||
checkpoint=CheckpointConfig(
|
||||
directory="./.checkpoints.db",
|
||||
provider=SqliteProvider(max_checkpoints=50),
|
||||
location="./.checkpoints.db",
|
||||
provider=SqliteProvider(),
|
||||
max_checkpoints=50,
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
`SqliteProvider` accepts its own `max_checkpoints` parameter that prunes old rows via SQL. WAL journal mode is enabled for concurrent read access.
|
||||
|
||||
<Note>
|
||||
When using `SqliteProvider`, the `directory` field is the database file path, not a directory. The `max_checkpoints` on `CheckpointConfig` controls filesystem pruning (for `JsonProvider`), while `SqliteProvider.max_checkpoints` controls row pruning in the database.
|
||||
</Note>
|
||||
WAL journal mode is enabled for concurrent read access.
|
||||
|
||||
## Event Types
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ crew = Crew(
|
||||
agents=[...],
|
||||
tasks=[...],
|
||||
checkpoint=CheckpointConfig(
|
||||
directory="./my_checkpoints",
|
||||
location="./my_checkpoints",
|
||||
on_events=["task_completed", "crew_kickoff_completed"],
|
||||
max_checkpoints=5,
|
||||
),
|
||||
@@ -50,7 +50,7 @@ crew = Crew(
|
||||
|
||||
| 필드 | 타입 | 기본값 | 설명 |
|
||||
|:-----|:-----|:-------|:-----|
|
||||
| `directory` | `str` | `"./.checkpoints"` | 체크포인트 파일 경로 |
|
||||
| `location` | `str` | `"./.checkpoints"` | 체크포인트 파일 경로 |
|
||||
| `on_events` | `list[str]` | `["task_completed"]` | 체크포인트를 트리거하는 이벤트 타입 |
|
||||
| `provider` | `BaseProvider` | `JsonProvider()` | 스토리지 백엔드 |
|
||||
| `max_checkpoints` | `int \| None` | `None` | 보관할 최대 파일 수; 오래된 것부터 삭제 |
|
||||
@@ -95,7 +95,7 @@ result = crew.kickoff() # 마지막으로 완료된 태스크부터 재개
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
tasks=[research_task, write_task, review_task],
|
||||
checkpoint=CheckpointConfig(directory="./crew_cp"),
|
||||
checkpoint=CheckpointConfig(location="./crew_cp"),
|
||||
)
|
||||
```
|
||||
|
||||
@@ -118,7 +118,7 @@ class MyFlow(Flow):
|
||||
|
||||
flow = MyFlow(
|
||||
checkpoint=CheckpointConfig(
|
||||
directory="./flow_cp",
|
||||
location="./flow_cp",
|
||||
on_events=["method_execution_finished"],
|
||||
),
|
||||
)
|
||||
@@ -137,7 +137,7 @@ agent = Agent(
|
||||
goal="Research topics",
|
||||
backstory="Expert researcher",
|
||||
checkpoint=CheckpointConfig(
|
||||
directory="./agent_cp",
|
||||
location="./agent_cp",
|
||||
on_events=["lite_agent_execution_completed"],
|
||||
),
|
||||
)
|
||||
@@ -160,7 +160,7 @@ crew = Crew(
|
||||
agents=[...],
|
||||
tasks=[...],
|
||||
checkpoint=CheckpointConfig(
|
||||
directory="./my_checkpoints",
|
||||
location="./my_checkpoints",
|
||||
provider=JsonProvider(),
|
||||
max_checkpoints=5,
|
||||
),
|
||||
@@ -179,15 +179,12 @@ crew = Crew(
|
||||
agents=[...],
|
||||
tasks=[...],
|
||||
checkpoint=CheckpointConfig(
|
||||
directory="./.checkpoints.db",
|
||||
provider=SqliteProvider(max_checkpoints=50),
|
||||
location="./.checkpoints.db",
|
||||
provider=SqliteProvider(),
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
<Note>
|
||||
`SqliteProvider`를 사용할 때 `directory` 필드는 디렉토리가 아닌 데이터베이스 파일 경로입니다.
|
||||
</Note>
|
||||
|
||||
## 이벤트 타입
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ crew = Crew(
|
||||
agents=[...],
|
||||
tasks=[...],
|
||||
checkpoint=CheckpointConfig(
|
||||
directory="./my_checkpoints",
|
||||
location="./my_checkpoints",
|
||||
on_events=["task_completed", "crew_kickoff_completed"],
|
||||
max_checkpoints=5,
|
||||
),
|
||||
@@ -50,7 +50,7 @@ crew = Crew(
|
||||
|
||||
| Campo | Tipo | Padrao | Descricao |
|
||||
|:------|:-----|:-------|:----------|
|
||||
| `directory` | `str` | `"./.checkpoints"` | Caminho para os arquivos de checkpoint |
|
||||
| `location` | `str` | `"./.checkpoints"` | Caminho para os arquivos de checkpoint |
|
||||
| `on_events` | `list[str]` | `["task_completed"]` | Tipos de evento que acionam um checkpoint |
|
||||
| `provider` | `BaseProvider` | `JsonProvider()` | Backend de armazenamento |
|
||||
| `max_checkpoints` | `int \| None` | `None` | Maximo de arquivos a manter; os mais antigos sao removidos primeiro |
|
||||
@@ -95,7 +95,7 @@ A crew restaurada pula tarefas ja concluidas e retoma a partir da primeira incom
|
||||
crew = Crew(
|
||||
agents=[researcher, writer],
|
||||
tasks=[research_task, write_task, review_task],
|
||||
checkpoint=CheckpointConfig(directory="./crew_cp"),
|
||||
checkpoint=CheckpointConfig(location="./crew_cp"),
|
||||
)
|
||||
```
|
||||
|
||||
@@ -118,7 +118,7 @@ class MyFlow(Flow):
|
||||
|
||||
flow = MyFlow(
|
||||
checkpoint=CheckpointConfig(
|
||||
directory="./flow_cp",
|
||||
location="./flow_cp",
|
||||
on_events=["method_execution_finished"],
|
||||
),
|
||||
)
|
||||
@@ -137,7 +137,7 @@ agent = Agent(
|
||||
goal="Research topics",
|
||||
backstory="Expert researcher",
|
||||
checkpoint=CheckpointConfig(
|
||||
directory="./agent_cp",
|
||||
location="./agent_cp",
|
||||
on_events=["lite_agent_execution_completed"],
|
||||
),
|
||||
)
|
||||
@@ -160,7 +160,7 @@ crew = Crew(
|
||||
agents=[...],
|
||||
tasks=[...],
|
||||
checkpoint=CheckpointConfig(
|
||||
directory="./my_checkpoints",
|
||||
location="./my_checkpoints",
|
||||
provider=JsonProvider(),
|
||||
max_checkpoints=5,
|
||||
),
|
||||
@@ -179,15 +179,12 @@ crew = Crew(
|
||||
agents=[...],
|
||||
tasks=[...],
|
||||
checkpoint=CheckpointConfig(
|
||||
directory="./.checkpoints.db",
|
||||
provider=SqliteProvider(max_checkpoints=50),
|
||||
location="./.checkpoints.db",
|
||||
provider=SqliteProvider(),
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
<Note>
|
||||
Ao usar `SqliteProvider`, o campo `directory` e o caminho do arquivo de banco de dados, nao um diretorio.
|
||||
</Note>
|
||||
|
||||
## Tipos de Evento
|
||||
|
||||
|
||||
@@ -165,9 +165,10 @@ class CheckpointConfig(BaseModel):
|
||||
automatically whenever the specified event(s) fire.
|
||||
"""
|
||||
|
||||
directory: str = Field(
|
||||
location: str = Field(
|
||||
default="./.checkpoints",
|
||||
description="Filesystem path where checkpoint JSON files are written.",
|
||||
description="Storage destination. For JsonProvider this is a directory "
|
||||
"path; for SqliteProvider it is a database file path.",
|
||||
)
|
||||
on_events: list[CheckpointEventType | Literal["*"]] = Field(
|
||||
default=["task_completed"],
|
||||
@@ -180,8 +181,8 @@ class CheckpointConfig(BaseModel):
|
||||
)
|
||||
max_checkpoints: int | None = Field(
|
||||
default=None,
|
||||
description="Maximum checkpoint files to keep. Oldest are pruned first. "
|
||||
"None means keep all.",
|
||||
description="Maximum checkpoints to keep. Oldest are pruned after "
|
||||
"each write. None means keep all.",
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -7,9 +7,7 @@ avoids per-event overhead when no entity uses checkpointing.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
@@ -105,29 +103,13 @@ def _find_checkpoint(source: Any) -> CheckpointConfig | None:
|
||||
|
||||
|
||||
def _do_checkpoint(state: RuntimeState, cfg: CheckpointConfig) -> None:
|
||||
"""Write a checkpoint synchronously and optionally prune old files."""
|
||||
"""Write a checkpoint and prune old ones if configured."""
|
||||
_prepare_entities(state.root)
|
||||
data = state.model_dump_json()
|
||||
cfg.provider.checkpoint(data, cfg.directory)
|
||||
cfg.provider.checkpoint(data, cfg.location)
|
||||
|
||||
if cfg.max_checkpoints is not None:
|
||||
_prune(cfg.directory, cfg.max_checkpoints)
|
||||
|
||||
|
||||
def _safe_remove(path: str) -> None:
|
||||
try:
|
||||
os.remove(path)
|
||||
except OSError:
|
||||
logger.debug("Failed to remove checkpoint file %s", path, exc_info=True)
|
||||
|
||||
|
||||
def _prune(directory: str, max_keep: int) -> None:
|
||||
"""Remove oldest checkpoint files beyond *max_keep*."""
|
||||
pattern = os.path.join(directory, "*.json")
|
||||
files = sorted(glob.glob(pattern), key=os.path.getmtime)
|
||||
to_remove = files if max_keep == 0 else files[:-max_keep]
|
||||
for path in to_remove:
|
||||
_safe_remove(path)
|
||||
cfg.provider.prune(cfg.location, cfg.max_checkpoints)
|
||||
|
||||
|
||||
def _should_checkpoint(source: Any, event: BaseEvent) -> CheckpointConfig | None:
|
||||
|
||||
@@ -34,27 +34,36 @@ class BaseProvider(Protocol):
|
||||
),
|
||||
)
|
||||
|
||||
def checkpoint(self, data: str, directory: str) -> str:
|
||||
def checkpoint(self, data: str, location: str) -> str:
|
||||
"""Persist a snapshot synchronously.
|
||||
|
||||
Args:
|
||||
data: The serialized string to persist.
|
||||
directory: Logical destination: path, bucket prefix, etc.
|
||||
location: Storage destination (directory, file path, URI, etc.).
|
||||
|
||||
Returns:
|
||||
A location identifier for the saved checkpoint, such as a file path or URI.
|
||||
A location identifier for the saved checkpoint.
|
||||
"""
|
||||
...
|
||||
|
||||
async def acheckpoint(self, data: str, directory: str) -> str:
|
||||
async def acheckpoint(self, data: str, location: str) -> str:
|
||||
"""Persist a snapshot asynchronously.
|
||||
|
||||
Args:
|
||||
data: The serialized string to persist.
|
||||
directory: Logical destination: path, bucket prefix, etc.
|
||||
location: Storage destination (directory, file path, URI, etc.).
|
||||
|
||||
Returns:
|
||||
A location identifier for the saved checkpoint, such as a file path or URI.
|
||||
A location identifier for the saved checkpoint.
|
||||
"""
|
||||
...
|
||||
|
||||
def prune(self, location: str, max_keep: int) -> None:
|
||||
"""Remove old checkpoints, keeping at most *max_keep*.
|
||||
|
||||
Args:
|
||||
location: The storage destination passed to ``checkpoint``.
|
||||
max_keep: Maximum number of checkpoints to retain.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import uuid
|
||||
|
||||
@@ -12,43 +15,56 @@ import aiofiles.os
|
||||
from crewai.state.provider.core import BaseProvider
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class JsonProvider(BaseProvider):
|
||||
"""Persists runtime state checkpoints as JSON files on the local filesystem."""
|
||||
|
||||
def checkpoint(self, data: str, directory: str) -> str:
|
||||
"""Write a JSON checkpoint file to the directory.
|
||||
def checkpoint(self, data: str, location: str) -> str:
|
||||
"""Write a JSON checkpoint file.
|
||||
|
||||
Args:
|
||||
data: The serialized JSON string to persist.
|
||||
directory: Filesystem path where the checkpoint will be saved.
|
||||
location: Directory where the checkpoint will be saved.
|
||||
|
||||
Returns:
|
||||
The path to the written checkpoint file.
|
||||
"""
|
||||
file_path = _build_path(directory)
|
||||
file_path = _build_path(location)
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(file_path, "w") as f:
|
||||
f.write(data)
|
||||
return str(file_path)
|
||||
|
||||
async def acheckpoint(self, data: str, directory: str) -> str:
|
||||
"""Write a JSON checkpoint file to the directory asynchronously.
|
||||
async def acheckpoint(self, data: str, location: str) -> str:
|
||||
"""Write a JSON checkpoint file asynchronously.
|
||||
|
||||
Args:
|
||||
data: The serialized JSON string to persist.
|
||||
directory: Filesystem path where the checkpoint will be saved.
|
||||
location: Directory where the checkpoint will be saved.
|
||||
|
||||
Returns:
|
||||
The path to the written checkpoint file.
|
||||
"""
|
||||
file_path = _build_path(directory)
|
||||
file_path = _build_path(location)
|
||||
await aiofiles.os.makedirs(str(file_path.parent), exist_ok=True)
|
||||
|
||||
async with aiofiles.open(file_path, "w") as f:
|
||||
await f.write(data)
|
||||
return str(file_path)
|
||||
|
||||
def prune(self, location: str, max_keep: int) -> None:
|
||||
"""Remove oldest checkpoint files beyond *max_keep*."""
|
||||
pattern = os.path.join(location, "*.json")
|
||||
files = sorted(glob.glob(pattern), key=os.path.getmtime)
|
||||
for path in files if max_keep == 0 else files[:-max_keep]:
|
||||
try:
|
||||
os.remove(path)
|
||||
except OSError: # noqa: PERF203
|
||||
logger.debug("Failed to remove %s", path, exc_info=True)
|
||||
|
||||
def from_checkpoint(self, location: str) -> str:
|
||||
"""Read a JSON checkpoint file.
|
||||
|
||||
|
||||
@@ -43,58 +43,53 @@ def _make_id() -> tuple[str, str]:
|
||||
class SqliteProvider(BaseProvider):
|
||||
"""Persists runtime state checkpoints in a SQLite database.
|
||||
|
||||
The ``directory`` argument to ``checkpoint`` / ``acheckpoint`` is
|
||||
used as the database path (e.g. ``"./.checkpoints.db"``).
|
||||
|
||||
Args:
|
||||
max_checkpoints: Maximum number of checkpoints to retain.
|
||||
Oldest rows are pruned after each write. None keeps all.
|
||||
The ``location`` argument to ``checkpoint`` / ``acheckpoint`` is
|
||||
used as the database file path.
|
||||
"""
|
||||
|
||||
def __init__(self, max_checkpoints: int | None = None) -> None:
|
||||
self.max_checkpoints = max_checkpoints
|
||||
|
||||
def checkpoint(self, data: str, directory: str) -> str:
|
||||
def checkpoint(self, data: str, location: str) -> str:
|
||||
"""Write a checkpoint to the SQLite database.
|
||||
|
||||
Args:
|
||||
data: The serialized JSON string to persist.
|
||||
directory: Path to the SQLite database file.
|
||||
location: Path to the SQLite database file.
|
||||
|
||||
Returns:
|
||||
A location string in the format ``"db_path#checkpoint_id"``.
|
||||
"""
|
||||
checkpoint_id, ts = _make_id()
|
||||
Path(directory).parent.mkdir(parents=True, exist_ok=True)
|
||||
with sqlite3.connect(directory) as conn:
|
||||
Path(location).parent.mkdir(parents=True, exist_ok=True)
|
||||
with sqlite3.connect(location) as conn:
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute(_CREATE_TABLE)
|
||||
conn.execute(_INSERT, (checkpoint_id, ts, data))
|
||||
if self.max_checkpoints is not None:
|
||||
conn.execute(_PRUNE, (self.max_checkpoints,))
|
||||
conn.commit()
|
||||
return f"{directory}#{checkpoint_id}"
|
||||
return f"{location}#{checkpoint_id}"
|
||||
|
||||
async def acheckpoint(self, data: str, directory: str) -> str:
|
||||
async def acheckpoint(self, data: str, location: str) -> str:
|
||||
"""Write a checkpoint to the SQLite database asynchronously.
|
||||
|
||||
Args:
|
||||
data: The serialized JSON string to persist.
|
||||
directory: Path to the SQLite database file.
|
||||
location: Path to the SQLite database file.
|
||||
|
||||
Returns:
|
||||
A location string in the format ``"db_path#checkpoint_id"``.
|
||||
"""
|
||||
checkpoint_id, ts = _make_id()
|
||||
Path(directory).parent.mkdir(parents=True, exist_ok=True)
|
||||
async with aiosqlite.connect(directory) as db:
|
||||
Path(location).parent.mkdir(parents=True, exist_ok=True)
|
||||
async with aiosqlite.connect(location) as db:
|
||||
await db.execute("PRAGMA journal_mode=WAL")
|
||||
await db.execute(_CREATE_TABLE)
|
||||
await db.execute(_INSERT, (checkpoint_id, ts, data))
|
||||
if self.max_checkpoints is not None:
|
||||
await db.execute(_PRUNE, (self.max_checkpoints,))
|
||||
await db.commit()
|
||||
return f"{directory}#{checkpoint_id}"
|
||||
return f"{location}#{checkpoint_id}"
|
||||
|
||||
def prune(self, location: str, max_keep: int) -> None:
|
||||
"""Remove oldest checkpoint rows beyond *max_keep*."""
|
||||
with sqlite3.connect(location) as conn:
|
||||
conn.execute(_PRUNE, (max_keep,))
|
||||
conn.commit()
|
||||
|
||||
def from_checkpoint(self, location: str) -> str:
|
||||
"""Read a checkpoint from the SQLite database.
|
||||
|
||||
@@ -90,29 +90,31 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
return state
|
||||
return handler(data)
|
||||
|
||||
def checkpoint(self, directory: str) -> str:
|
||||
"""Write a checkpoint file to the directory.
|
||||
def checkpoint(self, location: str) -> str:
|
||||
"""Write a checkpoint.
|
||||
|
||||
Args:
|
||||
directory: Filesystem path where the checkpoint JSON will be saved.
|
||||
location: Storage destination. For JsonProvider this is a directory
|
||||
path; for SqliteProvider it is a database file path.
|
||||
|
||||
Returns:
|
||||
A location identifier for the saved checkpoint.
|
||||
"""
|
||||
_prepare_entities(self.root)
|
||||
return self._provider.checkpoint(self.model_dump_json(), directory)
|
||||
return self._provider.checkpoint(self.model_dump_json(), location)
|
||||
|
||||
async def acheckpoint(self, directory: str) -> str:
|
||||
async def acheckpoint(self, location: str) -> str:
|
||||
"""Async version of :meth:`checkpoint`.
|
||||
|
||||
Args:
|
||||
directory: Filesystem path where the checkpoint JSON will be saved.
|
||||
location: Storage destination. For JsonProvider this is a directory
|
||||
path; for SqliteProvider it is a database file path.
|
||||
|
||||
Returns:
|
||||
A location identifier for the saved checkpoint.
|
||||
"""
|
||||
_prepare_entities(self.root)
|
||||
return await self._provider.acheckpoint(self.model_dump_json(), directory)
|
||||
return await self._provider.acheckpoint(self.model_dump_json(), location)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
|
||||
@@ -17,10 +17,10 @@ from crewai.flow.flow import Flow, start
|
||||
from crewai.state.checkpoint_config import CheckpointConfig
|
||||
from crewai.state.checkpoint_listener import (
|
||||
_find_checkpoint,
|
||||
_prune,
|
||||
_resolve,
|
||||
_SENTINEL,
|
||||
)
|
||||
from crewai.state.provider.json_provider import JsonProvider
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
@@ -37,10 +37,10 @@ class TestResolve:
|
||||
def test_true_returns_config(self) -> None:
|
||||
result = _resolve(True)
|
||||
assert isinstance(result, CheckpointConfig)
|
||||
assert result.directory == "./.checkpoints"
|
||||
assert result.location == "./.checkpoints"
|
||||
|
||||
def test_config_returns_config(self) -> None:
|
||||
cfg = CheckpointConfig(directory="/tmp/cp")
|
||||
cfg = CheckpointConfig(location="/tmp/cp")
|
||||
assert _resolve(cfg) is cfg
|
||||
|
||||
|
||||
@@ -77,12 +77,12 @@ class TestFindCheckpoint:
|
||||
|
||||
def test_agent_config_overrides_crew(self) -> None:
|
||||
a = self._make_agent(
|
||||
checkpoint=CheckpointConfig(directory="/agent_cp")
|
||||
checkpoint=CheckpointConfig(location="/agent_cp")
|
||||
)
|
||||
self._make_crew([a], checkpoint=True)
|
||||
cfg = _find_checkpoint(a)
|
||||
assert isinstance(cfg, CheckpointConfig)
|
||||
assert cfg.directory == "/agent_cp"
|
||||
assert cfg.location == "/agent_cp"
|
||||
|
||||
def test_task_inherits_from_crew(self) -> None:
|
||||
a = self._make_agent()
|
||||
@@ -123,7 +123,7 @@ class TestPrune:
|
||||
# Ensure distinct mtime
|
||||
time.sleep(0.01)
|
||||
|
||||
_prune(d, max_keep=2)
|
||||
JsonProvider().prune(d, max_keep=2)
|
||||
remaining = os.listdir(d)
|
||||
assert len(remaining) == 2
|
||||
assert "cp_3.json" in remaining
|
||||
@@ -135,7 +135,7 @@ class TestPrune:
|
||||
with open(os.path.join(d, f"cp_{i}.json"), "w") as f:
|
||||
f.write("{}")
|
||||
|
||||
_prune(d, max_keep=0)
|
||||
JsonProvider().prune(d, max_keep=0)
|
||||
assert os.listdir(d) == []
|
||||
|
||||
def test_prune_more_than_existing(self) -> None:
|
||||
@@ -143,7 +143,7 @@ class TestPrune:
|
||||
with open(os.path.join(d, "cp.json"), "w") as f:
|
||||
f.write("{}")
|
||||
|
||||
_prune(d, max_keep=10)
|
||||
JsonProvider().prune(d, max_keep=10)
|
||||
assert len(os.listdir(d)) == 1
|
||||
|
||||
|
||||
@@ -153,7 +153,7 @@ class TestPrune:
|
||||
class TestCheckpointConfig:
|
||||
def test_defaults(self) -> None:
|
||||
cfg = CheckpointConfig()
|
||||
assert cfg.directory == "./.checkpoints"
|
||||
assert cfg.location == "./.checkpoints"
|
||||
assert cfg.on_events == ["task_completed"]
|
||||
assert cfg.max_checkpoints is None
|
||||
assert not cfg.trigger_all
|
||||
|
||||
Reference in New Issue
Block a user