mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
fix: flow listener resumability for HITL and cyclic flows (#3322)
* fix: flow listener resumability for HITL and cyclic flows - Add resumption context flag to distinguish HITL resumption from cyclic execution - Skip method re-execution only during HITL resumption, not for cyclic flows - Ensure cyclic flows like test_cyclic_flow continue to work correctly * fix: prevent duplicate execution of conditional start methods in flows * fix: resolve type error in flow.py line 1040 assignment
This commit is contained in:
@@ -474,6 +474,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
self._method_outputs: List[Any] = [] # List to store all method outputs
|
self._method_outputs: List[Any] = [] # List to store all method outputs
|
||||||
self._completed_methods: Set[str] = set() # Track completed methods for reload
|
self._completed_methods: Set[str] = set() # Track completed methods for reload
|
||||||
self._persistence: Optional[FlowPersistence] = persistence
|
self._persistence: Optional[FlowPersistence] = persistence
|
||||||
|
self._is_execution_resuming: bool = False
|
||||||
|
|
||||||
# Initialize state with initial values
|
# Initialize state with initial values
|
||||||
self._state = self._create_initial_state()
|
self._state = self._create_initial_state()
|
||||||
@@ -829,6 +830,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
# Clear completed methods and outputs for a fresh start
|
# Clear completed methods and outputs for a fresh start
|
||||||
self._completed_methods.clear()
|
self._completed_methods.clear()
|
||||||
self._method_outputs.clear()
|
self._method_outputs.clear()
|
||||||
|
else:
|
||||||
|
# We're restoring from persistence, set the flag
|
||||||
|
self._is_execution_resuming = True
|
||||||
|
|
||||||
if inputs:
|
if inputs:
|
||||||
# Override the id in the state if it exists in inputs
|
# Override the id in the state if it exists in inputs
|
||||||
@@ -880,6 +884,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
]
|
]
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Clear the resumption flag after initial execution completes
|
||||||
|
self._is_execution_resuming = False
|
||||||
|
|
||||||
final_output = self._method_outputs[-1] if self._method_outputs else None
|
final_output = self._method_outputs[-1] if self._method_outputs else None
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
@@ -916,9 +923,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
- Automatically injects crewai_trigger_payload if available in flow inputs
|
- Automatically injects crewai_trigger_payload if available in flow inputs
|
||||||
"""
|
"""
|
||||||
if start_method_name in self._completed_methods:
|
if start_method_name in self._completed_methods:
|
||||||
last_output = self._method_outputs[-1] if self._method_outputs else None
|
if self._is_execution_resuming:
|
||||||
await self._execute_listeners(start_method_name, last_output)
|
# During resumption, skip execution but continue listeners
|
||||||
return
|
last_output = self._method_outputs[-1] if self._method_outputs else None
|
||||||
|
await self._execute_listeners(start_method_name, last_output)
|
||||||
|
return
|
||||||
|
# For cyclic flows, clear from completed to allow re-execution
|
||||||
|
self._completed_methods.discard(start_method_name)
|
||||||
|
|
||||||
method = self._methods[start_method_name]
|
method = self._methods[start_method_name]
|
||||||
enhanced_method = self._inject_trigger_payload_for_start_method(method)
|
enhanced_method = self._inject_trigger_payload_for_start_method(method)
|
||||||
@@ -1050,11 +1061,15 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
for router_name in routers_triggered:
|
for router_name in routers_triggered:
|
||||||
await self._execute_single_listener(router_name, result)
|
await self._execute_single_listener(router_name, result)
|
||||||
# After executing router, the router's result is the path
|
# After executing router, the router's result is the path
|
||||||
router_result = self._method_outputs[-1]
|
router_result = (
|
||||||
|
self._method_outputs[-1] if self._method_outputs else None
|
||||||
|
)
|
||||||
if router_result: # Only add non-None results
|
if router_result: # Only add non-None results
|
||||||
router_results.append(router_result)
|
router_results.append(router_result)
|
||||||
current_trigger = (
|
current_trigger = (
|
||||||
router_result # Update for next iteration of router chain
|
str(router_result)
|
||||||
|
if router_result is not None
|
||||||
|
else "" # Update for next iteration of router chain
|
||||||
)
|
)
|
||||||
|
|
||||||
# Now execute normal listeners for all router results and the original trigger
|
# Now execute normal listeners for all router results and the original trigger
|
||||||
@@ -1072,6 +1087,24 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
]
|
]
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
if current_trigger in router_results:
|
||||||
|
# Find start methods triggered by this router result
|
||||||
|
for method_name in self._start_methods:
|
||||||
|
# Check if this start method is triggered by the current trigger
|
||||||
|
if method_name in self._listeners:
|
||||||
|
condition_type, trigger_methods = self._listeners[
|
||||||
|
method_name
|
||||||
|
]
|
||||||
|
if current_trigger in trigger_methods:
|
||||||
|
# Only execute if this is a cycle (method was already completed)
|
||||||
|
if method_name in self._completed_methods:
|
||||||
|
# For router-triggered start methods in cycles, temporarily clear resumption flag
|
||||||
|
# to allow cyclic execution
|
||||||
|
was_resuming = self._is_execution_resuming
|
||||||
|
self._is_execution_resuming = False
|
||||||
|
await self._execute_start_method(method_name)
|
||||||
|
self._is_execution_resuming = was_resuming
|
||||||
|
|
||||||
def _find_triggered_methods(
|
def _find_triggered_methods(
|
||||||
self, trigger_method: str, router_only: bool
|
self, trigger_method: str, router_only: bool
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
@@ -1109,6 +1142,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
if router_only != is_router:
|
if router_only != is_router:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if not router_only and listener_name in self._start_methods:
|
||||||
|
continue
|
||||||
|
|
||||||
if condition_type == "OR":
|
if condition_type == "OR":
|
||||||
# If the trigger_method matches any in methods, run this
|
# If the trigger_method matches any in methods, run this
|
||||||
if trigger_method in methods:
|
if trigger_method in methods:
|
||||||
@@ -1158,10 +1194,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
Catches and logs any exceptions during execution, preventing
|
Catches and logs any exceptions during execution, preventing
|
||||||
individual listener failures from breaking the entire flow.
|
individual listener failures from breaking the entire flow.
|
||||||
"""
|
"""
|
||||||
# TODO: greyson fix
|
if listener_name in self._completed_methods:
|
||||||
# if listener_name in self._completed_methods:
|
if self._is_execution_resuming:
|
||||||
# await self._execute_listeners(listener_name, None)
|
# During resumption, skip execution but continue listeners
|
||||||
# return
|
await self._execute_listeners(listener_name, None)
|
||||||
|
return
|
||||||
|
# For cyclic flows, clear from completed to allow re-execution
|
||||||
|
self._completed_methods.discard(listener_name)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
method = self._methods[listener_name]
|
method = self._methods[listener_name]
|
||||||
|
|||||||
177
tests/test_flow_resumability_regression.py
Normal file
177
tests/test_flow_resumability_regression.py
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
"""Regression tests for flow listener resumability fix.
|
||||||
|
|
||||||
|
These tests ensure that:
|
||||||
|
1. HITL flows can resume properly without re-executing completed methods
|
||||||
|
2. Cyclic flows can re-execute methods on each iteration
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
from crewai.flow.flow import Flow, listen, router, start
|
||||||
|
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||||
|
|
||||||
|
|
||||||
|
def test_hitl_resumption_skips_completed_listeners(tmp_path):
|
||||||
|
"""Test that HITL resumption skips completed listener methods but continues chains."""
|
||||||
|
db_path = tmp_path / "test_flows.db"
|
||||||
|
persistence = SQLiteFlowPersistence(str(db_path))
|
||||||
|
execution_log = []
|
||||||
|
|
||||||
|
class HitlFlow(Flow[Dict[str, str]]):
|
||||||
|
@start()
|
||||||
|
def step_1(self):
|
||||||
|
execution_log.append("step_1_executed")
|
||||||
|
self.state["step1"] = "done"
|
||||||
|
return "step1_result"
|
||||||
|
|
||||||
|
@listen(step_1)
|
||||||
|
def step_2(self):
|
||||||
|
execution_log.append("step_2_executed")
|
||||||
|
self.state["step2"] = "done"
|
||||||
|
return "step2_result"
|
||||||
|
|
||||||
|
@listen(step_2)
|
||||||
|
def step_3(self):
|
||||||
|
execution_log.append("step_3_executed")
|
||||||
|
self.state["step3"] = "done"
|
||||||
|
return "step3_result"
|
||||||
|
|
||||||
|
flow1 = HitlFlow(persistence=persistence)
|
||||||
|
flow1.kickoff()
|
||||||
|
flow_id = flow1.state["id"]
|
||||||
|
|
||||||
|
assert execution_log == ["step_1_executed", "step_2_executed", "step_3_executed"]
|
||||||
|
|
||||||
|
flow2 = HitlFlow(persistence=persistence)
|
||||||
|
flow2._completed_methods = {"step_1", "step_2"} # Simulate partial completion
|
||||||
|
execution_log.clear()
|
||||||
|
|
||||||
|
flow2.kickoff(inputs={"id": flow_id})
|
||||||
|
|
||||||
|
assert "step_1_executed" not in execution_log
|
||||||
|
assert "step_2_executed" not in execution_log
|
||||||
|
assert "step_3_executed" in execution_log
|
||||||
|
|
||||||
|
|
||||||
|
def test_cyclic_flow_re_executes_on_each_iteration():
|
||||||
|
"""Test that cyclic flows properly re-execute methods on each iteration."""
|
||||||
|
execution_log = []
|
||||||
|
|
||||||
|
class CyclicFlowTest(Flow[Dict[str, str]]):
|
||||||
|
iteration = 0
|
||||||
|
max_iterations = 3
|
||||||
|
|
||||||
|
@start("loop")
|
||||||
|
def step_1(self):
|
||||||
|
if self.iteration >= self.max_iterations:
|
||||||
|
return None
|
||||||
|
execution_log.append(f"step_1_{self.iteration}")
|
||||||
|
return f"result_{self.iteration}"
|
||||||
|
|
||||||
|
@listen(step_1)
|
||||||
|
def step_2(self):
|
||||||
|
execution_log.append(f"step_2_{self.iteration}")
|
||||||
|
|
||||||
|
@router(step_2)
|
||||||
|
def step_3(self):
|
||||||
|
execution_log.append(f"step_3_{self.iteration}")
|
||||||
|
self.iteration += 1
|
||||||
|
if self.iteration < self.max_iterations:
|
||||||
|
return "loop"
|
||||||
|
return "exit"
|
||||||
|
|
||||||
|
flow = CyclicFlowTest()
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
expected = []
|
||||||
|
for i in range(3):
|
||||||
|
expected.extend([f"step_1_{i}", f"step_2_{i}", f"step_3_{i}"])
|
||||||
|
|
||||||
|
assert execution_log == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_conditional_start_with_resumption(tmp_path):
|
||||||
|
"""Test that conditional start methods work correctly with resumption."""
|
||||||
|
db_path = tmp_path / "test_flows.db"
|
||||||
|
persistence = SQLiteFlowPersistence(str(db_path))
|
||||||
|
execution_log = []
|
||||||
|
|
||||||
|
class ConditionalStartFlow(Flow[Dict[str, str]]):
|
||||||
|
@start()
|
||||||
|
def init(self):
|
||||||
|
execution_log.append("init")
|
||||||
|
return "initialized"
|
||||||
|
|
||||||
|
@router(init)
|
||||||
|
def route_to_branch(self):
|
||||||
|
execution_log.append("router")
|
||||||
|
return "branch_a"
|
||||||
|
|
||||||
|
@start("branch_a")
|
||||||
|
def branch_a_start(self):
|
||||||
|
execution_log.append("branch_a_start")
|
||||||
|
self.state["branch"] = "a"
|
||||||
|
|
||||||
|
@listen(branch_a_start)
|
||||||
|
def branch_a_process(self):
|
||||||
|
execution_log.append("branch_a_process")
|
||||||
|
self.state["processed"] = "yes"
|
||||||
|
|
||||||
|
flow1 = ConditionalStartFlow(persistence=persistence)
|
||||||
|
flow1.kickoff()
|
||||||
|
flow_id = flow1.state["id"]
|
||||||
|
|
||||||
|
assert execution_log == ["init", "router", "branch_a_start", "branch_a_process"]
|
||||||
|
|
||||||
|
flow2 = ConditionalStartFlow(persistence=persistence)
|
||||||
|
flow2._completed_methods = {"init", "route_to_branch", "branch_a_start"}
|
||||||
|
execution_log.clear()
|
||||||
|
|
||||||
|
flow2.kickoff(inputs={"id": flow_id})
|
||||||
|
|
||||||
|
assert execution_log == ["branch_a_process"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_cyclic_flow_with_conditional_start():
|
||||||
|
"""Test that cyclic flows work properly with conditional start methods."""
|
||||||
|
execution_log = []
|
||||||
|
|
||||||
|
class CyclicConditionalFlow(Flow[Dict[str, str]]):
|
||||||
|
iteration = 0
|
||||||
|
|
||||||
|
@start()
|
||||||
|
def initial(self):
|
||||||
|
execution_log.append("initial")
|
||||||
|
return "init_done"
|
||||||
|
|
||||||
|
@router(initial)
|
||||||
|
def route_to_cycle(self):
|
||||||
|
execution_log.append("router_initial")
|
||||||
|
return "loop"
|
||||||
|
|
||||||
|
@start("loop")
|
||||||
|
def cycle_entry(self):
|
||||||
|
execution_log.append(f"cycle_{self.iteration}")
|
||||||
|
self.iteration += 1
|
||||||
|
|
||||||
|
@router(cycle_entry)
|
||||||
|
def cycle_router(self):
|
||||||
|
execution_log.append(f"router_{self.iteration - 1}")
|
||||||
|
if self.iteration < 3:
|
||||||
|
return "loop"
|
||||||
|
return "exit"
|
||||||
|
|
||||||
|
flow = CyclicConditionalFlow()
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
expected = [
|
||||||
|
"initial",
|
||||||
|
"router_initial",
|
||||||
|
"cycle_0",
|
||||||
|
"router_0",
|
||||||
|
"cycle_1",
|
||||||
|
"router_1",
|
||||||
|
"cycle_2",
|
||||||
|
"router_2",
|
||||||
|
]
|
||||||
|
|
||||||
|
assert execution_log == expected
|
||||||
Reference in New Issue
Block a user