Files
crewAI/lib/crewai/tests/test_flow_persistence.py
Tiago Freire cd2b9ee38a feat(flow): add restore_from_state_id kickoff parameter (#5674)
## Summary

- Reverts `b0e2fda` ("fix(flow): add execution_id separate from state.id", COR-48): removes `Flow.execution_id` and points `current_flow_id` / `current_flow_request_id` back at `flow_id` (i.e. `state.id`). The separate per-run tracking id was no longer the right abstraction once `restore_from_state_id` reshapes how `state.id` is assigned;

- Adds an optional `restore_from_state_id` kwarg to `Flow.kickoff` / `Flow.kickoff_async` that hydrates state from a previously-persisted flow's latest snapshot

- Reassigns `state.id` to a fresh value (or `inputs["id"]` if pinned) so the new run's `@persist` writes don't extend the source's history

- Existing `inputs["id"]` resume, `@persist`, and `from_checkpoint` paths are unchanged

## Problem
`@persist` only supports *resume* today: `kickoff(inputs={"id": <uuid>})` hydrates state and continues writing under the same `flow_uuid`. There's no way to **fork** — hydrate from a snapshot but persist under a separate key, leaving the source's history intact. This PR adds that.

| | `state.id` after kickoff | `@persist` writes land under |
|---|---|---|
| `inputs["id"]` (resume) | supplied id | supplied id (extends history) |
| `restore_from_state_id` (fork) | fresh id, or `inputs["id"]` if pinned | new id (source preserved) |

## Behavior

| `inputs.id` | `restore_from_state_id` | Effect |
|---|---|---|
| — | — | Fresh kickoff |
| set | — | Existing resume |
| — | UUID | Fork — new `state.id`, hydrated from source |
| set | UUID | Fork into a pinned `state.id`, hydrated from source |

- Source not found → silent fallback (mirrors existing resume)
- Both `from_checkpoint` and `restore_from_state_id` set → `ValueError`
- `restore_from_state_id=None` → byte-identical to current main

## Design
Fork hydration runs before the existing `inputs` block in `kickoff_async`. On a hit, it calls the same `_restore_state` primitive used by resume, then overwrites `state.id` with a fresh UUID (or `inputs["id"]`). A `fork_succeeded` flag gates the existing `inputs["id"]` path so we don't double-load. `_completed_methods` / `_is_execution_resuming` are intentionally untouched — skip-completed-methods remains the territory of `apply_checkpoint` and `from_pending`.

## Test plan
- [ ] `pytest tests/test_flow_persistence.py` — 5 new tests (four-row matrix, not-found fallback, default no-op, conflict raise) + 6 existing as regression
- [ ] `pytest tests/test_flow.py` — broader flow suite
- [ ] Manual end-to-end against an HITL `@persist` flow
2026-05-01 11:46:07 -04:00

491 lines
16 KiB
Python

"""Test flow state persistence functionality."""
import os
from typing import Dict, List
import pytest
from crewai.flow.flow import Flow, FlowState, listen, start
from crewai.flow.persistence import persist
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
from pydantic import BaseModel
class TestState(FlowState):
"""Test state model with required id field."""
counter: int = 0
message: str = ""
def test_persist_decorator_saves_state(tmp_path, caplog):
"""Test that @persist decorator saves state in SQLite."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class TestFlow(Flow[Dict[str, str]]):
initial_state = dict() # Use dict instance as initial state
@start()
@persist(persistence)
def init_step(self):
self.state["message"] = "Hello, World!"
self.state["id"] = "test-uuid" # Ensure we have an ID for persistence
# Run flow and verify state is saved
flow = TestFlow(persistence=persistence)
flow.kickoff()
# Load state from DB and verify
saved_state = persistence.load_state(flow.state["id"])
assert saved_state is not None
assert saved_state["message"] == "Hello, World!"
def test_structured_state_persistence(tmp_path):
"""Test persistence with Pydantic model state."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class StructuredFlow(Flow[TestState]):
initial_state = TestState
@start()
@persist(persistence)
def count_up(self):
self.state.counter += 1
self.state.message = f"Count is {self.state.counter}"
# Run flow and verify state changes are saved
flow = StructuredFlow(persistence=persistence)
flow.kickoff()
# Load and verify state
saved_state = persistence.load_state(flow.state.id)
assert saved_state is not None
assert saved_state["counter"] == 1
assert saved_state["message"] == "Count is 1"
def test_flow_state_restoration(tmp_path):
"""Test restoring flow state from persistence with various restoration methods."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
# First flow execution to create initial state
class RestorableFlow(Flow[TestState]):
@start()
@persist(persistence)
def set_message(self):
if self.state.message == "":
self.state.message = "Original message"
if self.state.counter == 0:
self.state.counter = 42
# Create and persist initial state
flow1 = RestorableFlow(persistence=persistence)
flow1.kickoff()
original_uuid = flow1.state.id
# Test case 1: Restore using restore_uuid with field override
flow2 = RestorableFlow(persistence=persistence)
flow2.kickoff(inputs={"id": original_uuid, "counter": 43})
# Verify state restoration and selective field override
assert flow2.state.id == original_uuid
assert flow2.state.message == "Original message" # Preserved
assert flow2.state.counter == 43 # Overridden
# Test case 2: Restore using kwargs['id']
flow3 = RestorableFlow(persistence=persistence)
flow3.kickoff(inputs={"id": original_uuid, "message": "Updated message"})
# Verify state restoration and selective field override
assert flow3.state.id == original_uuid
assert flow3.state.counter == 43 # Preserved
assert flow3.state.message == "Updated message" # Overridden
def test_multiple_method_persistence(tmp_path):
"""Test state persistence across multiple method executions."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class MultiStepFlow(Flow[TestState]):
@start()
@persist(persistence)
def step_1(self):
if self.state.counter == 1:
self.state.counter = 99999
self.state.message = "Step 99999"
else:
self.state.counter = 1
self.state.message = "Step 1"
@listen(step_1)
@persist(persistence)
def step_2(self):
if self.state.counter == 1:
self.state.counter = 2
self.state.message = "Step 2"
flow = MultiStepFlow(persistence=persistence)
flow.kickoff()
flow2 = MultiStepFlow(persistence=persistence)
flow2.kickoff(inputs={"id": flow.state.id})
# Load final state
final_state = flow2.state
assert final_state is not None
assert final_state.counter == 2
assert final_state.message == "Step 2"
class NoPersistenceMultiStepFlow(Flow[TestState]):
@start()
@persist(persistence)
def step_1(self):
if self.state.counter == 1:
self.state.counter = 99999
self.state.message = "Step 99999"
else:
self.state.counter = 1
self.state.message = "Step 1"
@listen(step_1)
def step_2(self):
if self.state.counter == 1:
self.state.counter = 2
self.state.message = "Step 2"
flow = NoPersistenceMultiStepFlow(persistence=persistence)
flow.kickoff()
flow2 = NoPersistenceMultiStepFlow(persistence=persistence)
flow2.kickoff(inputs={"id": flow.state.id})
# Load final state
final_state = flow2.state
assert final_state.counter == 99999
assert final_state.message == "Step 99999"
def test_persist_decorator_verbose_logging(tmp_path, caplog):
"""Test that @persist decorator's verbose parameter controls logging."""
# Set logging level to ensure we capture all logs
caplog.set_level("INFO")
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
# Test with verbose=False (default)
class QuietFlow(Flow[Dict[str, str]]):
initial_state = dict()
@start()
@persist(persistence) # Default verbose=False
def init_step(self):
self.state["message"] = "Hello, World!"
self.state["id"] = "test-uuid-1"
flow = QuietFlow(persistence=persistence)
flow.kickoff()
assert "Saving flow state" not in caplog.text
# Clear the log
caplog.clear()
# Test with verbose=True
class VerboseFlow(Flow[Dict[str, str]]):
initial_state = dict()
@start()
@persist(persistence, verbose=True)
def init_step(self):
self.state["message"] = "Hello, World!"
self.state["id"] = "test-uuid-2"
flow = VerboseFlow(persistence=persistence)
flow.kickoff()
assert "Saving flow state" in caplog.text
def test_persistence_with_base_model(tmp_path):
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class Message(BaseModel):
role: str
type: str
content: str
class State(FlowState):
latest_message: Message | None = None
history: List[Message] = []
@persist(persistence)
class BaseModelFlow(Flow[State]):
initial_state = State(latest_message=None, history=[])
@start()
def init_step(self):
self.state.latest_message = Message(
role="user", type="text", content="Hello, World!"
)
self.state.history.append(self.state.latest_message)
flow = BaseModelFlow(persistence=persistence)
flow.kickoff()
latest_message = flow.state.latest_message
(message,) = flow.state.history
assert latest_message is not None
assert latest_message.role == "user"
assert latest_message.type == "text"
assert latest_message.content == "Hello, World!"
assert len(flow.state.history) == 1
assert message.role == "user"
assert message.type == "text"
assert message.content == "Hello, World!"
assert isinstance(flow.state._unwrap(), State)
def test_fork_with_restore_from_state_id(tmp_path):
"""Fork: restore_from_state_id hydrates state from source flow_uuid; new run gets a
fresh state.id; source's history is preserved (the fork's @persist writes go under
the new state.id, not the source's)."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class ForkableFlow(Flow[TestState]):
@start()
@persist(persistence)
def step(self):
self.state.counter += 1
# Run 1: build up source state. counter goes 0 -> 1.
flow1 = ForkableFlow(persistence=persistence)
flow1.kickoff()
source_uuid = flow1.state.id
assert flow1.state.counter == 1
# Resume on the same uuid bumps counter to 2 in the SAME flow_uuid history.
flow1b = ForkableFlow(persistence=persistence)
flow1b.kickoff(inputs={"id": source_uuid})
assert flow1b.state.counter == 2
assert persistence.load_state(source_uuid)["counter"] == 2
# Fork: hydrate from source, but persist under a fresh state.id.
flow2 = ForkableFlow(persistence=persistence)
flow2.kickoff(restore_from_state_id=source_uuid)
# Fork has a different state.id from the source.
assert flow2.state.id != source_uuid
# Hydrated from source's latest snapshot (counter=2), then incremented to 3.
assert flow2.state.counter == 3
# Source's history is unchanged after the fork.
assert persistence.load_state(source_uuid)["counter"] == 2
# Fork's writes landed under its own state.id.
assert persistence.load_state(flow2.state.id)["counter"] == 3
def test_fork_with_pinned_state_id(tmp_path):
"""Fork into a pinned state.id (inputs.id supplied alongside restore_from_state_id):
the new run uses inputs.id as state.id and hydrates from restore_from_state_id."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class PinnableFlow(Flow[TestState]):
@start()
@persist(persistence)
def step(self):
self.state.counter += 1
flow1 = PinnableFlow(persistence=persistence)
flow1.kickoff()
source_uuid = flow1.state.id
assert flow1.state.counter == 1
pinned_uuid = "pinned-fork-uuid-1234"
flow2 = PinnableFlow(persistence=persistence)
flow2.kickoff(
inputs={"id": pinned_uuid},
restore_from_state_id=source_uuid,
)
# state.id pinned to inputs.id, NOT the source uuid.
assert flow2.state.id == pinned_uuid
# Hydrated from source: counter started at 1, step incremented to 2.
assert flow2.state.counter == 2
# Source's history is unchanged.
assert persistence.load_state(source_uuid)["counter"] == 1
# Fork's writes are under the pinned uuid.
assert persistence.load_state(pinned_uuid)["counter"] == 2
def test_restore_from_state_id_not_found_silent_fallback(tmp_path):
"""Lookup miss on restore_from_state_id silently falls through to default behavior."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class FallbackFlow(Flow[TestState]):
@start()
@persist(persistence)
def step(self):
self.state.counter += 1
flow = FallbackFlow(persistence=persistence)
# No source UUID exists — should not raise.
flow.kickoff(restore_from_state_id="no-such-uuid")
# Default state path: counter starts at 0 and step increments to 1.
assert flow.state.counter == 1
# state.id is the auto-generated one, NOT the missing source.
assert flow.state.id != "no-such-uuid"
def test_restore_from_state_id_none_is_no_op(tmp_path):
"""restore_from_state_id=None (default) preserves baseline kickoff behavior."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class BaselineFlow(Flow[TestState]):
@start()
@persist(persistence)
def step(self):
self.state.counter += 1
flow = BaselineFlow(persistence=persistence)
flow.kickoff(restore_from_state_id=None)
assert flow.state.counter == 1
def test_fork_conflict_with_from_checkpoint_raises():
"""Passing both from_checkpoint and restore_from_state_id raises ValueError, naming
both parameters."""
from crewai.state import CheckpointConfig
class ConflictFlow(Flow[TestState]):
@start()
def step(self):
pass
flow = ConflictFlow()
with pytest.raises(ValueError) as excinfo:
flow.kickoff(
from_checkpoint=CheckpointConfig(),
restore_from_state_id="some-uuid",
)
msg = str(excinfo.value)
assert "from_checkpoint" in msg
assert "restore_from_state_id" in msg
@pytest.mark.asyncio
async def test_fork_via_kickoff_async(tmp_path):
"""kickoff_async honors restore_from_state_id: hydrates from source, mints fresh
state.id, persists under the new id, source history preserved."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class AsyncForkableFlow(Flow[TestState]):
@start()
@persist(persistence)
def step(self):
self.state.counter += 1
flow1 = AsyncForkableFlow(persistence=persistence)
await flow1.kickoff_async()
source_uuid = flow1.state.id
assert flow1.state.counter == 1
flow2 = AsyncForkableFlow(persistence=persistence)
await flow2.kickoff_async(restore_from_state_id=source_uuid)
assert flow2.state.id != source_uuid
assert flow2.state.counter == 2
assert persistence.load_state(source_uuid)["counter"] == 1
assert persistence.load_state(flow2.state.id)["counter"] == 2
@pytest.mark.asyncio
async def test_fork_via_akickoff(tmp_path):
"""akickoff is the public async alias and must accept restore_from_state_id with
the same semantics as kickoff_async."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class AkickoffForkableFlow(Flow[TestState]):
@start()
@persist(persistence)
def step(self):
self.state.counter += 1
flow1 = AkickoffForkableFlow(persistence=persistence)
await flow1.akickoff()
source_uuid = flow1.state.id
assert flow1.state.counter == 1
flow2 = AkickoffForkableFlow(persistence=persistence)
await flow2.akickoff(restore_from_state_id=source_uuid)
assert flow2.state.id != source_uuid
assert flow2.state.counter == 2
assert persistence.load_state(source_uuid)["counter"] == 1
assert persistence.load_state(flow2.state.id)["counter"] == 2
@pytest.mark.asyncio
async def test_akickoff_pinned_fork(tmp_path):
"""akickoff with both inputs.id and restore_from_state_id pins state.id while
hydrating from the source."""
db_path = os.path.join(tmp_path, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
class PinnableAsyncFlow(Flow[TestState]):
@start()
@persist(persistence)
def step(self):
self.state.counter += 1
flow1 = PinnableAsyncFlow(persistence=persistence)
await flow1.akickoff()
source_uuid = flow1.state.id
pinned_uuid = "pinned-akickoff-fork-uuid"
flow2 = PinnableAsyncFlow(persistence=persistence)
await flow2.akickoff(
inputs={"id": pinned_uuid},
restore_from_state_id=source_uuid,
)
assert flow2.state.id == pinned_uuid
assert flow2.state.counter == 2
assert persistence.load_state(source_uuid)["counter"] == 1
assert persistence.load_state(pinned_uuid)["counter"] == 2
@pytest.mark.asyncio
async def test_akickoff_fork_conflict_with_from_checkpoint_raises():
"""akickoff must raise the same conflict ValueError as kickoff/kickoff_async when
both from_checkpoint and restore_from_state_id are set."""
from crewai.state import CheckpointConfig
class AsyncConflictFlow(Flow[TestState]):
@start()
def step(self):
pass
flow = AsyncConflictFlow()
with pytest.raises(ValueError) as excinfo:
await flow.akickoff(
from_checkpoint=CheckpointConfig(),
restore_from_state_id="some-uuid",
)
msg = str(excinfo.value)
assert "from_checkpoint" in msg
assert "restore_from_state_id" in msg