mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-29 01:58:14 +00:00
adjust padding
This commit is contained in:
@@ -3,7 +3,6 @@
|
|||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
from pyvis.network import Network
|
from pyvis.network import Network
|
||||||
|
|
||||||
@@ -13,7 +12,7 @@ GRAY = "#666666"
|
|||||||
WHITE = "#FFFFFF"
|
WHITE = "#FFFFFF"
|
||||||
|
|
||||||
|
|
||||||
class FlowVisualizer(ABC):
|
class FlowVisualizer:
|
||||||
def __init__(self, flow):
|
def __init__(self, flow):
|
||||||
self.flow = flow
|
self.flow = flow
|
||||||
self.colors = {
|
self.colors = {
|
||||||
@@ -31,13 +30,13 @@ class FlowVisualizer(ABC):
|
|||||||
"color": self.colors["start"],
|
"color": self.colors["start"],
|
||||||
"shape": "box",
|
"shape": "box",
|
||||||
"font": {"color": self.colors["text"]},
|
"font": {"color": self.colors["text"]},
|
||||||
"margin": 15,
|
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
|
||||||
},
|
},
|
||||||
"method": {
|
"method": {
|
||||||
"color": self.colors["method"],
|
"color": self.colors["method"],
|
||||||
"shape": "box",
|
"shape": "box",
|
||||||
"font": {"color": self.colors["text"]},
|
"font": {"color": self.colors["text"]},
|
||||||
"margin": 15,
|
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
|
||||||
},
|
},
|
||||||
"router": {
|
"router": {
|
||||||
"color": {
|
"color": {
|
||||||
@@ -53,16 +52,10 @@ class FlowVisualizer(ABC):
|
|||||||
"borderWidth": 3,
|
"borderWidth": 3,
|
||||||
"borderWidthSelected": 4,
|
"borderWidthSelected": 4,
|
||||||
"shapeProperties": {"borderDashes": [5, 5]},
|
"shapeProperties": {"borderDashes": [5, 5]},
|
||||||
"margin": 15,
|
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def visualize(self, filename):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class PyvisFlowVisualizer(FlowVisualizer):
|
|
||||||
def visualize(self, filename):
|
def visualize(self, filename):
|
||||||
net = Network(
|
net = Network(
|
||||||
directed=True,
|
directed=True,
|
||||||
@@ -74,7 +67,6 @@ class PyvisFlowVisualizer(FlowVisualizer):
|
|||||||
|
|
||||||
# Calculate levels for nodes
|
# Calculate levels for nodes
|
||||||
node_levels = self._calculate_node_levels()
|
node_levels = self._calculate_node_levels()
|
||||||
print("node_levels", node_levels)
|
|
||||||
|
|
||||||
# Assign positions to nodes based on levels
|
# Assign positions to nodes based on levels
|
||||||
y_spacing = 150
|
y_spacing = 150
|
||||||
@@ -342,7 +334,7 @@ class PyvisFlowVisualizer(FlowVisualizer):
|
|||||||
visited = set()
|
visited = set()
|
||||||
pending_and_listeners = {}
|
pending_and_listeners = {}
|
||||||
|
|
||||||
# Initialize start methods at level 0
|
# Make all start methods at level 0
|
||||||
for method_name, method in self.flow._methods.items():
|
for method_name, method in self.flow._methods.items():
|
||||||
if hasattr(method, "__is_start_method__"):
|
if hasattr(method, "__is_start_method__"):
|
||||||
levels[method_name] = 0
|
levels[method_name] = 0
|
||||||
@@ -354,7 +346,6 @@ class PyvisFlowVisualizer(FlowVisualizer):
|
|||||||
current_level = levels[current]
|
current_level = levels[current]
|
||||||
visited.add(current)
|
visited.add(current)
|
||||||
|
|
||||||
# Get methods that listen to the current method
|
|
||||||
for listener_name, (
|
for listener_name, (
|
||||||
condition_type,
|
condition_type,
|
||||||
trigger_methods,
|
trigger_methods,
|
||||||
@@ -382,7 +373,7 @@ class PyvisFlowVisualizer(FlowVisualizer):
|
|||||||
if listener_name not in visited:
|
if listener_name not in visited:
|
||||||
queue.append(listener_name)
|
queue.append(listener_name)
|
||||||
|
|
||||||
# Handle router connections (same as before)
|
# Handle router connections
|
||||||
if current in self.flow._routers.values():
|
if current in self.flow._routers.values():
|
||||||
router_method_name = current
|
router_method_name = current
|
||||||
paths = self.flow._router_paths.get(router_method_name, [])
|
paths = self.flow._router_paths.get(router_method_name, [])
|
||||||
@@ -402,7 +393,6 @@ class PyvisFlowVisualizer(FlowVisualizer):
|
|||||||
return levels
|
return levels
|
||||||
|
|
||||||
def _count_outgoing_edges(self):
|
def _count_outgoing_edges(self):
|
||||||
# Helper method to count the number of outgoing edges from each node
|
|
||||||
counts = {}
|
counts = {}
|
||||||
for method_name in self.flow._methods:
|
for method_name in self.flow._methods:
|
||||||
counts[method_name] = 0
|
counts[method_name] = 0
|
||||||
@@ -419,9 +409,7 @@ class PyvisFlowVisualizer(FlowVisualizer):
|
|||||||
for node in self.flow._methods:
|
for node in self.flow._methods:
|
||||||
if node not in visited:
|
if node not in visited:
|
||||||
self._dfs_ancestors(node, ancestors, visited)
|
self._dfs_ancestors(node, ancestors, visited)
|
||||||
print("Ancestor Relationships:")
|
|
||||||
for node, node_ancestors in ancestors.items():
|
|
||||||
print(f"{node}: {node_ancestors}")
|
|
||||||
return ancestors
|
return ancestors
|
||||||
|
|
||||||
def _dfs_ancestors(self, node, ancestors, visited):
|
def _dfs_ancestors(self, node, ancestors, visited):
|
||||||
@@ -452,6 +440,7 @@ class PyvisFlowVisualizer(FlowVisualizer):
|
|||||||
|
|
||||||
def _build_parent_children_dict(self):
|
def _build_parent_children_dict(self):
|
||||||
parent_children = {}
|
parent_children = {}
|
||||||
|
|
||||||
# Map listeners to their trigger methods
|
# Map listeners to their trigger methods
|
||||||
for listener_name, (_, trigger_methods) in self.flow._listeners.items():
|
for listener_name, (_, trigger_methods) in self.flow._listeners.items():
|
||||||
for trigger in trigger_methods:
|
for trigger in trigger_methods:
|
||||||
@@ -459,6 +448,7 @@ class PyvisFlowVisualizer(FlowVisualizer):
|
|||||||
parent_children[trigger] = []
|
parent_children[trigger] = []
|
||||||
if listener_name not in parent_children[trigger]:
|
if listener_name not in parent_children[trigger]:
|
||||||
parent_children[trigger].append(listener_name)
|
parent_children[trigger].append(listener_name)
|
||||||
|
|
||||||
# Map router methods to their paths and to listeners
|
# Map router methods to their paths and to listeners
|
||||||
for router_method_name, paths in self.flow._router_paths.items():
|
for router_method_name, paths in self.flow._router_paths.items():
|
||||||
for path in paths:
|
for path in paths:
|
||||||
@@ -469,19 +459,15 @@ class PyvisFlowVisualizer(FlowVisualizer):
|
|||||||
parent_children[router_method_name] = []
|
parent_children[router_method_name] = []
|
||||||
if listener_name not in parent_children[router_method_name]:
|
if listener_name not in parent_children[router_method_name]:
|
||||||
parent_children[router_method_name].append(listener_name)
|
parent_children[router_method_name].append(listener_name)
|
||||||
# Debugging output
|
|
||||||
print("Parent-Children Relationships:")
|
|
||||||
for parent, children in parent_children.items():
|
|
||||||
print(f"{parent}: {children}")
|
|
||||||
return parent_children
|
return parent_children
|
||||||
|
|
||||||
def _get_child_index(self, parent, child, parent_children):
|
def _get_child_index(self, parent, child, parent_children):
|
||||||
# Helper method to get the index of the child among the parent's children
|
|
||||||
children = parent_children.get(parent, [])
|
children = parent_children.get(parent, [])
|
||||||
children.sort()
|
children.sort()
|
||||||
return children.index(child)
|
return children.index(child)
|
||||||
|
|
||||||
|
|
||||||
def visualize_flow(flow, filename="flow_graph"):
|
def visualize_flow(flow, filename="flow_graph"):
|
||||||
visualizer = PyvisFlowVisualizer(flow)
|
visualizer = FlowVisualizer(flow)
|
||||||
visualizer.visualize(filename)
|
visualizer.visualize(filename)
|
||||||
|
|||||||
Reference in New Issue
Block a user