fix: flow listener resumability for HITL and cyclic flows (#3322)

* fix: flow listener resumability for HITL and cyclic flows

- Add resumption context flag to distinguish HITL resumption from cyclic execution
- Skip method re-execution only during HITL resumption, not for cyclic flows
- Ensure cyclic flows like test_cyclic_flow continue to work correctly

* fix: prevent duplicate execution of conditional start methods in flows

* fix: resolve type error in flow.py line 1040 assignment
This commit is contained in:
Greyson LaLonde
2025-08-20 10:06:18 -04:00
committed by GitHub
parent ed187b495b
commit c0d2bf4c12
2 changed files with 225 additions and 9 deletions

View File

@@ -474,6 +474,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._method_outputs: List[Any] = [] # List to store all method outputs self._method_outputs: List[Any] = [] # List to store all method outputs
self._completed_methods: Set[str] = set() # Track completed methods for reload self._completed_methods: Set[str] = set() # Track completed methods for reload
self._persistence: Optional[FlowPersistence] = persistence self._persistence: Optional[FlowPersistence] = persistence
self._is_execution_resuming: bool = False
# Initialize state with initial values # Initialize state with initial values
self._state = self._create_initial_state() self._state = self._create_initial_state()
@@ -829,6 +830,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
# Clear completed methods and outputs for a fresh start # Clear completed methods and outputs for a fresh start
self._completed_methods.clear() self._completed_methods.clear()
self._method_outputs.clear() self._method_outputs.clear()
else:
# We're restoring from persistence, set the flag
self._is_execution_resuming = True
if inputs: if inputs:
# Override the id in the state if it exists in inputs # Override the id in the state if it exists in inputs
@@ -880,6 +884,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
] ]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
# Clear the resumption flag after initial execution completes
self._is_execution_resuming = False
final_output = self._method_outputs[-1] if self._method_outputs else None final_output = self._method_outputs[-1] if self._method_outputs else None
crewai_event_bus.emit( crewai_event_bus.emit(
@@ -916,9 +923,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
- Automatically injects crewai_trigger_payload if available in flow inputs - Automatically injects crewai_trigger_payload if available in flow inputs
""" """
if start_method_name in self._completed_methods: if start_method_name in self._completed_methods:
last_output = self._method_outputs[-1] if self._method_outputs else None if self._is_execution_resuming:
await self._execute_listeners(start_method_name, last_output) # During resumption, skip execution but continue listeners
return last_output = self._method_outputs[-1] if self._method_outputs else None
await self._execute_listeners(start_method_name, last_output)
return
# For cyclic flows, clear from completed to allow re-execution
self._completed_methods.discard(start_method_name)
method = self._methods[start_method_name] method = self._methods[start_method_name]
enhanced_method = self._inject_trigger_payload_for_start_method(method) enhanced_method = self._inject_trigger_payload_for_start_method(method)
@@ -1050,11 +1061,15 @@ class Flow(Generic[T], metaclass=FlowMeta):
for router_name in routers_triggered: for router_name in routers_triggered:
await self._execute_single_listener(router_name, result) await self._execute_single_listener(router_name, result)
# After executing router, the router's result is the path # After executing router, the router's result is the path
router_result = self._method_outputs[-1] router_result = (
self._method_outputs[-1] if self._method_outputs else None
)
if router_result: # Only add non-None results if router_result: # Only add non-None results
router_results.append(router_result) router_results.append(router_result)
current_trigger = ( current_trigger = (
router_result # Update for next iteration of router chain str(router_result)
if router_result is not None
else "" # Update for next iteration of router chain
) )
# Now execute normal listeners for all router results and the original trigger # Now execute normal listeners for all router results and the original trigger
@@ -1072,6 +1087,24 @@ class Flow(Generic[T], metaclass=FlowMeta):
] ]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
if current_trigger in router_results:
# Find start methods triggered by this router result
for method_name in self._start_methods:
# Check if this start method is triggered by the current trigger
if method_name in self._listeners:
condition_type, trigger_methods = self._listeners[
method_name
]
if current_trigger in trigger_methods:
# Only execute if this is a cycle (method was already completed)
if method_name in self._completed_methods:
# For router-triggered start methods in cycles, temporarily clear resumption flag
# to allow cyclic execution
was_resuming = self._is_execution_resuming
self._is_execution_resuming = False
await self._execute_start_method(method_name)
self._is_execution_resuming = was_resuming
def _find_triggered_methods( def _find_triggered_methods(
self, trigger_method: str, router_only: bool self, trigger_method: str, router_only: bool
) -> List[str]: ) -> List[str]:
@@ -1109,6 +1142,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
if router_only != is_router: if router_only != is_router:
continue continue
if not router_only and listener_name in self._start_methods:
continue
if condition_type == "OR": if condition_type == "OR":
# If the trigger_method matches any in methods, run this # If the trigger_method matches any in methods, run this
if trigger_method in methods: if trigger_method in methods:
@@ -1158,10 +1194,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
Catches and logs any exceptions during execution, preventing Catches and logs any exceptions during execution, preventing
individual listener failures from breaking the entire flow. individual listener failures from breaking the entire flow.
""" """
# TODO: greyson fix if listener_name in self._completed_methods:
# if listener_name in self._completed_methods: if self._is_execution_resuming:
# await self._execute_listeners(listener_name, None) # During resumption, skip execution but continue listeners
# return await self._execute_listeners(listener_name, None)
return
# For cyclic flows, clear from completed to allow re-execution
self._completed_methods.discard(listener_name)
try: try:
method = self._methods[listener_name] method = self._methods[listener_name]

