chore: fix ruff linting and mypy issues in flow module

This commit is contained in:
Greyson LaLonde
2025-09-22 13:03:06 -04:00
committed by GitHub
parent 0e370593f1
commit 37636f0dd7
12 changed files with 230 additions and 194 deletions

View File

@@ -1,5 +1,4 @@
from crewai.flow.flow import Flow, start, listen, or_, and_, router from crewai.flow.flow import Flow, and_, listen, or_, router, start
from crewai.flow.persistence import persist from crewai.flow.persistence import persist
__all__ = ["Flow", "start", "listen", "or_", "and_", "router", "persist"] __all__ = ["Flow", "and_", "listen", "or_", "persist", "router", "start"]

View File

@@ -1,5 +1,4 @@
import inspect import inspect
from typing import Optional
from pydantic import BaseModel, Field, InstanceOf, model_validator from pydantic import BaseModel, Field, InstanceOf, model_validator
@@ -14,7 +13,7 @@ class FlowTrackable(BaseModel):
inspecting the call stack. inspecting the call stack.
""" """
parent_flow: Optional[InstanceOf[Flow]] = Field( parent_flow: InstanceOf[Flow] | None = Field(
default=None, default=None,
description="The parent flow of the instance, if it was created inside a flow.", description="The parent flow of the instance, if it was created inside a flow.",
) )

View File

