feat: add CheckpointConfig for automatic checkpointing

This commit is contained in:
Greyson LaLonde
2026-04-07 05:34:25 +08:00
committed by GitHub
parent 86ce54fc82
commit c4e2d7ea3b
13 changed files with 2113 additions and 775 deletions

View File

@@ -16,6 +16,7 @@ from crewai.knowledge.knowledge import Knowledge
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.process import Process
from crewai.state.checkpoint_config import CheckpointConfig # noqa: F401
from crewai.task import Task
from crewai.tasks.llm_guardrail import LLMGuardrail
from crewai.tasks.task_output import TaskOutput
@@ -210,6 +211,7 @@ try:
Agent.model_rebuild(force=True, _types_namespace=_full_namespace)
except PydanticUserError:
pass
except (ImportError, PydanticUserError):
import logging as _logging

View File

@@ -39,6 +39,7 @@ from crewai.memory.unified_memory import Memory
from crewai.rag.embeddings.types import EmbedderConfig
from crewai.security.security_config import SecurityConfig
from crewai.skills.models import Skill
from crewai.state.checkpoint_config import CheckpointConfig
from crewai.tools.base_tool import BaseTool, Tool
from crewai.types.callback import SerializableCallable
from crewai.utilities.config import process_config
@@ -299,6 +300,11 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
default_factory=SecurityConfig,
description="Security configuration for the agent, including fingerprinting.",
)
checkpoint: CheckpointConfig | bool | None = Field(
default=None,
description="Automatic checkpointing configuration. "
"True for defaults, False to opt out, None to inherit.",
)
callbacks: list[SerializableCallable] = Field(
default_factory=list, description="Callbacks to be used for the agent"
)

View File

