regular methods and triggers working. Need to work on router next.

This commit is contained in:
Brandon Hancock
2024-09-27 16:01:07 -04:00
parent 16fabdd4b5
commit 5d645cd89f
4 changed files with 222 additions and 140 deletions

View File

@@ -0,0 +1,93 @@
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<title>{{ title }}</title>
<script
src="https://cdnjs.cloudflare.com/ajax/libs/vis-network/9.1.2/dist/vis-network.min.js"
integrity="sha512-LnvoEWDFrqGHlHmDD2101OrLcbsfkrzoSpvtSQtxK3RMnRV0eOkhhBN2dXHKRrUU8p2DGRTk35n4O8nWSVe1mQ=="
crossorigin="anonymous"
referrerpolicy="no-referrer"
></script>
<link
rel="stylesheet"
href="https://cdnjs.cloudflare.com/ajax/libs/vis-network/9.1.2/dist/dist/vis-network.min.css"
integrity="sha512-WgxfT5LWjfszlPHXRmBWHkV2eceiWTOBvrKCNbdgDYTHrT2AeLCGbF4sZlZw3UMN3WtL0tGUoIAKsu8mllg/XA=="
crossorigin="anonymous"
referrerpolicy="no-referrer"
/>
<style type="text/css">
body {
font-family: verdana;
margin: 0;
padding: 0;
}
.container {
display: flex;
flex-direction: column;
height: 100vh;
}
#mynetwork {
flex-grow: 1;
width: 100%;
height: 750px;
background-color: #ffffff;
}
.card {
border: none;
}
.legend-container {
display: flex;
align-items: center;
justify-content: center;
padding: 10px;
background-color: #f8f9fa;
position: fixed; /* Make the legend fixed */
bottom: 0; /* Position it at the bottom */
width: 100%; /* Make it span the full width */
}
.legend-item {
display: flex;
align-items: center;
margin-right: 20px;
}
.legend-color-box {
width: 20px;
height: 20px;
margin-right: 5px;
}
.logo {
height: 50px;
margin-right: 20px;
}
.legend-dashed {
border-bottom: 2px dashed #666666;
width: 20px;
height: 0;
margin-right: 5px;
}
.legend-solid {
border-bottom: 2px solid #666666;
width: 20px;
height: 0;
margin-right: 5px;
}
</style>
</head>
<body>
<div class="container">
<div class="card" style="width: 100%">
<div id="mynetwork" class="card-body"></div>
</div>
<div class="legend-container">
<img
src="data:image/svg+xml;base64,{{ logo_svg_base64 }}"
alt="CrewAI logo"
class="logo"
/>
<!-- LEGEND_ITEMS_PLACEHOLDER -->
</div>
</div>
{{ network_content }}
</body>
</html>

File diff suppressed because one or more lines are too long

After

Width:  |  Height:  |  Size: 27 KiB

View File