@@ -1,14 +1,13 @@
# flow_visualizer.py # flow_visualizer.py
import os import os
from pathlib import Path
from pyvis.network import Network from pyvis.network import Network # type: ignore[import-untyped]
from crewai.flow.config import COLORS, NODE_STYLES from crewai.flow.config import COLORS, NODE_STYLES
from crewai.flow.html_template_handler import HTMLTemplateHandler from crewai.flow.html_template_handler import HTMLTemplateHandler
from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items
from crewai.flow.path_utils import safe_path_join, validate_path_exists from crewai.flow.path_utils import safe_path_join
from crewai.flow.utils import calculate_node_levels from crewai.flow.utils import calculate_node_levels
from crewai.flow.visualization_utils import ( from crewai.flow.visualization_utils import (
add_edges, add_edges,
@@ -34,13 +33,13 @@ class FlowPlot:
ValueError ValueError
If flow object is invalid or missing required attributes. If flow object is invalid or missing required attributes.
""" """
if not hasattr(flow, '_methods'): if not hasattr(flow, "_methods"):
raise ValueError("Invalid flow object: missing '_methods' attribute") raise ValueError("Invalid flow object: missing '_methods' attribute")
if not hasattr(flow, '_listeners'): if not hasattr(flow, "_listeners"):
raise ValueError("Invalid flow object: missing '_listeners' attribute") raise ValueError("Invalid flow object: missing '_listeners' attribute")
if not hasattr(flow, '_start_methods'): if not hasattr(flow, "_start_methods"):
raise ValueError("Invalid flow object: missing '_start_methods' attribute") raise ValueError("Invalid flow object: missing '_start_methods' attribute")
self.flow = flow self.flow = flow
self.colors = COLORS self.colors = COLORS
self.node_styles = NODE_STYLES self.node_styles = NODE_STYLES
@@ -65,7 +64,7 @@ class FlowPlot:
""" """
if not filename or not isinstance(filename, str): if not filename or not isinstance(filename, str):
raise ValueError("Filename must be a non-empty string") raise ValueError("Filename must be a non-empty string")
try: try:
# Initialize network # Initialize network
net = Network( net = Network(
@@ -96,32 +95,34 @@ class FlowPlot:
try: try:
node_levels = calculate_node_levels(self.flow) node_levels = calculate_node_levels(self.flow)
except Exception as e: except Exception as e:
raise ValueError(f"Failed to calculate node levels: {str(e)}") raise ValueError(f"Failed to calculate node levels: {e!s}") from e
# Compute positions # Compute positions
try: try:
node_positions = compute_positions(self.flow, node_levels) node_positions = compute_positions(self.flow, node_levels)
except Exception as e: except Exception as e:
raise ValueError(f"Failed to compute node positions: {str(e)}") raise ValueError(f"Failed to compute node positions: {e!s}") from e
# Add nodes to the network # Add nodes to the network
try: try:
add_nodes_to_network(net, self.flow, node_positions, self.node_styles) add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to add nodes to network: {str(e)}") raise RuntimeError(f"Failed to add nodes to network: {e!s}") from e
# Add edges to the network # Add edges to the network
try: try:
add_edges(net, self.flow, node_positions, self.colors) add_edges(net, self.flow, node_positions, self.colors)
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to add edges to network: {str(e)}") raise RuntimeError(f"Failed to add edges to network: {e!s}") from e
# Generate HTML # Generate HTML
try: try:
network_html = net.generate_html() network_html = net.generate_html()
final_html_content = self._generate_final_html(network_html) final_html_content = self._generate_final_html(network_html)
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to generate network visualization: {str(e)}") raise RuntimeError(
f"Failed to generate network visualization: {e!s}"
) from e
# Save the final HTML content to the file # Save the final HTML content to the file
try: try:
@@ -129,12 +130,16 @@ class FlowPlot:
f.write(final_html_content) f.write(final_html_content)
print(f"Plot saved as {filename}.html") print(f"Plot saved as {filename}.html")
except IOError as e: except IOError as e:
raise IOError(f"Failed to save flow visualization to {filename}.html: {str(e)}") raise IOError(
f"Failed to save flow visualization to {filename}.html: {e!s}"
) from e
except (ValueError, RuntimeError, IOError) as e: except (ValueError, RuntimeError, IOError) as e:
raise e raise e
except Exception as e: except Exception as e:
raise RuntimeError(f"Unexpected error during flow visualization: {str(e)}") raise RuntimeError(
f"Unexpected error during flow visualization: {e!s}"
) from e
finally: finally:
self._cleanup_pyvis_lib() self._cleanup_pyvis_lib()
@@ -165,7 +170,9 @@ class FlowPlot:
try: try:
# Extract just the body content from the generated HTML # Extract just the body content from the generated HTML
current_dir = os.path.dirname(__file__) current_dir = os.path.dirname(__file__)
template_path = safe_path_join("assets", "crewai_flow_visual_template.html", root=current_dir) template_path = safe_path_join(
"assets", "crewai_flow_visual_template.html", root=current_dir
)
logo_path = safe_path_join("assets", "crewai_logo.svg", root=current_dir) logo_path = safe_path_join("assets", "crewai_logo.svg", root=current_dir)
if not os.path.exists(template_path): if not os.path.exists(template_path):
@@ -179,12 +186,9 @@ class FlowPlot:
# Generate the legend items HTML # Generate the legend items HTML
legend_items = get_legend_items(self.colors) legend_items = get_legend_items(self.colors)
legend_items_html = generate_legend_items_html(legend_items) legend_items_html = generate_legend_items_html(legend_items)
final_html_content = html_handler.generate_final_html( return html_handler.generate_final_html(network_body, legend_items_html)
network_body, legend_items_html
)
return final_html_content
except Exception as e: except Exception as e:
raise IOError(f"Failed to generate visualization HTML: {str(e)}") raise IOError(f"Failed to generate visualization HTML: {e!s}") from e
def _cleanup_pyvis_lib(self): def _cleanup_pyvis_lib(self):
""" """
@@ -197,6 +201,7 @@ class FlowPlot:
lib_folder = safe_path_join("lib", root=os.getcwd()) lib_folder = safe_path_join("lib", root=os.getcwd())
if os.path.exists(lib_folder) and os.path.isdir(lib_folder): if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
import shutil import shutil
shutil.rmtree(lib_folder) shutil.rmtree(lib_folder)
except ValueError as e: except ValueError as e:
print(f"Error validating lib folder path: {e}") print(f"Error validating lib folder path: {e}")

View File

@@ -1,8 +1,7 @@
import base64 import base64
import re import re
from pathlib import Path
from crewai.flow.path_utils import safe_path_join, validate_path_exists from crewai.flow.path_utils import validate_path_exists
class HTMLTemplateHandler: class HTMLTemplateHandler:
@@ -28,7 +27,7 @@ class HTMLTemplateHandler:
self.template_path = validate_path_exists(template_path, "file") self.template_path = validate_path_exists(template_path, "file")
self.logo_path = validate_path_exists(logo_path, "file") self.logo_path = validate_path_exists(logo_path, "file")
except ValueError as e: except ValueError as e:
raise ValueError(f"Invalid template or logo path: {e}") raise ValueError(f"Invalid template or logo path: {e}") from e
def read_template(self): def read_template(self):
"""Read and return the HTML template file contents.""" """Read and return the HTML template file contents."""
@@ -53,23 +52,23 @@ class HTMLTemplateHandler:
if "border" in item: if "border" in item:
legend_items_html += f""" legend_items_html += f"""
<div class="legend-item"> <div class="legend-item">
<div class="legend-color-box" style="background-color: {item['color']}; border: 2px dashed {item['border']};"></div> <div class="legend-color-box" style="background-color: {item["color"]}; border: 2px dashed {item["border"]};"></div>
<div>{item['label']}</div> <div>{item["label"]}</div>
</div> </div>
""" """
elif item.get("dashed") is not None: elif item.get("dashed") is not None:
style = "dashed" if item["dashed"] else "solid" style = "dashed" if item["dashed"] else "solid"
legend_items_html += f""" legend_items_html += f"""
<div class="legend-item"> <div class="legend-item">
<div class="legend-{style}" style="border-bottom: 2px {style} {item['color']};"></div> <div class="legend-{style}" style="border-bottom: 2px {style} {item["color"]};"></div>
<div>{item['label']}</div> <div>{item["label"]}</div>
</div> </div>
""" """
else: else:
legend_items_html += f""" legend_items_html += f"""
<div class="legend-item"> <div class="legend-item">
<div class="legend-color-box" style="background-color: {item['color']};"></div> <div class="legend-color-box" style="background-color: {item["color"]};"></div>
<div>{item['label']}</div> <div>{item["label"]}</div>
</div> </div>
""" """
return legend_items_html return legend_items_html
@@ -79,15 +78,9 @@ class HTMLTemplateHandler:
html_template = self.read_template() html_template = self.read_template()
logo_svg_base64 = self.encode_logo() logo_svg_base64 = self.encode_logo()
final_html_content = html_template.replace("{{ title }}", title) return (
final_html_content = final_html_content.replace( html_template.replace("{{ title }}", title)
"{{ network_content }}", network_body .replace("{{ network_content }}", network_body)
.replace("{{ logo_svg_base64 }}", logo_svg_base64)
.replace("<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html)
) )
final_html_content = final_html_content.replace(
"{{ logo_svg_base64 }}", logo_svg_base64
)
final_html_content = final_html_content.replace(
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
)
return final_html_content

View File

@@ -5,12 +5,10 @@ This module provides utilities for secure path handling to prevent directory
traversal attacks and ensure paths remain within allowed boundaries. traversal attacks and ensure paths remain within allowed boundaries.
""" """
import os
from pathlib import Path from pathlib import Path
from typing import List, Union
def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str: def safe_path_join(*parts: str, root: str | Path | None = None) -> str:
""" """
Safely join path components and ensure the result is within allowed boundaries. Safely join path components and ensure the result is within allowed boundaries.
@@ -43,25 +41,25 @@ def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
# Establish root directory # Establish root directory
root_path = Path(root).resolve() if root else Path.cwd() root_path = Path(root).resolve() if root else Path.cwd()
# Join and resolve the full path # Join and resolve the full path
full_path = Path(root_path, *clean_parts).resolve() full_path = Path(root_path, *clean_parts).resolve()
# Check if the resolved path is within root # Check if the resolved path is within root
if not str(full_path).startswith(str(root_path)): if not str(full_path).startswith(str(root_path)):
raise ValueError( raise ValueError(
f"Invalid path: Potential directory traversal. Path must be within {root_path}" f"Invalid path: Potential directory traversal. Path must be within {root_path}"
) )
return str(full_path) return str(full_path)
except Exception as e: except Exception as e:
if isinstance(e, ValueError): if isinstance(e, ValueError):
raise raise
raise ValueError(f"Invalid path components: {str(e)}") raise ValueError(f"Invalid path components: {e!s}") from e
def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str: def validate_path_exists(path: str | Path, file_type: str = "file") -> str:
""" """
Validate that a path exists and is of the expected type. Validate that a path exists and is of the expected type.
@@ -84,24 +82,24 @@ def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str
""" """
try: try:
path_obj = Path(path).resolve() path_obj = Path(path).resolve()
if not path_obj.exists(): if not path_obj.exists():
raise ValueError(f"Path does not exist: {path}") raise ValueError(f"Path does not exist: {path}")
if file_type == "file" and not path_obj.is_file(): if file_type == "file" and not path_obj.is_file():
raise ValueError(f"Path is not a file: {path}") raise ValueError(f"Path is not a file: {path}")
elif file_type == "directory" and not path_obj.is_dir(): if file_type == "directory" and not path_obj.is_dir():
raise ValueError(f"Path is not a directory: {path}") raise ValueError(f"Path is not a directory: {path}")
return str(path_obj) return str(path_obj)
except Exception as e: except Exception as e:
if isinstance(e, ValueError): if isinstance(e, ValueError):
raise raise
raise ValueError(f"Invalid path: {str(e)}") raise ValueError(f"Invalid path: {e!s}") from e
def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]: def list_files(directory: str | Path, pattern: str = "*") -> list[str]:
""" """
Safely list files in a directory matching a pattern. Safely list files in a directory matching a pattern.
@@ -126,10 +124,10 @@ def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
dir_path = Path(directory).resolve() dir_path = Path(directory).resolve()
if not dir_path.is_dir(): if not dir_path.is_dir():
raise ValueError(f"Not a directory: {directory}") raise ValueError(f"Not a directory: {directory}")
return [str(p) for p in dir_path.glob(pattern) if p.is_file()] return [str(p) for p in dir_path.glob(pattern) if p.is_file()]
except Exception as e: except Exception as e:
if isinstance(e, ValueError): if isinstance(e, ValueError):
raise raise
raise ValueError(f"Error listing files: {str(e)}") raise ValueError(f"Error listing files: {e!s}") from e

View File

@@ -4,7 +4,7 @@ CrewAI Flow Persistence.
This module provides interfaces and implementations for persisting flow states. This module provides interfaces and implementations for persisting flow states.
""" """
from typing import Any, Dict, TypeVar, Union from typing import Any, TypeVar
from pydantic import BaseModel from pydantic import BaseModel
@@ -12,7 +12,7 @@ from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.persistence.decorators import persist from crewai.flow.persistence.decorators import persist
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
__all__ = ["FlowPersistence", "persist", "SQLiteFlowPersistence"] __all__ = ["FlowPersistence", "SQLiteFlowPersistence", "persist"]
StateType = TypeVar('StateType', bound=Union[Dict[str, Any], BaseModel]) StateType = TypeVar("StateType", bound=dict[str, Any] | BaseModel)
DictStateType = Dict[str, Any] DictStateType = dict[str, Any]

View File

@@ -1,53 +1,47 @@
"""Base class for flow state persistence.""" """Base class for flow state persistence."""
import abc import abc
from typing import Any, Dict, Optional, Union from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
class FlowPersistence(abc.ABC): class FlowPersistence(abc.ABC):
"""Abstract base class for flow state persistence. """Abstract base class for flow state persistence.
This class defines the interface that all persistence implementations must follow. This class defines the interface that all persistence implementations must follow.
It supports both structured (Pydantic BaseModel) and unstructured (dict) states. It supports both structured (Pydantic BaseModel) and unstructured (dict) states.
""" """
@abc.abstractmethod @abc.abstractmethod
def init_db(self) -> None: def init_db(self) -> None:
"""Initialize the persistence backend. """Initialize the persistence backend.
This method should handle any necessary setup, such as: This method should handle any necessary setup, such as:
- Creating tables - Creating tables
- Establishing connections - Establishing connections
- Setting up indexes - Setting up indexes
""" """
pass
@abc.abstractmethod @abc.abstractmethod
def save_state( def save_state(
self, self, flow_uuid: str, method_name: str, state_data: dict[str, Any] | BaseModel
flow_uuid: str,
method_name: str,
state_data: Union[Dict[str, Any], BaseModel]
) -> None: ) -> None:
"""Persist the flow state after method completion. """Persist the flow state after method completion.
Args: Args:
flow_uuid: Unique identifier for the flow instance flow_uuid: Unique identifier for the flow instance
method_name: Name of the method that just completed method_name: Name of the method that just completed
state_data: Current state data (either dict or Pydantic model) state_data: Current state data (either dict or Pydantic model)
""" """
pass
@abc.abstractmethod @abc.abstractmethod
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]: def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
"""Load the most recent state for a given flow UUID. """Load the most recent state for a given flow UUID.
Args: Args:
flow_uuid: Unique identifier for the flow instance flow_uuid: Unique identifier for the flow instance
Returns: Returns:
The most recent state as a dictionary, or None if no state exists The most recent state as a dictionary, or None if no state exists
""" """
pass

View File

@@ -24,13 +24,10 @@ Example:
import asyncio import asyncio
import functools import functools
import logging import logging
from collections.abc import Callable
from typing import ( from typing import (
Any, Any,
Callable,
Optional,
Type,
TypeVar, TypeVar,
Union,
cast, cast,
) )
@@ -48,7 +45,7 @@ LOG_MESSAGES = {
"save_state": "Saving flow state to memory for ID: {}", "save_state": "Saving flow state to memory for ID: {}",
"save_error": "Failed to persist state for method {}: {}", "save_error": "Failed to persist state for method {}: {}",
"state_missing": "Flow instance has no state", "state_missing": "Flow instance has no state",
"id_missing": "Flow state must have an 'id' field for persistence" "id_missing": "Flow state must have an 'id' field for persistence",
} }
@@ -58,7 +55,13 @@ class PersistenceDecorator:
_printer = Printer() # Class-level printer instance _printer = Printer() # Class-level printer instance
@classmethod @classmethod
def persist_state(cls, flow_instance: Any, method_name: str, persistence_instance: FlowPersistence, verbose: bool = False) -> None: def persist_state(
cls,
flow_instance: Any,
method_name: str,
persistence_instance: FlowPersistence,
verbose: bool = False,
) -> None:
"""Persist flow state with proper error handling and logging. """Persist flow state with proper error handling and logging.
This method handles the persistence of flow state data, including proper This method handles the persistence of flow state data, including proper
@@ -76,22 +79,24 @@ class PersistenceDecorator:
AttributeError: If flow instance lacks required state attributes AttributeError: If flow instance lacks required state attributes
""" """
try: try:
state = getattr(flow_instance, 'state', None) state = getattr(flow_instance, "state", None)
if state is None: if state is None:
raise ValueError("Flow instance has no state") raise ValueError("Flow instance has no state")
flow_uuid: Optional[str] = None flow_uuid: str | None = None
if isinstance(state, dict): if isinstance(state, dict):
flow_uuid = state.get('id') flow_uuid = state.get("id")
elif isinstance(state, BaseModel): elif isinstance(state, BaseModel):
flow_uuid = getattr(state, 'id', None) flow_uuid = getattr(state, "id", None)
if not flow_uuid: if not flow_uuid:
raise ValueError("Flow state must have an 'id' field for persistence") raise ValueError("Flow state must have an 'id' field for persistence")
# Log state saving only if verbose is True # Log state saving only if verbose is True
if verbose: if verbose:
cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan") cls._printer.print(
LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan"
)
logger.info(LOG_MESSAGES["save_state"].format(flow_uuid)) logger.info(LOG_MESSAGES["save_state"].format(flow_uuid))
try: try:
@@ -104,12 +109,12 @@ class PersistenceDecorator:
error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e)) error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e))
cls._printer.print(error_msg, color="red") cls._printer.print(error_msg, color="red")
logger.error(error_msg) logger.error(error_msg)
raise RuntimeError(f"State persistence failed: {str(e)}") from e raise RuntimeError(f"State persistence failed: {e!s}") from e
except AttributeError: except AttributeError as e:
error_msg = LOG_MESSAGES["state_missing"] error_msg = LOG_MESSAGES["state_missing"]
cls._printer.print(error_msg, color="red") cls._printer.print(error_msg, color="red")
logger.error(error_msg) logger.error(error_msg)
raise ValueError(error_msg) raise ValueError(error_msg) from e
except (TypeError, ValueError) as e: except (TypeError, ValueError) as e:
error_msg = LOG_MESSAGES["id_missing"] error_msg = LOG_MESSAGES["id_missing"]
cls._printer.print(error_msg, color="red") cls._printer.print(error_msg, color="red")
@@ -117,7 +122,7 @@ class PersistenceDecorator:
raise ValueError(error_msg) from e raise ValueError(error_msg) from e
def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False): def persist(persistence: FlowPersistence | None = None, verbose: bool = False):
"""Decorator to persist flow state. """Decorator to persist flow state.
This decorator can be applied at either the class level or method level. This decorator can be applied at either the class level or method level.
@@ -144,111 +149,151 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False
def begin(self): def begin(self):
pass pass
""" """
def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]:
def decorator(target: type | Callable[..., T]) -> type | Callable[..., T]:
"""Decorator that handles both class and method decoration.""" """Decorator that handles both class and method decoration."""
actual_persistence = persistence or SQLiteFlowPersistence() actual_persistence = persistence or SQLiteFlowPersistence()
if isinstance(target, type): if isinstance(target, type):
# Class decoration # Class decoration
original_init = getattr(target, "__init__") original_init = target.__init__ # type: ignore[misc]
@functools.wraps(original_init) @functools.wraps(original_init)
def new_init(self: Any, *args: Any, **kwargs: Any) -> None: def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
if 'persistence' not in kwargs: if "persistence" not in kwargs:
kwargs['persistence'] = actual_persistence kwargs["persistence"] = actual_persistence
original_init(self, *args, **kwargs) original_init(self, *args, **kwargs)
setattr(target, "__init__", new_init) target.__init__ = new_init # type: ignore[misc]
# Store original methods to preserve their decorators # Store original methods to preserve their decorators
original_methods = {} original_methods = {
name: method
for name, method in target.__dict__.items(): for name, method in target.__dict__.items()
if callable(method) and ( if callable(method)
hasattr(method, "__is_start_method__") or and (
hasattr(method, "__trigger_methods__") or hasattr(method, "__is_start_method__")
hasattr(method, "__condition_type__") or or hasattr(method, "__trigger_methods__")
hasattr(method, "__is_flow_method__") or or hasattr(method, "__condition_type__")
hasattr(method, "__is_router__") or hasattr(method, "__is_flow_method__")
): or hasattr(method, "__is_router__")
original_methods[name] = method )
}
# Create wrapped versions of the methods that include persistence # Create wrapped versions of the methods that include persistence
for name, method in original_methods.items(): for name, method in original_methods.items():
if asyncio.iscoroutinefunction(method): if asyncio.iscoroutinefunction(method):
# Create a closure to capture the current name and method # Create a closure to capture the current name and method
def create_async_wrapper(method_name: str, original_method: Callable): def create_async_wrapper(
method_name: str, original_method: Callable
):
@functools.wraps(original_method) @functools.wraps(original_method)
async def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: async def method_wrapper(
self: Any, *args: Any, **kwargs: Any
) -> Any:
result = await original_method(self, *args, **kwargs) result = await original_method(self, *args, **kwargs)
PersistenceDecorator.persist_state(self, method_name, actual_persistence, verbose) PersistenceDecorator.persist_state(
self, method_name, actual_persistence, verbose
)
return result return result
return method_wrapper return method_wrapper
wrapped = create_async_wrapper(name, method) wrapped = create_async_wrapper(name, method)
# Preserve all original decorators and attributes # Preserve all original decorators and attributes
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: for attr in [
"__is_start_method__",
"__trigger_methods__",
"__condition_type__",
"__is_router__",
]:
if hasattr(method, attr): if hasattr(method, attr):
setattr(wrapped, attr, getattr(method, attr)) setattr(wrapped, attr, getattr(method, attr))
setattr(wrapped, "__is_flow_method__", True) wrapped.__is_flow_method__ = True # type: ignore[attr-defined]
# Update the class with the wrapped method # Update the class with the wrapped method
setattr(target, name, wrapped) setattr(target, name, wrapped)
else: else:
# Create a closure to capture the current name and method # Create a closure to capture the current name and method
def create_sync_wrapper(method_name: str, original_method: Callable): def create_sync_wrapper(
method_name: str, original_method: Callable
):
@functools.wraps(original_method) @functools.wraps(original_method)
def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
result = original_method(self, *args, **kwargs) result = original_method(self, *args, **kwargs)
PersistenceDecorator.persist_state(self, method_name, actual_persistence, verbose) PersistenceDecorator.persist_state(
self, method_name, actual_persistence, verbose
)
return result return result
return method_wrapper return method_wrapper
wrapped = create_sync_wrapper(name, method) wrapped = create_sync_wrapper(name, method)
# Preserve all original decorators and attributes # Preserve all original decorators and attributes
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: for attr in [
"__is_start_method__",
"__trigger_methods__",
"__condition_type__",
"__is_router__",
]:
if hasattr(method, attr): if hasattr(method, attr):
setattr(wrapped, attr, getattr(method, attr)) setattr(wrapped, attr, getattr(method, attr))
setattr(wrapped, "__is_flow_method__", True) wrapped.__is_flow_method__ = True # type: ignore[attr-defined]
# Update the class with the wrapped method # Update the class with the wrapped method
setattr(target, name, wrapped) setattr(target, name, wrapped)
return target return target
else: # Method decoration
# Method decoration method = target
method = target method.__is_flow_method__ = True # type: ignore[attr-defined]
setattr(method, "__is_flow_method__", True)
if asyncio.iscoroutinefunction(method): if asyncio.iscoroutinefunction(method):
@functools.wraps(method)
async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
method_coro = method(flow_instance, *args, **kwargs)
if asyncio.iscoroutine(method_coro):
result = await method_coro
else:
result = method_coro
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
return result
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: @functools.wraps(method)
if hasattr(method, attr): async def method_async_wrapper(
setattr(method_async_wrapper, attr, getattr(method, attr)) flow_instance: Any, *args: Any, **kwargs: Any
setattr(method_async_wrapper, "__is_flow_method__", True) ) -> T:
return cast(Callable[..., T], method_async_wrapper) method_coro = method(flow_instance, *args, **kwargs)
else: if asyncio.iscoroutine(method_coro):
@functools.wraps(method) result = await method_coro
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T: else:
result = method(flow_instance, *args, **kwargs) result = method_coro
PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose) PersistenceDecorator.persist_state(
return result flow_instance, method.__name__, actual_persistence, verbose
)
return result
for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: for attr in [
if hasattr(method, attr): "__is_start_method__",
setattr(method_sync_wrapper, attr, getattr(method, attr)) "__trigger_methods__",
setattr(method_sync_wrapper, "__is_flow_method__", True) "__condition_type__",
return cast(Callable[..., T], method_sync_wrapper) "__is_router__",
]:
if hasattr(method, attr):
setattr(method_async_wrapper, attr, getattr(method, attr))
method_async_wrapper.__is_flow_method__ = True # type: ignore[attr-defined]
return cast(Callable[..., T], method_async_wrapper)
@functools.wraps(method)
def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
result = method(flow_instance, *args, **kwargs)
PersistenceDecorator.persist_state(
flow_instance, method.__name__, actual_persistence, verbose
)
return result
for attr in [
"__is_start_method__",
"__trigger_methods__",
"__condition_type__",
"__is_router__",
]:
if hasattr(method, attr):
setattr(method_sync_wrapper, attr, getattr(method, attr))
method_sync_wrapper.__is_flow_method__ = True # type: ignore[attr-defined]
return cast(Callable[..., T], method_sync_wrapper)
return decorator return decorator

View File

@@ -6,7 +6,7 @@ import json
import sqlite3 import sqlite3
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Union from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
@@ -23,7 +23,7 @@ class SQLiteFlowPersistence(FlowPersistence):
db_path: str db_path: str
def __init__(self, db_path: Optional[str] = None): def __init__(self, db_path: str | None = None):
"""Initialize SQLite persistence. """Initialize SQLite persistence.
Args: Args:
@@ -70,7 +70,7 @@ class SQLiteFlowPersistence(FlowPersistence):
self, self,
flow_uuid: str, flow_uuid: str,
method_name: str, method_name: str,
state_data: Union[Dict[str, Any], BaseModel], state_data: dict[str, Any] | BaseModel,
) -> None: ) -> None:
"""Save the current flow state to SQLite. """Save the current flow state to SQLite.
@@ -107,7 +107,7 @@ class SQLiteFlowPersistence(FlowPersistence):
), ),
) )
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]: def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
"""Load the most recent state for a given flow UUID. """Load the most recent state for a given flow UUID.
Args: Args:

View File

@@ -5,6 +5,7 @@ the Flow system.
""" """
from typing import Any, TypedDict from typing import Any, TypedDict
from typing_extensions import NotRequired, Required from typing_extensions import NotRequired, Required

View File

@@ -17,10 +17,10 @@ import ast
import inspect import inspect
import textwrap import textwrap
from collections import defaultdict, deque from collections import defaultdict, deque
from typing import Any, Deque, Dict, List, Optional, Set, Union from typing import Any
def get_possible_return_constants(function: Any) -> Optional[List[str]]: def get_possible_return_constants(function: Any) -> list[str] | None:
try: try:
source = inspect.getsource(function) source = inspect.getsource(function)
except OSError: except OSError:
@@ -58,12 +58,12 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
target = node.targets[0] target = node.targets[0]
if isinstance(target, ast.Name): if isinstance(target, ast.Name):
var_name = target.id var_name = target.id
dict_values = []
# Extract string values from the dictionary # Extract string values from the dictionary
for val in node.value.values: dict_values = [
if isinstance(val, ast.Constant) and isinstance(val.value, str): val.value
dict_values.append(val.value) for val in node.value.values
# If non-string, skip or just ignore if isinstance(val, ast.Constant) and isinstance(val.value, str)
]
if dict_values: if dict_values:
dict_definitions[var_name] = dict_values dict_definitions[var_name] = dict_values
self.generic_visit(node) self.generic_visit(node)
@@ -94,7 +94,7 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
return list(return_values) if return_values else None return list(return_values) if return_values else None
def calculate_node_levels(flow: Any) -> Dict[str, int]: def calculate_node_levels(flow: Any) -> dict[str, int]:
""" """
Calculate the hierarchical level of each node in the flow. Calculate the hierarchical level of each node in the flow.
@@ -118,10 +118,10 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
- Handles both OR and AND conditions for listeners - Handles both OR and AND conditions for listeners
- Processes router paths separately - Processes router paths separately
""" """
levels: Dict[str, int] = {} levels: dict[str, int] = {}
queue: Deque[str] = deque() queue: deque[str] = deque()
visited: Set[str] = set() visited: set[str] = set()
pending_and_listeners: Dict[str, Set[str]] = {} pending_and_listeners: dict[str, set[str]] = {}
# Make all start methods at level 0 # Make all start methods at level 0
for method_name, method in flow._methods.items(): for method_name, method in flow._methods.items():
@@ -172,7 +172,7 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
return levels return levels
def count_outgoing_edges(flow: Any) -> Dict[str, int]: def count_outgoing_edges(flow: Any) -> dict[str, int]:
""" """
Count the number of outgoing edges for each method in the flow. Count the number of outgoing edges for each method in the flow.
@@ -197,7 +197,7 @@ def count_outgoing_edges(flow: Any) -> Dict[str, int]:
return counts return counts
def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]: def build_ancestor_dict(flow: Any) -> dict[str, set[str]]:
""" """
Build a dictionary mapping each node to its ancestor nodes. Build a dictionary mapping each node to its ancestor nodes.
@@ -211,8 +211,8 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
Dict[str, Set[str]] Dict[str, Set[str]]
Dictionary mapping each node to a set of its ancestor nodes. Dictionary mapping each node to a set of its ancestor nodes.
""" """
ancestors: Dict[str, Set[str]] = {node: set() for node in flow._methods} ancestors: dict[str, set[str]] = {node: set() for node in flow._methods}
visited: Set[str] = set() visited: set[str] = set()
for node in flow._methods: for node in flow._methods:
if node not in visited: if node not in visited:
dfs_ancestors(node, ancestors, visited, flow) dfs_ancestors(node, ancestors, visited, flow)
@@ -220,7 +220,7 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
def dfs_ancestors( def dfs_ancestors(
node: str, ancestors: Dict[str, Set[str]], visited: Set[str], flow: Any node: str, ancestors: dict[str, set[str]], visited: set[str], flow: Any
) -> None: ) -> None:
""" """
Perform depth-first search to build ancestor relationships. Perform depth-first search to build ancestor relationships.
@@ -265,7 +265,7 @@ def dfs_ancestors(
def is_ancestor( def is_ancestor(
node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]] node: str, ancestor_candidate: str, ancestors: dict[str, set[str]]
) -> bool: ) -> bool:
""" """
Check if one node is an ancestor of another. Check if one node is an ancestor of another.
@@ -287,7 +287,7 @@ def is_ancestor(
return ancestor_candidate in ancestors.get(node, set()) return ancestor_candidate in ancestors.get(node, set())
def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]: def build_parent_children_dict(flow: Any) -> dict[str, list[str]]:
""" """
Build a dictionary mapping parent nodes to their children. Build a dictionary mapping parent nodes to their children.
@@ -307,7 +307,7 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
- Maps router methods to their paths and listeners - Maps router methods to their paths and listeners
- Children lists are sorted for consistent ordering - Children lists are sorted for consistent ordering
""" """
parent_children: Dict[str, List[str]] = {} parent_children: dict[str, list[str]] = {}
# Map listeners to their trigger methods # Map listeners to their trigger methods
for listener_name, (_, trigger_methods) in flow._listeners.items(): for listener_name, (_, trigger_methods) in flow._listeners.items():
@@ -332,7 +332,7 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
def get_child_index( def get_child_index(
parent: str, child: str, parent_children: Dict[str, List[str]] parent: str, child: str, parent_children: dict[str, list[str]]
) -> int: ) -> int:
""" """
Get the index of a child node in its parent's sorted children list. Get the index of a child node in its parent's sorted children list.
@@ -364,7 +364,7 @@ def process_router_paths(flow, current, current_level, levels, queue):
paths = flow._router_paths.get(current, []) paths = flow._router_paths.get(current, [])
for path in paths: for path in paths:
for listener_name, ( for listener_name, (
condition_type, _condition_type,
trigger_methods, trigger_methods,
) in flow._listeners.items(): ) in flow._listeners.items():
if path in trigger_methods: if path in trigger_methods:

