diff --git a/src/crewai/flow/flow_visualizer.py b/src/crewai/flow/flow_visualizer.py index beaaa36ee..d473a34bc 100644 --- a/src/crewai/flow/flow_visualizer.py +++ b/src/crewai/flow/flow_visualizer.py @@ -12,12 +12,10 @@ class FlowVisualizer(ABC): "bg": "#FFFFFF", "start": "#FF5A50", "method": "#333333", - "router_outline": "#FF5A50", - "edge": "#333333", + "router": "#FF8C00", + "edge": "#666666", "text": "#FFFFFF", } - self.node_rectangles = {} - self.node_positions = {} @abstractmethod def visualize(self, filename): @@ -107,35 +105,50 @@ class GraphvizVisualizer(FlowVisualizer): class PyvisFlowVisualizer(FlowVisualizer): def visualize(self, filename): net = Network( - directed=True, height="750px", width="100%", bgcolor=self.colors["bg"] + directed=True, + height="750px", + width="100%", + bgcolor=self.colors["bg"], + layout=None, ) # Define custom node styles node_styles = { "start": { - "color": self.colors["start"], + "color": self.colors.get("start", "#FF5A50"), "shape": "box", - "font": {"color": self.colors["text"]}, + "font": {"color": self.colors.get("text", "#FFFFFF")}, }, "method": { - "color": self.colors["method"], + "color": self.colors.get("method", "#333333"), "shape": "box", - "font": {"color": self.colors["text"]}, + "font": {"color": self.colors.get("text", "#FFFFFF")}, + }, + "router": { + "color": self.colors.get("router", "#FF8C00"), + "shape": "box", + "font": {"color": self.colors.get("text", "#FFFFFF")}, }, - # "router": { - # "color": self.colors["router"], - # "shape": "box", - # "font": {"color": self.colors["text"]}, - # }, } - # Add nodes - for method_name, method in self.flow._methods.items(): - if ( - hasattr(method, "__is_start_method__") - or method_name in self.flow._listeners - or method_name in self.flow._routers.values() - ): + # Calculate levels for nodes + node_levels = self._calculate_node_levels() + + # Assign positions to nodes based on levels + y_spacing = 150 # Adjust spacing between levels (positive for top-down) + x_spacing = 150 # Adjust spacing between nodes + level_nodes = {} + + for method_name, level in node_levels.items(): + level_nodes.setdefault(level, []).append(method_name) + + # Compute positions + for level, nodes in level_nodes.items(): + x_offset = -(len(nodes) - 1) * x_spacing / 2 # Center nodes horizontally + for i, method_name in enumerate(nodes): + x = x_offset + i * x_spacing + y = level * y_spacing # Use level directly for y position + method = self.flow._methods.get(method_name) if hasattr(method, "__is_start_method__"): node_style = node_styles["start"] elif method_name in self.flow._routers.values(): @@ -143,9 +156,17 @@ class PyvisFlowVisualizer(FlowVisualizer): else: node_style = node_styles["method"] - net.add_node(method_name, label=method_name, **node_style) + net.add_node( + method_name, + label=method_name, + x=x, + y=y, + fixed=True, + physics=False, # Disable physics for fixed positioning + **node_style, + ) - # Add edges + # Add edges with curved lines for method_name in self.flow._listeners: condition_type, trigger_methods = self.flow._listeners[method_name] is_and_condition = condition_type == "AND" @@ -154,17 +175,68 @@ class PyvisFlowVisualizer(FlowVisualizer): net.add_edge( trigger, method_name, - color=self.colors["edge"], + color=self.colors.get("edge", "#666666"), width=2, arrows="to", - dashes=is_and_condition, # Dashed lines for AND conditions + dashes=is_and_condition, smooth={"type": "cubicBezier"}, ) + # Set options for curved edges and disable physics + net.set_options( + """ + var options = { + "physics": { + "enabled": false + }, + "edges": { + "smooth": { + "enabled": true, + "type": "cubicBezier", + "roundness": 0.5 + } + } + } + """ + ) + # Generate and save the graph net.write_html(f"{filename}.html") print(f"Graph saved as {filename}.html") + def _calculate_node_levels(self): + levels = {} + queue = [] + visited = set() + + # Initialize start methods at level 0 + for method_name, method in self.flow._methods.items(): + if hasattr(method, "__is_start_method__"): + levels[method_name] = 0 + queue.append(method_name) + + # Breadth-first traversal to assign levels + while queue: + current = queue.pop(0) + current_level = levels[current] + visited.add(current) + + # Get methods that listen to the current method + for listener_name, ( + condition_type, + trigger_methods, + ) in self.flow._listeners.items(): + if current in trigger_methods: + if ( + listener_name not in levels + or levels[listener_name] > current_level + 1 + ): + levels[listener_name] = current_level + 1 + if listener_name not in visited: + queue.append(listener_name) + + return levels + def is_graphviz_available(): try: