Files
crewAI/lib/crewai/tests/test_flow_persistence.py
Vinicius Brasil 340d23ae5d Remove StateProxy from flow state access (#6327)
`StateProxy` looked like a thread-safety boundary, but it only protected
a small slice of state operations. Some examples of operations that were
not covered:

- `self.state.counter += 1`, `self.state["counter"] += 1` (increments)
- `self.state.user.profile.score += 1` (nested object mutations)
- `self.state.config["limits"]["max"] = 10` (mutation through model fields)
- `self.state.items[0].status = "done"` (list/container mutations)

This commit decided to remove it completely for simplicity and
performance:

- Simpler runtime code
- attr read: 24x faster, attr write: 27x faster, list append: 19x faster (local benchmark)
- Clearer concurrency contract (lifecycle locks remain, but arbitrary
  shared state mutation is not presented as thread-safe)
2026-06-24 16:37:51 -07:00

475 lines
15 KiB
Python

"""Test flow state persistence functionality."""
import os
from typing import Dict, List
import pytest
from crewai.flow.flow import Flow, FlowState, listen, start
from crewai.flow.persistence import persist
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
from pydantic import BaseModel
class TestState(FlowState):
"""Test state model with required id field."""
counter: int = 0
message: str = ""
def test_persist_decorator_saves_state(tmp_path, caplog):
"""Test that @persist decorator saves state in SQLite."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class TestFlow(Flow[Dict[str, str]]):
initial_state = dict()
@start()
@persist(persistence)
def init_step(self):
self.state["message"] = "Hello, World!"
self.state["id"] = "test-uuid" # Ensure we have an ID for persistence
flow = TestFlow(persistence=persistence)
flow.kickoff()
saved_state = persistence.load_state(flow.state["id"])
assert saved_state is not None
assert saved_state["message"] == "Hello, World!"
def test_structured_state_persistence(tmp_path):
"""Test persistence with Pydantic model state."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class StructuredFlow(Flow[TestState]):
initial_state = TestState
@start()
@persist(persistence)
def count_up(self):
self.state.counter += 1
self.state.message = f"Count is {self.state.counter}"
flow = StructuredFlow(persistence=persistence)
flow.kickoff()
saved_state = persistence.load_state(flow.state.id)
assert saved_state is not None
assert saved_state["counter"] == 1
assert saved_state["message"] == "Count is 1"
def test_flow_state_restoration(tmp_path):
"""Test restoring flow state from persistence with various restoration methods."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class RestorableFlow(Flow[TestState]):
@start()
@persist(persistence)
def set_message(self):
if self.state.message == "":
self.state.message = "Original message"
if self.state.counter == 0:
self.state.counter = 42
flow1 = RestorableFlow(persistence=persistence)
flow1.kickoff()
original_uuid = flow1.state.id
flow2 = RestorableFlow(persistence=persistence)
flow2.kickoff(inputs={"id": original_uuid, "counter": 43})
assert flow2.state.id == original_uuid
assert flow2.state.message == "Original message"
assert flow2.state.counter == 43
flow3 = RestorableFlow(persistence=persistence)
flow3.kickoff(inputs={"id": original_uuid, "message": "Updated message"})
assert flow3.state.id == original_uuid
assert flow3.state.counter == 43
assert flow3.state.message == "Updated message"
def test_multiple_method_persistence(tmp_path):
"""Test state persistence across multiple method executions."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class MultiStepFlow(Flow[TestState]):
@start()
@persist(persistence)
def step_1(self):
if self.state.counter == 1:
self.state.counter = 99999
self.state.message = "Step 99999"
else:
self.state.counter = 1
self.state.message = "Step 1"
@listen(step_1)
@persist(persistence)
def step_2(self):
if self.state.counter == 1:
self.state.counter = 2
self.state.message = "Step 2"
flow = MultiStepFlow(persistence=persistence)
flow.kickoff()
flow2 = MultiStepFlow(persistence=persistence)
flow2.kickoff(inputs={"id": flow.state.id})
final_state = flow2.state
assert final_state is not None
assert final_state.counter == 2
assert final_state.message == "Step 2"
class NoPersistenceMultiStepFlow(Flow[TestState]):
@start()
@persist(persistence)
def step_1(self):
if self.state.counter == 1:
self.state.counter = 99999
self.state.message = "Step 99999"
else:
self.state.counter = 1
self.state.message = "Step 1"
@listen(step_1)
def step_2(self):
if self.state.counter == 1:
self.state.counter = 2
self.state.message = "Step 2"
flow = NoPersistenceMultiStepFlow(persistence=persistence)
flow.kickoff()
flow2 = NoPersistenceMultiStepFlow(persistence=persistence)
flow2.kickoff(inputs={"id": flow.state.id})
final_state = flow2.state
assert final_state.counter == 99999
assert final_state.message == "Step 99999"
def test_persist_decorator_verbose_logging(tmp_path, caplog):
"""Test that @persist decorator's verbose parameter controls logging."""
# Set logging level to ensure we capture all logs
caplog.set_level("INFO")
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class QuietFlow(Flow[Dict[str, str]]):
initial_state = dict()
@start()
@persist(persistence)
def init_step(self):
self.state["message"] = "Hello, World!"
self.state["id"] = "test-uuid-1"
flow = QuietFlow(persistence=persistence)
flow.kickoff()
assert "Saving flow state" not in caplog.text
caplog.clear()
class VerboseFlow(Flow[Dict[str, str]]):
initial_state = dict()
@start()
@persist(persistence, verbose=True)
def init_step(self):
self.state["message"] = "Hello, World!"
self.state["id"] = "test-uuid-2"
flow = VerboseFlow(persistence=persistence)
flow.kickoff()
assert "Saving flow state" in caplog.text
def test_persistence_with_base_model(tmp_path):
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class Message(BaseModel):
role: str
type: str
content: str
class State(FlowState):
latest_message: Message | None = None
history: List[Message] = []
@persist(persistence)
class BaseModelFlow(Flow[State]):
initial_state = State(latest_message=None, history=[])
@start()
def init_step(self):
self.state.latest_message = Message(
role="user", type="text", content="Hello, World!"
)
self.state.history.append(self.state.latest_message)
flow = BaseModelFlow(persistence=persistence)
flow.kickoff()
latest_message = flow.state.latest_message
(message,) = flow.state.history
assert latest_message is not None
assert latest_message.role == "user"
assert latest_message.type == "text"
assert latest_message.content == "Hello, World!"
assert len(flow.state.history) == 1
assert message.role == "user"
assert message.type == "text"
assert message.content == "Hello, World!"
assert isinstance(flow.state, State)
def test_fork_with_restore_from_state_id(tmp_path):
"""Fork: restore_from_state_id hydrates state from source flow_uuid; new run gets a
fresh state.id; source's history is preserved (the fork's @persist writes go under
the new state.id, not the source's)."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class ForkableFlow(Flow[TestState]):
@start()
@persist(persistence)
def step(self):
self.state.counter += 1
flow1 = ForkableFlow(persistence=persistence)
flow1.kickoff()
source_uuid = flow1.state.id
assert flow1.state.counter == 1
# Resume on the same uuid bumps counter to 2 in the SAME flow_uuid history.
flow1b = ForkableFlow(persistence=persistence)
flow1b.kickoff(inputs={"id": source_uuid})
assert flow1b.state.counter == 2
assert persistence.load_state(source_uuid)["counter"] == 2
# Fork: hydrate from source, but persist under a fresh state.id.
flow2 = ForkableFlow(persistence=persistence)
flow2.kickoff(restore_from_state_id=source_uuid)
# Fork has a different state.id from the source.
assert flow2.state.id != source_uuid
# Hydrated from source's latest snapshot (counter=2), then incremented to 3.
assert flow2.state.counter == 3
# Source's history is unchanged after the fork.
assert persistence.load_state(source_uuid)["counter"] == 2
# Fork's writes landed under its own state.id.
assert persistence.load_state(flow2.state.id)["counter"] == 3
def test_fork_with_pinned_state_id(tmp_path):
"""Fork into a pinned state.id (inputs.id supplied alongside restore_from_state_id):
the new run uses inputs.id as state.id and hydrates from restore_from_state_id."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class PinnableFlow(Flow[TestState]):
@start()
@persist(persistence)
def step(self):
self.state.counter += 1
flow1 = PinnableFlow(persistence=persistence)
flow1.kickoff()
source_uuid = flow1.state.id
assert flow1.state.counter == 1
pinned_uuid = "pinned-fork-uuid-1234"
flow2 = PinnableFlow(persistence=persistence)
flow2.kickoff(
inputs={"id": pinned_uuid},
restore_from_state_id=source_uuid,
)
# state.id pinned to inputs.id, NOT the source uuid.
assert flow2.state.id == pinned_uuid
# Hydrated from source: counter started at 1, step incremented to 2.
assert flow2.state.counter == 2
# Source's history is unchanged.
assert persistence.load_state(source_uuid)["counter"] == 1
# Fork's writes are under the pinned uuid.
assert persistence.load_state(pinned_uuid)["counter"] == 2
def test_restore_from_state_id_not_found_silent_fallback(tmp_path):
"""Lookup miss on restore_from_state_id silently falls through to default behavior."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class FallbackFlow(Flow[TestState]):
@start()
@persist(persistence)
def step(self):
self.state.counter += 1
flow = FallbackFlow(persistence=persistence)
# No source UUID exists — should not raise.
flow.kickoff(restore_from_state_id="no-such-uuid")
# Default state path: counter starts at 0 and step increments to 1.
assert flow.state.counter == 1
# state.id is the auto-generated one, NOT the missing source.
assert flow.state.id != "no-such-uuid"
def test_restore_from_state_id_none_is_no_op(tmp_path):
"""restore_from_state_id=None (default) preserves baseline kickoff behavior."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class BaselineFlow(Flow[TestState]):
@start()
@persist(persistence)
def step(self):
self.state.counter += 1
flow = BaselineFlow(persistence=persistence)
flow.kickoff(restore_from_state_id=None)
assert flow.state.counter == 1
def test_fork_conflict_with_from_checkpoint_raises():
"""Passing both from_checkpoint and restore_from_state_id raises ValueError, naming
both parameters."""
from crewai.state import CheckpointConfig
class ConflictFlow(Flow[TestState]):
@start()
def step(self):
pass
flow = ConflictFlow()
with pytest.raises(ValueError) as excinfo:
flow.kickoff(
from_checkpoint=CheckpointConfig(),
restore_from_state_id="some-uuid",
)
msg = str(excinfo.value)
assert "from_checkpoint" in msg
assert "restore_from_state_id" in msg
@pytest.mark.asyncio
async def test_fork_via_kickoff_async(tmp_path):
"""kickoff_async honors restore_from_state_id: hydrates from source, mints fresh
state.id, persists under the new id, source history preserved."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class AsyncForkableFlow(Flow[TestState]):
@start()
@persist(persistence)
def step(self):
self.state.counter += 1
flow1 = AsyncForkableFlow(persistence=persistence)
await flow1.kickoff_async()
source_uuid = flow1.state.id
assert flow1.state.counter == 1
flow2 = AsyncForkableFlow(persistence=persistence)
await flow2.kickoff_async(restore_from_state_id=source_uuid)
assert flow2.state.id != source_uuid
assert flow2.state.counter == 2
assert persistence.load_state(source_uuid)["counter"] == 1
assert persistence.load_state(flow2.state.id)["counter"] == 2
@pytest.mark.asyncio
async def test_fork_via_akickoff(tmp_path):
"""akickoff is the public async alias and must accept restore_from_state_id with
the same semantics as kickoff_async."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class AkickoffForkableFlow(Flow[TestState]):
@start()
@persist(persistence)
def step(self):
self.state.counter += 1
flow1 = AkickoffForkableFlow(persistence=persistence)
await flow1.akickoff()
source_uuid = flow1.state.id
assert flow1.state.counter == 1
flow2 = AkickoffForkableFlow(persistence=persistence)
await flow2.akickoff(restore_from_state_id=source_uuid)
assert flow2.state.id != source_uuid
assert flow2.state.counter == 2
assert persistence.load_state(source_uuid)["counter"] == 1
assert persistence.load_state(flow2.state.id)["counter"] == 2
@pytest.mark.asyncio
async def test_akickoff_pinned_fork(tmp_path):
"""akickoff with both inputs.id and restore_from_state_id pins state.id while
hydrating from the source."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class PinnableAsyncFlow(Flow[TestState]):
@start()
@persist(persistence)
def step(self):
self.state.counter += 1
flow1 = PinnableAsyncFlow(persistence=persistence)
await flow1.akickoff()
source_uuid = flow1.state.id
pinned_uuid = "pinned-akickoff-fork-uuid"
flow2 = PinnableAsyncFlow(persistence=persistence)
await flow2.akickoff(
inputs={"id": pinned_uuid},
restore_from_state_id=source_uuid,
)
assert flow2.state.id == pinned_uuid
assert flow2.state.counter == 2
assert persistence.load_state(source_uuid)["counter"] == 1
assert persistence.load_state(pinned_uuid)["counter"] == 2
@pytest.mark.asyncio
async def test_akickoff_fork_conflict_with_from_checkpoint_raises():
"""akickoff must raise the same conflict ValueError as kickoff/kickoff_async when
both from_checkpoint and restore_from_state_id are set."""
from crewai.state import CheckpointConfig
class AsyncConflictFlow(Flow[TestState]):
@start()
def step(self):
pass
flow = AsyncConflictFlow()
with pytest.raises(ValueError) as excinfo:
await flow.akickoff(
from_checkpoint=CheckpointConfig(),
restore_from_state_id="some-uuid",
)
msg = str(excinfo.value)
assert "from_checkpoint" in msg
assert "restore_from_state_id" in msg