Compare commits

...

8 Commits

Author SHA1 Message Date
João Moura
4aaa37d7af Merge branch 'main' into devin/1735591359-circular-import-fix 2024-12-30 22:15:32 -03:00
Devin AI
4604ac618d fix: rename edge_props to router_edge_props to fix type checker error
Co-Authored-By: Joe Moura <joao@crewai.com>
2024-12-30 21:18:29 +00:00
Devin AI
b0d1d86c26 fix: type checker errors and docstrings in flow modules
Co-Authored-By: Joe Moura <joao@crewai.com>
2024-12-30 21:15:51 +00:00
Devin AI
454ab55a26 fix: type annotations and import sorting in flow modules
Co-Authored-By: Joe Moura <joao@crewai.com>
2024-12-30 20:59:06 +00:00
Devin AI
613dd175ee fix: type checker errors in flow modules
Co-Authored-By: Joe Moura <joao@crewai.com>
2024-12-30 20:53:59 +00:00
Devin AI
1dc8ce2674 style: fix import sorting in flow modules
Co-Authored-By: Joe Moura <joao@crewai.com>
2024-12-30 20:47:54 +00:00
Devin AI
e0590e516f fix: break circular import by refactoring flow/visualizer/utils import structure
- Split utils.py into specialized modules (core_flow_utils.py, flow_visual_utils.py)
- Add path_utils.py for secure file path handling
- Update imports to prevent circular dependencies
- Use TYPE_CHECKING for type hints
- Fix import sorting issues

Co-Authored-By: Joe Moura <joao@crewai.com>
2024-12-30 20:42:39 +00:00
Marco Vinciguerra
8c6883e5ee feat: add docstring 2024-12-30 16:44:11 +01:00
10 changed files with 1170 additions and 317 deletions

View File

