Files
crewAI/tests/test_flow_persistence.py
João Moura ab2274caf0 Stateful flows (#1931)
* fix: ensure persisted state overrides class defaults

- Remove early return in Flow.__init__ to allow proper state initialization
- Add test_flow_default_override.py to verify state override behavior
- Fix issue where default values weren't being overridden by persisted state

Fixes the issue where persisted state values weren't properly overriding
class defaults when restarting a flow with a previously saved state ID.

Co-Authored-By: Joe Moura <joao@crewai.com>

* test: improve state restoration verification with has_set_count flag

Co-Authored-By: Joe Moura <joao@crewai.com>

* test: add has_set_count field to PoemState

Co-Authored-By: Joe Moura <joao@crewai.com>

* refactoring test

* fix: ensure persisted state overrides class defaults

- Remove early return in Flow.__init__ to allow proper state initialization
- Add test_flow_default_override.py to verify state override behavior
- Fix issue where default values weren't being overridden by persisted state

Fixes the issue where persisted state values weren't properly overriding
class defaults when restarting a flow with a previously saved state ID.

Co-Authored-By: Joe Moura <joao@crewai.com>

* test: improve state restoration verification with has_set_count flag

Co-Authored-By: Joe Moura <joao@crewai.com>

* test: add has_set_count field to PoemState

Co-Authored-By: Joe Moura <joao@crewai.com>

* refactoring test

* Fixing flow state

* fixing peristed stateful flows

* linter

* type fix

---------

Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Co-authored-by: Joe Moura <joao@crewai.com>
2025-01-20 13:30:09 -03:00

177 lines
5.5 KiB
Python

"""Test flow state persistence functionality."""
import os
from typing import Dict
import pytest
from pydantic import BaseModel
from crewai.flow.flow import Flow, FlowState, listen, start
from crewai.flow.persistence import persist
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
class TestState(FlowState):
"""Test state model with required id field."""
counter: int = 0
message: str = ""
def test_persist_decorator_saves_state(tmp_path):
"""Test that @persist decorator saves state in SQLite."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class TestFlow(Flow[Dict[str, str]]):
initial_state = dict() # Use dict instance as initial state
@start()
@persist(persistence)
def init_step(self):
self.state["message"] = "Hello, World!"
self.state["id"] = "test-uuid" # Ensure we have an ID for persistence
# Run flow and verify state is saved
flow = TestFlow(persistence=persistence)
flow.kickoff()
# Load state from DB and verify
saved_state = persistence.load_state(flow.state["id"])
assert saved_state is not None
assert saved_state["message"] == "Hello, World!"
def test_structured_state_persistence(tmp_path):
"""Test persistence with Pydantic model state."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class StructuredFlow(Flow[TestState]):
initial_state = TestState
@start()
@persist(persistence)
def count_up(self):
self.state.counter += 1
self.state.message = f"Count is {self.state.counter}"
# Run flow and verify state changes are saved
flow = StructuredFlow(persistence=persistence)
flow.kickoff()
# Load and verify state
saved_state = persistence.load_state(flow.state.id)
assert saved_state is not None
assert saved_state["counter"] == 1
assert saved_state["message"] == "Count is 1"
def test_flow_state_restoration(tmp_path):
"""Test restoring flow state from persistence with various restoration methods."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
# First flow execution to create initial state
class RestorableFlow(Flow[TestState]):
@start()
@persist(persistence)
def set_message(self):
if self.state.message == "":
self.state.message = "Original message"
if self.state.counter == 0:
self.state.counter = 42
# Create and persist initial state
flow1 = RestorableFlow(persistence=persistence)
flow1.kickoff()
original_uuid = flow1.state.id
# Test case 1: Restore using restore_uuid with field override
flow2 = RestorableFlow(persistence=persistence)
flow2.kickoff(inputs={
"id": original_uuid,
"counter": 43
})
# Verify state restoration and selective field override
assert flow2.state.id == original_uuid
assert flow2.state.message == "Original message" # Preserved
assert flow2.state.counter == 43 # Overridden
# Test case 2: Restore using kwargs['id']
flow3 = RestorableFlow(persistence=persistence)
flow3.kickoff(inputs={
"id": original_uuid,
"message": "Updated message"
})
# Verify state restoration and selective field override
assert flow3.state.id == original_uuid
assert flow3.state.counter == 43 # Preserved
assert flow3.state.message == "Updated message" # Overridden
def test_multiple_method_persistence(tmp_path):
"""Test state persistence across multiple method executions."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class MultiStepFlow(Flow[TestState]):
@start()
@persist(persistence)
def step_1(self):
if self.state.counter == 1:
self.state.counter = 99999
self.state.message = "Step 99999"
else:
self.state.counter = 1
self.state.message = "Step 1"
@listen(step_1)
@persist(persistence)
def step_2(self):
if self.state.counter == 1:
self.state.counter = 2
self.state.message = "Step 2"
flow = MultiStepFlow(persistence=persistence)
flow.kickoff()
flow2 = MultiStepFlow(persistence=persistence)
flow2.kickoff(inputs={"id": flow.state.id})
# Load final state
final_state = flow2.state
assert final_state is not None
assert final_state.counter == 2
assert final_state.message == "Step 2"
class NoPersistenceMultiStepFlow(Flow[TestState]):
@start()
@persist(persistence)
def step_1(self):
if self.state.counter == 1:
self.state.counter = 99999
self.state.message = "Step 99999"
else:
self.state.counter = 1
self.state.message = "Step 1"
@listen(step_1)
def step_2(self):
if self.state.counter == 1:
self.state.counter = 2
self.state.message = "Step 2"
flow = NoPersistenceMultiStepFlow(persistence=persistence)
flow.kickoff()
flow2 = NoPersistenceMultiStepFlow(persistence=persistence)
flow2.kickoff(inputs={"id": flow.state.id})
# Load final state
final_state = flow2.state
assert final_state.counter == 99999
assert final_state.message == "Step 99999"