mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
pyvis working
This commit is contained in:
2
poetry.lock
generated
2
poetry.lock
generated
@@ -7663,4 +7663,4 @@ tools = ["crewai-tools"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = ">=3.10,<=3.13"
|
python-versions = ">=3.10,<=3.13"
|
||||||
content-hash = "8edc2b56582cce28793790bf6526cf35ccf54b982a5cfd97330f0f3d6ac2a5b9"
|
content-hash = "13875b4236719007d8c126a03deefc6c59ce6717e39547d3d099053a89359eb0"
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ networkx = "^3.3"
|
|||||||
ipython = "^8.27.0"
|
ipython = "^8.27.0"
|
||||||
pyvis = "^0.3.2"
|
pyvis = "^0.3.2"
|
||||||
playwright = "^1.47.0"
|
playwright = "^1.47.0"
|
||||||
|
pillow = "^10.4.0"
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
tools = ["crewai-tools"]
|
tools = ["crewai-tools"]
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
import math
|
|
||||||
import shutil
|
import shutil
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from pyvis.network import Network
|
||||||
|
|
||||||
|
|
||||||
class FlowVisualizer(ABC):
|
class FlowVisualizer(ABC):
|
||||||
@@ -17,6 +16,8 @@ class FlowVisualizer(ABC):
|
|||||||
"edge": "#333333",
|
"edge": "#333333",
|
||||||
"text": "#FFFFFF",
|
"text": "#FFFFFF",
|
||||||
}
|
}
|
||||||
|
self.node_rectangles = {}
|
||||||
|
self.node_positions = {}
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def visualize(self, filename):
|
def visualize(self, filename):
|
||||||
@@ -103,143 +104,66 @@ class GraphvizVisualizer(FlowVisualizer):
|
|||||||
print(f"Graph saved as {filename}.png")
|
print(f"Graph saved as {filename}.png")
|
||||||
|
|
||||||
|
|
||||||
class PyvisFlowVisualizer:
|
class PyvisFlowVisualizer(FlowVisualizer):
|
||||||
def __init__(self, flow):
|
def visualize(self, filename):
|
||||||
self.flow = flow
|
net = Network(
|
||||||
self.colors = {
|
directed=True, height="750px", width="100%", bgcolor=self.colors["bg"]
|
||||||
"bg": "#FFFFFF",
|
)
|
||||||
"start": "#FF5A50",
|
|
||||||
"method": "#333333",
|
# Define custom node styles
|
||||||
"router": "#FF8C00", # Orange color for routers
|
node_styles = {
|
||||||
"edge": "#666666",
|
"start": {
|
||||||
"text": "#FFFFFF",
|
"color": self.colors["start"],
|
||||||
|
"shape": "box",
|
||||||
|
"font": {"color": self.colors["text"]},
|
||||||
|
},
|
||||||
|
"method": {
|
||||||
|
"color": self.colors["method"],
|
||||||
|
"shape": "box",
|
||||||
|
"font": {"color": self.colors["text"]},
|
||||||
|
},
|
||||||
|
# "router": {
|
||||||
|
# "color": self.colors["router"],
|
||||||
|
# "shape": "box",
|
||||||
|
# "font": {"color": self.colors["text"]},
|
||||||
|
# },
|
||||||
}
|
}
|
||||||
|
|
||||||
def visualize(self, filename):
|
# Add nodes
|
||||||
# Get decorated methods
|
for method_name, method in self.flow._methods.items():
|
||||||
start_methods = [
|
if (
|
||||||
name
|
hasattr(method, "__is_start_method__")
|
||||||
for name, method in self.flow._methods.items()
|
or method_name in self.flow._listeners
|
||||||
if hasattr(method, "__is_start_method__")
|
or method_name in self.flow._routers.values()
|
||||||
]
|
):
|
||||||
listen_methods = list(self.flow._listeners.keys())
|
if hasattr(method, "__is_start_method__"):
|
||||||
router_methods = list(self.flow._routers.values())
|
node_style = node_styles["start"]
|
||||||
|
elif method_name in self.flow._routers.values():
|
||||||
|
node_style = node_styles["router"]
|
||||||
|
else:
|
||||||
|
node_style = node_styles["method"]
|
||||||
|
|
||||||
all_methods = start_methods + listen_methods + router_methods
|
net.add_node(method_name, label=method_name, **node_style)
|
||||||
node_positions = self._calculate_positions(all_methods)
|
|
||||||
|
|
||||||
# Create image
|
# Add edges
|
||||||
img_width = 800
|
for method_name in self.flow._listeners:
|
||||||
img_height = len(all_methods) * 120 + 100
|
condition_type, trigger_methods = self.flow._listeners[method_name]
|
||||||
img = Image.new("RGB", (img_width, img_height), color=self.colors["bg"])
|
is_and_condition = condition_type == "AND"
|
||||||
draw = ImageDraw.Draw(img)
|
for trigger in trigger_methods:
|
||||||
|
if trigger in self.flow._methods:
|
||||||
|
net.add_edge(
|
||||||
|
trigger,
|
||||||
|
method_name,
|
||||||
|
color=self.colors["edge"],
|
||||||
|
width=2,
|
||||||
|
arrows="to",
|
||||||
|
dashes=is_and_condition, # Dashed lines for AND conditions
|
||||||
|
smooth={"type": "cubicBezier"},
|
||||||
|
)
|
||||||
|
|
||||||
# Draw edges
|
# Generate and save the graph
|
||||||
for method_name in listen_methods + router_methods:
|
net.write_html(f"{filename}.html")
|
||||||
if method_name in self.flow._listeners:
|
print(f"Graph saved as {filename}.html")
|
||||||
_, trigger_methods = self.flow._listeners[method_name]
|
|
||||||
for trigger in trigger_methods:
|
|
||||||
if trigger in node_positions and method_name in node_positions:
|
|
||||||
start = node_positions[trigger]
|
|
||||||
end = node_positions[method_name]
|
|
||||||
self._draw_curved_arrow(draw, start, end, self.colors["edge"])
|
|
||||||
|
|
||||||
# Draw nodes
|
|
||||||
for method_name, pos in node_positions.items():
|
|
||||||
if method_name in start_methods:
|
|
||||||
color = self.colors["start"]
|
|
||||||
elif method_name in router_methods:
|
|
||||||
color = self.colors["router"]
|
|
||||||
else:
|
|
||||||
color = self.colors["method"]
|
|
||||||
|
|
||||||
self._draw_node(draw, method_name, pos, color)
|
|
||||||
|
|
||||||
# Save image
|
|
||||||
img.save(f"{filename}.png")
|
|
||||||
print(f"Graph saved as {filename}.png")
|
|
||||||
|
|
||||||
def _calculate_positions(self, nodes):
|
|
||||||
positions = {}
|
|
||||||
start_methods = [
|
|
||||||
node
|
|
||||||
for node in nodes
|
|
||||||
if hasattr(self.flow._methods[node], "__is_start_method__")
|
|
||||||
]
|
|
||||||
other_methods = [node for node in nodes if node not in start_methods]
|
|
||||||
|
|
||||||
# Position start methods at the top
|
|
||||||
for i, node in enumerate(start_methods):
|
|
||||||
positions[node] = (400, 100 + i * 120)
|
|
||||||
|
|
||||||
# Position other methods below start methods
|
|
||||||
for i, node in enumerate(other_methods):
|
|
||||||
positions[node] = (400, 100 + (len(start_methods) + i) * 120)
|
|
||||||
|
|
||||||
return positions
|
|
||||||
|
|
||||||
def _draw_node(self, draw, label, position, color):
|
|
||||||
x, y = position
|
|
||||||
if color == self.colors["router"]:
|
|
||||||
# Draw router node as rounded rectangle
|
|
||||||
draw.rounded_rectangle(
|
|
||||||
[x - 70, y - 40, x + 70, y + 40],
|
|
||||||
radius=10,
|
|
||||||
fill=color,
|
|
||||||
outline=self.colors["edge"],
|
|
||||||
)
|
|
||||||
font = ImageFont.load_default()
|
|
||||||
text_width = draw.textlength(label, font=font)
|
|
||||||
draw.text(
|
|
||||||
(x - text_width / 2, y - 20), label, fill=self.colors["text"], font=font
|
|
||||||
)
|
|
||||||
draw.text((x - 30, y + 5), "Success", fill=self.colors["text"], font=font)
|
|
||||||
draw.text((x - 30, y + 25), "Failure", fill=self.colors["text"], font=font)
|
|
||||||
else:
|
|
||||||
# Draw regular node
|
|
||||||
draw.rectangle(
|
|
||||||
[x - 60, y - 30, x + 60, y + 30],
|
|
||||||
fill=color,
|
|
||||||
outline=self.colors["edge"],
|
|
||||||
)
|
|
||||||
font = ImageFont.load_default()
|
|
||||||
text_width = draw.textlength(label, font=font)
|
|
||||||
draw.text(
|
|
||||||
(x - text_width / 2, y - 7), label, fill=self.colors["text"], font=font
|
|
||||||
)
|
|
||||||
|
|
||||||
def _draw_curved_arrow(self, draw, start, end, color):
|
|
||||||
# Calculate control point for the curve
|
|
||||||
control_x = (start[0] + end[0]) / 2
|
|
||||||
control_y = (
|
|
||||||
start[1] + end[1]
|
|
||||||
) / 2 - 50 # Adjust this value to change curve height
|
|
||||||
|
|
||||||
# Draw the curved line
|
|
||||||
points = [start, (control_x, control_y), end]
|
|
||||||
draw.line(points, fill=color, width=2, joint="curve")
|
|
||||||
|
|
||||||
# Draw arrow head
|
|
||||||
self._draw_arrow_head(draw, points[-2], end, color)
|
|
||||||
|
|
||||||
def _draw_arrow_head(self, draw, start, end, color):
|
|
||||||
angle = math.atan2(end[1] - start[1], end[0] - start[0])
|
|
||||||
x = end[0] - 15 * math.cos(angle)
|
|
||||||
y = end[1] - 15 * math.sin(angle)
|
|
||||||
draw.polygon(
|
|
||||||
[
|
|
||||||
(x, y),
|
|
||||||
(
|
|
||||||
x - 10 * math.cos(angle - math.pi / 6),
|
|
||||||
y - 10 * math.sin(angle - math.pi / 6),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
x - 10 * math.cos(angle + math.pi / 6),
|
|
||||||
y - 10 * math.sin(angle + math.pi / 6),
|
|
||||||
),
|
|
||||||
],
|
|
||||||
fill=color,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def is_graphviz_available():
|
def is_graphviz_available():
|
||||||
|
|||||||
BIN
src/crewai/flow/fonts/arial_bold.ttf
Normal file
BIN
src/crewai/flow/fonts/arial_bold.ttf
Normal file
Binary file not shown.
Reference in New Issue
Block a user