everythin is showing up properly need to fix curves

This commit is contained in:
Brandon Hancock
2024-09-30 13:28:32 -04:00
parent e1c01ae907
commit b927989c4d

View File

@@ -65,6 +65,9 @@ class PyvisFlowVisualizer(FlowVisualizer):
x_spacing = 150 # Adjust spacing between nodes
level_nodes = {}
# Store node positions for edge calculations
node_positions = {}
for method_name, level in node_levels.items():
level_nodes.setdefault(level, []).append(method_name)
@@ -74,6 +77,8 @@ class PyvisFlowVisualizer(FlowVisualizer):
for i, method_name in enumerate(nodes):
x = x_offset + i * x_spacing
y = level * y_spacing # Use level directly for y position
node_positions[method_name] = (x, y) # Store positions
method = self.flow._methods.get(method_name)
if hasattr(method, "__is_start_method__"):
node_style = self.node_styles["start"]
@@ -92,6 +97,10 @@ class PyvisFlowVisualizer(FlowVisualizer):
**node_style,
)
# Prepare data structures for edge calculations
ancestors = self._build_ancestor_dict()
parent_children = self._build_parent_children_dict()
# Add edges
for method_name in self.flow._listeners:
condition_type, trigger_methods = self.flow._listeners[method_name]
@@ -108,18 +117,39 @@ class PyvisFlowVisualizer(FlowVisualizer):
else self.colors["edge"]
)
# **New: Determine if this is a cycle edge to apply curvature**
is_cycle_edge = trigger == method_name
# Determine if this edge forms a cycle
is_cycle_edge = self._is_ancestor(method_name, trigger, ancestors)
# Determine if parent has multiple children
parent_has_multiple_children = (
len(parent_children.get(trigger, [])) > 1
)
# Determine if edge should be curved
needs_curvature = is_cycle_edge or parent_has_multiple_children
if needs_curvature:
if is_cycle_edge:
# For cycles, curve left and up
edge_smooth = {"type": "curvedCCW", "roundness": 0.3}
else:
# For multiple children, adjust curvature to prevent overlap
index = self._get_child_index(
trigger, method_name, parent_children
)
edge_smooth = {
"type": "curvedCW" if index % 2 == 0 else "curvedCCW",
"roundness": 0.2 + (0.1 * index),
}
else:
edge_smooth = False # Draw straight line
edge_style = {
"color": edge_color,
"width": 2,
"arrows": "to",
"dashes": True if is_router_edge or is_and_condition else False,
"smooth": (
{"type": "curvedCCW", "roundness": 0.5}
if is_cycle_edge
else {"type": "cubicBezier"}
),
"smooth": edge_smooth,
}
net.add_edge(trigger, method_name, **edge_style)
@@ -132,31 +162,51 @@ class PyvisFlowVisualizer(FlowVisualizer):
trigger_methods,
) in self.flow._listeners.items():
if path in trigger_methods:
is_cycle_edge = self._is_ancestor(
listener_name, router_method_name, ancestors
)
# Determine if parent has multiple children
parent_has_multiple_children = (
len(parent_children.get(router_method_name, [])) > 1
)
# Determine if edge should be curved
needs_curvature = is_cycle_edge or parent_has_multiple_children
if needs_curvature:
if is_cycle_edge:
# Curve left and up for cycles
edge_smooth = {"type": "curvedCCW", "roundness": 0.3}
else:
# For multiple children, adjust curvature
index = self._get_child_index(
router_method_name, listener_name, parent_children
)
edge_smooth = {
"type": (
"curvedCW" if index % 2 == 0 else "curvedCCW"
),
"roundness": 0.2 + (0.1 * index),
}
else:
edge_smooth = False # Straight line
edge_style = {
"color": self.colors["router_edge"],
"width": 2,
"arrows": "to",
"dashes": True,
"smooth": {
"type": "curvedCW",
"roundness": 0.3,
}, # Curvature for router edges
"smooth": edge_smooth,
}
net.add_edge(router_method_name, listener_name, **edge_style)
# Set options for curved edges and disable physics
# Set options to disable physics
net.set_options(
"""
var options = {
"physics": {
"enabled": false
},
"edges": {
"smooth": {
"enabled": true,
"type": "cubicBezier",
"roundness": 0.5
}
}
}
"""
@@ -311,6 +361,75 @@ class PyvisFlowVisualizer(FlowVisualizer):
queue.append(listener_name)
return levels
def _count_outgoing_edges(self):
# Helper method to count the number of outgoing edges from each node
counts = {}
for method_name in self.flow._methods:
counts[method_name] = 0
for method_name in self.flow._listeners:
_, trigger_methods = self.flow._listeners[method_name]
for trigger in trigger_methods:
if trigger in self.flow._methods:
counts[trigger] += 1
return counts
def _build_ancestor_dict(self):
# Helper method to build a dictionary of ancestors for each node
ancestors = {node: set() for node in self.flow._methods}
for node in self.flow._methods:
self._dfs_ancestors(node, ancestors, set())
return ancestors
def _dfs_ancestors(self, node, ancestors, visited):
if node in visited:
return
visited.add(node)
for listener_name, (_, trigger_methods) in self.flow._listeners.items():
if node in trigger_methods:
ancestors[listener_name].add(node)
ancestors[listener_name].update(ancestors[node])
self._dfs_ancestors(listener_name, ancestors, visited)
# Include router paths
for router_method_name, paths in self.flow._router_paths.items():
if node == router_method_name:
for path in paths:
for listener_name, (
_,
trigger_methods,
) in self.flow._listeners.items():
if path in trigger_methods:
ancestors[listener_name].add(node)
ancestors[listener_name].update(ancestors[node])
self._dfs_ancestors(listener_name, ancestors, visited)
def _is_ancestor(self, node, ancestor_candidate, ancestors):
return ancestor_candidate in ancestors.get(node, set())
def _build_parent_children_dict(self):
# Helper method to build a parent to children mapping
parent_children = {}
for listener_name, (_, trigger_methods) in self.flow._listeners.items():
for trigger in trigger_methods:
if trigger not in parent_children:
parent_children[trigger] = []
parent_children[trigger].append(listener_name)
# Include router paths and their connected listener methods
for router_method_name, paths in self.flow._router_paths.items():
for path in paths:
# Find listener methods triggered by each path
for listener_name, (_, trigger_methods) in self.flow._listeners.items():
if path in trigger_methods:
parent_children.setdefault(router_method_name, []).append(
listener_name
)
return parent_children
def _get_child_index(self, parent, child, parent_children):
# Helper method to get the index of the child among the parent's children
children = parent_children.get(parent, [])
children.sort()
return children.index(child)
def visualize_flow(flow, filename="flow_graph"):
visualizer = PyvisFlowVisualizer(flow)