diff --git a/src/crewai/flow/flow_visualizer.py b/src/crewai/flow/flow_visualizer.py index 52902bdfd..2cf43fa02 100644 --- a/src/crewai/flow/flow_visualizer.py +++ b/src/crewai/flow/flow_visualizer.py @@ -65,6 +65,9 @@ class PyvisFlowVisualizer(FlowVisualizer): x_spacing = 150 # Adjust spacing between nodes level_nodes = {} + # Store node positions for edge calculations + node_positions = {} + for method_name, level in node_levels.items(): level_nodes.setdefault(level, []).append(method_name) @@ -74,6 +77,8 @@ class PyvisFlowVisualizer(FlowVisualizer): for i, method_name in enumerate(nodes): x = x_offset + i * x_spacing y = level * y_spacing # Use level directly for y position + node_positions[method_name] = (x, y) # Store positions + method = self.flow._methods.get(method_name) if hasattr(method, "__is_start_method__"): node_style = self.node_styles["start"] @@ -92,6 +97,10 @@ class PyvisFlowVisualizer(FlowVisualizer): **node_style, ) + # Prepare data structures for edge calculations + ancestors = self._build_ancestor_dict() + parent_children = self._build_parent_children_dict() + # Add edges for method_name in self.flow._listeners: condition_type, trigger_methods = self.flow._listeners[method_name] @@ -108,18 +117,39 @@ class PyvisFlowVisualizer(FlowVisualizer): else self.colors["edge"] ) - # **New: Determine if this is a cycle edge to apply curvature** - is_cycle_edge = trigger == method_name + # Determine if this edge forms a cycle + is_cycle_edge = self._is_ancestor(method_name, trigger, ancestors) + + # Determine if parent has multiple children + parent_has_multiple_children = ( + len(parent_children.get(trigger, [])) > 1 + ) + + # Determine if edge should be curved + needs_curvature = is_cycle_edge or parent_has_multiple_children + + if needs_curvature: + if is_cycle_edge: + # For cycles, curve left and up + edge_smooth = {"type": "curvedCCW", "roundness": 0.3} + else: + # For multiple children, adjust curvature to prevent overlap + index = self._get_child_index( + trigger, method_name, parent_children + ) + edge_smooth = { + "type": "curvedCW" if index % 2 == 0 else "curvedCCW", + "roundness": 0.2 + (0.1 * index), + } + else: + edge_smooth = False # Draw straight line + edge_style = { "color": edge_color, "width": 2, "arrows": "to", "dashes": True if is_router_edge or is_and_condition else False, - "smooth": ( - {"type": "curvedCCW", "roundness": 0.5} - if is_cycle_edge - else {"type": "cubicBezier"} - ), + "smooth": edge_smooth, } net.add_edge(trigger, method_name, **edge_style) @@ -132,31 +162,51 @@ class PyvisFlowVisualizer(FlowVisualizer): trigger_methods, ) in self.flow._listeners.items(): if path in trigger_methods: + is_cycle_edge = self._is_ancestor( + listener_name, router_method_name, ancestors + ) + + # Determine if parent has multiple children + parent_has_multiple_children = ( + len(parent_children.get(router_method_name, [])) > 1 + ) + + # Determine if edge should be curved + needs_curvature = is_cycle_edge or parent_has_multiple_children + + if needs_curvature: + if is_cycle_edge: + # Curve left and up for cycles + edge_smooth = {"type": "curvedCCW", "roundness": 0.3} + else: + # For multiple children, adjust curvature + index = self._get_child_index( + router_method_name, listener_name, parent_children + ) + edge_smooth = { + "type": ( + "curvedCW" if index % 2 == 0 else "curvedCCW" + ), + "roundness": 0.2 + (0.1 * index), + } + else: + edge_smooth = False # Straight line + edge_style = { "color": self.colors["router_edge"], "width": 2, "arrows": "to", "dashes": True, - "smooth": { - "type": "curvedCW", - "roundness": 0.3, - }, # Curvature for router edges + "smooth": edge_smooth, } net.add_edge(router_method_name, listener_name, **edge_style) - # Set options for curved edges and disable physics + # Set options to disable physics net.set_options( """ var options = { "physics": { "enabled": false - }, - "edges": { - "smooth": { - "enabled": true, - "type": "cubicBezier", - "roundness": 0.5 - } } } """ @@ -311,6 +361,75 @@ class PyvisFlowVisualizer(FlowVisualizer): queue.append(listener_name) return levels + def _count_outgoing_edges(self): + # Helper method to count the number of outgoing edges from each node + counts = {} + for method_name in self.flow._methods: + counts[method_name] = 0 + for method_name in self.flow._listeners: + _, trigger_methods = self.flow._listeners[method_name] + for trigger in trigger_methods: + if trigger in self.flow._methods: + counts[trigger] += 1 + return counts + + def _build_ancestor_dict(self): + # Helper method to build a dictionary of ancestors for each node + ancestors = {node: set() for node in self.flow._methods} + for node in self.flow._methods: + self._dfs_ancestors(node, ancestors, set()) + return ancestors + + def _dfs_ancestors(self, node, ancestors, visited): + if node in visited: + return + visited.add(node) + for listener_name, (_, trigger_methods) in self.flow._listeners.items(): + if node in trigger_methods: + ancestors[listener_name].add(node) + ancestors[listener_name].update(ancestors[node]) + self._dfs_ancestors(listener_name, ancestors, visited) + # Include router paths + for router_method_name, paths in self.flow._router_paths.items(): + if node == router_method_name: + for path in paths: + for listener_name, ( + _, + trigger_methods, + ) in self.flow._listeners.items(): + if path in trigger_methods: + ancestors[listener_name].add(node) + ancestors[listener_name].update(ancestors[node]) + self._dfs_ancestors(listener_name, ancestors, visited) + + def _is_ancestor(self, node, ancestor_candidate, ancestors): + return ancestor_candidate in ancestors.get(node, set()) + + def _build_parent_children_dict(self): + # Helper method to build a parent to children mapping + parent_children = {} + for listener_name, (_, trigger_methods) in self.flow._listeners.items(): + for trigger in trigger_methods: + if trigger not in parent_children: + parent_children[trigger] = [] + parent_children[trigger].append(listener_name) + # Include router paths and their connected listener methods + for router_method_name, paths in self.flow._router_paths.items(): + for path in paths: + # Find listener methods triggered by each path + for listener_name, (_, trigger_methods) in self.flow._listeners.items(): + if path in trigger_methods: + parent_children.setdefault(router_method_name, []).append( + listener_name + ) + return parent_children + + def _get_child_index(self, parent, child, parent_children): + # Helper method to get the index of the child among the parent's children + children = parent_children.get(parent, []) + children.sort() + return children.index(child) + def visualize_flow(flow, filename="flow_graph"): visualizer = PyvisFlowVisualizer(flow)