diff --git a/src/crewai/flow/__init__.py b/src/crewai/flow/__init__.py
index 48a49666d..b15c0a720 100644
--- a/src/crewai/flow/__init__.py
+++ b/src/crewai/flow/__init__.py
@@ -1,5 +1,4 @@
-from crewai.flow.flow import Flow, start, listen, or_, and_, router
+from crewai.flow.flow import Flow, and_, listen, or_, router, start
from crewai.flow.persistence import persist
-__all__ = ["Flow", "start", "listen", "or_", "and_", "router", "persist"]
-
+__all__ = ["Flow", "and_", "listen", "or_", "persist", "router", "start"]
diff --git a/src/crewai/flow/flow_trackable.py b/src/crewai/flow/flow_trackable.py
index 3cdbdeed3..30303c272 100644
--- a/src/crewai/flow/flow_trackable.py
+++ b/src/crewai/flow/flow_trackable.py
@@ -1,5 +1,4 @@
import inspect
-from typing import Optional
from pydantic import BaseModel, Field, InstanceOf, model_validator
@@ -14,7 +13,7 @@ class FlowTrackable(BaseModel):
inspecting the call stack.
"""
- parent_flow: Optional[InstanceOf[Flow]] = Field(
+ parent_flow: InstanceOf[Flow] | None = Field(
default=None,
description="The parent flow of the instance, if it was created inside a flow.",
)
diff --git a/src/crewai/flow/flow_visualizer.py b/src/crewai/flow/flow_visualizer.py
index a70e91a18..5b50c3844 100644
--- a/src/crewai/flow/flow_visualizer.py
+++ b/src/crewai/flow/flow_visualizer.py
@@ -1,14 +1,13 @@
# flow_visualizer.py
import os
-from pathlib import Path
-from pyvis.network import Network
+from pyvis.network import Network # type: ignore[import-untyped]
from crewai.flow.config import COLORS, NODE_STYLES
from crewai.flow.html_template_handler import HTMLTemplateHandler
from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items
-from crewai.flow.path_utils import safe_path_join, validate_path_exists
+from crewai.flow.path_utils import safe_path_join
from crewai.flow.utils import calculate_node_levels
from crewai.flow.visualization_utils import (
add_edges,
@@ -34,13 +33,13 @@ class FlowPlot:
ValueError
If flow object is invalid or missing required attributes.
"""
- if not hasattr(flow, '_methods'):
+ if not hasattr(flow, "_methods"):
raise ValueError("Invalid flow object: missing '_methods' attribute")
- if not hasattr(flow, '_listeners'):
+ if not hasattr(flow, "_listeners"):
raise ValueError("Invalid flow object: missing '_listeners' attribute")
- if not hasattr(flow, '_start_methods'):
+ if not hasattr(flow, "_start_methods"):
raise ValueError("Invalid flow object: missing '_start_methods' attribute")
-
+
self.flow = flow
self.colors = COLORS
self.node_styles = NODE_STYLES
@@ -65,7 +64,7 @@ class FlowPlot:
"""
if not filename or not isinstance(filename, str):
raise ValueError("Filename must be a non-empty string")
-
+
try:
# Initialize network
net = Network(
@@ -96,32 +95,34 @@ class FlowPlot:
try:
node_levels = calculate_node_levels(self.flow)
except Exception as e:
- raise ValueError(f"Failed to calculate node levels: {str(e)}")
+ raise ValueError(f"Failed to calculate node levels: {e!s}") from e
# Compute positions
try:
node_positions = compute_positions(self.flow, node_levels)
except Exception as e:
- raise ValueError(f"Failed to compute node positions: {str(e)}")
+ raise ValueError(f"Failed to compute node positions: {e!s}") from e
# Add nodes to the network
try:
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
except Exception as e:
- raise RuntimeError(f"Failed to add nodes to network: {str(e)}")
+ raise RuntimeError(f"Failed to add nodes to network: {e!s}") from e
# Add edges to the network
try:
add_edges(net, self.flow, node_positions, self.colors)
except Exception as e:
- raise RuntimeError(f"Failed to add edges to network: {str(e)}")
+ raise RuntimeError(f"Failed to add edges to network: {e!s}") from e
# Generate HTML
try:
network_html = net.generate_html()
final_html_content = self._generate_final_html(network_html)
except Exception as e:
- raise RuntimeError(f"Failed to generate network visualization: {str(e)}")
+ raise RuntimeError(
+ f"Failed to generate network visualization: {e!s}"
+ ) from e
# Save the final HTML content to the file
try:
@@ -129,12 +130,16 @@ class FlowPlot:
f.write(final_html_content)
print(f"Plot saved as {filename}.html")
except IOError as e:
- raise IOError(f"Failed to save flow visualization to {filename}.html: {str(e)}")
+ raise IOError(
+ f"Failed to save flow visualization to {filename}.html: {e!s}"
+ ) from e
except (ValueError, RuntimeError, IOError) as e:
raise e
except Exception as e:
- raise RuntimeError(f"Unexpected error during flow visualization: {str(e)}")
+ raise RuntimeError(
+ f"Unexpected error during flow visualization: {e!s}"
+ ) from e
finally:
self._cleanup_pyvis_lib()
@@ -165,7 +170,9 @@ class FlowPlot:
try:
# Extract just the body content from the generated HTML
current_dir = os.path.dirname(__file__)
- template_path = safe_path_join("assets", "crewai_flow_visual_template.html", root=current_dir)
+ template_path = safe_path_join(
+ "assets", "crewai_flow_visual_template.html", root=current_dir
+ )
logo_path = safe_path_join("assets", "crewai_logo.svg", root=current_dir)
if not os.path.exists(template_path):
@@ -179,12 +186,9 @@ class FlowPlot:
# Generate the legend items HTML
legend_items = get_legend_items(self.colors)
legend_items_html = generate_legend_items_html(legend_items)
- final_html_content = html_handler.generate_final_html(
- network_body, legend_items_html
- )
- return final_html_content
+ return html_handler.generate_final_html(network_body, legend_items_html)
except Exception as e:
- raise IOError(f"Failed to generate visualization HTML: {str(e)}")
+ raise IOError(f"Failed to generate visualization HTML: {e!s}") from e
def _cleanup_pyvis_lib(self):
"""
@@ -197,6 +201,7 @@ class FlowPlot:
lib_folder = safe_path_join("lib", root=os.getcwd())
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
import shutil
+
shutil.rmtree(lib_folder)
except ValueError as e:
print(f"Error validating lib folder path: {e}")
diff --git a/src/crewai/flow/html_template_handler.py b/src/crewai/flow/html_template_handler.py
index f0d2d89ad..55567393c 100644
--- a/src/crewai/flow/html_template_handler.py
+++ b/src/crewai/flow/html_template_handler.py
@@ -1,8 +1,7 @@
import base64
import re
-from pathlib import Path
-from crewai.flow.path_utils import safe_path_join, validate_path_exists
+from crewai.flow.path_utils import validate_path_exists
class HTMLTemplateHandler:
@@ -28,7 +27,7 @@ class HTMLTemplateHandler:
self.template_path = validate_path_exists(template_path, "file")
self.logo_path = validate_path_exists(logo_path, "file")
except ValueError as e:
- raise ValueError(f"Invalid template or logo path: {e}")
+ raise ValueError(f"Invalid template or logo path: {e}") from e
def read_template(self):
"""Read and return the HTML template file contents."""
@@ -53,23 +52,23 @@ class HTMLTemplateHandler:
if "border" in item:
legend_items_html += f"""
-
-
{item['label']}
+
+
{item["label"]}
"""
elif item.get("dashed") is not None:
style = "dashed" if item["dashed"] else "solid"
legend_items_html += f"""
-
-
{item['label']}
+
+
{item["label"]}
"""
else:
legend_items_html += f"""
-
-
{item['label']}
+
+
{item["label"]}
"""
return legend_items_html
@@ -79,15 +78,9 @@ class HTMLTemplateHandler:
html_template = self.read_template()
logo_svg_base64 = self.encode_logo()
- final_html_content = html_template.replace("{{ title }}", title)
- final_html_content = final_html_content.replace(
- "{{ network_content }}", network_body
+ return (
+ html_template.replace("{{ title }}", title)
+ .replace("{{ network_content }}", network_body)
+ .replace("{{ logo_svg_base64 }}", logo_svg_base64)
+ .replace("", legend_items_html)
)
- final_html_content = final_html_content.replace(
- "{{ logo_svg_base64 }}", logo_svg_base64
- )
- final_html_content = final_html_content.replace(
- "", legend_items_html
- )
-
- return final_html_content
diff --git a/src/crewai/flow/path_utils.py b/src/crewai/flow/path_utils.py
index 09ae8cd3d..02a893865 100644
--- a/src/crewai/flow/path_utils.py
+++ b/src/crewai/flow/path_utils.py
@@ -5,12 +5,10 @@ This module provides utilities for secure path handling to prevent directory
traversal attacks and ensure paths remain within allowed boundaries.
"""
-import os
from pathlib import Path
-from typing import List, Union
-def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
+def safe_path_join(*parts: str, root: str | Path | None = None) -> str:
"""
Safely join path components and ensure the result is within allowed boundaries.
@@ -43,25 +41,25 @@ def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str:
# Establish root directory
root_path = Path(root).resolve() if root else Path.cwd()
-
+
# Join and resolve the full path
full_path = Path(root_path, *clean_parts).resolve()
-
+
# Check if the resolved path is within root
if not str(full_path).startswith(str(root_path)):
raise ValueError(
f"Invalid path: Potential directory traversal. Path must be within {root_path}"
)
-
+
return str(full_path)
-
+
except Exception as e:
if isinstance(e, ValueError):
raise
- raise ValueError(f"Invalid path components: {str(e)}")
+ raise ValueError(f"Invalid path components: {e!s}") from e
-def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str:
+def validate_path_exists(path: str | Path, file_type: str = "file") -> str:
"""
Validate that a path exists and is of the expected type.
@@ -84,24 +82,24 @@ def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str
"""
try:
path_obj = Path(path).resolve()
-
+
if not path_obj.exists():
raise ValueError(f"Path does not exist: {path}")
-
+
if file_type == "file" and not path_obj.is_file():
raise ValueError(f"Path is not a file: {path}")
- elif file_type == "directory" and not path_obj.is_dir():
+ if file_type == "directory" and not path_obj.is_dir():
raise ValueError(f"Path is not a directory: {path}")
-
+
return str(path_obj)
-
+
except Exception as e:
if isinstance(e, ValueError):
raise
- raise ValueError(f"Invalid path: {str(e)}")
+ raise ValueError(f"Invalid path: {e!s}") from e
-def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
+def list_files(directory: str | Path, pattern: str = "*") -> list[str]:
"""
Safely list files in a directory matching a pattern.
@@ -126,10 +124,10 @@ def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]:
dir_path = Path(directory).resolve()
if not dir_path.is_dir():
raise ValueError(f"Not a directory: {directory}")
-
+
return [str(p) for p in dir_path.glob(pattern) if p.is_file()]
-
+
except Exception as e:
if isinstance(e, ValueError):
raise
- raise ValueError(f"Error listing files: {str(e)}")
+ raise ValueError(f"Error listing files: {e!s}") from e
diff --git a/src/crewai/flow/persistence/__init__.py b/src/crewai/flow/persistence/__init__.py
index 0b673f6bf..3a542f52c 100644
--- a/src/crewai/flow/persistence/__init__.py
+++ b/src/crewai/flow/persistence/__init__.py
@@ -4,7 +4,7 @@ CrewAI Flow Persistence.
This module provides interfaces and implementations for persisting flow states.
"""
-from typing import Any, Dict, TypeVar, Union
+from typing import Any, TypeVar
from pydantic import BaseModel
@@ -12,7 +12,7 @@ from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.persistence.decorators import persist
from crewai.flow.persistence.sqlite import SQLiteFlowPersistence
-__all__ = ["FlowPersistence", "persist", "SQLiteFlowPersistence"]
+__all__ = ["FlowPersistence", "SQLiteFlowPersistence", "persist"]
-StateType = TypeVar('StateType', bound=Union[Dict[str, Any], BaseModel])
-DictStateType = Dict[str, Any]
+StateType = TypeVar("StateType", bound=dict[str, Any] | BaseModel)
+DictStateType = dict[str, Any]
diff --git a/src/crewai/flow/persistence/base.py b/src/crewai/flow/persistence/base.py
index c926f6f34..df7f00add 100644
--- a/src/crewai/flow/persistence/base.py
+++ b/src/crewai/flow/persistence/base.py
@@ -1,53 +1,47 @@
"""Base class for flow state persistence."""
import abc
-from typing import Any, Dict, Optional, Union
+from typing import Any
from pydantic import BaseModel
class FlowPersistence(abc.ABC):
"""Abstract base class for flow state persistence.
-
+
This class defines the interface that all persistence implementations must follow.
It supports both structured (Pydantic BaseModel) and unstructured (dict) states.
"""
-
+
@abc.abstractmethod
def init_db(self) -> None:
"""Initialize the persistence backend.
-
+
This method should handle any necessary setup, such as:
- Creating tables
- Establishing connections
- Setting up indexes
"""
- pass
-
+
@abc.abstractmethod
def save_state(
- self,
- flow_uuid: str,
- method_name: str,
- state_data: Union[Dict[str, Any], BaseModel]
+ self, flow_uuid: str, method_name: str, state_data: dict[str, Any] | BaseModel
) -> None:
"""Persist the flow state after method completion.
-
+
Args:
flow_uuid: Unique identifier for the flow instance
method_name: Name of the method that just completed
state_data: Current state data (either dict or Pydantic model)
"""
- pass
-
+
@abc.abstractmethod
- def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
+ def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
"""Load the most recent state for a given flow UUID.
-
+
Args:
flow_uuid: Unique identifier for the flow instance
-
+
Returns:
The most recent state as a dictionary, or None if no state exists
"""
- pass
diff --git a/src/crewai/flow/persistence/decorators.py b/src/crewai/flow/persistence/decorators.py
index 7b3bd447c..fc7ed6bc0 100644
--- a/src/crewai/flow/persistence/decorators.py
+++ b/src/crewai/flow/persistence/decorators.py
@@ -24,13 +24,10 @@ Example:
import asyncio
import functools
import logging
+from collections.abc import Callable
from typing import (
Any,
- Callable,
- Optional,
- Type,
TypeVar,
- Union,
cast,
)
@@ -48,7 +45,7 @@ LOG_MESSAGES = {
"save_state": "Saving flow state to memory for ID: {}",
"save_error": "Failed to persist state for method {}: {}",
"state_missing": "Flow instance has no state",
- "id_missing": "Flow state must have an 'id' field for persistence"
+ "id_missing": "Flow state must have an 'id' field for persistence",
}
@@ -58,7 +55,13 @@ class PersistenceDecorator:
_printer = Printer() # Class-level printer instance
@classmethod
- def persist_state(cls, flow_instance: Any, method_name: str, persistence_instance: FlowPersistence, verbose: bool = False) -> None:
+ def persist_state(
+ cls,
+ flow_instance: Any,
+ method_name: str,
+ persistence_instance: FlowPersistence,
+ verbose: bool = False,
+ ) -> None:
"""Persist flow state with proper error handling and logging.
This method handles the persistence of flow state data, including proper
@@ -76,22 +79,24 @@ class PersistenceDecorator:
AttributeError: If flow instance lacks required state attributes
"""
try:
- state = getattr(flow_instance, 'state', None)
+ state = getattr(flow_instance, "state", None)
if state is None:
raise ValueError("Flow instance has no state")
- flow_uuid: Optional[str] = None
+ flow_uuid: str | None = None
if isinstance(state, dict):
- flow_uuid = state.get('id')
+ flow_uuid = state.get("id")
elif isinstance(state, BaseModel):
- flow_uuid = getattr(state, 'id', None)
+ flow_uuid = getattr(state, "id", None)
if not flow_uuid:
raise ValueError("Flow state must have an 'id' field for persistence")
# Log state saving only if verbose is True
if verbose:
- cls._printer.print(LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan")
+ cls._printer.print(
+ LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan"
+ )
logger.info(LOG_MESSAGES["save_state"].format(flow_uuid))
try:
@@ -104,12 +109,12 @@ class PersistenceDecorator:
error_msg = LOG_MESSAGES["save_error"].format(method_name, str(e))
cls._printer.print(error_msg, color="red")
logger.error(error_msg)
- raise RuntimeError(f"State persistence failed: {str(e)}") from e
- except AttributeError:
+ raise RuntimeError(f"State persistence failed: {e!s}") from e
+ except AttributeError as e:
error_msg = LOG_MESSAGES["state_missing"]
cls._printer.print(error_msg, color="red")
logger.error(error_msg)
- raise ValueError(error_msg)
+ raise ValueError(error_msg) from e
except (TypeError, ValueError) as e:
error_msg = LOG_MESSAGES["id_missing"]
cls._printer.print(error_msg, color="red")
@@ -117,7 +122,7 @@ class PersistenceDecorator:
raise ValueError(error_msg) from e
-def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False):
+def persist(persistence: FlowPersistence | None = None, verbose: bool = False):
"""Decorator to persist flow state.
This decorator can be applied at either the class level or method level.
@@ -144,111 +149,151 @@ def persist(persistence: Optional[FlowPersistence] = None, verbose: bool = False
def begin(self):
pass
"""
- def decorator(target: Union[Type, Callable[..., T]]) -> Union[Type, Callable[..., T]]:
+
+ def decorator(target: type | Callable[..., T]) -> type | Callable[..., T]:
"""Decorator that handles both class and method decoration."""
actual_persistence = persistence or SQLiteFlowPersistence()
if isinstance(target, type):
# Class decoration
- original_init = getattr(target, "__init__")
+ original_init = target.__init__ # type: ignore[misc]
@functools.wraps(original_init)
def new_init(self: Any, *args: Any, **kwargs: Any) -> None:
- if 'persistence' not in kwargs:
- kwargs['persistence'] = actual_persistence
+ if "persistence" not in kwargs:
+ kwargs["persistence"] = actual_persistence
original_init(self, *args, **kwargs)
- setattr(target, "__init__", new_init)
+ target.__init__ = new_init # type: ignore[misc]
# Store original methods to preserve their decorators
- original_methods = {}
-
- for name, method in target.__dict__.items():
- if callable(method) and (
- hasattr(method, "__is_start_method__") or
- hasattr(method, "__trigger_methods__") or
- hasattr(method, "__condition_type__") or
- hasattr(method, "__is_flow_method__") or
- hasattr(method, "__is_router__")
- ):
- original_methods[name] = method
+ original_methods = {
+ name: method
+ for name, method in target.__dict__.items()
+ if callable(method)
+ and (
+ hasattr(method, "__is_start_method__")
+ or hasattr(method, "__trigger_methods__")
+ or hasattr(method, "__condition_type__")
+ or hasattr(method, "__is_flow_method__")
+ or hasattr(method, "__is_router__")
+ )
+ }
# Create wrapped versions of the methods that include persistence
for name, method in original_methods.items():
if asyncio.iscoroutinefunction(method):
# Create a closure to capture the current name and method
- def create_async_wrapper(method_name: str, original_method: Callable):
+ def create_async_wrapper(
+ method_name: str, original_method: Callable
+ ):
@functools.wraps(original_method)
- async def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
+ async def method_wrapper(
+ self: Any, *args: Any, **kwargs: Any
+ ) -> Any:
result = await original_method(self, *args, **kwargs)
- PersistenceDecorator.persist_state(self, method_name, actual_persistence, verbose)
+ PersistenceDecorator.persist_state(
+ self, method_name, actual_persistence, verbose
+ )
return result
+
return method_wrapper
wrapped = create_async_wrapper(name, method)
# Preserve all original decorators and attributes
- for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
+ for attr in [
+ "__is_start_method__",
+ "__trigger_methods__",
+ "__condition_type__",
+ "__is_router__",
+ ]:
if hasattr(method, attr):
setattr(wrapped, attr, getattr(method, attr))
- setattr(wrapped, "__is_flow_method__", True)
+ wrapped.__is_flow_method__ = True # type: ignore[attr-defined]
# Update the class with the wrapped method
setattr(target, name, wrapped)
else:
# Create a closure to capture the current name and method
- def create_sync_wrapper(method_name: str, original_method: Callable):
+ def create_sync_wrapper(
+ method_name: str, original_method: Callable
+ ):
@functools.wraps(original_method)
def method_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
result = original_method(self, *args, **kwargs)
- PersistenceDecorator.persist_state(self, method_name, actual_persistence, verbose)
+ PersistenceDecorator.persist_state(
+ self, method_name, actual_persistence, verbose
+ )
return result
+
return method_wrapper
wrapped = create_sync_wrapper(name, method)
# Preserve all original decorators and attributes
- for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
+ for attr in [
+ "__is_start_method__",
+ "__trigger_methods__",
+ "__condition_type__",
+ "__is_router__",
+ ]:
if hasattr(method, attr):
setattr(wrapped, attr, getattr(method, attr))
- setattr(wrapped, "__is_flow_method__", True)
+ wrapped.__is_flow_method__ = True # type: ignore[attr-defined]
# Update the class with the wrapped method
setattr(target, name, wrapped)
return target
- else:
- # Method decoration
- method = target
- setattr(method, "__is_flow_method__", True)
+ # Method decoration
+ method = target
+ method.__is_flow_method__ = True # type: ignore[attr-defined]
- if asyncio.iscoroutinefunction(method):
- @functools.wraps(method)
- async def method_async_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
- method_coro = method(flow_instance, *args, **kwargs)
- if asyncio.iscoroutine(method_coro):
- result = await method_coro
- else:
- result = method_coro
- PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
- return result
+ if asyncio.iscoroutinefunction(method):
- for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
- if hasattr(method, attr):
- setattr(method_async_wrapper, attr, getattr(method, attr))
- setattr(method_async_wrapper, "__is_flow_method__", True)
- return cast(Callable[..., T], method_async_wrapper)
- else:
- @functools.wraps(method)
- def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
- result = method(flow_instance, *args, **kwargs)
- PersistenceDecorator.persist_state(flow_instance, method.__name__, actual_persistence, verbose)
- return result
+ @functools.wraps(method)
+ async def method_async_wrapper(
+ flow_instance: Any, *args: Any, **kwargs: Any
+ ) -> T:
+ method_coro = method(flow_instance, *args, **kwargs)
+ if asyncio.iscoroutine(method_coro):
+ result = await method_coro
+ else:
+ result = method_coro
+ PersistenceDecorator.persist_state(
+ flow_instance, method.__name__, actual_persistence, verbose
+ )
+ return result
- for attr in ["__is_start_method__", "__trigger_methods__", "__condition_type__", "__is_router__"]:
- if hasattr(method, attr):
- setattr(method_sync_wrapper, attr, getattr(method, attr))
- setattr(method_sync_wrapper, "__is_flow_method__", True)
- return cast(Callable[..., T], method_sync_wrapper)
+ for attr in [
+ "__is_start_method__",
+ "__trigger_methods__",
+ "__condition_type__",
+ "__is_router__",
+ ]:
+ if hasattr(method, attr):
+ setattr(method_async_wrapper, attr, getattr(method, attr))
+ method_async_wrapper.__is_flow_method__ = True # type: ignore[attr-defined]
+ return cast(Callable[..., T], method_async_wrapper)
+
+ @functools.wraps(method)
+ def method_sync_wrapper(flow_instance: Any, *args: Any, **kwargs: Any) -> T:
+ result = method(flow_instance, *args, **kwargs)
+ PersistenceDecorator.persist_state(
+ flow_instance, method.__name__, actual_persistence, verbose
+ )
+ return result
+
+ for attr in [
+ "__is_start_method__",
+ "__trigger_methods__",
+ "__condition_type__",
+ "__is_router__",
+ ]:
+ if hasattr(method, attr):
+ setattr(method_sync_wrapper, attr, getattr(method, attr))
+ method_sync_wrapper.__is_flow_method__ = True # type: ignore[attr-defined]
+ return cast(Callable[..., T], method_sync_wrapper)
return decorator
diff --git a/src/crewai/flow/persistence/sqlite.py b/src/crewai/flow/persistence/sqlite.py
index ee38c614d..1163d86c5 100644
--- a/src/crewai/flow/persistence/sqlite.py
+++ b/src/crewai/flow/persistence/sqlite.py
@@ -6,7 +6,7 @@ import json
import sqlite3
from datetime import datetime, timezone
from pathlib import Path
-from typing import Any, Dict, Optional, Union
+from typing import Any
from pydantic import BaseModel
@@ -23,7 +23,7 @@ class SQLiteFlowPersistence(FlowPersistence):
db_path: str
- def __init__(self, db_path: Optional[str] = None):
+ def __init__(self, db_path: str | None = None):
"""Initialize SQLite persistence.
Args:
@@ -70,7 +70,7 @@ class SQLiteFlowPersistence(FlowPersistence):
self,
flow_uuid: str,
method_name: str,
- state_data: Union[Dict[str, Any], BaseModel],
+ state_data: dict[str, Any] | BaseModel,
) -> None:
"""Save the current flow state to SQLite.
@@ -107,7 +107,7 @@ class SQLiteFlowPersistence(FlowPersistence):
),
)
- def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
+ def load_state(self, flow_uuid: str) -> dict[str, Any] | None:
"""Load the most recent state for a given flow UUID.
Args:
diff --git a/src/crewai/flow/types.py b/src/crewai/flow/types.py
index 8b6c9e6ad..38e3b7376 100644
--- a/src/crewai/flow/types.py
+++ b/src/crewai/flow/types.py
@@ -5,6 +5,7 @@ the Flow system.
"""
from typing import Any, TypedDict
+
from typing_extensions import NotRequired, Required
diff --git a/src/crewai/flow/utils.py b/src/crewai/flow/utils.py
index 81f3c1041..74e617bee 100644
--- a/src/crewai/flow/utils.py
+++ b/src/crewai/flow/utils.py
@@ -17,10 +17,10 @@ import ast
import inspect
import textwrap
from collections import defaultdict, deque
-from typing import Any, Deque, Dict, List, Optional, Set, Union
+from typing import Any
-def get_possible_return_constants(function: Any) -> Optional[List[str]]:
+def get_possible_return_constants(function: Any) -> list[str] | None:
try:
source = inspect.getsource(function)
except OSError:
@@ -58,12 +58,12 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
target = node.targets[0]
if isinstance(target, ast.Name):
var_name = target.id
- dict_values = []
# Extract string values from the dictionary
- for val in node.value.values:
- if isinstance(val, ast.Constant) and isinstance(val.value, str):
- dict_values.append(val.value)
- # If non-string, skip or just ignore
+ dict_values = [
+ val.value
+ for val in node.value.values
+ if isinstance(val, ast.Constant) and isinstance(val.value, str)
+ ]
if dict_values:
dict_definitions[var_name] = dict_values
self.generic_visit(node)
@@ -94,7 +94,7 @@ def get_possible_return_constants(function: Any) -> Optional[List[str]]:
return list(return_values) if return_values else None
-def calculate_node_levels(flow: Any) -> Dict[str, int]:
+def calculate_node_levels(flow: Any) -> dict[str, int]:
"""
Calculate the hierarchical level of each node in the flow.
@@ -118,10 +118,10 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
- Handles both OR and AND conditions for listeners
- Processes router paths separately
"""
- levels: Dict[str, int] = {}
- queue: Deque[str] = deque()
- visited: Set[str] = set()
- pending_and_listeners: Dict[str, Set[str]] = {}
+ levels: dict[str, int] = {}
+ queue: deque[str] = deque()
+ visited: set[str] = set()
+ pending_and_listeners: dict[str, set[str]] = {}
# Make all start methods at level 0
for method_name, method in flow._methods.items():
@@ -172,7 +172,7 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
return levels
-def count_outgoing_edges(flow: Any) -> Dict[str, int]:
+def count_outgoing_edges(flow: Any) -> dict[str, int]:
"""
Count the number of outgoing edges for each method in the flow.
@@ -197,7 +197,7 @@ def count_outgoing_edges(flow: Any) -> Dict[str, int]:
return counts
-def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
+def build_ancestor_dict(flow: Any) -> dict[str, set[str]]:
"""
Build a dictionary mapping each node to its ancestor nodes.
@@ -211,8 +211,8 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
Dict[str, Set[str]]
Dictionary mapping each node to a set of its ancestor nodes.
"""
- ancestors: Dict[str, Set[str]] = {node: set() for node in flow._methods}
- visited: Set[str] = set()
+ ancestors: dict[str, set[str]] = {node: set() for node in flow._methods}
+ visited: set[str] = set()
for node in flow._methods:
if node not in visited:
dfs_ancestors(node, ancestors, visited, flow)
@@ -220,7 +220,7 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
def dfs_ancestors(
- node: str, ancestors: Dict[str, Set[str]], visited: Set[str], flow: Any
+ node: str, ancestors: dict[str, set[str]], visited: set[str], flow: Any
) -> None:
"""
Perform depth-first search to build ancestor relationships.
@@ -265,7 +265,7 @@ def dfs_ancestors(
def is_ancestor(
- node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]
+ node: str, ancestor_candidate: str, ancestors: dict[str, set[str]]
) -> bool:
"""
Check if one node is an ancestor of another.
@@ -287,7 +287,7 @@ def is_ancestor(
return ancestor_candidate in ancestors.get(node, set())
-def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
+def build_parent_children_dict(flow: Any) -> dict[str, list[str]]:
"""
Build a dictionary mapping parent nodes to their children.
@@ -307,7 +307,7 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
- Maps router methods to their paths and listeners
- Children lists are sorted for consistent ordering
"""
- parent_children: Dict[str, List[str]] = {}
+ parent_children: dict[str, list[str]] = {}
# Map listeners to their trigger methods
for listener_name, (_, trigger_methods) in flow._listeners.items():
@@ -332,7 +332,7 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
def get_child_index(
- parent: str, child: str, parent_children: Dict[str, List[str]]
+ parent: str, child: str, parent_children: dict[str, list[str]]
) -> int:
"""
Get the index of a child node in its parent's sorted children list.
@@ -364,7 +364,7 @@ def process_router_paths(flow, current, current_level, levels, queue):
paths = flow._router_paths.get(current, [])
for path in paths:
for listener_name, (
- condition_type,
+ _condition_type,
trigger_methods,
) in flow._listeners.items():
if path in trigger_methods:
diff --git a/src/crewai/flow/visualization_utils.py b/src/crewai/flow/visualization_utils.py
index 202fb12d8..721aef23b 100644
--- a/src/crewai/flow/visualization_utils.py
+++ b/src/crewai/flow/visualization_utils.py
@@ -17,7 +17,7 @@ Example
import ast
import inspect
-from typing import Any, Dict, List, Tuple, Union
+from typing import Any
from .utils import (
build_ancestor_dict,
@@ -56,6 +56,7 @@ def method_calls_crew(method: Any) -> bool:
class CrewCallVisitor(ast.NodeVisitor):
"""AST visitor to detect .crew() method calls."""
+
def __init__(self):
self.found = False
@@ -73,8 +74,8 @@ def method_calls_crew(method: Any) -> bool:
def add_nodes_to_network(
net: Any,
flow: Any,
- node_positions: Dict[str, Tuple[float, float]],
- node_styles: Dict[str, Dict[str, Any]]
+ node_positions: dict[str, tuple[float, float]],
+ node_styles: dict[str, dict[str, Any]],
) -> None:
"""
Add nodes to the network visualization with appropriate styling.
@@ -98,6 +99,7 @@ def add_nodes_to_network(
- Crew methods
- Regular methods
"""
+
def human_friendly_label(method_name):
return method_name.replace("_", " ").title()
@@ -138,10 +140,10 @@ def add_nodes_to_network(
def compute_positions(
flow: Any,
- node_levels: Dict[str, int],
+ node_levels: dict[str, int],
y_spacing: float = 150,
- x_spacing: float = 300
-) -> Dict[str, Tuple[float, float]]:
+ x_spacing: float = 300,
+) -> dict[str, tuple[float, float]]:
"""
Compute the (x, y) positions for each node in the flow graph.
@@ -161,8 +163,8 @@ def compute_positions(
Dict[str, Tuple[float, float]]
Dictionary mapping node names to their (x, y) coordinates.
"""
- level_nodes: Dict[int, List[str]] = {}
- node_positions: Dict[str, Tuple[float, float]] = {}
+ level_nodes: dict[int, list[str]] = {}
+ node_positions: dict[str, tuple[float, float]] = {}
for method_name, level in node_levels.items():
level_nodes.setdefault(level, []).append(method_name)
@@ -180,10 +182,10 @@ def compute_positions(
def add_edges(
net: Any,
flow: Any,
- node_positions: Dict[str, Tuple[float, float]],
- colors: Dict[str, str]
+ node_positions: dict[str, tuple[float, float]],
+ colors: dict[str, str],
) -> None:
- edge_smooth: Dict[str, Union[str, float]] = {"type": "continuous"} # Default value
+ edge_smooth: dict[str, str | float] = {"type": "continuous"} # Default value
"""
Add edges to the network visualization with appropriate styling.
@@ -269,7 +271,7 @@ def add_edges(
for router_method_name, paths in flow._router_paths.items():
for path in paths:
for listener_name, (
- condition_type,
+ _condition_type,
trigger_methods,
) in flow._listeners.items():
if path in trigger_methods: