From 37636f0dd7d91cd51fec1d4d8d0e665f7a51a19f Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Mon, 22 Sep 2025 13:03:06 -0400 Subject: [PATCH] chore: fix ruff linting and mypy issues in flow module --- src/crewai/flow/__init__.py | 5 +- src/crewai/flow/flow_trackable.py | 3 +- src/crewai/flow/flow_visualizer.py | 47 +++--- src/crewai/flow/html_template_handler.py | 33 ++-- src/crewai/flow/path_utils.py | 36 ++--- src/crewai/flow/persistence/__init__.py | 8 +- src/crewai/flow/persistence/base.py | 28 ++-- src/crewai/flow/persistence/decorators.py | 185 ++++++++++++++-------- src/crewai/flow/persistence/sqlite.py | 8 +- src/crewai/flow/types.py | 1 + src/crewai/flow/utils.py | 44 ++--- src/crewai/flow/visualization_utils.py | 26 +-- 12 files changed, 230 insertions(+), 194 deletions(-) diff --git a/src/crewai/flow/__init__.py b/src/crewai/flow/__init__.py index 48a49666d..b15c0a720 100644 --- a/src/crewai/flow/__init__.py +++ b/src/crewai/flow/__init__.py @@ -1,5 +1,4 @@ -from crewai.flow.flow import Flow, start, listen, or_, and_, router +from crewai.flow.flow import Flow, and_, listen, or_, router, start from crewai.flow.persistence import persist -__all__ = ["Flow", "start", "listen", "or_", "and_", "router", "persist"] - +__all__ = ["Flow", "and_", "listen", "or_", "persist", "router", "start"] diff --git a/src/crewai/flow/flow_trackable.py b/src/crewai/flow/flow_trackable.py index 3cdbdeed3..30303c272 100644 --- a/src/crewai/flow/flow_trackable.py +++ b/src/crewai/flow/flow_trackable.py @@ -1,5 +1,4 @@ import inspect -from typing import Optional from pydantic import BaseModel, Field, InstanceOf, model_validator @@ -14,7 +13,7 @@ class FlowTrackable(BaseModel): inspecting the call stack. """ - parent_flow: Optional[InstanceOf[Flow]] = Field( + parent_flow: InstanceOf[Flow] | None = Field( default=None, description="The parent flow of the instance, if it was created inside a flow.", ) diff --git a/src/crewai/flow/flow_visualizer.py b/src/crewai/flow/flow_visualizer.py index a70e91a18..5b50c3844 100644 --- a/src/crewai/flow/flow_visualizer.py +++ b/src/crewai/flow/flow_visualizer.py @@ -1,14 +1,13 @@ # flow_visualizer.py import os -from pathlib import Path -from pyvis.network import Network +from pyvis.network import Network # type: ignore[import-untyped] from crewai.flow.config import COLORS, NODE_STYLES from crewai.flow.html_template_handler import HTMLTemplateHandler from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items -from crewai.flow.path_utils import safe_path_join, validate_path_exists +from crewai.flow.path_utils import safe_path_join from crewai.flow.utils import calculate_node_levels from crewai.flow.visualization_utils import ( add_edges, @@ -34,13 +33,13 @@ class FlowPlot: ValueError If flow object is invalid or missing required attributes. """ - if not hasattr(flow, '_methods'): + if not hasattr(flow, "_methods"): raise ValueError("Invalid flow object: missing '_methods' attribute") - if not hasattr(flow, '_listeners'): + if not hasattr(flow, "_listeners"): raise ValueError("Invalid flow object: missing '_listeners' attribute") - if not hasattr(flow, '_start_methods'): + if not hasattr(flow, "_start_methods"): raise ValueError("Invalid flow object: missing '_start_methods' attribute") - + self.flow = flow self.colors = COLORS self.node_styles = NODE_STYLES @@ -65,7 +64,7 @@ class FlowPlot: """ if not filename or not isinstance(filename, str): raise ValueError("Filename must be a non-empty string") - + try: # Initialize network net = Network( @@ -96,32 +95,34 @@ class FlowPlot: try: node_levels = calculate_node_levels(self.flow) except Exception as e: - raise ValueError(f"Failed to calculate node levels: {str(e)}") + raise ValueError(f"Failed to calculate node levels: {e!s}") from e # Compute positions try: node_positions = compute_positions(self.flow, node_levels) except Exception as e: - raise ValueError(f"Failed to compute node positions: {str(e)}") + raise ValueError(f"Failed to compute node positions: {e!s}") from e # Add nodes to the network try: add_nodes_to_network(net, self.flow, node_positions, self.node_styles) except Exception as e: - raise RuntimeError(f"Failed to add nodes to network: {str(e)}") + raise RuntimeError(f"Failed to add nodes to network: {e!s}") from e # Add edges to the network try: add_edges(net, self.flow, node_positions, self.colors) except Exception as e: - raise RuntimeError(f"Failed to add edges to network: {str(e)}") + raise RuntimeError(f"Failed to add edges to network: {e!s}") from e # Generate HTML try: network_html = net.generate_html() final_html_content = self._generate_final_html(network_html) except Exception as e: - raise RuntimeError(f"Failed to generate network visualization: {str(e)}") + raise RuntimeError( + f"Failed to generate network visualization: {e!s}" + ) from e # Save the final HTML content to the file try: @@ -129,12 +130,16 @@ class FlowPlot: f.write(final_html_content) print(f"Plot saved as {filename}.html") except IOError as e: - raise IOError(f"Failed to save flow visualization to {filename}.html: {str(e)}") + raise IOError( + f"Failed to save flow visualization to {filename}.html: {e!s}" + ) from e except (ValueError, RuntimeError, IOError) as e: raise e except Exception as e: - raise RuntimeError(f"Unexpected error during flow visualization: {str(e)}") + raise RuntimeError( + f"Unexpected error during flow visualization: {e!s}" + ) from e finally: self._cleanup_pyvis_lib() @@ -165,7 +170,9 @@ class FlowPlot: try: # Extract just the body content from the generated HTML current_dir = os.path.dirname(__file__) - template_path = safe_path_join("assets", "crewai_flow_visual_template.html", root=current_dir) + template_path = safe_path_join( + "assets", "crewai_flow_visual_template.html", root=current_dir + ) logo_path = safe_path_join("assets", "crewai_logo.svg", root=current_dir) if not os.path.exists(template_path): @@ -179,12 +186,9 @@ class FlowPlot: # Generate the legend items HTML legend_items = get_legend_items(self.colors) legend_items_html = generate_legend_items_html(legend_items) - final_html_content = html_handler.generate_final_html( - network_body, legend_items_html - ) - return final_html_content + return html_handler.generate_final_html(network_body, legend_items_html) except Exception as e: - raise IOError(f"Failed to generate visualization HTML: {str(e)}") + raise IOError(f"Failed to generate visualization HTML: {e!s}") from e def _cleanup_pyvis_lib(self): """ @@ -197,6 +201,7 @@ class FlowPlot: lib_folder = safe_path_join("lib", root=os.getcwd()) if os.path.exists(lib_folder) and os.path.isdir(lib_folder): import shutil + shutil.rmtree(lib_folder) except ValueError as e: print(f"Error validating lib folder path: {e}") diff --git a/src/crewai/flow/html_template_handler.py b/src/crewai/flow/html_template_handler.py index f0d2d89ad..55567393c 100644 --- a/src/crewai/flow/html_template_handler.py +++ b/src/crewai/flow/html_template_handler.py @@ -1,8 +1,7 @@ import base64 import re -from pathlib import Path -from crewai.flow.path_utils import safe_path_join, validate_path_exists +from crewai.flow.path_utils import validate_path_exists class HTMLTemplateHandler: @@ -28,7 +27,7 @@ class HTMLTemplateHandler: self.template_path = validate_path_exists(template_path, "file") self.logo_path = validate_path_exists(logo_path, "file") except ValueError as e: - raise ValueError(f"Invalid template or logo path: {e}") + raise ValueError(f"Invalid template or logo path: {e}") from e def read_template(self): """Read and return the HTML template file contents.""" @@ -53,23 +52,23 @@ class HTMLTemplateHandler: if "border" in item: legend_items_html += f"""
-
-
{item['label']}
+
+
{item["label"]}
""" elif item.get("dashed") is not None: style = "dashed" if item["dashed"] else "solid" legend_items_html += f"""
-
-
{item['label']}
+
+
{item["label"]}
""" else: legend_items_html += f"""
-
-
{item['label']}
+
+
{item["label"]}
""" return legend_items_html @@ -79,15 +78,9 @@ class HTMLTemplateHandler: html_template = self.read_template() logo_svg_base64 = self.encode_logo() - final_html_content = html_template.replace("{{ title }}", title) - final_html_content = final_html_content.replace( - "{{ network_content }}", network_body + return ( + html_template.replace("{{ title }}", title) + .replace("{{ network_content }}", network_body) + .replace("{{ logo_svg_base64 }}", logo_svg_base64) + .replace("", legend_items_html) ) - final_html_content = final_html_content.replace( - "{{ logo_svg_base64 }}", logo_svg_base64 - ) - final_html_content = final_html_content.replace( - "", legend_items_html - ) - - return final_html_content diff --git a/src/crewai/flow/path_utils.py b/src/crewai/flow/path_utils.py index 09ae8cd3d..02a893865 100644 --- a/src/crewai/flow/path_utils.py +++ b/src/crewai/flow/path_utils.py @@ -5,12 +5,10 @@ This module provides utilities for secure path handling to prevent directory traversal attacks and ensure paths remain within allowed boundaries. """ -import os from pathlib import Path -from typing import List, Union -def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str: +def safe_path_join(*parts: str, root: str | Path | None = None) -> str: """ Safely join path components and ensure the result is within allowed boundaries. @@ -43,25 +41,25 @@ def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str: # Establish root directory root_path = Path(root).resolve() if root else Path.cwd() - + # Join and resolve the full path full_path = Path(root_path, *clean_parts).resolve() - + # Check if the resolved path is within root if not str(full_path).startswith(str(root_path)): raise ValueError( f"Invalid path: Potential directory traversal. Path must be within {root_path}" ) - + return str(full_path) - + except Exception as e: if isinstance(e, ValueError): raise - raise ValueError(f"Invalid path components: {str(e)}") + raise ValueError(f"Invalid path components: {e!s}") from e -def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str: +def validate_path_exists(path: str | Path, file_type: str = "file") -> str: """ Validate that a path exists and is of the expected type. @@ -84,24 +82,24 @@ def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str """ try: path_obj = Path(path).resolve() - + if not path_obj.exists(): raise ValueError(f"Path does not exist: {path}") - + if file_type == "file" and not path_obj.is_file(): raise ValueError(f"Path is not a file: {path}") - elif file_type == "directory" and not path_obj.is_dir(): + if file_type == "directory" and not path_obj.is_dir(): raise ValueError(f"Path is not a directory: {path}") - + return str(path_obj) - + except Exception as e: if isinstance(e, ValueError): raise - raise ValueError(f"Invalid path: {str(e)}") + raise ValueError(f"Invalid path: {e!s}") from e -def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]: +def list_files(directory: str | Path, pattern: str = "*") -> list[str]: """ Safely list files in a directory matching a pattern. @@ -126,10 +124,10 @@ def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]: dir_path = Path(directory).resolve() if not dir_path.is_dir(): raise ValueError(f"Not a directory: {directory}") - + return [str(p) for p in dir_path.glob(pattern) if p.is_file()] - + except Exception as e: if isinstance(e, ValueError): raise - raise ValueError(f"Error listing files: {str(e)}") + raise ValueError(f"Error listing files: {e!s}") from e diff --git a/src/crewai/flow/persistence/__init__.py b/src/crewai/flow/persistence/__init__.py index 0b673f6bf..3a542f52c 100644 --- a/src/crewai/flow/persistence/__init__.py +++ b/src/crewai/flow/persistence/__init__.py @@ -4,7 +4,7 @@ CrewAI Flow Persistence. This module provides interfaces and implementations for persisting flow states. """ -from typing import Any, Dict, TypeVar, Union +from typing import Any, TypeVar from pydantic import BaseModel @@ -12,7 +12,7 @@ from crewai.flow.persistence.base import FlowPersistence from crewai.flow.persistence.decorators import persist from crewai.flow.persistence.sqlite import SQLiteFlowPersistence -__all__ = ["FlowPersistence", "persist", "SQLiteFlowPersistence"] +__all__ = ["FlowPersistence", "SQLiteFlowPersistence", "persist"] -StateType = TypeVar('StateType', bound=Union[Dict[str, Any], BaseModel]) -DictStateType = Dict[str, Any] +StateType = TypeVar("StateType", bound=dict[str, Any] | BaseModel) +DictStateType = dict[str, Any] diff --git a/src/crewai/flow/persistence/base.py b/src/crewai/flow/persistence/base.py index c926f6f34..df7f00add 100644 --- a/src/crewai/flow/persistence/base.py +++ b/src/crewai/flow/persistence/base.py @@ -1,53 +1,47 @@ """Base class for flow state persistence.""" import abc -from typing import Any, Dict, Optional, Union +from typing import Any from pydantic import BaseModel class FlowPersistence(abc.ABC): """Abstract base class for flow state persistence. - + This class defines the interface that all persistence implementations must follow. It supports both structured (Pydantic BaseModel) and unstructured (dict) states. """ - + @abc.abstractmethod def init_db(self) -> None: """Initialize the persistence backend. - + This method should handle any necessary setup, such as: - Creating tables - Establishing connections - Setting up indexes """ - pass - + @abc.abstractmethod def save_state( - self, - flow_uuid: str, - method_name: str, - state_data: Union[Dict[str, Any], BaseModel] + self, flow_uuid: str, method_name: str, state_data: dict[str, Any] | BaseModel ) -> None: """Persist the flow state after method completion. - + Args: flow_uuid: Unique identifier for the flow instance method_name: Name of the method that just completed state_data: Current state data (either dict or Pydantic model) """ - pass - + @abc.abstractmethod - def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]: + def load_state(self, flow_uuid: str) -> dict[str, Any] | None: """Load the most recent state for a given flow UUID. - + Args: flow_uuid: Unique identifier for the flow instance - + Returns: The most recent state as a dictionary, or None if no state exists """ - pass diff --git a/src/crewai/flow/persistence/decorators.py b/src/crewai/flow/persistence/decorators.py index 7b3bd447c..fc7ed6bc0 100644 --- a/src/crewai/flow/persistence/decorators.py +++ b/src/crewai/flow/persistence/decorators.py @@ -24,13 +24,10 @@ Example: import asyncio import functools import logging +from collections.abc import Callable from typing import ( Any, - Callable, - Optional, - Type, TypeVar, - Union, cast, ) @@ -48,7 +45,7 @@ LOG_MESSAGES = { "save_state": "Saving flow state to memory for ID: {}", "save_error": "Failed to persist state for method {}: {}", "state_missing": "Flow instance has no state", - "id_missing": "Flow state must have an 'id' field for persistence" + "id_missing": "Flow state must have an 'id' field for persistence", } @@ -58,7 +55,13 @@ class PersistenceDecorator: _printer = Printer() # Class-level printer instance @classmethod - def persist_state(cls, flow_instance: Any, method_name: str, persistence_instance: FlowPersistence, verbose: bool = False) -> None: + def persist_state( + cls, + flow_instance: Any, + method_name: str, + persistence_instance: FlowPersistence, + verbose: bool = False, + ) -> None: """Persist flow state with proper error handling and logging. This method handles the persistence of flow state data, including proper @@ -76,22 +79,24 @@ class PersistenceDecorator: AttributeError: If flow instance lacks required state attributes """ try: - state = getattr(flow_instance, 'state', None) + state = getattr(flow_instance, "state", None) if state is None: raise ValueError("Flow instance has no state") - flow_uuid: Optional[str] = None + flow_uuid: str | None = None if isinstance(state, dict): - flow_uuid = state.get('id') + flow_uuid = state.get("id") elif isinstance(state, BaseModel): - flow_uuid = getattr(state, 'id', None) + flow_uuid = getattr(state, "id", None) if not flow_uuid: raise ValueError("Flow state must have an 'id' field for persistence") # Log state saving only if verbose is True if verbose: - cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan") + cls._printer.print( + LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan" + ) logger.info(LOG_MESSAGES["save_state"].format(flow_uuid)) try: @@ -104,12 +109,12 @@ class PersistenceDecorator: error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e)) cls._printer.print(error_msg, color="red") logger.error(error_msg) - raise RuntimeError(f"State persistence failed: {str(e)}") from e - except AttributeError: + raise RuntimeError(f"State persistence failed: {e!s}") from e + except AttributeError as e: error_msg = LOG_MESSAGES["state_missing"] cls._printer.print(error_msg, color="red") logger.error(error_msg) - raise ValueError(error_msg) + raise ValueError(error_msg) from e except (TypeError, ValueError) as e: error_msg = LOG_MESSAGES["id_missing"] cls._printer.print(error_msg, color="red") @@ -117,7 +122,7 @@ class PersistenceDecorator: raise ValueError(error_msg) from e -def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False): +def persist(persistence: FlowPersistence | None = None, verbose: bool = False): """Decorator to persist flow state. This decorator can be applied at either the class level or method level. @@ -144,111 +149,151 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False def begin(self): pass """ - def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]: + + def decorator(target: type | Callable[..., T]) -> type | Callable[..., T]: """Decorator that handles both class and method decoration.""" actual_persistence = persistence or SQLiteFlowPersistence() if isinstance(target, type): # Class decoration - original_init = getattr(target, "__init__") + original_init = target.__init__ # type: ignore[misc] @functools.wraps(original_init) def new_init(self: Any, *args: Any, **kwargs: Any) -> None: - if 'persistence' not in kwargs: - kwargs['persistence'] = actual_persistence + if "persistence" not in kwargs: + kwargs["persistence"] = actual_persistence original_init(self, *args, **kwargs) - setattr(target, "__init__", new_init) + target.__init__ = new_init # type: ignore[misc] # Store original methods to preserve their decorators - original_methods = {} - - for name, method in target.__dict__.items(): - if callable(method) and ( - hasattr(method, "__is_start_method__") or - hasattr(method, "__trigger_methods__") or - hasattr(method, "__condition_type__") or - hasattr(method, "__is_flow_method__") or - hasattr(method, "__is_router__") - ): - original_methods[name] = method + original_methods = { + name: method + for name, method in target.__dict__.items() + if callable(method) + and ( + hasattr(method, "__is_start_method__") + or hasattr(method, "__trigger_methods__") + or hasattr(method, "__condition_type__") + or hasattr(method, "__is_flow_method__") + or hasattr(method, "__is_router__") + ) + } # Create wrapped versions of the methods that include persistence for name, method in original_methods.items(): if asyncio.iscoroutinefunction(method): # Create a closure to capture the current name and method - def create_async_wrapper(method_name: str, original_method: Callable): + def create_async_wrapper( + method_name: str, original_method: Callable + ): @functools.wraps(original_method) - async def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + async def method_wrapper( + self: Any, *args: Any, **kwargs: Any + ) -> Any: result = await original_method(self, *args, **kwargs) - PersistenceDecorator.persist_state(self, method_name, actual_persistence, verbose) + PersistenceDecorator.persist_state( + self, method_name, actual_persistence, verbose + ) return result + return method_wrapper wrapped = create_async_wrapper(name, method) # Preserve all original decorators and attributes - for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: + for attr in [ + "__is_start_method__", + "__trigger_methods__", + "__condition_type__", + "__is_router__", + ]: if hasattr(method, attr): setattr(wrapped, attr, getattr(method, attr)) - setattr(wrapped, "__is_flow_method__", True) + wrapped.__is_flow_method__ = True # type: ignore[attr-defined] # Update the class with the wrapped method setattr(target, name, wrapped) else: # Create a closure to capture the current name and method - def create_sync_wrapper(method_name: str, original_method: Callable): + def create_sync_wrapper( + method_name: str, original_method: Callable + ): @functools.wraps(original_method) def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: result = original_method(self, *args, **kwargs) - PersistenceDecorator.persist_state(self, method_name, actual_persistence, verbose) + PersistenceDecorator.persist_state( + self, method_name, actual_persistence, verbose + ) return result + return method_wrapper wrapped = create_sync_wrapper(name, method) # Preserve all original decorators and attributes - for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: + for attr in [ + "__is_start_method__", + "__trigger_methods__", + "__condition_type__", + "__is_router__", + ]: if hasattr(method, attr): setattr(wrapped, attr, getattr(method, attr)) - setattr(wrapped, "__is_flow_method__", True) + wrapped.__is_flow_method__ = True # type: ignore[attr-defined] # Update the class with the wrapped method setattr(target, name, wrapped) return target - else: - # Method decoration - method = target - setattr(method, "__is_flow_method__", True) + # Method decoration + method = target + method.__is_flow_method__ = True # type: ignore[attr-defined] - if asyncio.iscoroutinefunction(method): - @functools.wraps(method) - async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T: - method_coro = method(flow_instance, *args, **kwargs) - if asyncio.iscoroutine(method_coro): - result = await method_coro - else: - result = method_coro - PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose) - return result + if asyncio.iscoroutinefunction(method): - for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: - if hasattr(method, attr): - setattr(method_async_wrapper, attr, getattr(method, attr)) - setattr(method_async_wrapper, "__is_flow_method__", True) - return cast(Callable[..., T], method_async_wrapper) - else: - @functools.wraps(method) - def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T: - result = method(flow_instance, *args, **kwargs) - PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose) - return result + @functools.wraps(method) + async def method_async_wrapper( + flow_instance: Any, *args: Any, **kwargs: Any + ) -> T: + method_coro = method(flow_instance, *args, **kwargs) + if asyncio.iscoroutine(method_coro): + result = await method_coro + else: + result = method_coro + PersistenceDecorator.persist_state( + flow_instance, method.__name__, actual_persistence, verbose + ) + return result - for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]: - if hasattr(method, attr): - setattr(method_sync_wrapper, attr, getattr(method, attr)) - setattr(method_sync_wrapper, "__is_flow_method__", True) - return cast(Callable[..., T], method_sync_wrapper) + for attr in [ + "__is_start_method__", + "__trigger_methods__", + "__condition_type__", + "__is_router__", + ]: + if hasattr(method, attr): + setattr(method_async_wrapper, attr, getattr(method, attr)) + method_async_wrapper.__is_flow_method__ = True # type: ignore[attr-defined] + return cast(Callable[..., T], method_async_wrapper) + + @functools.wraps(method) + def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T: + result = method(flow_instance, *args, **kwargs) + PersistenceDecorator.persist_state( + flow_instance, method.__name__, actual_persistence, verbose + ) + return result + + for attr in [ + "__is_start_method__", + "__trigger_methods__", + "__condition_type__", + "__is_router__", + ]: + if hasattr(method, attr): + setattr(method_sync_wrapper, attr, getattr(method, attr)) + method_sync_wrapper.__is_flow_method__ = True # type: ignore[attr-defined] + return cast(Callable[..., T], method_sync_wrapper) return decorator diff --git a/src/crewai/flow/persistence/sqlite.py b/src/crewai/flow/persistence/sqlite.py index ee38c614d..1163d86c5 100644 --- a/src/crewai/flow/persistence/sqlite.py +++ b/src/crewai/flow/persistence/sqlite.py @@ -6,7 +6,7 @@ import json import sqlite3 from datetime import datetime, timezone from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any from pydantic import BaseModel @@ -23,7 +23,7 @@ class SQLiteFlowPersistence(FlowPersistence): db_path: str - def __init__(self, db_path: Optional[str] = None): + def __init__(self, db_path: str | None = None): """Initialize SQLite persistence. Args: @@ -70,7 +70,7 @@ class SQLiteFlowPersistence(FlowPersistence): self, flow_uuid: str, method_name: str, - state_data: Union[Dict[str, Any], BaseModel], + state_data: dict[str, Any] | BaseModel, ) -> None: """Save the current flow state to SQLite. @@ -107,7 +107,7 @@ class SQLiteFlowPersistence(FlowPersistence): ), ) - def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]: + def load_state(self, flow_uuid: str) -> dict[str, Any] | None: """Load the most recent state for a given flow UUID. Args: diff --git a/src/crewai/flow/types.py b/src/crewai/flow/types.py index 8b6c9e6ad..38e3b7376 100644 --- a/src/crewai/flow/types.py +++ b/src/crewai/flow/types.py @@ -5,6 +5,7 @@ the Flow system. """ from typing import Any, TypedDict + from typing_extensions import NotRequired, Required diff --git a/src/crewai/flow/utils.py b/src/crewai/flow/utils.py index 81f3c1041..74e617bee 100644 --- a/src/crewai/flow/utils.py +++ b/src/crewai/flow/utils.py @@ -17,10 +17,10 @@ import ast import inspect import textwrap from collections import defaultdict, deque -from typing import Any, Deque, Dict, List, Optional, Set, Union +from typing import Any -def get_possible_return_constants(function: Any) -> Optional[List[str]]: +def get_possible_return_constants(function: Any) -> list[str] | None: try: source = inspect.getsource(function) except OSError: @@ -58,12 +58,12 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]: target = node.targets[0] if isinstance(target, ast.Name): var_name = target.id - dict_values = [] # Extract string values from the dictionary - for val in node.value.values: - if isinstance(val, ast.Constant) and isinstance(val.value, str): - dict_values.append(val.value) - # If non-string, skip or just ignore + dict_values = [ + val.value + for val in node.value.values + if isinstance(val, ast.Constant) and isinstance(val.value, str) + ] if dict_values: dict_definitions[var_name] = dict_values self.generic_visit(node) @@ -94,7 +94,7 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]: return list(return_values) if return_values else None -def calculate_node_levels(flow: Any) -> Dict[str, int]: +def calculate_node_levels(flow: Any) -> dict[str, int]: """ Calculate the hierarchical level of each node in the flow. @@ -118,10 +118,10 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]: - Handles both OR and AND conditions for listeners - Processes router paths separately """ - levels: Dict[str, int] = {} - queue: Deque[str] = deque() - visited: Set[str] = set() - pending_and_listeners: Dict[str, Set[str]] = {} + levels: dict[str, int] = {} + queue: deque[str] = deque() + visited: set[str] = set() + pending_and_listeners: dict[str, set[str]] = {} # Make all start methods at level 0 for method_name, method in flow._methods.items(): @@ -172,7 +172,7 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]: return levels -def count_outgoing_edges(flow: Any) -> Dict[str, int]: +def count_outgoing_edges(flow: Any) -> dict[str, int]: """ Count the number of outgoing edges for each method in the flow. @@ -197,7 +197,7 @@ def count_outgoing_edges(flow: Any) -> Dict[str, int]: return counts -def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]: +def build_ancestor_dict(flow: Any) -> dict[str, set[str]]: """ Build a dictionary mapping each node to its ancestor nodes. @@ -211,8 +211,8 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]: Dict[str, Set[str]] Dictionary mapping each node to a set of its ancestor nodes. """ - ancestors: Dict[str, Set[str]] = {node: set() for node in flow._methods} - visited: Set[str] = set() + ancestors: dict[str, set[str]] = {node: set() for node in flow._methods} + visited: set[str] = set() for node in flow._methods: if node not in visited: dfs_ancestors(node, ancestors, visited, flow) @@ -220,7 +220,7 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]: def dfs_ancestors( - node: str, ancestors: Dict[str, Set[str]], visited: Set[str], flow: Any + node: str, ancestors: dict[str, set[str]], visited: set[str], flow: Any ) -> None: """ Perform depth-first search to build ancestor relationships. @@ -265,7 +265,7 @@ def dfs_ancestors( def is_ancestor( - node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]] + node: str, ancestor_candidate: str, ancestors: dict[str, set[str]] ) -> bool: """ Check if one node is an ancestor of another. @@ -287,7 +287,7 @@ def is_ancestor( return ancestor_candidate in ancestors.get(node, set()) -def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]: +def build_parent_children_dict(flow: Any) -> dict[str, list[str]]: """ Build a dictionary mapping parent nodes to their children. @@ -307,7 +307,7 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]: - Maps router methods to their paths and listeners - Children lists are sorted for consistent ordering """ - parent_children: Dict[str, List[str]] = {} + parent_children: dict[str, list[str]] = {} # Map listeners to their trigger methods for listener_name, (_, trigger_methods) in flow._listeners.items(): @@ -332,7 +332,7 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]: def get_child_index( - parent: str, child: str, parent_children: Dict[str, List[str]] + parent: str, child: str, parent_children: dict[str, list[str]] ) -> int: """ Get the index of a child node in its parent's sorted children list. @@ -364,7 +364,7 @@ def process_router_paths(flow, current, current_level, levels, queue): paths = flow._router_paths.get(current, []) for path in paths: for listener_name, ( - condition_type, + _condition_type, trigger_methods, ) in flow._listeners.items(): if path in trigger_methods: diff --git a/src/crewai/flow/visualization_utils.py b/src/crewai/flow/visualization_utils.py index 202fb12d8..721aef23b 100644 --- a/src/crewai/flow/visualization_utils.py +++ b/src/crewai/flow/visualization_utils.py @@ -17,7 +17,7 @@ Example import ast import inspect -from typing import Any, Dict, List, Tuple, Union +from typing import Any from .utils import ( build_ancestor_dict, @@ -56,6 +56,7 @@ def method_calls_crew(method: Any) -> bool: class CrewCallVisitor(ast.NodeVisitor): """AST visitor to detect .crew() method calls.""" + def __init__(self): self.found = False @@ -73,8 +74,8 @@ def method_calls_crew(method: Any) -> bool: def add_nodes_to_network( net: Any, flow: Any, - node_positions: Dict[str, Tuple[float, float]], - node_styles: Dict[str, Dict[str, Any]] + node_positions: dict[str, tuple[float, float]], + node_styles: dict[str, dict[str, Any]], ) -> None: """ Add nodes to the network visualization with appropriate styling. @@ -98,6 +99,7 @@ def add_nodes_to_network( - Crew methods - Regular methods """ + def human_friendly_label(method_name): return method_name.replace("_", " ").title() @@ -138,10 +140,10 @@ def add_nodes_to_network( def compute_positions( flow: Any, - node_levels: Dict[str, int], + node_levels: dict[str, int], y_spacing: float = 150, - x_spacing: float = 300 -) -> Dict[str, Tuple[float, float]]: + x_spacing: float = 300, +) -> dict[str, tuple[float, float]]: """ Compute the (x, y) positions for each node in the flow graph. @@ -161,8 +163,8 @@ def compute_positions( Dict[str, Tuple[float, float]] Dictionary mapping node names to their (x, y) coordinates. """ - level_nodes: Dict[int, List[str]] = {} - node_positions: Dict[str, Tuple[float, float]] = {} + level_nodes: dict[int, list[str]] = {} + node_positions: dict[str, tuple[float, float]] = {} for method_name, level in node_levels.items(): level_nodes.setdefault(level, []).append(method_name) @@ -180,10 +182,10 @@ def compute_positions( def add_edges( net: Any, flow: Any, - node_positions: Dict[str, Tuple[float, float]], - colors: Dict[str, str] + node_positions: dict[str, tuple[float, float]], + colors: dict[str, str], ) -> None: - edge_smooth: Dict[str, Union[str, float]] = {"type": "continuous"} # Default value + edge_smooth: dict[str, str | float] = {"type": "continuous"} # Default value """ Add edges to the network visualization with appropriate styling. @@ -269,7 +271,7 @@ def add_edges( for router_method_name, paths in flow._router_paths.items(): for path in paths: for listener_name, ( - condition_type, + _condition_type, trigger_methods, ) in flow._listeners.items(): if path in trigger_methods: