mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-02 05:38:12 +00:00
Wire config and persistence from FlowDefinition into the runtime
`from_definition` was silently dropping all config fields; it now passes `config.model_dump()` so suppress_flow_events, max_method_calls, etc. actually apply. Persistence is now engine-driven: `_persist_method_completion` fires after every method using the definition's persist metadata, so `@persist` no longer needs to wrap methods — it just stamps them.
This commit is contained in:
@@ -219,16 +219,19 @@ def _build_config_definition(
|
||||
) -> FlowConfigDefinition:
|
||||
config_field_names = set(FlowConfigDefinition.model_fields)
|
||||
field_defaults = {
|
||||
name: field.default
|
||||
name: field.get_default(call_default_factory=True)
|
||||
for name, field in getattr(flow_class, "model_fields", {}).items()
|
||||
if name in config_field_names
|
||||
}
|
||||
values: dict[str, Any] = {}
|
||||
for field_name, default in field_defaults.items():
|
||||
value = getattr(flow_class, field_name, default)
|
||||
values[field_name] = _serialize_static_value(
|
||||
value, diagnostics, f"config.{field_name}"
|
||||
)
|
||||
if field_name == "input_provider":
|
||||
values[field_name] = None if value is None else _object_ref(value)
|
||||
else:
|
||||
values[field_name] = _serialize_static_value(
|
||||
value, diagnostics, f"config.{field_name}"
|
||||
)
|
||||
return FlowConfigDefinition(**values)
|
||||
|
||||
|
||||
|
||||
@@ -64,10 +64,12 @@ class FlowConfigDefinition(BaseModel):
|
||||
|
||||
tracing: bool | None = None
|
||||
stream: bool = False
|
||||
memory: Any = None
|
||||
input_provider: Any = None
|
||||
memory: dict[str, Any] | None = None
|
||||
input_provider: str | None = None
|
||||
suppress_flow_events: bool = False
|
||||
max_method_calls: int = 100
|
||||
defer_trace_finalization: bool = False
|
||||
checkpoint: bool | dict[str, Any] | None = None
|
||||
|
||||
|
||||
class FlowPersistenceDefinition(BaseModel):
|
||||
@@ -75,7 +77,7 @@ class FlowPersistenceDefinition(BaseModel):
|
||||
|
||||
enabled: bool = False
|
||||
verbose: bool = False
|
||||
persistence: Any = None
|
||||
persistence: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class FlowHumanFeedbackDefinition(BaseModel):
|
||||
@@ -126,7 +128,9 @@ class FlowDefinition(BaseModel):
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, arbitrary_types_allowed=True)
|
||||
|
||||
schema_: str = Field(default="crewai.flow/v1", alias="schema")
|
||||
schema_: TypingLiteral["crewai.flow/v1"] = Field(
|
||||
default="crewai.flow/v1", alias="schema"
|
||||
)
|
||||
name: str
|
||||
description: str | None = None
|
||||
state: FlowStateDefinition | None = None
|
||||
|
||||
@@ -24,12 +24,11 @@ Example:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import functools
|
||||
import logging
|
||||
from types import SimpleNamespace
|
||||
from typing import TYPE_CHECKING, Any, Final, TypeVar, cast
|
||||
from typing import TYPE_CHECKING, Any, Final, TypeVar
|
||||
|
||||
from crewai_core.printer import PRINTER
|
||||
from pydantic import BaseModel
|
||||
@@ -39,7 +38,7 @@ from crewai.flow.persistence.factory import default_flow_persistence
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.flow import Flow
|
||||
from crewai.flow.runtime import Flow
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -66,14 +65,6 @@ def _stamp_persistence_metadata(
|
||||
)
|
||||
|
||||
|
||||
_PRESERVED_FLOW_ATTRS: Final[tuple[str, ...]] = (
|
||||
"__human_feedback_config__",
|
||||
"__flow_persistence_config__",
|
||||
"__flow_method_definition__",
|
||||
"_human_feedback_llm",
|
||||
)
|
||||
|
||||
|
||||
class PersistenceDecorator:
|
||||
"""Class to handle flow state persistence with consistent logging."""
|
||||
|
||||
@@ -164,6 +155,10 @@ def persist(
|
||||
states. When applied at the method level, it persists only that method's
|
||||
state.
|
||||
|
||||
The decorator is a pure metadata stamper: it records the persistence
|
||||
configuration on the class or method, and the Flow engine saves state
|
||||
after each persisted method completes, driven by the flow's definition.
|
||||
|
||||
Args:
|
||||
persistence: Optional FlowPersistence implementation to use.
|
||||
If not provided, uses ``default_flow_persistence()`` (the
|
||||
@@ -202,111 +197,10 @@ def persist(
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
target.__init__ = new_init # type: ignore[misc]
|
||||
|
||||
# Preserve original methods' decorators
|
||||
original_methods = {
|
||||
name: method
|
||||
for name, method in target.__dict__.items()
|
||||
if callable(method)
|
||||
and (
|
||||
hasattr(method, "__is_flow_method__")
|
||||
or hasattr(method, "__flow_method_definition__")
|
||||
)
|
||||
}
|
||||
|
||||
for name, method in original_methods.items():
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
# Closure captures the current name and method
|
||||
def create_async_wrapper(
|
||||
method_name: str, original_method: Callable[..., Any]
|
||||
) -> Callable[..., Any]:
|
||||
@functools.wraps(original_method)
|
||||
async def method_wrapper(
|
||||
self: Any, *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
result = await original_method(self, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(
|
||||
self, method_name, actual_persistence, verbose
|
||||
)
|
||||
return result
|
||||
|
||||
return method_wrapper
|
||||
|
||||
wrapped = create_async_wrapper(name, method)
|
||||
|
||||
for attr in _PRESERVED_FLOW_ATTRS:
|
||||
if hasattr(method, attr):
|
||||
setattr(wrapped, attr, getattr(method, attr))
|
||||
wrapped.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
|
||||
setattr(target, name, wrapped)
|
||||
else:
|
||||
|
||||
def create_sync_wrapper(
|
||||
method_name: str, original_method: Callable[..., Any]
|
||||
) -> Callable[..., Any]:
|
||||
@functools.wraps(original_method)
|
||||
def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
result = original_method(self, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(
|
||||
self, method_name, actual_persistence, verbose
|
||||
)
|
||||
return result
|
||||
|
||||
return method_wrapper
|
||||
|
||||
wrapped = create_sync_wrapper(name, method)
|
||||
|
||||
for attr in _PRESERVED_FLOW_ATTRS:
|
||||
if hasattr(method, attr):
|
||||
setattr(wrapped, attr, getattr(method, attr))
|
||||
wrapped.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
|
||||
setattr(target, name, wrapped)
|
||||
|
||||
return target
|
||||
method = target
|
||||
method.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
_stamp_persistence_metadata(method, actual_persistence, verbose)
|
||||
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
|
||||
@functools.wraps(method)
|
||||
async def method_async_wrapper(
|
||||
flow_instance: Any, *args: Any, **kwargs: Any
|
||||
) -> T:
|
||||
method_coro = method(flow_instance, *args, **kwargs)
|
||||
if asyncio.iscoroutine(method_coro):
|
||||
result = await method_coro
|
||||
else:
|
||||
result = method_coro
|
||||
PersistenceDecorator.persist_state(
|
||||
flow_instance, method.__name__, actual_persistence, verbose
|
||||
)
|
||||
return cast(T, result)
|
||||
|
||||
for attr in _PRESERVED_FLOW_ATTRS:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_async_wrapper, attr, getattr(method, attr))
|
||||
method_async_wrapper.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
_stamp_persistence_metadata(
|
||||
method_async_wrapper, actual_persistence, verbose
|
||||
)
|
||||
return cast(Callable[..., T], method_async_wrapper)
|
||||
|
||||
@functools.wraps(method)
|
||||
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
||||
result = method(flow_instance, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(
|
||||
flow_instance, method.__name__, actual_persistence, verbose
|
||||
)
|
||||
return result
|
||||
|
||||
for attr in _PRESERVED_FLOW_ATTRS:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_sync_wrapper, attr, getattr(method, attr))
|
||||
method_sync_wrapper.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
_stamp_persistence_metadata(method_sync_wrapper, actual_persistence, verbose)
|
||||
return cast(Callable[..., T], method_sync_wrapper)
|
||||
target.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
_stamp_persistence_metadata(target, actual_persistence, verbose)
|
||||
return target
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -96,6 +96,7 @@ from crewai.flow.flow_definition import (
|
||||
FlowDefinition,
|
||||
FlowDefinitionCondition,
|
||||
FlowMethodDefinition,
|
||||
FlowPersistenceDefinition,
|
||||
FlowStateDefinition,
|
||||
)
|
||||
from crewai.flow.flow_wrappers import (
|
||||
@@ -282,9 +283,12 @@ def _serialize_persistence(value: Any) -> dict[str, Any] | None:
|
||||
def _validate_input_provider(value: Any) -> Any:
|
||||
if value is None or isinstance(value, InputProvider):
|
||||
return value
|
||||
from crewai.types.callback import _dotted_path_to_instance
|
||||
if isinstance(value, str) and ":" in value:
|
||||
resolved = _resolve_input_provider_ref(value)
|
||||
else:
|
||||
from crewai.types.callback import _dotted_path_to_instance
|
||||
|
||||
resolved = _dotted_path_to_instance(value)
|
||||
resolved = _dotted_path_to_instance(value)
|
||||
if resolved is None or isinstance(resolved, InputProvider):
|
||||
return resolved
|
||||
raise ValueError(
|
||||
@@ -293,6 +297,15 @@ def _validate_input_provider(value: Any) -> Any:
|
||||
)
|
||||
|
||||
|
||||
def _resolve_input_provider_ref(ref: str) -> Any:
|
||||
from crewai.flow.runtime._action_resolvers import import_ref
|
||||
|
||||
target = import_ref(ref)
|
||||
if inspect.isclass(target):
|
||||
return target()
|
||||
return target
|
||||
|
||||
|
||||
def _serialize_input_provider(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
@@ -751,7 +764,10 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
@classmethod
|
||||
def from_definition(cls, definition: FlowDefinition) -> Flow[Any]:
|
||||
"""Build a runnable Flow directly from a definition; no subclass required."""
|
||||
return cls.model_validate({}, context={"flow_definition": definition})
|
||||
return cls.model_validate(
|
||||
definition.config.model_dump(),
|
||||
context={"flow_definition": definition},
|
||||
)
|
||||
|
||||
def _start_method_names(self) -> list[FlowMethodName]:
|
||||
return [
|
||||
@@ -960,6 +976,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
_input_history: list[InputHistoryEntry] = PrivateAttr(default_factory=list)
|
||||
_state: Any = PrivateAttr(default=None)
|
||||
_deferred_flow_started_event_id: str | None = PrivateAttr(default=None)
|
||||
_persist_backends: dict[int, FlowPersistence] = PrivateAttr(default_factory=dict)
|
||||
|
||||
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]: # type: ignore[override]
|
||||
class _FlowGeneric(cls): # type: ignore[valid-type,misc]
|
||||
@@ -998,6 +1015,14 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
else self._class_bound_methods()
|
||||
)
|
||||
|
||||
flow_persist = self._definition.persist
|
||||
if (
|
||||
self.persistence is None
|
||||
and flow_persist is not None
|
||||
and flow_persist.enabled
|
||||
):
|
||||
self.persistence = self._resolve_persist_backend(flow_persist)
|
||||
|
||||
if self._state is None:
|
||||
self._state = self._create_initial_state()
|
||||
|
||||
@@ -1524,6 +1549,8 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
|
||||
self._completed_methods.add(FlowMethodName(context.method_name))
|
||||
|
||||
self._persist_method_completion(FlowMethodName(context.method_name))
|
||||
|
||||
self._pending_feedback_context = None
|
||||
|
||||
if self.persistence is not None:
|
||||
@@ -2703,6 +2730,8 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
|
||||
self._completed_methods.add(method_name)
|
||||
|
||||
self._persist_method_completion(method_name)
|
||||
|
||||
finished_event_id: str | None = None
|
||||
if not self.suppress_flow_events:
|
||||
finished_event = MethodExecutionFinishedEvent(
|
||||
@@ -2761,6 +2790,48 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
self._event_futures.append(future)
|
||||
raise e
|
||||
|
||||
def _persist_method_completion(self, method_name: FlowMethodName) -> None:
|
||||
method_definition = self._definition.methods.get(method_name)
|
||||
persist_definition = (
|
||||
method_definition.persist
|
||||
if method_definition is not None and method_definition.persist is not None
|
||||
else self._definition.persist
|
||||
)
|
||||
if persist_definition is None or not persist_definition.enabled:
|
||||
return
|
||||
|
||||
from crewai.flow.persistence.decorators import PersistenceDecorator
|
||||
|
||||
backend = self.persistence or self._persist_backend_for(persist_definition)
|
||||
PersistenceDecorator.persist_state(
|
||||
self, method_name, backend, verbose=persist_definition.verbose
|
||||
)
|
||||
|
||||
def _persist_backend_for(
|
||||
self, persist_definition: FlowPersistenceDefinition
|
||||
) -> FlowPersistence:
|
||||
cached = self._persist_backends.get(id(persist_definition))
|
||||
if cached is None:
|
||||
cached = self._resolve_persist_backend(persist_definition)
|
||||
self._persist_backends[id(persist_definition)] = cached
|
||||
return cached
|
||||
|
||||
def _resolve_persist_backend(
|
||||
self, persist_definition: FlowPersistenceDefinition
|
||||
) -> FlowPersistence:
|
||||
if persist_definition.persistence is None:
|
||||
from crewai.flow.persistence.factory import default_flow_persistence
|
||||
|
||||
return default_flow_persistence()
|
||||
resolved = _resolve_persistence(persist_definition.persistence)
|
||||
if not isinstance(resolved, FlowPersistence):
|
||||
raise ValueError(
|
||||
f"Cannot resolve persistence backend "
|
||||
f"{persist_definition.persistence!r} from the flow definition "
|
||||
f"for flow {self._definition.name!r}."
|
||||
)
|
||||
return resolved
|
||||
|
||||
def _copy_and_serialize_state(self) -> dict[str, Any]:
|
||||
state_copy = self._copy_state()
|
||||
if isinstance(state_copy, BaseModel):
|
||||
|
||||
@@ -17,17 +17,22 @@ class InvalidActionRefError(ValueError):
|
||||
super().__init__(f"invalid callable {ref!r}; expected 'module:qualname'")
|
||||
|
||||
|
||||
def _resolve_code_action(
|
||||
flow: Flow[Any], action: FlowActionDefinition
|
||||
) -> Callable[..., Any]:
|
||||
ref = action.ref
|
||||
def import_ref(ref: str) -> Any:
|
||||
"""Import the object a `module:qualname` reference points to."""
|
||||
module_name, _, qualname = ref.partition(":")
|
||||
if "<" in ref or not module_name or not qualname:
|
||||
raise InvalidActionRefError(ref)
|
||||
try:
|
||||
target = attrgetter(qualname)(importlib.import_module(module_name))
|
||||
return attrgetter(qualname)(importlib.import_module(module_name))
|
||||
except (ImportError, AttributeError) as e:
|
||||
raise InvalidActionRefError(ref) from e
|
||||
|
||||
|
||||
def _resolve_code_action(
|
||||
flow: Flow[Any], action: FlowActionDefinition
|
||||
) -> Callable[..., Any]:
|
||||
ref = action.ref
|
||||
target = import_ref(ref)
|
||||
if not callable(target):
|
||||
raise InvalidActionRefError(ref)
|
||||
handler = cast(Callable[..., Any], target)
|
||||
|
||||
@@ -1,16 +1,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowCreatedEvent,
|
||||
FlowFinishedEvent,
|
||||
FlowStartedEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.flow import Flow, and_, listen, or_, router, start
|
||||
from crewai.flow import Flow, and_, human_feedback, listen, or_, router, start
|
||||
from crewai.flow.async_feedback import PendingFeedbackContext
|
||||
from crewai.flow.flow import FlowState
|
||||
from crewai.flow.flow_definition import FlowDefinition
|
||||
from crewai.flow.flow_definition import FlowConfigDefinition, FlowDefinition
|
||||
from crewai.flow.persistence import persist
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.state.checkpoint_config import CheckpointConfig
|
||||
from crewai.types.streaming import FlowStreamingOutput
|
||||
|
||||
|
||||
class ChainFlow(Flow):
|
||||
@@ -550,3 +561,454 @@ def test_unknown_state_type_falls_back_to_dict(caplog):
|
||||
result = flow.kickoff()
|
||||
assert result == "hello"
|
||||
assert flow.state["begin_ran"] is True
|
||||
|
||||
|
||||
class StubInputProvider:
|
||||
def request_input(self, message, flow, metadata=None):
|
||||
return "stub"
|
||||
|
||||
|
||||
class ConfiguredFlow(Flow):
|
||||
suppress_flow_events = True
|
||||
max_method_calls = 5
|
||||
input_provider = StubInputProvider()
|
||||
|
||||
@start()
|
||||
def begin(self):
|
||||
return "configured"
|
||||
|
||||
|
||||
SUPPRESSED_CHAIN_YAML = (
|
||||
CHAIN_YAML
|
||||
+ """
|
||||
config:
|
||||
suppress_flow_events: true
|
||||
"""
|
||||
)
|
||||
|
||||
CAPPED_LOOP_YAML = (
|
||||
LOOP_YAML
|
||||
+ """
|
||||
config:
|
||||
max_method_calls: 2
|
||||
"""
|
||||
)
|
||||
|
||||
STREAMING_CHAIN_YAML = (
|
||||
CHAIN_YAML
|
||||
+ """
|
||||
config:
|
||||
stream: true
|
||||
"""
|
||||
)
|
||||
|
||||
DEFERRED_CHAIN_YAML = (
|
||||
CHAIN_YAML
|
||||
+ """
|
||||
config:
|
||||
defer_trace_finalization: true
|
||||
"""
|
||||
)
|
||||
|
||||
INPUT_PROVIDER_CHAIN_YAML = (
|
||||
CHAIN_YAML
|
||||
+ f"""
|
||||
config:
|
||||
input_provider: {__name__}:StubInputProvider
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def _run_capturing_flow_lifecycle(yaml_str, event_types):
|
||||
events = []
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
for event_type in event_types:
|
||||
|
||||
@crewai_event_bus.on(event_type)
|
||||
def capture(source, event):
|
||||
events.append(event)
|
||||
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
|
||||
result = flow.kickoff()
|
||||
return flow, result, events
|
||||
|
||||
|
||||
_LIFECYCLE_EVENTS = [
|
||||
FlowCreatedEvent,
|
||||
FlowStartedEvent,
|
||||
FlowFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
]
|
||||
|
||||
|
||||
def test_config_suppress_flow_events_from_yaml():
|
||||
twin_events = []
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
for event_type in _LIFECYCLE_EVENTS:
|
||||
|
||||
@crewai_event_bus.on(event_type)
|
||||
def capture(source, event):
|
||||
twin_events.append(type(event).__name__)
|
||||
|
||||
twin_result = ChainFlow(suppress_flow_events=True).kickoff()
|
||||
|
||||
flow, result, events = _run_capturing_flow_lifecycle(
|
||||
SUPPRESSED_CHAIN_YAML, _LIFECYCLE_EVENTS
|
||||
)
|
||||
assert result == twin_result == "confirmed:True"
|
||||
assert flow.suppress_flow_events is True
|
||||
assert [type(e).__name__ for e in events] == twin_events
|
||||
assert not any(
|
||||
isinstance(e, (MethodExecutionStartedEvent, MethodExecutionFinishedEvent))
|
||||
for e in events
|
||||
)
|
||||
|
||||
|
||||
def test_config_max_method_calls_from_yaml():
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(CAPPED_LOOP_YAML))
|
||||
with pytest.raises(RecursionError, match="has been called 2 times"):
|
||||
flow.kickoff()
|
||||
|
||||
|
||||
def test_config_stream_from_yaml():
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(STREAMING_CHAIN_YAML))
|
||||
streaming = flow.kickoff()
|
||||
assert isinstance(streaming, FlowStreamingOutput)
|
||||
for _ in streaming:
|
||||
pass
|
||||
assert streaming.result == "confirmed:True"
|
||||
assert flow.stream is True
|
||||
|
||||
|
||||
def test_config_defer_trace_finalization_from_yaml():
|
||||
_, _, baseline_events = _run_capturing_flow_lifecycle(
|
||||
CHAIN_YAML, [FlowFinishedEvent]
|
||||
)
|
||||
assert len(baseline_events) == 1
|
||||
|
||||
flow, result, deferred_events = _run_capturing_flow_lifecycle(
|
||||
DEFERRED_CHAIN_YAML, [FlowFinishedEvent]
|
||||
)
|
||||
assert result == "confirmed:True"
|
||||
assert flow.defer_trace_finalization is True
|
||||
assert deferred_events == []
|
||||
|
||||
|
||||
def test_config_checkpoint_from_yaml(tmp_path):
|
||||
yaml_str = (
|
||||
CHAIN_YAML
|
||||
+ f"""
|
||||
config:
|
||||
checkpoint:
|
||||
location: {tmp_path}
|
||||
"""
|
||||
)
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
|
||||
assert isinstance(flow.checkpoint, CheckpointConfig)
|
||||
assert flow.checkpoint.location == str(tmp_path)
|
||||
|
||||
|
||||
def test_config_input_provider_from_yaml():
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(INPUT_PROVIDER_CHAIN_YAML))
|
||||
assert isinstance(flow.input_provider, StubInputProvider)
|
||||
|
||||
|
||||
def test_round_trip_config_equivalence():
|
||||
class_flow = ConfiguredFlow()
|
||||
definition = FlowDefinition.from_yaml(ConfiguredFlow.flow_definition().to_yaml())
|
||||
definition_flow = Flow.from_definition(definition)
|
||||
|
||||
assert definition.config.suppress_flow_events is True
|
||||
assert definition.config.max_method_calls == 5
|
||||
assert definition.config.input_provider == f"{__name__}:StubInputProvider"
|
||||
assert definition_flow.suppress_flow_events is class_flow.suppress_flow_events
|
||||
assert definition_flow.max_method_calls == class_flow.max_method_calls
|
||||
assert isinstance(definition_flow.input_provider, StubInputProvider)
|
||||
|
||||
class_result, class_events = _run_with_events(class_flow)
|
||||
definition_result, definition_events = _run_with_events(definition_flow)
|
||||
assert definition_result == class_result == "configured"
|
||||
assert definition_events == class_events
|
||||
|
||||
|
||||
def test_unknown_schema_rejected():
|
||||
with pytest.raises(ValidationError, match="schema"):
|
||||
FlowDefinition.from_dict(
|
||||
{
|
||||
"schema": "crewai.flow/v2",
|
||||
"name": "FutureSchema",
|
||||
"methods": {
|
||||
"begin": {"start": True, "do": {"ref": f"{__name__}:ChainFlow.begin"}}
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_flow_config_definition_mirrors_flow_fields():
|
||||
for name, field in FlowConfigDefinition.model_fields.items():
|
||||
assert name in Flow.model_fields
|
||||
assert field.get_default(call_default_factory=True) == Flow.model_fields[
|
||||
name
|
||||
].get_default(call_default_factory=True)
|
||||
|
||||
|
||||
class DefinitionStoreBackend(FlowPersistence):
|
||||
persistence_type: str = "DefinitionStoreBackend"
|
||||
store: str = "default"
|
||||
|
||||
saves: ClassVar[dict[str, list[tuple[str, dict[str, Any]]]]] = defaultdict(list)
|
||||
pending: ClassVar[dict[str, tuple[dict[str, Any], PendingFeedbackContext]]] = {}
|
||||
|
||||
def init_db(self) -> None:
|
||||
pass
|
||||
|
||||
def save_state(self, flow_uuid, method_name, state_data):
|
||||
data = state_data if isinstance(state_data, dict) else state_data.model_dump()
|
||||
DefinitionStoreBackend.saves[self.store].append((method_name, dict(data)))
|
||||
|
||||
def load_state(self, flow_uuid):
|
||||
for _, data in reversed(DefinitionStoreBackend.saves[self.store]):
|
||||
if data.get("id") == flow_uuid:
|
||||
return data
|
||||
return None
|
||||
|
||||
def save_pending_feedback(self, flow_uuid, context, state_data):
|
||||
data = state_data if isinstance(state_data, dict) else state_data.model_dump()
|
||||
DefinitionStoreBackend.pending[flow_uuid] = (dict(data), context)
|
||||
|
||||
def load_pending_feedback(self, flow_uuid):
|
||||
return DefinitionStoreBackend.pending.get(flow_uuid)
|
||||
|
||||
def clear_pending_feedback(self, flow_uuid):
|
||||
DefinitionStoreBackend.pending.pop(flow_uuid, None)
|
||||
|
||||
|
||||
def _saved_methods(store):
|
||||
return [name for name, _ in DefinitionStoreBackend.saves[store]]
|
||||
|
||||
|
||||
class PersistedFlow(Flow):
|
||||
@start()
|
||||
def first(self):
|
||||
self.state["count"] = self.state.get("count", 0) + 1
|
||||
return "one"
|
||||
|
||||
@listen(first)
|
||||
def second(self):
|
||||
self.state["count"] += 1
|
||||
return "two"
|
||||
|
||||
|
||||
def _flow_level_persist_yaml(store):
|
||||
return f"""
|
||||
schema: crewai.flow/v1
|
||||
name: PersistedFlow
|
||||
persist:
|
||||
enabled: true
|
||||
persistence:
|
||||
persistence_type: DefinitionStoreBackend
|
||||
store: {store}
|
||||
methods:
|
||||
first:
|
||||
do:
|
||||
ref: {__name__}:PersistedFlow.first
|
||||
start: true
|
||||
second:
|
||||
do:
|
||||
ref: {__name__}:PersistedFlow.second
|
||||
listen: first
|
||||
"""
|
||||
|
||||
|
||||
def _method_level_persist_yaml(store):
|
||||
return f"""
|
||||
schema: crewai.flow/v1
|
||||
name: PersistedFlow
|
||||
methods:
|
||||
first:
|
||||
do:
|
||||
ref: {__name__}:PersistedFlow.first
|
||||
start: true
|
||||
persist:
|
||||
enabled: true
|
||||
persistence:
|
||||
persistence_type: DefinitionStoreBackend
|
||||
store: {store}
|
||||
second:
|
||||
do:
|
||||
ref: {__name__}:PersistedFlow.second
|
||||
listen: first
|
||||
"""
|
||||
|
||||
|
||||
_CLASS_LEVEL_BACKEND = DefinitionStoreBackend(store="class-decorator")
|
||||
|
||||
|
||||
@persist(_CLASS_LEVEL_BACKEND)
|
||||
class ClassPersistedFlow(Flow):
|
||||
@start()
|
||||
def first(self):
|
||||
self.state["count"] = self.state.get("count", 0) + 1
|
||||
return "one"
|
||||
|
||||
@listen(first)
|
||||
def second(self):
|
||||
self.state["count"] += 1
|
||||
return "two"
|
||||
|
||||
|
||||
_COMBINED_BACKEND = DefinitionStoreBackend(store="combined-decorator")
|
||||
|
||||
|
||||
@persist(_COMBINED_BACKEND)
|
||||
class CombinedPersistedFlow(Flow):
|
||||
@start()
|
||||
@persist(_COMBINED_BACKEND)
|
||||
def first(self):
|
||||
return "one"
|
||||
|
||||
@listen(first)
|
||||
def second(self):
|
||||
return "two"
|
||||
|
||||
|
||||
class MethodPersistedFlow(Flow):
|
||||
@start()
|
||||
@persist(DefinitionStoreBackend(store="method-decorator"))
|
||||
def first(self):
|
||||
self.state["count"] = self.state.get("count", 0) + 1
|
||||
return "one"
|
||||
|
||||
@listen(first)
|
||||
def second(self):
|
||||
self.state["count"] += 1
|
||||
return "two"
|
||||
|
||||
|
||||
def test_flow_level_persist_from_yaml_saves_once_per_method():
|
||||
yaml_str = _flow_level_persist_yaml("yaml-flow-level")
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
|
||||
result = flow.kickoff()
|
||||
|
||||
assert result == "two"
|
||||
assert _saved_methods("yaml-flow-level") == ["first", "second"]
|
||||
_, final_save = DefinitionStoreBackend.saves["yaml-flow-level"][-1]
|
||||
assert final_save["count"] == 2
|
||||
assert final_save["id"] == flow.state["id"]
|
||||
|
||||
|
||||
def test_method_level_persist_from_yaml_saves_only_that_method():
|
||||
yaml_str = _method_level_persist_yaml("yaml-method-level")
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
|
||||
flow.kickoff()
|
||||
|
||||
assert _saved_methods("yaml-method-level") == ["first"]
|
||||
_, save = DefinitionStoreBackend.saves["yaml-method-level"][0]
|
||||
assert save["count"] == 1
|
||||
|
||||
|
||||
def test_method_level_persist_disabled_wins_over_flow_level():
|
||||
yaml_str = f"""
|
||||
schema: crewai.flow/v1
|
||||
name: PersistedFlow
|
||||
persist:
|
||||
enabled: true
|
||||
persistence:
|
||||
persistence_type: DefinitionStoreBackend
|
||||
store: yaml-opt-out
|
||||
methods:
|
||||
first:
|
||||
do:
|
||||
ref: {__name__}:PersistedFlow.first
|
||||
start: true
|
||||
second:
|
||||
do:
|
||||
ref: {__name__}:PersistedFlow.second
|
||||
listen: first
|
||||
persist:
|
||||
enabled: false
|
||||
"""
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
|
||||
flow.kickoff()
|
||||
|
||||
assert _saved_methods("yaml-opt-out") == ["first"]
|
||||
|
||||
|
||||
def test_persist_restore_by_id_from_yaml():
|
||||
yaml_str = _flow_level_persist_yaml("yaml-restore")
|
||||
|
||||
flow1 = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
|
||||
flow1.kickoff()
|
||||
assert flow1.state["count"] == 2
|
||||
|
||||
flow2 = Flow.from_definition(FlowDefinition.from_yaml(yaml_str))
|
||||
flow2.kickoff(inputs={"id": flow1.state["id"]})
|
||||
assert flow2.state["count"] == 4
|
||||
|
||||
|
||||
def test_combined_class_and_method_persist_saves_once_per_method():
|
||||
before = len(DefinitionStoreBackend.saves["combined-decorator"])
|
||||
CombinedPersistedFlow().kickoff()
|
||||
|
||||
assert _saved_methods("combined-decorator")[before:] == ["first", "second"]
|
||||
|
||||
|
||||
def test_method_level_persist_decorator_saves_only_that_method():
|
||||
before = len(DefinitionStoreBackend.saves["method-decorator"])
|
||||
MethodPersistedFlow().kickoff()
|
||||
|
||||
assert _saved_methods("method-decorator")[before:] == ["first"]
|
||||
|
||||
|
||||
def test_round_trip_persist_equivalence():
|
||||
definition = FlowDefinition.from_yaml(ClassPersistedFlow.flow_definition().to_yaml())
|
||||
|
||||
before = len(DefinitionStoreBackend.saves["class-decorator"])
|
||||
flow = Flow.from_definition(definition)
|
||||
flow.kickoff()
|
||||
|
||||
assert _saved_methods("class-decorator")[before:] == ["first", "second"]
|
||||
|
||||
|
||||
def test_instance_persistence_overrides_definition_backend():
|
||||
before = len(DefinitionStoreBackend.saves["method-decorator"])
|
||||
flow = MethodPersistedFlow(
|
||||
persistence=DefinitionStoreBackend(store="instance-override")
|
||||
)
|
||||
flow.kickoff()
|
||||
|
||||
assert _saved_methods("instance-override") == ["first"]
|
||||
assert len(DefinitionStoreBackend.saves["method-decorator"]) == before
|
||||
|
||||
|
||||
def test_resume_synthetic_completion_persists():
|
||||
backend = DefinitionStoreBackend(store="resume-synthetic")
|
||||
|
||||
class ResumableFlow(Flow):
|
||||
@start()
|
||||
@persist(DefinitionStoreBackend(store="resume-synthetic"))
|
||||
@human_feedback(message="Review:")
|
||||
def generate(self):
|
||||
return "content"
|
||||
|
||||
@listen(generate)
|
||||
def process(self, result):
|
||||
return "done"
|
||||
|
||||
context = PendingFeedbackContext(
|
||||
flow_id="resume-persist-1",
|
||||
flow_class="ResumableFlow",
|
||||
method_name="generate",
|
||||
method_output="content",
|
||||
message="Review:",
|
||||
)
|
||||
backend.save_pending_feedback(
|
||||
"resume-persist-1", context, {"id": "resume-persist-1"}
|
||||
)
|
||||
|
||||
flow = ResumableFlow.from_pending("resume-persist-1", backend)
|
||||
result = flow.resume("looks good")
|
||||
|
||||
assert result == "done"
|
||||
assert _saved_methods("resume-synthetic") == ["generate"]
|
||||
|
||||
Reference in New Issue
Block a user