diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py index 5cbe3ccbf..9ea223b46 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py @@ -51,6 +51,7 @@ from crewai.utilities.string_utils import interpolate_only if TYPE_CHECKING: from crewai.context import ExecutionContext from crewai.crew import Crew + from crewai.state.provider.core import BaseProvider def _validate_crew_ref(value: Any) -> Any: @@ -296,17 +297,18 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): execution_context: ExecutionContext | None = Field(default=None) @classmethod - def from_checkpoint(cls, path: str) -> Self: + def from_checkpoint( + cls, path: str, *, provider: BaseProvider | None = None + ) -> Self: """Restore an Agent from a checkpoint file.""" - from pathlib import Path as _Path - from crewai.context import apply_execution_context + from crewai.state.provider.json_provider import JsonProvider + from crewai.state.runtime import RuntimeState - json_str = _Path(path).read_text() - from crewai import RuntimeState - - state = RuntimeState.model_validate_json( - json_str, context={"from_checkpoint": True} + state = RuntimeState.from_checkpoint( + path, + provider=provider or JsonProvider(), + context={"from_checkpoint": True}, ) for entity in state.root: if isinstance(entity, cls): diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent_executor.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent_executor.py index ee076dfe0..37028a63b 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent_executor.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent_executor.py @@ -21,18 +21,20 @@ if TYPE_CHECKING: class BaseAgentExecutor(BaseModel): model_config = {"arbitrary_types_allowed": True} - crew: Crew = Field(default=None, exclude=True) # type: ignore[assignment] - agent: BaseAgent = Field(default=None, exclude=True) # type: ignore[assignment] - task: Task = Field(default=None, exclude=True) # type: ignore[assignment] + crew: Crew | None = Field(default=None, exclude=True) + agent: BaseAgent | None = Field(default=None, exclude=True) + task: Task | None = Field(default=None, exclude=True) iterations: int = Field(default=0) max_iter: int = Field(default=25) messages: list[LLMMessage] = Field(default_factory=list) _resuming: bool = PrivateAttr(default=False) - _i18n: I18N = PrivateAttr(default=None) # type: ignore[assignment] + _i18n: I18N | None = PrivateAttr(default=None) _printer: Printer = PrivateAttr(default_factory=Printer) def _save_to_memory(self, output: AgentFinish) -> None: """Save task result to unified memory (memory or crew._memory).""" + if self.agent is None: + return memory = getattr(self.agent, "memory", None) or ( getattr(self.crew, "_memory", None) if self.crew else None ) diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index 4f0162314..19d4aeb93 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -42,6 +42,7 @@ if TYPE_CHECKING: from opentelemetry.trace import Span from crewai.context import ExecutionContext + from crewai.state.provider.core import BaseProvider try: from crewai_files import get_supported_content_types @@ -354,25 +355,27 @@ class Crew(FlowTrackable, BaseModel): checkpoint_kickoff_event_id: str | None = Field(default=None) @classmethod - def from_checkpoint(cls, path: str) -> Crew: + def from_checkpoint( + cls, path: str, *, provider: BaseProvider | None = None + ) -> Crew: """Restore a Crew from a checkpoint file, ready to resume via kickoff(). Args: path: Path to a checkpoint JSON file. + provider: Storage backend to read from. Defaults to JsonProvider. Returns: A Crew instance. Call kickoff() to resume from the last completed task. """ - from pathlib import Path as _Path - from crewai.context import apply_execution_context - - json_str = _Path(path).read_text() - from crewai import RuntimeState from crewai.events.event_bus import crewai_event_bus + from crewai.state.provider.json_provider import JsonProvider + from crewai.state.runtime import RuntimeState - state = RuntimeState.model_validate_json( - json_str, context={"from_checkpoint": True} + state = RuntimeState.from_checkpoint( + path, + provider=provider or JsonProvider(), + context={"from_checkpoint": True}, ) crewai_event_bus.set_runtime_state(state) for entity in state.root: diff --git a/lib/crewai/src/crewai/flow/flow.py b/lib/crewai/src/crewai/flow/flow.py index cad4da304..ed1eda7e5 100644 --- a/lib/crewai/src/crewai/flow/flow.py +++ b/lib/crewai/src/crewai/flow/flow.py @@ -121,6 +121,7 @@ if TYPE_CHECKING: from crewai.context import ExecutionContext from crewai.flow.async_feedback.types import PendingFeedbackContext from crewai.llms.base_llm import BaseLLM + from crewai.state.provider.core import BaseProvider from crewai.flow.visualization import build_flow_structure, render_interactive from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput @@ -921,17 +922,18 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): execution_context: ExecutionContext | None = Field(default=None) @classmethod - def from_checkpoint(cls, path: str) -> Flow: # type: ignore[type-arg] + def from_checkpoint( + cls, path: str, *, provider: BaseProvider | None = None + ) -> Flow: # type: ignore[type-arg] """Restore a Flow from a checkpoint file.""" - from pathlib import Path as _Path - from crewai.context import apply_execution_context + from crewai.state.provider.json_provider import JsonProvider + from crewai.state.runtime import RuntimeState - json_str = _Path(path).read_text() - from crewai import RuntimeState - - state = RuntimeState.model_validate_json( - json_str, context={"from_checkpoint": True} + state = RuntimeState.from_checkpoint( + path, + provider=provider or JsonProvider(), + context={"from_checkpoint": True}, ) for entity in state.root: if isinstance(entity, cls): diff --git a/lib/crewai/src/crewai/state/provider/core.py b/lib/crewai/src/crewai/state/provider/core.py index 71698c712..ee420eea0 100644 --- a/lib/crewai/src/crewai/state/provider/core.py +++ b/lib/crewai/src/crewai/state/provider/core.py @@ -57,3 +57,25 @@ class BaseProvider(Protocol): A location identifier for the saved checkpoint, such as a file path or URI. """ ... + + def from_checkpoint(self, location: str) -> str: + """Read a snapshot synchronously. + + Args: + location: The identifier returned by a previous ``checkpoint`` call. + + Returns: + The raw serialized string. + """ + ... + + async def afrom_checkpoint(self, location: str) -> str: + """Read a snapshot asynchronously. + + Args: + location: The identifier returned by a previous ``acheckpoint`` call. + + Returns: + The raw serialized string. + """ + ... diff --git a/lib/crewai/src/crewai/state/provider/json_provider.py b/lib/crewai/src/crewai/state/provider/json_provider.py index 67f770925..656e19fe0 100644 --- a/lib/crewai/src/crewai/state/provider/json_provider.py +++ b/lib/crewai/src/crewai/state/provider/json_provider.py @@ -49,6 +49,29 @@ class JsonProvider(BaseProvider): await f.write(data) return str(file_path) + def from_checkpoint(self, location: str) -> str: + """Read a JSON checkpoint file. + + Args: + location: Filesystem path to the checkpoint file. + + Returns: + The raw JSON string. + """ + return Path(location).read_text() + + async def afrom_checkpoint(self, location: str) -> str: + """Read a JSON checkpoint file asynchronously. + + Args: + location: Filesystem path to the checkpoint file. + + Returns: + The raw JSON string. + """ + async with aiofiles.open(location) as f: + return await f.read() + def _build_path(directory: str) -> Path: """Build a timestamped checkpoint file path. diff --git a/lib/crewai/src/crewai/state/runtime.py b/lib/crewai/src/crewai/state/runtime.py index 8fa684264..a4bc8584c 100644 --- a/lib/crewai/src/crewai/state/runtime.py +++ b/lib/crewai/src/crewai/state/runtime.py @@ -128,6 +128,40 @@ class RuntimeState(RootModel): # type: ignore[type-arg] _prepare_entities(self.root) return await self._provider.acheckpoint(self.model_dump_json(), directory) + @classmethod + def from_checkpoint( + cls, location: str, provider: BaseProvider, **kwargs: Any + ) -> RuntimeState: + """Restore a RuntimeState from a checkpoint. + + Args: + location: The identifier returned by a previous ``checkpoint`` call. + provider: The storage backend to read from. + **kwargs: Passed to ``model_validate_json``. + + Returns: + A restored RuntimeState. + """ + raw = provider.from_checkpoint(location) + return cls.model_validate_json(raw, **kwargs) + + @classmethod + async def afrom_checkpoint( + cls, location: str, provider: BaseProvider, **kwargs: Any + ) -> RuntimeState: + """Async version of :meth:`from_checkpoint`. + + Args: + location: The identifier returned by a previous ``acheckpoint`` call. + provider: The storage backend to read from. + **kwargs: Passed to ``model_validate_json``. + + Returns: + A restored RuntimeState. + """ + raw = await provider.afrom_checkpoint(location) + return cls.model_validate_json(raw, **kwargs) + def _prepare_entities(root: list[Entity]) -> None: """Capture execution context and sync checkpoint fields on each entity.