diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 4e06d85d8..167d3d416 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -474,6 +474,7 @@ class Flow(Generic[T], metaclass=FlowMeta): self._method_outputs: List[Any] = [] # List to store all method outputs self._completed_methods: Set[str] = set() # Track completed methods for reload self._persistence: Optional[FlowPersistence] = persistence + self._is_execution_resuming: bool = False # Initialize state with initial values self._state = self._create_initial_state() @@ -829,6 +830,9 @@ class Flow(Generic[T], metaclass=FlowMeta): # Clear completed methods and outputs for a fresh start self._completed_methods.clear() self._method_outputs.clear() + else: + # We're restoring from persistence, set the flag + self._is_execution_resuming = True if inputs: # Override the id in the state if it exists in inputs @@ -880,6 +884,9 @@ class Flow(Generic[T], metaclass=FlowMeta): ] await asyncio.gather(*tasks) + # Clear the resumption flag after initial execution completes + self._is_execution_resuming = False + final_output = self._method_outputs[-1] if self._method_outputs else None crewai_event_bus.emit( @@ -916,9 +923,13 @@ class Flow(Generic[T], metaclass=FlowMeta): - Automatically injects crewai_trigger_payload if available in flow inputs """ if start_method_name in self._completed_methods: - last_output = self._method_outputs[-1] if self._method_outputs else None - await self._execute_listeners(start_method_name, last_output) - return + if self._is_execution_resuming: + # During resumption, skip execution but continue listeners + last_output = self._method_outputs[-1] if self._method_outputs else None + await self._execute_listeners(start_method_name, last_output) + return + # For cyclic flows, clear from completed to allow re-execution + self._completed_methods.discard(start_method_name) method = self._methods[start_method_name] enhanced_method = self._inject_trigger_payload_for_start_method(method) @@ -1050,11 +1061,15 @@ class Flow(Generic[T], metaclass=FlowMeta): for router_name in routers_triggered: await self._execute_single_listener(router_name, result) # After executing router, the router's result is the path - router_result = self._method_outputs[-1] + router_result = ( + self._method_outputs[-1] if self._method_outputs else None + ) if router_result: # Only add non-None results router_results.append(router_result) current_trigger = ( - router_result # Update for next iteration of router chain + str(router_result) + if router_result is not None + else "" # Update for next iteration of router chain ) # Now execute normal listeners for all router results and the original trigger @@ -1072,6 +1087,24 @@ class Flow(Generic[T], metaclass=FlowMeta): ] await asyncio.gather(*tasks) + if current_trigger in router_results: + # Find start methods triggered by this router result + for method_name in self._start_methods: + # Check if this start method is triggered by the current trigger + if method_name in self._listeners: + condition_type, trigger_methods = self._listeners[ + method_name + ] + if current_trigger in trigger_methods: + # Only execute if this is a cycle (method was already completed) + if method_name in self._completed_methods: + # For router-triggered start methods in cycles, temporarily clear resumption flag + # to allow cyclic execution + was_resuming = self._is_execution_resuming + self._is_execution_resuming = False + await self._execute_start_method(method_name) + self._is_execution_resuming = was_resuming + def _find_triggered_methods( self, trigger_method: str, router_only: bool ) -> List[str]: @@ -1109,6 +1142,9 @@ class Flow(Generic[T], metaclass=FlowMeta): if router_only != is_router: continue + if not router_only and listener_name in self._start_methods: + continue + if condition_type == "OR": # If the trigger_method matches any in methods, run this if trigger_method in methods: @@ -1158,10 +1194,13 @@ class Flow(Generic[T], metaclass=FlowMeta): Catches and logs any exceptions during execution, preventing individual listener failures from breaking the entire flow. """ - # TODO: greyson fix - # if listener_name in self._completed_methods: - # await self._execute_listeners(listener_name, None) - # return + if listener_name in self._completed_methods: + if self._is_execution_resuming: + # During resumption, skip execution but continue listeners + await self._execute_listeners(listener_name, None) + return + # For cyclic flows, clear from completed to allow re-execution + self._completed_methods.discard(listener_name) try: method = self._methods[listener_name] diff --git a/tests/test_flow_resumability_regression.py b/tests/test_flow_resumability_regression.py new file mode 100644 index 000000000..87f67173d --- /dev/null +++ b/tests/test_flow_resumability_regression.py @@ -0,0 +1,177 @@ +"""Regression tests for flow listener resumability fix. + +These tests ensure that: +1. HITL flows can resume properly without re-executing completed methods +2. Cyclic flows can re-execute methods on each iteration +""" + +from typing import Dict +from crewai.flow.flow import Flow, listen, router, start +from crewai.flow.persistence.sqlite import SQLiteFlowPersistence + + +def test_hitl_resumption_skips_completed_listeners(tmp_path): + """Test that HITL resumption skips completed listener methods but continues chains.""" + db_path = tmp_path / "test_flows.db" + persistence = SQLiteFlowPersistence(str(db_path)) + execution_log = [] + + class HitlFlow(Flow[Dict[str, str]]): + @start() + def step_1(self): + execution_log.append("step_1_executed") + self.state["step1"] = "done" + return "step1_result" + + @listen(step_1) + def step_2(self): + execution_log.append("step_2_executed") + self.state["step2"] = "done" + return "step2_result" + + @listen(step_2) + def step_3(self): + execution_log.append("step_3_executed") + self.state["step3"] = "done" + return "step3_result" + + flow1 = HitlFlow(persistence=persistence) + flow1.kickoff() + flow_id = flow1.state["id"] + + assert execution_log == ["step_1_executed", "step_2_executed", "step_3_executed"] + + flow2 = HitlFlow(persistence=persistence) + flow2._completed_methods = {"step_1", "step_2"} # Simulate partial completion + execution_log.clear() + + flow2.kickoff(inputs={"id": flow_id}) + + assert "step_1_executed" not in execution_log + assert "step_2_executed" not in execution_log + assert "step_3_executed" in execution_log + + +def test_cyclic_flow_re_executes_on_each_iteration(): + """Test that cyclic flows properly re-execute methods on each iteration.""" + execution_log = [] + + class CyclicFlowTest(Flow[Dict[str, str]]): + iteration = 0 + max_iterations = 3 + + @start("loop") + def step_1(self): + if self.iteration >= self.max_iterations: + return None + execution_log.append(f"step_1_{self.iteration}") + return f"result_{self.iteration}" + + @listen(step_1) + def step_2(self): + execution_log.append(f"step_2_{self.iteration}") + + @router(step_2) + def step_3(self): + execution_log.append(f"step_3_{self.iteration}") + self.iteration += 1 + if self.iteration < self.max_iterations: + return "loop" + return "exit" + + flow = CyclicFlowTest() + flow.kickoff() + + expected = [] + for i in range(3): + expected.extend([f"step_1_{i}", f"step_2_{i}", f"step_3_{i}"]) + + assert execution_log == expected + + +def test_conditional_start_with_resumption(tmp_path): + """Test that conditional start methods work correctly with resumption.""" + db_path = tmp_path / "test_flows.db" + persistence = SQLiteFlowPersistence(str(db_path)) + execution_log = [] + + class ConditionalStartFlow(Flow[Dict[str, str]]): + @start() + def init(self): + execution_log.append("init") + return "initialized" + + @router(init) + def route_to_branch(self): + execution_log.append("router") + return "branch_a" + + @start("branch_a") + def branch_a_start(self): + execution_log.append("branch_a_start") + self.state["branch"] = "a" + + @listen(branch_a_start) + def branch_a_process(self): + execution_log.append("branch_a_process") + self.state["processed"] = "yes" + + flow1 = ConditionalStartFlow(persistence=persistence) + flow1.kickoff() + flow_id = flow1.state["id"] + + assert execution_log == ["init", "router", "branch_a_start", "branch_a_process"] + + flow2 = ConditionalStartFlow(persistence=persistence) + flow2._completed_methods = {"init", "route_to_branch", "branch_a_start"} + execution_log.clear() + + flow2.kickoff(inputs={"id": flow_id}) + + assert execution_log == ["branch_a_process"] + + +def test_cyclic_flow_with_conditional_start(): + """Test that cyclic flows work properly with conditional start methods.""" + execution_log = [] + + class CyclicConditionalFlow(Flow[Dict[str, str]]): + iteration = 0 + + @start() + def initial(self): + execution_log.append("initial") + return "init_done" + + @router(initial) + def route_to_cycle(self): + execution_log.append("router_initial") + return "loop" + + @start("loop") + def cycle_entry(self): + execution_log.append(f"cycle_{self.iteration}") + self.iteration += 1 + + @router(cycle_entry) + def cycle_router(self): + execution_log.append(f"router_{self.iteration - 1}") + if self.iteration < 3: + return "loop" + return "exit" + + flow = CyclicConditionalFlow() + flow.kickoff() + + expected = [ + "initial", + "router_initial", + "cycle_0", + "router_0", + "cycle_1", + "router_1", + "cycle_2", + "router_2", + ] + + assert execution_log == expected