View File

@@ -0,0 +1,177 @@
"""Regression tests for flow listener resumability fix.
These tests ensure that:
1. HITL flows can resume properly without re-executing completed methods
2. Cyclic flows can re-execute methods on each iteration
"""
from typing import Dict
from crewai.flow.flow import Flow, listen, router, start
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
def test_hitl_resumption_skips_completed_listeners(tmp_path):
"""Test that HITL resumption skips completed listener methods but continues chains."""
db_path = tmp_path / "test_flows.db"
persistence = SQLiteFlowPersistence(str(db_path))
execution_log = []
class HitlFlow(Flow[Dict[str, str]]):
@start()
def step_1(self):
execution_log.append("step_1_executed")
self.state["step1"] = "done"
return "step1_result"
@listen(step_1)
def step_2(self):
execution_log.append("step_2_executed")
self.state["step2"] = "done"
return "step2_result"
@listen(step_2)
def step_3(self):
execution_log.append("step_3_executed")
self.state["step3"] = "done"
return "step3_result"
flow1 = HitlFlow(persistence=persistence)
flow1.kickoff()
flow_id = flow1.state["id"]
assert execution_log == ["step_1_executed", "step_2_executed", "step_3_executed"]
flow2 = HitlFlow(persistence=persistence)
flow2._completed_methods = {"step_1", "step_2"} # Simulate partial completion
execution_log.clear()
flow2.kickoff(inputs={"id": flow_id})
assert "step_1_executed" not in execution_log
assert "step_2_executed" not in execution_log
assert "step_3_executed" in execution_log
def test_cyclic_flow_re_executes_on_each_iteration():
"""Test that cyclic flows properly re-execute methods on each iteration."""
execution_log = []
class CyclicFlowTest(Flow[Dict[str, str]]):
iteration = 0
max_iterations = 3
@start("loop")
def step_1(self):
if self.iteration >= self.max_iterations:
return None
execution_log.append(f"step_1_{self.iteration}")
return f"result_{self.iteration}"
@listen(step_1)
def step_2(self):
execution_log.append(f"step_2_{self.iteration}")
@router(step_2)
def step_3(self):
execution_log.append(f"step_3_{self.iteration}")
self.iteration += 1
if self.iteration < self.max_iterations:
return "loop"
return "exit"
flow = CyclicFlowTest()
flow.kickoff()
expected = []
for i in range(3):
expected.extend([f"step_1_{i}", f"step_2_{i}", f"step_3_{i}"])
assert execution_log == expected
def test_conditional_start_with_resumption(tmp_path):
"""Test that conditional start methods work correctly with resumption."""
db_path = tmp_path / "test_flows.db"
persistence = SQLiteFlowPersistence(str(db_path))
execution_log = []
class ConditionalStartFlow(Flow[Dict[str, str]]):
@start()
def init(self):
execution_log.append("init")
return "initialized"
@router(init)
def route_to_branch(self):
execution_log.append("router")
return "branch_a"
@start("branch_a")
def branch_a_start(self):
execution_log.append("branch_a_start")
self.state["branch"] = "a"
@listen(branch_a_start)
def branch_a_process(self):
execution_log.append("branch_a_process")
self.state["processed"] = "yes"
flow1 = ConditionalStartFlow(persistence=persistence)
flow1.kickoff()
flow_id = flow1.state["id"]
assert execution_log == ["init", "router", "branch_a_start", "branch_a_process"]
flow2 = ConditionalStartFlow(persistence=persistence)
flow2._completed_methods = {"init", "route_to_branch", "branch_a_start"}
execution_log.clear()
flow2.kickoff(inputs={"id": flow_id})
assert execution_log == ["branch_a_process"]
def test_cyclic_flow_with_conditional_start():
"""Test that cyclic flows work properly with conditional start methods."""
execution_log = []
class CyclicConditionalFlow(Flow[Dict[str, str]]):
iteration = 0
@start()
def initial(self):
execution_log.append("initial")
return "init_done"
@router(initial)
def route_to_cycle(self):
execution_log.append("router_initial")
return "loop"
@start("loop")
def cycle_entry(self):
execution_log.append(f"cycle_{self.iteration}")
self.iteration += 1
@router(cycle_entry)
def cycle_router(self):
execution_log.append(f"router_{self.iteration - 1}")
if self.iteration < 3:
return "loop"
return "exit"
flow = CyclicConditionalFlow()
flow.kickoff()
expected = [
"initial",
"router_initial",
"cycle_0",
"router_0",
"cycle_1",
"router_1",
"cycle_2",
"router_2",
]
assert execution_log == expected