From 4ce7cf679f1d86719ab3f9eb796a72fdb8d10f46 Mon Sep 17 00:00:00 2001 From: Vinicius Brasil Date: Thu, 11 Jun 2026 18:49:58 -0700 Subject: [PATCH] Wire config and persistence from FlowDefinition into the runtime MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `from_definition` was silently dropping all config fields; it now passes `config.model_dump()` so suppress_flow_events, max_method_calls, etc. actually apply. Persistence is now engine-driven: `_persist_method_completion` fires after every method using the definition's persist metadata, so `@persist` no longer needs to wrap methods — it just stamps them. --- lib/crewai/src/crewai/flow/dsl/_utils.py | 11 +- lib/crewai/src/crewai/flow/flow_definition.py | 12 +- .../src/crewai/flow/persistence/decorators.py | 124 +---- .../src/crewai/flow/runtime/__init__.py | 77 ++- .../crewai/flow/runtime/_action_resolvers.py | 15 +- lib/crewai/tests/test_flow_from_definition.py | 466 +++++++++++++++++- 6 files changed, 572 insertions(+), 133 deletions(-) diff --git a/lib/crewai/src/crewai/flow/dsl/_utils.py b/lib/crewai/src/crewai/flow/dsl/_utils.py index c9ceebdc0..ee8202272 100644 --- a/lib/crewai/src/crewai/flow/dsl/_utils.py +++ b/lib/crewai/src/crewai/flow/dsl/_utils.py @@ -219,16 +219,19 @@ def _build_config_definition( ) -> FlowConfigDefinition: config_field_names = set(FlowConfigDefinition.model_fields) field_defaults = { - name: field.default + name: field.get_default(call_default_factory=True) for name, field in getattr(flow_class, "model_fields", {}).items() if name in config_field_names } values: dict[str, Any] = {} for field_name, default in field_defaults.items(): value = getattr(flow_class, field_name, default) - values[field_name] = _serialize_static_value( - value, diagnostics, f"config.{field_name}" - ) + if field_name == "input_provider": + values[field_name] = None if value is None else _object_ref(value) + else: + values[field_name] = _serialize_static_value( + value, diagnostics, f"config.{field_name}" + ) return FlowConfigDefinition(**values) diff --git a/lib/crewai/src/crewai/flow/flow_definition.py b/lib/crewai/src/crewai/flow/flow_definition.py index 365bfd7a7..5de3ae2e6 100644 --- a/lib/crewai/src/crewai/flow/flow_definition.py +++ b/lib/crewai/src/crewai/flow/flow_definition.py @@ -64,10 +64,12 @@ class FlowConfigDefinition(BaseModel): tracing: bool | None = None stream: bool = False - memory: Any = None - input_provider: Any = None + memory: dict[str, Any] | None = None + input_provider: str | None = None suppress_flow_events: bool = False max_method_calls: int = 100 + defer_trace_finalization: bool = False + checkpoint: bool | dict[str, Any] | None = None class FlowPersistenceDefinition(BaseModel): @@ -75,7 +77,7 @@ class FlowPersistenceDefinition(BaseModel): enabled: bool = False verbose: bool = False - persistence: Any = None + persistence: dict[str, Any] | None = None class FlowHumanFeedbackDefinition(BaseModel): @@ -126,7 +128,9 @@ class FlowDefinition(BaseModel): model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True) - schema_: str = Field(default="crewai.flow/v1", alias="schema") + schema_: TypingLiteral["crewai.flow/v1"] = Field( + default="crewai.flow/v1", alias="schema" + ) name: str description: str | None = None state: FlowStateDefinition | None = None diff --git a/lib/crewai/src/crewai/flow/persistence/decorators.py b/lib/crewai/src/crewai/flow/persistence/decorators.py index 2f30d6b0c..65da2cee1 100644 --- a/lib/crewai/src/crewai/flow/persistence/decorators.py +++ b/lib/crewai/src/crewai/flow/persistence/decorators.py @@ -24,12 +24,11 @@ Example: from __future__ import annotations -import asyncio from collections.abc import Callable import functools import logging from types import SimpleNamespace -from typing import TYPE_CHECKING, Any, Final, TypeVar, cast +from typing import TYPE_CHECKING, Any, Final, TypeVar from crewai_core.printer import PRINTER from pydantic import BaseModel @@ -39,7 +38,7 @@ from crewai.flow.persistence.factory import default_flow_persistence if TYPE_CHECKING: - from crewai.flow.flow import Flow + from crewai.flow.runtime import Flow logger = logging.getLogger(__name__) @@ -66,14 +65,6 @@ def _stamp_persistence_metadata( ) -_PRESERVED_FLOW_ATTRS: Final[tuple[str, ...]] = ( - "__human_feedback_config__", - "__flow_persistence_config__", - "__flow_method_definition__", - "_human_feedback_llm", -) - - class PersistenceDecorator: """Class to handle flow state persistence with consistent logging.""" @@ -164,6 +155,10 @@ def persist( states. When applied at the method level, it persists only that method's state. + The decorator is a pure metadata stamper: it records the persistence + configuration on the class or method, and the Flow engine saves state + after each persisted method completes, driven by the flow's definition. + Args: persistence: Optional FlowPersistence implementation to use. If not provided, uses ``default_flow_persistence()`` (the @@ -202,111 +197,10 @@ def persist( original_init(self, *args, **kwargs) target.__init__ = new_init # type: ignore[misc] - - # Preserve original methods' decorators - original_methods = { - name: method - for name, method in target.__dict__.items() - if callable(method) - and ( - hasattr(method, "__is_flow_method__") - or hasattr(method, "__flow_method_definition__") - ) - } - - for name, method in original_methods.items(): - if asyncio.iscoroutinefunction(method): - # Closure captures the current name and method - def create_async_wrapper( - method_name: str, original_method: Callable[..., Any] - ) -> Callable[..., Any]: - @functools.wraps(original_method) - async def method_wrapper( - self: Any, *args: Any, **kwargs: Any - ) -> Any: - result = await original_method(self, *args, **kwargs) - PersistenceDecorator.persist_state( - self, method_name, actual_persistence, verbose - ) - return result - - return method_wrapper - - wrapped = create_async_wrapper(name, method) - - for attr in _PRESERVED_FLOW_ATTRS: - if hasattr(method, attr): - setattr(wrapped, attr, getattr(method, attr)) - wrapped.__is_flow_method__ = True # type: ignore[attr-defined] - - setattr(target, name, wrapped) - else: - - def create_sync_wrapper( - method_name: str, original_method: Callable[..., Any] - ) -> Callable[..., Any]: - @functools.wraps(original_method) - def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - result = original_method(self, *args, **kwargs) - PersistenceDecorator.persist_state( - self, method_name, actual_persistence, verbose - ) - return result - - return method_wrapper - - wrapped = create_sync_wrapper(name, method) - - for attr in _PRESERVED_FLOW_ATTRS: - if hasattr(method, attr): - setattr(wrapped, attr, getattr(method, attr)) - wrapped.__is_flow_method__ = True # type: ignore[attr-defined] - - setattr(target, name, wrapped) - return target - method = target - method.__is_flow_method__ = True # type: ignore[attr-defined] - _stamp_persistence_metadata(method, actual_persistence, verbose) - if asyncio.iscoroutinefunction(method): - - @functools.wraps(method) - async def method_async_wrapper( - flow_instance: Any, *args: Any, **kwargs: Any - ) -> T: - method_coro = method(flow_instance, *args, **kwargs) - if asyncio.iscoroutine(method_coro): - result = await method_coro - else: - result = method_coro - PersistenceDecorator.persist_state( - flow_instance, method.__name__, actual_persistence, verbose - ) - return cast(T, result) - - for attr in _PRESERVED_FLOW_ATTRS: - if hasattr(method, attr): - setattr(method_async_wrapper, attr, getattr(method, attr)) - method_async_wrapper.__is_flow_method__ = True # type: ignore[attr-defined] - _stamp_persistence_metadata( - method_async_wrapper, actual_persistence, verbose - ) - return cast(Callable[..., T], method_async_wrapper) - - @functools.wraps(method) - def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T: - result = method(flow_instance, *args, **kwargs) - PersistenceDecorator.persist_state( - flow_instance, method.__name__, actual_persistence, verbose - ) - return result - - for attr in _PRESERVED_FLOW_ATTRS: - if hasattr(method, attr): - setattr(method_sync_wrapper, attr, getattr(method, attr)) - method_sync_wrapper.__is_flow_method__ = True # type: ignore[attr-defined] - _stamp_persistence_metadata(method_sync_wrapper, actual_persistence, verbose) - return cast(Callable[..., T], method_sync_wrapper) + target.__is_flow_method__ = True # type: ignore[attr-defined] + _stamp_persistence_metadata(target, actual_persistence, verbose) + return target return decorator diff --git a/lib/crewai/src/crewai/flow/runtime/__init__.py b/lib/crewai/src/crewai/flow/runtime/__init__.py index 33d399da5..09e6983b4 100644 --- a/lib/crewai/src/crewai/flow/runtime/__init__.py +++ b/lib/crewai/src/crewai/flow/runtime/__init__.py @@ -96,6 +96,7 @@ from crewai.flow.flow_definition import ( FlowDefinition, FlowDefinitionCondition, FlowMethodDefinition, + FlowPersistenceDefinition, FlowStateDefinition, ) from crewai.flow.flow_wrappers import ( @@ -282,9 +283,12 @@ def _serialize_persistence(value: Any) -> dict[str, Any] | None: def _validate_input_provider(value: Any) -> Any: if value is None or isinstance(value, InputProvider): return value - from crewai.types.callback import _dotted_path_to_instance + if isinstance(value, str) and ":" in value: + resolved = _resolve_input_provider_ref(value) + else: + from crewai.types.callback import _dotted_path_to_instance - resolved = _dotted_path_to_instance(value) + resolved = _dotted_path_to_instance(value) if resolved is None or isinstance(resolved, InputProvider): return resolved raise ValueError( @@ -293,6 +297,15 @@ def _validate_input_provider(value: Any) -> Any: ) +def _resolve_input_provider_ref(ref: str) -> Any: + from crewai.flow.runtime._action_resolvers import import_ref + + target = import_ref(ref) + if inspect.isclass(target): + return target() + return target + + def _serialize_input_provider(value: Any) -> str | None: if value is None: return None @@ -751,7 +764,10 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): @classmethod def from_definition(cls, definition: FlowDefinition) -> Flow[Any]: """Build a runnable Flow directly from a definition; no subclass required.""" - return cls.model_validate({}, context={"flow_definition": definition}) + return cls.model_validate( + definition.config.model_dump(), + context={"flow_definition": definition}, + ) def _start_method_names(self) -> list[FlowMethodName]: return [ @@ -960,6 +976,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): _input_history: list[InputHistoryEntry] = PrivateAttr(default_factory=list) _state: Any = PrivateAttr(default=None) _deferred_flow_started_event_id: str | None = PrivateAttr(default=None) + _persist_backends: dict[int, FlowPersistence] = PrivateAttr(default_factory=dict) def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]: # type: ignore[override] class _FlowGeneric(cls): # type: ignore[valid-type,misc] @@ -998,6 +1015,14 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): else self._class_bound_methods() ) + flow_persist = self._definition.persist + if ( + self.persistence is None + and flow_persist is not None + and flow_persist.enabled + ): + self.persistence = self._resolve_persist_backend(flow_persist) + if self._state is None: self._state = self._create_initial_state() @@ -1524,6 +1549,8 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self._completed_methods.add(FlowMethodName(context.method_name)) + self._persist_method_completion(FlowMethodName(context.method_name)) + self._pending_feedback_context = None if self.persistence is not None: @@ -2703,6 +2730,8 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self._completed_methods.add(method_name) + self._persist_method_completion(method_name) + finished_event_id: str | None = None if not self.suppress_flow_events: finished_event = MethodExecutionFinishedEvent( @@ -2761,6 +2790,48 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self._event_futures.append(future) raise e + def _persist_method_completion(self, method_name: FlowMethodName) -> None: + method_definition = self._definition.methods.get(method_name) + persist_definition = ( + method_definition.persist + if method_definition is not None and method_definition.persist is not None + else self._definition.persist + ) + if persist_definition is None or not persist_definition.enabled: + return + + from crewai.flow.persistence.decorators import PersistenceDecorator + + backend = self.persistence or self._persist_backend_for(persist_definition) + PersistenceDecorator.persist_state( + self, method_name, backend, verbose=persist_definition.verbose + ) + + def _persist_backend_for( + self, persist_definition: FlowPersistenceDefinition + ) -> FlowPersistence: + cached = self._persist_backends.get(id(persist_definition)) + if cached is None: + cached = self._resolve_persist_backend(persist_definition) + self._persist_backends[id(persist_definition)] = cached + return cached + + def _resolve_persist_backend( + self, persist_definition: FlowPersistenceDefinition + ) -> FlowPersistence: + if persist_definition.persistence is None: + from crewai.flow.persistence.factory import default_flow_persistence + + return default_flow_persistence() + resolved = _resolve_persistence(persist_definition.persistence) + if not isinstance(resolved, FlowPersistence): + raise ValueError( + f"Cannot resolve persistence backend " + f"{persist_definition.persistence!r} from the flow definition " + f"for flow {self._definition.name!r}." + ) + return resolved + def _copy_and_serialize_state(self) -> dict[str, Any]: state_copy = self._copy_state() if isinstance(state_copy, BaseModel): diff --git a/lib/crewai/src/crewai/flow/runtime/_action_resolvers.py b/lib/crewai/src/crewai/flow/runtime/_action_resolvers.py index 80512b11d..d71dfacaa 100644 --- a/lib/crewai/src/crewai/flow/runtime/_action_resolvers.py +++ b/lib/crewai/src/crewai/flow/runtime/_action_resolvers.py @@ -17,17 +17,22 @@ class InvalidActionRefError(ValueError): super().__init__(f"invalid callable {ref!r}; expected 'module:qualname'") -def _resolve_code_action( - flow: Flow[Any], action: FlowActionDefinition -) -> Callable[..., Any]: - ref = action.ref +def import_ref(ref: str) -> Any: + """Import the object a `module:qualname` reference points to.""" module_name, _, qualname = ref.partition(":") if "<" in ref or not module_name or not qualname: raise InvalidActionRefError(ref) try: - target = attrgetter(qualname)(importlib.import_module(module_name)) + return attrgetter(qualname)(importlib.import_module(module_name)) except (ImportError, AttributeError) as e: raise InvalidActionRefError(ref) from e + + +def _resolve_code_action( + flow: Flow[Any], action: FlowActionDefinition +) -> Callable[..., Any]: + ref = action.ref + target = import_ref(ref) if not callable(target): raise InvalidActionRefError(ref) handler = cast(Callable[..., Any], target) diff --git a/lib/crewai/tests/test_flow_from_definition.py b/lib/crewai/tests/test_flow_from_definition.py index 14591ca69..0df09caa6 100644 --- a/lib/crewai/tests/test_flow_from_definition.py +++ b/lib/crewai/tests/test_flow_from_definition.py @@ -1,16 +1,27 @@ from __future__ import annotations +from collections import defaultdict +from typing import Any, ClassVar + import pytest from pydantic import ValidationError from crewai.events.event_bus import crewai_event_bus from crewai.events.types.flow_events import ( + FlowCreatedEvent, + FlowFinishedEvent, + FlowStartedEvent, MethodExecutionFinishedEvent, MethodExecutionStartedEvent, ) -from crewai.flow import Flow, and_, listen, or_, router, start +from crewai.flow import Flow, and_, human_feedback, listen, or_, router, start +from crewai.flow.async_feedback import PendingFeedbackContext from crewai.flow.flow import FlowState -from crewai.flow.flow_definition import FlowDefinition +from crewai.flow.flow_definition import FlowConfigDefinition, FlowDefinition +from crewai.flow.persistence import persist +from crewai.flow.persistence.base import FlowPersistence +from crewai.state.checkpoint_config import CheckpointConfig +from crewai.types.streaming import FlowStreamingOutput class ChainFlow(Flow): @@ -550,3 +561,454 @@ def test_unknown_state_type_falls_back_to_dict(caplog): result = flow.kickoff() assert result == "hello" assert flow.state["begin_ran"] is True + + +class StubInputProvider: + def request_input(self, message, flow, metadata=None): + return "stub" + + +class ConfiguredFlow(Flow): + suppress_flow_events = True + max_method_calls = 5 + input_provider = StubInputProvider() + + @start() + def begin(self): + return "configured" + + +SUPPRESSED_CHAIN_YAML = ( + CHAIN_YAML + + """ +config: + suppress_flow_events: true +""" +) + +CAPPED_LOOP_YAML = ( + LOOP_YAML + + """ +config: + max_method_calls: 2 +""" +) + +STREAMING_CHAIN_YAML = ( + CHAIN_YAML + + """ +config: + stream: true +""" +) + +DEFERRED_CHAIN_YAML = ( + CHAIN_YAML + + """ +config: + defer_trace_finalization: true +""" +) + +INPUT_PROVIDER_CHAIN_YAML = ( + CHAIN_YAML + + f""" +config: + input_provider: {__name__}:StubInputProvider +""" +) + + +def _run_capturing_flow_lifecycle(yaml_str, event_types): + events = [] + with crewai_event_bus.scoped_handlers(): + for event_type in event_types: + + @crewai_event_bus.on(event_type) + def capture(source, event): + events.append(event) + + flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str)) + result = flow.kickoff() + return flow, result, events + + +_LIFECYCLE_EVENTS = [ + FlowCreatedEvent, + FlowStartedEvent, + FlowFinishedEvent, + MethodExecutionStartedEvent, + MethodExecutionFinishedEvent, +] + + +def test_config_suppress_flow_events_from_yaml(): + twin_events = [] + with crewai_event_bus.scoped_handlers(): + for event_type in _LIFECYCLE_EVENTS: + + @crewai_event_bus.on(event_type) + def capture(source, event): + twin_events.append(type(event).__name__) + + twin_result = ChainFlow(suppress_flow_events=True).kickoff() + + flow, result, events = _run_capturing_flow_lifecycle( + SUPPRESSED_CHAIN_YAML, _LIFECYCLE_EVENTS + ) + assert result == twin_result == "confirmed:True" + assert flow.suppress_flow_events is True + assert [type(e).__name__ for e in events] == twin_events + assert not any( + isinstance(e, (MethodExecutionStartedEvent, MethodExecutionFinishedEvent)) + for e in events + ) + + +def test_config_max_method_calls_from_yaml(): + flow = Flow.from_definition(FlowDefinition.from_yaml(CAPPED_LOOP_YAML)) + with pytest.raises(RecursionError, match="has been called 2 times"): + flow.kickoff() + + +def test_config_stream_from_yaml(): + flow = Flow.from_definition(FlowDefinition.from_yaml(STREAMING_CHAIN_YAML)) + streaming = flow.kickoff() + assert isinstance(streaming, FlowStreamingOutput) + for _ in streaming: + pass + assert streaming.result == "confirmed:True" + assert flow.stream is True + + +def test_config_defer_trace_finalization_from_yaml(): + _, _, baseline_events = _run_capturing_flow_lifecycle( + CHAIN_YAML, [FlowFinishedEvent] + ) + assert len(baseline_events) == 1 + + flow, result, deferred_events = _run_capturing_flow_lifecycle( + DEFERRED_CHAIN_YAML, [FlowFinishedEvent] + ) + assert result == "confirmed:True" + assert flow.defer_trace_finalization is True + assert deferred_events == [] + + +def test_config_checkpoint_from_yaml(tmp_path): + yaml_str = ( + CHAIN_YAML + + f""" +config: + checkpoint: + location: {tmp_path} +""" + ) + flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str)) + assert isinstance(flow.checkpoint, CheckpointConfig) + assert flow.checkpoint.location == str(tmp_path) + + +def test_config_input_provider_from_yaml(): + flow = Flow.from_definition(FlowDefinition.from_yaml(INPUT_PROVIDER_CHAIN_YAML)) + assert isinstance(flow.input_provider, StubInputProvider) + + +def test_round_trip_config_equivalence(): + class_flow = ConfiguredFlow() + definition = FlowDefinition.from_yaml(ConfiguredFlow.flow_definition().to_yaml()) + definition_flow = Flow.from_definition(definition) + + assert definition.config.suppress_flow_events is True + assert definition.config.max_method_calls == 5 + assert definition.config.input_provider == f"{__name__}:StubInputProvider" + assert definition_flow.suppress_flow_events is class_flow.suppress_flow_events + assert definition_flow.max_method_calls == class_flow.max_method_calls + assert isinstance(definition_flow.input_provider, StubInputProvider) + + class_result, class_events = _run_with_events(class_flow) + definition_result, definition_events = _run_with_events(definition_flow) + assert definition_result == class_result == "configured" + assert definition_events == class_events + + +def test_unknown_schema_rejected(): + with pytest.raises(ValidationError, match="schema"): + FlowDefinition.from_dict( + { + "schema": "crewai.flow/v2", + "name": "FutureSchema", + "methods": { + "begin": {"start": True, "do": {"ref": f"{__name__}:ChainFlow.begin"}} + }, + } + ) + + +def test_flow_config_definition_mirrors_flow_fields(): + for name, field in FlowConfigDefinition.model_fields.items(): + assert name in Flow.model_fields + assert field.get_default(call_default_factory=True) == Flow.model_fields[ + name + ].get_default(call_default_factory=True) + + +class DefinitionStoreBackend(FlowPersistence): + persistence_type: str = "DefinitionStoreBackend" + store: str = "default" + + saves: ClassVar[dict[str, list[tuple[str, dict[str, Any]]]]] = defaultdict(list) + pending: ClassVar[dict[str, tuple[dict[str, Any], PendingFeedbackContext]]] = {} + + def init_db(self) -> None: + pass + + def save_state(self, flow_uuid, method_name, state_data): + data = state_data if isinstance(state_data, dict) else state_data.model_dump() + DefinitionStoreBackend.saves[self.store].append((method_name, dict(data))) + + def load_state(self, flow_uuid): + for _, data in reversed(DefinitionStoreBackend.saves[self.store]): + if data.get("id") == flow_uuid: + return data + return None + + def save_pending_feedback(self, flow_uuid, context, state_data): + data = state_data if isinstance(state_data, dict) else state_data.model_dump() + DefinitionStoreBackend.pending[flow_uuid] = (dict(data), context) + + def load_pending_feedback(self, flow_uuid): + return DefinitionStoreBackend.pending.get(flow_uuid) + + def clear_pending_feedback(self, flow_uuid): + DefinitionStoreBackend.pending.pop(flow_uuid, None) + + +def _saved_methods(store): + return [name for name, _ in DefinitionStoreBackend.saves[store]] + + +class PersistedFlow(Flow): + @start() + def first(self): + self.state["count"] = self.state.get("count", 0) + 1 + return "one" + + @listen(first) + def second(self): + self.state["count"] += 1 + return "two" + + +def _flow_level_persist_yaml(store): + return f""" +schema: crewai.flow/v1 +name: PersistedFlow +persist: + enabled: true + persistence: + persistence_type: DefinitionStoreBackend + store: {store} +methods: + first: + do: + ref: {__name__}:PersistedFlow.first + start: true + second: + do: + ref: {__name__}:PersistedFlow.second + listen: first +""" + + +def _method_level_persist_yaml(store): + return f""" +schema: crewai.flow/v1 +name: PersistedFlow +methods: + first: + do: + ref: {__name__}:PersistedFlow.first + start: true + persist: + enabled: true + persistence: + persistence_type: DefinitionStoreBackend + store: {store} + second: + do: + ref: {__name__}:PersistedFlow.second + listen: first +""" + + +_CLASS_LEVEL_BACKEND = DefinitionStoreBackend(store="class-decorator") + + +@persist(_CLASS_LEVEL_BACKEND) +class ClassPersistedFlow(Flow): + @start() + def first(self): + self.state["count"] = self.state.get("count", 0) + 1 + return "one" + + @listen(first) + def second(self): + self.state["count"] += 1 + return "two" + + +_COMBINED_BACKEND = DefinitionStoreBackend(store="combined-decorator") + + +@persist(_COMBINED_BACKEND) +class CombinedPersistedFlow(Flow): + @start() + @persist(_COMBINED_BACKEND) + def first(self): + return "one" + + @listen(first) + def second(self): + return "two" + + +class MethodPersistedFlow(Flow): + @start() + @persist(DefinitionStoreBackend(store="method-decorator")) + def first(self): + self.state["count"] = self.state.get("count", 0) + 1 + return "one" + + @listen(first) + def second(self): + self.state["count"] += 1 + return "two" + + +def test_flow_level_persist_from_yaml_saves_once_per_method(): + yaml_str = _flow_level_persist_yaml("yaml-flow-level") + flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str)) + result = flow.kickoff() + + assert result == "two" + assert _saved_methods("yaml-flow-level") == ["first", "second"] + _, final_save = DefinitionStoreBackend.saves["yaml-flow-level"][-1] + assert final_save["count"] == 2 + assert final_save["id"] == flow.state["id"] + + +def test_method_level_persist_from_yaml_saves_only_that_method(): + yaml_str = _method_level_persist_yaml("yaml-method-level") + flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str)) + flow.kickoff() + + assert _saved_methods("yaml-method-level") == ["first"] + _, save = DefinitionStoreBackend.saves["yaml-method-level"][0] + assert save["count"] == 1 + + +def test_method_level_persist_disabled_wins_over_flow_level(): + yaml_str = f""" +schema: crewai.flow/v1 +name: PersistedFlow +persist: + enabled: true + persistence: + persistence_type: DefinitionStoreBackend + store: yaml-opt-out +methods: + first: + do: + ref: {__name__}:PersistedFlow.first + start: true + second: + do: + ref: {__name__}:PersistedFlow.second + listen: first + persist: + enabled: false +""" + flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str)) + flow.kickoff() + + assert _saved_methods("yaml-opt-out") == ["first"] + + +def test_persist_restore_by_id_from_yaml(): + yaml_str = _flow_level_persist_yaml("yaml-restore") + + flow1 = Flow.from_definition(FlowDefinition.from_yaml(yaml_str)) + flow1.kickoff() + assert flow1.state["count"] == 2 + + flow2 = Flow.from_definition(FlowDefinition.from_yaml(yaml_str)) + flow2.kickoff(inputs={"id": flow1.state["id"]}) + assert flow2.state["count"] == 4 + + +def test_combined_class_and_method_persist_saves_once_per_method(): + before = len(DefinitionStoreBackend.saves["combined-decorator"]) + CombinedPersistedFlow().kickoff() + + assert _saved_methods("combined-decorator")[before:] == ["first", "second"] + + +def test_method_level_persist_decorator_saves_only_that_method(): + before = len(DefinitionStoreBackend.saves["method-decorator"]) + MethodPersistedFlow().kickoff() + + assert _saved_methods("method-decorator")[before:] == ["first"] + + +def test_round_trip_persist_equivalence(): + definition = FlowDefinition.from_yaml(ClassPersistedFlow.flow_definition().to_yaml()) + + before = len(DefinitionStoreBackend.saves["class-decorator"]) + flow = Flow.from_definition(definition) + flow.kickoff() + + assert _saved_methods("class-decorator")[before:] == ["first", "second"] + + +def test_instance_persistence_overrides_definition_backend(): + before = len(DefinitionStoreBackend.saves["method-decorator"]) + flow = MethodPersistedFlow( + persistence=DefinitionStoreBackend(store="instance-override") + ) + flow.kickoff() + + assert _saved_methods("instance-override") == ["first"] + assert len(DefinitionStoreBackend.saves["method-decorator"]) == before + + +def test_resume_synthetic_completion_persists(): + backend = DefinitionStoreBackend(store="resume-synthetic") + + class ResumableFlow(Flow): + @start() + @persist(DefinitionStoreBackend(store="resume-synthetic")) + @human_feedback(message="Review:") + def generate(self): + return "content" + + @listen(generate) + def process(self, result): + return "done" + + context = PendingFeedbackContext( + flow_id="resume-persist-1", + flow_class="ResumableFlow", + method_name="generate", + method_output="content", + message="Review:", + ) + backend.save_pending_feedback( + "resume-persist-1", context, {"id": "resume-persist-1"} + ) + + flow = ResumableFlow.from_pending("resume-persist-1", backend) + result = flow.resume("looks good") + + assert result == "done" + assert _saved_methods("resume-synthetic") == ["generate"]