Refactor to make crews easier to understand

This commit is contained in:
Brandon Hancock
2024-10-01 09:56:14 -04:00
parent 09bc68078c
commit b22568aa6d
6 changed files with 467 additions and 422 deletions

46
src/crewai/flow/config.py Normal file
View File

@@ -0,0 +1,46 @@
DARK_GRAY = "#333333"
CREWAI_ORANGE = "#FF5A50"
GRAY = "#666666"
WHITE = "#FFFFFF"
COLORS = {
"bg": WHITE,
"start": CREWAI_ORANGE,
"method": DARK_GRAY,
"router": DARK_GRAY,
"router_border": CREWAI_ORANGE,
"edge": GRAY,
"router_edge": CREWAI_ORANGE,
"text": WHITE,
}
NODE_STYLES = {
"start": {
"color": COLORS["start"],
"shape": "box",
"font": {"color": COLORS["text"]},
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
},
"method": {
"color": COLORS["method"],
"shape": "box",
"font": {"color": COLORS["text"]},
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
},
"router": {
"color": {
"background": COLORS["router"],
"border": COLORS["router_border"],
"highlight": {
"border": COLORS["router_border"],
"background": COLORS["router"],
},
},
"shape": "box",
"font": {"color": COLORS["text"]},
"borderWidth": 3,
"borderWidthSelected": 4,
"shapeProperties": {"borderDashes": [5, 5]},
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
},
}

View File

