From 5d645cd89f63884f9338de7c27041a8611f2b193 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Fri, 27 Sep 2024 16:01:07 -0400 Subject: [PATCH] regular methods and triggers working. Need to work on router next. --- .../flow/{fonts => assets}/arial_bold.ttf | Bin .../assets/crewai_flow_visual_template.html | 93 +++++++ src/crewai/flow/assets/crewai_logo.svg | 12 + src/crewai/flow/flow_visualizer.py | 257 ++++++++---------- 4 files changed, 222 insertions(+), 140 deletions(-) rename src/crewai/flow/{fonts => assets}/arial_bold.ttf (100%) create mode 100644 src/crewai/flow/assets/crewai_flow_visual_template.html create mode 100644 src/crewai/flow/assets/crewai_logo.svg diff --git a/src/crewai/flow/fonts/arial_bold.ttf b/src/crewai/flow/assets/arial_bold.ttf similarity index 100% rename from src/crewai/flow/fonts/arial_bold.ttf rename to src/crewai/flow/assets/arial_bold.ttf diff --git a/src/crewai/flow/assets/crewai_flow_visual_template.html b/src/crewai/flow/assets/crewai_flow_visual_template.html new file mode 100644 index 000000000..f175ef1a7 --- /dev/null +++ b/src/crewai/flow/assets/crewai_flow_visual_template.html @@ -0,0 +1,93 @@ + + + + + {{ title }} + + + + + +
+
+
+
+
+ + +
+
+ {{ network_content }} + + diff --git a/src/crewai/flow/assets/crewai_logo.svg b/src/crewai/flow/assets/crewai_logo.svg new file mode 100644 index 000000000..1668a48e5 --- /dev/null +++ b/src/crewai/flow/assets/crewai_logo.svg @@ -0,0 +1,12 @@ + + + + + + + + + + + + diff --git a/src/crewai/flow/flow_visualizer.py b/src/crewai/flow/flow_visualizer.py index d473a34bc..ef599c436 100644 --- a/src/crewai/flow/flow_visualizer.py +++ b/src/crewai/flow/flow_visualizer.py @@ -1,5 +1,6 @@ -import shutil -import warnings +import base64 +import os +import re from abc import ABC, abstractmethod from pyvis.network import Network @@ -16,92 +17,29 @@ class FlowVisualizer(ABC): "edge": "#666666", "text": "#FFFFFF", } + self.node_styles = { + "start": { + "color": self.colors["start"], + "shape": "box", + "font": {"color": self.colors["text"]}, + }, + "method": { + "color": self.colors["method"], + "shape": "box", + "font": {"color": self.colors["text"]}, + }, + "router": { + "color": self.colors["router"], + "shape": "box", + "font": {"color": self.colors["text"]}, + }, + } @abstractmethod def visualize(self, filename): pass -class GraphvizVisualizer(FlowVisualizer): - def visualize(self, filename): - import graphviz - - dot = graphviz.Digraph(comment="Flow Graph", engine="dot") - dot.attr(rankdir="TB", size="20,20", splines="curved") - dot.attr(bgcolor=self.colors["bg"]) - - # Add nodes - for method_name, method in self.flow._methods.items(): - if ( - hasattr(method, "__is_start_method__") - or method_name in self.flow._listeners - or method_name in self.flow._routers.values() - ): - shape = "rectangle" - style = "filled,rounded" - fillcolor = ( - self.colors["start"] - if hasattr(method, "__is_start_method__") - else self.colors["method"] - ) - - dot.node( - method_name, - method_name, - shape=shape, - style=style, - fillcolor=fillcolor, - fontcolor=self.colors["text"], - penwidth="2", - ) - - # Add edges and routers - for method_name, method in self.flow._methods.items(): - if method_name in self.flow._listeners: - condition_type, trigger_methods = self.flow._listeners[method_name] - for trigger in trigger_methods: - style = "dashed" if condition_type == "AND" else "solid" - dot.edge( - trigger, - method_name, - color=self.colors["edge"], - style=style, - penwidth="2", - ) - - if method_name in self.flow._routers.values(): - for trigger, router in self.flow._routers.items(): - if router == method_name: - subgraph_name = f"cluster_{method_name}" - subgraph = graphviz.Digraph(name=subgraph_name) - subgraph.attr( - label="", - style="filled,rounded", - color=self.colors["router_outline"], - fillcolor=self.colors["method"], - penwidth="3", - ) - label = f"{method_name}\\n\\nPossible outcomes:\\n• Success\\n• Failure" - subgraph.node( - method_name, - label, - shape="plaintext", - fontcolor=self.colors["text"], - ) - dot.subgraph(subgraph) - dot.edge( - trigger, - method_name, - color=self.colors["edge"], - style="solid", - penwidth="2", - lhead=subgraph_name, - ) - - dot.render(filename, format="png", cleanup=True, view=True) - print(f"Graph saved as {filename}.png") - - class PyvisFlowVisualizer(FlowVisualizer): def visualize(self, filename): net = Network( @@ -112,25 +50,6 @@ class PyvisFlowVisualizer(FlowVisualizer): layout=None, ) - # Define custom node styles - node_styles = { - "start": { - "color": self.colors.get("start", "#FF5A50"), - "shape": "box", - "font": {"color": self.colors.get("text", "#FFFFFF")}, - }, - "method": { - "color": self.colors.get("method", "#333333"), - "shape": "box", - "font": {"color": self.colors.get("text", "#FFFFFF")}, - }, - "router": { - "color": self.colors.get("router", "#FF8C00"), - "shape": "box", - "font": {"color": self.colors.get("text", "#FFFFFF")}, - }, - } - # Calculate levels for nodes node_levels = self._calculate_node_levels() @@ -150,11 +69,11 @@ class PyvisFlowVisualizer(FlowVisualizer): y = level * y_spacing # Use level directly for y position method = self.flow._methods.get(method_name) if hasattr(method, "__is_start_method__"): - node_style = node_styles["start"] + node_style = self.node_styles["start"] elif method_name in self.flow._routers.values(): - node_style = node_styles["router"] + node_style = self.node_styles["router"] else: - node_style = node_styles["method"] + node_style = self.node_styles["method"] net.add_node( method_name, @@ -185,23 +104,101 @@ class PyvisFlowVisualizer(FlowVisualizer): # Set options for curved edges and disable physics net.set_options( """ - var options = { - "physics": { - "enabled": false - }, - "edges": { - "smooth": { - "enabled": true, - "type": "cubicBezier", - "roundness": 0.5 - } - } - } - """ + var options = { + "physics": { + "enabled": false + }, + "edges": { + "smooth": { + "enabled": true, + "type": "cubicBezier", + "roundness": 0.5 + } + } + } + """ ) - # Generate and save the graph - net.write_html(f"{filename}.html") + network_html = net.generate_html() + + # Extract just the body content from the generated HTML + match = re.search("(.*?)", network_html, re.DOTALL) + if match: + network_body = match.group(1) + else: + network_body = "" + + # Read the custom template + current_dir = os.path.dirname(__file__) + template_path = os.path.join( + current_dir, "assets", "crewai_flow_visual_template.html" + ) + with open(template_path, "r", encoding="utf-8") as f: + html_template = f.read() + + # Generate the legend items HTML + legend_items = [ + {"label": "Start Method", "color": self.colors.get("start", "#FF5A50")}, + {"label": "Method", "color": self.colors.get("method", "#333333")}, + # {"label": "Router", "color": self.colors.get("router", "#FF8C00")}, + { + "label": "Trigger", + "color": self.colors.get("edge", "#666666"), + "dashed": False, + }, + { + "label": "AND Trigger", + "color": self.colors.get("edge", "#666666"), + "dashed": True, + }, + ] + + legend_items_html = "" + for item in legend_items: + if item.get("dashed") is not None: + if item.get("dashed"): + legend_items_html += f""" +
+
+
{item['label']}
+
+ """ + else: + legend_items_html += f""" +
+
+
{item['label']}
+
+ """ + else: + legend_items_html += f""" +
+
+
{item['label']}
+
+ """ + + # Read the logo file and encode it + logo_path = os.path.join(current_dir, "assets", "crewai_logo.svg") + with open(logo_path, "rb") as logo_file: + logo_svg_data = logo_file.read() + logo_svg_base64 = base64.b64encode(logo_svg_data).decode("utf-8") + + # Replace placeholders in the template + final_html_content = html_template.replace("{{ title }}", "Flow Graph") + final_html_content = final_html_content.replace( + "{{ network_content }}", network_body + ) + final_html_content = final_html_content.replace( + "{{ logo_svg_base64 }}", logo_svg_base64 + ) + final_html_content = final_html_content.replace( + "", legend_items_html + ) + + # Save the final HTML content to the file + with open(f"{filename}.html", "w", encoding="utf-8") as f: + f.write(final_html_content) print(f"Graph saved as {filename}.html") def _calculate_node_levels(self): @@ -238,26 +235,6 @@ class PyvisFlowVisualizer(FlowVisualizer): return levels -def is_graphviz_available(): - try: - import graphviz - - if shutil.which("dot") is None: # Check for Graphviz executable - raise ImportError("Graphviz executable not found") - return True - except ImportError: - return False - - def visualize_flow(flow, filename="flow_graph"): - if False: - visualizer = GraphvizVisualizer(flow) - else: - warnings.warn( - "Graphviz is not available. Falling back to NetworkX and Matplotlib for visualization. " - "For better visualization, please install Graphviz. " - "See our documentation for installation instructions: https://docs.crewai.com/advanced-usage/visualization/" - ) - visualizer = PyvisFlowVisualizer(flow) - + visualizer = PyvisFlowVisualizer(flow) visualizer.visualize(filename)