refactor: generic from_checkpoint with provider, full LLM serialization

This commit is contained in:
Greyson LaLonde
2026-04-04 01:25:31 +08:00
parent 6dc9f462f9
commit 191053c41b
7 changed files with 116 additions and 28 deletions

View File

@@ -51,6 +51,7 @@ from crewai.utilities.string_utils import interpolate_only
if TYPE_CHECKING:
from crewai.context import ExecutionContext
from crewai.crew import Crew
from crewai.state.provider.core import BaseProvider
def _validate_crew_ref(value: Any) -> Any:
@@ -296,17 +297,18 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
execution_context: ExecutionContext | None = Field(default=None)
@classmethod
def from_checkpoint(cls, path: str) -> Self:
def from_checkpoint(
cls, path: str, *, provider: BaseProvider | None = None
) -> Self:
"""Restore an Agent from a checkpoint file."""
from pathlib import Path as _Path
from crewai.context import apply_execution_context
from crewai.state.provider.json_provider import JsonProvider
from crewai.state.runtime import RuntimeState
json_str = _Path(path).read_text()
from crewai import RuntimeState
state = RuntimeState.model_validate_json(
json_str, context={"from_checkpoint": True}
state = RuntimeState.from_checkpoint(
path,
provider=provider or JsonProvider(),
context={"from_checkpoint": True},
)
for entity in state.root:
if isinstance(entity, cls):

View File

@@ -21,18 +21,20 @@ if TYPE_CHECKING:
class BaseAgentExecutor(BaseModel):
model_config = {"arbitrary_types_allowed": True}
crew: Crew = Field(default=None, exclude=True) # type: ignore[assignment]
agent: BaseAgent = Field(default=None, exclude=True) # type: ignore[assignment]
task: Task = Field(default=None, exclude=True) # type: ignore[assignment]
crew: Crew | None = Field(default=None, exclude=True)
agent: BaseAgent | None = Field(default=None, exclude=True)
task: Task | None = Field(default=None, exclude=True)
iterations: int = Field(default=0)
max_iter: int = Field(default=25)
messages: list[LLMMessage] = Field(default_factory=list)
_resuming: bool = PrivateAttr(default=False)
_i18n: I18N = PrivateAttr(default=None) # type: ignore[assignment]
_i18n: I18N | None = PrivateAttr(default=None)
_printer: Printer = PrivateAttr(default_factory=Printer)
def _save_to_memory(self, output: AgentFinish) -> None:
"""Save task result to unified memory (memory or crew._memory)."""
if self.agent is None:
return
memory = getattr(self.agent, "memory", None) or (
getattr(self.crew, "_memory", None) if self.crew else None
)

View File

@@ -42,6 +42,7 @@ if TYPE_CHECKING:
from opentelemetry.trace import Span
from crewai.context import ExecutionContext
from crewai.state.provider.core import BaseProvider
try:
from crewai_files import get_supported_content_types
@@ -354,25 +355,27 @@ class Crew(FlowTrackable, BaseModel):
checkpoint_kickoff_event_id: str | None = Field(default=None)
@classmethod
def from_checkpoint(cls, path: str) -> Crew:
def from_checkpoint(
cls, path: str, *, provider: BaseProvider | None = None
) -> Crew:
"""Restore a Crew from a checkpoint file, ready to resume via kickoff().
Args:
path: Path to a checkpoint JSON file.
provider: Storage backend to read from. Defaults to JsonProvider.
Returns:
A Crew instance. Call kickoff() to resume from the last completed task.
"""
from pathlib import Path as _Path
from crewai.context import apply_execution_context
json_str = _Path(path).read_text()
from crewai import RuntimeState
from crewai.events.event_bus import crewai_event_bus
from crewai.state.provider.json_provider import JsonProvider
from crewai.state.runtime import RuntimeState
state = RuntimeState.model_validate_json(
json_str, context={"from_checkpoint": True}
state = RuntimeState.from_checkpoint(
path,
provider=provider or JsonProvider(),
context={"from_checkpoint": True},
)
crewai_event_bus.set_runtime_state(state)
for entity in state.root:

View File

@@ -121,6 +121,7 @@ if TYPE_CHECKING:
from crewai.context import ExecutionContext
from crewai.flow.async_feedback.types import PendingFeedbackContext
from crewai.llms.base_llm import BaseLLM
from crewai.state.provider.core import BaseProvider
from crewai.flow.visualization import build_flow_structure, render_interactive
from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput
@@ -921,17 +922,18 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
execution_context: ExecutionContext | None = Field(default=None)
@classmethod
def from_checkpoint(cls, path: str) -> Flow: # type: ignore[type-arg]
def from_checkpoint(
cls, path: str, *, provider: BaseProvider | None = None
) -> Flow: # type: ignore[type-arg]
"""Restore a Flow from a checkpoint file."""
from pathlib import Path as _Path
from crewai.context import apply_execution_context
from crewai.state.provider.json_provider import JsonProvider
from crewai.state.runtime import RuntimeState
json_str = _Path(path).read_text()
from crewai import RuntimeState
state = RuntimeState.model_validate_json(
json_str, context={"from_checkpoint": True}
state = RuntimeState.from_checkpoint(
path,
provider=provider or JsonProvider(),
context={"from_checkpoint": True},
)
for entity in state.root:
if isinstance(entity, cls):

View File

@@ -57,3 +57,25 @@ class BaseProvider(Protocol):
A location identifier for the saved checkpoint, such as a file path or URI.
"""
...
def from_checkpoint(self, location: str) -> str:
"""Read a snapshot synchronously.
Args:
location: The identifier returned by a previous ``checkpoint`` call.
Returns:
The raw serialized string.
"""
...
async def afrom_checkpoint(self, location: str) -> str:
"""Read a snapshot asynchronously.
Args:
location: The identifier returned by a previous ``acheckpoint`` call.
Returns:
The raw serialized string.
"""
...

View File

@@ -49,6 +49,29 @@ class JsonProvider(BaseProvider):
await f.write(data)
return str(file_path)
def from_checkpoint(self, location: str) -> str:
"""Read a JSON checkpoint file.
Args:
location: Filesystem path to the checkpoint file.
Returns:
The raw JSON string.
"""
return Path(location).read_text()
async def afrom_checkpoint(self, location: str) -> str:
"""Read a JSON checkpoint file asynchronously.
Args:
location: Filesystem path to the checkpoint file.
Returns:
The raw JSON string.
"""
async with aiofiles.open(location) as f:
return await f.read()
def _build_path(directory: str) -> Path:
"""Build a timestamped checkpoint file path.

View File

@@ -128,6 +128,40 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
_prepare_entities(self.root)
return await self._provider.acheckpoint(self.model_dump_json(), directory)
@classmethod
def from_checkpoint(
cls, location: str, provider: BaseProvider, **kwargs: Any
) -> RuntimeState:
"""Restore a RuntimeState from a checkpoint.
Args:
location: The identifier returned by a previous ``checkpoint`` call.
provider: The storage backend to read from.
**kwargs: Passed to ``model_validate_json``.
Returns:
A restored RuntimeState.
"""
raw = provider.from_checkpoint(location)
return cls.model_validate_json(raw, **kwargs)
@classmethod
async def afrom_checkpoint(
cls, location: str, provider: BaseProvider, **kwargs: Any
) -> RuntimeState:
"""Async version of :meth:`from_checkpoint`.
Args:
location: The identifier returned by a previous ``acheckpoint`` call.
provider: The storage backend to read from.
**kwargs: Passed to ``model_validate_json``.
Returns:
A restored RuntimeState.
"""
raw = await provider.afrom_checkpoint(location)
return cls.model_validate_json(raw, **kwargs)
def _prepare_entities(root: list[Entity]) -> None:
"""Capture execution context and sync checkpoint fields on each entity.