fix: register checkpoint handlers when CheckpointConfig is created

This commit is contained in:
Greyson LaLonde
2026-04-08 02:11:34 +08:00
committed by GitHub
parent 25eb4adc49
commit c0f3151e13
4 changed files with 38 additions and 8 deletions

View File

@@ -39,7 +39,7 @@ from crewai.memory.unified_memory import Memory
from crewai.rag.embeddings.types import EmbedderConfig
from crewai.security.security_config import SecurityConfig
from crewai.skills.models import Skill
from crewai.state.checkpoint_config import CheckpointConfig
from crewai.state.checkpoint_config import CheckpointConfig, _coerce_checkpoint
from crewai.tools.base_tool import BaseTool, Tool
from crewai.types.callback import SerializableCallable
from crewai.utilities.config import process_config
@@ -300,7 +300,10 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
default_factory=SecurityConfig,
description="Security configuration for the agent, including fingerprinting.",
)
checkpoint: CheckpointConfig | bool | None = Field(
checkpoint: Annotated[
CheckpointConfig | bool | None,
BeforeValidator(_coerce_checkpoint),
] = Field(
default=None,
description="Automatic checkpointing configuration. "
"True for defaults, False to opt out, None to inherit.",

View File

@@ -104,7 +104,7 @@ from crewai.rag.types import SearchResult
from crewai.security.fingerprint import Fingerprint
from crewai.security.security_config import SecurityConfig
from crewai.skills.models import Skill
from crewai.state.checkpoint_config import CheckpointConfig
from crewai.state.checkpoint_config import CheckpointConfig, _coerce_checkpoint
from crewai.task import Task
from crewai.tasks.conditional_task import ConditionalTask
from crewai.tasks.task_output import TaskOutput
@@ -341,7 +341,10 @@ class Crew(FlowTrackable, BaseModel):
default_factory=SecurityConfig,
description="Security configuration for the crew, including fingerprinting.",
)
checkpoint: CheckpointConfig | bool | None = Field(
checkpoint: Annotated[
CheckpointConfig | bool | None,
BeforeValidator(_coerce_checkpoint),
] = Field(
default=None,
description="Automatic checkpointing configuration. "
"True for defaults, False to opt out, None to inherit.",

View File

@@ -113,7 +113,7 @@ from crewai.flow.utils import (
)
from crewai.memory.memory_scope import MemoryScope, MemorySlice
from crewai.memory.unified_memory import Memory
from crewai.state.checkpoint_config import CheckpointConfig
from crewai.state.checkpoint_config import CheckpointConfig, _coerce_checkpoint
if TYPE_CHECKING:
@@ -921,7 +921,10 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
max_method_calls: int = Field(default=100)
execution_context: ExecutionContext | None = Field(default=None)
checkpoint: CheckpointConfig | bool | None = Field(default=None)
checkpoint: Annotated[
CheckpointConfig | bool | None,
BeforeValidator(_coerce_checkpoint),
] = Field(default=None)
@classmethod
def from_checkpoint(

View File

@@ -2,9 +2,9 @@
from __future__ import annotations
from typing import Literal
from typing import Any, Literal
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from crewai.state.provider.core import BaseProvider
from crewai.state.provider.json_provider import JsonProvider
@@ -158,6 +158,20 @@ CheckpointEventType = Literal[
]
def _coerce_checkpoint(v: Any) -> Any:
"""BeforeValidator for checkpoint fields on Crew/Flow/Agent.
Converts True to CheckpointConfig and triggers handler registration.
"""
if v is True:
v = CheckpointConfig()
if isinstance(v, CheckpointConfig):
from crewai.state.checkpoint_listener import _ensure_handlers_registered
_ensure_handlers_registered()
return v
class CheckpointConfig(BaseModel):
"""Configuration for automatic checkpointing.
@@ -185,6 +199,13 @@ class CheckpointConfig(BaseModel):
"each write. None means keep all.",
)
@model_validator(mode="after")
def _register_handlers(self) -> CheckpointConfig:
from crewai.state.checkpoint_listener import _ensure_handlers_registered
_ensure_handlers_registered()
return self
@property
def trigger_all(self) -> bool:
return "*" in self.on_events