From b22568aa6d22315cea1ec3b0c8809f02254ce2c4 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Tue, 1 Oct 2024 09:56:14 -0400 Subject: [PATCH] Refactor to make crews easier to understand --- src/crewai/flow/config.py | 46 +++ src/crewai/flow/flow_visualizer.py | 456 ++--------------------- src/crewai/flow/html_template_handler.py | 66 ++++ src/crewai/flow/legend_generator.py | 46 +++ src/crewai/flow/utils.py | 143 +++++++ src/crewai/flow/visualization_utils.py | 132 +++++++ 6 files changed, 467 insertions(+), 422 deletions(-) create mode 100644 src/crewai/flow/config.py create mode 100644 src/crewai/flow/html_template_handler.py create mode 100644 src/crewai/flow/legend_generator.py create mode 100644 src/crewai/flow/utils.py create mode 100644 src/crewai/flow/visualization_utils.py diff --git a/src/crewai/flow/config.py b/src/crewai/flow/config.py new file mode 100644 index 000000000..ddaddc7a8 --- /dev/null +++ b/src/crewai/flow/config.py @@ -0,0 +1,46 @@ +DARK_GRAY = "#333333" +CREWAI_ORANGE = "#FF5A50" +GRAY = "#666666" +WHITE = "#FFFFFF" + +COLORS = { + "bg": WHITE, + "start": CREWAI_ORANGE, + "method": DARK_GRAY, + "router": DARK_GRAY, + "router_border": CREWAI_ORANGE, + "edge": GRAY, + "router_edge": CREWAI_ORANGE, + "text": WHITE, +} + +NODE_STYLES = { + "start": { + "color": COLORS["start"], + "shape": "box", + "font": {"color": COLORS["text"]}, + "margin": {"top": 10, "bottom": 8, "left": 10, "right": 10}, + }, + "method": { + "color": COLORS["method"], + "shape": "box", + "font": {"color": COLORS["text"]}, + "margin": {"top": 10, "bottom": 8, "left": 10, "right": 10}, + }, + "router": { + "color": { + "background": COLORS["router"], + "border": COLORS["router_border"], + "highlight": { + "border": COLORS["router_border"], + "background": COLORS["router"], + }, + }, + "shape": "box", + "font": {"color": COLORS["text"]}, + "borderWidth": 3, + "borderWidthSelected": 4, + "shapeProperties": {"borderDashes": [5, 5]}, + "margin": {"top": 10, "bottom": 8, "left": 10, "right": 10}, + }, +} diff --git a/src/crewai/flow/flow_visualizer.py b/src/crewai/flow/flow_visualizer.py index 8b00d8822..468b390d1 100644 --- a/src/crewai/flow/flow_visualizer.py +++ b/src/crewai/flow/flow_visualizer.py @@ -1,62 +1,26 @@ # flow_visualizer.py -import base64 import os -import re from pyvis.network import Network -DARK_GRAY = "#333333" -CREWAI_ORANGE = "#FF5A50" -GRAY = "#666666" -WHITE = "#FFFFFF" +from crewai.flow.config import COLORS, NODE_STYLES +from crewai.flow.html_template_handler import HTMLTemplateHandler +from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items +from crewai.flow.utils import calculate_node_levels +from crewai.flow.visualization_utils import ( + add_edges, + add_nodes_to_network, + compute_positions, +) class FlowVisualizer: def __init__(self, flow): self.flow = flow - self.colors = { - "bg": WHITE, - "start": CREWAI_ORANGE, - "method": DARK_GRAY, - "router": DARK_GRAY, - "router_border": CREWAI_ORANGE, - "edge": GRAY, - "router_edge": CREWAI_ORANGE, - "text": WHITE, - } - self.node_styles = { - "start": { - "color": self.colors["start"], - "shape": "box", - "font": {"color": self.colors["text"]}, - "margin": {"top": 10, "bottom": 8, "left": 10, "right": 10}, - }, - "method": { - "color": self.colors["method"], - "shape": "box", - "font": {"color": self.colors["text"]}, - "margin": {"top": 10, "bottom": 8, "left": 10, "right": 10}, - }, - "router": { - "color": { - "background": self.colors["router"], - "border": self.colors["router_border"], - "highlight": { - "border": self.colors["router_border"], - "background": self.colors["router"], - }, - }, - "shape": "box", - "font": {"color": self.colors["text"]}, - "borderWidth": 3, - "borderWidthSelected": 4, - "shapeProperties": {"borderDashes": [5, 5]}, - "margin": {"top": 10, "bottom": 8, "left": 10, "right": 10}, - }, - } + self.colors = COLORS + self.node_styles = NODE_STYLES - # TODO: DROP LIB FOLDER POST GENERATION def visualize(self, filename): net = Network( directed=True, @@ -67,172 +31,16 @@ class FlowVisualizer: ) # Calculate levels for nodes - node_levels = self._calculate_node_levels() - - # Assign positions to nodes based on levels - y_spacing = 150 - x_spacing = 150 - level_nodes = {} - - # Store node positions for edge calculations - node_positions = {} - - for method_name, level in node_levels.items(): - level_nodes.setdefault(level, []).append(method_name) + node_levels = calculate_node_levels(self.flow) # Compute positions - for level, nodes in level_nodes.items(): - x_offset = -(len(nodes) - 1) * x_spacing / 2 # Center nodes horizontally - for i, method_name in enumerate(nodes): - x = x_offset + i * x_spacing - y = level * y_spacing - node_positions[method_name] = (x, y) + node_positions = compute_positions(self.flow, node_levels) - method = self.flow._methods.get(method_name) - if hasattr(method, "__is_start_method__"): - node_style = self.node_styles["start"] - elif hasattr(method, "__is_router__"): - node_style = self.node_styles["router"] - else: - node_style = self.node_styles["method"] + # Add nodes to the network + add_nodes_to_network(net, self.flow, node_positions, self.node_styles) - net.add_node( - method_name, - label=method_name, - x=x, - y=y, - fixed=True, - physics=False, - **node_style, - ) - - ancestors = self._build_ancestor_dict() - parent_children = self._build_parent_children_dict() - - # Add edges - for method_name in self.flow._listeners: - condition_type, trigger_methods = self.flow._listeners[method_name] - is_and_condition = condition_type == "AND" - - for trigger in trigger_methods: - if ( - trigger in self.flow._methods - or trigger in self.flow._routers.values() - ): - is_router_edge = any( - trigger in paths for paths in self.flow._router_paths.values() - ) - edge_color = ( - self.colors["router_edge"] - if is_router_edge - else self.colors["edge"] - ) - - # Determine if this edge forms a cycle - is_cycle_edge = self._is_ancestor(trigger, method_name, ancestors) - - # Determine if parent has multiple children - parent_has_multiple_children = ( - len(parent_children.get(trigger, [])) > 1 - ) - - # Edge curvature logic - needs_curvature = is_cycle_edge or parent_has_multiple_children - - if needs_curvature: - # Get node positions - source_pos = node_positions.get(trigger) - target_pos = node_positions.get(method_name) - - if source_pos and target_pos: - dx = target_pos[0] - source_pos[0] - - if dx <= 0: - # Child is left or directly below - smooth_type = "curvedCCW" # Curve left and down - else: - # Child is to the right - smooth_type = "curvedCW" # Curve right and down - - index = self._get_child_index( - trigger, method_name, parent_children - ) - edge_smooth = { - "type": smooth_type, - "roundness": 0.2 + (0.1 * index), - } - else: - # Fallback curvature - edge_smooth = {"type": "cubicBezier"} - else: - edge_smooth = False # Draw straight line - - edge_style = { - "color": edge_color, - "width": 2, - "arrows": "to", - "dashes": True if is_router_edge or is_and_condition else False, - "smooth": edge_smooth, - } - - net.add_edge(trigger, method_name, **edge_style) - - # Add edges from router methods to their possible paths - for router_method_name, paths in self.flow._router_paths.items(): - for path in paths: - for listener_name, ( - condition_type, - trigger_methods, - ) in self.flow._listeners.items(): - if path in trigger_methods: - is_cycle_edge = self._is_ancestor( - trigger, method_name, ancestors - ) - - # Determine if parent has multiple children - parent_has_multiple_children = ( - len(parent_children.get(router_method_name, [])) > 1 - ) - - # Edge curvature logic - needs_curvature = is_cycle_edge or parent_has_multiple_children - - if needs_curvature: - # Get node positions - source_pos = node_positions.get(router_method_name) - target_pos = node_positions.get(listener_name) - - if source_pos and target_pos: - dx = target_pos[0] - source_pos[0] - - if dx <= 0: - # Child is left or directly below - smooth_type = "curvedCCW" # Curve left and down - else: - # Child is to the right - smooth_type = "curvedCW" # Curve right and down - - index = self._get_child_index( - router_method_name, listener_name, parent_children - ) - edge_smooth = { - "type": smooth_type, - "roundness": 0.2 + (0.1 * index), - } - else: - # Fallback curvature - edge_smooth = {"type": "cubicBezier"} - else: - edge_smooth = False # Straight line - - edge_style = { - "color": self.colors["router_edge"], - "width": 2, - "arrows": "to", - "dashes": True, - "smooth": edge_smooth, - } - net.add_edge(router_method_name, listener_name, **edge_style) + # Add edges to the network + add_edges(net, self.flow, node_positions, self.colors) # Set options to disable physics net.set_options( @@ -246,227 +54,31 @@ class FlowVisualizer: ) network_html = net.generate_html() - - # Extract just the body content from the generated HTML - match = re.search("(.*?)", network_html, re.DOTALL) - if match: - network_body = match.group(1) - else: - network_body = "" - - # Read the custom template - current_dir = os.path.dirname(__file__) - template_path = os.path.join( - current_dir, "assets", "crewai_flow_visual_template.html" - ) - with open(template_path, "r", encoding="utf-8") as f: - html_template = f.read() - - # Generate the legend items HTML - legend_items = [ - {"label": "Start Method", "color": self.colors["start"]}, - {"label": "Method", "color": self.colors["method"]}, - { - "label": "Router", - "color": self.colors["router"], - "border": self.colors["router_border"], - "dashed": True, - }, - {"label": "Trigger", "color": self.colors["edge"], "dashed": False}, - {"label": "AND Trigger", "color": self.colors["edge"], "dashed": True}, - { - "label": "Router Trigger", - "color": self.colors["router_edge"], - "dashed": True, - }, - ] - - legend_items_html = "" - for item in legend_items: - if "border" in item: - legend_items_html += f""" -
-
-
{item['label']}
-
- """ - elif item.get("dashed") is not None: - style = "dashed" if item["dashed"] else "solid" - legend_items_html += f""" -
-
-
{item['label']}
-
- """ - else: - legend_items_html += f""" -
-
-
{item['label']}
-
- """ - - # Read the logo file and encode it - logo_path = os.path.join(current_dir, "assets", "crewai_logo.svg") - with open(logo_path, "rb") as logo_file: - logo_svg_data = logo_file.read() - logo_svg_base64 = base64.b64encode(logo_svg_data).decode("utf-8") - - # Replace placeholders in the template - final_html_content = html_template.replace("{{ title }}", "Flow Graph") - final_html_content = final_html_content.replace( - "{{ network_content }}", network_body - ) - final_html_content = final_html_content.replace( - "{{ logo_svg_base64 }}", logo_svg_base64 - ) - final_html_content = final_html_content.replace( - "", legend_items_html - ) + final_html_content = self._generate_final_html(network_html) # Save the final HTML content to the file with open(f"{filename}.html", "w", encoding="utf-8") as f: f.write(final_html_content) print(f"Graph saved as {filename}.html") - def _calculate_node_levels(self): - levels = {} - queue = [] - visited = set() - pending_and_listeners = {} + def _generate_final_html(self, network_html): + # Extract just the body content from the generated HTML + current_dir = os.path.dirname(__file__) + template_path = os.path.join( + current_dir, "assets", "crewai_flow_visual_template.html" + ) + logo_path = os.path.join(current_dir, "assets", "crewai_logo.svg") - # Make all start methods at level 0 - for method_name, method in self.flow._methods.items(): - if hasattr(method, "__is_start_method__"): - levels[method_name] = 0 - queue.append(method_name) + html_handler = HTMLTemplateHandler(template_path, logo_path) + network_body = html_handler.extract_body_content(network_html) - # Breadth-first traversal to assign levels - while queue: - current = queue.pop(0) - current_level = levels[current] - visited.add(current) - - for listener_name, ( - condition_type, - trigger_methods, - ) in self.flow._listeners.items(): - if condition_type == "OR": - if current in trigger_methods: - if ( - listener_name not in levels - or levels[listener_name] > current_level + 1 - ): - levels[listener_name] = current_level + 1 - if listener_name not in visited: - queue.append(listener_name) - elif condition_type == "AND": - if listener_name not in pending_and_listeners: - pending_and_listeners[listener_name] = set() - if current in trigger_methods: - pending_and_listeners[listener_name].add(current) - if set(trigger_methods) == pending_and_listeners[listener_name]: - if ( - listener_name not in levels - or levels[listener_name] > current_level + 1 - ): - levels[listener_name] = current_level + 1 - if listener_name not in visited: - queue.append(listener_name) - - # Handle router connections - if current in self.flow._routers.values(): - router_method_name = current - paths = self.flow._router_paths.get(router_method_name, []) - for path in paths: - for listener_name, ( - condition_type, - trigger_methods, - ) in self.flow._listeners.items(): - if path in trigger_methods: - if ( - listener_name not in levels - or levels[listener_name] > current_level + 1 - ): - levels[listener_name] = current_level + 1 - if listener_name not in visited: - queue.append(listener_name) - return levels - - def _count_outgoing_edges(self): - counts = {} - for method_name in self.flow._methods: - counts[method_name] = 0 - for method_name in self.flow._listeners: - _, trigger_methods = self.flow._listeners[method_name] - for trigger in trigger_methods: - if trigger in self.flow._methods: - counts[trigger] += 1 - return counts - - def _build_ancestor_dict(self): - ancestors = {node: set() for node in self.flow._methods} - visited = set() - for node in self.flow._methods: - if node not in visited: - self._dfs_ancestors(node, ancestors, visited) - - return ancestors - - def _dfs_ancestors(self, node, ancestors, visited): - if node in visited: - return - visited.add(node) - - # Handle regular listeners - for listener_name, (_, trigger_methods) in self.flow._listeners.items(): - if node in trigger_methods: - ancestors[listener_name].add(node) - ancestors[listener_name].update(ancestors[node]) - self._dfs_ancestors(listener_name, ancestors, visited) - - # Handle router methods separately - if node in self.flow._routers.values(): - router_method_name = node - paths = self.flow._router_paths.get(router_method_name, []) - for path in paths: - for listener_name, (_, trigger_methods) in self.flow._listeners.items(): - if path in trigger_methods: - # Only propagate the ancestors of the router method, not the router method itself - ancestors[listener_name].update(ancestors[node]) - self._dfs_ancestors(listener_name, ancestors, visited) - - def _is_ancestor(self, node, ancestor_candidate, ancestors): - return ancestor_candidate in ancestors.get(node, set()) - - def _build_parent_children_dict(self): - parent_children = {} - - # Map listeners to their trigger methods - for listener_name, (_, trigger_methods) in self.flow._listeners.items(): - for trigger in trigger_methods: - if trigger not in parent_children: - parent_children[trigger] = [] - if listener_name not in parent_children[trigger]: - parent_children[trigger].append(listener_name) - - # Map router methods to their paths and to listeners - for router_method_name, paths in self.flow._router_paths.items(): - for path in paths: - # Map router method to listeners of each path - for listener_name, (_, trigger_methods) in self.flow._listeners.items(): - if path in trigger_methods: - if router_method_name not in parent_children: - parent_children[router_method_name] = [] - if listener_name not in parent_children[router_method_name]: - parent_children[router_method_name].append(listener_name) - - return parent_children - - def _get_child_index(self, parent, child, parent_children): - children = parent_children.get(parent, []) - children.sort() - return children.index(child) + # Generate the legend items HTML + legend_items = get_legend_items(self.colors) + legend_items_html = generate_legend_items_html(legend_items) + final_html_content = html_handler.generate_final_html( + network_body, legend_items_html + ) + return final_html_content def visualize_flow(flow, filename="flow_graph"): diff --git a/src/crewai/flow/html_template_handler.py b/src/crewai/flow/html_template_handler.py new file mode 100644 index 000000000..8a88da42a --- /dev/null +++ b/src/crewai/flow/html_template_handler.py @@ -0,0 +1,66 @@ +import base64 +import os +import re + + +class HTMLTemplateHandler: + def __init__(self, template_path, logo_path): + self.template_path = template_path + self.logo_path = logo_path + + def read_template(self): + with open(self.template_path, "r", encoding="utf-8") as f: + return f.read() + + def encode_logo(self): + with open(self.logo_path, "rb") as logo_file: + logo_svg_data = logo_file.read() + return base64.b64encode(logo_svg_data).decode("utf-8") + + def extract_body_content(self, html): + match = re.search("(.*?)", html, re.DOTALL) + return match.group(1) if match else "" + + def generate_legend_items_html(self, legend_items): + legend_items_html = "" + for item in legend_items: + if "border" in item: + legend_items_html += f""" +
+
+
{item['label']}
+
+ """ + elif item.get("dashed") is not None: + style = "dashed" if item["dashed"] else "solid" + legend_items_html += f""" +
+
+
{item['label']}
+
+ """ + else: + legend_items_html += f""" +
+
+
{item['label']}
+
+ """ + return legend_items_html + + def generate_final_html(self, network_body, legend_items_html, title="Flow Graph"): + html_template = self.read_template() + logo_svg_base64 = self.encode_logo() + + final_html_content = html_template.replace("{{ title }}", title) + final_html_content = final_html_content.replace( + "{{ network_content }}", network_body + ) + final_html_content = final_html_content.replace( + "{{ logo_svg_base64 }}", logo_svg_base64 + ) + final_html_content = final_html_content.replace( + "", legend_items_html + ) + + return final_html_content diff --git a/src/crewai/flow/legend_generator.py b/src/crewai/flow/legend_generator.py new file mode 100644 index 000000000..83d9b97a2 --- /dev/null +++ b/src/crewai/flow/legend_generator.py @@ -0,0 +1,46 @@ +def get_legend_items(colors): + return [ + {"label": "Start Method", "color": colors["start"]}, + {"label": "Method", "color": colors["method"]}, + { + "label": "Router", + "color": colors["router"], + "border": colors["router_border"], + "dashed": True, + }, + {"label": "Trigger", "color": colors["edge"], "dashed": False}, + {"label": "AND Trigger", "color": colors["edge"], "dashed": True}, + { + "label": "Router Trigger", + "color": colors["router_edge"], + "dashed": True, + }, + ] + + +def generate_legend_items_html(legend_items): + legend_items_html = "" + for item in legend_items: + if "border" in item: + legend_items_html += f""" +
+
+
{item['label']}
+
+ """ + elif item.get("dashed") is not None: + style = "dashed" if item["dashed"] else "solid" + legend_items_html += f""" +
+
+
{item['label']}
+
+ """ + else: + legend_items_html += f""" +
+
+
{item['label']}
+
+ """ + return legend_items_html diff --git a/src/crewai/flow/utils.py b/src/crewai/flow/utils.py new file mode 100644 index 000000000..f2dbfb7fd --- /dev/null +++ b/src/crewai/flow/utils.py @@ -0,0 +1,143 @@ +def calculate_node_levels(flow): + levels = {} + queue = [] + visited = set() + pending_and_listeners = {} + + # Make all start methods at level 0 + for method_name, method in flow._methods.items(): + if hasattr(method, "__is_start_method__"): + levels[method_name] = 0 + queue.append(method_name) + + # Breadth-first traversal to assign levels + while queue: + current = queue.pop(0) + current_level = levels[current] + visited.add(current) + + for listener_name, ( + condition_type, + trigger_methods, + ) in flow._listeners.items(): + if condition_type == "OR": + if current in trigger_methods: + if ( + listener_name not in levels + or levels[listener_name] > current_level + 1 + ): + levels[listener_name] = current_level + 1 + if listener_name not in visited: + queue.append(listener_name) + elif condition_type == "AND": + if listener_name not in pending_and_listeners: + pending_and_listeners[listener_name] = set() + if current in trigger_methods: + pending_and_listeners[listener_name].add(current) + if set(trigger_methods) == pending_and_listeners[listener_name]: + if ( + listener_name not in levels + or levels[listener_name] > current_level + 1 + ): + levels[listener_name] = current_level + 1 + if listener_name not in visited: + queue.append(listener_name) + + # Handle router connections + if current in flow._routers.values(): + router_method_name = current + paths = flow._router_paths.get(router_method_name, []) + for path in paths: + for listener_name, ( + condition_type, + trigger_methods, + ) in flow._listeners.items(): + if path in trigger_methods: + if ( + listener_name not in levels + or levels[listener_name] > current_level + 1 + ): + levels[listener_name] = current_level + 1 + if listener_name not in visited: + queue.append(listener_name) + return levels + + +def count_outgoing_edges(flow): + counts = {} + for method_name in flow._methods: + counts[method_name] = 0 + for method_name in flow._listeners: + _, trigger_methods = flow._listeners[method_name] + for trigger in trigger_methods: + if trigger in flow._methods: + counts[trigger] += 1 + return counts + + +def build_ancestor_dict(flow): + ancestors = {node: set() for node in flow._methods} + visited = set() + for node in flow._methods: + if node not in visited: + dfs_ancestors(node, ancestors, visited, flow) + return ancestors + + +def dfs_ancestors(node, ancestors, visited, flow): + if node in visited: + return + visited.add(node) + + # Handle regular listeners + for listener_name, (_, trigger_methods) in flow._listeners.items(): + if node in trigger_methods: + ancestors[listener_name].add(node) + ancestors[listener_name].update(ancestors[node]) + dfs_ancestors(listener_name, ancestors, visited, flow) + + # Handle router methods separately + if node in flow._routers.values(): + router_method_name = node + paths = flow._router_paths.get(router_method_name, []) + for path in paths: + for listener_name, (_, trigger_methods) in flow._listeners.items(): + if path in trigger_methods: + # Only propagate the ancestors of the router method, not the router method itself + ancestors[listener_name].update(ancestors[node]) + dfs_ancestors(listener_name, ancestors, visited, flow) + + +def is_ancestor(node, ancestor_candidate, ancestors): + return ancestor_candidate in ancestors.get(node, set()) + + +def build_parent_children_dict(flow): + parent_children = {} + + # Map listeners to their trigger methods + for listener_name, (_, trigger_methods) in flow._listeners.items(): + for trigger in trigger_methods: + if trigger not in parent_children: + parent_children[trigger] = [] + if listener_name not in parent_children[trigger]: + parent_children[trigger].append(listener_name) + + # Map router methods to their paths and to listeners + for router_method_name, paths in flow._router_paths.items(): + for path in paths: + # Map router method to listeners of each path + for listener_name, (_, trigger_methods) in flow._listeners.items(): + if path in trigger_methods: + if router_method_name not in parent_children: + parent_children[router_method_name] = [] + if listener_name not in parent_children[router_method_name]: + parent_children[router_method_name].append(listener_name) + + return parent_children + + +def get_child_index(parent, child, parent_children): + children = parent_children.get(parent, []) + children.sort() + return children.index(child) diff --git a/src/crewai/flow/visualization_utils.py b/src/crewai/flow/visualization_utils.py new file mode 100644 index 000000000..ba2ba5f18 --- /dev/null +++ b/src/crewai/flow/visualization_utils.py @@ -0,0 +1,132 @@ +from .utils import ( + build_ancestor_dict, + build_parent_children_dict, + get_child_index, + is_ancestor, +) + + +def compute_positions(flow, node_levels, y_spacing=150, x_spacing=150): + level_nodes = {} + node_positions = {} + + for method_name, level in node_levels.items(): + level_nodes.setdefault(level, []).append(method_name) + + for level, nodes in level_nodes.items(): + x_offset = -(len(nodes) - 1) * x_spacing / 2 # Center nodes horizontally + for i, method_name in enumerate(nodes): + x = x_offset + i * x_spacing + y = level * y_spacing + node_positions[method_name] = (x, y) + + return node_positions + + +def add_edges(net, flow, node_positions, colors): + ancestors = build_ancestor_dict(flow) + parent_children = build_parent_children_dict(flow) + + for method_name in flow._listeners: + condition_type, trigger_methods = flow._listeners[method_name] + is_and_condition = condition_type == "AND" + + for trigger in trigger_methods: + if trigger in flow._methods or trigger in flow._routers.values(): + is_router_edge = any( + trigger in paths for paths in flow._router_paths.values() + ) + edge_color = colors["router_edge"] if is_router_edge else colors["edge"] + + is_cycle_edge = is_ancestor(trigger, method_name, ancestors) + parent_has_multiple_children = len(parent_children.get(trigger, [])) > 1 + needs_curvature = is_cycle_edge or parent_has_multiple_children + + if needs_curvature: + source_pos = node_positions.get(trigger) + target_pos = node_positions.get(method_name) + + if source_pos and target_pos: + dx = target_pos[0] - source_pos[0] + smooth_type = "curvedCCW" if dx <= 0 else "curvedCW" + index = get_child_index(trigger, method_name, parent_children) + edge_smooth = { + "type": smooth_type, + "roundness": 0.2 + (0.1 * index), + } + else: + edge_smooth = {"type": "cubicBezier"} + else: + edge_smooth = False + + edge_style = { + "color": edge_color, + "width": 2, + "arrows": "to", + "dashes": True if is_router_edge or is_and_condition else False, + "smooth": edge_smooth, + } + + net.add_edge(trigger, method_name, **edge_style) + + for router_method_name, paths in flow._router_paths.items(): + for path in paths: + for listener_name, ( + condition_type, + trigger_methods, + ) in flow._listeners.items(): + if path in trigger_methods: + is_cycle_edge = is_ancestor(trigger, method_name, ancestors) + parent_has_multiple_children = ( + len(parent_children.get(router_method_name, [])) > 1 + ) + needs_curvature = is_cycle_edge or parent_has_multiple_children + + if needs_curvature: + source_pos = node_positions.get(router_method_name) + target_pos = node_positions.get(listener_name) + + if source_pos and target_pos: + dx = target_pos[0] - source_pos[0] + smooth_type = "curvedCCW" if dx <= 0 else "curvedCW" + index = get_child_index( + router_method_name, listener_name, parent_children + ) + edge_smooth = { + "type": smooth_type, + "roundness": 0.2 + (0.1 * index), + } + else: + edge_smooth = {"type": "cubicBezier"} + else: + edge_smooth = False + + edge_style = { + "color": colors["router_edge"], + "width": 2, + "arrows": "to", + "dashes": True, + "smooth": edge_smooth, + } + net.add_edge(router_method_name, listener_name, **edge_style) + + +def add_nodes_to_network(net, flow, node_positions, node_styles): + for method_name, (x, y) in node_positions.items(): + method = flow._methods.get(method_name) + if hasattr(method, "__is_start_method__"): + node_style = node_styles["start"] + elif hasattr(method, "__is_router__"): + node_style = node_styles["router"] + else: + node_style = node_styles["method"] + + net.add_node( + method_name, + label=method_name, + x=x, + y=y, + fixed=True, + physics=False, + **node_style, + )