mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 12:28:30 +00:00
Stateful flows (#1931)
* fix: ensure persisted state overrides class defaults - Remove early return in Flow.__init__ to allow proper state initialization - Add test_flow_default_override.py to verify state override behavior - Fix issue where default values weren't being overridden by persisted state Fixes the issue where persisted state values weren't properly overriding class defaults when restarting a flow with a previously saved state ID. Co-Authored-By: Joe Moura <joao@crewai.com> * test: improve state restoration verification with has_set_count flag Co-Authored-By: Joe Moura <joao@crewai.com> * test: add has_set_count field to PoemState Co-Authored-By: Joe Moura <joao@crewai.com> * refactoring test * fix: ensure persisted state overrides class defaults - Remove early return in Flow.__init__ to allow proper state initialization - Add test_flow_default_override.py to verify state override behavior - Fix issue where default values weren't being overridden by persisted state Fixes the issue where persisted state values weren't properly overriding class defaults when restarting a flow with a previously saved state ID. Co-Authored-By: Joe Moura <joao@crewai.com> * test: improve state restoration verification with has_set_count flag Co-Authored-By: Joe Moura <joao@crewai.com> * test: add has_set_count field to PoemState Co-Authored-By: Joe Moura <joao@crewai.com> * refactoring test * Fixing flow state * fixing peristed stateful flows * linter * type fix --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -447,14 +447,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
def __init__(
|
||||
self,
|
||||
persistence: Optional[FlowPersistence] = None,
|
||||
restore_uuid: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a new Flow instance.
|
||||
|
||||
Args:
|
||||
persistence: Optional persistence backend for storing flow states
|
||||
restore_uuid: Optional UUID to restore state from persistence
|
||||
**kwargs: Additional state values to initialize or override
|
||||
"""
|
||||
# Initialize basic instance attributes
|
||||
@@ -464,64 +462,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._method_outputs: List[Any] = [] # List to store all method outputs
|
||||
self._persistence: Optional[FlowPersistence] = persistence
|
||||
|
||||
# Validate state model before initialization
|
||||
if isinstance(self.initial_state, type):
|
||||
if issubclass(self.initial_state, BaseModel) and not issubclass(
|
||||
self.initial_state, FlowState
|
||||
):
|
||||
# Check if model has id field
|
||||
model_fields = getattr(self.initial_state, "model_fields", None)
|
||||
if not model_fields or "id" not in model_fields:
|
||||
raise ValueError("Flow state model must have an 'id' field")
|
||||
# Initialize state with initial values
|
||||
self._state = self._create_initial_state()
|
||||
|
||||
# Handle persistence and potential ID conflicts
|
||||
stored_state = None
|
||||
if self._persistence is not None:
|
||||
if (
|
||||
restore_uuid
|
||||
and kwargs
|
||||
and "id" in kwargs
|
||||
and restore_uuid != kwargs["id"]
|
||||
):
|
||||
raise ValueError(
|
||||
f"Conflicting IDs provided: restore_uuid='{restore_uuid}' "
|
||||
f"vs kwargs['id']='{kwargs['id']}'. Use only one ID for restoration."
|
||||
)
|
||||
|
||||
# Attempt to load state, prioritizing restore_uuid
|
||||
if restore_uuid:
|
||||
self._log_flow_event(f"Loading flow state from memory for UUID: {restore_uuid}", color="bold_yellow")
|
||||
stored_state = self._persistence.load_state(restore_uuid)
|
||||
if not stored_state:
|
||||
raise ValueError(
|
||||
f"No state found for restore_uuid='{restore_uuid}'"
|
||||
)
|
||||
elif kwargs and "id" in kwargs:
|
||||
self._log_flow_event(f"Loading flow state from memory for ID: {kwargs['id']}", color="bold_yellow")
|
||||
stored_state = self._persistence.load_state(kwargs["id"])
|
||||
if not stored_state:
|
||||
# For kwargs["id"], we allow creating new state if not found
|
||||
self._state = self._create_initial_state()
|
||||
if kwargs:
|
||||
self._initialize_state(kwargs)
|
||||
return
|
||||
|
||||
# Initialize state based on persistence and kwargs
|
||||
if stored_state:
|
||||
# Create initial state and restore from persistence
|
||||
self._state = self._create_initial_state()
|
||||
self._restore_state(stored_state)
|
||||
# Apply any additional kwargs to override specific fields
|
||||
if kwargs:
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "id"}
|
||||
if filtered_kwargs:
|
||||
self._initialize_state(filtered_kwargs)
|
||||
else:
|
||||
# No stored state, create new state with initial values
|
||||
self._state = self._create_initial_state()
|
||||
# Apply any additional kwargs
|
||||
if kwargs:
|
||||
self._initialize_state(kwargs)
|
||||
# Apply any additional kwargs
|
||||
if kwargs:
|
||||
self._initialize_state(kwargs)
|
||||
|
||||
self._telemetry.flow_creation_span(self.__class__.__name__)
|
||||
|
||||
@@ -635,18 +581,18 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
@property
|
||||
def flow_id(self) -> str:
|
||||
"""Returns the unique identifier of this flow instance.
|
||||
|
||||
|
||||
This property provides a consistent way to access the flow's unique identifier
|
||||
regardless of the underlying state implementation (dict or BaseModel).
|
||||
|
||||
|
||||
Returns:
|
||||
str: The flow's unique identifier, or an empty string if not found
|
||||
|
||||
|
||||
Note:
|
||||
This property safely handles both dictionary and BaseModel state types,
|
||||
returning an empty string if the ID cannot be retrieved rather than raising
|
||||
an exception.
|
||||
|
||||
|
||||
Example:
|
||||
```python
|
||||
flow = MyFlow()
|
||||
@@ -656,7 +602,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
try:
|
||||
if not hasattr(self, '_state'):
|
||||
return ""
|
||||
|
||||
|
||||
if isinstance(self._state, dict):
|
||||
return str(self._state.get("id", ""))
|
||||
elif isinstance(self._state, BaseModel):
|
||||
@@ -731,7 +677,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
"""
|
||||
# When restoring from persistence, use the stored ID
|
||||
stored_id = stored_state.get("id")
|
||||
self._log_flow_event(f"Restoring flow state from memory for ID: {stored_id}", color="bold_yellow")
|
||||
if not stored_id:
|
||||
raise ValueError("Stored state must have an 'id' field")
|
||||
|
||||
@@ -755,6 +700,36 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}")
|
||||
|
||||
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""Start the flow execution.
|
||||
|
||||
Args:
|
||||
inputs: Optional dictionary containing input values and potentially a state ID to restore
|
||||
"""
|
||||
# Handle state restoration if ID is provided in inputs
|
||||
if inputs and 'id' in inputs and self._persistence is not None:
|
||||
restore_uuid = inputs['id']
|
||||
stored_state = self._persistence.load_state(restore_uuid)
|
||||
|
||||
# Override the id in the state if it exists in inputs
|
||||
if 'id' in inputs:
|
||||
if isinstance(self._state, dict):
|
||||
self._state['id'] = inputs['id']
|
||||
elif isinstance(self._state, BaseModel):
|
||||
setattr(self._state, 'id', inputs['id'])
|
||||
|
||||
if stored_state:
|
||||
self._log_flow_event(f"Loading flow state from memory for UUID: {restore_uuid}", color="yellow")
|
||||
# Restore the state
|
||||
self._restore_state(stored_state)
|
||||
else:
|
||||
self._log_flow_event(f"No flow state found for UUID: {restore_uuid}", color="red")
|
||||
|
||||
# Apply any additional inputs after restoration
|
||||
filtered_inputs = {k: v for k, v in inputs.items() if k != 'id'}
|
||||
if filtered_inputs:
|
||||
self._initialize_state(filtered_inputs)
|
||||
|
||||
# Start flow execution
|
||||
self.event_emitter.send(
|
||||
self,
|
||||
event=FlowStartedEvent(
|
||||
@@ -762,10 +737,11 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
flow_name=self.__class__.__name__,
|
||||
),
|
||||
)
|
||||
self._log_flow_event(f"Flow started with ID: {self.flow_id}", color="yellow")
|
||||
self._log_flow_event(f"Flow started with ID: {self.flow_id}", color="bold_magenta")
|
||||
|
||||
if inputs is not None:
|
||||
if inputs is not None and 'id' not in inputs:
|
||||
self._initialize_state(inputs)
|
||||
|
||||
return asyncio.run(self.kickoff_async())
|
||||
|
||||
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
||||
@@ -1010,18 +986,18 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
def _log_flow_event(self, message: str, color: str = "yellow", level: str = "info") -> None:
|
||||
"""Centralized logging method for flow events.
|
||||
|
||||
|
||||
This method provides a consistent interface for logging flow-related events,
|
||||
combining both console output with colors and proper logging levels.
|
||||
|
||||
|
||||
Args:
|
||||
message: The message to log
|
||||
color: Color to use for console output (default: yellow)
|
||||
Available colors: purple, red, bold_green, bold_purple,
|
||||
bold_blue, yellow, bold_yellow
|
||||
bold_blue, yellow, yellow
|
||||
level: Log level to use (default: info)
|
||||
Supported levels: info, warning
|
||||
|
||||
|
||||
Note:
|
||||
This method uses the Printer utility for colored console output
|
||||
and the standard logging module for log level support.
|
||||
@@ -1031,7 +1007,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
logger.info(message)
|
||||
elif level == "warning":
|
||||
logger.warning(message)
|
||||
|
||||
|
||||
def plot(self, filename: str = "crewai_flow") -> None:
|
||||
self._telemetry.flow_plotting_span(
|
||||
self.__class__.__name__, list(self._methods.keys())
|
||||
|
||||
@@ -54,57 +54,44 @@ LOG_MESSAGES = {
|
||||
|
||||
class PersistenceDecorator:
|
||||
"""Class to handle flow state persistence with consistent logging."""
|
||||
|
||||
|
||||
_printer = Printer() # Class-level printer instance
|
||||
|
||||
|
||||
@classmethod
|
||||
def persist_state(cls, flow_instance: Any, method_name: str, persistence_instance: FlowPersistence) -> None:
|
||||
"""Persist flow state with proper error handling and logging.
|
||||
|
||||
|
||||
This method handles the persistence of flow state data, including proper
|
||||
error handling and colored console output for status updates.
|
||||
|
||||
|
||||
Args:
|
||||
flow_instance: The flow instance whose state to persist
|
||||
method_name: Name of the method that triggered persistence
|
||||
persistence_instance: The persistence backend to use
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If flow has no state or state lacks an ID
|
||||
RuntimeError: If state persistence fails
|
||||
AttributeError: If flow instance lacks required state attributes
|
||||
|
||||
Note:
|
||||
Uses bold_yellow color for success messages and red for errors.
|
||||
All operations are logged at appropriate levels (info/error).
|
||||
|
||||
Example:
|
||||
```python
|
||||
@persist
|
||||
def my_flow_method(self):
|
||||
# Method implementation
|
||||
pass
|
||||
# State will be automatically persisted after method execution
|
||||
```
|
||||
"""
|
||||
try:
|
||||
state = getattr(flow_instance, 'state', None)
|
||||
if state is None:
|
||||
raise ValueError("Flow instance has no state")
|
||||
|
||||
|
||||
flow_uuid: Optional[str] = None
|
||||
if isinstance(state, dict):
|
||||
flow_uuid = state.get('id')
|
||||
elif isinstance(state, BaseModel):
|
||||
flow_uuid = getattr(state, 'id', None)
|
||||
|
||||
|
||||
if not flow_uuid:
|
||||
raise ValueError("Flow state must have an 'id' field for persistence")
|
||||
|
||||
|
||||
# Log state saving with consistent message
|
||||
cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="bold_yellow")
|
||||
cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan")
|
||||
logger.info(LOG_MESSAGES["save_state"].format(flow_uuid))
|
||||
|
||||
|
||||
try:
|
||||
persistence_instance.save_state(
|
||||
flow_uuid=flow_uuid,
|
||||
@@ -154,44 +141,79 @@ def persist(persistence: Optional[FlowPersistence] = None):
|
||||
def begin(self):
|
||||
pass
|
||||
"""
|
||||
|
||||
def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]:
|
||||
"""Decorator that handles both class and method decoration."""
|
||||
actual_persistence = persistence or SQLiteFlowPersistence()
|
||||
|
||||
if isinstance(target, type):
|
||||
# Class decoration
|
||||
class_methods = {}
|
||||
for name, method in target.__dict__.items():
|
||||
if callable(method) and hasattr(method, "__is_flow_method__"):
|
||||
# Wrap each flow method with persistence
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
@functools.wraps(method)
|
||||
async def class_async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
method_coro = method(self, *args, **kwargs)
|
||||
if asyncio.iscoroutine(method_coro):
|
||||
result = await method_coro
|
||||
else:
|
||||
result = method_coro
|
||||
PersistenceDecorator.persist_state(self, method.__name__, actual_persistence)
|
||||
return result
|
||||
class_methods[name] = class_async_wrapper
|
||||
else:
|
||||
@functools.wraps(method)
|
||||
def class_sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
result = method(self, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(self, method.__name__, actual_persistence)
|
||||
return result
|
||||
class_methods[name] = class_sync_wrapper
|
||||
original_init = getattr(target, "__init__")
|
||||
|
||||
# Preserve flow-specific attributes
|
||||
@functools.wraps(original_init)
|
||||
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
|
||||
if 'persistence' not in kwargs:
|
||||
kwargs['persistence'] = actual_persistence
|
||||
original_init(self, *args, **kwargs)
|
||||
|
||||
setattr(target, "__init__", new_init)
|
||||
|
||||
# Store original methods to preserve their decorators
|
||||
original_methods = {}
|
||||
|
||||
for name, method in target.__dict__.items():
|
||||
if callable(method) and (
|
||||
hasattr(method, "__is_start_method__") or
|
||||
hasattr(method, "__trigger_methods__") or
|
||||
hasattr(method, "__condition_type__") or
|
||||
hasattr(method, "__is_flow_method__") or
|
||||
hasattr(method, "__is_router__")
|
||||
):
|
||||
original_methods[name] = method
|
||||
|
||||
# Create wrapped versions of the methods that include persistence
|
||||
for name, method in original_methods.items():
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
# Create a closure to capture the current name and method
|
||||
def create_async_wrapper(method_name: str, original_method: Callable):
|
||||
@functools.wraps(original_method)
|
||||
async def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
result = await original_method(self, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(self, method_name, actual_persistence)
|
||||
return result
|
||||
return method_wrapper
|
||||
|
||||
wrapped = create_async_wrapper(name, method)
|
||||
|
||||
# Preserve all original decorators and attributes
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(class_methods[name], attr, getattr(method, attr))
|
||||
setattr(class_methods[name], "__is_flow_method__", True)
|
||||
setattr(wrapped, attr, getattr(method, attr))
|
||||
setattr(wrapped, "__is_flow_method__", True)
|
||||
|
||||
# Update the class with the wrapped method
|
||||
setattr(target, name, wrapped)
|
||||
else:
|
||||
# Create a closure to capture the current name and method
|
||||
def create_sync_wrapper(method_name: str, original_method: Callable):
|
||||
@functools.wraps(original_method)
|
||||
def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
result = original_method(self, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(self, method_name, actual_persistence)
|
||||
return result
|
||||
return method_wrapper
|
||||
|
||||
wrapped = create_sync_wrapper(name, method)
|
||||
|
||||
# Preserve all original decorators and attributes
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(wrapped, attr, getattr(method, attr))
|
||||
setattr(wrapped, "__is_flow_method__", True)
|
||||
|
||||
# Update the class with the wrapped method
|
||||
setattr(target, name, wrapped)
|
||||
|
||||
# Update class with wrapped methods
|
||||
for name, method in class_methods.items():
|
||||
setattr(target, name, method)
|
||||
return target
|
||||
else:
|
||||
# Method decoration
|
||||
@@ -208,6 +230,7 @@ def persist(persistence: Optional[FlowPersistence] = None):
|
||||
result = method_coro
|
||||
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence)
|
||||
return result
|
||||
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_async_wrapper, attr, getattr(method, attr))
|
||||
@@ -219,6 +242,7 @@ def persist(persistence: Optional[FlowPersistence] = None):
|
||||
result = method(flow_instance, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence)
|
||||
return result
|
||||
|
||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||
if hasattr(method, attr):
|
||||
setattr(method_sync_wrapper, attr, getattr(method, attr))
|
||||
|
||||
@@ -3,10 +3,9 @@ SQLite-based implementation of flow state persistence.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -16,34 +15,34 @@ from crewai.flow.persistence.base import FlowPersistence
|
||||
|
||||
class SQLiteFlowPersistence(FlowPersistence):
|
||||
"""SQLite-based implementation of flow state persistence.
|
||||
|
||||
|
||||
This class provides a simple, file-based persistence implementation using SQLite.
|
||||
It's suitable for development and testing, or for production use cases with
|
||||
moderate performance requirements.
|
||||
"""
|
||||
|
||||
|
||||
db_path: str # Type annotation for instance variable
|
||||
|
||||
|
||||
def __init__(self, db_path: Optional[str] = None):
|
||||
"""Initialize SQLite persistence.
|
||||
|
||||
|
||||
Args:
|
||||
db_path: Path to the SQLite database file. If not provided, uses
|
||||
db_storage_path() from utilities.paths.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If db_path is invalid
|
||||
"""
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
# Get path from argument or default location
|
||||
path = db_path or db_storage_path()
|
||||
|
||||
path = db_path or str(Path(db_storage_path()) / "flow_states.db")
|
||||
|
||||
if not path:
|
||||
raise ValueError("Database path must be provided")
|
||||
|
||||
|
||||
self.db_path = path # Now mypy knows this is str
|
||||
self.init_db()
|
||||
|
||||
|
||||
def init_db(self) -> None:
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -58,10 +57,10 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
""")
|
||||
# Add index for faster UUID lookups
|
||||
conn.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_flow_states_uuid
|
||||
CREATE INDEX IF NOT EXISTS idx_flow_states_uuid
|
||||
ON flow_states(flow_uuid)
|
||||
""")
|
||||
|
||||
|
||||
def save_state(
|
||||
self,
|
||||
flow_uuid: str,
|
||||
@@ -69,7 +68,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
state_data: Union[Dict[str, Any], BaseModel],
|
||||
) -> None:
|
||||
"""Save the current flow state to SQLite.
|
||||
|
||||
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
method_name: Name of the method that just completed
|
||||
@@ -84,7 +83,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
raise ValueError(
|
||||
f"state_data must be either a Pydantic BaseModel or dict, got {type(state_data)}"
|
||||
)
|
||||
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("""
|
||||
INSERT INTO flow_states (
|
||||
@@ -99,13 +98,13 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
datetime.utcnow().isoformat(),
|
||||
json.dumps(state_dict),
|
||||
))
|
||||
|
||||
|
||||
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
|
||||
"""Load the most recent state for a given flow UUID.
|
||||
|
||||
|
||||
Args:
|
||||
flow_uuid: Unique identifier for the flow instance
|
||||
|
||||
|
||||
Returns:
|
||||
The most recent state as a dictionary, or None if no state exists
|
||||
"""
|
||||
@@ -118,7 +117,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
LIMIT 1
|
||||
""", (flow_uuid,))
|
||||
row = cursor.fetchone()
|
||||
|
||||
|
||||
if row:
|
||||
return json.loads(row[0])
|
||||
return None
|
||||
|
||||
@@ -23,7 +23,7 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
) -> None:
|
||||
if db_path is None:
|
||||
# Get the parent directory of the default db path and create our db file there
|
||||
db_path = str(Path(db_storage_path()).parent / "latest_kickoff_task_outputs.db")
|
||||
db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db")
|
||||
self.db_path = db_path
|
||||
self._printer: Printer = Printer()
|
||||
self._initialize_db()
|
||||
|
||||
@@ -17,7 +17,7 @@ class LTMSQLiteStorage:
|
||||
) -> None:
|
||||
if db_path is None:
|
||||
# Get the parent directory of the default db path and create our db file there
|
||||
db_path = str(Path(db_storage_path()).parent / "long_term_memory_storage.db")
|
||||
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
|
||||
self.db_path = db_path
|
||||
self._printer: Printer = Printer()
|
||||
# Ensure parent directory exists
|
||||
|
||||
@@ -7,7 +7,7 @@ import appdirs
|
||||
|
||||
def db_storage_path() -> str:
|
||||
"""Returns the path for SQLite database storage.
|
||||
|
||||
|
||||
Returns:
|
||||
str: Full path to the SQLite database file
|
||||
"""
|
||||
@@ -16,7 +16,7 @@ def db_storage_path() -> str:
|
||||
|
||||
data_dir = Path(appdirs.user_data_dir(app_name, app_author))
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
return str(data_dir / "crewai_flows.db")
|
||||
return str(data_dir)
|
||||
|
||||
|
||||
def get_project_directory_name():
|
||||
@@ -28,4 +28,4 @@ def get_project_directory_name():
|
||||
else:
|
||||
cwd = Path.cwd()
|
||||
project_directory_name = cwd.name
|
||||
return project_directory_name
|
||||
return project_directory_name
|
||||
@@ -21,6 +21,16 @@ class Printer:
|
||||
self._print_yellow(content)
|
||||
elif color == "bold_yellow":
|
||||
self._print_bold_yellow(content)
|
||||
elif color == "cyan":
|
||||
self._print_cyan(content)
|
||||
elif color == "bold_cyan":
|
||||
self._print_bold_cyan(content)
|
||||
elif color == "magenta":
|
||||
self._print_magenta(content)
|
||||
elif color == "bold_magenta":
|
||||
self._print_bold_magenta(content)
|
||||
elif color == "green":
|
||||
self._print_green(content)
|
||||
else:
|
||||
print(content)
|
||||
|
||||
@@ -44,3 +54,18 @@ class Printer:
|
||||
|
||||
def _print_bold_yellow(self, content):
|
||||
print("\033[1m\033[93m {}\033[00m".format(content))
|
||||
|
||||
def _print_cyan(self, content):
|
||||
print("\033[96m {}\033[00m".format(content))
|
||||
|
||||
def _print_bold_cyan(self, content):
|
||||
print("\033[1m\033[96m {}\033[00m".format(content))
|
||||
|
||||
def _print_magenta(self, content):
|
||||
print("\033[35m {}\033[00m".format(content))
|
||||
|
||||
def _print_bold_magenta(self, content):
|
||||
print("\033[1m\033[35m {}\033[00m".format(content))
|
||||
|
||||
def _print_green(self, content):
|
||||
print("\033[32m {}\033[00m".format(content))
|
||||
|
||||
112
tests/test_flow_default_override.py
Normal file
112
tests/test_flow_default_override.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""Test that persisted state properly overrides default values."""
|
||||
|
||||
from crewai.flow.flow import Flow, FlowState, listen, start
|
||||
from crewai.flow.persistence import persist
|
||||
|
||||
|
||||
class PoemState(FlowState):
|
||||
"""Test state model with default values that should be overridden."""
|
||||
sentence_count: int = 1000 # Default that should be overridden
|
||||
has_set_count: bool = False # Track whether we've set the count
|
||||
poem_type: str = ""
|
||||
|
||||
|
||||
def test_default_value_override():
|
||||
"""Test that persisted state values override class defaults."""
|
||||
|
||||
@persist()
|
||||
class PoemFlow(Flow[PoemState]):
|
||||
initial_state = PoemState
|
||||
|
||||
@start()
|
||||
def set_sentence_count(self):
|
||||
if self.state.has_set_count and self.state.sentence_count == 2:
|
||||
self.state.sentence_count = 3
|
||||
|
||||
elif self.state.has_set_count and self.state.sentence_count == 1000:
|
||||
self.state.sentence_count = 1000
|
||||
|
||||
elif self.state.has_set_count and self.state.sentence_count == 5:
|
||||
self.state.sentence_count = 5
|
||||
|
||||
else:
|
||||
self.state.sentence_count = 2
|
||||
self.state.has_set_count = True
|
||||
|
||||
# First run - should set sentence_count to 2
|
||||
flow1 = PoemFlow()
|
||||
flow1.kickoff()
|
||||
original_uuid = flow1.state.id
|
||||
assert flow1.state.sentence_count == 2
|
||||
|
||||
# Second run - should load sentence_count=2 instead of default 1000
|
||||
flow2 = PoemFlow()
|
||||
flow2.kickoff(inputs={"id": original_uuid})
|
||||
assert flow2.state.sentence_count == 3 # Should load 2, not default 1000
|
||||
|
||||
# Fourth run - explicit override should work
|
||||
flow3 = PoemFlow()
|
||||
flow3.kickoff(inputs={
|
||||
"id": original_uuid,
|
||||
"has_set_count": True,
|
||||
"sentence_count": 5, # Override persisted value
|
||||
})
|
||||
assert flow3.state.sentence_count == 5 # Should use override value
|
||||
|
||||
# Third run - should not load sentence_count=2 instead of default 1000
|
||||
flow4 = PoemFlow()
|
||||
flow4.kickoff(inputs={"has_set_count": True})
|
||||
assert flow4.state.sentence_count == 1000 # Should load 1000, not 2
|
||||
|
||||
|
||||
def test_multi_step_default_override():
|
||||
"""Test default value override with multiple start methods."""
|
||||
|
||||
@persist()
|
||||
class MultiStepPoemFlow(Flow[PoemState]):
|
||||
initial_state = PoemState
|
||||
|
||||
@start()
|
||||
def set_sentence_count(self):
|
||||
print("Setting sentence count")
|
||||
if not self.state.has_set_count:
|
||||
self.state.sentence_count = 3
|
||||
self.state.has_set_count = True
|
||||
|
||||
@listen(set_sentence_count)
|
||||
def set_poem_type(self):
|
||||
print("Setting poem type")
|
||||
if self.state.sentence_count == 3:
|
||||
self.state.poem_type = "haiku"
|
||||
elif self.state.sentence_count == 5:
|
||||
self.state.poem_type = "limerick"
|
||||
else:
|
||||
self.state.poem_type = "free_verse"
|
||||
|
||||
@listen(set_poem_type)
|
||||
def finished(self):
|
||||
print("finished")
|
||||
|
||||
# First run - should set both sentence count and poem type
|
||||
flow1 = MultiStepPoemFlow()
|
||||
flow1.kickoff()
|
||||
original_uuid = flow1.state.id
|
||||
assert flow1.state.sentence_count == 3
|
||||
assert flow1.state.poem_type == "haiku"
|
||||
|
||||
# Second run - should load persisted state and update poem type
|
||||
flow2 = MultiStepPoemFlow()
|
||||
flow2.kickoff(inputs={
|
||||
"id": original_uuid,
|
||||
"sentence_count": 5
|
||||
})
|
||||
assert flow2.state.sentence_count == 5
|
||||
assert flow2.state.poem_type == "limerick"
|
||||
|
||||
# Third run - new flow without persisted state should use defaults
|
||||
flow3 = MultiStepPoemFlow()
|
||||
flow3.kickoff(inputs={
|
||||
"id": original_uuid
|
||||
})
|
||||
assert flow3.state.sentence_count == 5
|
||||
assert flow3.state.poem_type == "limerick"
|
||||
@@ -1,12 +1,12 @@
|
||||
"""Test flow state persistence functionality."""
|
||||
|
||||
import os
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.flow.flow import Flow, FlowState, start
|
||||
from crewai.flow.flow import Flow, FlowState, listen, start
|
||||
from crewai.flow.persistence import persist
|
||||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||
|
||||
@@ -73,13 +73,14 @@ def test_flow_state_restoration(tmp_path):
|
||||
|
||||
# First flow execution to create initial state
|
||||
class RestorableFlow(Flow[TestState]):
|
||||
initial_state = TestState
|
||||
|
||||
@start()
|
||||
@persist(persistence)
|
||||
def set_message(self):
|
||||
self.state.message = "Original message"
|
||||
self.state.counter = 42
|
||||
if self.state.message == "":
|
||||
self.state.message = "Original message"
|
||||
if self.state.counter == 0:
|
||||
self.state.counter = 42
|
||||
|
||||
# Create and persist initial state
|
||||
flow1 = RestorableFlow(persistence=persistence)
|
||||
@@ -87,11 +88,11 @@ def test_flow_state_restoration(tmp_path):
|
||||
original_uuid = flow1.state.id
|
||||
|
||||
# Test case 1: Restore using restore_uuid with field override
|
||||
flow2 = RestorableFlow(
|
||||
persistence=persistence,
|
||||
restore_uuid=original_uuid,
|
||||
counter=43, # Override counter
|
||||
)
|
||||
flow2 = RestorableFlow(persistence=persistence)
|
||||
flow2.kickoff(inputs={
|
||||
"id": original_uuid,
|
||||
"counter": 43
|
||||
})
|
||||
|
||||
# Verify state restoration and selective field override
|
||||
assert flow2.state.id == original_uuid
|
||||
@@ -99,48 +100,17 @@ def test_flow_state_restoration(tmp_path):
|
||||
assert flow2.state.counter == 43 # Overridden
|
||||
|
||||
# Test case 2: Restore using kwargs['id']
|
||||
flow3 = RestorableFlow(
|
||||
persistence=persistence,
|
||||
id=original_uuid,
|
||||
message="Updated message", # Override message
|
||||
)
|
||||
flow3 = RestorableFlow(persistence=persistence)
|
||||
flow3.kickoff(inputs={
|
||||
"id": original_uuid,
|
||||
"message": "Updated message"
|
||||
})
|
||||
|
||||
# Verify state restoration and selective field override
|
||||
assert flow3.state.id == original_uuid
|
||||
assert flow3.state.counter == 42 # Preserved
|
||||
assert flow3.state.counter == 43 # Preserved
|
||||
assert flow3.state.message == "Updated message" # Overridden
|
||||
|
||||
# Test case 3: Verify error on conflicting IDs
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
RestorableFlow(
|
||||
persistence=persistence,
|
||||
restore_uuid=original_uuid,
|
||||
id="different-id", # Conflict with restore_uuid
|
||||
)
|
||||
assert "Conflicting IDs provided" in str(exc_info.value)
|
||||
|
||||
# Test case 4: Verify error on non-existent restore_uuid
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
RestorableFlow(
|
||||
persistence=persistence,
|
||||
restore_uuid="non-existent-uuid",
|
||||
)
|
||||
assert "No state found" in str(exc_info.value)
|
||||
|
||||
# Test case 5: Allow new state creation with kwargs['id']
|
||||
new_uuid = "new-flow-id"
|
||||
flow4 = RestorableFlow(
|
||||
persistence=persistence,
|
||||
id=new_uuid,
|
||||
message="New message",
|
||||
counter=100,
|
||||
)
|
||||
|
||||
# Verify new state creation with provided ID
|
||||
assert flow4.state.id == new_uuid
|
||||
assert flow4.state.message == "New message"
|
||||
assert flow4.state.counter == 100
|
||||
|
||||
|
||||
def test_multiple_method_persistence(tmp_path):
|
||||
"""Test state persistence across multiple method executions."""
|
||||
@@ -148,48 +118,59 @@ def test_multiple_method_persistence(tmp_path):
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
|
||||
class MultiStepFlow(Flow[TestState]):
|
||||
initial_state = TestState
|
||||
|
||||
@start()
|
||||
@persist(persistence)
|
||||
def step_1(self):
|
||||
self.state.counter = 1
|
||||
self.state.message = "Step 1"
|
||||
if self.state.counter == 1:
|
||||
self.state.counter = 99999
|
||||
self.state.message = "Step 99999"
|
||||
else:
|
||||
self.state.counter = 1
|
||||
self.state.message = "Step 1"
|
||||
|
||||
@start()
|
||||
@listen(step_1)
|
||||
@persist(persistence)
|
||||
def step_2(self):
|
||||
self.state.counter = 2
|
||||
self.state.message = "Step 2"
|
||||
if self.state.counter == 1:
|
||||
self.state.counter = 2
|
||||
self.state.message = "Step 2"
|
||||
|
||||
flow = MultiStepFlow(persistence=persistence)
|
||||
flow.kickoff()
|
||||
|
||||
flow2 = MultiStepFlow(persistence=persistence)
|
||||
flow2.kickoff(inputs={"id": flow.state.id})
|
||||
|
||||
# Load final state
|
||||
final_state = persistence.load_state(flow.state.id)
|
||||
final_state = flow2.state
|
||||
assert final_state is not None
|
||||
assert final_state["counter"] == 2
|
||||
assert final_state["message"] == "Step 2"
|
||||
|
||||
|
||||
def test_persistence_error_handling(tmp_path):
|
||||
"""Test error handling in persistence operations."""
|
||||
db_path = os.path.join(tmp_path, "test_flows.db")
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
|
||||
class InvalidFlow(Flow[TestState]):
|
||||
# Missing id field in initial state
|
||||
class InvalidState(BaseModel):
|
||||
value: str = ""
|
||||
|
||||
initial_state = InvalidState
|
||||
assert final_state.counter == 2
|
||||
assert final_state.message == "Step 2"
|
||||
|
||||
class NoPersistenceMultiStepFlow(Flow[TestState]):
|
||||
@start()
|
||||
@persist(persistence)
|
||||
def will_fail(self):
|
||||
self.state.value = "test"
|
||||
def step_1(self):
|
||||
if self.state.counter == 1:
|
||||
self.state.counter = 99999
|
||||
self.state.message = "Step 99999"
|
||||
else:
|
||||
self.state.counter = 1
|
||||
self.state.message = "Step 1"
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
flow = InvalidFlow(persistence=persistence)
|
||||
@listen(step_1)
|
||||
def step_2(self):
|
||||
if self.state.counter == 1:
|
||||
self.state.counter = 2
|
||||
self.state.message = "Step 2"
|
||||
|
||||
assert "must have an 'id' field" in str(exc_info.value)
|
||||
flow = NoPersistenceMultiStepFlow(persistence=persistence)
|
||||
flow.kickoff()
|
||||
|
||||
flow2 = NoPersistenceMultiStepFlow(persistence=persistence)
|
||||
flow2.kickoff(inputs={"id": flow.state.id})
|
||||
|
||||
# Load final state
|
||||
final_state = flow2.state
|
||||
assert final_state.counter == 99999
|
||||
assert final_state.message == "Step 99999"
|
||||
|
||||
Reference in New Issue
Block a user