mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 23:02:50 +00:00
refactor: generic from_checkpoint with provider, full LLM serialization
This commit is contained in:
@@ -51,6 +51,7 @@ from crewai.utilities.string_utils import interpolate_only
|
||||
if TYPE_CHECKING:
|
||||
from crewai.context import ExecutionContext
|
||||
from crewai.crew import Crew
|
||||
from crewai.state.provider.core import BaseProvider
|
||||
|
||||
|
||||
def _validate_crew_ref(value: Any) -> Any:
|
||||
@@ -296,17 +297,18 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
execution_context: ExecutionContext | None = Field(default=None)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(cls, path: str) -> Self:
|
||||
def from_checkpoint(
|
||||
cls, path: str, *, provider: BaseProvider | None = None
|
||||
) -> Self:
|
||||
"""Restore an Agent from a checkpoint file."""
|
||||
from pathlib import Path as _Path
|
||||
|
||||
from crewai.context import apply_execution_context
|
||||
from crewai.state.provider.json_provider import JsonProvider
|
||||
from crewai.state.runtime import RuntimeState
|
||||
|
||||
json_str = _Path(path).read_text()
|
||||
from crewai import RuntimeState
|
||||
|
||||
state = RuntimeState.model_validate_json(
|
||||
json_str, context={"from_checkpoint": True}
|
||||
state = RuntimeState.from_checkpoint(
|
||||
path,
|
||||
provider=provider or JsonProvider(),
|
||||
context={"from_checkpoint": True},
|
||||
)
|
||||
for entity in state.root:
|
||||
if isinstance(entity, cls):
|
||||
|
||||
@@ -21,18 +21,20 @@ if TYPE_CHECKING:
|
||||
class BaseAgentExecutor(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
crew: Crew = Field(default=None, exclude=True) # type: ignore[assignment]
|
||||
agent: BaseAgent = Field(default=None, exclude=True) # type: ignore[assignment]
|
||||
task: Task = Field(default=None, exclude=True) # type: ignore[assignment]
|
||||
crew: Crew | None = Field(default=None, exclude=True)
|
||||
agent: BaseAgent | None = Field(default=None, exclude=True)
|
||||
task: Task | None = Field(default=None, exclude=True)
|
||||
iterations: int = Field(default=0)
|
||||
max_iter: int = Field(default=25)
|
||||
messages: list[LLMMessage] = Field(default_factory=list)
|
||||
_resuming: bool = PrivateAttr(default=False)
|
||||
_i18n: I18N = PrivateAttr(default=None) # type: ignore[assignment]
|
||||
_i18n: I18N | None = PrivateAttr(default=None)
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
|
||||
def _save_to_memory(self, output: AgentFinish) -> None:
|
||||
"""Save task result to unified memory (memory or crew._memory)."""
|
||||
if self.agent is None:
|
||||
return
|
||||
memory = getattr(self.agent, "memory", None) or (
|
||||
getattr(self.crew, "_memory", None) if self.crew else None
|
||||
)
|
||||
|
||||
@@ -42,6 +42,7 @@ if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
from crewai.context import ExecutionContext
|
||||
from crewai.state.provider.core import BaseProvider
|
||||
|
||||
try:
|
||||
from crewai_files import get_supported_content_types
|
||||
@@ -354,25 +355,27 @@ class Crew(FlowTrackable, BaseModel):
|
||||
checkpoint_kickoff_event_id: str | None = Field(default=None)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(cls, path: str) -> Crew:
|
||||
def from_checkpoint(
|
||||
cls, path: str, *, provider: BaseProvider | None = None
|
||||
) -> Crew:
|
||||
"""Restore a Crew from a checkpoint file, ready to resume via kickoff().
|
||||
|
||||
Args:
|
||||
path: Path to a checkpoint JSON file.
|
||||
provider: Storage backend to read from. Defaults to JsonProvider.
|
||||
|
||||
Returns:
|
||||
A Crew instance. Call kickoff() to resume from the last completed task.
|
||||
"""
|
||||
from pathlib import Path as _Path
|
||||
|
||||
from crewai.context import apply_execution_context
|
||||
|
||||
json_str = _Path(path).read_text()
|
||||
from crewai import RuntimeState
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.state.provider.json_provider import JsonProvider
|
||||
from crewai.state.runtime import RuntimeState
|
||||
|
||||
state = RuntimeState.model_validate_json(
|
||||
json_str, context={"from_checkpoint": True}
|
||||
state = RuntimeState.from_checkpoint(
|
||||
path,
|
||||
provider=provider or JsonProvider(),
|
||||
context={"from_checkpoint": True},
|
||||
)
|
||||
crewai_event_bus.set_runtime_state(state)
|
||||
for entity in state.root:
|
||||
|
||||
@@ -121,6 +121,7 @@ if TYPE_CHECKING:
|
||||
from crewai.context import ExecutionContext
|
||||
from crewai.flow.async_feedback.types import PendingFeedbackContext
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.state.provider.core import BaseProvider
|
||||
|
||||
from crewai.flow.visualization import build_flow_structure, render_interactive
|
||||
from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput
|
||||
@@ -921,17 +922,18 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
execution_context: ExecutionContext | None = Field(default=None)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(cls, path: str) -> Flow: # type: ignore[type-arg]
|
||||
def from_checkpoint(
|
||||
cls, path: str, *, provider: BaseProvider | None = None
|
||||
) -> Flow: # type: ignore[type-arg]
|
||||
"""Restore a Flow from a checkpoint file."""
|
||||
from pathlib import Path as _Path
|
||||
|
||||
from crewai.context import apply_execution_context
|
||||
from crewai.state.provider.json_provider import JsonProvider
|
||||
from crewai.state.runtime import RuntimeState
|
||||
|
||||
json_str = _Path(path).read_text()
|
||||
from crewai import RuntimeState
|
||||
|
||||
state = RuntimeState.model_validate_json(
|
||||
json_str, context={"from_checkpoint": True}
|
||||
state = RuntimeState.from_checkpoint(
|
||||
path,
|
||||
provider=provider or JsonProvider(),
|
||||
context={"from_checkpoint": True},
|
||||
)
|
||||
for entity in state.root:
|
||||
if isinstance(entity, cls):
|
||||
|
||||
@@ -57,3 +57,25 @@ class BaseProvider(Protocol):
|
||||
A location identifier for the saved checkpoint, such as a file path or URI.
|
||||
"""
|
||||
...
|
||||
|
||||
def from_checkpoint(self, location: str) -> str:
|
||||
"""Read a snapshot synchronously.
|
||||
|
||||
Args:
|
||||
location: The identifier returned by a previous ``checkpoint`` call.
|
||||
|
||||
Returns:
|
||||
The raw serialized string.
|
||||
"""
|
||||
...
|
||||
|
||||
async def afrom_checkpoint(self, location: str) -> str:
|
||||
"""Read a snapshot asynchronously.
|
||||
|
||||
Args:
|
||||
location: The identifier returned by a previous ``acheckpoint`` call.
|
||||
|
||||
Returns:
|
||||
The raw serialized string.
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -49,6 +49,29 @@ class JsonProvider(BaseProvider):
|
||||
await f.write(data)
|
||||
return str(file_path)
|
||||
|
||||
def from_checkpoint(self, location: str) -> str:
|
||||
"""Read a JSON checkpoint file.
|
||||
|
||||
Args:
|
||||
location: Filesystem path to the checkpoint file.
|
||||
|
||||
Returns:
|
||||
The raw JSON string.
|
||||
"""
|
||||
return Path(location).read_text()
|
||||
|
||||
async def afrom_checkpoint(self, location: str) -> str:
|
||||
"""Read a JSON checkpoint file asynchronously.
|
||||
|
||||
Args:
|
||||
location: Filesystem path to the checkpoint file.
|
||||
|
||||
Returns:
|
||||
The raw JSON string.
|
||||
"""
|
||||
async with aiofiles.open(location) as f:
|
||||
return await f.read()
|
||||
|
||||
|
||||
def _build_path(directory: str) -> Path:
|
||||
"""Build a timestamped checkpoint file path.
|
||||
|
||||
@@ -128,6 +128,40 @@ class RuntimeState(RootModel): # type: ignore[type-arg]
|
||||
_prepare_entities(self.root)
|
||||
return await self._provider.acheckpoint(self.model_dump_json(), directory)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(
|
||||
cls, location: str, provider: BaseProvider, **kwargs: Any
|
||||
) -> RuntimeState:
|
||||
"""Restore a RuntimeState from a checkpoint.
|
||||
|
||||
Args:
|
||||
location: The identifier returned by a previous ``checkpoint`` call.
|
||||
provider: The storage backend to read from.
|
||||
**kwargs: Passed to ``model_validate_json``.
|
||||
|
||||
Returns:
|
||||
A restored RuntimeState.
|
||||
"""
|
||||
raw = provider.from_checkpoint(location)
|
||||
return cls.model_validate_json(raw, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def afrom_checkpoint(
|
||||
cls, location: str, provider: BaseProvider, **kwargs: Any
|
||||
) -> RuntimeState:
|
||||
"""Async version of :meth:`from_checkpoint`.
|
||||
|
||||
Args:
|
||||
location: The identifier returned by a previous ``acheckpoint`` call.
|
||||
provider: The storage backend to read from.
|
||||
**kwargs: Passed to ``model_validate_json``.
|
||||
|
||||
Returns:
|
||||
A restored RuntimeState.
|
||||
"""
|
||||
raw = await provider.afrom_checkpoint(location)
|
||||
return cls.model_validate_json(raw, **kwargs)
|
||||
|
||||
|
||||
def _prepare_entities(root: list[Entity]) -> None:
|
||||
"""Capture execution context and sync checkpoint fields on each entity.
|
||||
|
||||
Reference in New Issue
Block a user