Files
crewAI/lib/crewai/tests/test_flow_persistence.py
Devin AI a24b011f6f fix: auto-restore persisted state for class-level @persist decorator
When @persist is applied at the class level, new flow instances now
automatically load the most recent persisted state for that flow class.
This matches the documented behavior where creating a new instance of a
@persist-decorated flow seamlessly continues from the previous run's state.

Changes:
- Add load_latest_by_class() to FlowPersistence base class
- Implement load_latest_by_class() in SQLiteFlowPersistence with
  flow_class column and index for efficient lookups
- Store flow class name when persisting state via PersistenceDecorator
- Auto-restore latest state in class-level @persist decorator's __init__
- Add migration for existing databases (ALTER TABLE ADD COLUMN)

Fixes #5378

Co-Authored-By: João <joao@crewai.com>
2026-04-09 06:32:49 +00:00

426 lines
13 KiB
Python

"""Test flow state persistence functionality."""
import os
from typing import Dict, List
from crewai.flow.flow import Flow, FlowState, listen, start
from crewai.flow.persistence import persist
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
from pydantic import BaseModel
class TestState(FlowState):
"""Test state model with required id field."""
counter: int = 0
message: str = ""
def test_persist_decorator_saves_state(tmp_path, caplog):
"""Test that @persist decorator saves state in SQLite."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class TestFlow(Flow[Dict[str, str]]):
initial_state = dict() # Use dict instance as initial state
@start()
@persist(persistence)
def init_step(self):
self.state["message"] = "Hello, World!"
self.state["id"] = "test-uuid" # Ensure we have an ID for persistence
# Run flow and verify state is saved
flow = TestFlow(persistence=persistence)
flow.kickoff()
# Load state from DB and verify
saved_state = persistence.load_state(flow.state["id"])
assert saved_state is not None
assert saved_state["message"] == "Hello, World!"
def test_structured_state_persistence(tmp_path):
"""Test persistence with Pydantic model state."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class StructuredFlow(Flow[TestState]):
initial_state = TestState
@start()
@persist(persistence)
def count_up(self):
self.state.counter += 1
self.state.message = f"Count is {self.state.counter}"
# Run flow and verify state changes are saved
flow = StructuredFlow(persistence=persistence)
flow.kickoff()
# Load and verify state
saved_state = persistence.load_state(flow.state.id)
assert saved_state is not None
assert saved_state["counter"] == 1
assert saved_state["message"] == "Count is 1"
def test_flow_state_restoration(tmp_path):
"""Test restoring flow state from persistence with various restoration methods."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
# First flow execution to create initial state
class RestorableFlow(Flow[TestState]):
@start()
@persist(persistence)
def set_message(self):
if self.state.message == "":
self.state.message = "Original message"
if self.state.counter == 0:
self.state.counter = 42
# Create and persist initial state
flow1 = RestorableFlow(persistence=persistence)
flow1.kickoff()
original_uuid = flow1.state.id
# Test case 1: Restore using restore_uuid with field override
flow2 = RestorableFlow(persistence=persistence)
flow2.kickoff(inputs={"id": original_uuid, "counter": 43})
# Verify state restoration and selective field override
assert flow2.state.id == original_uuid
assert flow2.state.message == "Original message" # Preserved
assert flow2.state.counter == 43 # Overridden
# Test case 2: Restore using kwargs['id']
flow3 = RestorableFlow(persistence=persistence)
flow3.kickoff(inputs={"id": original_uuid, "message": "Updated message"})
# Verify state restoration and selective field override
assert flow3.state.id == original_uuid
assert flow3.state.counter == 43 # Preserved
assert flow3.state.message == "Updated message" # Overridden
def test_multiple_method_persistence(tmp_path):
"""Test state persistence across multiple method executions."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class MultiStepFlow(Flow[TestState]):
@start()
@persist(persistence)
def step_1(self):
if self.state.counter == 1:
self.state.counter = 99999
self.state.message = "Step 99999"
else:
self.state.counter = 1
self.state.message = "Step 1"
@listen(step_1)
@persist(persistence)
def step_2(self):
if self.state.counter == 1:
self.state.counter = 2
self.state.message = "Step 2"
flow = MultiStepFlow(persistence=persistence)
flow.kickoff()
flow2 = MultiStepFlow(persistence=persistence)
flow2.kickoff(inputs={"id": flow.state.id})
# Load final state
final_state = flow2.state
assert final_state is not None
assert final_state.counter == 2
assert final_state.message == "Step 2"
class NoPersistenceMultiStepFlow(Flow[TestState]):
@start()
@persist(persistence)
def step_1(self):
if self.state.counter == 1:
self.state.counter = 99999
self.state.message = "Step 99999"
else:
self.state.counter = 1
self.state.message = "Step 1"
@listen(step_1)
def step_2(self):
if self.state.counter == 1:
self.state.counter = 2
self.state.message = "Step 2"
flow = NoPersistenceMultiStepFlow(persistence=persistence)
flow.kickoff()
flow2 = NoPersistenceMultiStepFlow(persistence=persistence)
flow2.kickoff(inputs={"id": flow.state.id})
# Load final state
final_state = flow2.state
assert final_state.counter == 99999
assert final_state.message == "Step 99999"
def test_persist_decorator_verbose_logging(tmp_path, caplog):
"""Test that @persist decorator's verbose parameter controls logging."""
# Set logging level to ensure we capture all logs
caplog.set_level("INFO")
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
# Test with verbose=False (default)
class QuietFlow(Flow[Dict[str, str]]):
initial_state = dict()
@start()
@persist(persistence) # Default verbose=False
def init_step(self):
self.state["message"] = "Hello, World!"
self.state["id"] = "test-uuid-1"
flow = QuietFlow(persistence=persistence)
flow.kickoff()
assert "Saving flow state" not in caplog.text
# Clear the log
caplog.clear()
# Test with verbose=True
class VerboseFlow(Flow[Dict[str, str]]):
initial_state = dict()
@start()
@persist(persistence, verbose=True)
def init_step(self):
self.state["message"] = "Hello, World!"
self.state["id"] = "test-uuid-2"
flow = VerboseFlow(persistence=persistence)
flow.kickoff()
assert "Saving flow state" in caplog.text
def test_class_level_persist_auto_restores_state(tmp_path):
"""Test that class-level @persist automatically restores state on new instance.
This is the documented behavior from the PersistentCounterFlow example:
when @persist is applied at the class level, a new flow instance should
automatically load the most recent persisted state for that flow class.
"""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class CounterState(FlowState):
value: int = 0
@persist(persistence)
class PersistentCounterFlow(Flow[CounterState]):
@start()
def increment(self):
self.state.value += 1
return self.state.value
@listen(increment)
def double(self, value):
self.state.value = value * 2
return self.state.value
# First run: 0 -> increment to 1 -> double to 2
flow1 = PersistentCounterFlow()
result1 = flow1.kickoff()
assert result1 == 2
assert flow1.state.value == 2
# Second run: state auto-restored to 2 -> increment to 3 -> double to 6
flow2 = PersistentCounterFlow()
# State should be auto-restored before kickoff
assert flow2.state.value == 2
result2 = flow2.kickoff()
assert result2 == 6
assert flow2.state.value == 6
# Third run: state auto-restored to 6 -> increment to 7 -> double to 14
flow3 = PersistentCounterFlow()
assert flow3.state.value == 6
result3 = flow3.kickoff()
assert result3 == 14
assert flow3.state.value == 14
def test_class_level_persist_auto_restores_dict_state(tmp_path):
"""Test auto-restore with unstructured (dict) state."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
@persist(persistence)
class DictCounterFlow(Flow):
@start()
def count(self):
current = self.state.get("counter", 0)
self.state["counter"] = current + 1
return self.state["counter"]
# First run
flow1 = DictCounterFlow()
result1 = flow1.kickoff()
assert result1 == 1
assert flow1.state["counter"] == 1
# Second run: auto-restores state
flow2 = DictCounterFlow()
assert flow2.state["counter"] == 1
result2 = flow2.kickoff()
assert result2 == 2
assert flow2.state["counter"] == 2
def test_class_level_persist_different_classes_dont_interfere(tmp_path):
"""Test that different flow classes with @persist don't interfere."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class SharedState(FlowState):
value: int = 0
@persist(persistence)
class FlowA(Flow[SharedState]):
@start()
def step(self):
self.state.value += 10
return self.state.value
@persist(persistence)
class FlowB(Flow[SharedState]):
@start()
def step(self):
self.state.value += 100
return self.state.value
# Run FlowA
flow_a1 = FlowA()
flow_a1.kickoff()
assert flow_a1.state.value == 10
# Run FlowB
flow_b1 = FlowB()
flow_b1.kickoff()
assert flow_b1.state.value == 100
# New FlowA should restore FlowA's state (10), not FlowB's (100)
flow_a2 = FlowA()
assert flow_a2.state.value == 10
# New FlowB should restore FlowB's state (100), not FlowA's (10)
flow_b2 = FlowB()
assert flow_b2.state.value == 100
def test_load_latest_by_class_returns_none_when_empty(tmp_path):
"""Test that load_latest_by_class returns None for unknown classes."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
assert persistence.load_latest_by_class("NonExistentFlow") is None
def test_save_state_stores_flow_class(tmp_path):
"""Test that save_state stores the flow_class when provided."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
persistence.save_state(
flow_uuid="test-uuid",
method_name="test_method",
state_data={"id": "test-uuid", "value": 42},
flow_class="MyTestFlow",
)
# Should be retrievable by class name
loaded = persistence.load_latest_by_class("MyTestFlow")
assert loaded is not None
assert loaded["value"] == 42
assert loaded["id"] == "test-uuid"
# Should still be retrievable by UUID
loaded_by_uuid = persistence.load_state("test-uuid")
assert loaded_by_uuid is not None
assert loaded_by_uuid["value"] == 42
def test_method_level_persist_does_not_auto_restore(tmp_path):
"""Test that method-level @persist does NOT auto-restore state.
Only class-level @persist should trigger auto-restore. Method-level
@persist only saves state after method execution.
"""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class CounterState(FlowState):
value: int = 0
class MethodLevelFlow(Flow[CounterState]):
@start()
@persist(persistence)
def increment(self):
self.state.value += 1
return self.state.value
# First run
flow1 = MethodLevelFlow(persistence=persistence)
flow1.kickoff()
assert flow1.state.value == 1
# Second run: no auto-restore since @persist is method-level
flow2 = MethodLevelFlow(persistence=persistence)
assert flow2.state.value == 0 # Default, not restored
def test_persistence_with_base_model(tmp_path):
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class Message(BaseModel):
role: str
type: str
content: str
class State(FlowState):
latest_message: Message | None = None
history: List[Message] = []
@persist(persistence)
class BaseModelFlow(Flow[State]):
initial_state = State(latest_message=None, history=[])
@start()
def init_step(self):
self.state.latest_message = Message(
role="user", type="text", content="Hello, World!"
)
self.state.history.append(self.state.latest_message)
flow = BaseModelFlow(persistence=persistence)
flow.kickoff()
latest_message = flow.state.latest_message
(message,) = flow.state.history
assert latest_message is not None
assert latest_message.role == "user"
assert latest_message.type == "text"
assert latest_message.content == "Hello, World!"
assert len(flow.state.history) == 1
assert message.role == "user"
assert message.type == "text"
assert message.content == "Hello, World!"
assert isinstance(flow.state._unwrap(), State)