@@ -104,6 +104,7 @@ from crewai.rag.types import SearchResult
from crewai.security.fingerprint import Fingerprint
from crewai.security.security_config import SecurityConfig
from crewai.skills.models import Skill
from crewai.state.checkpoint_config import CheckpointConfig
from crewai.task import Task
from crewai.tasks.conditional_task import ConditionalTask
from crewai.tasks.task_output import TaskOutput
@@ -340,6 +341,11 @@ class Crew(FlowTrackable, BaseModel):
default_factory=SecurityConfig,
description="Security configuration for the crew, including fingerprinting.",
)
checkpoint: CheckpointConfig | bool | None = Field(
default=None,
description="Automatic checkpointing configuration. "
"True for defaults, False to opt out, None to inherit.",
)
token_usage: UsageMetrics | None = Field(
default=None,
description="Metrics for the LLM usage during all tasks execution.",

View File

@@ -113,6 +113,7 @@ from crewai.flow.utils import (
)
from crewai.memory.memory_scope import MemoryScope, MemorySlice
from crewai.memory.unified_memory import Memory
from crewai.state.checkpoint_config import CheckpointConfig
if TYPE_CHECKING:
@@ -920,6 +921,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
max_method_calls: int = Field(default=100)
execution_context: ExecutionContext | None = Field(default=None)
checkpoint: CheckpointConfig | bool | None = Field(default=None)
@classmethod
def from_checkpoint(

View File

@@ -0,0 +1,4 @@
from crewai.state.checkpoint_config import CheckpointConfig, CheckpointEventType
__all__ = ["CheckpointConfig", "CheckpointEventType"]

View File

@@ -0,0 +1,193 @@
"""Checkpoint configuration for automatic state persistence."""
from __future__ import annotations
from typing import Literal
from pydantic import BaseModel, Field
from crewai.state.provider.core import BaseProvider
from crewai.state.provider.json_provider import JsonProvider
CheckpointEventType = Literal[
# Task
"task_started",
"task_completed",
"task_failed",
"task_evaluation",
# Crew
"crew_kickoff_started",
"crew_kickoff_completed",
"crew_kickoff_failed",
"crew_train_started",
"crew_train_completed",
"crew_train_failed",
"crew_test_started",
"crew_test_completed",
"crew_test_failed",
"crew_test_result",
# Agent
"agent_execution_started",
"agent_execution_completed",
"agent_execution_error",
"lite_agent_execution_started",
"lite_agent_execution_completed",
"lite_agent_execution_error",
"agent_evaluation_started",
"agent_evaluation_completed",
"agent_evaluation_failed",
# Flow
"flow_created",
"flow_started",
"flow_finished",
"flow_paused",
"method_execution_started",
"method_execution_finished",
"method_execution_failed",
"method_execution_paused",
"human_feedback_requested",
"human_feedback_received",
"flow_input_requested",
"flow_input_received",
# LLM
"llm_call_started",
"llm_call_completed",
"llm_call_failed",
"llm_stream_chunk",
"llm_thinking_chunk",
# LLM Guardrail
"llm_guardrail_started",
"llm_guardrail_completed",
"llm_guardrail_failed",
# Tool
"tool_usage_started",
"tool_usage_finished",
"tool_usage_error",
"tool_validate_input_error",
"tool_selection_error",
"tool_execution_error",
# Memory
"memory_save_started",
"memory_save_completed",
"memory_save_failed",
"memory_query_started",
"memory_query_completed",
"memory_query_failed",
"memory_retrieval_started",
"memory_retrieval_completed",
"memory_retrieval_failed",
# Knowledge
"knowledge_search_query_started",
"knowledge_search_query_completed",
"knowledge_query_started",
"knowledge_query_completed",
"knowledge_query_failed",
"knowledge_search_query_failed",
# Reasoning
"agent_reasoning_started",
"agent_reasoning_completed",
"agent_reasoning_failed",
# MCP
"mcp_connection_started",
"mcp_connection_completed",
"mcp_connection_failed",
"mcp_tool_execution_started",
"mcp_tool_execution_completed",
"mcp_tool_execution_failed",
"mcp_config_fetch_failed",
# Observation
"step_observation_started",
"step_observation_completed",
"step_observation_failed",
"plan_refinement",
"plan_replan_triggered",
"goal_achieved_early",
# Skill
"skill_discovery_started",
"skill_discovery_completed",
"skill_loaded",
"skill_activated",
"skill_load_failed",
# Logging
"agent_logs_started",
"agent_logs_execution",
# A2A
"a2a_delegation_started",
"a2a_delegation_completed",
"a2a_conversation_started",
"a2a_conversation_completed",
"a2a_message_sent",
"a2a_response_received",
"a2a_polling_started",
"a2a_polling_status",
"a2a_push_notification_registered",
"a2a_push_notification_received",
"a2a_push_notification_sent",
"a2a_push_notification_timeout",
"a2a_streaming_started",
"a2a_streaming_chunk",
"a2a_agent_card_fetched",
"a2a_authentication_failed",
"a2a_artifact_received",
"a2a_connection_error",
"a2a_server_task_started",
"a2a_server_task_completed",
"a2a_server_task_canceled",
"a2a_server_task_failed",
"a2a_parallel_delegation_started",
"a2a_parallel_delegation_completed",
"a2a_transport_negotiated",
"a2a_content_type_negotiated",
"a2a_context_created",
"a2a_context_expired",
"a2a_context_idle",
"a2a_context_completed",
"a2a_context_pruned",
# System
"SIGTERM",
"SIGINT",
"SIGHUP",
"SIGTSTP",
"SIGCONT",
# Env
"cc_env",
"codex_env",
"cursor_env",
"default_env",
]
class CheckpointConfig(BaseModel):
"""Configuration for automatic checkpointing.
When set on a Crew, Flow, or Agent, checkpoints are written
automatically whenever the specified event(s) fire.
"""
directory: str = Field(
default="./.checkpoints",
description="Filesystem path where checkpoint JSON files are written.",
)
on_events: list[CheckpointEventType | Literal["*"]] = Field(
default=["task_completed"],
description="Event types that trigger a checkpoint write. "
'Use ["*"] to checkpoint on every event.',
)
provider: BaseProvider = Field(
default_factory=JsonProvider,
description="Storage backend. Defaults to JsonProvider.",
)
max_checkpoints: int | None = Field(
default=None,
description="Maximum checkpoint files to keep. Oldest are pruned first. "
"None means keep all.",
)
@property
def trigger_all(self) -> bool:
return "*" in self.on_events
@property
def trigger_events(self) -> set[str]:
return set(self.on_events)

View File

@@ -0,0 +1,176 @@
"""Event listener that writes checkpoints automatically.
Handlers are registered lazily — only when the first ``CheckpointConfig``
is resolved (i.e. an entity actually has checkpointing enabled). This
avoids per-event overhead when no entity uses checkpointing.
"""
from __future__ import annotations
import glob
import logging
import os
import threading
from typing import Any
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.crew import Crew
from crewai.events.base_events import BaseEvent
from crewai.events.event_bus import CrewAIEventsBus, crewai_event_bus
from crewai.flow.flow import Flow
from crewai.state.checkpoint_config import CheckpointConfig
from crewai.state.runtime import RuntimeState, _prepare_entities
from crewai.task import Task
logger = logging.getLogger(__name__)
_handlers_registered = False
_register_lock = threading.Lock()
_SENTINEL = object()
def _ensure_handlers_registered() -> None:
"""Register checkpoint handlers on the event bus once, lazily."""
global _handlers_registered
if _handlers_registered:
return
with _register_lock:
if _handlers_registered:
return
_register_all_handlers(crewai_event_bus)
_handlers_registered = True
def _resolve(value: CheckpointConfig | bool | None) -> CheckpointConfig | None | object:
"""Coerce a checkpoint field value.
Returns:
CheckpointConfig — use this config.
_SENTINEL — explicit opt-out (``False``), stop walking parents.
None — not configured, keep walking parents.
"""
if isinstance(value, CheckpointConfig):
_ensure_handlers_registered()
return value
if value is True:
_ensure_handlers_registered()
return CheckpointConfig()
if value is False:
return _SENTINEL
return None # None = inherit
def _find_checkpoint(source: Any) -> CheckpointConfig | None:
"""Find the CheckpointConfig for an event source.
Walks known relationships: Task -> Agent -> Crew. Flow and Agent
carry their own checkpoint field directly.
A ``None`` value means "not configured, inherit from parent".
A ``False`` value means "opt out" and stops the walk.
"""
if isinstance(source, Flow):
result = _resolve(source.checkpoint)
return result if isinstance(result, CheckpointConfig) else None
if isinstance(source, Crew):
result = _resolve(source.checkpoint)
return result if isinstance(result, CheckpointConfig) else None
if isinstance(source, BaseAgent):
result = _resolve(source.checkpoint)
if isinstance(result, CheckpointConfig):
return result
if result is _SENTINEL:
return None
crew = source.crew
if isinstance(crew, Crew):
result = _resolve(crew.checkpoint)
return result if isinstance(result, CheckpointConfig) else None
return None
if isinstance(source, Task):
agent = source.agent
if isinstance(agent, BaseAgent):
result = _resolve(agent.checkpoint)
if isinstance(result, CheckpointConfig):
return result
if result is _SENTINEL:
return None
crew = agent.crew
if isinstance(crew, Crew):
result = _resolve(crew.checkpoint)
return result if isinstance(result, CheckpointConfig) else None
return None
return None
def _do_checkpoint(state: RuntimeState, cfg: CheckpointConfig) -> None:
"""Write a checkpoint synchronously and optionally prune old files."""
_prepare_entities(state.root)
data = state.model_dump_json()
cfg.provider.checkpoint(data, cfg.directory)
if cfg.max_checkpoints is not None:
_prune(cfg.directory, cfg.max_checkpoints)
def _safe_remove(path: str) -> None:
try:
os.remove(path)
except OSError:
logger.debug("Failed to remove checkpoint file %s", path, exc_info=True)
def _prune(directory: str, max_keep: int) -> None:
"""Remove oldest checkpoint files beyond *max_keep*."""
pattern = os.path.join(directory, "*.json")
files = sorted(glob.glob(pattern), key=os.path.getmtime)
to_remove = files if max_keep == 0 else files[:-max_keep]
for path in to_remove:
_safe_remove(path)
def _should_checkpoint(source: Any, event: BaseEvent) -> CheckpointConfig | None:
"""Return the CheckpointConfig if this event should trigger a checkpoint."""
cfg = _find_checkpoint(source)
if cfg is None:
return None
if not cfg.trigger_all and event.type not in cfg.trigger_events:
return None
return cfg
def _on_any_event(source: Any, event: BaseEvent, state: Any) -> None:
"""Sync handler registered on every event class."""
cfg = _should_checkpoint(source, event)
if cfg is None:
return
try:
_do_checkpoint(state, cfg)
except Exception:
logger.warning("Auto-checkpoint failed for event %s", event.type, exc_info=True)
def _register_all_handlers(event_bus: CrewAIEventsBus) -> None:
"""Register the checkpoint handler on all known event classes.
Only the sync handler is registered. The event bus runs sync handlers
in a ``ThreadPoolExecutor``, so blocking I/O is safe and we avoid
writing duplicate checkpoints from both sync and async dispatch.
"""
seen: set[type] = set()
def _collect(cls: type[BaseEvent]) -> None:
for sub in cls.__subclasses__():
if sub not in seen:
seen.add(sub)
type_field = sub.model_fields.get("type")
if (
type_field
and type_field.default
and type_field.default != "base_event"
):
event_bus.register_handler(sub, _on_any_event)
_collect(sub)
_collect(BaseEvent)

View File

@@ -0,0 +1,169 @@
"""Tests for CheckpointConfig, checkpoint listener, and pruning."""
from __future__ import annotations
import os
import tempfile
import time
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from crewai.agent.core import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.crew import Crew
from crewai.flow.flow import Flow, start
from crewai.state.checkpoint_config import CheckpointConfig
from crewai.state.checkpoint_listener import (
_find_checkpoint,
_prune,
_resolve,
_SENTINEL,
)
from crewai.task import Task
# ---------- _resolve ----------
class TestResolve:
def test_none_returns_none(self) -> None:
assert _resolve(None) is None
def test_false_returns_sentinel(self) -> None:
assert _resolve(False) is _SENTINEL
def test_true_returns_config(self) -> None:
result = _resolve(True)
assert isinstance(result, CheckpointConfig)
assert result.directory == "./.checkpoints"
def test_config_returns_config(self) -> None:
cfg = CheckpointConfig(directory="/tmp/cp")
assert _resolve(cfg) is cfg
# ---------- _find_checkpoint inheritance ----------
class TestFindCheckpoint:
def _make_agent(self, checkpoint: Any = None) -> Agent:
return Agent(role="r", goal="g", backstory="b", checkpoint=checkpoint)
def _make_crew(
self, agents: list[Agent], checkpoint: Any = None
) -> Crew:
crew = Crew(agents=agents, tasks=[], checkpoint=checkpoint)
for a in agents:
a.crew = crew
return crew
def test_crew_true(self) -> None:
a = self._make_agent()
self._make_crew([a], checkpoint=True)
cfg = _find_checkpoint(a)
assert isinstance(cfg, CheckpointConfig)
def test_crew_true_agent_false_opts_out(self) -> None:
a = self._make_agent(checkpoint=False)
self._make_crew([a], checkpoint=True)
assert _find_checkpoint(a) is None
def test_crew_none_agent_none(self) -> None:
a = self._make_agent()
self._make_crew([a])
assert _find_checkpoint(a) is None
def test_agent_config_overrides_crew(self) -> None:
a = self._make_agent(
checkpoint=CheckpointConfig(directory="/agent_cp")
)
self._make_crew([a], checkpoint=True)
cfg = _find_checkpoint(a)
assert isinstance(cfg, CheckpointConfig)
assert cfg.directory == "/agent_cp"
def test_task_inherits_from_crew(self) -> None:
a = self._make_agent()
self._make_crew([a], checkpoint=True)
task = Task(description="d", expected_output="e", agent=a)
cfg = _find_checkpoint(task)
assert isinstance(cfg, CheckpointConfig)
def test_task_agent_false_blocks(self) -> None:
a = self._make_agent(checkpoint=False)
self._make_crew([a], checkpoint=True)
task = Task(description="d", expected_output="e", agent=a)
assert _find_checkpoint(task) is None
def test_flow_direct(self) -> None:
flow = Flow(checkpoint=True)
cfg = _find_checkpoint(flow)
assert isinstance(cfg, CheckpointConfig)
def test_flow_none(self) -> None:
flow = Flow()
assert _find_checkpoint(flow) is None
def test_unknown_source(self) -> None:
assert _find_checkpoint("random") is None
# ---------- _prune ----------
class TestPrune:
def test_prune_keeps_newest(self) -> None:
with tempfile.TemporaryDirectory() as d:
for i in range(5):
path = os.path.join(d, f"cp_{i}.json")
with open(path, "w") as f:
f.write("{}")
# Ensure distinct mtime
time.sleep(0.01)
_prune(d, max_keep=2)
remaining = os.listdir(d)
assert len(remaining) == 2
assert "cp_3.json" in remaining
assert "cp_4.json" in remaining
def test_prune_zero_removes_all(self) -> None:
with tempfile.TemporaryDirectory() as d:
for i in range(3):
with open(os.path.join(d, f"cp_{i}.json"), "w") as f:
f.write("{}")
_prune(d, max_keep=0)
assert os.listdir(d) == []
def test_prune_more_than_existing(self) -> None:
with tempfile.TemporaryDirectory() as d:
with open(os.path.join(d, "cp.json"), "w") as f:
f.write("{}")
_prune(d, max_keep=10)
assert len(os.listdir(d)) == 1
# ---------- CheckpointConfig ----------
class TestCheckpointConfig:
def test_defaults(self) -> None:
cfg = CheckpointConfig()
assert cfg.directory == "./.checkpoints"
assert cfg.on_events == ["task_completed"]
assert cfg.max_checkpoints is None
assert not cfg.trigger_all
def test_trigger_all(self) -> None:
cfg = CheckpointConfig(on_events=["*"])
assert cfg.trigger_all
def test_trigger_events(self) -> None:
cfg = CheckpointConfig(
on_events=["task_completed", "crew_kickoff_completed"]
)
assert cfg.trigger_events == {"task_completed", "crew_kickoff_completed"}