Brandon/cre 291 flow improvements (#1390)

* Implement joao feedback

* update colors for crew nodes

* clean up

* more linting clean up

* round legend corners

---------

Co-authored-by: João Moura <joaomdmoura@gmail.com>
This commit is contained in:
Brandon Hancock (bhancock_ai)
2024-10-04 12:22:46 -04:00
committed by GitHub
parent e8a49e7687
commit 3f81383285
7 changed files with 169 additions and 54 deletions

View File

@@ -2,6 +2,7 @@ DARK_GRAY = "#333333"
CREWAI_ORANGE = "#FF5A50"
GRAY = "#666666"
WHITE = "#FFFFFF"
BLACK = "#000000"
COLORS = {
"bg": WHITE,
@@ -16,31 +17,43 @@ COLORS = {
NODE_STYLES = {
"start": {
"color": COLORS["start"],
"color": CREWAI_ORANGE,
"shape": "box",
"font": {"color": COLORS["text"]},
"font": {"color": WHITE},
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
},
"method": {
"color": COLORS["method"],
"color": DARK_GRAY,
"shape": "box",
"font": {"color": COLORS["text"]},
"font": {"color": WHITE},
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
},
"router": {
"color": {
"background": COLORS["router"],
"border": COLORS["router_border"],
"background": DARK_GRAY,
"border": CREWAI_ORANGE,
"highlight": {
"border": COLORS["router_border"],
"background": COLORS["router"],
"border": CREWAI_ORANGE,
"background": DARK_GRAY,
},
},
"shape": "box",
"font": {"color": COLORS["text"]},
"font": {"color": WHITE},
"borderWidth": 3,
"borderWidthSelected": 4,
"shapeProperties": {"borderDashes": [5, 5]},
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
},
"crew": {
"color": {
"background": WHITE,
"border": CREWAI_ORANGE,
},
"shape": "box",
"font": {"color": BLACK},
"borderWidth": 3,
"borderWidthSelected": 4,
"shapeProperties": {"borderDashes": False},
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
},
}

View File

@@ -7,6 +7,7 @@ from typing import Any, Callable, Dict, Generic, List, Set, Type, TypeVar, Union
from pydantic import BaseModel
from crewai.flow.flow_visualizer import plot_flow
from crewai.flow.utils import get_possible_return_constants
from crewai.telemetry import Telemetry
T = TypeVar("T", bound=Union[BaseModel, Dict[str, Any]])
@@ -62,12 +63,10 @@ def listen(condition):
return decorator
def router(method, paths=None):
def router(method):
def decorator(func):
func.__is_router__ = True
func.__router_for__ = method.__name__
if paths:
func.__router_paths__ = paths
return func
return decorator
@@ -123,10 +122,11 @@ class FlowMeta(type):
listeners[attr_name] = (condition_type, methods)
elif hasattr(attr_value, "__is_router__"):
routers[attr_value.__router_for__] = attr_name
if hasattr(attr_value, "__router_paths__"):
router_paths[attr_name] = attr_value.__router_paths__
possible_returns = get_possible_return_constants(attr_value)
if possible_returns:
router_paths[attr_name] = possible_returns
# **Register router as a listener to its triggering method**
# Register router as a listener to its triggering method
trigger_method_name = attr_value.__router_for__
methods = [trigger_method_name]
condition_type = "OR"

View File

@@ -30,6 +30,22 @@ class FlowPlot:
layout=None,
)
# Set options to disable physics
net.set_options(
"""
var options = {
"nodes": {
"font": {
"multi": "html"
}
},
"physics": {
"enabled": false
}
}
"""
)
# Calculate levels for nodes
node_levels = calculate_node_levels(self.flow)
@@ -42,24 +58,13 @@ class FlowPlot:
# Add edges to the network
add_edges(net, self.flow, node_positions, self.colors)
# Set options to disable physics
net.set_options(
"""
var options = {
"physics": {
"enabled": false
}
}
"""
)
network_html = net.generate_html()
final_html_content = self._generate_final_html(network_html)
# Save the final HTML content to the file
with open(f"{filename}.html", "w", encoding="utf-8") as f:
f.write(final_html_content)
print(f"Graph saved as {filename}.html")
print(f"Plot saved as {filename}.html")
self._cleanup_pyvis_lib()
@@ -94,6 +99,6 @@ class FlowPlot:
print(f"Error cleaning up {lib_folder}: {e}")
def plot_flow(flow, filename="flow_graph"):
def plot_flow(flow, filename="flow_plot"):
visualizer = FlowPlot(flow)
visualizer.plot(filename)

View File

@@ -1,5 +1,4 @@
import base64
import os
import re
@@ -48,7 +47,7 @@ class HTMLTemplateHandler:
"""
return legend_items_html
def generate_final_html(self, network_body, legend_items_html, title="Flow Graph"):
def generate_final_html(self, network_body, legend_items_html, title="Flow Plot"):
html_template = self.read_template()
logo_svg_base64 = self.encode_logo()

View File

@@ -2,6 +2,12 @@ def get_legend_items(colors):
return [
{"label": "Start Method", "color": colors["start"]},
{"label": "Method", "color": colors["method"]},
{
"label": "Crew Method",
"color": colors["bg"],
"border": colors["start"],
"dashed": False,
},
{
"label": "Router",
"color": colors["router"],
@@ -22,9 +28,10 @@ def generate_legend_items_html(legend_items):
legend_items_html = ""
for item in legend_items:
if "border" in item:
style = "dashed" if item["dashed"] else "solid"
legend_items_html += f"""
<div class="legend-item">
<div class="legend-color-box" style="background-color: {item['color']}; border: 2px dashed {item['border']};"></div>
<div class="legend-color-box" style="background-color: {item['color']}; border: 2px {style} {item['border']}; border-radius: 5px;"></div>
<div>{item['label']}</div>
</div>
"""
@@ -32,14 +39,14 @@ def generate_legend_items_html(legend_items):
style = "dashed" if item["dashed"] else "solid"
legend_items_html += f"""
<div class="legend-item">
<div class="legend-{style}" style="border-bottom: 2px {style} {item['color']};"></div>
<div class="legend-{style}" style="border-bottom: 2px {style} {item['color']}; border-radius: 5px;"></div>
<div>{item['label']}</div>
</div>
"""
else:
legend_items_html += f"""
<div class="legend-item">
<div class="legend-color-box" style="background-color: {item['color']};"></div>
<div class="legend-color-box" style="background-color: {item['color']}; border-radius: 5px;"></div>
<div>{item['label']}</div>
</div>
"""

View File

@@ -1,3 +1,48 @@
import ast
import inspect
import textwrap
def get_possible_return_constants(function):
try:
source = inspect.getsource(function)
except OSError:
# Can't get source code
return None
except Exception as e:
print(f"Error retrieving source code for function {function.__name__}: {e}")
return None
try:
# Remove leading indentation
source = textwrap.dedent(source)
# Parse the source code into an AST
code_ast = ast.parse(source)
except IndentationError as e:
print(f"IndentationError while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
return None
except SyntaxError as e:
print(f"SyntaxError while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
return None
except Exception as e:
print(f"Unexpected error while parsing source code of {function.__name__}: {e}")
print(f"Source code:\n{source}")
return None
return_values = []
class ReturnVisitor(ast.NodeVisitor):
def visit_Return(self, node):
# Check if the return value is a constant (Python 3.8+)
if isinstance(node.value, ast.Constant):
return_values.append(node.value.value)
ReturnVisitor().visit(code_ast)
return return_values
def calculate_node_levels(flow):
levels = {}
queue = []

View File

@@ -1,3 +1,6 @@
import ast
import inspect
from .utils import (
build_ancestor_dict,
build_parent_children_dict,
@@ -6,6 +9,70 @@ from .utils import (
)
def method_calls_crew(method):
"""Check if the method calls `.crew()`."""
try:
source = inspect.getsource(method)
source = inspect.cleandoc(source)
tree = ast.parse(source)
except Exception as e:
print(f"Could not parse method {method.__name__}: {e}")
return False
class CrewCallVisitor(ast.NodeVisitor):
def __init__(self):
self.found = False
def visit_Call(self, node):
if isinstance(node.func, ast.Attribute):
if node.func.attr == "crew":
self.found = True
self.generic_visit(node)
visitor = CrewCallVisitor()
visitor.visit(tree)
return visitor.found
def add_nodes_to_network(net, flow, node_positions, node_styles):
def human_friendly_label(method_name):
return method_name.replace("_", " ").title()
for method_name, (x, y) in node_positions.items():
method = flow._methods.get(method_name)
if hasattr(method, "__is_start_method__"):
node_style = node_styles["start"]
elif hasattr(method, "__is_router__"):
node_style = node_styles["router"]
elif method_calls_crew(method):
node_style = node_styles["crew"]
else:
node_style = node_styles["method"]
node_style = node_style.copy()
label = human_friendly_label(method_name)
node_style.update(
{
"label": label,
"shape": "box",
"font": {
"multi": "html",
"color": node_style.get("font", {}).get("color", "#FFFFFF"),
},
}
)
net.add_node(
method_name,
x=x,
y=y,
fixed=True,
physics=False,
**node_style,
)
def compute_positions(flow, node_levels, y_spacing=150, x_spacing=150):
level_nodes = {}
node_positions = {}
@@ -109,24 +176,3 @@ def add_edges(net, flow, node_positions, colors):
"smooth": edge_smooth,
}
net.add_edge(router_method_name, listener_name, **edge_style)
def add_nodes_to_network(net, flow, node_positions, node_styles):
for method_name, (x, y) in node_positions.items():
method = flow._methods.get(method_name)
if hasattr(method, "__is_start_method__"):
node_style = node_styles["start"]
elif hasattr(method, "__is_router__"):
node_style = node_styles["router"]
else:
node_style = node_styles["method"]
net.add_node(
method_name,
label=method_name,
x=x,
y=y,
fixed=True,
physics=False,
**node_style,
)