mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 08:12:39 +00:00
pyvis is beginning to work
This commit is contained in:
@@ -2,6 +2,7 @@ import asyncio
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, Generic, List, Set, Type, TypeVar, Union
|
||||
|
||||
from crewai.flow.flow_visualizer import visualize_flow
|
||||
from pydantic import BaseModel
|
||||
|
||||
T = TypeVar("T", bound=Union[BaseModel, Dict[str, Any]])
|
||||
@@ -250,3 +251,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
def visualize(self, filename: str = "crewai_flow_graph"):
|
||||
visualize_flow(self, filename)
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import math
|
||||
import shutil
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import requests
|
||||
from IPython.display import Image, display
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
|
||||
class FlowVisualizer(ABC):
|
||||
@@ -103,73 +103,144 @@ class GraphvizVisualizer(FlowVisualizer):
|
||||
print(f"Graph saved as {filename}.png")
|
||||
|
||||
|
||||
class MermaidFlowVisualizer(FlowVisualizer):
|
||||
class PyvisFlowVisualizer:
|
||||
def __init__(self, flow):
|
||||
self.flow = flow
|
||||
self.colors = {
|
||||
"bg": "#FFFFFF",
|
||||
"start": "#FF5A50",
|
||||
"method": "#333333",
|
||||
"router": "#FF8C00", # Orange color for routers
|
||||
"edge": "#666666",
|
||||
"text": "#FFFFFF",
|
||||
}
|
||||
|
||||
def visualize(self, filename):
|
||||
mermaid_code = self.generate_mermaid_code()
|
||||
# Get decorated methods
|
||||
start_methods = [
|
||||
name
|
||||
for name, method in self.flow._methods.items()
|
||||
if hasattr(method, "__is_start_method__")
|
||||
]
|
||||
listen_methods = list(self.flow._listeners.keys())
|
||||
router_methods = list(self.flow._routers.values())
|
||||
|
||||
# Use Mermaid.ink API to generate the diagram
|
||||
response = requests.post(
|
||||
"https://mermaid.ink/img/",
|
||||
data=mermaid_code,
|
||||
headers={"Content-Type": "text/plain"},
|
||||
)
|
||||
all_methods = start_methods + listen_methods + router_methods
|
||||
node_positions = self._calculate_positions(all_methods)
|
||||
|
||||
if response.status_code == 200:
|
||||
image_url = response.url
|
||||
print(f"Graph available at {image_url}")
|
||||
# Create image
|
||||
img_width = 800
|
||||
img_height = len(all_methods) * 120 + 100
|
||||
img = Image.new("RGB", (img_width, img_height), color=self.colors["bg"])
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# Optionally, download the image and save it locally
|
||||
image_data = requests.get(image_url).content
|
||||
with open(f"{filename}.png", "wb") as f:
|
||||
f.write(image_data)
|
||||
print(f"Graph saved as {filename}.png")
|
||||
|
||||
# Display the image in Jupyter notebook
|
||||
display(Image(image_url))
|
||||
else:
|
||||
print(f"Failed to generate graph: {response.status_code} {response.text}")
|
||||
|
||||
def generate_mermaid_code(self):
|
||||
mermaid_code = ["graph TB"]
|
||||
|
||||
# Add nodes
|
||||
for method_name, method in self.flow._methods.items():
|
||||
if (
|
||||
hasattr(method, "__is_start_method__")
|
||||
or method_name in self.flow._listeners
|
||||
or method_name in self.flow._routers.values()
|
||||
):
|
||||
shape = '((" "))' if hasattr(method, "__is_start_method__") else '[" "]'
|
||||
color = (
|
||||
self.colors["start"]
|
||||
if hasattr(method, "__is_start_method__")
|
||||
else self.colors["method"]
|
||||
)
|
||||
mermaid_code.append(
|
||||
f' {method_name}{shape}:::{"startNode" if hasattr(method, "__is_start_method__") else "methodNode"}'
|
||||
)
|
||||
mermaid_code.append(
|
||||
f' style {method_name} fill:{color},color:{self.colors["text"]}'
|
||||
)
|
||||
|
||||
# Add edges
|
||||
for method_name, method in self.flow._methods.items():
|
||||
# Draw edges
|
||||
for method_name in listen_methods + router_methods:
|
||||
if method_name in self.flow._listeners:
|
||||
condition_type, trigger_methods = self.flow._listeners[method_name]
|
||||
_, trigger_methods = self.flow._listeners[method_name]
|
||||
for trigger in trigger_methods:
|
||||
edge_style = " -.- " if condition_type == "AND" else " --> "
|
||||
mermaid_code.append(f" {trigger}{edge_style}{method_name}")
|
||||
if trigger in node_positions and method_name in node_positions:
|
||||
start = node_positions[trigger]
|
||||
end = node_positions[method_name]
|
||||
self._draw_curved_arrow(draw, start, end, self.colors["edge"])
|
||||
|
||||
# Add styles
|
||||
mermaid_code.extend(
|
||||
# Draw nodes
|
||||
for method_name, pos in node_positions.items():
|
||||
if method_name in start_methods:
|
||||
color = self.colors["start"]
|
||||
elif method_name in router_methods:
|
||||
color = self.colors["router"]
|
||||
else:
|
||||
color = self.colors["method"]
|
||||
|
||||
self._draw_node(draw, method_name, pos, color)
|
||||
|
||||
# Save image
|
||||
img.save(f"{filename}.png")
|
||||
print(f"Graph saved as {filename}.png")
|
||||
|
||||
def _calculate_positions(self, nodes):
|
||||
positions = {}
|
||||
start_methods = [
|
||||
node
|
||||
for node in nodes
|
||||
if hasattr(self.flow._methods[node], "__is_start_method__")
|
||||
]
|
||||
other_methods = [node for node in nodes if node not in start_methods]
|
||||
|
||||
# Position start methods at the top
|
||||
for i, node in enumerate(start_methods):
|
||||
positions[node] = (400, 100 + i * 120)
|
||||
|
||||
# Position other methods below start methods
|
||||
for i, node in enumerate(other_methods):
|
||||
positions[node] = (400, 100 + (len(start_methods) + i) * 120)
|
||||
|
||||
return positions
|
||||
|
||||
def _draw_node(self, draw, label, position, color):
|
||||
x, y = position
|
||||
if color == self.colors["router"]:
|
||||
# Draw router node as rounded rectangle
|
||||
draw.rounded_rectangle(
|
||||
[x - 70, y - 40, x + 70, y + 40],
|
||||
radius=10,
|
||||
fill=color,
|
||||
outline=self.colors["edge"],
|
||||
)
|
||||
font = ImageFont.load_default()
|
||||
text_width = draw.textlength(label, font=font)
|
||||
draw.text(
|
||||
(x - text_width / 2, y - 20), label, fill=self.colors["text"], font=font
|
||||
)
|
||||
draw.text((x - 30, y + 5), "Success", fill=self.colors["text"], font=font)
|
||||
draw.text((x - 30, y + 25), "Failure", fill=self.colors["text"], font=font)
|
||||
else:
|
||||
# Draw regular node
|
||||
draw.rectangle(
|
||||
[x - 60, y - 30, x + 60, y + 30],
|
||||
fill=color,
|
||||
outline=self.colors["edge"],
|
||||
)
|
||||
font = ImageFont.load_default()
|
||||
text_width = draw.textlength(label, font=font)
|
||||
draw.text(
|
||||
(x - text_width / 2, y - 7), label, fill=self.colors["text"], font=font
|
||||
)
|
||||
|
||||
def _draw_curved_arrow(self, draw, start, end, color):
|
||||
# Calculate control point for the curve
|
||||
control_x = (start[0] + end[0]) / 2
|
||||
control_y = (
|
||||
start[1] + end[1]
|
||||
) / 2 - 50 # Adjust this value to change curve height
|
||||
|
||||
# Draw the curved line
|
||||
points = [start, (control_x, control_y), end]
|
||||
draw.line(points, fill=color, width=2, joint="curve")
|
||||
|
||||
# Draw arrow head
|
||||
self._draw_arrow_head(draw, points[-2], end, color)
|
||||
|
||||
def _draw_arrow_head(self, draw, start, end, color):
|
||||
angle = math.atan2(end[1] - start[1], end[0] - start[0])
|
||||
x = end[0] - 15 * math.cos(angle)
|
||||
y = end[1] - 15 * math.sin(angle)
|
||||
draw.polygon(
|
||||
[
|
||||
" classDef startNode stroke:#333,stroke-width:4px;",
|
||||
" classDef methodNode stroke:#333,stroke-width:2px;",
|
||||
]
|
||||
(x, y),
|
||||
(
|
||||
x - 10 * math.cos(angle - math.pi / 6),
|
||||
y - 10 * math.sin(angle - math.pi / 6),
|
||||
),
|
||||
(
|
||||
x - 10 * math.cos(angle + math.pi / 6),
|
||||
y - 10 * math.sin(angle + math.pi / 6),
|
||||
),
|
||||
],
|
||||
fill=color,
|
||||
)
|
||||
|
||||
return "\n".join(mermaid_code)
|
||||
|
||||
|
||||
def is_graphviz_available():
|
||||
try:
|
||||
@@ -191,6 +262,6 @@ def visualize_flow(flow, filename="flow_graph"):
|
||||
"For better visualization, please install Graphviz. "
|
||||
"See our documentation for installation instructions: https://docs.crewai.com/advanced-usage/visualization/"
|
||||
)
|
||||
visualizer = MermaidFlowVisualizer(flow)
|
||||
visualizer = PyvisFlowVisualizer(flow)
|
||||
|
||||
visualizer.visualize(filename)
|
||||
|
||||
Reference in New Issue
Block a user