@@ -1,62 +1,26 @@
# flow_visualizer.py
import base64
import os
import re
from pyvis.network import Network
DARK_GRAY = "#333333"
CREWAI_ORANGE = "#FF5A50"
GRAY = "#666666"
WHITE = "#FFFFFF"
from crewai.flow.config import COLORS, NODE_STYLES
from crewai.flow.html_template_handler import HTMLTemplateHandler
from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items
from crewai.flow.utils import calculate_node_levels
from crewai.flow.visualization_utils import (
add_edges,
add_nodes_to_network,
compute_positions,
)
class FlowVisualizer:
def __init__(self, flow):
self.flow = flow
self.colors = {
"bg": WHITE,
"start": CREWAI_ORANGE,
"method": DARK_GRAY,
"router": DARK_GRAY,
"router_border": CREWAI_ORANGE,
"edge": GRAY,
"router_edge": CREWAI_ORANGE,
"text": WHITE,
}
self.node_styles = {
"start": {
"color": self.colors["start"],
"shape": "box",
"font": {"color": self.colors["text"]},
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
},
"method": {
"color": self.colors["method"],
"shape": "box",
"font": {"color": self.colors["text"]},
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
},
"router": {
"color": {
"background": self.colors["router"],
"border": self.colors["router_border"],
"highlight": {
"border": self.colors["router_border"],
"background": self.colors["router"],
},
},
"shape": "box",
"font": {"color": self.colors["text"]},
"borderWidth": 3,
"borderWidthSelected": 4,
"shapeProperties": {"borderDashes": [5, 5]},
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
},
}
self.colors = COLORS
self.node_styles = NODE_STYLES
# TODO: DROP LIB FOLDER POST GENERATION
def visualize(self, filename):
net = Network(
directed=True,
@@ -67,172 +31,16 @@ class FlowVisualizer:
)
# Calculate levels for nodes
node_levels = self._calculate_node_levels()
# Assign positions to nodes based on levels
y_spacing = 150
x_spacing = 150
level_nodes = {}
# Store node positions for edge calculations
node_positions = {}
for method_name, level in node_levels.items():
level_nodes.setdefault(level, []).append(method_name)
node_levels = calculate_node_levels(self.flow)
# Compute positions
for level, nodes in level_nodes.items():
x_offset = -(len(nodes) - 1) * x_spacing / 2 # Center nodes horizontally
for i, method_name in enumerate(nodes):
x = x_offset + i * x_spacing
y = level * y_spacing
node_positions[method_name] = (x, y)
node_positions = compute_positions(self.flow, node_levels)
method = self.flow._methods.get(method_name)
if hasattr(method, "__is_start_method__"):
node_style = self.node_styles["start"]
elif hasattr(method, "__is_router__"):
node_style = self.node_styles["router"]
else:
node_style = self.node_styles["method"]
# Add nodes to the network
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
net.add_node(
method_name,
label=method_name,
x=x,
y=y,
fixed=True,
physics=False,
**node_style,
)
ancestors = self._build_ancestor_dict()
parent_children = self._build_parent_children_dict()
# Add edges
for method_name in self.flow._listeners:
condition_type, trigger_methods = self.flow._listeners[method_name]
is_and_condition = condition_type == "AND"
for trigger in trigger_methods:
if (
trigger in self.flow._methods
or trigger in self.flow._routers.values()
):
is_router_edge = any(
trigger in paths for paths in self.flow._router_paths.values()
)
edge_color = (
self.colors["router_edge"]
if is_router_edge
else self.colors["edge"]
)
# Determine if this edge forms a cycle
is_cycle_edge = self._is_ancestor(trigger, method_name, ancestors)
# Determine if parent has multiple children
parent_has_multiple_children = (
len(parent_children.get(trigger, [])) > 1
)
# Edge curvature logic
needs_curvature = is_cycle_edge or parent_has_multiple_children
if needs_curvature:
# Get node positions
source_pos = node_positions.get(trigger)
target_pos = node_positions.get(method_name)
if source_pos and target_pos:
dx = target_pos[0] - source_pos[0]
if dx <= 0:
# Child is left or directly below
smooth_type = "curvedCCW" # Curve left and down
else:
# Child is to the right
smooth_type = "curvedCW" # Curve right and down
index = self._get_child_index(
trigger, method_name, parent_children
)
edge_smooth = {
"type": smooth_type,
"roundness": 0.2 + (0.1 * index),
}
else:
# Fallback curvature
edge_smooth = {"type": "cubicBezier"}
else:
edge_smooth = False # Draw straight line
edge_style = {
"color": edge_color,
"width": 2,
"arrows": "to",
"dashes": True if is_router_edge or is_and_condition else False,
"smooth": edge_smooth,
}
net.add_edge(trigger, method_name, **edge_style)
# Add edges from router methods to their possible paths
for router_method_name, paths in self.flow._router_paths.items():
for path in paths:
for listener_name, (
condition_type,
trigger_methods,
) in self.flow._listeners.items():
if path in trigger_methods:
is_cycle_edge = self._is_ancestor(
trigger, method_name, ancestors
)
# Determine if parent has multiple children
parent_has_multiple_children = (
len(parent_children.get(router_method_name, [])) > 1
)
# Edge curvature logic
needs_curvature = is_cycle_edge or parent_has_multiple_children
if needs_curvature:
# Get node positions
source_pos = node_positions.get(router_method_name)
target_pos = node_positions.get(listener_name)
if source_pos and target_pos:
dx = target_pos[0] - source_pos[0]
if dx <= 0:
# Child is left or directly below
smooth_type = "curvedCCW" # Curve left and down
else:
# Child is to the right
smooth_type = "curvedCW" # Curve right and down
index = self._get_child_index(
router_method_name, listener_name, parent_children
)
edge_smooth = {
"type": smooth_type,
"roundness": 0.2 + (0.1 * index),
}
else:
# Fallback curvature
edge_smooth = {"type": "cubicBezier"}
else:
edge_smooth = False # Straight line
edge_style = {
"color": self.colors["router_edge"],
"width": 2,
"arrows": "to",
"dashes": True,
"smooth": edge_smooth,
}
net.add_edge(router_method_name, listener_name, **edge_style)
# Add edges to the network
add_edges(net, self.flow, node_positions, self.colors)
# Set options to disable physics
net.set_options(
@@ -246,227 +54,31 @@ class FlowVisualizer:
)
network_html = net.generate_html()
# Extract just the body content from the generated HTML
match = re.search("<body.*?>(.*?)</body>", network_html, re.DOTALL)
if match:
network_body = match.group(1)
else:
network_body = ""
# Read the custom template
current_dir = os.path.dirname(__file__)
template_path = os.path.join(
current_dir, "assets", "crewai_flow_visual_template.html"
)
with open(template_path, "r", encoding="utf-8") as f:
html_template = f.read()
# Generate the legend items HTML
legend_items = [
{"label": "Start Method", "color": self.colors["start"]},
{"label": "Method", "color": self.colors["method"]},
{
"label": "Router",
"color": self.colors["router"],
"border": self.colors["router_border"],
"dashed": True,
},
{"label": "Trigger", "color": self.colors["edge"], "dashed": False},
{"label": "AND Trigger", "color": self.colors["edge"], "dashed": True},
{
"label": "Router Trigger",
"color": self.colors["router_edge"],
"dashed": True,
},
]
legend_items_html = ""
for item in legend_items:
if "border" in item:
legend_items_html += f"""
<div class="legend-item">
<div class="legend-color-box" style="background-color: {item['color']}; border: 2px dashed {item['border']};"></div>
<div>{item['label']}</div>
</div>
"""
elif item.get("dashed") is not None:
style = "dashed" if item["dashed"] else "solid"
legend_items_html += f"""
<div class="legend-item">
<div class="legend-{style}" style="border-bottom: 2px {style} {item['color']};"></div>
<div>{item['label']}</div>
</div>
"""
else:
legend_items_html += f"""
<div class="legend-item">
<div class="legend-color-box" style="background-color: {item['color']};"></div>
<div>{item['label']}</div>
</div>
"""
# Read the logo file and encode it
logo_path = os.path.join(current_dir, "assets", "crewai_logo.svg")
with open(logo_path, "rb") as logo_file:
logo_svg_data = logo_file.read()
logo_svg_base64 = base64.b64encode(logo_svg_data).decode("utf-8")
# Replace placeholders in the template
final_html_content = html_template.replace("{{ title }}", "Flow Graph")
final_html_content = final_html_content.replace(
"{{ network_content }}", network_body
)
final_html_content = final_html_content.replace(
"{{ logo_svg_base64 }}", logo_svg_base64
)
final_html_content = final_html_content.replace(
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
)
final_html_content = self._generate_final_html(network_html)
# Save the final HTML content to the file
with open(f"{filename}.html", "w", encoding="utf-8") as f:
f.write(final_html_content)
print(f"Graph saved as {filename}.html")
def _calculate_node_levels(self):
levels = {}
queue = []
visited = set()
pending_and_listeners = {}
def _generate_final_html(self, network_html):
# Extract just the body content from the generated HTML
current_dir = os.path.dirname(__file__)
template_path = os.path.join(
current_dir, "assets", "crewai_flow_visual_template.html"
)
logo_path = os.path.join(current_dir, "assets", "crewai_logo.svg")
# Make all start methods at level 0
for method_name, method in self.flow._methods.items():
if hasattr(method, "__is_start_method__"):
levels[method_name] = 0
queue.append(method_name)
html_handler = HTMLTemplateHandler(template_path, logo_path)
network_body = html_handler.extract_body_content(network_html)
# Breadth-first traversal to assign levels
while queue:
current = queue.pop(0)
current_level = levels[current]
visited.add(current)
for listener_name, (
condition_type,
trigger_methods,
) in self.flow._listeners.items():
if condition_type == "OR":
if current in trigger_methods:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
elif condition_type == "AND":
if listener_name not in pending_and_listeners:
pending_and_listeners[listener_name] = set()
if current in trigger_methods:
pending_and_listeners[listener_name].add(current)
if set(trigger_methods) == pending_and_listeners[listener_name]:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
# Handle router connections
if current in self.flow._routers.values():
router_method_name = current
paths = self.flow._router_paths.get(router_method_name, [])
for path in paths:
for listener_name, (
condition_type,
trigger_methods,
) in self.flow._listeners.items():
if path in trigger_methods:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
return levels
def _count_outgoing_edges(self):
counts = {}
for method_name in self.flow._methods:
counts[method_name] = 0
for method_name in self.flow._listeners:
_, trigger_methods = self.flow._listeners[method_name]
for trigger in trigger_methods:
if trigger in self.flow._methods:
counts[trigger] += 1
return counts
def _build_ancestor_dict(self):
ancestors = {node: set() for node in self.flow._methods}
visited = set()
for node in self.flow._methods:
if node not in visited:
self._dfs_ancestors(node, ancestors, visited)
return ancestors
def _dfs_ancestors(self, node, ancestors, visited):
if node in visited:
return
visited.add(node)
# Handle regular listeners
for listener_name, (_, trigger_methods) in self.flow._listeners.items():
if node in trigger_methods:
ancestors[listener_name].add(node)
ancestors[listener_name].update(ancestors[node])
self._dfs_ancestors(listener_name, ancestors, visited)
# Handle router methods separately
if node in self.flow._routers.values():
router_method_name = node
paths = self.flow._router_paths.get(router_method_name, [])
for path in paths:
for listener_name, (_, trigger_methods) in self.flow._listeners.items():
if path in trigger_methods:
# Only propagate the ancestors of the router method, not the router method itself
ancestors[listener_name].update(ancestors[node])
self._dfs_ancestors(listener_name, ancestors, visited)
def _is_ancestor(self, node, ancestor_candidate, ancestors):
return ancestor_candidate in ancestors.get(node, set())
def _build_parent_children_dict(self):
parent_children = {}
# Map listeners to their trigger methods
for listener_name, (_, trigger_methods) in self.flow._listeners.items():
for trigger in trigger_methods:
if trigger not in parent_children:
parent_children[trigger] = []
if listener_name not in parent_children[trigger]:
parent_children[trigger].append(listener_name)
# Map router methods to their paths and to listeners
for router_method_name, paths in self.flow._router_paths.items():
for path in paths:
# Map router method to listeners of each path
for listener_name, (_, trigger_methods) in self.flow._listeners.items():
if path in trigger_methods:
if router_method_name not in parent_children:
parent_children[router_method_name] = []
if listener_name not in parent_children[router_method_name]:
parent_children[router_method_name].append(listener_name)
return parent_children
def _get_child_index(self, parent, child, parent_children):
children = parent_children.get(parent, [])
children.sort()
return children.index(child)
# Generate the legend items HTML
legend_items = get_legend_items(self.colors)
legend_items_html = generate_legend_items_html(legend_items)
final_html_content = html_handler.generate_final_html(
network_body, legend_items_html
)
return final_html_content
def visualize_flow(flow, filename="flow_graph"):

