feat: add flow resumability support (#3312)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled

- Add reload() method to restore flow state from execution data
- Add FlowExecutionData type definitions
- Track completed methods for proper flow resumption
- Support OpenTelemetry baggage context for flow inputs
This commit is contained in:
Greyson LaLonde
2025-08-13 13:45:08 -04:00
committed by GitHub
parent dc6771ae95
commit 8b686fb0c6
3 changed files with 264 additions and 60 deletions

View File

@@ -17,10 +17,13 @@ from typing import (
)
from uuid import uuid4
from opentelemetry import baggage
from opentelemetry.context import attach, detach
from pydantic import BaseModel, Field, ValidationError
from crewai.flow.flow_visualizer import plot_flow
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.types import FlowExecutionData
from crewai.flow.utils import get_possible_return_constants
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
from crewai.utilities.events.flow_events import (
@@ -467,6 +470,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._method_execution_counts: Dict[str, int] = {}
self._pending_and_listeners: Dict[str, Set[str]] = {}
self._method_outputs: List[Any] = [] # List to store all method outputs
self._completed_methods: Set[str] = set() # Track completed methods for reload
self._persistence: Optional[FlowPersistence] = persistence
# Initialize state with initial values
@@ -718,6 +722,73 @@ class Flow(Generic[T], metaclass=FlowMeta):
else:
raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}")
def reload(self, execution_data: FlowExecutionData) -> None:
"""Reloads the flow from an execution data dict.
This method restores the flow's execution ID, completed methods, and state,
allowing it to resume from where it left off.
Args:
execution_data: Flow execution data containing:
- id: Flow execution ID
- flow: Flow structure
- completed_methods: List of successfully completed methods
- execution_methods: All execution methods with their status
"""
flow_id = execution_data.get("id")
if flow_id:
self._update_state_field("id", flow_id)
self._completed_methods = {
name
for method_data in execution_data.get("completed_methods", [])
if (name := method_data.get("flow_method", {}).get("name")) is not None
}
execution_methods = execution_data.get("execution_methods", [])
if not execution_methods:
return
sorted_methods = sorted(
execution_methods,
key=lambda m: m.get("started_at", ""),
)
state_to_apply = None
for method in reversed(sorted_methods):
if method.get("final_state"):
state_to_apply = method["final_state"]
break
if not state_to_apply and sorted_methods:
last_method = sorted_methods[-1]
if last_method.get("initial_state"):
state_to_apply = last_method["initial_state"]
if state_to_apply:
self._apply_state_updates(state_to_apply)
for i, method in enumerate(sorted_methods[:-1]):
method_name = method.get("flow_method", {}).get("name")
if method_name:
self._completed_methods.add(method_name)
def _update_state_field(self, field_name: str, value: Any) -> None:
"""Update a single field in the state."""
if isinstance(self._state, dict):
self._state[field_name] = value
elif hasattr(self._state, field_name):
object.__setattr__(self._state, field_name, value)
def _apply_state_updates(self, updates: Dict[str, Any]) -> None:
"""Apply multiple state updates efficiently."""
if isinstance(self._state, dict):
self._state.update(updates)
elif hasattr(self._state, "__dict__"):
for key, value in updates.items():
if hasattr(self._state, key):
object.__setattr__(self._state, key, value)
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
"""
Start the flow execution in a synchronous context.
@@ -746,68 +817,81 @@ class Flow(Generic[T], metaclass=FlowMeta):
Returns:
The final output from the flow, which is the result of the last executed method.
"""
if inputs:
# Override the id in the state if it exists in inputs
if "id" in inputs:
if isinstance(self._state, dict):
self._state["id"] = inputs["id"]
elif isinstance(self._state, BaseModel):
setattr(self._state, "id", inputs["id"])
ctx = baggage.set_baggage("flow_inputs", inputs or {})
flow_token = attach(ctx)
# If persistence is enabled, attempt to restore the stored state using the provided id.
if "id" in inputs and self._persistence is not None:
restore_uuid = inputs["id"]
stored_state = self._persistence.load_state(restore_uuid)
if stored_state:
self._log_flow_event(
f"Loading flow state from memory for UUID: {restore_uuid}",
color="yellow",
)
self._restore_state(stored_state)
else:
self._log_flow_event(
f"No flow state found for UUID: {restore_uuid}", color="red"
)
try:
# Reset flow state for fresh execution unless restoring from persistence
is_restoring = inputs and "id" in inputs and self._persistence is not None
if not is_restoring:
# Clear completed methods and outputs for a fresh start
self._completed_methods.clear()
self._method_outputs.clear()
# Update state with any additional inputs (ignoring the 'id' key)
filtered_inputs = {k: v for k, v in inputs.items() if k != "id"}
if filtered_inputs:
self._initialize_state(filtered_inputs)
if inputs:
# Override the id in the state if it exists in inputs
if "id" in inputs:
if isinstance(self._state, dict):
self._state["id"] = inputs["id"]
elif isinstance(self._state, BaseModel):
setattr(self._state, "id", inputs["id"])
# Emit FlowStartedEvent and log the start of the flow.
crewai_event_bus.emit(
self,
FlowStartedEvent(
type="flow_started",
flow_name=self.name or self.__class__.__name__,
inputs=inputs,
),
)
self._log_flow_event(
f"Flow started with ID: {self.flow_id}", color="bold_magenta"
)
# If persistence is enabled, attempt to restore the stored state using the provided id.
if "id" in inputs and self._persistence is not None:
restore_uuid = inputs["id"]
stored_state = self._persistence.load_state(restore_uuid)
if stored_state:
self._log_flow_event(
f"Loading flow state from memory for UUID: {restore_uuid}",
color="yellow",
)
self._restore_state(stored_state)
else:
self._log_flow_event(
f"No flow state found for UUID: {restore_uuid}", color="red"
)
if inputs is not None and "id" not in inputs:
self._initialize_state(inputs)
# Update state with any additional inputs (ignoring the 'id' key)
filtered_inputs = {k: v for k, v in inputs.items() if k != "id"}
if filtered_inputs:
self._initialize_state(filtered_inputs)
tasks = [
self._execute_start_method(start_method)
for start_method in self._start_methods
]
await asyncio.gather(*tasks)
# Emit FlowStartedEvent and log the start of the flow.
crewai_event_bus.emit(
self,
FlowStartedEvent(
type="flow_started",
flow_name=self.name or self.__class__.__name__,
inputs=inputs,
),
)
self._log_flow_event(
f"Flow started with ID: {self.flow_id}", color="bold_magenta"
)
final_output = self._method_outputs[-1] if self._method_outputs else None
if inputs is not None and "id" not in inputs:
self._initialize_state(inputs)
crewai_event_bus.emit(
self,
FlowFinishedEvent(
type="flow_finished",
flow_name=self.name or self.__class__.__name__,
result=final_output,
),
)
tasks = [
self._execute_start_method(start_method)
for start_method in self._start_methods
]
await asyncio.gather(*tasks)
return final_output
final_output = self._method_outputs[-1] if self._method_outputs else None
crewai_event_bus.emit(
self,
FlowFinishedEvent(
type="flow_finished",
flow_name=self.name or self.__class__.__name__,
result=final_output,
),
)
return final_output
finally:
detach(flow_token)
async def _execute_start_method(self, start_method_name: str) -> None:
"""
@@ -826,7 +910,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
- Executes the start method and captures its result
- Triggers execution of any listeners waiting on this start method
- Part of the flow's initialization sequence
- Skips execution if method was already completed (e.g., after reload)
"""
if start_method_name in self._completed_methods:
last_output = self._method_outputs[-1] if self._method_outputs else None
await self._execute_listeners(start_method_name, last_output)
return
result = await self._execute_method(
start_method_name, self._methods[start_method_name]
)
@@ -861,6 +951,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._method_execution_counts.get(method_name, 0) + 1
)
self._completed_methods.add(method_name)
crewai_event_bus.emit(
self,
MethodExecutionFinishedEvent(
@@ -1023,12 +1114,17 @@ class Flow(Generic[T], metaclass=FlowMeta):
- Handles errors gracefully with detailed logging
- Recursively triggers listeners of this listener
- Supports both parameterized and parameter-less listeners
- Skips execution if method was already completed (e.g., after reload)
Error Handling
-------------
Catches and logs any exceptions during execution, preventing
individual listener failures from breaking the entire flow.
"""
if listener_name in self._completed_methods:
await self._execute_listeners(listener_name, None)
return
try:
method = self._methods[listener_name]
@@ -1047,12 +1143,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
await self._execute_listeners(listener_name, listener_result)
except Exception as e:
print(
f"[Flow._execute_single_listener] Error in method {listener_name}: {e}"
)
import traceback
traceback.print_exc()
logger.error(f"Error executing listener {listener_name}: {e}")
raise
def _log_flow_event(