mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-05 01:02:37 +00:00
## Summary
- Reverts `b0e2fda` ("fix(flow): add execution_id separate from state.id", COR-48): removes `Flow.execution_id` and points `current_flow_id` / `current_flow_request_id` back at `flow_id` (i.e. `state.id`). The separate per-run tracking id was no longer the right abstraction once `restore_from_state_id` reshapes how `state.id` is assigned;
- Adds an optional `restore_from_state_id` kwarg to `Flow.kickoff` / `Flow.kickoff_async` that hydrates state from a previously-persisted flow's latest snapshot
- Reassigns `state.id` to a fresh value (or `inputs["id"]` if pinned) so the new run's `@persist` writes don't extend the source's history
- Existing `inputs["id"]` resume, `@persist`, and `from_checkpoint` paths are unchanged
## Problem
`@persist` only supports *resume* today: `kickoff(inputs={"id": <uuid>})` hydrates state and continues writing under the same `flow_uuid`. There's no way to **fork** — hydrate from a snapshot but persist under a separate key, leaving the source's history intact. This PR adds that.
| | `state.id` after kickoff | `@persist` writes land under |
|---|---|---|
| `inputs["id"]` (resume) | supplied id | supplied id (extends history) |
| `restore_from_state_id` (fork) | fresh id, or `inputs["id"]` if pinned | new id (source preserved) |
## Behavior
| `inputs.id` | `restore_from_state_id` | Effect |
|---|---|---|
| — | — | Fresh kickoff |
| set | — | Existing resume |
| — | UUID | Fork — new `state.id`, hydrated from source |
| set | UUID | Fork into a pinned `state.id`, hydrated from source |
- Source not found → silent fallback (mirrors existing resume)
- Both `from_checkpoint` and `restore_from_state_id` set → `ValueError`
- `restore_from_state_id=None` → byte-identical to current main
## Design
Fork hydration runs before the existing `inputs` block in `kickoff_async`. On a hit, it calls the same `_restore_state` primitive used by resume, then overwrites `state.id` with a fresh UUID (or `inputs["id"]`). A `fork_succeeded` flag gates the existing `inputs["id"]` path so we don't double-load. `_completed_methods` / `_is_execution_resuming` are intentionally untouched — skip-completed-methods remains the territory of `apply_checkpoint` and `from_pending`.
## Test plan
- [ ] `pytest tests/test_flow_persistence.py` — 5 new tests (four-row matrix, not-found fallback, default no-op, conflict raise) + 6 existing as regression
- [ ] `pytest tests/test_flow.py` — broader flow suite
- [ ] Manual end-to-end against an HITL `@persist` flow
491 lines
16 KiB
Python
491 lines
16 KiB
Python
"""Test flow state persistence functionality."""
|
|
|
|
import os
|
|
from typing import Dict, List
|
|
|
|
import pytest
|
|
from crewai.flow.flow import Flow, FlowState, listen, start
|
|
from crewai.flow.persistence import persist
|
|
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
|
from pydantic import BaseModel
|
|
|
|
|
|
class TestState(FlowState):
|
|
"""Test state model with required id field."""
|
|
|
|
counter: int = 0
|
|
message: str = ""
|
|
|
|
|
|
def test_persist_decorator_saves_state(tmp_path, caplog):
|
|
"""Test that @persist decorator saves state in SQLite."""
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
class TestFlow(Flow[Dict[str, str]]):
|
|
initial_state = dict() # Use dict instance as initial state
|
|
|
|
@start()
|
|
@persist(persistence)
|
|
def init_step(self):
|
|
self.state["message"] = "Hello, World!"
|
|
self.state["id"] = "test-uuid" # Ensure we have an ID for persistence
|
|
|
|
# Run flow and verify state is saved
|
|
flow = TestFlow(persistence=persistence)
|
|
flow.kickoff()
|
|
|
|
# Load state from DB and verify
|
|
saved_state = persistence.load_state(flow.state["id"])
|
|
assert saved_state is not None
|
|
assert saved_state["message"] == "Hello, World!"
|
|
|
|
|
|
def test_structured_state_persistence(tmp_path):
|
|
"""Test persistence with Pydantic model state."""
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
class StructuredFlow(Flow[TestState]):
|
|
initial_state = TestState
|
|
|
|
@start()
|
|
@persist(persistence)
|
|
def count_up(self):
|
|
self.state.counter += 1
|
|
self.state.message = f"Count is {self.state.counter}"
|
|
|
|
# Run flow and verify state changes are saved
|
|
flow = StructuredFlow(persistence=persistence)
|
|
flow.kickoff()
|
|
|
|
# Load and verify state
|
|
saved_state = persistence.load_state(flow.state.id)
|
|
assert saved_state is not None
|
|
assert saved_state["counter"] == 1
|
|
assert saved_state["message"] == "Count is 1"
|
|
|
|
|
|
def test_flow_state_restoration(tmp_path):
|
|
"""Test restoring flow state from persistence with various restoration methods."""
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
# First flow execution to create initial state
|
|
class RestorableFlow(Flow[TestState]):
|
|
@start()
|
|
@persist(persistence)
|
|
def set_message(self):
|
|
if self.state.message == "":
|
|
self.state.message = "Original message"
|
|
if self.state.counter == 0:
|
|
self.state.counter = 42
|
|
|
|
# Create and persist initial state
|
|
flow1 = RestorableFlow(persistence=persistence)
|
|
flow1.kickoff()
|
|
original_uuid = flow1.state.id
|
|
|
|
# Test case 1: Restore using restore_uuid with field override
|
|
flow2 = RestorableFlow(persistence=persistence)
|
|
flow2.kickoff(inputs={"id": original_uuid, "counter": 43})
|
|
|
|
# Verify state restoration and selective field override
|
|
assert flow2.state.id == original_uuid
|
|
assert flow2.state.message == "Original message" # Preserved
|
|
assert flow2.state.counter == 43 # Overridden
|
|
|
|
# Test case 2: Restore using kwargs['id']
|
|
flow3 = RestorableFlow(persistence=persistence)
|
|
flow3.kickoff(inputs={"id": original_uuid, "message": "Updated message"})
|
|
|
|
# Verify state restoration and selective field override
|
|
assert flow3.state.id == original_uuid
|
|
assert flow3.state.counter == 43 # Preserved
|
|
assert flow3.state.message == "Updated message" # Overridden
|
|
|
|
|
|
def test_multiple_method_persistence(tmp_path):
|
|
"""Test state persistence across multiple method executions."""
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
class MultiStepFlow(Flow[TestState]):
|
|
@start()
|
|
@persist(persistence)
|
|
def step_1(self):
|
|
if self.state.counter == 1:
|
|
self.state.counter = 99999
|
|
self.state.message = "Step 99999"
|
|
else:
|
|
self.state.counter = 1
|
|
self.state.message = "Step 1"
|
|
|
|
@listen(step_1)
|
|
@persist(persistence)
|
|
def step_2(self):
|
|
if self.state.counter == 1:
|
|
self.state.counter = 2
|
|
self.state.message = "Step 2"
|
|
|
|
flow = MultiStepFlow(persistence=persistence)
|
|
flow.kickoff()
|
|
|
|
flow2 = MultiStepFlow(persistence=persistence)
|
|
flow2.kickoff(inputs={"id": flow.state.id})
|
|
|
|
# Load final state
|
|
final_state = flow2.state
|
|
assert final_state is not None
|
|
assert final_state.counter == 2
|
|
assert final_state.message == "Step 2"
|
|
|
|
class NoPersistenceMultiStepFlow(Flow[TestState]):
|
|
@start()
|
|
@persist(persistence)
|
|
def step_1(self):
|
|
if self.state.counter == 1:
|
|
self.state.counter = 99999
|
|
self.state.message = "Step 99999"
|
|
else:
|
|
self.state.counter = 1
|
|
self.state.message = "Step 1"
|
|
|
|
@listen(step_1)
|
|
def step_2(self):
|
|
if self.state.counter == 1:
|
|
self.state.counter = 2
|
|
self.state.message = "Step 2"
|
|
|
|
flow = NoPersistenceMultiStepFlow(persistence=persistence)
|
|
flow.kickoff()
|
|
|
|
flow2 = NoPersistenceMultiStepFlow(persistence=persistence)
|
|
flow2.kickoff(inputs={"id": flow.state.id})
|
|
|
|
# Load final state
|
|
final_state = flow2.state
|
|
assert final_state.counter == 99999
|
|
assert final_state.message == "Step 99999"
|
|
|
|
|
|
def test_persist_decorator_verbose_logging(tmp_path, caplog):
|
|
"""Test that @persist decorator's verbose parameter controls logging."""
|
|
# Set logging level to ensure we capture all logs
|
|
caplog.set_level("INFO")
|
|
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
# Test with verbose=False (default)
|
|
class QuietFlow(Flow[Dict[str, str]]):
|
|
initial_state = dict()
|
|
|
|
@start()
|
|
@persist(persistence) # Default verbose=False
|
|
def init_step(self):
|
|
self.state["message"] = "Hello, World!"
|
|
self.state["id"] = "test-uuid-1"
|
|
|
|
flow = QuietFlow(persistence=persistence)
|
|
flow.kickoff()
|
|
assert "Saving flow state" not in caplog.text
|
|
|
|
# Clear the log
|
|
caplog.clear()
|
|
|
|
# Test with verbose=True
|
|
class VerboseFlow(Flow[Dict[str, str]]):
|
|
initial_state = dict()
|
|
|
|
@start()
|
|
@persist(persistence, verbose=True)
|
|
def init_step(self):
|
|
self.state["message"] = "Hello, World!"
|
|
self.state["id"] = "test-uuid-2"
|
|
|
|
flow = VerboseFlow(persistence=persistence)
|
|
flow.kickoff()
|
|
assert "Saving flow state" in caplog.text
|
|
|
|
|
|
def test_persistence_with_base_model(tmp_path):
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
class Message(BaseModel):
|
|
role: str
|
|
type: str
|
|
content: str
|
|
|
|
class State(FlowState):
|
|
latest_message: Message | None = None
|
|
history: List[Message] = []
|
|
|
|
@persist(persistence)
|
|
class BaseModelFlow(Flow[State]):
|
|
initial_state = State(latest_message=None, history=[])
|
|
|
|
@start()
|
|
def init_step(self):
|
|
self.state.latest_message = Message(
|
|
role="user", type="text", content="Hello, World!"
|
|
)
|
|
self.state.history.append(self.state.latest_message)
|
|
|
|
flow = BaseModelFlow(persistence=persistence)
|
|
flow.kickoff()
|
|
|
|
latest_message = flow.state.latest_message
|
|
(message,) = flow.state.history
|
|
|
|
assert latest_message is not None
|
|
assert latest_message.role == "user"
|
|
assert latest_message.type == "text"
|
|
assert latest_message.content == "Hello, World!"
|
|
|
|
assert len(flow.state.history) == 1
|
|
assert message.role == "user"
|
|
assert message.type == "text"
|
|
assert message.content == "Hello, World!"
|
|
assert isinstance(flow.state._unwrap(), State)
|
|
|
|
|
|
def test_fork_with_restore_from_state_id(tmp_path):
|
|
"""Fork: restore_from_state_id hydrates state from source flow_uuid; new run gets a
|
|
fresh state.id; source's history is preserved (the fork's @persist writes go under
|
|
the new state.id, not the source's)."""
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
class ForkableFlow(Flow[TestState]):
|
|
@start()
|
|
@persist(persistence)
|
|
def step(self):
|
|
self.state.counter += 1
|
|
|
|
# Run 1: build up source state. counter goes 0 -> 1.
|
|
flow1 = ForkableFlow(persistence=persistence)
|
|
flow1.kickoff()
|
|
source_uuid = flow1.state.id
|
|
assert flow1.state.counter == 1
|
|
|
|
# Resume on the same uuid bumps counter to 2 in the SAME flow_uuid history.
|
|
flow1b = ForkableFlow(persistence=persistence)
|
|
flow1b.kickoff(inputs={"id": source_uuid})
|
|
assert flow1b.state.counter == 2
|
|
assert persistence.load_state(source_uuid)["counter"] == 2
|
|
|
|
# Fork: hydrate from source, but persist under a fresh state.id.
|
|
flow2 = ForkableFlow(persistence=persistence)
|
|
flow2.kickoff(restore_from_state_id=source_uuid)
|
|
|
|
# Fork has a different state.id from the source.
|
|
assert flow2.state.id != source_uuid
|
|
# Hydrated from source's latest snapshot (counter=2), then incremented to 3.
|
|
assert flow2.state.counter == 3
|
|
|
|
# Source's history is unchanged after the fork.
|
|
assert persistence.load_state(source_uuid)["counter"] == 2
|
|
|
|
# Fork's writes landed under its own state.id.
|
|
assert persistence.load_state(flow2.state.id)["counter"] == 3
|
|
|
|
|
|
def test_fork_with_pinned_state_id(tmp_path):
|
|
"""Fork into a pinned state.id (inputs.id supplied alongside restore_from_state_id):
|
|
the new run uses inputs.id as state.id and hydrates from restore_from_state_id."""
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
class PinnableFlow(Flow[TestState]):
|
|
@start()
|
|
@persist(persistence)
|
|
def step(self):
|
|
self.state.counter += 1
|
|
|
|
flow1 = PinnableFlow(persistence=persistence)
|
|
flow1.kickoff()
|
|
source_uuid = flow1.state.id
|
|
assert flow1.state.counter == 1
|
|
|
|
pinned_uuid = "pinned-fork-uuid-1234"
|
|
flow2 = PinnableFlow(persistence=persistence)
|
|
flow2.kickoff(
|
|
inputs={"id": pinned_uuid},
|
|
restore_from_state_id=source_uuid,
|
|
)
|
|
|
|
# state.id pinned to inputs.id, NOT the source uuid.
|
|
assert flow2.state.id == pinned_uuid
|
|
# Hydrated from source: counter started at 1, step incremented to 2.
|
|
assert flow2.state.counter == 2
|
|
# Source's history is unchanged.
|
|
assert persistence.load_state(source_uuid)["counter"] == 1
|
|
# Fork's writes are under the pinned uuid.
|
|
assert persistence.load_state(pinned_uuid)["counter"] == 2
|
|
|
|
|
|
def test_restore_from_state_id_not_found_silent_fallback(tmp_path):
|
|
"""Lookup miss on restore_from_state_id silently falls through to default behavior."""
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
class FallbackFlow(Flow[TestState]):
|
|
@start()
|
|
@persist(persistence)
|
|
def step(self):
|
|
self.state.counter += 1
|
|
|
|
flow = FallbackFlow(persistence=persistence)
|
|
# No source UUID exists — should not raise.
|
|
flow.kickoff(restore_from_state_id="no-such-uuid")
|
|
|
|
# Default state path: counter starts at 0 and step increments to 1.
|
|
assert flow.state.counter == 1
|
|
# state.id is the auto-generated one, NOT the missing source.
|
|
assert flow.state.id != "no-such-uuid"
|
|
|
|
|
|
def test_restore_from_state_id_none_is_no_op(tmp_path):
|
|
"""restore_from_state_id=None (default) preserves baseline kickoff behavior."""
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
class BaselineFlow(Flow[TestState]):
|
|
@start()
|
|
@persist(persistence)
|
|
def step(self):
|
|
self.state.counter += 1
|
|
|
|
flow = BaselineFlow(persistence=persistence)
|
|
flow.kickoff(restore_from_state_id=None)
|
|
assert flow.state.counter == 1
|
|
|
|
|
|
def test_fork_conflict_with_from_checkpoint_raises():
|
|
"""Passing both from_checkpoint and restore_from_state_id raises ValueError, naming
|
|
both parameters."""
|
|
from crewai.state import CheckpointConfig
|
|
|
|
class ConflictFlow(Flow[TestState]):
|
|
@start()
|
|
def step(self):
|
|
pass
|
|
|
|
flow = ConflictFlow()
|
|
with pytest.raises(ValueError) as excinfo:
|
|
flow.kickoff(
|
|
from_checkpoint=CheckpointConfig(),
|
|
restore_from_state_id="some-uuid",
|
|
)
|
|
msg = str(excinfo.value)
|
|
assert "from_checkpoint" in msg
|
|
assert "restore_from_state_id" in msg
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fork_via_kickoff_async(tmp_path):
|
|
"""kickoff_async honors restore_from_state_id: hydrates from source, mints fresh
|
|
state.id, persists under the new id, source history preserved."""
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
class AsyncForkableFlow(Flow[TestState]):
|
|
@start()
|
|
@persist(persistence)
|
|
def step(self):
|
|
self.state.counter += 1
|
|
|
|
flow1 = AsyncForkableFlow(persistence=persistence)
|
|
await flow1.kickoff_async()
|
|
source_uuid = flow1.state.id
|
|
assert flow1.state.counter == 1
|
|
|
|
flow2 = AsyncForkableFlow(persistence=persistence)
|
|
await flow2.kickoff_async(restore_from_state_id=source_uuid)
|
|
|
|
assert flow2.state.id != source_uuid
|
|
assert flow2.state.counter == 2
|
|
assert persistence.load_state(source_uuid)["counter"] == 1
|
|
assert persistence.load_state(flow2.state.id)["counter"] == 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fork_via_akickoff(tmp_path):
|
|
"""akickoff is the public async alias and must accept restore_from_state_id with
|
|
the same semantics as kickoff_async."""
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
class AkickoffForkableFlow(Flow[TestState]):
|
|
@start()
|
|
@persist(persistence)
|
|
def step(self):
|
|
self.state.counter += 1
|
|
|
|
flow1 = AkickoffForkableFlow(persistence=persistence)
|
|
await flow1.akickoff()
|
|
source_uuid = flow1.state.id
|
|
assert flow1.state.counter == 1
|
|
|
|
flow2 = AkickoffForkableFlow(persistence=persistence)
|
|
await flow2.akickoff(restore_from_state_id=source_uuid)
|
|
|
|
assert flow2.state.id != source_uuid
|
|
assert flow2.state.counter == 2
|
|
assert persistence.load_state(source_uuid)["counter"] == 1
|
|
assert persistence.load_state(flow2.state.id)["counter"] == 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_akickoff_pinned_fork(tmp_path):
|
|
"""akickoff with both inputs.id and restore_from_state_id pins state.id while
|
|
hydrating from the source."""
|
|
db_path = os.path.join(tmp_path, "test_flows.db")
|
|
persistence = SQLiteFlowPersistence(db_path)
|
|
|
|
class PinnableAsyncFlow(Flow[TestState]):
|
|
@start()
|
|
@persist(persistence)
|
|
def step(self):
|
|
self.state.counter += 1
|
|
|
|
flow1 = PinnableAsyncFlow(persistence=persistence)
|
|
await flow1.akickoff()
|
|
source_uuid = flow1.state.id
|
|
|
|
pinned_uuid = "pinned-akickoff-fork-uuid"
|
|
flow2 = PinnableAsyncFlow(persistence=persistence)
|
|
await flow2.akickoff(
|
|
inputs={"id": pinned_uuid},
|
|
restore_from_state_id=source_uuid,
|
|
)
|
|
|
|
assert flow2.state.id == pinned_uuid
|
|
assert flow2.state.counter == 2
|
|
assert persistence.load_state(source_uuid)["counter"] == 1
|
|
assert persistence.load_state(pinned_uuid)["counter"] == 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_akickoff_fork_conflict_with_from_checkpoint_raises():
|
|
"""akickoff must raise the same conflict ValueError as kickoff/kickoff_async when
|
|
both from_checkpoint and restore_from_state_id are set."""
|
|
from crewai.state import CheckpointConfig
|
|
|
|
class AsyncConflictFlow(Flow[TestState]):
|
|
@start()
|
|
def step(self):
|
|
pass
|
|
|
|
flow = AsyncConflictFlow()
|
|
with pytest.raises(ValueError) as excinfo:
|
|
await flow.akickoff(
|
|
from_checkpoint=CheckpointConfig(),
|
|
restore_from_state_id="some-uuid",
|
|
)
|
|
msg = str(excinfo.value)
|
|
assert "from_checkpoint" in msg
|
|
assert "restore_from_state_id" in msg
|