View File

@@ -0,0 +1,66 @@
import base64
import os
import re
class HTMLTemplateHandler:
def __init__(self, template_path, logo_path):
self.template_path = template_path
self.logo_path = logo_path
def read_template(self):
with open(self.template_path, "r", encoding="utf-8") as f:
return f.read()
def encode_logo(self):
with open(self.logo_path, "rb") as logo_file:
logo_svg_data = logo_file.read()
return base64.b64encode(logo_svg_data).decode("utf-8")
def extract_body_content(self, html):
match = re.search("<body.*?>(.*?)</body>", html, re.DOTALL)
return match.group(1) if match else ""
def generate_legend_items_html(self, legend_items):
legend_items_html = ""
for item in legend_items:
if "border" in item:
legend_items_html += f"""
<div class="legend-item">
<div class="legend-color-box" style="background-color: {item['color']}; border: 2px dashed {item['border']};"></div>
<div>{item['label']}</div>
</div>
"""
elif item.get("dashed") is not None:
style = "dashed" if item["dashed"] else "solid"
legend_items_html += f"""
<div class="legend-item">
<div class="legend-{style}" style="border-bottom: 2px {style} {item['color']};"></div>
<div>{item['label']}</div>
</div>
"""
else:
legend_items_html += f"""
<div class="legend-item">
<div class="legend-color-box" style="background-color: {item['color']};"></div>
<div>{item['label']}</div>
</div>
"""
return legend_items_html
def generate_final_html(self, network_body, legend_items_html, title="Flow Graph"):
html_template = self.read_template()
logo_svg_base64 = self.encode_logo()
final_html_content = html_template.replace("{{ title }}", title)
final_html_content = final_html_content.replace(
"{{ network_content }}", network_body
)
final_html_content = final_html_content.replace(
"{{ logo_svg_base64 }}", logo_svg_base64
)
final_html_content = final_html_content.replace(
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
)
return final_html_content