View File

@@ -17,7 +17,7 @@ Example
import ast import ast
import inspect import inspect
from typing import Any, Dict, List, Tuple, Union from typing import Any
from .utils import ( from .utils import (
build_ancestor_dict, build_ancestor_dict,
@@ -56,6 +56,7 @@ def method_calls_crew(method: Any) -> bool:
class CrewCallVisitor(ast.NodeVisitor): class CrewCallVisitor(ast.NodeVisitor):
"""AST visitor to detect .crew() method calls.""" """AST visitor to detect .crew() method calls."""
def __init__(self): def __init__(self):
self.found = False self.found = False
@@ -73,8 +74,8 @@ def method_calls_crew(method: Any) -> bool:
def add_nodes_to_network( def add_nodes_to_network(
net: Any, net: Any,
flow: Any, flow: Any,
node_positions: Dict[str, Tuple[float, float]], node_positions: dict[str, tuple[float, float]],
node_styles: Dict[str, Dict[str, Any]] node_styles: dict[str, dict[str, Any]],
) -> None: ) -> None:
""" """
Add nodes to the network visualization with appropriate styling. Add nodes to the network visualization with appropriate styling.
@@ -98,6 +99,7 @@ def add_nodes_to_network(
- Crew methods - Crew methods
- Regular methods - Regular methods
""" """
def human_friendly_label(method_name): def human_friendly_label(method_name):
return method_name.replace("_", " ").title() return method_name.replace("_", " ").title()
@@ -138,10 +140,10 @@ def add_nodes_to_network(
def compute_positions( def compute_positions(
flow: Any, flow: Any,
node_levels: Dict[str, int], node_levels: dict[str, int],
y_spacing: float = 150, y_spacing: float = 150,
x_spacing: float = 300 x_spacing: float = 300,
) -> Dict[str, Tuple[float, float]]: ) -> dict[str, tuple[float, float]]:
""" """
Compute the (x, y) positions for each node in the flow graph. Compute the (x, y) positions for each node in the flow graph.
@@ -161,8 +163,8 @@ def compute_positions(
Dict[str, Tuple[float, float]] Dict[str, Tuple[float, float]]
Dictionary mapping node names to their (x, y) coordinates. Dictionary mapping node names to their (x, y) coordinates.
""" """
level_nodes: Dict[int, List[str]] = {} level_nodes: dict[int, list[str]] = {}
node_positions: Dict[str, Tuple[float, float]] = {} node_positions: dict[str, tuple[float, float]] = {}
for method_name, level in node_levels.items(): for method_name, level in node_levels.items():
level_nodes.setdefault(level, []).append(method_name) level_nodes.setdefault(level, []).append(method_name)
@@ -180,10 +182,10 @@ def compute_positions(
def add_edges( def add_edges(
net: Any, net: Any,
flow: Any, flow: Any,
node_positions: Dict[str, Tuple[float, float]], node_positions: dict[str, tuple[float, float]],
colors: Dict[str, str] colors: dict[str, str],
) -> None: ) -> None:
edge_smooth: Dict[str, Union[str, float]] = {"type": "continuous"} # Default value edge_smooth: dict[str, str | float] = {"type": "continuous"} # Default value
""" """
Add edges to the network visualization with appropriate styling. Add edges to the network visualization with appropriate styling.
@@ -269,7 +271,7 @@ def add_edges(
for router_method_name, paths in flow._router_paths.items(): for router_method_name, paths in flow._router_paths.items():
for path in paths: for path in paths:
for listener_name, ( for listener_name, (
condition_type, _condition_type,
trigger_methods, trigger_methods,
) in flow._listeners.items(): ) in flow._listeners.items():
if path in trigger_methods: if path in trigger_methods: