mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-03 06:08:15 +00:00
feat(flow): support custom persistence key in @persist (#5649)
* feat(flow): add optional key param to @persist decorator Allows users to specify which state attribute to use as the persistence key instead of always defaulting to state.id. Usage: @persist(key='conversation_id') Falls back to state.id when key is not provided (no breaking change). Raises ValueError if the specified key is missing or falsy on state. * docs(flow): document @persist key parameter for custom persistence keys * fix(flow): use explicit None check for persist key to avoid empty-string fallback --------- Co-authored-by: iris-clawd <iris-clawd@anthropic.com> Co-authored-by: iris-clawd <iris@crewai.com> Co-authored-by: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com>
This commit is contained in:
@@ -50,6 +50,7 @@ LOG_MESSAGES: Final[dict[str, str]] = {
|
||||
"save_error": "Failed to persist state for method {}: {}",
|
||||
"state_missing": "Flow instance has no state",
|
||||
"id_missing": "Flow state must have an 'id' field for persistence",
|
||||
"key_missing": "Flow state is missing required persistence key '{}'",
|
||||
}
|
||||
|
||||
|
||||
@@ -63,6 +64,7 @@ class PersistenceDecorator:
|
||||
method_name: str,
|
||||
persistence_instance: FlowPersistence,
|
||||
verbose: bool = False,
|
||||
key: str | None = None,
|
||||
) -> None:
|
||||
"""Persist flow state with proper error handling and logging.
|
||||
|
||||
@@ -74,9 +76,12 @@ class PersistenceDecorator:
|
||||
method_name: Name of the method that triggered persistence
|
||||
persistence_instance: The persistence backend to use
|
||||
verbose: Whether to log persistence operations
|
||||
key: Optional state attribute/key to use as the persistence key.
|
||||
When None, falls back to ``state.id``.
|
||||
|
||||
Raises:
|
||||
ValueError: If flow has no state or state lacks an ID
|
||||
ValueError: If flow has no state, state lacks an ID, or the
|
||||
requested ``key`` is missing or falsy on state.
|
||||
RuntimeError: If state persistence fails
|
||||
AttributeError: If flow instance lacks required state attributes
|
||||
"""
|
||||
@@ -85,19 +90,22 @@ class PersistenceDecorator:
|
||||
if state is None:
|
||||
raise ValueError("Flow instance has no state")
|
||||
|
||||
lookup_key = key if key is not None else "id"
|
||||
flow_uuid: str | None = None
|
||||
if isinstance(state, dict):
|
||||
flow_uuid = state.get("id")
|
||||
flow_uuid = state.get(lookup_key)
|
||||
elif hasattr(state, "_unwrap"):
|
||||
unwrapped = state._unwrap()
|
||||
if isinstance(unwrapped, dict):
|
||||
flow_uuid = unwrapped.get("id")
|
||||
flow_uuid = unwrapped.get(lookup_key)
|
||||
else:
|
||||
flow_uuid = getattr(unwrapped, "id", None)
|
||||
elif isinstance(state, BaseModel) or hasattr(state, "id"):
|
||||
flow_uuid = getattr(state, "id", None)
|
||||
flow_uuid = getattr(unwrapped, lookup_key, None)
|
||||
elif isinstance(state, BaseModel) or hasattr(state, lookup_key):
|
||||
flow_uuid = getattr(state, lookup_key, None)
|
||||
|
||||
if not flow_uuid:
|
||||
if key is not None:
|
||||
raise ValueError(LOG_MESSAGES["key_missing"].format(key))
|
||||
raise ValueError("Flow state must have an 'id' field for persistence")
|
||||
|
||||
# Log state saving only if verbose is True
|
||||
@@ -127,7 +135,7 @@ class PersistenceDecorator:
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg) from e
|
||||
except (TypeError, ValueError) as e:
|
||||
error_msg = LOG_MESSAGES["id_missing"]
|
||||
error_msg = str(e) or LOG_MESSAGES["id_missing"]
|
||||
if verbose:
|
||||
PRINTER.print(error_msg, color="red")
|
||||
logger.error(error_msg)
|
||||
@@ -135,7 +143,9 @@ class PersistenceDecorator:
|
||||
|
||||
|
||||
def persist(
|
||||
persistence: FlowPersistence | None = None, verbose: bool = False
|
||||
persistence: FlowPersistence | None = None,
|
||||
verbose: bool = False,
|
||||
key: str | None = None,
|
||||
) -> Callable[[type | Callable[..., T]], type | Callable[..., T]]:
|
||||
"""Decorator to persist flow state.
|
||||
|
||||
@@ -148,12 +158,16 @@ def persist(
|
||||
persistence: Optional FlowPersistence implementation to use.
|
||||
If not provided, uses SQLiteFlowPersistence.
|
||||
verbose: Whether to log persistence operations. Defaults to False.
|
||||
key: Optional name of the state attribute (for Pydantic/object states)
|
||||
or dict key (for dict states) to use as the persistence key. When
|
||||
``None`` (default) the decorator falls back to ``state.id``.
|
||||
|
||||
Returns:
|
||||
A decorator that can be applied to either a class or method
|
||||
|
||||
Raises:
|
||||
ValueError: If the flow state doesn't have an 'id' field
|
||||
ValueError: If the flow state doesn't have an 'id' field, or the
|
||||
specified ``key`` is missing or falsy on state.
|
||||
RuntimeError: If state persistence fails
|
||||
|
||||
Example:
|
||||
@@ -162,6 +176,10 @@ def persist(
|
||||
@start()
|
||||
def begin(self):
|
||||
pass
|
||||
|
||||
@persist(key="conversation_id") # Custom persistence key
|
||||
class MyFlow(Flow[MyState]):
|
||||
...
|
||||
"""
|
||||
|
||||
def decorator(target: type | Callable[..., T]) -> type | Callable[..., T]:
|
||||
@@ -207,7 +225,7 @@ def persist(
|
||||
) -> Any:
|
||||
result = await original_method(self, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(
|
||||
self, method_name, actual_persistence, verbose
|
||||
self, method_name, actual_persistence, verbose, key
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -237,7 +255,7 @@ def persist(
|
||||
def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
result = original_method(self, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(
|
||||
self, method_name, actual_persistence, verbose
|
||||
self, method_name, actual_persistence, verbose, key
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -276,7 +294,7 @@ def persist(
|
||||
else:
|
||||
result = method_coro
|
||||
PersistenceDecorator.persist_state(
|
||||
flow_instance, method.__name__, actual_persistence, verbose
|
||||
flow_instance, method.__name__, actual_persistence, verbose, key
|
||||
)
|
||||
return cast(T, result)
|
||||
|
||||
@@ -295,7 +313,7 @@ def persist(
|
||||
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
||||
result = method(flow_instance, *args, **kwargs)
|
||||
PersistenceDecorator.persist_state(
|
||||
flow_instance, method.__name__, actual_persistence, verbose
|
||||
flow_instance, method.__name__, actual_persistence, verbose, key
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import os
|
||||
from typing import Dict, List
|
||||
|
||||
import pytest
|
||||
from crewai.flow.flow import Flow, FlowState, listen, start
|
||||
from crewai.flow.persistence import persist
|
||||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||
@@ -248,3 +249,69 @@ def test_persistence_with_base_model(tmp_path):
|
||||
assert message.type == "text"
|
||||
assert message.content == "Hello, World!"
|
||||
assert isinstance(flow.state._unwrap(), State)
|
||||
|
||||
|
||||
def test_persist_custom_key_with_pydantic_state(tmp_path):
|
||||
"""`@persist(key=...)` uses the named attribute on a Pydantic state."""
|
||||
db_path = os.path.join(tmp_path, "test_flows.db")
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
|
||||
class KeyedState(FlowState):
|
||||
conversation_id: str = "conv-42"
|
||||
message: str = ""
|
||||
|
||||
class KeyedFlow(Flow[KeyedState]):
|
||||
@start()
|
||||
@persist(persistence, key="conversation_id")
|
||||
def init_step(self):
|
||||
self.state.message = "hello"
|
||||
|
||||
flow = KeyedFlow(persistence=persistence)
|
||||
flow.kickoff()
|
||||
|
||||
saved_state = persistence.load_state("conv-42")
|
||||
assert saved_state is not None
|
||||
assert saved_state["message"] == "hello"
|
||||
# The default `state.id` lookup must NOT have been used as the key.
|
||||
assert persistence.load_state(flow.state.id) is None
|
||||
|
||||
|
||||
def test_persist_custom_key_with_dict_state(tmp_path):
|
||||
"""`@persist(key=...)` uses the named key on a dict state."""
|
||||
db_path = os.path.join(tmp_path, "test_flows.db")
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
|
||||
class DictKeyedFlow(Flow[Dict[str, str]]):
|
||||
initial_state = dict()
|
||||
|
||||
@start()
|
||||
@persist(persistence, key="conversation_id")
|
||||
def init_step(self):
|
||||
self.state["conversation_id"] = "conv-dict-7"
|
||||
self.state["message"] = "hi from dict"
|
||||
|
||||
flow = DictKeyedFlow(persistence=persistence)
|
||||
flow.kickoff()
|
||||
|
||||
saved_state = persistence.load_state("conv-dict-7")
|
||||
assert saved_state is not None
|
||||
assert saved_state["message"] == "hi from dict"
|
||||
|
||||
|
||||
def test_persist_custom_key_missing_raises(tmp_path):
|
||||
"""A missing/falsy custom key must raise a clear ValueError."""
|
||||
db_path = os.path.join(tmp_path, "test_flows.db")
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
|
||||
class MissingKeyFlow(Flow[Dict[str, str]]):
|
||||
initial_state = dict()
|
||||
|
||||
@start()
|
||||
@persist(persistence, key="conversation_id")
|
||||
def init_step(self):
|
||||
# Intentionally do NOT set "conversation_id" on state.
|
||||
self.state["message"] = "no key here"
|
||||
|
||||
flow = MissingKeyFlow(persistence=persistence)
|
||||
with pytest.raises(ValueError, match="conversation_id"):
|
||||
flow.kickoff()
|
||||
|
||||
Reference in New Issue
Block a user