feat: pass RuntimeState through event bus, add .checkpoint() and .from_checkpoint()

This commit is contained in:
Greyson LaLonde
2026-04-03 05:33:15 +08:00
parent 6627845372
commit cf241d85e8
8 changed files with 199 additions and 112 deletions

View File

@@ -185,6 +185,30 @@ try:
Discriminator(_entity_discriminator),
]
def _sync_checkpoint_fields(entity: object) -> None:
"""Copy private runtime attrs into checkpoint fields before serializing."""
if isinstance(entity, Flow):
entity.checkpoint_completed_methods = (
set(entity._completed_methods) if entity._completed_methods else None
)
entity.checkpoint_method_outputs = (
list(entity._method_outputs) if entity._method_outputs else None
)
entity.checkpoint_method_counts = (
{str(k): v for k, v in entity._method_execution_counts.items()}
if entity._method_execution_counts
else None
)
entity.checkpoint_state = (
entity._copy_and_serialize_state()
if entity._state is not None
else None
)
if isinstance(entity, Crew):
entity.checkpoint_inputs = entity._inputs
entity.checkpoint_train = entity._train
entity.checkpoint_kickoff_event_id = entity._kickoff_event_id
class RuntimeState(RootModel[list[Entity]]):
def checkpoint(self, directory: str) -> str:
"""Write a checkpoint file to the directory.
@@ -203,6 +227,7 @@ try:
for entity in self.root:
entity.execution_context = capture_execution_context()
_sync_checkpoint_fields(entity)
dir_path = _Path(directory)
dir_path.mkdir(parents=True, exist_ok=True)

View File

@@ -27,7 +27,6 @@ from pydantic import (
BeforeValidator,
ConfigDict,
Field,
InstanceOf,
PrivateAttr,
model_validator,
)
@@ -297,8 +296,8 @@ class Agent(BaseAgent):
Can be a single A2AConfig/A2AClientConfig/A2AServerConfig, or a list of any number of A2AConfig/A2AClientConfig with a single A2AServerConfig.
""",
)
agent_executor: InstanceOf[CrewAgentExecutor] | InstanceOf[AgentExecutor] | None = (
Field(default=None, description="An instance of the CrewAgentExecutor class.")
agent_executor: CrewAgentExecutor | AgentExecutor | None = Field(
default=None, description="An instance of the CrewAgentExecutor class."
)
executor_class: Annotated[
type[CrewAgentExecutor] | type[AgentExecutor],
@@ -1011,7 +1010,7 @@ class Agent(BaseAgent):
)
self.agent_executor = self.executor_class(
llm=self.llm,
task=task, # type: ignore[arg-type]
task=task,
i18n=self.i18n,
agent=self,
crew=self.crew, # type: ignore[arg-type]

View File

@@ -14,7 +14,6 @@ from pydantic import (
BaseModel,
BeforeValidator,
Field,
InstanceOf,
PrivateAttr,
field_validator,
model_validator,
@@ -197,7 +196,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
max_iter: int = Field(
default=25, description="Maximum iterations for an agent to execute a task"
)
agent_executor: InstanceOf[CrewAgentExecutorMixin] | None = Field(
agent_executor: CrewAgentExecutorMixin | None = Field(
default=None, description="An instance of the CrewAgentExecutor class."
)
llm: Annotated[
@@ -276,6 +275,26 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
)
execution_context: ExecutionContext | None = Field(default=None)
@classmethod
def from_checkpoint(cls, path: str) -> Self:
"""Restore an Agent from a checkpoint file."""
from pathlib import Path as _Path
from crewai.context import apply_execution_context
json_str = _Path(path).read_text()
from crewai import RuntimeState
state = RuntimeState.model_validate_json(
json_str, context={"from_checkpoint": True}
)
for entity in state.root:
if isinstance(entity, cls):
if entity.execution_context is not None:
apply_execution_context(entity.execution_context)
return entity
raise ValueError(f"No {cls.__name__} found in checkpoint: {path}")
@model_validator(mode="before")
@classmethod
def process_model_config(cls, values: Any) -> dict[str, Any]:

View File

@@ -1,6 +1,8 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, Field, PrivateAttr
from crewai.agents.parser import AgentFinish
from crewai.memory.utils import sanitize_scope_name
@@ -9,22 +11,44 @@ from crewai.utilities.string_utils import sanitize_tool_name
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.crew import Crew
from crewai.task import Task
from crewai.utilities.i18n import I18N
from crewai.utilities.types import LLMMessage
pass
class CrewAgentExecutorMixin:
crew: Crew | None
agent: Agent
task: Task | None
iterations: int
max_iter: int
messages: list[LLMMessage]
_i18n: I18N
_printer: Printer = Printer()
class CrewAgentExecutorMixin(BaseModel):
model_config = {"arbitrary_types_allowed": True}
_crew: Any = PrivateAttr(default=None)
_agent: Any = PrivateAttr(default=None)
_task: Any = PrivateAttr(default=None)
iterations: int = Field(default=0)
max_iter: int = Field(default=25)
messages: list[Any] = Field(default_factory=list)
_i18n: Any = PrivateAttr(default=None)
_printer: Printer = PrivateAttr(default_factory=Printer)
@property
def crew(self) -> Any:
return self._crew
@crew.setter
def crew(self, value: Any) -> None:
self._crew = value
@property
def agent(self) -> Any:
return self._agent
@agent.setter
def agent(self, value: Any) -> None:
self._agent = value
@property
def task(self) -> Any:
return self._task
@task.setter
def task(self, value: Any) -> None:
self._task = value
def _save_to_memory(self, output: AgentFinish) -> None:
"""Save task result to unified memory (memory or crew._memory).
@@ -49,11 +73,9 @@ class CrewAgentExecutorMixin:
)
extracted = memory.extract_memories(raw)
if extracted:
# Get the memory's existing root_scope
base_root = getattr(memory, "root_scope", None)
if isinstance(base_root, str) and base_root:
# Memory has a root_scope — extend it with agent info
agent_role = self.agent.role or "unknown"
sanitized_role = sanitize_scope_name(agent_role)
agent_root = f"{base_root.rstrip('/')}/agent/{sanitized_role}"
@@ -63,7 +85,6 @@ class CrewAgentExecutorMixin:
extracted, agent_role=self.agent.role, root_scope=agent_root
)
else:
# No base root_scope — don't inject one, preserve backward compat
memory.remember_many(extracted, agent_role=self.agent.role)
except Exception as e:
self.agent._logger.log("error", f"Failed to save to memory: {e}")

View File

@@ -14,8 +14,7 @@ import inspect
import logging
from typing import TYPE_CHECKING, Any, Literal, cast
from pydantic import BaseModel, GetCoreSchemaHandler, ValidationError
from pydantic_core import CoreSchema, core_schema
from pydantic import BaseModel, Field, ValidationError
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.parser import (
@@ -58,7 +57,6 @@ from crewai.utilities.agent_utils import (
from crewai.utilities.constants import TRAINING_DATA_FILE
from crewai.utilities.file_store import aget_all_files, get_all_files
from crewai.utilities.i18n import I18N, get_i18n
from crewai.utilities.printer import Printer
from crewai.utilities.string_utils import sanitize_tool_name
from crewai.utilities.tool_utils import (
aexecute_tool_and_check_finality,
@@ -89,19 +87,38 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
LLM interactions, tool execution, and feedback handling.
"""
llm: Any = Field(default=None)
prompt: Any = Field(default=None)
tools: list[Any] = Field(default_factory=list)
tools_names: str = Field(default="")
stop: list[str] = Field(default_factory=list)
tools_description: str = Field(default="")
tools_handler: Any = Field(default=None)
step_callback: Any = Field(default=None)
original_tools: list[Any] = Field(default_factory=list)
function_calling_llm: Any = Field(default=None)
respect_context_window: bool = Field(default=False)
request_within_rpm_limit: Any = Field(default=None)
callbacks: list[Any] = Field(default_factory=list)
response_model: Any = Field(default=None)
ask_for_human_input: bool = Field(default=False)
log_error_after: int = Field(default=3)
before_llm_call_hooks: list[Any] = Field(default_factory=list)
after_llm_call_hooks: list[Any] = Field(default_factory=list)
def __init__(
self,
llm: BaseLLM,
task: Task,
crew: Crew,
agent: Agent,
prompt: SystemPromptResult | StandardPromptResult,
max_iter: int,
tools: list[CrewStructuredTool],
tools_names: str,
stop_words: list[str],
tools_description: str,
tools_handler: ToolsHandler,
llm: BaseLLM | None = None,
task: Task | None = None,
crew: Crew | None = None,
agent: Agent | None = None,
prompt: SystemPromptResult | StandardPromptResult | None = None,
max_iter: int = 25,
tools: list[CrewStructuredTool] | None = None,
tools_names: str = "",
stop_words: list[str] | None = None,
tools_description: str = "",
tools_handler: ToolsHandler | None = None,
step_callback: Any = None,
original_tools: list[BaseTool] | None = None,
function_calling_llm: BaseLLM | Any | None = None,
@@ -110,59 +127,33 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
callbacks: list[Any] | None = None,
response_model: type[BaseModel] | None = None,
i18n: I18N | None = None,
**kwargs: Any,
) -> None:
"""Initialize executor.
Args:
llm: Language model instance.
task: Task to execute.
crew: Crew instance.
agent: Agent to execute.
prompt: Prompt templates.
max_iter: Maximum iterations.
tools: Available tools.
tools_names: Tool names string.
stop_words: Stop word list.
tools_description: Tool descriptions.
tools_handler: Tool handler instance.
step_callback: Optional step callback.
original_tools: Original tool list.
function_calling_llm: Optional function calling LLM.
respect_context_window: Respect context limits.
request_within_rpm_limit: RPM limit check function.
callbacks: Optional callbacks list.
response_model: Optional Pydantic model for structured outputs.
"""
self._i18n: I18N = i18n or get_i18n()
self.llm = llm
self.task = task
self.agent = agent
super().__init__(
llm=llm,
prompt=prompt,
tools=tools or [],
tools_names=tools_names,
stop=stop_words or [],
max_iter=max_iter,
callbacks=callbacks or [],
tools_handler=tools_handler,
original_tools=original_tools or [],
step_callback=step_callback,
tools_description=tools_description,
function_calling_llm=function_calling_llm,
respect_context_window=respect_context_window,
request_within_rpm_limit=request_within_rpm_limit,
response_model=response_model,
**kwargs,
)
self.crew = crew
self.prompt = prompt
self.tools = tools
self.tools_names = tools_names
self.stop = stop_words
self.max_iter = max_iter
self.callbacks = callbacks or []
self._printer: Printer = Printer()
self.tools_handler = tools_handler
self.original_tools = original_tools or []
self.step_callback = step_callback
self.tools_description = tools_description
self.function_calling_llm = function_calling_llm
self.respect_context_window = respect_context_window
self.request_within_rpm_limit = request_within_rpm_limit
self.response_model = response_model
self.ask_for_human_input = False
self.messages: list[LLMMessage] = []
self.iterations = 0
self.log_error_after = 3
self.before_llm_call_hooks: list[Callable[..., Any]] = []
self.after_llm_call_hooks: list[Callable[..., Any]] = []
self.agent = agent
self.task = task
self._i18n = i18n or get_i18n()
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
if self.llm:
# This may be mutating the shared llm object and needs further evaluation
existing_stop = getattr(self.llm, "stop", [])
self.llm.stop = list(
set(
@@ -1687,14 +1678,3 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return format_message_for_llm(
self._i18n.slice("feedback_instructions").format(feedback=feedback)
)
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Generate Pydantic core schema for BaseClient Protocol.
This allows the Protocol to be used in Pydantic models without
requiring arbitrary_types_allowed=True.
"""
return core_schema.any_schema()

View File

@@ -353,6 +353,34 @@ class Crew(FlowTrackable, BaseModel):
checkpoint_train: bool | None = Field(default=None)
checkpoint_kickoff_event_id: str | None = Field(default=None)
@classmethod
def from_checkpoint(cls, path: str) -> Crew:
"""Restore a Crew from a checkpoint file.
Args:
path: Path to a checkpoint JSON file.
Returns:
A Crew instance with state restored from the checkpoint.
"""
from pathlib import Path as _Path
from crewai.context import apply_execution_context
json_str = _Path(path).read_text()
# Parse as RuntimeState to handle discriminated union
from crewai import RuntimeState
state = RuntimeState.model_validate_json(
json_str, context={"from_checkpoint": True}
)
for entity in state.root:
if isinstance(entity, cls):
if entity.execution_context is not None:
apply_execution_context(entity.execution_context)
return entity
raise ValueError(f"No Crew found in checkpoint: {path}")
@field_validator("id", mode="before")
@classmethod
def _deny_user_set_id(cls, v: UUID4 | None, info: Any) -> UUID4 | None:

View File

@@ -106,11 +106,8 @@ from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent import Agent
from crewai.agents.tools_handler import ToolsHandler
from crewai.crew import Crew
from crewai.llms.base_llm import BaseLLM
from crewai.task import Task
from crewai.tools.tool_types import ToolResult
from crewai.utilities.prompts import StandardPromptResult, SystemPromptResult
@@ -155,7 +152,7 @@ class AgentExecutorState(BaseModel):
)
class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin): # type: ignore[pydantic-unexpected]
"""Agent Executor for both standalone agents and crew-bound agents.
_skip_auto_memory prevents Flow from eagerly allocating a Memory
@@ -174,7 +171,6 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
suppress_flow_events: bool = True # always suppress for executor
llm: BaseLLM = Field(exclude=True)
agent: Agent = Field(exclude=True)
prompt: SystemPromptResult | StandardPromptResult = Field(exclude=True)
max_iter: int = Field(default=25, exclude=True)
tools: list[CrewStructuredTool] = Field(default_factory=list, exclude=True)
@@ -182,8 +178,6 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
stop_words: list[str] = Field(default_factory=list, exclude=True)
tools_description: str = Field(default="", exclude=True)
tools_handler: ToolsHandler | None = Field(default=None, exclude=True)
task: Task | None = Field(default=None, exclude=True)
crew: Crew | None = Field(default=None, exclude=True)
step_callback: Any = Field(default=None, exclude=True)
original_tools: list[BaseTool] = Field(default_factory=list, exclude=True)
function_calling_llm: BaseLLM | None = Field(default=None, exclude=True)
@@ -268,20 +262,20 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
"""Get thread-safe state proxy."""
return StateProxy(self._state, self._state_lock) # type: ignore[return-value]
@property
@property # type: ignore[misc]
def iterations(self) -> int:
"""Compatibility property for mixin - returns state iterations."""
return self._state.iterations # type: ignore[no-any-return]
return int(self._state.iterations)
@iterations.setter
def iterations(self, value: int) -> None:
"""Set state iterations."""
self._state.iterations = value
@property
@property # type: ignore[misc]
def messages(self) -> list[LLMMessage]:
"""Compatibility property - returns state messages."""
return self._state.messages # type: ignore[no-any-return]
return list(self._state.messages)
@messages.setter
def messages(self, value: list[LLMMessage]) -> None:
@@ -395,28 +389,28 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
"""
config = self.agent.planning_config
if config is not None:
return config.reasoning_effort
return str(config.reasoning_effort)
return "medium"
def _get_max_replans(self) -> int:
"""Get max replans from planning config or default to 3."""
config = self.agent.planning_config
if config is not None:
return config.max_replans
return int(config.max_replans)
return 3
def _get_max_step_iterations(self) -> int:
"""Get max step iterations from planning config or default to 15."""
config = self.agent.planning_config
if config is not None:
return config.max_step_iterations
return int(config.max_step_iterations)
return 15
def _get_step_timeout(self) -> int | None:
"""Get per-step timeout from planning config or default to None."""
config = self.agent.planning_config
if config is not None:
return config.step_timeout
return int(config.step_timeout) if config.step_timeout is not None else None
return None
def _build_context_for_todo(self, todo: TodoItem) -> StepExecutionContext:

View File

@@ -919,6 +919,27 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
max_method_calls: int = Field(default=100)
execution_context: ExecutionContext | None = Field(default=None)
@classmethod
def from_checkpoint(cls, path: str) -> Flow: # type: ignore[type-arg]
"""Restore a Flow from a checkpoint file."""
from pathlib import Path as _Path
from crewai.context import apply_execution_context
json_str = _Path(path).read_text()
from crewai import RuntimeState
state = RuntimeState.model_validate_json(
json_str, context={"from_checkpoint": True}
)
for entity in state.root:
if isinstance(entity, cls):
if entity.execution_context is not None:
apply_execution_context(entity.execution_context)
return entity
raise ValueError(f"No {cls.__name__} found in checkpoint: {path}")
checkpoint_completed_methods: set[str] | None = Field(default=None)
checkpoint_method_outputs: list[Any] | None = Field(default=None)
checkpoint_method_counts: dict[str, int] | None = Field(default=None)