Stateful flows (#1931)

* fix: ensure persisted state overrides class defaults

- Remove early return in Flow.__init__ to allow proper state initialization
- Add test_flow_default_override.py to verify state override behavior
- Fix issue where default values weren't being overridden by persisted state

Fixes the issue where persisted state values weren't properly overriding
class defaults when restarting a flow with a previously saved state ID.

Co-Authored-By: Joe Moura <joao@crewai.com>

* test: improve state restoration verification with has_set_count flag

Co-Authored-By: Joe Moura <joao@crewai.com>

* test: add has_set_count field to PoemState

Co-Authored-By: Joe Moura <joao@crewai.com>

* refactoring test

* fix: ensure persisted state overrides class defaults

- Remove early return in Flow.__init__ to allow proper state initialization
- Add test_flow_default_override.py to verify state override behavior
- Fix issue where default values weren't being overridden by persisted state

Fixes the issue where persisted state values weren't properly overriding
class defaults when restarting a flow with a previously saved state ID.

Co-Authored-By: Joe Moura <joao@crewai.com>

* test: improve state restoration verification with has_set_count flag

Co-Authored-By: Joe Moura <joao@crewai.com>

* test: add has_set_count field to PoemState

Co-Authored-By: Joe Moura <joao@crewai.com>

* refactoring test

* Fixing flow state

* fixing peristed stateful flows

* linter

* type fix

---------

Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Co-authored-by: Joe Moura <joao@crewai.com>
This commit is contained in:
João Moura
2025-01-20 13:30:09 -03:00
committed by GitHub
parent 3e4f112f39
commit ab2274caf0
9 changed files with 339 additions and 222 deletions

View File

@@ -0,0 +1,112 @@
"""Test that persisted state properly overrides default values."""
from crewai.flow.flow import Flow, FlowState, listen, start
from crewai.flow.persistence import persist
class PoemState(FlowState):
"""Test state model with default values that should be overridden."""
sentence_count: int = 1000 # Default that should be overridden
has_set_count: bool = False # Track whether we've set the count
poem_type: str = ""
def test_default_value_override():
"""Test that persisted state values override class defaults."""
@persist()
class PoemFlow(Flow[PoemState]):
initial_state = PoemState
@start()
def set_sentence_count(self):
if self.state.has_set_count and self.state.sentence_count == 2:
self.state.sentence_count = 3
elif self.state.has_set_count and self.state.sentence_count == 1000:
self.state.sentence_count = 1000
elif self.state.has_set_count and self.state.sentence_count == 5:
self.state.sentence_count = 5
else:
self.state.sentence_count = 2
self.state.has_set_count = True
# First run - should set sentence_count to 2
flow1 = PoemFlow()
flow1.kickoff()
original_uuid = flow1.state.id
assert flow1.state.sentence_count == 2
# Second run - should load sentence_count=2 instead of default 1000
flow2 = PoemFlow()
flow2.kickoff(inputs={"id": original_uuid})
assert flow2.state.sentence_count == 3 # Should load 2, not default 1000
# Fourth run - explicit override should work
flow3 = PoemFlow()
flow3.kickoff(inputs={
"id": original_uuid,
"has_set_count": True,
"sentence_count": 5, # Override persisted value
})
assert flow3.state.sentence_count == 5 # Should use override value
# Third run - should not load sentence_count=2 instead of default 1000
flow4 = PoemFlow()
flow4.kickoff(inputs={"has_set_count": True})
assert flow4.state.sentence_count == 1000 # Should load 1000, not 2
def test_multi_step_default_override():
"""Test default value override with multiple start methods."""
@persist()
class MultiStepPoemFlow(Flow[PoemState]):
initial_state = PoemState
@start()
def set_sentence_count(self):
print("Setting sentence count")
if not self.state.has_set_count:
self.state.sentence_count = 3
self.state.has_set_count = True
@listen(set_sentence_count)
def set_poem_type(self):
print("Setting poem type")
if self.state.sentence_count == 3:
self.state.poem_type = "haiku"
elif self.state.sentence_count == 5:
self.state.poem_type = "limerick"
else:
self.state.poem_type = "free_verse"
@listen(set_poem_type)
def finished(self):
print("finished")
# First run - should set both sentence count and poem type
flow1 = MultiStepPoemFlow()
flow1.kickoff()
original_uuid = flow1.state.id
assert flow1.state.sentence_count == 3
assert flow1.state.poem_type == "haiku"
# Second run - should load persisted state and update poem type
flow2 = MultiStepPoemFlow()
flow2.kickoff(inputs={
"id": original_uuid,
"sentence_count": 5
})
assert flow2.state.sentence_count == 5
assert flow2.state.poem_type == "limerick"
# Third run - new flow without persisted state should use defaults
flow3 = MultiStepPoemFlow()
flow3.kickoff(inputs={
"id": original_uuid
})
assert flow3.state.sentence_count == 5
assert flow3.state.poem_type == "limerick"

View File

@@ -1,12 +1,12 @@
"""Test flow state persistence functionality."""
import os
from typing import Dict, Optional
from typing import Dict
import pytest
from pydantic import BaseModel
from crewai.flow.flow import Flow, FlowState, start
from crewai.flow.flow import Flow, FlowState, listen, start
from crewai.flow.persistence import persist
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
@@ -73,13 +73,14 @@ def test_flow_state_restoration(tmp_path):
# First flow execution to create initial state
class RestorableFlow(Flow[TestState]):
initial_state = TestState
@start()
@persist(persistence)
def set_message(self):
self.state.message = "Original message"
self.state.counter = 42
if self.state.message == "":
self.state.message = "Original message"
if self.state.counter == 0:
self.state.counter = 42
# Create and persist initial state
flow1 = RestorableFlow(persistence=persistence)
@@ -87,11 +88,11 @@ def test_flow_state_restoration(tmp_path):
original_uuid = flow1.state.id
# Test case 1: Restore using restore_uuid with field override
flow2 = RestorableFlow(
persistence=persistence,
restore_uuid=original_uuid,
counter=43, # Override counter
)
flow2 = RestorableFlow(persistence=persistence)
flow2.kickoff(inputs={
"id": original_uuid,
"counter": 43
})
# Verify state restoration and selective field override
assert flow2.state.id == original_uuid
@@ -99,48 +100,17 @@ def test_flow_state_restoration(tmp_path):
assert flow2.state.counter == 43 # Overridden
# Test case 2: Restore using kwargs['id']
flow3 = RestorableFlow(
persistence=persistence,
id=original_uuid,
message="Updated message", # Override message
)
flow3 = RestorableFlow(persistence=persistence)
flow3.kickoff(inputs={
"id": original_uuid,
"message": "Updated message"
})
# Verify state restoration and selective field override
assert flow3.state.id == original_uuid
assert flow3.state.counter == 42 # Preserved
assert flow3.state.counter == 43 # Preserved
assert flow3.state.message == "Updated message" # Overridden
# Test case 3: Verify error on conflicting IDs
with pytest.raises(ValueError) as exc_info:
RestorableFlow(
persistence=persistence,
restore_uuid=original_uuid,
id="different-id", # Conflict with restore_uuid
)
assert "Conflicting IDs provided" in str(exc_info.value)
# Test case 4: Verify error on non-existent restore_uuid
with pytest.raises(ValueError) as exc_info:
RestorableFlow(
persistence=persistence,
restore_uuid="non-existent-uuid",
)
assert "No state found" in str(exc_info.value)
# Test case 5: Allow new state creation with kwargs['id']
new_uuid = "new-flow-id"
flow4 = RestorableFlow(
persistence=persistence,
id=new_uuid,
message="New message",
counter=100,
)
# Verify new state creation with provided ID
assert flow4.state.id == new_uuid
assert flow4.state.message == "New message"
assert flow4.state.counter == 100
def test_multiple_method_persistence(tmp_path):
"""Test state persistence across multiple method executions."""
@@ -148,48 +118,59 @@ def test_multiple_method_persistence(tmp_path):
persistence = SQLiteFlowPersistence(db_path)
class MultiStepFlow(Flow[TestState]):
initial_state = TestState
@start()
@persist(persistence)
def step_1(self):
self.state.counter = 1
self.state.message = "Step 1"
if self.state.counter == 1:
self.state.counter = 99999
self.state.message = "Step 99999"
else:
self.state.counter = 1
self.state.message = "Step 1"
@start()
@listen(step_1)
@persist(persistence)
def step_2(self):
self.state.counter = 2
self.state.message = "Step 2"
if self.state.counter == 1:
self.state.counter = 2
self.state.message = "Step 2"
flow = MultiStepFlow(persistence=persistence)
flow.kickoff()
flow2 = MultiStepFlow(persistence=persistence)
flow2.kickoff(inputs={"id": flow.state.id})
# Load final state
final_state = persistence.load_state(flow.state.id)
final_state = flow2.state
assert final_state is not None
assert final_state["counter"] == 2
assert final_state["message"] == "Step 2"
def test_persistence_error_handling(tmp_path):
"""Test error handling in persistence operations."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class InvalidFlow(Flow[TestState]):
# Missing id field in initial state
class InvalidState(BaseModel):
value: str = ""
initial_state = InvalidState
assert final_state.counter == 2
assert final_state.message == "Step 2"
class NoPersistenceMultiStepFlow(Flow[TestState]):
@start()
@persist(persistence)
def will_fail(self):
self.state.value = "test"
def step_1(self):
if self.state.counter == 1:
self.state.counter = 99999
self.state.message = "Step 99999"
else:
self.state.counter = 1
self.state.message = "Step 1"
with pytest.raises(ValueError) as exc_info:
flow = InvalidFlow(persistence=persistence)
@listen(step_1)
def step_2(self):
if self.state.counter == 1:
self.state.counter = 2
self.state.message = "Step 2"
assert "must have an 'id' field" in str(exc_info.value)
flow = NoPersistenceMultiStepFlow(persistence=persistence)
flow.kickoff()
flow2 = NoPersistenceMultiStepFlow(persistence=persistence)
flow2.kickoff(inputs={"id": flow.state.id})
# Load final state
final_state = flow2.state
assert final_state.counter == 99999
assert final_state.message == "Step 99999"