View File

@@ -0,0 +1,46 @@
def get_legend_items(colors):
return [
{"label": "Start Method", "color": colors["start"]},
{"label": "Method", "color": colors["method"]},
{
"label": "Router",
"color": colors["router"],
"border": colors["router_border"],
"dashed": True,
},
{"label": "Trigger", "color": colors["edge"], "dashed": False},
{"label": "AND Trigger", "color": colors["edge"], "dashed": True},
{
"label": "Router Trigger",
"color": colors["router_edge"],
"dashed": True,
},
]
def generate_legend_items_html(legend_items):
legend_items_html = ""
for item in legend_items:
if "border" in item:
legend_items_html += f"""
<div class="legend-item">
<div class="legend-color-box" style="background-color: {item['color']}; border: 2px dashed {item['border']};"></div>
<div>{item['label']}</div>
</div>
"""
elif item.get("dashed") is not None:
style = "dashed" if item["dashed"] else "solid"
legend_items_html += f"""
<div class="legend-item">
<div class="legend-{style}" style="border-bottom: 2px {style} {item['color']};"></div>
<div>{item['label']}</div>
</div>
"""
else:
legend_items_html += f"""
<div class="legend-item">
<div class="legend-color-box" style="background-color: {item['color']};"></div>
<div>{item['label']}</div>
</div>
"""
return legend_items_html

