diff --git a/lib/crewai/src/crewai/flow/persistence/base.py b/lib/crewai/src/crewai/flow/persistence/base.py index 1114359a1..fd31c9af1 100644 --- a/lib/crewai/src/crewai/flow/persistence/base.py +++ b/lib/crewai/src/crewai/flow/persistence/base.py @@ -67,6 +67,24 @@ class FlowPersistence(BaseModel, ABC): The most recent state as a dictionary, or None if no state exists """ + def load_latest_by_class(self, flow_class: str) -> dict[str, Any] | None: + """Load the most recent state for a given flow class name. + + This method is used to auto-restore persisted state when the @persist + decorator is applied at the class level. When a new flow instance is + created, this method is called to find and restore the latest persisted + state for that flow class, enabling seamless state continuity across runs. + + Override in subclasses to support automatic state restoration. + + Args: + flow_class: The name of the flow class + + Returns: + The most recent state as a dictionary, or None if no state exists + """ + return None + def save_pending_feedback( self, flow_uuid: str, diff --git a/lib/crewai/src/crewai/flow/persistence/decorators.py b/lib/crewai/src/crewai/flow/persistence/decorators.py index 937b557f4..126bbd39a 100644 --- a/lib/crewai/src/crewai/flow/persistence/decorators.py +++ b/lib/crewai/src/crewai/flow/persistence/decorators.py @@ -109,11 +109,22 @@ class PersistenceDecorator: try: state_data = state._unwrap() if hasattr(state, "_unwrap") else state - persistence_instance.save_state( - flow_uuid=flow_uuid, - method_name=method_name, - state_data=state_data, - ) + flow_class_name = type(flow_instance).__name__ + try: + persistence_instance.save_state( + flow_uuid=flow_uuid, + method_name=method_name, + state_data=state_data, + flow_class=flow_class_name, + ) + except TypeError: + # Fallback for custom persistence backends that + # don't accept the flow_class parameter + persistence_instance.save_state( + flow_uuid=flow_uuid, + method_name=method_name, + state_data=state_data, + ) except Exception as e: error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e)) if verbose: @@ -178,6 +189,23 @@ def persist( kwargs["persistence"] = actual_persistence original_init(self, *args, **kwargs) + # Auto-restore the latest persisted state for this flow class + # so that a new instance seamlessly continues from the previous run. + if hasattr(actual_persistence, "load_latest_by_class"): + flow_class_name = type(self).__name__ + stored_state = actual_persistence.load_latest_by_class( + flow_class_name + ) + if stored_state is not None: + try: + self._restore_state(stored_state) + except Exception: + logger.debug( + "Could not auto-restore persisted state for %s", + flow_class_name, + exc_info=True, + ) + target.__init__ = new_init # type: ignore[misc] # Store original methods to preserve their decorators diff --git a/lib/crewai/src/crewai/flow/persistence/sqlite.py b/lib/crewai/src/crewai/flow/persistence/sqlite.py index fa2e4e127..18ad13e0f 100644 --- a/lib/crewai/src/crewai/flow/persistence/sqlite.py +++ b/lib/crewai/src/crewai/flow/persistence/sqlite.py @@ -83,10 +83,18 @@ class SQLiteFlowPersistence(FlowPersistence): flow_uuid TEXT NOT NULL, method_name TEXT NOT NULL, timestamp DATETIME NOT NULL, - state_json TEXT NOT NULL + state_json TEXT NOT NULL, + flow_class TEXT ) """ ) + # Migration: add flow_class column for existing databases + try: + conn.execute( + "ALTER TABLE flow_states ADD COLUMN flow_class TEXT" + ) + except sqlite3.OperationalError: + pass # Column already exists # Add index for faster UUID lookups conn.execute( """ @@ -94,6 +102,13 @@ class SQLiteFlowPersistence(FlowPersistence): ON flow_states(flow_uuid) """ ) + # Add index for faster flow class lookups + conn.execute( + """ + CREATE INDEX IF NOT EXISTS idx_flow_states_class + ON flow_states(flow_class) + """ + ) # Pending feedback table for async HITL conn.execute( @@ -121,6 +136,7 @@ class SQLiteFlowPersistence(FlowPersistence): flow_uuid: str, method_name: str, state_dict: dict[str, Any], + flow_class: str | None = None, ) -> None: """Execute the save-state INSERT without acquiring the lock. @@ -129,6 +145,7 @@ class SQLiteFlowPersistence(FlowPersistence): flow_uuid: Unique identifier for the flow instance. method_name: Name of the method that just completed. state_dict: State data as a plain dict. + flow_class: Optional name of the flow class for auto-restore support. """ conn.execute( """ @@ -136,14 +153,16 @@ class SQLiteFlowPersistence(FlowPersistence): flow_uuid, method_name, timestamp, - state_json - ) VALUES (?, ?, ?, ?) + state_json, + flow_class + ) VALUES (?, ?, ?, ?, ?) """, ( flow_uuid, method_name, datetime.now(timezone.utc).isoformat(), json.dumps(state_dict), + flow_class, ), ) @@ -163,6 +182,7 @@ class SQLiteFlowPersistence(FlowPersistence): flow_uuid: str, method_name: str, state_data: dict[str, Any] | BaseModel, + flow_class: str | None = None, ) -> None: """Save the current flow state to SQLite. @@ -170,6 +190,7 @@ class SQLiteFlowPersistence(FlowPersistence): flow_uuid: Unique identifier for the flow instance method_name: Name of the method that just completed state_data: Current state data (either dict or Pydantic model) + flow_class: Optional name of the flow class for auto-restore support """ state_dict = self._to_state_dict(state_data) @@ -177,7 +198,7 @@ class SQLiteFlowPersistence(FlowPersistence): store_lock(self._lock_name), sqlite3.connect(self.db_path, timeout=30) as conn, ): - self._save_state_sql(conn, flow_uuid, method_name, state_dict) + self._save_state_sql(conn, flow_uuid, method_name, state_dict, flow_class) def load_state(self, flow_uuid: str) -> dict[str, Any] | None: """Load the most recent state for a given flow UUID. @@ -206,6 +227,38 @@ class SQLiteFlowPersistence(FlowPersistence): return result if isinstance(result, dict) else None return None + def load_latest_by_class(self, flow_class: str) -> dict[str, Any] | None: + """Load the most recent state for a given flow class name. + + This enables automatic state restoration when @persist is applied at + the class level. The most recent state entry for the given flow class + is returned, allowing new flow instances to seamlessly continue from + where the previous run left off. + + Args: + flow_class: The name of the flow class + + Returns: + The most recent state as a dictionary, or None if no state exists + """ + with sqlite3.connect(self.db_path, timeout=30) as conn: + cursor = conn.execute( + """ + SELECT state_json + FROM flow_states + WHERE flow_class = ? + ORDER BY id DESC + LIMIT 1 + """, + (flow_class,), + ) + row = cursor.fetchone() + + if row: + result = json.loads(row[0]) + return result if isinstance(result, dict) else None + return None + def save_pending_feedback( self, flow_uuid: str, diff --git a/lib/crewai/tests/test_flow_persistence.py b/lib/crewai/tests/test_flow_persistence.py index 06bbf7231..1727d7200 100644 --- a/lib/crewai/tests/test_flow_persistence.py +++ b/lib/crewai/tests/test_flow_persistence.py @@ -208,6 +208,181 @@ def test_persist_decorator_verbose_logging(tmp_path, caplog): assert "Saving flow state" in caplog.text +def test_class_level_persist_auto_restores_state(tmp_path): + """Test that class-level @persist automatically restores state on new instance. + + This is the documented behavior from the PersistentCounterFlow example: + when @persist is applied at the class level, a new flow instance should + automatically load the most recent persisted state for that flow class. + """ + db_path = os.path.join(tmp_path, "test_flows.db") + persistence = SQLiteFlowPersistence(db_path) + + class CounterState(FlowState): + value: int = 0 + + @persist(persistence) + class PersistentCounterFlow(Flow[CounterState]): + @start() + def increment(self): + self.state.value += 1 + return self.state.value + + @listen(increment) + def double(self, value): + self.state.value = value * 2 + return self.state.value + + # First run: 0 -> increment to 1 -> double to 2 + flow1 = PersistentCounterFlow() + result1 = flow1.kickoff() + assert result1 == 2 + assert flow1.state.value == 2 + + # Second run: state auto-restored to 2 -> increment to 3 -> double to 6 + flow2 = PersistentCounterFlow() + # State should be auto-restored before kickoff + assert flow2.state.value == 2 + result2 = flow2.kickoff() + assert result2 == 6 + assert flow2.state.value == 6 + + # Third run: state auto-restored to 6 -> increment to 7 -> double to 14 + flow3 = PersistentCounterFlow() + assert flow3.state.value == 6 + result3 = flow3.kickoff() + assert result3 == 14 + assert flow3.state.value == 14 + + +def test_class_level_persist_auto_restores_dict_state(tmp_path): + """Test auto-restore with unstructured (dict) state.""" + db_path = os.path.join(tmp_path, "test_flows.db") + persistence = SQLiteFlowPersistence(db_path) + + @persist(persistence) + class DictCounterFlow(Flow): + @start() + def count(self): + current = self.state.get("counter", 0) + self.state["counter"] = current + 1 + return self.state["counter"] + + # First run + flow1 = DictCounterFlow() + result1 = flow1.kickoff() + assert result1 == 1 + assert flow1.state["counter"] == 1 + + # Second run: auto-restores state + flow2 = DictCounterFlow() + assert flow2.state["counter"] == 1 + result2 = flow2.kickoff() + assert result2 == 2 + assert flow2.state["counter"] == 2 + + +def test_class_level_persist_different_classes_dont_interfere(tmp_path): + """Test that different flow classes with @persist don't interfere.""" + db_path = os.path.join(tmp_path, "test_flows.db") + persistence = SQLiteFlowPersistence(db_path) + + class SharedState(FlowState): + value: int = 0 + + @persist(persistence) + class FlowA(Flow[SharedState]): + @start() + def step(self): + self.state.value += 10 + return self.state.value + + @persist(persistence) + class FlowB(Flow[SharedState]): + @start() + def step(self): + self.state.value += 100 + return self.state.value + + # Run FlowA + flow_a1 = FlowA() + flow_a1.kickoff() + assert flow_a1.state.value == 10 + + # Run FlowB + flow_b1 = FlowB() + flow_b1.kickoff() + assert flow_b1.state.value == 100 + + # New FlowA should restore FlowA's state (10), not FlowB's (100) + flow_a2 = FlowA() + assert flow_a2.state.value == 10 + + # New FlowB should restore FlowB's state (100), not FlowA's (10) + flow_b2 = FlowB() + assert flow_b2.state.value == 100 + + +def test_load_latest_by_class_returns_none_when_empty(tmp_path): + """Test that load_latest_by_class returns None for unknown classes.""" + db_path = os.path.join(tmp_path, "test_flows.db") + persistence = SQLiteFlowPersistence(db_path) + assert persistence.load_latest_by_class("NonExistentFlow") is None + + +def test_save_state_stores_flow_class(tmp_path): + """Test that save_state stores the flow_class when provided.""" + db_path = os.path.join(tmp_path, "test_flows.db") + persistence = SQLiteFlowPersistence(db_path) + + persistence.save_state( + flow_uuid="test-uuid", + method_name="test_method", + state_data={"id": "test-uuid", "value": 42}, + flow_class="MyTestFlow", + ) + + # Should be retrievable by class name + loaded = persistence.load_latest_by_class("MyTestFlow") + assert loaded is not None + assert loaded["value"] == 42 + assert loaded["id"] == "test-uuid" + + # Should still be retrievable by UUID + loaded_by_uuid = persistence.load_state("test-uuid") + assert loaded_by_uuid is not None + assert loaded_by_uuid["value"] == 42 + + +def test_method_level_persist_does_not_auto_restore(tmp_path): + """Test that method-level @persist does NOT auto-restore state. + + Only class-level @persist should trigger auto-restore. Method-level + @persist only saves state after method execution. + """ + db_path = os.path.join(tmp_path, "test_flows.db") + persistence = SQLiteFlowPersistence(db_path) + + class CounterState(FlowState): + value: int = 0 + + class MethodLevelFlow(Flow[CounterState]): + @start() + @persist(persistence) + def increment(self): + self.state.value += 1 + return self.state.value + + # First run + flow1 = MethodLevelFlow(persistence=persistence) + flow1.kickoff() + assert flow1.state.value == 1 + + # Second run: no auto-restore since @persist is method-level + flow2 = MethodLevelFlow(persistence=persistence) + assert flow2.state.value == 0 # Default, not restored + + def test_persistence_with_base_model(tmp_path): db_path = os.path.join(tmp_path, "test_flows.db") persistence = SQLiteFlowPersistence(db_path)