diff --git a/poetry.lock b/poetry.lock index c1bdc9f8e..4bb2c71e7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -7663,4 +7663,4 @@ tools = ["crewai-tools"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<=3.13" -content-hash = "8edc2b56582cce28793790bf6526cf35ccf54b982a5cfd97330f0f3d6ac2a5b9" +content-hash = "13875b4236719007d8c126a03deefc6c59ce6717e39547d3d099053a89359eb0" diff --git a/pyproject.toml b/pyproject.toml index 49efc7f83..36d218e23 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ networkx = "^3.3" ipython = "^8.27.0" pyvis = "^0.3.2" playwright = "^1.47.0" +pillow = "^10.4.0" [tool.poetry.extras] tools = ["crewai-tools"] diff --git a/src/crewai/flow/flow_visualizer.py b/src/crewai/flow/flow_visualizer.py index 737501d6a..beaaa36ee 100644 --- a/src/crewai/flow/flow_visualizer.py +++ b/src/crewai/flow/flow_visualizer.py @@ -1,9 +1,8 @@ -import math import shutil import warnings from abc import ABC, abstractmethod -from PIL import Image, ImageDraw, ImageFont +from pyvis.network import Network class FlowVisualizer(ABC): @@ -17,6 +16,8 @@ class FlowVisualizer(ABC): "edge": "#333333", "text": "#FFFFFF", } + self.node_rectangles = {} + self.node_positions = {} @abstractmethod def visualize(self, filename): @@ -103,143 +104,66 @@ class GraphvizVisualizer(FlowVisualizer): print(f"Graph saved as {filename}.png") -class PyvisFlowVisualizer: - def __init__(self, flow): - self.flow = flow - self.colors = { - "bg": "#FFFFFF", - "start": "#FF5A50", - "method": "#333333", - "router": "#FF8C00", # Orange color for routers - "edge": "#666666", - "text": "#FFFFFF", +class PyvisFlowVisualizer(FlowVisualizer): + def visualize(self, filename): + net = Network( + directed=True, height="750px", width="100%", bgcolor=self.colors["bg"] + ) + + # Define custom node styles + node_styles = { + "start": { + "color": self.colors["start"], + "shape": "box", + "font": {"color": self.colors["text"]}, + }, + "method": { + "color": self.colors["method"], + "shape": "box", + "font": {"color": self.colors["text"]}, + }, + # "router": { + # "color": self.colors["router"], + # "shape": "box", + # "font": {"color": self.colors["text"]}, + # }, } - def visualize(self, filename): - # Get decorated methods - start_methods = [ - name - for name, method in self.flow._methods.items() - if hasattr(method, "__is_start_method__") - ] - listen_methods = list(self.flow._listeners.keys()) - router_methods = list(self.flow._routers.values()) + # Add nodes + for method_name, method in self.flow._methods.items(): + if ( + hasattr(method, "__is_start_method__") + or method_name in self.flow._listeners + or method_name in self.flow._routers.values() + ): + if hasattr(method, "__is_start_method__"): + node_style = node_styles["start"] + elif method_name in self.flow._routers.values(): + node_style = node_styles["router"] + else: + node_style = node_styles["method"] - all_methods = start_methods + listen_methods + router_methods - node_positions = self._calculate_positions(all_methods) + net.add_node(method_name, label=method_name, **node_style) - # Create image - img_width = 800 - img_height = len(all_methods) * 120 + 100 - img = Image.new("RGB", (img_width, img_height), color=self.colors["bg"]) - draw = ImageDraw.Draw(img) + # Add edges + for method_name in self.flow._listeners: + condition_type, trigger_methods = self.flow._listeners[method_name] + is_and_condition = condition_type == "AND" + for trigger in trigger_methods: + if trigger in self.flow._methods: + net.add_edge( + trigger, + method_name, + color=self.colors["edge"], + width=2, + arrows="to", + dashes=is_and_condition, # Dashed lines for AND conditions + smooth={"type": "cubicBezier"}, + ) - # Draw edges - for method_name in listen_methods + router_methods: - if method_name in self.flow._listeners: - _, trigger_methods = self.flow._listeners[method_name] - for trigger in trigger_methods: - if trigger in node_positions and method_name in node_positions: - start = node_positions[trigger] - end = node_positions[method_name] - self._draw_curved_arrow(draw, start, end, self.colors["edge"]) - - # Draw nodes - for method_name, pos in node_positions.items(): - if method_name in start_methods: - color = self.colors["start"] - elif method_name in router_methods: - color = self.colors["router"] - else: - color = self.colors["method"] - - self._draw_node(draw, method_name, pos, color) - - # Save image - img.save(f"{filename}.png") - print(f"Graph saved as {filename}.png") - - def _calculate_positions(self, nodes): - positions = {} - start_methods = [ - node - for node in nodes - if hasattr(self.flow._methods[node], "__is_start_method__") - ] - other_methods = [node for node in nodes if node not in start_methods] - - # Position start methods at the top - for i, node in enumerate(start_methods): - positions[node] = (400, 100 + i * 120) - - # Position other methods below start methods - for i, node in enumerate(other_methods): - positions[node] = (400, 100 + (len(start_methods) + i) * 120) - - return positions - - def _draw_node(self, draw, label, position, color): - x, y = position - if color == self.colors["router"]: - # Draw router node as rounded rectangle - draw.rounded_rectangle( - [x - 70, y - 40, x + 70, y + 40], - radius=10, - fill=color, - outline=self.colors["edge"], - ) - font = ImageFont.load_default() - text_width = draw.textlength(label, font=font) - draw.text( - (x - text_width / 2, y - 20), label, fill=self.colors["text"], font=font - ) - draw.text((x - 30, y + 5), "Success", fill=self.colors["text"], font=font) - draw.text((x - 30, y + 25), "Failure", fill=self.colors["text"], font=font) - else: - # Draw regular node - draw.rectangle( - [x - 60, y - 30, x + 60, y + 30], - fill=color, - outline=self.colors["edge"], - ) - font = ImageFont.load_default() - text_width = draw.textlength(label, font=font) - draw.text( - (x - text_width / 2, y - 7), label, fill=self.colors["text"], font=font - ) - - def _draw_curved_arrow(self, draw, start, end, color): - # Calculate control point for the curve - control_x = (start[0] + end[0]) / 2 - control_y = ( - start[1] + end[1] - ) / 2 - 50 # Adjust this value to change curve height - - # Draw the curved line - points = [start, (control_x, control_y), end] - draw.line(points, fill=color, width=2, joint="curve") - - # Draw arrow head - self._draw_arrow_head(draw, points[-2], end, color) - - def _draw_arrow_head(self, draw, start, end, color): - angle = math.atan2(end[1] - start[1], end[0] - start[0]) - x = end[0] - 15 * math.cos(angle) - y = end[1] - 15 * math.sin(angle) - draw.polygon( - [ - (x, y), - ( - x - 10 * math.cos(angle - math.pi / 6), - y - 10 * math.sin(angle - math.pi / 6), - ), - ( - x - 10 * math.cos(angle + math.pi / 6), - y - 10 * math.sin(angle + math.pi / 6), - ), - ], - fill=color, - ) + # Generate and save the graph + net.write_html(f"{filename}.html") + print(f"Graph saved as {filename}.html") def is_graphviz_available(): diff --git a/src/crewai/flow/fonts/arial_bold.ttf b/src/crewai/flow/fonts/arial_bold.ttf new file mode 100644 index 000000000..940e255d0 Binary files /dev/null and b/src/crewai/flow/fonts/arial_bold.ttf differ