mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-02 05:38:12 +00:00
fix: register checkpoint handlers when CheckpointConfig is created
This commit is contained in:
@@ -39,7 +39,7 @@ from crewai.memory.unified_memory import Memory
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.security.security_config import SecurityConfig
|
||||
from crewai.skills.models import Skill
|
||||
from crewai.state.checkpoint_config import CheckpointConfig
|
||||
from crewai.state.checkpoint_config import CheckpointConfig, _coerce_checkpoint
|
||||
from crewai.tools.base_tool import BaseTool, Tool
|
||||
from crewai.types.callback import SerializableCallable
|
||||
from crewai.utilities.config import process_config
|
||||
@@ -300,7 +300,10 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
default_factory=SecurityConfig,
|
||||
description="Security configuration for the agent, including fingerprinting.",
|
||||
)
|
||||
checkpoint: CheckpointConfig | bool | None = Field(
|
||||
checkpoint: Annotated[
|
||||
CheckpointConfig | bool | None,
|
||||
BeforeValidator(_coerce_checkpoint),
|
||||
] = Field(
|
||||
default=None,
|
||||
description="Automatic checkpointing configuration. "
|
||||
"True for defaults, False to opt out, None to inherit.",
|
||||
|
||||
@@ -104,7 +104,7 @@ from crewai.rag.types import SearchResult
|
||||
from crewai.security.fingerprint import Fingerprint
|
||||
from crewai.security.security_config import SecurityConfig
|
||||
from crewai.skills.models import Skill
|
||||
from crewai.state.checkpoint_config import CheckpointConfig
|
||||
from crewai.state.checkpoint_config import CheckpointConfig, _coerce_checkpoint
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
@@ -341,7 +341,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default_factory=SecurityConfig,
|
||||
description="Security configuration for the crew, including fingerprinting.",
|
||||
)
|
||||
checkpoint: CheckpointConfig | bool | None = Field(
|
||||
checkpoint: Annotated[
|
||||
CheckpointConfig | bool | None,
|
||||
BeforeValidator(_coerce_checkpoint),
|
||||
] = Field(
|
||||
default=None,
|
||||
description="Automatic checkpointing configuration. "
|
||||
"True for defaults, False to opt out, None to inherit.",
|
||||
|
||||
@@ -113,7 +113,7 @@ from crewai.flow.utils import (
|
||||
)
|
||||
from crewai.memory.memory_scope import MemoryScope, MemorySlice
|
||||
from crewai.memory.unified_memory import Memory
|
||||
from crewai.state.checkpoint_config import CheckpointConfig
|
||||
from crewai.state.checkpoint_config import CheckpointConfig, _coerce_checkpoint
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -921,7 +921,10 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
max_method_calls: int = Field(default=100)
|
||||
|
||||
execution_context: ExecutionContext | None = Field(default=None)
|
||||
checkpoint: CheckpointConfig | bool | None = Field(default=None)
|
||||
checkpoint: Annotated[
|
||||
CheckpointConfig | bool | None,
|
||||
BeforeValidator(_coerce_checkpoint),
|
||||
] = Field(default=None)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from crewai.state.provider.core import BaseProvider
|
||||
from crewai.state.provider.json_provider import JsonProvider
|
||||
@@ -158,6 +158,20 @@ CheckpointEventType = Literal[
|
||||
]
|
||||
|
||||
|
||||
def _coerce_checkpoint(v: Any) -> Any:
|
||||
"""BeforeValidator for checkpoint fields on Crew/Flow/Agent.
|
||||
|
||||
Converts True to CheckpointConfig and triggers handler registration.
|
||||
"""
|
||||
if v is True:
|
||||
v = CheckpointConfig()
|
||||
if isinstance(v, CheckpointConfig):
|
||||
from crewai.state.checkpoint_listener import _ensure_handlers_registered
|
||||
|
||||
_ensure_handlers_registered()
|
||||
return v
|
||||
|
||||
|
||||
class CheckpointConfig(BaseModel):
|
||||
"""Configuration for automatic checkpointing.
|
||||
|
||||
@@ -185,6 +199,13 @@ class CheckpointConfig(BaseModel):
|
||||
"each write. None means keep all.",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _register_handlers(self) -> CheckpointConfig:
|
||||
from crewai.state.checkpoint_listener import _ensure_handlers_registered
|
||||
|
||||
_ensure_handlers_registered()
|
||||
return self
|
||||
|
||||
@property
|
||||
def trigger_all(self) -> bool:
|
||||
return "*" in self.on_events
|
||||
|
||||
Reference in New Issue
Block a user