fix: auto-restore persisted state for class-level @persist decorator

When @persist is applied at the class level, new flow instances now
automatically load the most recent persisted state for that flow class.
This matches the documented behavior where creating a new instance of a
@persist-decorated flow seamlessly continues from the previous run's state.

Changes:
- Add load_latest_by_class() to FlowPersistence base class
- Implement load_latest_by_class() in SQLiteFlowPersistence with
  flow_class column and index for efficient lookups
- Store flow class name when persisting state via PersistenceDecorator
- Auto-restore latest state in class-level @persist decorator's __init__
- Add migration for existing databases (ALTER TABLE ADD COLUMN)

Fixes #5378

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2026-04-09 06:32:49 +00:00
parent 06fe163611
commit a24b011f6f
4 changed files with 283 additions and 9 deletions

View File

@@ -67,6 +67,24 @@ class FlowPersistence(BaseModel, ABC):
The most recent state as a dictionary, or None if no state exists
"""
def load_latest_by_class(self, flow_class: str) -> dict[str, Any] | None:
"""Load the most recent state for a given flow class name.
This method is used to auto-restore persisted state when the @persist
decorator is applied at the class level. When a new flow instance is
created, this method is called to find and restore the latest persisted
state for that flow class, enabling seamless state continuity across runs.
Override in subclasses to support automatic state restoration.
Args:
flow_class: The name of the flow class
Returns:
The most recent state as a dictionary, or None if no state exists
"""
return None
def save_pending_feedback(
self,
flow_uuid: str,

View File

@@ -109,11 +109,22 @@ class PersistenceDecorator:
try:
state_data = state._unwrap() if hasattr(state, "_unwrap") else state
persistence_instance.save_state(
flow_uuid=flow_uuid,
method_name=method_name,
state_data=state_data,
)
flow_class_name = type(flow_instance).__name__
try:
persistence_instance.save_state(
flow_uuid=flow_uuid,
method_name=method_name,
state_data=state_data,
flow_class=flow_class_name,
)
except TypeError:
# Fallback for custom persistence backends that
# don't accept the flow_class parameter
persistence_instance.save_state(
flow_uuid=flow_uuid,
method_name=method_name,
state_data=state_data,
)
except Exception as e:
error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e))
if verbose:
@@ -178,6 +189,23 @@ def persist(
kwargs["persistence"] = actual_persistence
original_init(self, *args, **kwargs)
# Auto-restore the latest persisted state for this flow class
# so that a new instance seamlessly continues from the previous run.
if hasattr(actual_persistence, "load_latest_by_class"):
flow_class_name = type(self).__name__
stored_state = actual_persistence.load_latest_by_class(
flow_class_name
)
if stored_state is not None:
try:
self._restore_state(stored_state)
except Exception:
logger.debug(
"Could not auto-restore persisted state for %s",
flow_class_name,
exc_info=True,
)
target.__init__ = new_init # type: ignore[misc]
# Store original methods to preserve their decorators

View File

@@ -83,10 +83,18 @@ class SQLiteFlowPersistence(FlowPersistence):
flow_uuid TEXT NOT NULL,
method_name TEXT NOT NULL,
timestamp DATETIME NOT NULL,
state_json TEXT NOT NULL
state_json TEXT NOT NULL,
flow_class TEXT
)
"""
)
# Migration: add flow_class column for existing databases
try:
conn.execute(
"ALTER TABLE flow_states ADD COLUMN flow_class TEXT"
)
except sqlite3.OperationalError:
pass # Column already exists
# Add index for faster UUID lookups
conn.execute(
"""
@@ -94,6 +102,13 @@ class SQLiteFlowPersistence(FlowPersistence):
ON flow_states(flow_uuid)
"""
)
# Add index for faster flow class lookups
conn.execute(
"""
CREATE INDEX IF NOT EXISTS idx_flow_states_class
ON flow_states(flow_class)
"""
)
# Pending feedback table for async HITL
conn.execute(
@@ -121,6 +136,7 @@ class SQLiteFlowPersistence(FlowPersistence):
flow_uuid: str,
method_name: str,
state_dict: dict[str, Any],
flow_class: str | None = None,
) -> None:
"""Execute the save-state INSERT without acquiring the lock.
@@ -129,6 +145,7 @@ class SQLiteFlowPersistence(FlowPersistence):
flow_uuid: Unique identifier for the flow instance.
method_name: Name of the method that just completed.
state_dict: State data as a plain dict.
flow_class: Optional name of the flow class for auto-restore support.
"""
conn.execute(
"""
@@ -136,14 +153,16 @@ class SQLiteFlowPersistence(FlowPersistence):
flow_uuid,
method_name,
timestamp,
state_json
) VALUES (?, ?, ?, ?)
state_json,
flow_class
) VALUES (?, ?, ?, ?, ?)
""",
(
flow_uuid,
method_name,
datetime.now(timezone.utc).isoformat(),
json.dumps(state_dict),
flow_class,
),
)
@@ -163,6 +182,7 @@ class SQLiteFlowPersistence(FlowPersistence):
flow_uuid: str,
method_name: str,
state_data: dict[str, Any] | BaseModel,
flow_class: str | None = None,
) -> None:
"""Save the current flow state to SQLite.
@@ -170,6 +190,7 @@ class SQLiteFlowPersistence(FlowPersistence):
flow_uuid: Unique identifier for the flow instance
method_name: Name of the method that just completed
state_data: Current state data (either dict or Pydantic model)
flow_class: Optional name of the flow class for auto-restore support
"""
state_dict = self._to_state_dict(state_data)
@@ -177,7 +198,7 @@ class SQLiteFlowPersistence(FlowPersistence):
store_lock(self._lock_name),
sqlite3.connect(self.db_path, timeout=30) as conn,
):
self._save_state_sql(conn, flow_uuid, method_name, state_dict)
self._save_state_sql(conn, flow_uuid, method_name, state_dict, flow_class)
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
"""Load the most recent state for a given flow UUID.
@@ -206,6 +227,38 @@ class SQLiteFlowPersistence(FlowPersistence):
return result if isinstance(result, dict) else None
return None
def load_latest_by_class(self, flow_class: str) -> dict[str, Any] | None:
"""Load the most recent state for a given flow class name.
This enables automatic state restoration when @persist is applied at
the class level. The most recent state entry for the given flow class
is returned, allowing new flow instances to seamlessly continue from
where the previous run left off.
Args:
flow_class: The name of the flow class
Returns:
The most recent state as a dictionary, or None if no state exists
"""
with sqlite3.connect(self.db_path, timeout=30) as conn:
cursor = conn.execute(
"""
SELECT state_json
FROM flow_states
WHERE flow_class = ?
ORDER BY id DESC
LIMIT 1
""",
(flow_class,),
)
row = cursor.fetchone()
if row:
result = json.loads(row[0])
return result if isinstance(result, dict) else None
return None
def save_pending_feedback(
self,
flow_uuid: str,

View File

@@ -208,6 +208,181 @@ def test_persist_decorator_verbose_logging(tmp_path, caplog):
assert "Saving flow state" in caplog.text
def test_class_level_persist_auto_restores_state(tmp_path):
"""Test that class-level @persist automatically restores state on new instance.
This is the documented behavior from the PersistentCounterFlow example:
when @persist is applied at the class level, a new flow instance should
automatically load the most recent persisted state for that flow class.
"""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class CounterState(FlowState):
value: int = 0
@persist(persistence)
class PersistentCounterFlow(Flow[CounterState]):
@start()
def increment(self):
self.state.value += 1
return self.state.value
@listen(increment)
def double(self, value):
self.state.value = value * 2
return self.state.value
# First run: 0 -> increment to 1 -> double to 2
flow1 = PersistentCounterFlow()
result1 = flow1.kickoff()
assert result1 == 2
assert flow1.state.value == 2
# Second run: state auto-restored to 2 -> increment to 3 -> double to 6
flow2 = PersistentCounterFlow()
# State should be auto-restored before kickoff
assert flow2.state.value == 2
result2 = flow2.kickoff()
assert result2 == 6
assert flow2.state.value == 6
# Third run: state auto-restored to 6 -> increment to 7 -> double to 14
flow3 = PersistentCounterFlow()
assert flow3.state.value == 6
result3 = flow3.kickoff()
assert result3 == 14
assert flow3.state.value == 14
def test_class_level_persist_auto_restores_dict_state(tmp_path):
"""Test auto-restore with unstructured (dict) state."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
@persist(persistence)
class DictCounterFlow(Flow):
@start()
def count(self):
current = self.state.get("counter", 0)
self.state["counter"] = current + 1
return self.state["counter"]
# First run
flow1 = DictCounterFlow()
result1 = flow1.kickoff()
assert result1 == 1
assert flow1.state["counter"] == 1
# Second run: auto-restores state
flow2 = DictCounterFlow()
assert flow2.state["counter"] == 1
result2 = flow2.kickoff()
assert result2 == 2
assert flow2.state["counter"] == 2
def test_class_level_persist_different_classes_dont_interfere(tmp_path):
"""Test that different flow classes with @persist don't interfere."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class SharedState(FlowState):
value: int = 0
@persist(persistence)
class FlowA(Flow[SharedState]):
@start()
def step(self):
self.state.value += 10
return self.state.value
@persist(persistence)
class FlowB(Flow[SharedState]):
@start()
def step(self):
self.state.value += 100
return self.state.value
# Run FlowA
flow_a1 = FlowA()
flow_a1.kickoff()
assert flow_a1.state.value == 10
# Run FlowB
flow_b1 = FlowB()
flow_b1.kickoff()
assert flow_b1.state.value == 100
# New FlowA should restore FlowA's state (10), not FlowB's (100)
flow_a2 = FlowA()
assert flow_a2.state.value == 10
# New FlowB should restore FlowB's state (100), not FlowA's (10)
flow_b2 = FlowB()
assert flow_b2.state.value == 100
def test_load_latest_by_class_returns_none_when_empty(tmp_path):
"""Test that load_latest_by_class returns None for unknown classes."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
assert persistence.load_latest_by_class("NonExistentFlow") is None
def test_save_state_stores_flow_class(tmp_path):
"""Test that save_state stores the flow_class when provided."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
persistence.save_state(
flow_uuid="test-uuid",
method_name="test_method",
state_data={"id": "test-uuid", "value": 42},
flow_class="MyTestFlow",
)
# Should be retrievable by class name
loaded = persistence.load_latest_by_class("MyTestFlow")
assert loaded is not None
assert loaded["value"] == 42
assert loaded["id"] == "test-uuid"
# Should still be retrievable by UUID
loaded_by_uuid = persistence.load_state("test-uuid")
assert loaded_by_uuid is not None
assert loaded_by_uuid["value"] == 42
def test_method_level_persist_does_not_auto_restore(tmp_path):
"""Test that method-level @persist does NOT auto-restore state.
Only class-level @persist should trigger auto-restore. Method-level
@persist only saves state after method execution.
"""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class CounterState(FlowState):
value: int = 0
class MethodLevelFlow(Flow[CounterState]):
@start()
@persist(persistence)
def increment(self):
self.state.value += 1
return self.state.value
# First run
flow1 = MethodLevelFlow(persistence=persistence)
flow1.kickoff()
assert flow1.state.value == 1
# Second run: no auto-restore since @persist is method-level
flow2 = MethodLevelFlow(persistence=persistence)
assert flow2.state.value == 0 # Default, not restored
def test_persistence_with_base_model(tmp_path):
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)