143
src/crewai/flow/utils.py Normal file
View File

@@ -0,0 +1,143 @@
def calculate_node_levels(flow):
levels = {}
queue = []
visited = set()
pending_and_listeners = {}
# Make all start methods at level 0
for method_name, method in flow._methods.items():
if hasattr(method, "__is_start_method__"):
levels[method_name] = 0
queue.append(method_name)
# Breadth-first traversal to assign levels
while queue:
current = queue.pop(0)
current_level = levels[current]
visited.add(current)
for listener_name, (
condition_type,
trigger_methods,
) in flow._listeners.items():
if condition_type == "OR":
if current in trigger_methods:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
elif condition_type == "AND":
if listener_name not in pending_and_listeners:
pending_and_listeners[listener_name] = set()
if current in trigger_methods:
pending_and_listeners[listener_name].add(current)
if set(trigger_methods) == pending_and_listeners[listener_name]:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
# Handle router connections
if current in flow._routers.values():
router_method_name = current
paths = flow._router_paths.get(router_method_name, [])
for path in paths:
for listener_name, (
condition_type,
trigger_methods,
) in flow._listeners.items():
if path in trigger_methods:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
return levels
def count_outgoing_edges(flow):
counts = {}
for method_name in flow._methods:
counts[method_name] = 0
for method_name in flow._listeners:
_, trigger_methods = flow._listeners[method_name]
for trigger in trigger_methods:
if trigger in flow._methods:
counts[trigger] += 1
return counts
def build_ancestor_dict(flow):
ancestors = {node: set() for node in flow._methods}
visited = set()
for node in flow._methods:
if node not in visited:
dfs_ancestors(node, ancestors, visited, flow)
return ancestors
def dfs_ancestors(node, ancestors, visited, flow):
if node in visited:
return
visited.add(node)
# Handle regular listeners
for listener_name, (_, trigger_methods) in flow._listeners.items():
if node in trigger_methods:
ancestors[listener_name].add(node)
ancestors[listener_name].update(ancestors[node])
dfs_ancestors(listener_name, ancestors, visited, flow)
# Handle router methods separately
if node in flow._routers.values():
router_method_name = node
paths = flow._router_paths.get(router_method_name, [])
for path in paths:
for listener_name, (_, trigger_methods) in flow._listeners.items():
if path in trigger_methods:
# Only propagate the ancestors of the router method, not the router method itself
ancestors[listener_name].update(ancestors[node])
dfs_ancestors(listener_name, ancestors, visited, flow)
def is_ancestor(node, ancestor_candidate, ancestors):
return ancestor_candidate in ancestors.get(node, set())
def build_parent_children_dict(flow):
parent_children = {}
# Map listeners to their trigger methods
for listener_name, (_, trigger_methods) in flow._listeners.items():
for trigger in trigger_methods:
if trigger not in parent_children:
parent_children[trigger] = []
if listener_name not in parent_children[trigger]:
parent_children[trigger].append(listener_name)
# Map router methods to their paths and to listeners
for router_method_name, paths in flow._router_paths.items():
for path in paths:
# Map router method to listeners of each path
for listener_name, (_, trigger_methods) in flow._listeners.items():
if path in trigger_methods:
if router_method_name not in parent_children:
parent_children[router_method_name] = []
if listener_name not in parent_children[router_method_name]:
parent_children[router_method_name].append(listener_name)
return parent_children
def get_child_index(parent, child, parent_children):
children = parent_children.get(parent, [])
children.sort()
return children.index(child)

