mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 23:02:50 +00:00
feat: pass RuntimeState through event bus, add .checkpoint() and .from_checkpoint()
This commit is contained in:
@@ -185,6 +185,30 @@ try:
|
||||
Discriminator(_entity_discriminator),
|
||||
]
|
||||
|
||||
def _sync_checkpoint_fields(entity: object) -> None:
|
||||
"""Copy private runtime attrs into checkpoint fields before serializing."""
|
||||
if isinstance(entity, Flow):
|
||||
entity.checkpoint_completed_methods = (
|
||||
set(entity._completed_methods) if entity._completed_methods else None
|
||||
)
|
||||
entity.checkpoint_method_outputs = (
|
||||
list(entity._method_outputs) if entity._method_outputs else None
|
||||
)
|
||||
entity.checkpoint_method_counts = (
|
||||
{str(k): v for k, v in entity._method_execution_counts.items()}
|
||||
if entity._method_execution_counts
|
||||
else None
|
||||
)
|
||||
entity.checkpoint_state = (
|
||||
entity._copy_and_serialize_state()
|
||||
if entity._state is not None
|
||||
else None
|
||||
)
|
||||
if isinstance(entity, Crew):
|
||||
entity.checkpoint_inputs = entity._inputs
|
||||
entity.checkpoint_train = entity._train
|
||||
entity.checkpoint_kickoff_event_id = entity._kickoff_event_id
|
||||
|
||||
class RuntimeState(RootModel[list[Entity]]):
|
||||
def checkpoint(self, directory: str) -> str:
|
||||
"""Write a checkpoint file to the directory.
|
||||
@@ -203,6 +227,7 @@ try:
|
||||
|
||||
for entity in self.root:
|
||||
entity.execution_context = capture_execution_context()
|
||||
_sync_checkpoint_fields(entity)
|
||||
|
||||
dir_path = _Path(directory)
|
||||
dir_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -27,7 +27,6 @@ from pydantic import (
|
||||
BeforeValidator,
|
||||
ConfigDict,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
model_validator,
|
||||
)
|
||||
@@ -297,8 +296,8 @@ class Agent(BaseAgent):
|
||||
Can be a single A2AConfig/A2AClientConfig/A2AServerConfig, or a list of any number of A2AConfig/A2AClientConfig with a single A2AServerConfig.
|
||||
""",
|
||||
)
|
||||
agent_executor: InstanceOf[CrewAgentExecutor] | InstanceOf[AgentExecutor] | None = (
|
||||
Field(default=None, description="An instance of the CrewAgentExecutor class.")
|
||||
agent_executor: CrewAgentExecutor | AgentExecutor | None = Field(
|
||||
default=None, description="An instance of the CrewAgentExecutor class."
|
||||
)
|
||||
executor_class: Annotated[
|
||||
type[CrewAgentExecutor] | type[AgentExecutor],
|
||||
@@ -1011,7 +1010,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
self.agent_executor = self.executor_class(
|
||||
llm=self.llm,
|
||||
task=task, # type: ignore[arg-type]
|
||||
task=task,
|
||||
i18n=self.i18n,
|
||||
agent=self,
|
||||
crew=self.crew, # type: ignore[arg-type]
|
||||
|
||||
@@ -14,7 +14,6 @@ from pydantic import (
|
||||
BaseModel,
|
||||
BeforeValidator,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
field_validator,
|
||||
model_validator,
|
||||
@@ -197,7 +196,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
max_iter: int = Field(
|
||||
default=25, description="Maximum iterations for an agent to execute a task"
|
||||
)
|
||||
agent_executor: InstanceOf[CrewAgentExecutorMixin] | None = Field(
|
||||
agent_executor: CrewAgentExecutorMixin | None = Field(
|
||||
default=None, description="An instance of the CrewAgentExecutor class."
|
||||
)
|
||||
llm: Annotated[
|
||||
@@ -276,6 +275,26 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
)
|
||||
execution_context: ExecutionContext | None = Field(default=None)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(cls, path: str) -> Self:
|
||||
"""Restore an Agent from a checkpoint file."""
|
||||
from pathlib import Path as _Path
|
||||
|
||||
from crewai.context import apply_execution_context
|
||||
|
||||
json_str = _Path(path).read_text()
|
||||
from crewai import RuntimeState
|
||||
|
||||
state = RuntimeState.model_validate_json(
|
||||
json_str, context={"from_checkpoint": True}
|
||||
)
|
||||
for entity in state.root:
|
||||
if isinstance(entity, cls):
|
||||
if entity.execution_context is not None:
|
||||
apply_execution_context(entity.execution_context)
|
||||
return entity
|
||||
raise ValueError(f"No {cls.__name__} found in checkpoint: {path}")
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def process_model_config(cls, values: Any) -> dict[str, Any]:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from crewai.agents.parser import AgentFinish
|
||||
from crewai.memory.utils import sanitize_scope_name
|
||||
@@ -9,22 +11,44 @@ from crewai.utilities.string_utils import sanitize_tool_name
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.crew import Crew
|
||||
from crewai.task import Task
|
||||
from crewai.utilities.i18n import I18N
|
||||
from crewai.utilities.types import LLMMessage
|
||||
pass
|
||||
|
||||
|
||||
class CrewAgentExecutorMixin:
|
||||
crew: Crew | None
|
||||
agent: Agent
|
||||
task: Task | None
|
||||
iterations: int
|
||||
max_iter: int
|
||||
messages: list[LLMMessage]
|
||||
_i18n: I18N
|
||||
_printer: Printer = Printer()
|
||||
class CrewAgentExecutorMixin(BaseModel):
|
||||
model_config = {"arbitrary_types_allowed": True}
|
||||
|
||||
_crew: Any = PrivateAttr(default=None)
|
||||
_agent: Any = PrivateAttr(default=None)
|
||||
_task: Any = PrivateAttr(default=None)
|
||||
iterations: int = Field(default=0)
|
||||
max_iter: int = Field(default=25)
|
||||
messages: list[Any] = Field(default_factory=list)
|
||||
_i18n: Any = PrivateAttr(default=None)
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
|
||||
@property
|
||||
def crew(self) -> Any:
|
||||
return self._crew
|
||||
|
||||
@crew.setter
|
||||
def crew(self, value: Any) -> None:
|
||||
self._crew = value
|
||||
|
||||
@property
|
||||
def agent(self) -> Any:
|
||||
return self._agent
|
||||
|
||||
@agent.setter
|
||||
def agent(self, value: Any) -> None:
|
||||
self._agent = value
|
||||
|
||||
@property
|
||||
def task(self) -> Any:
|
||||
return self._task
|
||||
|
||||
@task.setter
|
||||
def task(self, value: Any) -> None:
|
||||
self._task = value
|
||||
|
||||
def _save_to_memory(self, output: AgentFinish) -> None:
|
||||
"""Save task result to unified memory (memory or crew._memory).
|
||||
@@ -49,11 +73,9 @@ class CrewAgentExecutorMixin:
|
||||
)
|
||||
extracted = memory.extract_memories(raw)
|
||||
if extracted:
|
||||
# Get the memory's existing root_scope
|
||||
base_root = getattr(memory, "root_scope", None)
|
||||
|
||||
if isinstance(base_root, str) and base_root:
|
||||
# Memory has a root_scope — extend it with agent info
|
||||
agent_role = self.agent.role or "unknown"
|
||||
sanitized_role = sanitize_scope_name(agent_role)
|
||||
agent_root = f"{base_root.rstrip('/')}/agent/{sanitized_role}"
|
||||
@@ -63,7 +85,6 @@ class CrewAgentExecutorMixin:
|
||||
extracted, agent_role=self.agent.role, root_scope=agent_root
|
||||
)
|
||||
else:
|
||||
# No base root_scope — don't inject one, preserve backward compat
|
||||
memory.remember_many(extracted, agent_role=self.agent.role)
|
||||
except Exception as e:
|
||||
self.agent._logger.log("error", f"Failed to save to memory: {e}")
|
||||
|
||||
@@ -14,8 +14,7 @@ import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, ValidationError
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
||||
from crewai.agents.parser import (
|
||||
@@ -58,7 +57,6 @@ from crewai.utilities.agent_utils import (
|
||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||
from crewai.utilities.file_store import aget_all_files, get_all_files
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
from crewai.utilities.tool_utils import (
|
||||
aexecute_tool_and_check_finality,
|
||||
@@ -89,19 +87,38 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
LLM interactions, tool execution, and feedback handling.
|
||||
"""
|
||||
|
||||
llm: Any = Field(default=None)
|
||||
prompt: Any = Field(default=None)
|
||||
tools: list[Any] = Field(default_factory=list)
|
||||
tools_names: str = Field(default="")
|
||||
stop: list[str] = Field(default_factory=list)
|
||||
tools_description: str = Field(default="")
|
||||
tools_handler: Any = Field(default=None)
|
||||
step_callback: Any = Field(default=None)
|
||||
original_tools: list[Any] = Field(default_factory=list)
|
||||
function_calling_llm: Any = Field(default=None)
|
||||
respect_context_window: bool = Field(default=False)
|
||||
request_within_rpm_limit: Any = Field(default=None)
|
||||
callbacks: list[Any] = Field(default_factory=list)
|
||||
response_model: Any = Field(default=None)
|
||||
ask_for_human_input: bool = Field(default=False)
|
||||
log_error_after: int = Field(default=3)
|
||||
before_llm_call_hooks: list[Any] = Field(default_factory=list)
|
||||
after_llm_call_hooks: list[Any] = Field(default_factory=list)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLLM,
|
||||
task: Task,
|
||||
crew: Crew,
|
||||
agent: Agent,
|
||||
prompt: SystemPromptResult | StandardPromptResult,
|
||||
max_iter: int,
|
||||
tools: list[CrewStructuredTool],
|
||||
tools_names: str,
|
||||
stop_words: list[str],
|
||||
tools_description: str,
|
||||
tools_handler: ToolsHandler,
|
||||
llm: BaseLLM | None = None,
|
||||
task: Task | None = None,
|
||||
crew: Crew | None = None,
|
||||
agent: Agent | None = None,
|
||||
prompt: SystemPromptResult | StandardPromptResult | None = None,
|
||||
max_iter: int = 25,
|
||||
tools: list[CrewStructuredTool] | None = None,
|
||||
tools_names: str = "",
|
||||
stop_words: list[str] | None = None,
|
||||
tools_description: str = "",
|
||||
tools_handler: ToolsHandler | None = None,
|
||||
step_callback: Any = None,
|
||||
original_tools: list[BaseTool] | None = None,
|
||||
function_calling_llm: BaseLLM | Any | None = None,
|
||||
@@ -110,59 +127,33 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
callbacks: list[Any] | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
i18n: I18N | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize executor.
|
||||
|
||||
Args:
|
||||
llm: Language model instance.
|
||||
task: Task to execute.
|
||||
crew: Crew instance.
|
||||
agent: Agent to execute.
|
||||
prompt: Prompt templates.
|
||||
max_iter: Maximum iterations.
|
||||
tools: Available tools.
|
||||
tools_names: Tool names string.
|
||||
stop_words: Stop word list.
|
||||
tools_description: Tool descriptions.
|
||||
tools_handler: Tool handler instance.
|
||||
step_callback: Optional step callback.
|
||||
original_tools: Original tool list.
|
||||
function_calling_llm: Optional function calling LLM.
|
||||
respect_context_window: Respect context limits.
|
||||
request_within_rpm_limit: RPM limit check function.
|
||||
callbacks: Optional callbacks list.
|
||||
response_model: Optional Pydantic model for structured outputs.
|
||||
"""
|
||||
self._i18n: I18N = i18n or get_i18n()
|
||||
self.llm = llm
|
||||
self.task = task
|
||||
self.agent = agent
|
||||
super().__init__(
|
||||
llm=llm,
|
||||
prompt=prompt,
|
||||
tools=tools or [],
|
||||
tools_names=tools_names,
|
||||
stop=stop_words or [],
|
||||
max_iter=max_iter,
|
||||
callbacks=callbacks or [],
|
||||
tools_handler=tools_handler,
|
||||
original_tools=original_tools or [],
|
||||
step_callback=step_callback,
|
||||
tools_description=tools_description,
|
||||
function_calling_llm=function_calling_llm,
|
||||
respect_context_window=respect_context_window,
|
||||
request_within_rpm_limit=request_within_rpm_limit,
|
||||
response_model=response_model,
|
||||
**kwargs,
|
||||
)
|
||||
self.crew = crew
|
||||
self.prompt = prompt
|
||||
self.tools = tools
|
||||
self.tools_names = tools_names
|
||||
self.stop = stop_words
|
||||
self.max_iter = max_iter
|
||||
self.callbacks = callbacks or []
|
||||
self._printer: Printer = Printer()
|
||||
self.tools_handler = tools_handler
|
||||
self.original_tools = original_tools or []
|
||||
self.step_callback = step_callback
|
||||
self.tools_description = tools_description
|
||||
self.function_calling_llm = function_calling_llm
|
||||
self.respect_context_window = respect_context_window
|
||||
self.request_within_rpm_limit = request_within_rpm_limit
|
||||
self.response_model = response_model
|
||||
self.ask_for_human_input = False
|
||||
self.messages: list[LLMMessage] = []
|
||||
self.iterations = 0
|
||||
self.log_error_after = 3
|
||||
self.before_llm_call_hooks: list[Callable[..., Any]] = []
|
||||
self.after_llm_call_hooks: list[Callable[..., Any]] = []
|
||||
self.agent = agent
|
||||
self.task = task
|
||||
self._i18n = i18n or get_i18n()
|
||||
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
|
||||
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
|
||||
if self.llm:
|
||||
# This may be mutating the shared llm object and needs further evaluation
|
||||
existing_stop = getattr(self.llm, "stop", [])
|
||||
self.llm.stop = list(
|
||||
set(
|
||||
@@ -1687,14 +1678,3 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
return format_message_for_llm(
|
||||
self._i18n.slice("feedback_instructions").format(feedback=feedback)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""Generate Pydantic core schema for BaseClient Protocol.
|
||||
|
||||
This allows the Protocol to be used in Pydantic models without
|
||||
requiring arbitrary_types_allowed=True.
|
||||
"""
|
||||
return core_schema.any_schema()
|
||||
|
||||
@@ -353,6 +353,34 @@ class Crew(FlowTrackable, BaseModel):
|
||||
checkpoint_train: bool | None = Field(default=None)
|
||||
checkpoint_kickoff_event_id: str | None = Field(default=None)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(cls, path: str) -> Crew:
|
||||
"""Restore a Crew from a checkpoint file.
|
||||
|
||||
Args:
|
||||
path: Path to a checkpoint JSON file.
|
||||
|
||||
Returns:
|
||||
A Crew instance with state restored from the checkpoint.
|
||||
"""
|
||||
from pathlib import Path as _Path
|
||||
|
||||
from crewai.context import apply_execution_context
|
||||
|
||||
json_str = _Path(path).read_text()
|
||||
# Parse as RuntimeState to handle discriminated union
|
||||
from crewai import RuntimeState
|
||||
|
||||
state = RuntimeState.model_validate_json(
|
||||
json_str, context={"from_checkpoint": True}
|
||||
)
|
||||
for entity in state.root:
|
||||
if isinstance(entity, cls):
|
||||
if entity.execution_context is not None:
|
||||
apply_execution_context(entity.execution_context)
|
||||
return entity
|
||||
raise ValueError(f"No Crew found in checkpoint: {path}")
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def _deny_user_set_id(cls, v: UUID4 | None, info: Any) -> UUID4 | None:
|
||||
|
||||
@@ -106,11 +106,8 @@ from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.crew import Crew
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.task import Task
|
||||
from crewai.tools.tool_types import ToolResult
|
||||
from crewai.utilities.prompts import StandardPromptResult, SystemPromptResult
|
||||
|
||||
@@ -155,7 +152,7 @@ class AgentExecutorState(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin): # type: ignore[pydantic-unexpected]
|
||||
"""Agent Executor for both standalone agents and crew-bound agents.
|
||||
|
||||
_skip_auto_memory prevents Flow from eagerly allocating a Memory
|
||||
@@ -174,7 +171,6 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
|
||||
suppress_flow_events: bool = True # always suppress for executor
|
||||
llm: BaseLLM = Field(exclude=True)
|
||||
agent: Agent = Field(exclude=True)
|
||||
prompt: SystemPromptResult | StandardPromptResult = Field(exclude=True)
|
||||
max_iter: int = Field(default=25, exclude=True)
|
||||
tools: list[CrewStructuredTool] = Field(default_factory=list, exclude=True)
|
||||
@@ -182,8 +178,6 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
stop_words: list[str] = Field(default_factory=list, exclude=True)
|
||||
tools_description: str = Field(default="", exclude=True)
|
||||
tools_handler: ToolsHandler | None = Field(default=None, exclude=True)
|
||||
task: Task | None = Field(default=None, exclude=True)
|
||||
crew: Crew | None = Field(default=None, exclude=True)
|
||||
step_callback: Any = Field(default=None, exclude=True)
|
||||
original_tools: list[BaseTool] = Field(default_factory=list, exclude=True)
|
||||
function_calling_llm: BaseLLM | None = Field(default=None, exclude=True)
|
||||
@@ -268,20 +262,20 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
"""Get thread-safe state proxy."""
|
||||
return StateProxy(self._state, self._state_lock) # type: ignore[return-value]
|
||||
|
||||
@property
|
||||
@property # type: ignore[misc]
|
||||
def iterations(self) -> int:
|
||||
"""Compatibility property for mixin - returns state iterations."""
|
||||
return self._state.iterations # type: ignore[no-any-return]
|
||||
return int(self._state.iterations)
|
||||
|
||||
@iterations.setter
|
||||
def iterations(self, value: int) -> None:
|
||||
"""Set state iterations."""
|
||||
self._state.iterations = value
|
||||
|
||||
@property
|
||||
@property # type: ignore[misc]
|
||||
def messages(self) -> list[LLMMessage]:
|
||||
"""Compatibility property - returns state messages."""
|
||||
return self._state.messages # type: ignore[no-any-return]
|
||||
return list(self._state.messages)
|
||||
|
||||
@messages.setter
|
||||
def messages(self, value: list[LLMMessage]) -> None:
|
||||
@@ -395,28 +389,28 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
"""
|
||||
config = self.agent.planning_config
|
||||
if config is not None:
|
||||
return config.reasoning_effort
|
||||
return str(config.reasoning_effort)
|
||||
return "medium"
|
||||
|
||||
def _get_max_replans(self) -> int:
|
||||
"""Get max replans from planning config or default to 3."""
|
||||
config = self.agent.planning_config
|
||||
if config is not None:
|
||||
return config.max_replans
|
||||
return int(config.max_replans)
|
||||
return 3
|
||||
|
||||
def _get_max_step_iterations(self) -> int:
|
||||
"""Get max step iterations from planning config or default to 15."""
|
||||
config = self.agent.planning_config
|
||||
if config is not None:
|
||||
return config.max_step_iterations
|
||||
return int(config.max_step_iterations)
|
||||
return 15
|
||||
|
||||
def _get_step_timeout(self) -> int | None:
|
||||
"""Get per-step timeout from planning config or default to None."""
|
||||
config = self.agent.planning_config
|
||||
if config is not None:
|
||||
return config.step_timeout
|
||||
return int(config.step_timeout) if config.step_timeout is not None else None
|
||||
return None
|
||||
|
||||
def _build_context_for_todo(self, todo: TodoItem) -> StepExecutionContext:
|
||||
|
||||
@@ -919,6 +919,27 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
max_method_calls: int = Field(default=100)
|
||||
|
||||
execution_context: ExecutionContext | None = Field(default=None)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(cls, path: str) -> Flow: # type: ignore[type-arg]
|
||||
"""Restore a Flow from a checkpoint file."""
|
||||
from pathlib import Path as _Path
|
||||
|
||||
from crewai.context import apply_execution_context
|
||||
|
||||
json_str = _Path(path).read_text()
|
||||
from crewai import RuntimeState
|
||||
|
||||
state = RuntimeState.model_validate_json(
|
||||
json_str, context={"from_checkpoint": True}
|
||||
)
|
||||
for entity in state.root:
|
||||
if isinstance(entity, cls):
|
||||
if entity.execution_context is not None:
|
||||
apply_execution_context(entity.execution_context)
|
||||
return entity
|
||||
raise ValueError(f"No {cls.__name__} found in checkpoint: {path}")
|
||||
|
||||
checkpoint_completed_methods: set[str] | None = Field(default=None)
|
||||
checkpoint_method_outputs: list[Any] | None = Field(default=None)
|
||||
checkpoint_method_counts: dict[str, int] | None = Field(default=None)
|
||||
|
||||
Reference in New Issue
Block a user