@@ -0,0 +1,141 @@
"""Core utility functions for Flow class operations.
This module contains utility functions that are specifically designed to work
with the Flow class and require direct access to Flow class internals. These
utilities are separated from general-purpose utilities to maintain a clean
dependency structure and avoid circular imports.
Functions in this module are core to Flow functionality and are not related
to visualization or other optional features.
"""
import ast
import inspect
import textwrap
from typing import Any, Callable, Dict, List, Optional, Set, Union
from pydantic import BaseModel
def get_possible_return_constants(function: Callable[..., Any]) -> Optional[List[str]]:
"""Extract possible string return values from a function by analyzing its source code.
Analyzes the function's source code using AST to identify string constants that
could be returned, including strings stored in dictionaries and direct returns.
Args:
function: The function to analyze for possible return values
Returns:
list[str] | None: List of possible string return values, or None if:
- Source code cannot be retrieved
- Source code has syntax/indentation errors
- No string return values are found
Raises:
OSError: If source code cannot be retrieved
IndentationError: If source code has invalid indentation
SyntaxError: If source code has syntax errors
Example:
>>> def get_status():
... paths = {"success": "completed", "error": "failed"}
... return paths["success"]
>>> get_possible_return_constants(get_status)
['completed', 'failed']
"""
try:
source = inspect.getsource(function)
except OSError:
# Can't get source code
return None
except Exception as e:
print(f"Error retrieving source code for function {function.__name__}: {e}")
return None
try:
# Remove leading indentation
source = textwrap.dedent(source)
# Parse the source code into an AST
code_ast = ast.parse(source)
except IndentationError as e:
print(f"IndentationError while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
return None
except SyntaxError as e:
print(f"SyntaxError while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
return None
except Exception as e:
print(f"Unexpected error while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
return None
return_values = set()
dict_definitions = {}
class DictionaryAssignmentVisitor(ast.NodeVisitor):
def visit_Assign(self, node):
# Check if this assignment is assigning a dictionary literal to a variable
if isinstance(node.value, ast.Dict) and len(node.targets) == 1:
target = node.targets[0]
if isinstance(target, ast.Name):
var_name = target.id
dict_values = []
# Extract string values from the dictionary
for val in node.value.values:
if isinstance(val, ast.Constant) and isinstance(val.value, str):
dict_values.append(val.value)
# If non-string, skip or just ignore
if dict_values:
dict_definitions[var_name] = dict_values
self.generic_visit(node)
class ReturnVisitor(ast.NodeVisitor):
def visit_Return(self, node):
# Direct string return
if isinstance(node.value, ast.Constant) and isinstance(
node.value.value, str
):
return_values.add(node.value.value)
# Dictionary-based return, like return paths[result]
elif isinstance(node.value, ast.Subscript):
# Check if we're subscripting a known dictionary variable
if isinstance(node.value.value, ast.Name):
var_name = node.value.value.id
if var_name in dict_definitions:
# Add all possible dictionary values
for v in dict_definitions[var_name]:
return_values.add(v)
self.generic_visit(node)
# First pass: identify dictionary assignments
DictionaryAssignmentVisitor().visit(code_ast)
# Second pass: identify returns
ReturnVisitor().visit(code_ast)
return list(return_values) if return_values else None
def is_ancestor(node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]) -> bool:
"""Check if one node is an ancestor of another in the flow graph.
Args:
node: Target node to check ancestors for
ancestor_candidate: Node to check if it's an ancestor
ancestors: Dictionary mapping nodes to their ancestor sets
Returns:
bool: True if ancestor_candidate is an ancestor of node
Raises:
TypeError: If any argument has an invalid type
"""
if not isinstance(node, str):
raise TypeError("Argument 'node' must be a string")
if not isinstance(ancestor_candidate, str):
raise TypeError("Argument 'ancestor_candidate' must be a string")
if not isinstance(ancestors, dict):
raise TypeError("Argument 'ancestors' must be a dictionary")
return ancestor_candidate in ancestors.get(node, set())

View File

@@ -17,20 +17,43 @@ from typing import (
from blinker import Signal
from pydantic import BaseModel, ValidationError
from crewai.flow.core_flow_utils import get_possible_return_constants
from crewai.flow.flow_events import (
FlowFinishedEvent,
FlowStartedEvent,
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from crewai.flow.flow_visualizer import plot_flow
from crewai.flow.utils import get_possible_return_constants
from crewai.telemetry import Telemetry
T = TypeVar("T", bound=Union[BaseModel, Dict[str, Any]])
def start(condition=None):
def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable:
"""Marks a method as a flow starting point, optionally triggered by other methods.
Args:
condition: The condition that triggers this method. Can be:
- str: Name of the triggering method
- dict: Dictionary with 'type' and 'methods' keys for complex conditions
- Callable: A function reference
- None: No trigger condition (default)
Returns:
Callable: The decorated function that will serve as a flow starting point.
Raises:
ValueError: If the condition format is invalid.
Example:
>>> @start() # No condition
>>> def begin_flow():
>>> pass
>>>
>>> @start("method_name") # Triggered by specific method
>>> def conditional_start():
>>> pass
"""
def decorator(func):
func.__is_start_method__ = True
if condition is not None:
@@ -56,7 +79,30 @@ def start(condition=None):
return decorator
def listen(condition):
def listen(condition: Union[str, dict, Callable]) -> Callable:
"""Marks a method to execute when specified conditions/methods complete.
Args:
condition: The condition that triggers this method. Can be:
- str: Name of the triggering method
- dict: Dictionary with 'type' and 'methods' keys for complex conditions
- Callable: A function reference
Returns:
Callable: The decorated function that will execute when conditions are met.
Raises:
ValueError: If the condition format is invalid.
Example:
>>> @listen("start_method") # Listen to single method
>>> def on_start():
>>> pass
>>>
>>> @listen(and_("method1", "method2")) # Listen with AND condition
>>> def on_both_complete():
>>> pass
"""
def decorator(func):
if isinstance(condition, str):
func.__trigger_methods__ = [condition]
@@ -80,10 +126,33 @@ def listen(condition):
return decorator
def router(condition):
def router(condition: Union[str, dict, Callable]) -> Callable:
"""Marks a method as a router to direct flow based on its return value.
A router method can return different string values that trigger different
subsequent methods, allowing for dynamic flow control.
Args:
condition: The condition that triggers this router. Can be:
- str: Name of the triggering method
- dict: Dictionary with 'type' and 'methods' keys for complex conditions
- Callable: A function reference
Returns:
Callable: The decorated function that will serve as a router.
Raises:
ValueError: If the condition format is invalid.
Example:
>>> @router("process_data")
>>> def route_result(result):
>>> if result.success:
>>> return "handle_success"
>>> return "handle_error"
"""
def decorator(func):
func.__is_router__ = True
# Handle conditions like listen/start
if isinstance(condition, str):
func.__trigger_methods__ = [condition]
func.__condition_type__ = "OR"
@@ -106,7 +175,27 @@ def router(condition):
return decorator
def or_(*conditions):
def or_(*conditions: Union[str, dict, Callable]) -> dict:
"""Combines multiple conditions with OR logic for flow control.
Args:
*conditions: Variable number of conditions. Each can be:
- str: Name of a method
- dict: Dictionary with 'type' and 'methods' keys
- Callable: A function reference
Returns:
dict: A dictionary with 'type': 'OR' and 'methods' list.
Raises:
ValueError: If any condition is invalid.
Example:
>>> @listen(or_("method1", "method2"))
>>> def on_either():
>>> # Executes when either method1 OR method2 completes
>>> pass
"""
methods = []
for condition in conditions:
if isinstance(condition, dict) and "methods" in condition:
@@ -120,7 +209,27 @@ def or_(*conditions):
return {"type": "OR", "methods": methods}
def and_(*conditions):
def and_(*conditions: Union[str, dict, Callable]) -> dict:
"""Combines multiple conditions with AND logic for flow control.
Args:
*conditions: Variable number of conditions. Each can be:
- str: Name of a method
- dict: Dictionary with 'type' and 'methods' keys
- Callable: A function reference
Returns:
dict: A dictionary with 'type': 'AND' and 'methods' list.
Raises:
ValueError: If any condition is invalid.
Example:
>>> @listen(and_("method1", "method2"))
>>> def on_both():
>>> # Executes when BOTH method1 AND method2 complete
>>> pass
"""
methods = []
for condition in conditions:
if isinstance(condition, dict) and "methods" in condition:
@@ -179,6 +288,22 @@ class Flow(Generic[T], metaclass=FlowMeta):
event_emitter = Signal("event_emitter")
def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]:
"""Create a generic version of Flow with specified state type.
Args:
cls: The Flow class
item: The type parameter for the flow's state
Returns:
Type["Flow"]: A new Flow class with the specified state type
Example:
>>> class MyState(BaseModel):
>>> value: int
>>>
>>> class MyFlow(Flow[MyState]):
>>> pass
"""
class _FlowGeneric(cls): # type: ignore
_initial_state_T = item # type: ignore
@@ -186,11 +311,23 @@ class Flow(Generic[T], metaclass=FlowMeta):
return _FlowGeneric
def __init__(self) -> None:
"""Initialize a new Flow instance.
Sets up internal state tracking, method registration, and telemetry.
The flow's methods are automatically discovered and registered during initialization.
Attributes initialized:
_methods: Dictionary mapping method names to their callable objects
_state: The flow's state object of type T
_method_execution_counts: Tracks how many times each method has executed
_pending_and_listeners: Tracks methods waiting for AND conditions
_method_outputs: List of all outputs from executed methods
"""
self._methods: Dict[str, Callable] = {}
self._state: T = self._create_initial_state()
self._method_execution_counts: Dict[str, int] = {}
self._pending_and_listeners: Dict[str, Set[str]] = {}
self._method_outputs: List[Any] = [] # List to store all method outputs
self._method_outputs: List[Any] = []
self._telemetry.flow_creation_span(self.__class__.__name__)
@@ -201,6 +338,20 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._methods[method_name] = getattr(self, method_name)
def _create_initial_state(self) -> T:
"""Create the initial state for the flow.
The state is created based on the following priority:
1. If initial_state is None and _initial_state_T exists (generic type), use that
2. If initial_state is None, return empty dict
3. If initial_state is a type, instantiate it
4. Otherwise, use initial_state as-is
Returns:
T: The initial state object of type T
Note:
The type T can be either a Pydantic BaseModel or a dictionary.
"""
if self.initial_state is None and hasattr(self, "_initial_state_T"):
return self._initial_state_T() # type: ignore
if self.initial_state is None:
@@ -212,11 +363,21 @@ class Flow(Generic[T], metaclass=FlowMeta):
@property
def state(self) -> T:
"""Get the current state of the flow.
Returns:
T: The current state object, either a Pydantic model or dictionary
"""
return self._state
@property
def method_outputs(self) -> List[Any]:
"""Returns the list of all outputs from executed methods."""
"""Get the list of all outputs from executed methods.
Returns:
List[Any]: A list containing the output values from all executed flow methods,
in order of execution.
"""
return self._method_outputs
def _initialize_state(self, inputs: Dict[str, Any]) -> None:
@@ -306,6 +467,23 @@ class Flow(Generic[T], metaclass=FlowMeta):
return result
async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
"""Execute all listener methods triggered by a completed method.
This method handles both router and non-router listeners in a specific order:
1. First executes all triggered router methods sequentially until no more routers
are triggered
2. Then executes all regular listeners in parallel
Args:
trigger_method: The name of the method that completed execution
result: The result value from the triggering method
Note:
Router methods are executed sequentially to ensure proper flow control,
while regular listeners are executed concurrently for better performance.
This provides fine-grained control over the execution flow while
maintaining efficiency.
"""
# First, handle routers repeatedly until no router triggers anymore
while True:
routers_triggered = self._find_triggered_methods(
@@ -335,6 +513,27 @@ class Flow(Generic[T], metaclass=FlowMeta):
def _find_triggered_methods(
self, trigger_method: str, router_only: bool
) -> List[str]:
"""Find all methods that should be triggered based on completed method and type.
Provides precise control over method triggering by handling both OR and AND
conditions separately for router and non-router methods.
Args:
trigger_method: The name of the method that completed execution
router_only: If True, only find router methods; if False, only regular
listeners
Returns:
List[str]: Names of methods that should be executed next
Note:
This method implements sophisticated flow control by:
1. Filtering methods based on their router/non-router status
2. Handling OR conditions for immediate triggering
3. Managing AND conditions with state tracking for complex dependencies
This ensures predictable and consistent execution order in complex flows.
"""
triggered = []
for listener_name, (condition_type, methods) in self._listeners.items():
is_router = listener_name in self._routers
@@ -363,6 +562,27 @@ class Flow(Generic[T], metaclass=FlowMeta):
return triggered
async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
"""Execute a single listener method with precise parameter handling and error tracking.
Provides fine-grained control over method execution through:
1. Automatic parameter inspection to determine if the method accepts results
2. Event emission for execution tracking
3. Comprehensive error handling
4. Recursive listener execution
Args:
listener_name: The name of the listener method to execute
result: The result from the triggering method, passed to the listener
if its signature accepts parameters
Note:
This method ensures precise execution control by:
- Inspecting method signatures to handle parameters correctly
- Emitting events for execution tracking
- Providing comprehensive error handling
- Supporting both parameterized and parameter-less methods
- Maintaining execution chain through recursive listener calls
"""
try:
method = self._methods[listener_name]
@@ -406,8 +626,32 @@ class Flow(Generic[T], metaclass=FlowMeta):
traceback.print_exc()
def plot(self, filename: str = "crewai_flow") -> None:
def plot(self, *args, **kwargs):
"""Generate an interactive visualization of the flow's execution graph.
Creates a detailed HTML visualization showing the relationships between
methods, including start points, listeners, routers, and their
connections. Includes telemetry tracking for flow analysis.
Args:
*args: Variable length argument list passed to plot_flow
**kwargs: Arbitrary keyword arguments passed to plot_flow
Note:
The visualization provides:
- Clear representation of method relationships
- Visual distinction between different method types
- Interactive exploration capabilities
- Execution path tracing
- Telemetry tracking for flow analysis
Example:
>>> flow = MyFlow()
>>> flow.plot("my_workflow") # Creates my_workflow.html
"""
from crewai.flow.flow_visualizer import plot_flow
self._telemetry.flow_plotting_span(
self.__class__.__name__, list(self._methods.keys())
)
plot_flow(self, filename)
return plot_flow(self, *args, **kwargs)

View File

@@ -0,0 +1,240 @@
"""Utility functions for Flow visualization.
This module contains utility functions specifically designed for visualizing
Flow graphs and calculating layout information. These utilities are separated
from general-purpose utilities to maintain a clean dependency structure.
"""
from typing import Any, Dict, List, Set, TYPE_CHECKING
if TYPE_CHECKING:
from crewai.flow.flow import Flow
def calculate_node_levels(flow: Flow[Any]) -> Dict[str, int]:
"""Calculate the hierarchical level of each node in the flow graph.
Uses breadth-first traversal to assign levels to nodes, starting with
start methods at level 0. Handles both OR and AND conditions for listeners,
and considers router paths when calculating levels.
Args:
flow: Flow instance containing methods, listeners, and router configurations
Returns:
dict[str, int]: Dictionary mapping method names to their hierarchical levels,
where level 0 contains start methods and each subsequent level contains
methods triggered by the previous level
Example:
>>> flow = Flow()
>>> @flow.start
... def start(): pass
>>> @flow.on("start")
... def second(): pass
>>> calculate_node_levels(flow)
{'start': 0, 'second': 1}
"""
levels: Dict[str, int] = {}
queue: List[str] = []
visited: Set[str] = set()
pending_and_listeners: Dict[str, Set[str]] = {}
# Make all start methods at level 0
for method_name, method in flow._methods.items():
if hasattr(method, "__is_start_method__"):
levels[method_name] = 0
queue.append(method_name)
# Breadth-first traversal to assign levels
while queue:
current = queue.pop(0)
current_level = levels[current]
visited.add(current)
for listener_name, (condition_type, trigger_methods) in flow._listeners.items():
if condition_type == "OR":
if current in trigger_methods:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
elif condition_type == "AND":
if listener_name not in pending_and_listeners:
pending_and_listeners[listener_name] = set()
if current in trigger_methods:
pending_and_listeners[listener_name].add(current)
if set(trigger_methods) == pending_and_listeners[listener_name]:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
# Handle router connections
if current in flow._routers:
router_method_name = current
paths = flow._router_paths.get(router_method_name, [])
for path in paths:
for listener_name, (
condition_type,
trigger_methods,
) in flow._listeners.items():
if path in trigger_methods:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
return levels
def count_outgoing_edges(flow: Flow[Any]) -> Dict[str, int]:
"""Count the number of outgoing edges for each node in the flow graph.
An outgoing edge represents a connection from a method to a listener
that it triggers. This is useful for visualization and analysis of
flow structure.
Args:
flow: Flow instance containing methods and their connections
Returns:
dict[str, int]: Dictionary mapping method names to their number
of outgoing connections
"""
counts: Dict[str, int] = {}
for method_name in flow._methods:
counts[method_name] = 0
for method_name in flow._listeners:
_, trigger_methods = flow._listeners[method_name]
for trigger in trigger_methods:
if trigger in flow._methods:
counts[trigger] += 1
return counts
def build_ancestor_dict(flow: Flow[Any]) -> Dict[str, Set[str]]:
"""Build a dictionary mapping each node to its set of ancestor nodes.
Uses depth-first search to identify all ancestors (direct and indirect
trigger methods) for each node in the flow graph. Handles both regular
listeners and router paths.
Args:
flow: Flow instance containing methods and their relationships
Returns:
dict[str, set[str]]: Dictionary mapping each method name to a set
of its ancestor method names
"""
ancestors: Dict[str, Set[str]] = {node: set() for node in flow._methods}
visited: Set[str] = set()
for node in flow._methods:
if node not in visited:
dfs_ancestors(node, ancestors, visited, flow)
return ancestors
def dfs_ancestors(node: str, ancestors: Dict[str, Set[str]],
visited: Set[str], flow: Flow[Any]) -> None:
"""Perform depth-first search to populate the ancestors dictionary.
Helper function for build_ancestor_dict that recursively traverses
the flow graph to identify ancestors of each node.
Args:
node: Current node being processed
ancestors: Dictionary mapping nodes to their ancestor sets
visited: Set of already visited nodes
flow: Flow instance containing the graph structure
"""
if node in visited:
return
visited.add(node)
# Handle regular listeners
for listener_name, (_, trigger_methods) in flow._listeners.items():
if node in trigger_methods:
ancestors[listener_name].add(node)
ancestors[listener_name].update(ancestors[node])
dfs_ancestors(listener_name, ancestors, visited, flow)
# Handle router methods separately
if node in flow._routers:
router_method_name = node
paths = flow._router_paths.get(router_method_name, [])
for path in paths:
for listener_name, (_, trigger_methods) in flow._listeners.items():
if path in trigger_methods:
# Only propagate the ancestors of the router method, not the router method itself
ancestors[listener_name].update(ancestors[node])
dfs_ancestors(listener_name, ancestors, visited, flow)
def build_parent_children_dict(flow: Flow[Any]) -> Dict[str, List[str]]:
"""Build a dictionary mapping each node to its list of child nodes.
Maps both regular trigger methods to their listeners and router
methods to their path listeners. Useful for visualization and
traversal of the flow graph structure.
Args:
flow: Flow instance containing methods and their relationships
Returns:
dict[str, list[str]]: Dictionary mapping each method name to a
sorted list of its child method names
"""
parent_children: Dict[str, List[str]] = {}
# Map listeners to their trigger methods
for listener_name, (_, trigger_methods) in flow._listeners.items():
for trigger in trigger_methods:
if trigger not in parent_children:
parent_children[trigger] = []
if listener_name not in parent_children[trigger]:
parent_children[trigger].append(listener_name)
# Map router methods to their paths and to listeners
for router_method_name, paths in flow._router_paths.items():
for path in paths:
# Map router method to listeners of each path
for listener_name, (_, trigger_methods) in flow._listeners.items():
if path in trigger_methods:
if router_method_name not in parent_children:
parent_children[router_method_name] = []
if listener_name not in parent_children[router_method_name]:
parent_children[router_method_name].append(listener_name)
return parent_children
def get_child_index(parent: str, child: str,
parent_children: Dict[str, List[str]]) -> int:
"""Get the index of a child node in its parent's sorted children list.
Args:
parent: Parent node name
child: Child node name to find index for
parent_children: Dictionary mapping parents to their children lists
Returns:
int: Zero-based index of the child in parent's sorted children list
Raises:
ValueError: If child is not found in parent's children list
"""
children = parent_children.get(parent, [])
children.sort()
return children.index(child)

View File

@@ -1,13 +1,15 @@
# flow_visualizer.py
import os
from pathlib import Path
from pyvis.network import Network
from crewai.flow.config import COLORS, NODE_STYLES
from crewai.flow.flow_visual_utils import calculate_node_levels
from crewai.flow.html_template_handler import HTMLTemplateHandler
from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items
from crewai.flow.utils import calculate_node_levels
from crewai.flow.path_utils import safe_path_join, validate_file_path
from crewai.flow.visualization_utils import (
add_edges,
add_nodes_to_network,
@@ -16,12 +18,30 @@ from crewai.flow.visualization_utils import (
class FlowPlot:
"""Handles the creation and rendering of flow visualization diagrams."""
def __init__(self, flow):
"""Initialize flow plot with flow instance and styling configuration.
Args:
flow: A Flow instance with required attributes for visualization
Raises:
ValueError: If flow object is invalid or missing required attributes
"""
if not hasattr(flow, '_methods'):
raise ValueError("Invalid flow object: Missing '_methods' attribute")
if not hasattr(flow, '_start_methods'):
raise ValueError("Invalid flow object: Missing '_start_methods' attribute")
if not hasattr(flow, '_listeners'):
raise ValueError("Invalid flow object: Missing '_listeners' attribute")
self.flow = flow
self.colors = COLORS
self.node_styles = NODE_STYLES
def plot(self, filename):
"""Generate and save interactive flow visualization to HTML file."""
net = Network(
directed=True,
height="750px",
@@ -46,30 +66,29 @@ class FlowPlot:
"""
)
# Calculate levels for nodes
node_levels = calculate_node_levels(self.flow)
# Compute positions
node_positions = compute_positions(self.flow, node_levels)
# Add nodes to the network
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
# Add edges to the network
add_edges(net, self.flow, node_positions, self.colors)
network_html = net.generate_html()
final_html_content = self._generate_final_html(network_html)
# Save the final HTML content to the file
with open(f"{filename}.html", "w", encoding="utf-8") as f:
f.write(final_html_content)
print(f"Plot saved as {filename}.html")
try:
# Ensure the output path is safe
output_dir = os.getcwd()
output_path = safe_path_join(output_dir, f"{filename}.html")
with open(output_path, "w", encoding="utf-8") as f:
f.write(final_html_content)
print(f"Plot saved as {output_path}")
except (IOError, ValueError) as e:
raise IOError(f"Failed to save flow visualization: {str(e)}")
self._cleanup_pyvis_lib()
def _generate_final_html(self, network_html):
# Extract just the body content from the generated HTML
"""Generate final HTML content with network visualization and legend."""
current_dir = os.path.dirname(__file__)
template_path = os.path.join(
current_dir, "assets", "crewai_flow_visual_template.html"
@@ -79,7 +98,6 @@ class FlowPlot:
html_handler = HTMLTemplateHandler(template_path, logo_path)
network_body = html_handler.extract_body_content(network_html)
# Generate the legend items HTML
legend_items = get_legend_items(self.colors)
legend_items_html = generate_legend_items_html(legend_items)
final_html_content = html_handler.generate_final_html(
@@ -88,17 +106,17 @@ class FlowPlot:
return final_html_content
def _cleanup_pyvis_lib(self):
# Clean up the generated lib folder
"""Clean up temporary files generated by pyvis library."""
lib_folder = os.path.join(os.getcwd(), "lib")
try:
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
import shutil
shutil.rmtree(lib_folder)
except Exception as e:
print(f"Error cleaning up {lib_folder}: {e}")
def plot_flow(flow, filename="flow_plot"):
"""Create and save a visualization of the given flow."""
visualizer = FlowPlot(flow)
visualizer.plot(filename)

View File

@@ -1,28 +1,107 @@
import base64
import os
import re
from pathlib import Path
from crewai.flow.path_utils import safe_path_join, validate_file_path
class HTMLTemplateHandler:
"""Handles HTML template processing and generation for flow visualization diagrams."""
def __init__(self, template_path, logo_path):
self.template_path = template_path
self.logo_path = logo_path
"""Initialize template handler with template and logo file paths.
Args:
template_path: Path to the HTML template file
logo_path: Path to the logo SVG file
Raises:
ValueError: If template_path or logo_path is invalid or files don't exist
"""
try:
self.template_path = validate_file_path(template_path)
self.logo_path = validate_file_path(logo_path)
except (ValueError, TypeError) as e:
raise ValueError(f"Invalid file path: {str(e)}")
def read_template(self):
with open(self.template_path, "r", encoding="utf-8") as f:
return f.read()
"""Read and return the HTML template file contents.
Returns:
str: The contents of the template file
Raises:
IOError: If template file cannot be read
"""
try:
with open(self.template_path, "r", encoding="utf-8") as f:
return f.read()
except IOError as e:
raise IOError(f"Failed to read template file {self.template_path}: {str(e)}")
def encode_logo(self):
with open(self.logo_path, "rb") as logo_file:
logo_svg_data = logo_file.read()
return base64.b64encode(logo_svg_data).decode("utf-8")
"""Convert the logo SVG file to base64 encoded string.
Returns:
str: Base64 encoded logo data
Raises:
IOError: If logo file cannot be read
ValueError: If logo data cannot be encoded
"""
try:
with open(self.logo_path, "rb") as logo_file:
logo_svg_data = logo_file.read()
try:
return base64.b64encode(logo_svg_data).decode("utf-8")
except Exception as e:
raise ValueError(f"Failed to encode logo data: {str(e)}")
except IOError as e:
raise IOError(f"Failed to read logo file {self.logo_path}: {str(e)}")
def extract_body_content(self, html):
"""Extract and return content between body tags from HTML string.
Args:
html: HTML string to extract body content from
Returns:
str: Content between body tags, or empty string if not found
Raises:
ValueError: If input HTML is invalid
"""
if not html or not isinstance(html, str):
raise ValueError("Input HTML must be a non-empty string")
match = re.search("<body.*?>(.*?)</body>", html, re.DOTALL)
return match.group(1) if match else ""
def generate_legend_items_html(self, legend_items):
"""Generate HTML markup for the legend items.
Args:
legend_items: List of dictionaries containing legend item properties
Returns:
str: Generated HTML markup for legend items
Raises:
ValueError: If legend_items is invalid or missing required properties
"""
if not isinstance(legend_items, list):
raise ValueError("legend_items must be a list")
legend_items_html = ""
for item in legend_items:
if not isinstance(item, dict):
raise ValueError("Each legend item must be a dictionary")
if "color" not in item:
raise ValueError("Each legend item must have a 'color' property")
if "label" not in item:
raise ValueError("Each legend item must have a 'label' property")
if "border" in item:
legend_items_html += f"""
<div class="legend-item">
@@ -48,18 +127,42 @@ class HTMLTemplateHandler:
return legend_items_html
def generate_final_html(self, network_body, legend_items_html, title="Flow Plot"):
html_template = self.read_template()
logo_svg_base64 = self.encode_logo()
"""Combine all components into final HTML document with network visualization.
Args:
network_body: HTML string containing network visualization
legend_items_html: HTML string containing legend items markup
title: Title for the visualization page (default: "Flow Plot")
Returns:
str: Complete HTML document with all components integrated
Raises:
ValueError: If any input parameters are invalid
IOError: If template or logo files cannot be read
"""
if not isinstance(network_body, str):
raise ValueError("network_body must be a string")
if not isinstance(legend_items_html, str):
raise ValueError("legend_items_html must be a string")
if not isinstance(title, str):
raise ValueError("title must be a string")
try:
html_template = self.read_template()
logo_svg_base64 = self.encode_logo()
final_html_content = html_template.replace("{{ title }}", title)
final_html_content = final_html_content.replace(
"{{ network_content }}", network_body
)
final_html_content = final_html_content.replace(
"{{ logo_svg_base64 }}", logo_svg_base64
)
final_html_content = final_html_content.replace(
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
)
final_html_content = html_template.replace("{{ title }}", title)
final_html_content = final_html_content.replace(
"{{ network_content }}", network_body
)
final_html_content = final_html_content.replace(
"{{ logo_svg_base64 }}", logo_svg_base64
)
final_html_content = final_html_content.replace(
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
)
return final_html_content
return final_html_content
except Exception as e:
raise ValueError(f"Failed to generate final HTML: {str(e)}")

View File

@@ -1,3 +1,4 @@
def get_legend_items(colors):
return [
{"label": "Start Method", "color": colors["start"]},

View File

@@ -0,0 +1,123 @@
"""Utilities for safe path handling in flow visualization.
This module provides a comprehensive set of utilities for secure path handling,
including path joining, validation, and normalization. It helps prevent common
security issues like directory traversal attacks while providing a consistent
interface for path operations.
"""
import os
from pathlib import Path
from typing import List, Optional, Union
def safe_path_join(base_dir: Union[str, Path], filename: str) -> str:
"""Safely join base directory with filename, preventing directory traversal.
Args:
base_dir: Base directory path
filename: Filename or path to join with base_dir
Returns:
str: Safely joined absolute path
Raises:
ValueError: If resulting path would escape base_dir or contains dangerous patterns
TypeError: If inputs are not strings or Path objects
OSError: If path resolution fails
"""
if not isinstance(base_dir, (str, Path)):
raise TypeError("base_dir must be a string or Path object")
if not isinstance(filename, str):
raise TypeError("filename must be a string")
# Check for dangerous patterns
dangerous_patterns = ['..', '~', '*', '?', '|', '>', '<', '$', '&', '`']
if any(pattern in filename for pattern in dangerous_patterns):
raise ValueError(f"Invalid filename: Contains dangerous pattern")
try:
base_path = Path(base_dir).resolve(strict=True)
full_path = Path(base_path, filename).resolve(strict=True)
if not str(full_path).startswith(str(base_path)):
raise ValueError(
f"Invalid path: {filename} would escape base directory {base_dir}"
)
return str(full_path)
except OSError as e:
raise OSError(f"Failed to resolve path: {str(e)}")
except Exception as e:
raise ValueError(f"Failed to process paths: {str(e)}")
def normalize_path(path: Union[str, Path]) -> str:
"""Normalize a path by resolving symlinks and removing redundant separators.
Args:
path: Path to normalize
Returns:
str: Normalized absolute path
Raises:
TypeError: If path is not a string or Path object
OSError: If path resolution fails
"""
if not isinstance(path, (str, Path)):
raise TypeError("path must be a string or Path object")
try:
return str(Path(path).resolve(strict=True))
except OSError as e:
raise OSError(f"Failed to normalize path: {str(e)}")
def validate_path_components(components: List[str]) -> None:
"""Validate path components for potentially dangerous patterns.
Args:
components: List of path components to validate
Raises:
TypeError: If components is not a list or contains non-string items
ValueError: If any component contains dangerous patterns
"""
if not isinstance(components, list):
raise TypeError("components must be a list")
dangerous_patterns = ['..', '~', '*', '?', '|', '>', '<', '$', '&', '`']
for component in components:
if not isinstance(component, str):
raise TypeError(f"Path component '{component}' must be a string")
if any(pattern in component for pattern in dangerous_patterns):
raise ValueError(f"Invalid path component '{component}': Contains dangerous pattern")
def validate_file_path(path: Union[str, Path], must_exist: bool = True) -> str:
"""Validate a file path for security and existence.
Args:
path: File path to validate
must_exist: Whether the file must exist (default: True)
Returns:
str: Validated absolute path
Raises:
ValueError: If path is invalid or file doesn't exist when required
TypeError: If path is not a string or Path object
"""
if not isinstance(path, (str, Path)):
raise TypeError("path must be a string or Path object")
try:
resolved_path = Path(path).resolve()
if must_exist and not resolved_path.is_file():
raise ValueError(f"File not found: {path}")
return str(resolved_path)
except Exception as e:
raise ValueError(f"Invalid file path {path}: {str(e)}")

View File

@@ -1,220 +1,35 @@
import ast
import inspect
import textwrap
"""General utility functions for flow execution.
This module has been deprecated. All functionality has been moved to:
- core_flow_utils.py: Core flow execution utilities
- flow_visual_utils.py: Visualization-related utilities
def get_possible_return_constants(function):
try:
source = inspect.getsource(function)
except OSError:
# Can't get source code
return None
except Exception as e:
print(f"Error retrieving source code for function {function.__name__}: {e}")
return None
This module is kept as a temporary redirect to maintain backwards compatibility.
New code should import from the appropriate new modules directly.
"""
try:
# Remove leading indentation
source = textwrap.dedent(source)
# Parse the source code into an AST
code_ast = ast.parse(source)
except IndentationError as e:
print(f"IndentationError while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
return None
except SyntaxError as e:
print(f"SyntaxError while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
return None
except Exception as e:
print(f"Unexpected error while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
return None
from typing import Any, Dict, List, Optional, Set
return_values = set()
dict_definitions = {}
from .core_flow_utils import get_possible_return_constants, is_ancestor
from .flow_visual_utils import (
build_ancestor_dict,
build_parent_children_dict,
calculate_node_levels,
count_outgoing_edges,
dfs_ancestors,
get_child_index,
)
class DictionaryAssignmentVisitor(ast.NodeVisitor):
def visit_Assign(self, node):
# Check if this assignment is assigning a dictionary literal to a variable
if isinstance(node.value, ast.Dict) and len(node.targets) == 1:
target = node.targets[0]
if isinstance(target, ast.Name):
var_name = target.id
dict_values = []
# Extract string values from the dictionary
for val in node.value.values:
if isinstance(val, ast.Constant) and isinstance(val.value, str):
dict_values.append(val.value)
# If non-string, skip or just ignore
if dict_values:
dict_definitions[var_name] = dict_values
self.generic_visit(node)
# Re-export all functions for backwards compatibility
__all__ = [
'get_possible_return_constants',
'calculate_node_levels',
'count_outgoing_edges',
'build_ancestor_dict',
'dfs_ancestors',
'is_ancestor',
'build_parent_children_dict',
'get_child_index',
]
class ReturnVisitor(ast.NodeVisitor):
def visit_Return(self, node):
# Direct string return
if isinstance(node.value, ast.Constant) and isinstance(
node.value.value, str
):
return_values.add(node.value.value)
# Dictionary-based return, like return paths[result]
elif isinstance(node.value, ast.Subscript):
# Check if we're subscripting a known dictionary variable
if isinstance(node.value.value, ast.Name):
var_name = node.value.value.id
if var_name in dict_definitions:
# Add all possible dictionary values
for v in dict_definitions[var_name]:
return_values.add(v)
self.generic_visit(node)
# First pass: identify dictionary assignments
DictionaryAssignmentVisitor().visit(code_ast)
# Second pass: identify returns
ReturnVisitor().visit(code_ast)
return list(return_values) if return_values else None
def calculate_node_levels(flow):
levels = {}
queue = []
visited = set()
pending_and_listeners = {}
# Make all start methods at level 0
for method_name, method in flow._methods.items():
if hasattr(method, "__is_start_method__"):
levels[method_name] = 0
queue.append(method_name)
# Breadth-first traversal to assign levels
while queue:
current = queue.pop(0)
current_level = levels[current]
visited.add(current)
for listener_name, (condition_type, trigger_methods) in flow._listeners.items():
if condition_type == "OR":
if current in trigger_methods:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
elif condition_type == "AND":
if listener_name not in pending_and_listeners:
pending_and_listeners[listener_name] = set()
if current in trigger_methods:
pending_and_listeners[listener_name].add(current)
if set(trigger_methods) == pending_and_listeners[listener_name]:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
# Handle router connections
if current in flow._routers:
router_method_name = current
paths = flow._router_paths.get(router_method_name, [])
for path in paths:
for listener_name, (
condition_type,
trigger_methods,
) in flow._listeners.items():
if path in trigger_methods:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
return levels
def count_outgoing_edges(flow):
counts = {}
for method_name in flow._methods:
counts[method_name] = 0
for method_name in flow._listeners:
_, trigger_methods = flow._listeners[method_name]
for trigger in trigger_methods:
if trigger in flow._methods:
counts[trigger] += 1
return counts
def build_ancestor_dict(flow):
ancestors = {node: set() for node in flow._methods}
visited = set()
for node in flow._methods:
if node not in visited:
dfs_ancestors(node, ancestors, visited, flow)
return ancestors
def dfs_ancestors(node, ancestors, visited, flow):
if node in visited:
return
visited.add(node)
# Handle regular listeners
for listener_name, (_, trigger_methods) in flow._listeners.items():
if node in trigger_methods:
ancestors[listener_name].add(node)
ancestors[listener_name].update(ancestors[node])
dfs_ancestors(listener_name, ancestors, visited, flow)
# Handle router methods separately
if node in flow._routers:
router_method_name = node
paths = flow._router_paths.get(router_method_name, [])
for path in paths:
for listener_name, (_, trigger_methods) in flow._listeners.items():
if path in trigger_methods:
# Only propagate the ancestors of the router method, not the router method itself
ancestors[listener_name].update(ancestors[node])
dfs_ancestors(listener_name, ancestors, visited, flow)
def is_ancestor(node, ancestor_candidate, ancestors):
return ancestor_candidate in ancestors.get(node, set())
def build_parent_children_dict(flow):
parent_children = {}
# Map listeners to their trigger methods
for listener_name, (_, trigger_methods) in flow._listeners.items():
for trigger in trigger_methods:
if trigger not in parent_children:
parent_children[trigger] = []
if listener_name not in parent_children[trigger]:
parent_children[trigger].append(listener_name)
# Map router methods to their paths and to listeners
for router_method_name, paths in flow._router_paths.items():
for path in paths:
# Map router method to listeners of each path
for listener_name, (_, trigger_methods) in flow._listeners.items():
if path in trigger_methods:
if router_method_name not in parent_children:
parent_children[router_method_name] = []
if listener_name not in parent_children[router_method_name]:
parent_children[router_method_name].append(listener_name)
return parent_children
def get_child_index(parent, child, parent_children):
children = parent_children.get(parent, [])
children.sort()
return children.index(child)
# Function implementations have been moved to core_flow_utils.py and flow_visual_utils.py

View File

@@ -1,25 +1,63 @@
import ast
import inspect
import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, cast
from .utils import (
from crewai.flow.flow import Flow
from pyvis.network import Network
from .core_flow_utils import is_ancestor
from .flow_visual_utils import (
build_ancestor_dict,
build_parent_children_dict,
get_child_index,
is_ancestor,
)
from .path_utils import safe_path_join, validate_file_path
def method_calls_crew(method):
"""Check if the method calls `.crew()`."""
def method_calls_crew(method: Optional[Callable[..., Any]]) -> bool:
"""Check if the method contains a .crew() call in its implementation.
Analyzes the method's source code using AST to detect if it makes any
calls to the .crew() method, which indicates crew involvement in the
flow execution.
Args:
method: The method to analyze for crew calls, can be None
Returns:
bool: True if the method contains a .crew() call, False otherwise
Raises:
TypeError: If input is not None and not a callable method
ValueError: If method source code cannot be parsed
RuntimeError: If unexpected error occurs during parsing
"""
if method is None:
return False
if not callable(method):
raise TypeError("Input must be a callable method")
try:
source = inspect.getsource(method)
source = inspect.cleandoc(source)
tree = ast.parse(source)
except (TypeError, ValueError, OSError) as e:
raise ValueError(f"Could not parse method {getattr(method, '__name__', str(method))}: {e}")
except Exception as e:
print(f"Could not parse method {method.__name__}: {e}")
return False
raise RuntimeError(f"Unexpected error parsing method: {e}")
class CrewCallVisitor(ast.NodeVisitor):
"""AST visitor to detect .crew() method calls in source code.
A specialized AST visitor that analyzes Python source code to precisely
identify calls to the .crew() method, enabling accurate detection of
crew involvement in flow methods.
Attributes:
found (bool): Indicates whether a .crew() call was found
"""
def __init__(self):
self.found = False
@@ -34,8 +72,64 @@ def method_calls_crew(method):
return visitor.found
def add_nodes_to_network(net, flow, node_positions, node_styles):
def human_friendly_label(method_name):
def add_nodes_to_network(net: Network, flow: Flow[Any],
node_positions: Dict[str, Tuple[float, float]],
node_styles: Dict[str, dict],
output_dir: Optional[str] = None) -> None:
"""Add nodes to the network visualization with precise styling and positioning.
Creates and styles nodes in the visualization network based on their type
(start, router, crew, or regular method) with fine-grained control over
appearance and positioning.
Args:
net: The network visualization object to add nodes to
flow: Flow object containing method definitions and relationships
node_positions: Dictionary mapping method names to (x,y) coordinates
node_styles: Dictionary mapping node types to their visual styles
output_dir: Optional directory path for saving visualization assets
Returns:
None
Raises:
ValueError: If flow object is invalid or required styles are missing
TypeError: If input arguments have incorrect types
OSError: If output directory operations fail
Note:
Node styles are applied with precise control over shape, font, color,
and positioning to ensure accurate visual representation of the flow.
If output_dir is provided, it will be validated and created if needed.
"""
if not hasattr(flow, '_methods'):
raise ValueError("Invalid flow object: missing '_methods' attribute")
if not isinstance(node_positions, dict):
raise TypeError("node_positions must be a dictionary")
if not isinstance(node_styles, dict):
raise TypeError("node_styles must be a dictionary")
required_styles = {'start', 'router', 'crew', 'method'}
missing_styles = required_styles - set(node_styles.keys())
if missing_styles:
raise ValueError(f"Missing required node styles: {missing_styles}")
# Validate and create output directory if specified
if output_dir:
try:
output_dir = validate_file_path(output_dir, must_exist=False)
os.makedirs(output_dir, exist_ok=True)
except (ValueError, OSError) as e:
raise OSError(f"Failed to create or validate output directory: {e}")
def human_friendly_label(method_name: str) -> str:
"""Convert method name to human-readable format.
Args:
method_name: Original method name with underscores
Returns:
str: Formatted method name with spaces and title case
"""
return method_name.replace("_", " ").title()
for method_name, (x, y) in node_positions.items():
@@ -52,6 +146,15 @@ def add_nodes_to_network(net, flow, node_positions, node_styles):
node_style = node_style.copy()
label = human_friendly_label(method_name)
# Handle file-based assets if output directory is provided
if output_dir and node_style.get("image"):
try:
image_path = node_style["image"]
safe_image_path = safe_path_join(output_dir, Path(image_path).name)
node_style["image"] = str(safe_image_path)
except (ValueError, OSError) as e:
raise OSError(f"Failed to process node image path: {e}")
node_style.update(
{
"label": label,
@@ -73,9 +176,41 @@ def add_nodes_to_network(net, flow, node_positions, node_styles):
)
def compute_positions(flow, node_levels, y_spacing=150, x_spacing=150):
level_nodes = {}
node_positions = {}
def compute_positions(flow: Flow[Any], node_levels: Dict[str, int],
y_spacing: float = 150, x_spacing: float = 150) -> Dict[str, Tuple[float, float]]:
"""Calculate precise x,y coordinates for each node in the flow diagram.
Computes optimal node positions with fine-grained control over spacing
and alignment, ensuring clear visualization of flow hierarchy and
relationships.
Args:
flow: Flow object containing method definitions
node_levels: Dictionary mapping method names to their hierarchy levels
y_spacing: Vertical spacing between hierarchy levels (default: 150)
x_spacing: Horizontal spacing between nodes at same level (default: 150)
Returns:
dict[str, tuple[float, float]]: Dictionary mapping method names to
their calculated (x,y) coordinates in the visualization
Note:
Positions are calculated to maintain clear hierarchical structure while
ensuring optimal spacing and readability of the flow diagram.
"""
if not hasattr(flow, '_methods'):
raise ValueError("Invalid flow object: missing '_methods' attribute")
if not isinstance(node_levels, dict):
raise TypeError("node_levels must be a dictionary")
if not isinstance(y_spacing, (int, float)) or y_spacing <= 0:
raise ValueError("y_spacing must be a positive number")
if not isinstance(x_spacing, (int, float)) or x_spacing <= 0:
raise ValueError("x_spacing must be a positive number")
if not node_levels:
raise ValueError("node_levels dictionary cannot be empty")
level_nodes: Dict[int, List[str]] = {}
node_positions: Dict[str, Tuple[float, float]] = {}
for method_name, level in node_levels.items():
level_nodes.setdefault(level, []).append(method_name)
@@ -90,7 +225,34 @@ def compute_positions(flow, node_levels, y_spacing=150, x_spacing=150):
return node_positions
def add_edges(net, flow, node_positions, colors):
def add_edges(net: Network, flow: Flow[Any],
node_positions: Dict[str, Tuple[float, float]],
colors: Dict[str, str],
asset_dir: Optional[str] = None) -> None:
if not hasattr(flow, '_methods'):
raise ValueError("Invalid flow object: missing '_methods' attribute")
if not hasattr(flow, '_listeners'):
raise ValueError("Invalid flow object: missing '_listeners' attribute")
if not hasattr(flow, '_router_paths'):
raise ValueError("Invalid flow object: missing '_router_paths' attribute")
if not isinstance(node_positions, dict):
raise TypeError("node_positions must be a dictionary")
if not isinstance(colors, dict):
raise TypeError("colors must be a dictionary")
required_colors = {'edge', 'router_edge'}
missing_colors = required_colors - set(colors.keys())
if missing_colors:
raise ValueError(f"Missing required edge colors: {missing_colors}")
# Validate asset directory if provided
if asset_dir:
try:
asset_dir = validate_file_path(asset_dir, must_exist=False)
os.makedirs(asset_dir, exist_ok=True)
except (ValueError, OSError) as e:
raise OSError(f"Failed to create or validate asset directory: {e}")
ancestors = build_ancestor_dict(flow)
parent_children = build_parent_children_dict(flow)
@@ -119,24 +281,24 @@ def add_edges(net, flow, node_positions, colors):
dx = target_pos[0] - source_pos[0]
smooth_type = "curvedCCW" if dx <= 0 else "curvedCW"
index = get_child_index(trigger, method_name, parent_children)
edge_smooth = {
edge_config = {
"type": smooth_type,
"roundness": 0.2 + (0.1 * index),
}
else:
edge_smooth = {"type": "cubicBezier"}
edge_config = {"type": "cubicBezier"}
else:
edge_smooth = False
edge_config = {"type": "straight"}
edge_style = {
edge_props: Dict[str, Any] = {
"color": edge_color,
"width": 2,
"arrows": "to",
"dashes": True if is_router_edge or is_and_condition else False,
"smooth": edge_smooth,
"smooth": edge_config,
}
net.add_edge(trigger, method_name, **edge_style)
net.add_edge(trigger, method_name, **edge_props)
else:
# Nodes not found in node_positions. Check if it's a known router outcome and a known method.
is_router_edge = any(
@@ -182,23 +344,23 @@ def add_edges(net, flow, node_positions, colors):
index = get_child_index(
router_method_name, listener_name, parent_children
)
edge_smooth = {
edge_config = {
"type": smooth_type,
"roundness": 0.2 + (0.1 * index),
}
else:
edge_smooth = {"type": "cubicBezier"}
edge_config = {"type": "cubicBezier"}
else:
edge_smooth = False
edge_config = {"type": "straight"}
edge_style = {
router_edge_props: Dict[str, Any] = {
"color": colors["router_edge"],
"width": 2,
"arrows": "to",
"dashes": True,
"smooth": edge_smooth,
"smooth": edge_config,
}
net.add_edge(router_method_name, listener_name, **edge_style)
net.add_edge(router_method_name, listener_name, **router_edge_props)
else:
# Same check here: known router edge and known method?
method_known = listener_name in flow._methods

68
uv.lock generated
View File

@@ -1,10 +1,18 @@
version = 1
requires-python = ">=3.10, <3.13"
resolution-markers = [
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12' and python_full_version < '3.12.4'",
"python_full_version >= '3.12.4'",
"python_full_version < '3.11' and sys_platform == 'darwin'",
"python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'",
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
"python_full_version == '3.11.*' and sys_platform == 'darwin'",
"python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'",
"(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')",
"python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform == 'darwin'",
"python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'",
"(python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux')",
"python_full_version >= '3.12.4' and sys_platform == 'darwin'",
"python_full_version >= '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'",
"(python_full_version >= '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux')",
]
[[package]]
@@ -300,7 +308,7 @@ name = "build"
version = "1.2.2.post1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "os_name == 'nt'" },
{ name = "colorama", marker = "(os_name == 'nt' and platform_machine != 'aarch64' and sys_platform == 'linux') or (os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "importlib-metadata", marker = "python_full_version < '3.10.2'" },
{ name = "packaging" },
{ name = "pyproject-hooks" },
@@ -535,7 +543,7 @@ name = "click"
version = "8.1.7"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "platform_system == 'Windows'" },
{ name = "colorama", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 }
wheels = [
@@ -642,7 +650,6 @@ tools = [
[package.dev-dependencies]
dev = [
{ name = "cairosvg" },
{ name = "crewai-tools" },
{ name = "mkdocs" },
{ name = "mkdocs-material" },
{ name = "mkdocs-material-extensions" },
@@ -696,7 +703,6 @@ requires-dist = [
[package.metadata.requires-dev]
dev = [
{ name = "cairosvg", specifier = ">=2.7.1" },
{ name = "crewai-tools", specifier = ">=0.17.0" },
{ name = "mkdocs", specifier = ">=1.4.3" },
{ name = "mkdocs-material", specifier = ">=9.5.7" },
{ name = "mkdocs-material-extensions", specifier = ">=1.3.1" },
@@ -2462,7 +2468,7 @@ version = "1.6.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "click" },
{ name = "colorama", marker = "platform_system == 'Windows'" },
{ name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "ghp-import" },
{ name = "jinja2" },
{ name = "markdown" },
@@ -2643,7 +2649,7 @@ version = "2.10.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pygments" },
{ name = "pywin32", marker = "platform_system == 'Windows'" },
{ name = "pywin32", marker = "sys_platform == 'win32'" },
{ name = "tqdm" },
]
sdist = { url = "https://files.pythonhosted.org/packages/3a/93/80ac75c20ce54c785648b4ed363c88f148bf22637e10c9863db4fbe73e74/mpire-2.10.2.tar.gz", hash = "sha256:f66a321e93fadff34585a4bfa05e95bd946cf714b442f51c529038eb45773d97", size = 271270 }
@@ -2890,7 +2896,7 @@ name = "nvidia-cudnn-cu12"
version = "9.1.0.70"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 },
@@ -2917,9 +2923,9 @@ name = "nvidia-cusolver-cu12"
version = "11.4.5.107"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
{ name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 },
@@ -2930,7 +2936,7 @@ name = "nvidia-cusparse-cu12"
version = "12.1.0.106"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 },
@@ -3480,7 +3486,7 @@ name = "portalocker"
version = "2.10.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pywin32", marker = "platform_system == 'Windows'" },
{ name = "pywin32", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ed/d3/c6c64067759e87af98cc668c1cc75171347d0f1577fab7ca3749134e3cd4/portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f", size = 40891 }
wheels = [
@@ -5022,19 +5028,19 @@ dependencies = [
{ name = "fsspec" },
{ name = "jinja2" },
{ name = "networkx" },
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "sympy" },
{ name = "triton", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "typing-extensions" },
]
wheels = [
@@ -5081,7 +5087,7 @@ name = "tqdm"
version = "4.66.5"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "platform_system == 'Windows'" },
{ name = "colorama", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/58/83/6ba9844a41128c62e810fddddd72473201f3eacde02046066142a2d96cc5/tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad", size = 169504 }
wheels = [
@@ -5124,7 +5130,7 @@ version = "0.27.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "attrs" },
{ name = "cffi", marker = "implementation_name != 'pypy' and os_name == 'nt'" },
{ name = "cffi", marker = "(implementation_name != 'pypy' and os_name == 'nt' and platform_machine != 'aarch64' and sys_platform == 'linux') or (implementation_name != 'pypy' and os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
{ name = "idna" },
{ name = "outcome" },
@@ -5155,7 +5161,7 @@ name = "triton"
version = "3.0.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
{ name = "filelock", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 },