mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Stateful flows (#1931)
* fix: ensure persisted state overrides class defaults - Remove early return in Flow.__init__ to allow proper state initialization - Add test_flow_default_override.py to verify state override behavior - Fix issue where default values weren't being overridden by persisted state Fixes the issue where persisted state values weren't properly overriding class defaults when restarting a flow with a previously saved state ID. Co-Authored-By: Joe Moura <joao@crewai.com> * test: improve state restoration verification with has_set_count flag Co-Authored-By: Joe Moura <joao@crewai.com> * test: add has_set_count field to PoemState Co-Authored-By: Joe Moura <joao@crewai.com> * refactoring test * fix: ensure persisted state overrides class defaults - Remove early return in Flow.__init__ to allow proper state initialization - Add test_flow_default_override.py to verify state override behavior - Fix issue where default values weren't being overridden by persisted state Fixes the issue where persisted state values weren't properly overriding class defaults when restarting a flow with a previously saved state ID. Co-Authored-By: Joe Moura <joao@crewai.com> * test: improve state restoration verification with has_set_count flag Co-Authored-By: Joe Moura <joao@crewai.com> * test: add has_set_count field to PoemState Co-Authored-By: Joe Moura <joao@crewai.com> * refactoring test * Fixing flow state * fixing peristed stateful flows * linter * type fix --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
"""Test flow state persistence functionality."""
|
||||
|
||||
import os
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.flow.flow import Flow, FlowState, start
|
||||
from crewai.flow.flow import Flow, FlowState, listen, start
|
||||
from crewai.flow.persistence import persist
|
||||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||
|
||||
@@ -73,13 +73,14 @@ def test_flow_state_restoration(tmp_path):
|
||||
|
||||
# First flow execution to create initial state
|
||||
class RestorableFlow(Flow[TestState]):
|
||||
initial_state = TestState
|
||||
|
||||
@start()
|
||||
@persist(persistence)
|
||||
def set_message(self):
|
||||
self.state.message = "Original message"
|
||||
self.state.counter = 42
|
||||
if self.state.message == "":
|
||||
self.state.message = "Original message"
|
||||
if self.state.counter == 0:
|
||||
self.state.counter = 42
|
||||
|
||||
# Create and persist initial state
|
||||
flow1 = RestorableFlow(persistence=persistence)
|
||||
@@ -87,11 +88,11 @@ def test_flow_state_restoration(tmp_path):
|
||||
original_uuid = flow1.state.id
|
||||
|
||||
# Test case 1: Restore using restore_uuid with field override
|
||||
flow2 = RestorableFlow(
|
||||
persistence=persistence,
|
||||
restore_uuid=original_uuid,
|
||||
counter=43, # Override counter
|
||||
)
|
||||
flow2 = RestorableFlow(persistence=persistence)
|
||||
flow2.kickoff(inputs={
|
||||
"id": original_uuid,
|
||||
"counter": 43
|
||||
})
|
||||
|
||||
# Verify state restoration and selective field override
|
||||
assert flow2.state.id == original_uuid
|
||||
@@ -99,48 +100,17 @@ def test_flow_state_restoration(tmp_path):
|
||||
assert flow2.state.counter == 43 # Overridden
|
||||
|
||||
# Test case 2: Restore using kwargs['id']
|
||||
flow3 = RestorableFlow(
|
||||
persistence=persistence,
|
||||
id=original_uuid,
|
||||
message="Updated message", # Override message
|
||||
)
|
||||
flow3 = RestorableFlow(persistence=persistence)
|
||||
flow3.kickoff(inputs={
|
||||
"id": original_uuid,
|
||||
"message": "Updated message"
|
||||
})
|
||||
|
||||
# Verify state restoration and selective field override
|
||||
assert flow3.state.id == original_uuid
|
||||
assert flow3.state.counter == 42 # Preserved
|
||||
assert flow3.state.counter == 43 # Preserved
|
||||
assert flow3.state.message == "Updated message" # Overridden
|
||||
|
||||
# Test case 3: Verify error on conflicting IDs
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
RestorableFlow(
|
||||
persistence=persistence,
|
||||
restore_uuid=original_uuid,
|
||||
id="different-id", # Conflict with restore_uuid
|
||||
)
|
||||
assert "Conflicting IDs provided" in str(exc_info.value)
|
||||
|
||||
# Test case 4: Verify error on non-existent restore_uuid
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
RestorableFlow(
|
||||
persistence=persistence,
|
||||
restore_uuid="non-existent-uuid",
|
||||
)
|
||||
assert "No state found" in str(exc_info.value)
|
||||
|
||||
# Test case 5: Allow new state creation with kwargs['id']
|
||||
new_uuid = "new-flow-id"
|
||||
flow4 = RestorableFlow(
|
||||
persistence=persistence,
|
||||
id=new_uuid,
|
||||
message="New message",
|
||||
counter=100,
|
||||
)
|
||||
|
||||
# Verify new state creation with provided ID
|
||||
assert flow4.state.id == new_uuid
|
||||
assert flow4.state.message == "New message"
|
||||
assert flow4.state.counter == 100
|
||||
|
||||
|
||||
def test_multiple_method_persistence(tmp_path):
|
||||
"""Test state persistence across multiple method executions."""
|
||||
@@ -148,48 +118,59 @@ def test_multiple_method_persistence(tmp_path):
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
|
||||
class MultiStepFlow(Flow[TestState]):
|
||||
initial_state = TestState
|
||||
|
||||
@start()
|
||||
@persist(persistence)
|
||||
def step_1(self):
|
||||
self.state.counter = 1
|
||||
self.state.message = "Step 1"
|
||||
if self.state.counter == 1:
|
||||
self.state.counter = 99999
|
||||
self.state.message = "Step 99999"
|
||||
else:
|
||||
self.state.counter = 1
|
||||
self.state.message = "Step 1"
|
||||
|
||||
@start()
|
||||
@listen(step_1)
|
||||
@persist(persistence)
|
||||
def step_2(self):
|
||||
self.state.counter = 2
|
||||
self.state.message = "Step 2"
|
||||
if self.state.counter == 1:
|
||||
self.state.counter = 2
|
||||
self.state.message = "Step 2"
|
||||
|
||||
flow = MultiStepFlow(persistence=persistence)
|
||||
flow.kickoff()
|
||||
|
||||
flow2 = MultiStepFlow(persistence=persistence)
|
||||
flow2.kickoff(inputs={"id": flow.state.id})
|
||||
|
||||
# Load final state
|
||||
final_state = persistence.load_state(flow.state.id)
|
||||
final_state = flow2.state
|
||||
assert final_state is not None
|
||||
assert final_state["counter"] == 2
|
||||
assert final_state["message"] == "Step 2"
|
||||
|
||||
|
||||
def test_persistence_error_handling(tmp_path):
|
||||
"""Test error handling in persistence operations."""
|
||||
db_path = os.path.join(tmp_path, "test_flows.db")
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
|
||||
class InvalidFlow(Flow[TestState]):
|
||||
# Missing id field in initial state
|
||||
class InvalidState(BaseModel):
|
||||
value: str = ""
|
||||
|
||||
initial_state = InvalidState
|
||||
assert final_state.counter == 2
|
||||
assert final_state.message == "Step 2"
|
||||
|
||||
class NoPersistenceMultiStepFlow(Flow[TestState]):
|
||||
@start()
|
||||
@persist(persistence)
|
||||
def will_fail(self):
|
||||
self.state.value = "test"
|
||||
def step_1(self):
|
||||
if self.state.counter == 1:
|
||||
self.state.counter = 99999
|
||||
self.state.message = "Step 99999"
|
||||
else:
|
||||
self.state.counter = 1
|
||||
self.state.message = "Step 1"
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
flow = InvalidFlow(persistence=persistence)
|
||||
@listen(step_1)
|
||||
def step_2(self):
|
||||
if self.state.counter == 1:
|
||||
self.state.counter = 2
|
||||
self.state.message = "Step 2"
|
||||
|
||||
assert "must have an 'id' field" in str(exc_info.value)
|
||||
flow = NoPersistenceMultiStepFlow(persistence=persistence)
|
||||
flow.kickoff()
|
||||
|
||||
flow2 = NoPersistenceMultiStepFlow(persistence=persistence)
|
||||
flow2.kickoff(inputs={"id": flow.state.id})
|
||||
|
||||
# Load final state
|
||||
final_state = flow2.state
|
||||
assert final_state.counter == 99999
|
||||
assert final_state.message == "Step 99999"
|
||||
|
||||
Reference in New Issue
Block a user