mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
feat: add flow resumability support (#3312)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
- Add reload() method to restore flow state from execution data - Add FlowExecutionData type definitions - Track completed methods for proper flow resumption - Support OpenTelemetry baggage context for flow inputs
This commit is contained in:
@@ -17,10 +17,13 @@ from typing import (
|
||||
)
|
||||
from uuid import uuid4
|
||||
|
||||
from opentelemetry import baggage
|
||||
from opentelemetry.context import attach, detach
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from crewai.flow.flow_visualizer import plot_flow
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.types import FlowExecutionData
|
||||
from crewai.flow.utils import get_possible_return_constants
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
from crewai.utilities.events.flow_events import (
|
||||
@@ -467,6 +470,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._method_execution_counts: Dict[str, int] = {}
|
||||
self._pending_and_listeners: Dict[str, Set[str]] = {}
|
||||
self._method_outputs: List[Any] = [] # List to store all method outputs
|
||||
self._completed_methods: Set[str] = set() # Track completed methods for reload
|
||||
self._persistence: Optional[FlowPersistence] = persistence
|
||||
|
||||
# Initialize state with initial values
|
||||
@@ -718,6 +722,73 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
else:
|
||||
raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}")
|
||||
|
||||
def reload(self, execution_data: FlowExecutionData) -> None:
|
||||
"""Reloads the flow from an execution data dict.
|
||||
|
||||
This method restores the flow's execution ID, completed methods, and state,
|
||||
allowing it to resume from where it left off.
|
||||
|
||||
Args:
|
||||
execution_data: Flow execution data containing:
|
||||
- id: Flow execution ID
|
||||
- flow: Flow structure
|
||||
- completed_methods: List of successfully completed methods
|
||||
- execution_methods: All execution methods with their status
|
||||
"""
|
||||
flow_id = execution_data.get("id")
|
||||
if flow_id:
|
||||
self._update_state_field("id", flow_id)
|
||||
|
||||
self._completed_methods = {
|
||||
name
|
||||
for method_data in execution_data.get("completed_methods", [])
|
||||
if (name := method_data.get("flow_method", {}).get("name")) is not None
|
||||
}
|
||||
|
||||
execution_methods = execution_data.get("execution_methods", [])
|
||||
if not execution_methods:
|
||||
return
|
||||
|
||||
sorted_methods = sorted(
|
||||
execution_methods,
|
||||
key=lambda m: m.get("started_at", ""),
|
||||
)
|
||||
|
||||
state_to_apply = None
|
||||
for method in reversed(sorted_methods):
|
||||
if method.get("final_state"):
|
||||
state_to_apply = method["final_state"]
|
||||
break
|
||||
|
||||
if not state_to_apply and sorted_methods:
|
||||
last_method = sorted_methods[-1]
|
||||
if last_method.get("initial_state"):
|
||||
state_to_apply = last_method["initial_state"]
|
||||
|
||||
if state_to_apply:
|
||||
self._apply_state_updates(state_to_apply)
|
||||
|
||||
for i, method in enumerate(sorted_methods[:-1]):
|
||||
method_name = method.get("flow_method", {}).get("name")
|
||||
if method_name:
|
||||
self._completed_methods.add(method_name)
|
||||
|
||||
def _update_state_field(self, field_name: str, value: Any) -> None:
|
||||
"""Update a single field in the state."""
|
||||
if isinstance(self._state, dict):
|
||||
self._state[field_name] = value
|
||||
elif hasattr(self._state, field_name):
|
||||
object.__setattr__(self._state, field_name, value)
|
||||
|
||||
def _apply_state_updates(self, updates: Dict[str, Any]) -> None:
|
||||
"""Apply multiple state updates efficiently."""
|
||||
if isinstance(self._state, dict):
|
||||
self._state.update(updates)
|
||||
elif hasattr(self._state, "__dict__"):
|
||||
for key, value in updates.items():
|
||||
if hasattr(self._state, key):
|
||||
object.__setattr__(self._state, key, value)
|
||||
|
||||
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""
|
||||
Start the flow execution in a synchronous context.
|
||||
@@ -746,6 +817,17 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Returns:
|
||||
The final output from the flow, which is the result of the last executed method.
|
||||
"""
|
||||
ctx = baggage.set_baggage("flow_inputs", inputs or {})
|
||||
flow_token = attach(ctx)
|
||||
|
||||
try:
|
||||
# Reset flow state for fresh execution unless restoring from persistence
|
||||
is_restoring = inputs and "id" in inputs and self._persistence is not None
|
||||
if not is_restoring:
|
||||
# Clear completed methods and outputs for a fresh start
|
||||
self._completed_methods.clear()
|
||||
self._method_outputs.clear()
|
||||
|
||||
if inputs:
|
||||
# Override the id in the state if it exists in inputs
|
||||
if "id" in inputs:
|
||||
@@ -808,6 +890,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
|
||||
return final_output
|
||||
finally:
|
||||
detach(flow_token)
|
||||
|
||||
async def _execute_start_method(self, start_method_name: str) -> None:
|
||||
"""
|
||||
@@ -826,7 +910,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
- Executes the start method and captures its result
|
||||
- Triggers execution of any listeners waiting on this start method
|
||||
- Part of the flow's initialization sequence
|
||||
- Skips execution if method was already completed (e.g., after reload)
|
||||
"""
|
||||
if start_method_name in self._completed_methods:
|
||||
last_output = self._method_outputs[-1] if self._method_outputs else None
|
||||
await self._execute_listeners(start_method_name, last_output)
|
||||
return
|
||||
|
||||
result = await self._execute_method(
|
||||
start_method_name, self._methods[start_method_name]
|
||||
)
|
||||
@@ -861,6 +951,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._method_execution_counts.get(method_name, 0) + 1
|
||||
)
|
||||
|
||||
self._completed_methods.add(method_name)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MethodExecutionFinishedEvent(
|
||||
@@ -1023,12 +1114,17 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
- Handles errors gracefully with detailed logging
|
||||
- Recursively triggers listeners of this listener
|
||||
- Supports both parameterized and parameter-less listeners
|
||||
- Skips execution if method was already completed (e.g., after reload)
|
||||
|
||||
Error Handling
|
||||
-------------
|
||||
Catches and logs any exceptions during execution, preventing
|
||||
individual listener failures from breaking the entire flow.
|
||||
"""
|
||||
if listener_name in self._completed_methods:
|
||||
await self._execute_listeners(listener_name, None)
|
||||
return
|
||||
|
||||
try:
|
||||
method = self._methods[listener_name]
|
||||
|
||||
@@ -1047,12 +1143,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
await self._execute_listeners(listener_name, listener_result)
|
||||
|
||||
except Exception as e:
|
||||
print(
|
||||
f"[Flow._execute_single_listener] Error in method {listener_name}: {e}"
|
||||
)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
logger.error(f"Error executing listener {listener_name}: {e}")
|
||||
raise
|
||||
|
||||
def _log_flow_event(
|
||||
|
||||
95
src/crewai/flow/types.py
Normal file
95
src/crewai/flow/types.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""Type definitions for CrewAI Flow module.
|
||||
|
||||
This module contains TypedDict definitions and type aliases used throughout
|
||||
the Flow system.
|
||||
"""
|
||||
|
||||
from typing import Any, TypedDict
|
||||
from typing_extensions import NotRequired, Required
|
||||
|
||||
|
||||
class FlowMethodData(TypedDict):
|
||||
"""Flow method information.
|
||||
|
||||
Attributes:
|
||||
name: The name of the flow method.
|
||||
starting_point: Whether this method is a starting point for the flow.
|
||||
"""
|
||||
|
||||
name: str
|
||||
starting_point: NotRequired[bool]
|
||||
|
||||
|
||||
class CompletedMethodData(TypedDict):
|
||||
"""Completed method information.
|
||||
|
||||
Represents a flow method that has been successfully executed.
|
||||
|
||||
Attributes:
|
||||
flow_method: The flow method information.
|
||||
status: The completion status of the method.
|
||||
"""
|
||||
|
||||
flow_method: FlowMethodData
|
||||
status: str
|
||||
|
||||
|
||||
class ExecutionMethodData(TypedDict, total=False):
|
||||
"""Execution method information.
|
||||
|
||||
Contains detailed information about a method's execution, including
|
||||
timing, state, and any error details.
|
||||
|
||||
Attributes:
|
||||
flow_method: The flow method information.
|
||||
started_at: ISO timestamp when the method started execution.
|
||||
finished_at: ISO timestamp when the method finished execution, if completed.
|
||||
status: Current status of the method execution.
|
||||
initial_state: The state before method execution.
|
||||
final_state: The state after method execution.
|
||||
error_details: Details about any error that occurred during execution.
|
||||
"""
|
||||
|
||||
flow_method: Required[FlowMethodData]
|
||||
started_at: Required[str]
|
||||
status: Required[str]
|
||||
finished_at: str
|
||||
initial_state: dict[str, Any]
|
||||
final_state: dict[str, Any]
|
||||
error_details: dict[str, Any]
|
||||
|
||||
|
||||
class FlowData(TypedDict):
|
||||
"""Flow structure information.
|
||||
|
||||
Contains metadata about the flow structure and its methods.
|
||||
|
||||
Attributes:
|
||||
name: The name of the flow.
|
||||
flow_methods_attributes: List of all flow methods and their attributes.
|
||||
"""
|
||||
|
||||
name: str
|
||||
flow_methods_attributes: list[FlowMethodData]
|
||||
|
||||
|
||||
class FlowExecutionData(TypedDict):
|
||||
"""Flow execution data.
|
||||
|
||||
Complete execution data for a flow, including its current state,
|
||||
completed methods, and execution history. Used for resuming flows
|
||||
from a previous state.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the flow execution.
|
||||
flow: Flow structure and metadata.
|
||||
inputs: Input data provided to the flow.
|
||||
completed_methods: List of methods that have been successfully completed.
|
||||
execution_methods: Detailed execution history for all methods.
|
||||
"""
|
||||
|
||||
id: str
|
||||
flow: FlowData
|
||||
inputs: dict[str, Any]
|
||||
completed_methods: list[CompletedMethodData]
|
||||
execution_methods: list[ExecutionMethodData]
|
||||
18
src/crewai/types/hitl.py
Normal file
18
src/crewai/types/hitl.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from typing import List, Dict, TypedDict
|
||||
|
||||
|
||||
class HITLResumeInfo(TypedDict, total=False):
|
||||
"""HITL resume information passed from flow to crew."""
|
||||
|
||||
task_id: str
|
||||
crew_execution_id: str
|
||||
task_key: str
|
||||
task_output: str
|
||||
human_feedback: str
|
||||
previous_messages: List[Dict[str, str]]
|
||||
|
||||
|
||||
class CrewInputsWithHITL(TypedDict, total=False):
|
||||
"""Crew inputs that may contain HITL resume information."""
|
||||
|
||||
_hitl_resume: HITLResumeInfo
|
||||
Reference in New Issue
Block a user