mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-27 09:08:14 +00:00
Merge branch 'main' into devin/1737272386-flow-override-fix
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
@@ -28,6 +29,9 @@ from crewai.flow.flow_visualizer import plot_flow
|
|||||||
from crewai.flow.persistence.base import FlowPersistence
|
from crewai.flow.persistence.base import FlowPersistence
|
||||||
from crewai.flow.utils import get_possible_return_constants
|
from crewai.flow.utils import get_possible_return_constants
|
||||||
from crewai.telemetry import Telemetry
|
from crewai.telemetry import Telemetry
|
||||||
|
from crewai.utilities.printer import Printer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class FlowState(BaseModel):
|
class FlowState(BaseModel):
|
||||||
@@ -424,6 +428,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
Type parameter T must be either Dict[str, Any] or a subclass of BaseModel."""
|
Type parameter T must be either Dict[str, Any] or a subclass of BaseModel."""
|
||||||
|
|
||||||
_telemetry = Telemetry()
|
_telemetry = Telemetry()
|
||||||
|
_printer = Printer()
|
||||||
|
|
||||||
_start_methods: List[str] = []
|
_start_methods: List[str] = []
|
||||||
_listeners: Dict[str, tuple[str, List[str]]] = {}
|
_listeners: Dict[str, tuple[str, List[str]]] = {}
|
||||||
@@ -485,12 +490,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
|
|
||||||
# Attempt to load state, prioritizing restore_uuid
|
# Attempt to load state, prioritizing restore_uuid
|
||||||
if restore_uuid:
|
if restore_uuid:
|
||||||
|
self._log_flow_event(f"Loading flow state from memory for UUID: {restore_uuid}", color="bold_yellow")
|
||||||
stored_state = self._persistence.load_state(restore_uuid)
|
stored_state = self._persistence.load_state(restore_uuid)
|
||||||
if not stored_state:
|
if not stored_state:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No state found for restore_uuid='{restore_uuid}'"
|
f"No state found for restore_uuid='{restore_uuid}'"
|
||||||
)
|
)
|
||||||
elif kwargs and "id" in kwargs:
|
elif kwargs and "id" in kwargs:
|
||||||
|
self._log_flow_event(f"Loading flow state from memory for ID: {kwargs['id']}", color="bold_yellow")
|
||||||
stored_state = self._persistence.load_state(kwargs["id"])
|
stored_state = self._persistence.load_state(kwargs["id"])
|
||||||
# Don't return early if state not found - let the normal initialization flow handle it
|
# Don't return early if state not found - let the normal initialization flow handle it
|
||||||
# This ensures proper state initialization and override behavior
|
# This ensures proper state initialization and override behavior
|
||||||
@@ -621,6 +628,39 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
"""Returns the list of all outputs from executed methods."""
|
"""Returns the list of all outputs from executed methods."""
|
||||||
return self._method_outputs
|
return self._method_outputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def flow_id(self) -> str:
|
||||||
|
"""Returns the unique identifier of this flow instance.
|
||||||
|
|
||||||
|
This property provides a consistent way to access the flow's unique identifier
|
||||||
|
regardless of the underlying state implementation (dict or BaseModel).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The flow's unique identifier, or an empty string if not found
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This property safely handles both dictionary and BaseModel state types,
|
||||||
|
returning an empty string if the ID cannot be retrieved rather than raising
|
||||||
|
an exception.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
flow = MyFlow()
|
||||||
|
print(f"Current flow ID: {flow.flow_id}") # Safely get flow ID
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not hasattr(self, '_state'):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if isinstance(self._state, dict):
|
||||||
|
return str(self._state.get("id", ""))
|
||||||
|
elif isinstance(self._state, BaseModel):
|
||||||
|
return str(getattr(self._state, "id", ""))
|
||||||
|
return ""
|
||||||
|
except (AttributeError, TypeError):
|
||||||
|
return "" # Safely handle any unexpected attribute access issues
|
||||||
|
|
||||||
def _initialize_state(self, inputs: Dict[str, Any]) -> None:
|
def _initialize_state(self, inputs: Dict[str, Any]) -> None:
|
||||||
"""Initialize or update flow state with new inputs.
|
"""Initialize or update flow state with new inputs.
|
||||||
|
|
||||||
@@ -687,6 +727,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
"""
|
"""
|
||||||
# When restoring from persistence, use the stored ID
|
# When restoring from persistence, use the stored ID
|
||||||
stored_id = stored_state.get("id")
|
stored_id = stored_state.get("id")
|
||||||
|
self._log_flow_event(f"Restoring flow state from memory for ID: {stored_id}", color="bold_yellow")
|
||||||
if not stored_id:
|
if not stored_id:
|
||||||
raise ValueError("Stored state must have an 'id' field")
|
raise ValueError("Stored state must have an 'id' field")
|
||||||
|
|
||||||
@@ -717,6 +758,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
flow_name=self.__class__.__name__,
|
flow_name=self.__class__.__name__,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
self._log_flow_event(f"Flow started with ID: {self.flow_id}", color="yellow")
|
||||||
|
|
||||||
if inputs is not None:
|
if inputs is not None:
|
||||||
self._initialize_state(inputs)
|
self._initialize_state(inputs)
|
||||||
@@ -962,6 +1004,30 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
|
def _log_flow_event(self, message: str, color: str = "yellow", level: str = "info") -> None:
|
||||||
|
"""Centralized logging method for flow events.
|
||||||
|
|
||||||
|
This method provides a consistent interface for logging flow-related events,
|
||||||
|
combining both console output with colors and proper logging levels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The message to log
|
||||||
|
color: Color to use for console output (default: yellow)
|
||||||
|
Available colors: purple, red, bold_green, bold_purple,
|
||||||
|
bold_blue, yellow, bold_yellow
|
||||||
|
level: Log level to use (default: info)
|
||||||
|
Supported levels: info, warning
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This method uses the Printer utility for colored console output
|
||||||
|
and the standard logging module for log level support.
|
||||||
|
"""
|
||||||
|
self._printer.print(message, color=color)
|
||||||
|
if level == "info":
|
||||||
|
logger.info(message)
|
||||||
|
elif level == "warning":
|
||||||
|
logger.warning(message)
|
||||||
|
|
||||||
def plot(self, filename: str = "crewai_flow") -> None:
|
def plot(self, filename: str = "crewai_flow") -> None:
|
||||||
self._telemetry.flow_plotting_span(
|
self._telemetry.flow_plotting_span(
|
||||||
self.__class__.__name__, list(self._methods.keys())
|
self.__class__.__name__, list(self._methods.keys())
|
||||||
|
|||||||
@@ -38,10 +38,95 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
from crewai.flow.persistence.base import FlowPersistence
|
from crewai.flow.persistence.base import FlowPersistence
|
||||||
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
|
||||||
|
from crewai.utilities.printer import Printer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
# Constants for log messages
|
||||||
|
LOG_MESSAGES = {
|
||||||
|
"save_state": "Saving flow state to memory for ID: {}",
|
||||||
|
"save_error": "Failed to persist state for method {}: {}",
|
||||||
|
"state_missing": "Flow instance has no state",
|
||||||
|
"id_missing": "Flow state must have an 'id' field for persistence"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class PersistenceDecorator:
|
||||||
|
"""Class to handle flow state persistence with consistent logging."""
|
||||||
|
|
||||||
|
_printer = Printer() # Class-level printer instance
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def persist_state(cls, flow_instance: Any, method_name: str, persistence_instance: FlowPersistence) -> None:
|
||||||
|
"""Persist flow state with proper error handling and logging.
|
||||||
|
|
||||||
|
This method handles the persistence of flow state data, including proper
|
||||||
|
error handling and colored console output for status updates.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
flow_instance: The flow instance whose state to persist
|
||||||
|
method_name: Name of the method that triggered persistence
|
||||||
|
persistence_instance: The persistence backend to use
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If flow has no state or state lacks an ID
|
||||||
|
RuntimeError: If state persistence fails
|
||||||
|
AttributeError: If flow instance lacks required state attributes
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Uses bold_yellow color for success messages and red for errors.
|
||||||
|
All operations are logged at appropriate levels (info/error).
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
@persist
|
||||||
|
def my_flow_method(self):
|
||||||
|
# Method implementation
|
||||||
|
pass
|
||||||
|
# State will be automatically persisted after method execution
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
state = getattr(flow_instance, 'state', None)
|
||||||
|
if state is None:
|
||||||
|
raise ValueError("Flow instance has no state")
|
||||||
|
|
||||||
|
flow_uuid: Optional[str] = None
|
||||||
|
if isinstance(state, dict):
|
||||||
|
flow_uuid = state.get('id')
|
||||||
|
elif isinstance(state, BaseModel):
|
||||||
|
flow_uuid = getattr(state, 'id', None)
|
||||||
|
|
||||||
|
if not flow_uuid:
|
||||||
|
raise ValueError("Flow state must have an 'id' field for persistence")
|
||||||
|
|
||||||
|
# Log state saving with consistent message
|
||||||
|
cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="bold_yellow")
|
||||||
|
logger.info(LOG_MESSAGES["save_state"].format(flow_uuid))
|
||||||
|
|
||||||
|
try:
|
||||||
|
persistence_instance.save_state(
|
||||||
|
flow_uuid=flow_uuid,
|
||||||
|
method_name=method_name,
|
||||||
|
state_data=state,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e))
|
||||||
|
cls._printer.print(error_msg, color="red")
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise RuntimeError(f"State persistence failed: {str(e)}") from e
|
||||||
|
except AttributeError:
|
||||||
|
error_msg = LOG_MESSAGES["state_missing"]
|
||||||
|
cls._printer.print(error_msg, color="red")
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise ValueError(error_msg)
|
||||||
|
except (TypeError, ValueError) as e:
|
||||||
|
error_msg = LOG_MESSAGES["id_missing"]
|
||||||
|
cls._printer.print(error_msg, color="red")
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise ValueError(error_msg) from e
|
||||||
|
|
||||||
|
|
||||||
def persist(persistence: Optional[FlowPersistence] = None):
|
def persist(persistence: Optional[FlowPersistence] = None):
|
||||||
"""Decorator to persist flow state.
|
"""Decorator to persist flow state.
|
||||||
@@ -69,37 +154,6 @@ def persist(persistence: Optional[FlowPersistence] = None):
|
|||||||
def begin(self):
|
def begin(self):
|
||||||
pass
|
pass
|
||||||
"""
|
"""
|
||||||
def _persist_state(flow_instance: Any, method_name: str, persistence_instance: FlowPersistence) -> None:
|
|
||||||
"""Helper to persist state with error handling."""
|
|
||||||
try:
|
|
||||||
# Get flow UUID from state
|
|
||||||
state = getattr(flow_instance, 'state', None)
|
|
||||||
if state is None:
|
|
||||||
raise ValueError("Flow instance has no state")
|
|
||||||
|
|
||||||
flow_uuid: Optional[str] = None
|
|
||||||
if isinstance(state, dict):
|
|
||||||
flow_uuid = state.get('id')
|
|
||||||
elif isinstance(state, BaseModel):
|
|
||||||
flow_uuid = getattr(state, 'id', None)
|
|
||||||
|
|
||||||
if not flow_uuid:
|
|
||||||
raise ValueError(
|
|
||||||
"Flow state must have an 'id' field for persistence"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Persist the state
|
|
||||||
persistence_instance.save_state(
|
|
||||||
flow_uuid=flow_uuid,
|
|
||||||
method_name=method_name,
|
|
||||||
state_data=state,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to persist state for method {method_name}: {str(e)}"
|
|
||||||
)
|
|
||||||
raise RuntimeError(f"State persistence failed: {str(e)}") from e
|
|
||||||
|
|
||||||
def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]:
|
def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]:
|
||||||
"""Decorator that handles both class and method decoration."""
|
"""Decorator that handles both class and method decoration."""
|
||||||
actual_persistence = persistence or SQLiteFlowPersistence()
|
actual_persistence = persistence or SQLiteFlowPersistence()
|
||||||
@@ -118,14 +172,14 @@ def persist(persistence: Optional[FlowPersistence] = None):
|
|||||||
result = await method_coro
|
result = await method_coro
|
||||||
else:
|
else:
|
||||||
result = method_coro
|
result = method_coro
|
||||||
_persist_state(self, method.__name__, actual_persistence)
|
PersistenceDecorator.persist_state(self, method.__name__, actual_persistence)
|
||||||
return result
|
return result
|
||||||
class_methods[name] = class_async_wrapper
|
class_methods[name] = class_async_wrapper
|
||||||
else:
|
else:
|
||||||
@functools.wraps(method)
|
@functools.wraps(method)
|
||||||
def class_sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
def class_sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
|
||||||
result = method(self, *args, **kwargs)
|
result = method(self, *args, **kwargs)
|
||||||
_persist_state(self, method.__name__, actual_persistence)
|
PersistenceDecorator.persist_state(self, method.__name__, actual_persistence)
|
||||||
return result
|
return result
|
||||||
class_methods[name] = class_sync_wrapper
|
class_methods[name] = class_sync_wrapper
|
||||||
|
|
||||||
@@ -152,7 +206,7 @@ def persist(persistence: Optional[FlowPersistence] = None):
|
|||||||
result = await method_coro
|
result = await method_coro
|
||||||
else:
|
else:
|
||||||
result = method_coro
|
result = method_coro
|
||||||
_persist_state(flow_instance, method.__name__, actual_persistence)
|
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence)
|
||||||
return result
|
return result
|
||||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||||
if hasattr(method, attr):
|
if hasattr(method, attr):
|
||||||
@@ -163,7 +217,7 @@ def persist(persistence: Optional[FlowPersistence] = None):
|
|||||||
@functools.wraps(method)
|
@functools.wraps(method)
|
||||||
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
|
||||||
result = method(flow_instance, *args, **kwargs)
|
result = method(flow_instance, *args, **kwargs)
|
||||||
_persist_state(flow_instance, method.__name__, actual_persistence)
|
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence)
|
||||||
return result
|
return result
|
||||||
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
|
||||||
if hasattr(method, attr):
|
if hasattr(method, attr):
|
||||||
|
|||||||
Reference in New Issue
Block a user