View File

@@ -0,0 +1,132 @@
from .utils import (
build_ancestor_dict,
build_parent_children_dict,
get_child_index,
is_ancestor,
)
def compute_positions(flow, node_levels, y_spacing=150, x_spacing=150):
level_nodes = {}
node_positions = {}
for method_name, level in node_levels.items():
level_nodes.setdefault(level, []).append(method_name)
for level, nodes in level_nodes.items():
x_offset = -(len(nodes) - 1) * x_spacing / 2 # Center nodes horizontally
for i, method_name in enumerate(nodes):
x = x_offset + i * x_spacing
y = level * y_spacing
node_positions[method_name] = (x, y)
return node_positions
def add_edges(net, flow, node_positions, colors):
ancestors = build_ancestor_dict(flow)
parent_children = build_parent_children_dict(flow)
for method_name in flow._listeners:
condition_type, trigger_methods = flow._listeners[method_name]
is_and_condition = condition_type == "AND"
for trigger in trigger_methods:
if trigger in flow._methods or trigger in flow._routers.values():
is_router_edge = any(
trigger in paths for paths in flow._router_paths.values()
)
edge_color = colors["router_edge"] if is_router_edge else colors["edge"]
is_cycle_edge = is_ancestor(trigger, method_name, ancestors)
parent_has_multiple_children = len(parent_children.get(trigger, [])) > 1
needs_curvature = is_cycle_edge or parent_has_multiple_children
if needs_curvature:
source_pos = node_positions.get(trigger)
target_pos = node_positions.get(method_name)
if source_pos and target_pos:
dx = target_pos[0] - source_pos[0]
smooth_type = "curvedCCW" if dx <= 0 else "curvedCW"
index = get_child_index(trigger, method_name, parent_children)
edge_smooth = {
"type": smooth_type,
"roundness": 0.2 + (0.1 * index),
}
else:
edge_smooth = {"type": "cubicBezier"}
else:
edge_smooth = False
edge_style = {
"color": edge_color,
"width": 2,
"arrows": "to",
"dashes": True if is_router_edge or is_and_condition else False,
"smooth": edge_smooth,
}
net.add_edge(trigger, method_name, **edge_style)
for router_method_name, paths in flow._router_paths.items():
for path in paths:
for listener_name, (
condition_type,
trigger_methods,
) in flow._listeners.items():
if path in trigger_methods:
is_cycle_edge = is_ancestor(trigger, method_name, ancestors)
parent_has_multiple_children = (
len(parent_children.get(router_method_name, [])) > 1
)
needs_curvature = is_cycle_edge or parent_has_multiple_children
if needs_curvature:
source_pos = node_positions.get(router_method_name)
target_pos = node_positions.get(listener_name)
if source_pos and target_pos:
dx = target_pos[0] - source_pos[0]
smooth_type = "curvedCCW" if dx <= 0 else "curvedCW"
index = get_child_index(
router_method_name, listener_name, parent_children
)
edge_smooth = {
"type": smooth_type,
"roundness": 0.2 + (0.1 * index),
}
else:
edge_smooth = {"type": "cubicBezier"}
else:
edge_smooth = False
edge_style = {
"color": colors["router_edge"],
"width": 2,
"arrows": "to",
"dashes": True,
"smooth": edge_smooth,
}
net.add_edge(router_method_name, listener_name, **edge_style)
def add_nodes_to_network(net, flow, node_positions, node_styles):
for method_name, (x, y) in node_positions.items():
method = flow._methods.get(method_name)
if hasattr(method, "__is_start_method__"):
node_style = node_styles["start"]
elif hasattr(method, "__is_router__"):
node_style = node_styles["router"]
else:
node_style = node_styles["method"]
net.add_node(
method_name,
label=method_name,
x=x,
y=y,
fixed=True,
physics=False,
**node_style,
)