@@ -1,5 +1,6 @@
import shutil import base64
import warnings import os
import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from pyvis.network import Network from pyvis.network import Network
@@ -16,92 +17,29 @@ class FlowVisualizer(ABC):
"edge": "#666666", "edge": "#666666",
"text": "#FFFFFF", "text": "#FFFFFF",
} }
self.node_styles = {
"start": {
"color": self.colors["start"],
"shape": "box",
"font": {"color": self.colors["text"]},
},
"method": {
"color": self.colors["method"],
"shape": "box",
"font": {"color": self.colors["text"]},
},
"router": {
"color": self.colors["router"],
"shape": "box",
"font": {"color": self.colors["text"]},
},
}
@abstractmethod @abstractmethod
def visualize(self, filename): def visualize(self, filename):
pass pass
class GraphvizVisualizer(FlowVisualizer):
def visualize(self, filename):
import graphviz
dot = graphviz.Digraph(comment="Flow Graph", engine="dot")
dot.attr(rankdir="TB", size="20,20", splines="curved")
dot.attr(bgcolor=self.colors["bg"])
# Add nodes
for method_name, method in self.flow._methods.items():
if (
hasattr(method, "__is_start_method__")
or method_name in self.flow._listeners
or method_name in self.flow._routers.values()
):
shape = "rectangle"
style = "filled,rounded"
fillcolor = (
self.colors["start"]
if hasattr(method, "__is_start_method__")
else self.colors["method"]
)
dot.node(
method_name,
method_name,
shape=shape,
style=style,
fillcolor=fillcolor,
fontcolor=self.colors["text"],
penwidth="2",
)
# Add edges and routers
for method_name, method in self.flow._methods.items():
if method_name in self.flow._listeners:
condition_type, trigger_methods = self.flow._listeners[method_name]
for trigger in trigger_methods:
style = "dashed" if condition_type == "AND" else "solid"
dot.edge(
trigger,
method_name,
color=self.colors["edge"],
style=style,
penwidth="2",
)
if method_name in self.flow._routers.values():
for trigger, router in self.flow._routers.items():
if router == method_name:
subgraph_name = f"cluster_{method_name}"
subgraph = graphviz.Digraph(name=subgraph_name)
subgraph.attr(
label="",
style="filled,rounded",
color=self.colors["router_outline"],
fillcolor=self.colors["method"],
penwidth="3",
)
label = f"{method_name}\\n\\nPossible outcomes:\\n• Success\\n• Failure"
subgraph.node(
method_name,
label,
shape="plaintext",
fontcolor=self.colors["text"],
)
dot.subgraph(subgraph)
dot.edge(
trigger,
method_name,
color=self.colors["edge"],
style="solid",
penwidth="2",
lhead=subgraph_name,
)
dot.render(filename, format="png", cleanup=True, view=True)
print(f"Graph saved as {filename}.png")
class PyvisFlowVisualizer(FlowVisualizer): class PyvisFlowVisualizer(FlowVisualizer):
def visualize(self, filename): def visualize(self, filename):
net = Network( net = Network(
@@ -112,25 +50,6 @@ class PyvisFlowVisualizer(FlowVisualizer):
layout=None, layout=None,
) )
# Define custom node styles
node_styles = {
"start": {
"color": self.colors.get("start", "#FF5A50"),
"shape": "box",
"font": {"color": self.colors.get("text", "#FFFFFF")},
},
"method": {
"color": self.colors.get("method", "#333333"),
"shape": "box",
"font": {"color": self.colors.get("text", "#FFFFFF")},
},
"router": {
"color": self.colors.get("router", "#FF8C00"),
"shape": "box",
"font": {"color": self.colors.get("text", "#FFFFFF")},
},
}
# Calculate levels for nodes # Calculate levels for nodes
node_levels = self._calculate_node_levels() node_levels = self._calculate_node_levels()
@@ -150,11 +69,11 @@ class PyvisFlowVisualizer(FlowVisualizer):
y = level * y_spacing # Use level directly for y position y = level * y_spacing # Use level directly for y position
method = self.flow._methods.get(method_name) method = self.flow._methods.get(method_name)
if hasattr(method, "__is_start_method__"): if hasattr(method, "__is_start_method__"):
node_style = node_styles["start"] node_style = self.node_styles["start"]
elif method_name in self.flow._routers.values(): elif method_name in self.flow._routers.values():
node_style = node_styles["router"] node_style = self.node_styles["router"]
else: else:
node_style = node_styles["method"] node_style = self.node_styles["method"]
net.add_node( net.add_node(
method_name, method_name,
@@ -185,23 +104,101 @@ class PyvisFlowVisualizer(FlowVisualizer):
# Set options for curved edges and disable physics # Set options for curved edges and disable physics
net.set_options( net.set_options(
""" """
var options = { var options = {
"physics": { "physics": {
"enabled": false "enabled": false
}, },
"edges": { "edges": {
"smooth": { "smooth": {
"enabled": true, "enabled": true,
"type": "cubicBezier", "type": "cubicBezier",
"roundness": 0.5 "roundness": 0.5
} }
} }
} }
""" """
) )
# Generate and save the graph network_html = net.generate_html()
net.write_html(f"{filename}.html")
# Extract just the body content from the generated HTML
match = re.search("<body.*?>(.*?)</body>", network_html, re.DOTALL)
if match:
network_body = match.group(1)
else:
network_body = ""
# Read the custom template
current_dir = os.path.dirname(__file__)
template_path = os.path.join(
current_dir, "assets", "crewai_flow_visual_template.html"
)
with open(template_path, "r", encoding="utf-8") as f:
html_template = f.read()
# Generate the legend items HTML
legend_items = [
{"label": "Start Method", "color": self.colors.get("start", "#FF5A50")},
{"label": "Method", "color": self.colors.get("method", "#333333")},
# {"label": "Router", "color": self.colors.get("router", "#FF8C00")},
{
"label": "Trigger",
"color": self.colors.get("edge", "#666666"),
"dashed": False,
},
{
"label": "AND Trigger",
"color": self.colors.get("edge", "#666666"),
"dashed": True,
},
]
legend_items_html = ""
for item in legend_items:
if item.get("dashed") is not None:
if item.get("dashed"):
legend_items_html += f"""
<div class="legend-item">
<div class="legend-dashed"></div>
<div>{item['label']}</div>
</div>
"""
else:
legend_items_html += f"""
<div class="legend-item">
<div class="legend-solid" style="border-bottom: 2px solid {item['color']};"></div>
<div>{item['label']}</div>
</div>
"""
else:
legend_items_html += f"""
<div class="legend-item">
<div class="legend-color-box" style="background-color: {item['color']};"></div>
<div>{item['label']}</div>
</div>
"""
# Read the logo file and encode it
logo_path = os.path.join(current_dir, "assets", "crewai_logo.svg")
with open(logo_path, "rb") as logo_file:
logo_svg_data = logo_file.read()
logo_svg_base64 = base64.b64encode(logo_svg_data).decode("utf-8")
# Replace placeholders in the template
final_html_content = html_template.replace("{{ title }}", "Flow Graph")
final_html_content = final_html_content.replace(
"{{ network_content }}", network_body
)
final_html_content = final_html_content.replace(
"{{ logo_svg_base64 }}", logo_svg_base64
)
final_html_content = final_html_content.replace(
"<!-- LEGEND_ITEMS_PLACEHOLDER -->", legend_items_html
)
# Save the final HTML content to the file
with open(f"{filename}.html", "w", encoding="utf-8") as f:
f.write(final_html_content)
print(f"Graph saved as {filename}.html") print(f"Graph saved as {filename}.html")
def _calculate_node_levels(self): def _calculate_node_levels(self):
@@ -238,26 +235,6 @@ class PyvisFlowVisualizer(FlowVisualizer):
return levels return levels
def is_graphviz_available():
try:
import graphviz
if shutil.which("dot") is None: # Check for Graphviz executable
raise ImportError("Graphviz executable not found")
return True
except ImportError:
return False
def visualize_flow(flow, filename="flow_graph"): def visualize_flow(flow, filename="flow_graph"):
if False: visualizer = PyvisFlowVisualizer(flow)
visualizer = GraphvizVisualizer(flow)
else:
warnings.warn(
"Graphviz is not available. Falling back to NetworkX and Matplotlib for visualization. "
"For better visualization, please install Graphviz. "
"See our documentation for installation instructions: https://docs.crewai.com/advanced-usage/visualization/"
)
visualizer = PyvisFlowVisualizer(flow)
visualizer.visualize(filename) visualizer.visualize(filename)