mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-07 18:19:00 +00:00
When @persist is applied at the class level, new flow instances now automatically load the most recent persisted state for that flow class. This matches the documented behavior where creating a new instance of a @persist-decorated flow seamlessly continues from the previous run's state. Changes: - Add load_latest_by_class() to FlowPersistence base class - Implement load_latest_by_class() in SQLiteFlowPersistence with flow_class column and index for efficient lookups - Store flow class name when persisting state via PersistenceDecorator - Auto-restore latest state in class-level @persist decorator's __init__ - Add migration for existing databases (ALTER TABLE ADD COLUMN) Fixes #5378 Co-Authored-By: João <joao@crewai.com>
426 lines
13 KiB
Python
426 lines
13 KiB
Python
"""Test flow state persistence functionality."""
|
|
|
|
import os
|
|
from typing import Dict, List
|
|
|
|
from crewai.flow.flow import Flow, FlowState, listen, start
|
|
from crewai.flow.persistence import persist
|
|
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
|
from pydantic import BaseModel
|
|
|
|
|
|
class TestState(FlowState):
|
|
"""Test state model with required id field."""
|
|
|
|
counter: int = 0
|
|
message: str = ""
|
|
|
|
|
|
def test_persist_decorator_saves_state(tmp_path, caplog):
|
|
"""Test that @persist decorator saves state in SQLite."""
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
class TestFlow(Flow[Dict[str, str]]):
|
|
initial_state = dict() # Use dict instance as initial state
|
|
|
|
@start()
|
|
@persist(persistence)
|
|
def init_step(self):
|
|
self.state["message"] = "Hello, World!"
|
|
self.state["id"] = "test-uuid" # Ensure we have an ID for persistence
|
|
|
|
# Run flow and verify state is saved
|
|
flow = TestFlow(persistence=persistence)
|
|
flow.kickoff()
|
|
|
|
# Load state from DB and verify
|
|
saved_state = persistence.load_state(flow.state["id"])
|
|
assert saved_state is not None
|
|
assert saved_state["message"] == "Hello, World!"
|
|
|
|
|
|
def test_structured_state_persistence(tmp_path):
|
|
"""Test persistence with Pydantic model state."""
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
class StructuredFlow(Flow[TestState]):
|
|
initial_state = TestState
|
|
|
|
@start()
|
|
@persist(persistence)
|
|
def count_up(self):
|
|
self.state.counter += 1
|
|
self.state.message = f"Count is {self.state.counter}"
|
|
|
|
# Run flow and verify state changes are saved
|
|
flow = StructuredFlow(persistence=persistence)
|
|
flow.kickoff()
|
|
|
|
# Load and verify state
|
|
saved_state = persistence.load_state(flow.state.id)
|
|
assert saved_state is not None
|
|
assert saved_state["counter"] == 1
|
|
assert saved_state["message"] == "Count is 1"
|
|
|
|
|
|
def test_flow_state_restoration(tmp_path):
|
|
"""Test restoring flow state from persistence with various restoration methods."""
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
# First flow execution to create initial state
|
|
class RestorableFlow(Flow[TestState]):
|
|
@start()
|
|
@persist(persistence)
|
|
def set_message(self):
|
|
if self.state.message == "":
|
|
self.state.message = "Original message"
|
|
if self.state.counter == 0:
|
|
self.state.counter = 42
|
|
|
|
# Create and persist initial state
|
|
flow1 = RestorableFlow(persistence=persistence)
|
|
flow1.kickoff()
|
|
original_uuid = flow1.state.id
|
|
|
|
# Test case 1: Restore using restore_uuid with field override
|
|
flow2 = RestorableFlow(persistence=persistence)
|
|
flow2.kickoff(inputs={"id": original_uuid, "counter": 43})
|
|
|
|
# Verify state restoration and selective field override
|
|
assert flow2.state.id == original_uuid
|
|
assert flow2.state.message == "Original message" # Preserved
|
|
assert flow2.state.counter == 43 # Overridden
|
|
|
|
# Test case 2: Restore using kwargs['id']
|
|
flow3 = RestorableFlow(persistence=persistence)
|
|
flow3.kickoff(inputs={"id": original_uuid, "message": "Updated message"})
|
|
|
|
# Verify state restoration and selective field override
|
|
assert flow3.state.id == original_uuid
|
|
assert flow3.state.counter == 43 # Preserved
|
|
assert flow3.state.message == "Updated message" # Overridden
|
|
|
|
|
|
def test_multiple_method_persistence(tmp_path):
|
|
"""Test state persistence across multiple method executions."""
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
class MultiStepFlow(Flow[TestState]):
|
|
@start()
|
|
@persist(persistence)
|
|
def step_1(self):
|
|
if self.state.counter == 1:
|
|
self.state.counter = 99999
|
|
self.state.message = "Step 99999"
|
|
else:
|
|
self.state.counter = 1
|
|
self.state.message = "Step 1"
|
|
|
|
@listen(step_1)
|
|
@persist(persistence)
|
|
def step_2(self):
|
|
if self.state.counter == 1:
|
|
self.state.counter = 2
|
|
self.state.message = "Step 2"
|
|
|
|
flow = MultiStepFlow(persistence=persistence)
|
|
flow.kickoff()
|
|
|
|
flow2 = MultiStepFlow(persistence=persistence)
|
|
flow2.kickoff(inputs={"id": flow.state.id})
|
|
|
|
# Load final state
|
|
final_state = flow2.state
|
|
assert final_state is not None
|
|
assert final_state.counter == 2
|
|
assert final_state.message == "Step 2"
|
|
|
|
class NoPersistenceMultiStepFlow(Flow[TestState]):
|
|
@start()
|
|
@persist(persistence)
|
|
def step_1(self):
|
|
if self.state.counter == 1:
|
|
self.state.counter = 99999
|
|
self.state.message = "Step 99999"
|
|
else:
|
|
self.state.counter = 1
|
|
self.state.message = "Step 1"
|
|
|
|
@listen(step_1)
|
|
def step_2(self):
|
|
if self.state.counter == 1:
|
|
self.state.counter = 2
|
|
self.state.message = "Step 2"
|
|
|
|
flow = NoPersistenceMultiStepFlow(persistence=persistence)
|
|
flow.kickoff()
|
|
|
|
flow2 = NoPersistenceMultiStepFlow(persistence=persistence)
|
|
flow2.kickoff(inputs={"id": flow.state.id})
|
|
|
|
# Load final state
|
|
final_state = flow2.state
|
|
assert final_state.counter == 99999
|
|
assert final_state.message == "Step 99999"
|
|
|
|
|
|
def test_persist_decorator_verbose_logging(tmp_path, caplog):
|
|
"""Test that @persist decorator's verbose parameter controls logging."""
|
|
# Set logging level to ensure we capture all logs
|
|
caplog.set_level("INFO")
|
|
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
# Test with verbose=False (default)
|
|
class QuietFlow(Flow[Dict[str, str]]):
|
|
initial_state = dict()
|
|
|
|
@start()
|
|
@persist(persistence) # Default verbose=False
|
|
def init_step(self):
|
|
self.state["message"] = "Hello, World!"
|
|
self.state["id"] = "test-uuid-1"
|
|
|
|
flow = QuietFlow(persistence=persistence)
|
|
flow.kickoff()
|
|
assert "Saving flow state" not in caplog.text
|
|
|
|
# Clear the log
|
|
caplog.clear()
|
|
|
|
# Test with verbose=True
|
|
class VerboseFlow(Flow[Dict[str, str]]):
|
|
initial_state = dict()
|
|
|
|
@start()
|
|
@persist(persistence, verbose=True)
|
|
def init_step(self):
|
|
self.state["message"] = "Hello, World!"
|
|
self.state["id"] = "test-uuid-2"
|
|
|
|
flow = VerboseFlow(persistence=persistence)
|
|
flow.kickoff()
|
|
assert "Saving flow state" in caplog.text
|
|
|
|
|
|
def test_class_level_persist_auto_restores_state(tmp_path):
|
|
"""Test that class-level @persist automatically restores state on new instance.
|
|
|
|
This is the documented behavior from the PersistentCounterFlow example:
|
|
when @persist is applied at the class level, a new flow instance should
|
|
automatically load the most recent persisted state for that flow class.
|
|
"""
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
class CounterState(FlowState):
|
|
value: int = 0
|
|
|
|
@persist(persistence)
|
|
class PersistentCounterFlow(Flow[CounterState]):
|
|
@start()
|
|
def increment(self):
|
|
self.state.value += 1
|
|
return self.state.value
|
|
|
|
@listen(increment)
|
|
def double(self, value):
|
|
self.state.value = value * 2
|
|
return self.state.value
|
|
|
|
# First run: 0 -> increment to 1 -> double to 2
|
|
flow1 = PersistentCounterFlow()
|
|
result1 = flow1.kickoff()
|
|
assert result1 == 2
|
|
assert flow1.state.value == 2
|
|
|
|
# Second run: state auto-restored to 2 -> increment to 3 -> double to 6
|
|
flow2 = PersistentCounterFlow()
|
|
# State should be auto-restored before kickoff
|
|
assert flow2.state.value == 2
|
|
result2 = flow2.kickoff()
|
|
assert result2 == 6
|
|
assert flow2.state.value == 6
|
|
|
|
# Third run: state auto-restored to 6 -> increment to 7 -> double to 14
|
|
flow3 = PersistentCounterFlow()
|
|
assert flow3.state.value == 6
|
|
result3 = flow3.kickoff()
|
|
assert result3 == 14
|
|
assert flow3.state.value == 14
|
|
|
|
|
|
def test_class_level_persist_auto_restores_dict_state(tmp_path):
|
|
"""Test auto-restore with unstructured (dict) state."""
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
@persist(persistence)
|
|
class DictCounterFlow(Flow):
|
|
@start()
|
|
def count(self):
|
|
current = self.state.get("counter", 0)
|
|
self.state["counter"] = current + 1
|
|
return self.state["counter"]
|
|
|
|
# First run
|
|
flow1 = DictCounterFlow()
|
|
result1 = flow1.kickoff()
|
|
assert result1 == 1
|
|
assert flow1.state["counter"] == 1
|
|
|
|
# Second run: auto-restores state
|
|
flow2 = DictCounterFlow()
|
|
assert flow2.state["counter"] == 1
|
|
result2 = flow2.kickoff()
|
|
assert result2 == 2
|
|
assert flow2.state["counter"] == 2
|
|
|
|
|
|
def test_class_level_persist_different_classes_dont_interfere(tmp_path):
|
|
"""Test that different flow classes with @persist don't interfere."""
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
class SharedState(FlowState):
|
|
value: int = 0
|
|
|
|
@persist(persistence)
|
|
class FlowA(Flow[SharedState]):
|
|
@start()
|
|
def step(self):
|
|
self.state.value += 10
|
|
return self.state.value
|
|
|
|
@persist(persistence)
|
|
class FlowB(Flow[SharedState]):
|
|
@start()
|
|
def step(self):
|
|
self.state.value += 100
|
|
return self.state.value
|
|
|
|
# Run FlowA
|
|
flow_a1 = FlowA()
|
|
flow_a1.kickoff()
|
|
assert flow_a1.state.value == 10
|
|
|
|
# Run FlowB
|
|
flow_b1 = FlowB()
|
|
flow_b1.kickoff()
|
|
assert flow_b1.state.value == 100
|
|
|
|
# New FlowA should restore FlowA's state (10), not FlowB's (100)
|
|
flow_a2 = FlowA()
|
|
assert flow_a2.state.value == 10
|
|
|
|
# New FlowB should restore FlowB's state (100), not FlowA's (10)
|
|
flow_b2 = FlowB()
|
|
assert flow_b2.state.value == 100
|
|
|
|
|
|
def test_load_latest_by_class_returns_none_when_empty(tmp_path):
|
|
"""Test that load_latest_by_class returns None for unknown classes."""
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
assert persistence.load_latest_by_class("NonExistentFlow") is None
|
|
|
|
|
|
def test_save_state_stores_flow_class(tmp_path):
|
|
"""Test that save_state stores the flow_class when provided."""
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
persistence.save_state(
|
|
flow_uuid="test-uuid",
|
|
method_name="test_method",
|
|
state_data={"id": "test-uuid", "value": 42},
|
|
flow_class="MyTestFlow",
|
|
)
|
|
|
|
# Should be retrievable by class name
|
|
loaded = persistence.load_latest_by_class("MyTestFlow")
|
|
assert loaded is not None
|
|
assert loaded["value"] == 42
|
|
assert loaded["id"] == "test-uuid"
|
|
|
|
# Should still be retrievable by UUID
|
|
loaded_by_uuid = persistence.load_state("test-uuid")
|
|
assert loaded_by_uuid is not None
|
|
assert loaded_by_uuid["value"] == 42
|
|
|
|
|
|
def test_method_level_persist_does_not_auto_restore(tmp_path):
|
|
"""Test that method-level @persist does NOT auto-restore state.
|
|
|
|
Only class-level @persist should trigger auto-restore. Method-level
|
|
@persist only saves state after method execution.
|
|
"""
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
class CounterState(FlowState):
|
|
value: int = 0
|
|
|
|
class MethodLevelFlow(Flow[CounterState]):
|
|
@start()
|
|
@persist(persistence)
|
|
def increment(self):
|
|
self.state.value += 1
|
|
return self.state.value
|
|
|
|
# First run
|
|
flow1 = MethodLevelFlow(persistence=persistence)
|
|
flow1.kickoff()
|
|
assert flow1.state.value == 1
|
|
|
|
# Second run: no auto-restore since @persist is method-level
|
|
flow2 = MethodLevelFlow(persistence=persistence)
|
|
assert flow2.state.value == 0 # Default, not restored
|
|
|
|
|
|
def test_persistence_with_base_model(tmp_path):
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
class Message(BaseModel):
|
|
role: str
|
|
type: str
|
|
content: str
|
|
|
|
class State(FlowState):
|
|
latest_message: Message | None = None
|
|
history: List[Message] = []
|
|
|
|
@persist(persistence)
|
|
class BaseModelFlow(Flow[State]):
|
|
initial_state = State(latest_message=None, history=[])
|
|
|
|
@start()
|
|
def init_step(self):
|
|
self.state.latest_message = Message(
|
|
role="user", type="text", content="Hello, World!"
|
|
)
|
|
self.state.history.append(self.state.latest_message)
|
|
|
|
flow = BaseModelFlow(persistence=persistence)
|
|
flow.kickoff()
|
|
|
|
latest_message = flow.state.latest_message
|
|
(message,) = flow.state.history
|
|
|
|
assert latest_message is not None
|
|
assert latest_message.role == "user"
|
|
assert latest_message.type == "text"
|
|
assert latest_message.content == "Hello, World!"
|
|
|
|
assert len(flow.state.history) == 1
|
|
assert message.role == "user"
|
|
assert message.type == "text"
|
|
assert message.content == "Hello, World!"
|
|
assert isinstance(flow.state._unwrap(), State)
|