mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 23:02:50 +00:00
fix: auto-restore persisted state for class-level @persist decorator
When @persist is applied at the class level, new flow instances now automatically load the most recent persisted state for that flow class. This matches the documented behavior where creating a new instance of a @persist-decorated flow seamlessly continues from the previous run's state. Changes: - Add load_latest_by_class() to FlowPersistence base class - Implement load_latest_by_class() in SQLiteFlowPersistence with flow_class column and index for efficient lookups - Store flow class name when persisting state via PersistenceDecorator - Auto-restore latest state in class-level @persist decorator's __init__ - Add migration for existing databases (ALTER TABLE ADD COLUMN) Fixes #5378 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -67,6 +67,24 @@ class FlowPersistence(BaseModel, ABC):
|
||||
The most recent state as a dictionary, or None if no state exists
|
||||
"""
|
||||
|
||||
def load_latest_by_class(self, flow_class: str) -> dict[str, Any] | None:
|
||||
"""Load the most recent state for a given flow class name.
|
||||
|
||||
This method is used to auto-restore persisted state when the @persist
|
||||
decorator is applied at the class level. When a new flow instance is
|
||||
created, this method is called to find and restore the latest persisted
|
||||
state for that flow class, enabling seamless state continuity across runs.
|
||||
|
||||
Override in subclasses to support automatic state restoration.
|
||||
|
||||
Args:
|
||||
flow_class: The name of the flow class
|
||||
|
||||
Returns:
|
||||
The most recent state as a dictionary, or None if no state exists
|
||||
"""
|
||||
return None
|
||||
|
||||
def save_pending_feedback(
|
||||
self,
|
||||
flow_uuid: str,
|
||||
|
||||
@@ -109,11 +109,22 @@ class PersistenceDecorator:
|
||||
|
||||
try:
|
||||
state_data = state._unwrap() if hasattr(state, "_unwrap") else state
|
||||
persistence_instance.save_state(
|
||||
flow_uuid=flow_uuid,
|
||||
method_name=method_name,
|
||||
state_data=state_data,
|
||||
)
|
||||
flow_class_name = type(flow_instance).__name__
|
||||
try:
|
||||
persistence_instance.save_state(
|
||||
flow_uuid=flow_uuid,
|
||||
method_name=method_name,
|
||||
state_data=state_data,
|
||||
flow_class=flow_class_name,
|
||||
)
|
||||
except TypeError:
|
||||
# Fallback for custom persistence backends that
|
||||
# don't accept the flow_class parameter
|
||||
persistence_instance.save_state(
|
||||
flow_uuid=flow_uuid,
|
||||
method_name=method_name,
|
||||
state_data=state_data,
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e))
|
||||
if verbose:
|
||||
@@ -178,6 +189,23 @@ def persist(
|
||||
kwargs["persistence"] = actual_persistence
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
# Auto-restore the latest persisted state for this flow class
|
||||
# so that a new instance seamlessly continues from the previous run.
|
||||
if hasattr(actual_persistence, "load_latest_by_class"):
|
||||
flow_class_name = type(self).__name__
|
||||
stored_state = actual_persistence.load_latest_by_class(
|
||||
flow_class_name
|
||||
)
|
||||
if stored_state is not None:
|
||||
try:
|
||||
self._restore_state(stored_state)
|
||||
except Exception:
|
||||
logger.debug(
|
||||
"Could not auto-restore persisted state for %s",
|
||||
flow_class_name,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
target.__init__ = new_init # type: ignore[misc]
|
||||
|
||||
# Store original methods to preserve their decorators
|
||||
|
||||
@@ -83,10 +83,18 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
flow_uuid TEXT NOT NULL,
|
||||
method_name TEXT NOT NULL,
|
||||
timestamp DATETIME NOT NULL,
|
||||
state_json TEXT NOT NULL
|
||||
state_json TEXT NOT NULL,
|
||||
flow_class TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
# Migration: add flow_class column for existing databases
|
||||
try:
|
||||
conn.execute(
|
||||
"ALTER TABLE flow_states ADD COLUMN flow_class TEXT"
|
||||
)
|
||||
except sqlite3.OperationalError:
|
||||
pass # Column already exists
|
||||
# Add index for faster UUID lookups
|
||||
conn.execute(
|
||||
"""
|
||||
@@ -94,6 +102,13 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
ON flow_states(flow_uuid)
|
||||
"""
|
||||
)
|
||||
# Add index for faster flow class lookups
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_flow_states_class
|
||||
ON flow_states(flow_class)
|
||||
"""
|
||||
)
|
||||
|
||||
# Pending feedback table for async HITL
|
||||
conn.execute(
|
||||
@@ -121,6 +136,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
flow_uuid: str,
|
||||
method_name: str,
|
||||
state_dict: dict[str, Any],
|
||||
flow_class: str | None = None,
|
||||
) -> None:
|
||||
"""Execute the save-state INSERT without acquiring the lock.
|
||||
|
||||
@@ -129,6 +145,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
flow_uuid: Unique identifier for the flow instance.
|
||||
method_name: Name of the method that just completed.
|
||||
state_dict: State data as a plain dict.
|
||||
flow_class: Optional name of the flow class for auto-restore support.
|
||||
"""
|
||||
conn.execute(
|
||||
"""
|
||||
@@ -136,14 +153,16 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
flow_uuid,
|
||||
method_name,
|
||||
timestamp,
|
||||
state_json
|
||||
) VALUES (?, ?, ?, ?)
|
||||
state_json,
|
||||
flow_class
|
||||
) VALUES (?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
flow_uuid,
|
||||
method_name,
|
||||
datetime.now(timezone.utc).isoformat(),
|
||||
json.dumps(state_dict),
|
||||
flow_class,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -163,6 +182,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
flow_uuid: str,
|
||||
method_name: str,
|
||||
state_data: dict[str, Any] | BaseModel,
|
||||
flow_class: str | None = None,
|
||||
) -> None:
|
||||
"""Save the current flow state to SQLite.
|
||||
|
||||
@@ -170,6 +190,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
method_name: Name of the method that just completed
|
||||
state_data: Current state data (either dict or Pydantic model)
|
||||
flow_class: Optional name of the flow class for auto-restore support
|
||||
"""
|
||||
state_dict = self._to_state_dict(state_data)
|
||||
|
||||
@@ -177,7 +198,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
store_lock(self._lock_name),
|
||||
sqlite3.connect(self.db_path, timeout=30) as conn,
|
||||
):
|
||||
self._save_state_sql(conn, flow_uuid, method_name, state_dict)
|
||||
self._save_state_sql(conn, flow_uuid, method_name, state_dict, flow_class)
|
||||
|
||||
def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
|
||||
"""Load the most recent state for a given flow UUID.
|
||||
@@ -206,6 +227,38 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
return result if isinstance(result, dict) else None
|
||||
return None
|
||||
|
||||
def load_latest_by_class(self, flow_class: str) -> dict[str, Any] | None:
|
||||
"""Load the most recent state for a given flow class name.
|
||||
|
||||
This enables automatic state restoration when @persist is applied at
|
||||
the class level. The most recent state entry for the given flow class
|
||||
is returned, allowing new flow instances to seamlessly continue from
|
||||
where the previous run left off.
|
||||
|
||||
Args:
|
||||
flow_class: The name of the flow class
|
||||
|
||||
Returns:
|
||||
The most recent state as a dictionary, or None if no state exists
|
||||
"""
|
||||
with sqlite3.connect(self.db_path, timeout=30) as conn:
|
||||
cursor = conn.execute(
|
||||
"""
|
||||
SELECT state_json
|
||||
FROM flow_states
|
||||
WHERE flow_class = ?
|
||||
ORDER BY id DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
(flow_class,),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row:
|
||||
result = json.loads(row[0])
|
||||
return result if isinstance(result, dict) else None
|
||||
return None
|
||||
|
||||
def save_pending_feedback(
|
||||
self,
|
||||
flow_uuid: str,
|
||||
|
||||
@@ -208,6 +208,181 @@ def test_persist_decorator_verbose_logging(tmp_path, caplog):
|
||||
assert "Saving flow state" in caplog.text
|
||||
|
||||
|
||||
def test_class_level_persist_auto_restores_state(tmp_path):
|
||||
"""Test that class-level @persist automatically restores state on new instance.
|
||||
|
||||
This is the documented behavior from the PersistentCounterFlow example:
|
||||
when @persist is applied at the class level, a new flow instance should
|
||||
automatically load the most recent persisted state for that flow class.
|
||||
"""
|
||||
db_path = os.path.join(tmp_path, "test_flows.db")
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
|
||||
class CounterState(FlowState):
|
||||
value: int = 0
|
||||
|
||||
@persist(persistence)
|
||||
class PersistentCounterFlow(Flow[CounterState]):
|
||||
@start()
|
||||
def increment(self):
|
||||
self.state.value += 1
|
||||
return self.state.value
|
||||
|
||||
@listen(increment)
|
||||
def double(self, value):
|
||||
self.state.value = value * 2
|
||||
return self.state.value
|
||||
|
||||
# First run: 0 -> increment to 1 -> double to 2
|
||||
flow1 = PersistentCounterFlow()
|
||||
result1 = flow1.kickoff()
|
||||
assert result1 == 2
|
||||
assert flow1.state.value == 2
|
||||
|
||||
# Second run: state auto-restored to 2 -> increment to 3 -> double to 6
|
||||
flow2 = PersistentCounterFlow()
|
||||
# State should be auto-restored before kickoff
|
||||
assert flow2.state.value == 2
|
||||
result2 = flow2.kickoff()
|
||||
assert result2 == 6
|
||||
assert flow2.state.value == 6
|
||||
|
||||
# Third run: state auto-restored to 6 -> increment to 7 -> double to 14
|
||||
flow3 = PersistentCounterFlow()
|
||||
assert flow3.state.value == 6
|
||||
result3 = flow3.kickoff()
|
||||
assert result3 == 14
|
||||
assert flow3.state.value == 14
|
||||
|
||||
|
||||
def test_class_level_persist_auto_restores_dict_state(tmp_path):
|
||||
"""Test auto-restore with unstructured (dict) state."""
|
||||
db_path = os.path.join(tmp_path, "test_flows.db")
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
|
||||
@persist(persistence)
|
||||
class DictCounterFlow(Flow):
|
||||
@start()
|
||||
def count(self):
|
||||
current = self.state.get("counter", 0)
|
||||
self.state["counter"] = current + 1
|
||||
return self.state["counter"]
|
||||
|
||||
# First run
|
||||
flow1 = DictCounterFlow()
|
||||
result1 = flow1.kickoff()
|
||||
assert result1 == 1
|
||||
assert flow1.state["counter"] == 1
|
||||
|
||||
# Second run: auto-restores state
|
||||
flow2 = DictCounterFlow()
|
||||
assert flow2.state["counter"] == 1
|
||||
result2 = flow2.kickoff()
|
||||
assert result2 == 2
|
||||
assert flow2.state["counter"] == 2
|
||||
|
||||
|
||||
def test_class_level_persist_different_classes_dont_interfere(tmp_path):
|
||||
"""Test that different flow classes with @persist don't interfere."""
|
||||
db_path = os.path.join(tmp_path, "test_flows.db")
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
|
||||
class SharedState(FlowState):
|
||||
value: int = 0
|
||||
|
||||
@persist(persistence)
|
||||
class FlowA(Flow[SharedState]):
|
||||
@start()
|
||||
def step(self):
|
||||
self.state.value += 10
|
||||
return self.state.value
|
||||
|
||||
@persist(persistence)
|
||||
class FlowB(Flow[SharedState]):
|
||||
@start()
|
||||
def step(self):
|
||||
self.state.value += 100
|
||||
return self.state.value
|
||||
|
||||
# Run FlowA
|
||||
flow_a1 = FlowA()
|
||||
flow_a1.kickoff()
|
||||
assert flow_a1.state.value == 10
|
||||
|
||||
# Run FlowB
|
||||
flow_b1 = FlowB()
|
||||
flow_b1.kickoff()
|
||||
assert flow_b1.state.value == 100
|
||||
|
||||
# New FlowA should restore FlowA's state (10), not FlowB's (100)
|
||||
flow_a2 = FlowA()
|
||||
assert flow_a2.state.value == 10
|
||||
|
||||
# New FlowB should restore FlowB's state (100), not FlowA's (10)
|
||||
flow_b2 = FlowB()
|
||||
assert flow_b2.state.value == 100
|
||||
|
||||
|
||||
def test_load_latest_by_class_returns_none_when_empty(tmp_path):
|
||||
"""Test that load_latest_by_class returns None for unknown classes."""
|
||||
db_path = os.path.join(tmp_path, "test_flows.db")
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
assert persistence.load_latest_by_class("NonExistentFlow") is None
|
||||
|
||||
|
||||
def test_save_state_stores_flow_class(tmp_path):
|
||||
"""Test that save_state stores the flow_class when provided."""
|
||||
db_path = os.path.join(tmp_path, "test_flows.db")
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
|
||||
persistence.save_state(
|
||||
flow_uuid="test-uuid",
|
||||
method_name="test_method",
|
||||
state_data={"id": "test-uuid", "value": 42},
|
||||
flow_class="MyTestFlow",
|
||||
)
|
||||
|
||||
# Should be retrievable by class name
|
||||
loaded = persistence.load_latest_by_class("MyTestFlow")
|
||||
assert loaded is not None
|
||||
assert loaded["value"] == 42
|
||||
assert loaded["id"] == "test-uuid"
|
||||
|
||||
# Should still be retrievable by UUID
|
||||
loaded_by_uuid = persistence.load_state("test-uuid")
|
||||
assert loaded_by_uuid is not None
|
||||
assert loaded_by_uuid["value"] == 42
|
||||
|
||||
|
||||
def test_method_level_persist_does_not_auto_restore(tmp_path):
|
||||
"""Test that method-level @persist does NOT auto-restore state.
|
||||
|
||||
Only class-level @persist should trigger auto-restore. Method-level
|
||||
@persist only saves state after method execution.
|
||||
"""
|
||||
db_path = os.path.join(tmp_path, "test_flows.db")
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
|
||||
class CounterState(FlowState):
|
||||
value: int = 0
|
||||
|
||||
class MethodLevelFlow(Flow[CounterState]):
|
||||
@start()
|
||||
@persist(persistence)
|
||||
def increment(self):
|
||||
self.state.value += 1
|
||||
return self.state.value
|
||||
|
||||
# First run
|
||||
flow1 = MethodLevelFlow(persistence=persistence)
|
||||
flow1.kickoff()
|
||||
assert flow1.state.value == 1
|
||||
|
||||
# Second run: no auto-restore since @persist is method-level
|
||||
flow2 = MethodLevelFlow(persistence=persistence)
|
||||
assert flow2.state.value == 0 # Default, not restored
|
||||
|
||||
|
||||
def test_persistence_with_base_model(tmp_path):
|
||||
db_path = os.path.join(tmp_path, "test_flows.db")
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
|
||||
Reference in New Issue
Block a user