diff --git a/src/crewai/flow/flow_visualizer.py b/src/crewai/flow/flow_visualizer.py index 839694a0b..3a2e60005 100644 --- a/src/crewai/flow/flow_visualizer.py +++ b/src/crewai/flow/flow_visualizer.py @@ -3,7 +3,6 @@ import base64 import os import re -from abc import ABC, abstractmethod from pyvis.network import Network @@ -13,7 +12,7 @@ GRAY = "#666666" WHITE = "#FFFFFF" -class FlowVisualizer(ABC): +class FlowVisualizer: def __init__(self, flow): self.flow = flow self.colors = { @@ -31,13 +30,13 @@ class FlowVisualizer(ABC): "color": self.colors["start"], "shape": "box", "font": {"color": self.colors["text"]}, - "margin": 15, + "margin": {"top": 10, "bottom": 8, "left": 10, "right": 10}, }, "method": { "color": self.colors["method"], "shape": "box", "font": {"color": self.colors["text"]}, - "margin": 15, + "margin": {"top": 10, "bottom": 8, "left": 10, "right": 10}, }, "router": { "color": { @@ -53,16 +52,10 @@ class FlowVisualizer(ABC): "borderWidth": 3, "borderWidthSelected": 4, "shapeProperties": {"borderDashes": [5, 5]}, - "margin": 15, + "margin": {"top": 10, "bottom": 8, "left": 10, "right": 10}, }, } - @abstractmethod - def visualize(self, filename): - pass - - -class PyvisFlowVisualizer(FlowVisualizer): def visualize(self, filename): net = Network( directed=True, @@ -74,7 +67,6 @@ class PyvisFlowVisualizer(FlowVisualizer): # Calculate levels for nodes node_levels = self._calculate_node_levels() - print("node_levels", node_levels) # Assign positions to nodes based on levels y_spacing = 150 @@ -342,7 +334,7 @@ class PyvisFlowVisualizer(FlowVisualizer): visited = set() pending_and_listeners = {} - # Initialize start methods at level 0 + # Make all start methods at level 0 for method_name, method in self.flow._methods.items(): if hasattr(method, "__is_start_method__"): levels[method_name] = 0 @@ -354,7 +346,6 @@ class PyvisFlowVisualizer(FlowVisualizer): current_level = levels[current] visited.add(current) - # Get methods that listen to the current method for listener_name, ( condition_type, trigger_methods, @@ -382,7 +373,7 @@ class PyvisFlowVisualizer(FlowVisualizer): if listener_name not in visited: queue.append(listener_name) - # Handle router connections (same as before) + # Handle router connections if current in self.flow._routers.values(): router_method_name = current paths = self.flow._router_paths.get(router_method_name, []) @@ -402,7 +393,6 @@ class PyvisFlowVisualizer(FlowVisualizer): return levels def _count_outgoing_edges(self): - # Helper method to count the number of outgoing edges from each node counts = {} for method_name in self.flow._methods: counts[method_name] = 0 @@ -419,9 +409,7 @@ class PyvisFlowVisualizer(FlowVisualizer): for node in self.flow._methods: if node not in visited: self._dfs_ancestors(node, ancestors, visited) - print("Ancestor Relationships:") - for node, node_ancestors in ancestors.items(): - print(f"{node}: {node_ancestors}") + return ancestors def _dfs_ancestors(self, node, ancestors, visited): @@ -452,6 +440,7 @@ class PyvisFlowVisualizer(FlowVisualizer): def _build_parent_children_dict(self): parent_children = {} + # Map listeners to their trigger methods for listener_name, (_, trigger_methods) in self.flow._listeners.items(): for trigger in trigger_methods: @@ -459,6 +448,7 @@ class PyvisFlowVisualizer(FlowVisualizer): parent_children[trigger] = [] if listener_name not in parent_children[trigger]: parent_children[trigger].append(listener_name) + # Map router methods to their paths and to listeners for router_method_name, paths in self.flow._router_paths.items(): for path in paths: @@ -469,19 +459,15 @@ class PyvisFlowVisualizer(FlowVisualizer): parent_children[router_method_name] = [] if listener_name not in parent_children[router_method_name]: parent_children[router_method_name].append(listener_name) - # Debugging output - print("Parent-Children Relationships:") - for parent, children in parent_children.items(): - print(f"{parent}: {children}") + return parent_children def _get_child_index(self, parent, child, parent_children): - # Helper method to get the index of the child among the parent's children children = parent_children.get(parent, []) children.sort() return children.index(child) def visualize_flow(flow, filename="flow_graph"): - visualizer = PyvisFlowVisualizer(flow) + visualizer = FlowVisualizer(flow) visualizer.visualize(filename)