""" Decorators for flow state persistence. Example: ```python from crewai.flow.flow import Flow, start from crewai.flow.persistence import persist, SQLiteFlowPersistence class MyFlow(Flow): @start() @persist(SQLiteFlowPersistence()) def sync_method(self): # Synchronous method implementation pass @start() @persist(SQLiteFlowPersistence()) async def async_method(self): # Asynchronous method implementation await some_async_operation() ``` """ import asyncio import functools import logging from typing import ( Any, Callable, Optional, Type, TypeVar, Union, cast, ) from pydantic import BaseModel from crewai.flow.persistence.base import FlowPersistence from crewai.flow.persistence.sqlite import SQLiteFlowPersistence from crewai.utilities.printer import Printer logger = logging.getLogger(__name__) T = TypeVar("T") # Constants for log messages LOG_MESSAGES = { "save_state": "Saving flow state to memory for ID: {}", "save_error": "Failed to persist state for method {}: {}", "state_missing": "Flow instance has no state", "id_missing": "Flow state must have an 'id' field for persistence" } class PersistenceDecorator: """Class to handle flow state persistence with consistent logging.""" _printer = Printer() # Class-level printer instance @classmethod def persist_state(cls, flow_instance: Any, method_name: str, persistence_instance: FlowPersistence) -> None: """Persist flow state with proper error handling and logging. This method handles the persistence of flow state data, including proper error handling and colored console output for status updates. Args: flow_instance: The flow instance whose state to persist method_name: Name of the method that triggered persistence persistence_instance: The persistence backend to use Raises: ValueError: If flow has no state or state lacks an ID RuntimeError: If state persistence fails AttributeError: If flow instance lacks required state attributes """ try: state = getattr(flow_instance, 'state', None) if state is None: raise ValueError("Flow instance has no state") flow_uuid: Optional[str] = None if isinstance(state, dict): flow_uuid = state.get('id') elif isinstance(state, BaseModel): flow_uuid = getattr(state, 'id', None) if not flow_uuid: raise ValueError("Flow state must have an 'id' field for persistence") # Log state saving with consistent message cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan") logger.info(LOG_MESSAGES["save_state"].format(flow_uuid)) try: persistence_instance.save_state( flow_uuid=flow_uuid, method_name=method_name, state_data=state, ) except Exception as e: error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e)) cls._printer.print(error_msg, color="red") logger.error(error_msg) raise RuntimeError(f"State persistence failed: {str(e)}") from e except AttributeError: error_msg = LOG_MESSAGES["state_missing"] cls._printer.print(error_msg, color="red") logger.error(error_msg) raise ValueError(error_msg) except (TypeError, ValueError) as e: error_msg = LOG_MESSAGES["id_missing"] cls._printer.print(error_msg, color="red") logger.error(error_msg) raise ValueError(error_msg) from e def persist(persistence: Optional[FlowPersistence] = None): """Decorator to persist flow state. This decorator can be applied at either the class level or method level. When applied at the class level, it automatically persists all flow method states. When applied at the method level, it persists only that method's state. Args: persistence: Optional FlowPersistence implementation to use. If not provided, uses SQLiteFlowPersistence. Returns: A decorator that can be applied to either a class or method Raises: ValueError: If the flow state doesn't have an 'id' field RuntimeError: If state persistence fails Example: @persist # Class-level persistence with default SQLite class MyFlow(Flow[MyState]): @start() def begin(self): pass """ def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]: """Decorator that handles both class and method decoration.""" actual_persistence = persistence or SQLiteFlowPersistence() if isinstance(target, type): # Class decoration original_init = getattr(target, "__init__") @functools.wraps(original_init) def new_init(self: Any, *args: Any, **kwargs: Any) -> None: if 'persistence' not in kwargs: kwargs['persistence'] = actual_persistence original_init(self, *args, **kwargs) setattr(target, "__init__", new_init) # Store original methods to preserve their decorators original_methods = {} for name, method in target.__dict__.items(): if callable(method) and ( hasattr(method, "__is_start_method__") or hasattr(method, "__trigger_methods__") or hasattr(method, "__condition_type__") or hasattr(method, "__is_flow_method__") or hasattr(method, "__is_router__") ): original_methods[name] = method # Create wrapped versions of the methods that include persistence for name, method in original_methods.items(): if asyncio.iscoroutinefunction(method): # Create a closure to capture the current name and method def create_async_wrapper(method_name: str, original_method: Callable): @functools.wraps(original_method) async def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: result = await original_method(self, *args, **kwargs) PersistenceDecorator.persist_state(self, method_name, actual_persistence) return result return method_wrapper wrapped = create_async_wrapper(name, method) # Preserve all original decorators and attributes for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: if hasattr(method, attr): setattr(wrapped, attr, getattr(method, attr)) setattr(wrapped, "__is_flow_method__", True) # Update the class with the wrapped method setattr(target, name, wrapped) else: # Create a closure to capture the current name and method def create_sync_wrapper(method_name: str, original_method: Callable): @functools.wraps(original_method) def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: result = original_method(self, *args, **kwargs) PersistenceDecorator.persist_state(self, method_name, actual_persistence) return result return method_wrapper wrapped = create_sync_wrapper(name, method) # Preserve all original decorators and attributes for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: if hasattr(method, attr): setattr(wrapped, attr, getattr(method, attr)) setattr(wrapped, "__is_flow_method__", True) # Update the class with the wrapped method setattr(target, name, wrapped) return target else: # Method decoration method = target setattr(method, "__is_flow_method__", True) if asyncio.iscoroutinefunction(method): @functools.wraps(method) async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T: method_coro = method(flow_instance, *args, **kwargs) if asyncio.iscoroutine(method_coro): result = await method_coro else: result = method_coro PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence) return result for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: if hasattr(method, attr): setattr(method_async_wrapper, attr, getattr(method, attr)) setattr(method_async_wrapper, "__is_flow_method__", True) return cast(Callable[..., T], method_async_wrapper) else: @functools.wraps(method) def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T: result = method(flow_instance, *args, **kwargs) PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence) return result for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: if hasattr(method, attr): setattr(method_sync_wrapper, attr, getattr(method, attr)) setattr(method_sync_wrapper, "__is_flow_method__", True) return cast(Callable[..., T], method_sync_wrapper) return decorator