"""Test flow state persistence functionality.""" import os from typing import Dict, List from crewai.flow.flow import Flow, FlowState, listen, start from crewai.flow.persistence import persist from crewai.flow.persistence.sqlite import SQLiteFlowPersistence from pydantic import BaseModel class TestState(FlowState): """Test state model with required id field.""" counter: int = 0 message: str = "" def test_persist_decorator_saves_state(tmp_path, caplog): """Test that @persist decorator saves state in SQLite.""" db_path = os.path.join(tmp_path, "test_flows.db") persistence = SQLiteFlowPersistence(db_path) class TestFlow(Flow[Dict[str, str]]): initial_state = dict() # Use dict instance as initial state @start() @persist(persistence) def init_step(self): self.state["message"] = "Hello, World!" self.state["id"] = "test-uuid" # Ensure we have an ID for persistence # Run flow and verify state is saved flow = TestFlow(persistence=persistence) flow.kickoff() # Load state from DB and verify saved_state = persistence.load_state(flow.state["id"]) assert saved_state is not None assert saved_state["message"] == "Hello, World!" def test_structured_state_persistence(tmp_path): """Test persistence with Pydantic model state.""" db_path = os.path.join(tmp_path, "test_flows.db") persistence = SQLiteFlowPersistence(db_path) class StructuredFlow(Flow[TestState]): initial_state = TestState @start() @persist(persistence) def count_up(self): self.state.counter += 1 self.state.message = f"Count is {self.state.counter}" # Run flow and verify state changes are saved flow = StructuredFlow(persistence=persistence) flow.kickoff() # Load and verify state saved_state = persistence.load_state(flow.state.id) assert saved_state is not None assert saved_state["counter"] == 1 assert saved_state["message"] == "Count is 1" def test_flow_state_restoration(tmp_path): """Test restoring flow state from persistence with various restoration methods.""" db_path = os.path.join(tmp_path, "test_flows.db") persistence = SQLiteFlowPersistence(db_path) # First flow execution to create initial state class RestorableFlow(Flow[TestState]): @start() @persist(persistence) def set_message(self): if self.state.message == "": self.state.message = "Original message" if self.state.counter == 0: self.state.counter = 42 # Create and persist initial state flow1 = RestorableFlow(persistence=persistence) flow1.kickoff() original_uuid = flow1.state.id # Test case 1: Restore using restore_uuid with field override flow2 = RestorableFlow(persistence=persistence) flow2.kickoff(inputs={"id": original_uuid, "counter": 43}) # Verify state restoration and selective field override assert flow2.state.id == original_uuid assert flow2.state.message == "Original message" # Preserved assert flow2.state.counter == 43 # Overridden # Test case 2: Restore using kwargs['id'] flow3 = RestorableFlow(persistence=persistence) flow3.kickoff(inputs={"id": original_uuid, "message": "Updated message"}) # Verify state restoration and selective field override assert flow3.state.id == original_uuid assert flow3.state.counter == 43 # Preserved assert flow3.state.message == "Updated message" # Overridden def test_multiple_method_persistence(tmp_path): """Test state persistence across multiple method executions.""" db_path = os.path.join(tmp_path, "test_flows.db") persistence = SQLiteFlowPersistence(db_path) class MultiStepFlow(Flow[TestState]): @start() @persist(persistence) def step_1(self): if self.state.counter == 1: self.state.counter = 99999 self.state.message = "Step 99999" else: self.state.counter = 1 self.state.message = "Step 1" @listen(step_1) @persist(persistence) def step_2(self): if self.state.counter == 1: self.state.counter = 2 self.state.message = "Step 2" flow = MultiStepFlow(persistence=persistence) flow.kickoff() flow2 = MultiStepFlow(persistence=persistence) flow2.kickoff(inputs={"id": flow.state.id}) # Load final state final_state = flow2.state assert final_state is not None assert final_state.counter == 2 assert final_state.message == "Step 2" class NoPersistenceMultiStepFlow(Flow[TestState]): @start() @persist(persistence) def step_1(self): if self.state.counter == 1: self.state.counter = 99999 self.state.message = "Step 99999" else: self.state.counter = 1 self.state.message = "Step 1" @listen(step_1) def step_2(self): if self.state.counter == 1: self.state.counter = 2 self.state.message = "Step 2" flow = NoPersistenceMultiStepFlow(persistence=persistence) flow.kickoff() flow2 = NoPersistenceMultiStepFlow(persistence=persistence) flow2.kickoff(inputs={"id": flow.state.id}) # Load final state final_state = flow2.state assert final_state.counter == 99999 assert final_state.message == "Step 99999" def test_persist_decorator_verbose_logging(tmp_path, caplog): """Test that @persist decorator's verbose parameter controls logging.""" # Set logging level to ensure we capture all logs caplog.set_level("INFO") db_path = os.path.join(tmp_path, "test_flows.db") persistence = SQLiteFlowPersistence(db_path) # Test with verbose=False (default) class QuietFlow(Flow[Dict[str, str]]): initial_state = dict() @start() @persist(persistence) # Default verbose=False def init_step(self): self.state["message"] = "Hello, World!" self.state["id"] = "test-uuid-1" flow = QuietFlow(persistence=persistence) flow.kickoff() assert "Saving flow state" not in caplog.text # Clear the log caplog.clear() # Test with verbose=True class VerboseFlow(Flow[Dict[str, str]]): initial_state = dict() @start() @persist(persistence, verbose=True) def init_step(self): self.state["message"] = "Hello, World!" self.state["id"] = "test-uuid-2" flow = VerboseFlow(persistence=persistence) flow.kickoff() assert "Saving flow state" in caplog.text def test_class_level_persist_auto_restores_state(tmp_path): """Test that class-level @persist automatically restores state on new instance. This is the documented behavior from the PersistentCounterFlow example: when @persist is applied at the class level, a new flow instance should automatically load the most recent persisted state for that flow class. """ db_path = os.path.join(tmp_path, "test_flows.db") persistence = SQLiteFlowPersistence(db_path) class CounterState(FlowState): value: int = 0 @persist(persistence) class PersistentCounterFlow(Flow[CounterState]): @start() def increment(self): self.state.value += 1 return self.state.value @listen(increment) def double(self, value): self.state.value = value * 2 return self.state.value # First run: 0 -> increment to 1 -> double to 2 flow1 = PersistentCounterFlow() result1 = flow1.kickoff() assert result1 == 2 assert flow1.state.value == 2 # Second run: state auto-restored to 2 -> increment to 3 -> double to 6 flow2 = PersistentCounterFlow() # State should be auto-restored before kickoff assert flow2.state.value == 2 result2 = flow2.kickoff() assert result2 == 6 assert flow2.state.value == 6 # Third run: state auto-restored to 6 -> increment to 7 -> double to 14 flow3 = PersistentCounterFlow() assert flow3.state.value == 6 result3 = flow3.kickoff() assert result3 == 14 assert flow3.state.value == 14 def test_class_level_persist_auto_restores_dict_state(tmp_path): """Test auto-restore with unstructured (dict) state.""" db_path = os.path.join(tmp_path, "test_flows.db") persistence = SQLiteFlowPersistence(db_path) @persist(persistence) class DictCounterFlow(Flow): @start() def count(self): current = self.state.get("counter", 0) self.state["counter"] = current + 1 return self.state["counter"] # First run flow1 = DictCounterFlow() result1 = flow1.kickoff() assert result1 == 1 assert flow1.state["counter"] == 1 # Second run: auto-restores state flow2 = DictCounterFlow() assert flow2.state["counter"] == 1 result2 = flow2.kickoff() assert result2 == 2 assert flow2.state["counter"] == 2 def test_class_level_persist_different_classes_dont_interfere(tmp_path): """Test that different flow classes with @persist don't interfere.""" db_path = os.path.join(tmp_path, "test_flows.db") persistence = SQLiteFlowPersistence(db_path) class SharedState(FlowState): value: int = 0 @persist(persistence) class FlowA(Flow[SharedState]): @start() def step(self): self.state.value += 10 return self.state.value @persist(persistence) class FlowB(Flow[SharedState]): @start() def step(self): self.state.value += 100 return self.state.value # Run FlowA flow_a1 = FlowA() flow_a1.kickoff() assert flow_a1.state.value == 10 # Run FlowB flow_b1 = FlowB() flow_b1.kickoff() assert flow_b1.state.value == 100 # New FlowA should restore FlowA's state (10), not FlowB's (100) flow_a2 = FlowA() assert flow_a2.state.value == 10 # New FlowB should restore FlowB's state (100), not FlowA's (10) flow_b2 = FlowB() assert flow_b2.state.value == 100 def test_load_latest_by_class_returns_none_when_empty(tmp_path): """Test that load_latest_by_class returns None for unknown classes.""" db_path = os.path.join(tmp_path, "test_flows.db") persistence = SQLiteFlowPersistence(db_path) assert persistence.load_latest_by_class("NonExistentFlow") is None def test_save_state_stores_flow_class(tmp_path): """Test that save_state stores the flow_class when provided.""" db_path = os.path.join(tmp_path, "test_flows.db") persistence = SQLiteFlowPersistence(db_path) persistence.save_state( flow_uuid="test-uuid", method_name="test_method", state_data={"id": "test-uuid", "value": 42}, flow_class="MyTestFlow", ) # Should be retrievable by class name loaded = persistence.load_latest_by_class("MyTestFlow") assert loaded is not None assert loaded["value"] == 42 assert loaded["id"] == "test-uuid" # Should still be retrievable by UUID loaded_by_uuid = persistence.load_state("test-uuid") assert loaded_by_uuid is not None assert loaded_by_uuid["value"] == 42 def test_method_level_persist_does_not_auto_restore(tmp_path): """Test that method-level @persist does NOT auto-restore state. Only class-level @persist should trigger auto-restore. Method-level @persist only saves state after method execution. """ db_path = os.path.join(tmp_path, "test_flows.db") persistence = SQLiteFlowPersistence(db_path) class CounterState(FlowState): value: int = 0 class MethodLevelFlow(Flow[CounterState]): @start() @persist(persistence) def increment(self): self.state.value += 1 return self.state.value # First run flow1 = MethodLevelFlow(persistence=persistence) flow1.kickoff() assert flow1.state.value == 1 # Second run: no auto-restore since @persist is method-level flow2 = MethodLevelFlow(persistence=persistence) assert flow2.state.value == 0 # Default, not restored def test_persistence_with_base_model(tmp_path): db_path = os.path.join(tmp_path, "test_flows.db") persistence = SQLiteFlowPersistence(db_path) class Message(BaseModel): role: str type: str content: str class State(FlowState): latest_message: Message | None = None history: List[Message] = [] @persist(persistence) class BaseModelFlow(Flow[State]): initial_state = State(latest_message=None, history=[]) @start() def init_step(self): self.state.latest_message = Message( role="user", type="text", content="Hello, World!" ) self.state.history.append(self.state.latest_message) flow = BaseModelFlow(persistence=persistence) flow.kickoff() latest_message = flow.state.latest_message (message,) = flow.state.history assert latest_message is not None assert latest_message.role == "user" assert latest_message.type == "text" assert latest_message.content == "Hello, World!" assert len(flow.state.history) == 1 assert message.role == "user" assert message.type == "text" assert message.content == "Hello, World!" assert isinstance(flow.state._unwrap(), State)