mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
Brandon/cre 291 flow improvements (#1390)
* Implement joao feedback * update colors for crew nodes * clean up * more linting clean up * round legend corners --------- Co-authored-by: João Moura <joaomdmoura@gmail.com>
This commit is contained in:
committed by
GitHub
parent
f98c2c00d8
commit
4361f4ff75
@@ -2,6 +2,7 @@ DARK_GRAY = "#333333"
|
|||||||
CREWAI_ORANGE = "#FF5A50"
|
CREWAI_ORANGE = "#FF5A50"
|
||||||
GRAY = "#666666"
|
GRAY = "#666666"
|
||||||
WHITE = "#FFFFFF"
|
WHITE = "#FFFFFF"
|
||||||
|
BLACK = "#000000"
|
||||||
|
|
||||||
COLORS = {
|
COLORS = {
|
||||||
"bg": WHITE,
|
"bg": WHITE,
|
||||||
@@ -16,31 +17,43 @@ COLORS = {
|
|||||||
|
|
||||||
NODE_STYLES = {
|
NODE_STYLES = {
|
||||||
"start": {
|
"start": {
|
||||||
"color": COLORS["start"],
|
"color": CREWAI_ORANGE,
|
||||||
"shape": "box",
|
"shape": "box",
|
||||||
"font": {"color": COLORS["text"]},
|
"font": {"color": WHITE},
|
||||||
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
|
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
|
||||||
},
|
},
|
||||||
"method": {
|
"method": {
|
||||||
"color": COLORS["method"],
|
"color": DARK_GRAY,
|
||||||
"shape": "box",
|
"shape": "box",
|
||||||
"font": {"color": COLORS["text"]},
|
"font": {"color": WHITE},
|
||||||
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
|
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
|
||||||
},
|
},
|
||||||
"router": {
|
"router": {
|
||||||
"color": {
|
"color": {
|
||||||
"background": COLORS["router"],
|
"background": DARK_GRAY,
|
||||||
"border": COLORS["router_border"],
|
"border": CREWAI_ORANGE,
|
||||||
"highlight": {
|
"highlight": {
|
||||||
"border": COLORS["router_border"],
|
"border": CREWAI_ORANGE,
|
||||||
"background": COLORS["router"],
|
"background": DARK_GRAY,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"shape": "box",
|
"shape": "box",
|
||||||
"font": {"color": COLORS["text"]},
|
"font": {"color": WHITE},
|
||||||
"borderWidth": 3,
|
"borderWidth": 3,
|
||||||
"borderWidthSelected": 4,
|
"borderWidthSelected": 4,
|
||||||
"shapeProperties": {"borderDashes": [5, 5]},
|
"shapeProperties": {"borderDashes": [5, 5]},
|
||||||
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
|
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
|
||||||
},
|
},
|
||||||
|
"crew": {
|
||||||
|
"color": {
|
||||||
|
"background": WHITE,
|
||||||
|
"border": CREWAI_ORANGE,
|
||||||
|
},
|
||||||
|
"shape": "box",
|
||||||
|
"font": {"color": BLACK},
|
||||||
|
"borderWidth": 3,
|
||||||
|
"borderWidthSelected": 4,
|
||||||
|
"shapeProperties": {"borderDashes": False},
|
||||||
|
"margin": {"top": 10, "bottom": 8, "left": 10, "right": 10},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import Any, Callable, Dict, Generic, List, Set, Type, TypeVar, Union
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from crewai.flow.flow_visualizer import plot_flow
|
from crewai.flow.flow_visualizer import plot_flow
|
||||||
|
from crewai.flow.utils import get_possible_return_constants
|
||||||
from crewai.telemetry import Telemetry
|
from crewai.telemetry import Telemetry
|
||||||
|
|
||||||
T = TypeVar("T", bound=Union[BaseModel, Dict[str, Any]])
|
T = TypeVar("T", bound=Union[BaseModel, Dict[str, Any]])
|
||||||
@@ -62,12 +63,10 @@ def listen(condition):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def router(method, paths=None):
|
def router(method):
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
func.__is_router__ = True
|
func.__is_router__ = True
|
||||||
func.__router_for__ = method.__name__
|
func.__router_for__ = method.__name__
|
||||||
if paths:
|
|
||||||
func.__router_paths__ = paths
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
@@ -123,10 +122,11 @@ class FlowMeta(type):
|
|||||||
listeners[attr_name] = (condition_type, methods)
|
listeners[attr_name] = (condition_type, methods)
|
||||||
elif hasattr(attr_value, "__is_router__"):
|
elif hasattr(attr_value, "__is_router__"):
|
||||||
routers[attr_value.__router_for__] = attr_name
|
routers[attr_value.__router_for__] = attr_name
|
||||||
if hasattr(attr_value, "__router_paths__"):
|
possible_returns = get_possible_return_constants(attr_value)
|
||||||
router_paths[attr_name] = attr_value.__router_paths__
|
if possible_returns:
|
||||||
|
router_paths[attr_name] = possible_returns
|
||||||
|
|
||||||
# **Register router as a listener to its triggering method**
|
# Register router as a listener to its triggering method
|
||||||
trigger_method_name = attr_value.__router_for__
|
trigger_method_name = attr_value.__router_for__
|
||||||
methods = [trigger_method_name]
|
methods = [trigger_method_name]
|
||||||
condition_type = "OR"
|
condition_type = "OR"
|
||||||
|
|||||||
@@ -30,6 +30,22 @@ class FlowPlot:
|
|||||||
layout=None,
|
layout=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Set options to disable physics
|
||||||
|
net.set_options(
|
||||||
|
"""
|
||||||
|
var options = {
|
||||||
|
"nodes": {
|
||||||
|
"font": {
|
||||||
|
"multi": "html"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"physics": {
|
||||||
|
"enabled": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
# Calculate levels for nodes
|
# Calculate levels for nodes
|
||||||
node_levels = calculate_node_levels(self.flow)
|
node_levels = calculate_node_levels(self.flow)
|
||||||
|
|
||||||
@@ -42,24 +58,13 @@ class FlowPlot:
|
|||||||
# Add edges to the network
|
# Add edges to the network
|
||||||
add_edges(net, self.flow, node_positions, self.colors)
|
add_edges(net, self.flow, node_positions, self.colors)
|
||||||
|
|
||||||
# Set options to disable physics
|
|
||||||
net.set_options(
|
|
||||||
"""
|
|
||||||
var options = {
|
|
||||||
"physics": {
|
|
||||||
"enabled": false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
network_html = net.generate_html()
|
network_html = net.generate_html()
|
||||||
final_html_content = self._generate_final_html(network_html)
|
final_html_content = self._generate_final_html(network_html)
|
||||||
|
|
||||||
# Save the final HTML content to the file
|
# Save the final HTML content to the file
|
||||||
with open(f"{filename}.html", "w", encoding="utf-8") as f:
|
with open(f"{filename}.html", "w", encoding="utf-8") as f:
|
||||||
f.write(final_html_content)
|
f.write(final_html_content)
|
||||||
print(f"Graph saved as {filename}.html")
|
print(f"Plot saved as {filename}.html")
|
||||||
|
|
||||||
self._cleanup_pyvis_lib()
|
self._cleanup_pyvis_lib()
|
||||||
|
|
||||||
@@ -94,6 +99,6 @@ class FlowPlot:
|
|||||||
print(f"Error cleaning up {lib_folder}: {e}")
|
print(f"Error cleaning up {lib_folder}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def plot_flow(flow, filename="flow_graph"):
|
def plot_flow(flow, filename="flow_plot"):
|
||||||
visualizer = FlowPlot(flow)
|
visualizer = FlowPlot(flow)
|
||||||
visualizer.plot(filename)
|
visualizer.plot(filename)
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import base64
|
import base64
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
|
||||||
@@ -48,7 +47,7 @@ class HTMLTemplateHandler:
|
|||||||
"""
|
"""
|
||||||
return legend_items_html
|
return legend_items_html
|
||||||
|
|
||||||
def generate_final_html(self, network_body, legend_items_html, title="Flow Graph"):
|
def generate_final_html(self, network_body, legend_items_html, title="Flow Plot"):
|
||||||
html_template = self.read_template()
|
html_template = self.read_template()
|
||||||
logo_svg_base64 = self.encode_logo()
|
logo_svg_base64 = self.encode_logo()
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,12 @@ def get_legend_items(colors):
|
|||||||
return [
|
return [
|
||||||
{"label": "Start Method", "color": colors["start"]},
|
{"label": "Start Method", "color": colors["start"]},
|
||||||
{"label": "Method", "color": colors["method"]},
|
{"label": "Method", "color": colors["method"]},
|
||||||
|
{
|
||||||
|
"label": "Crew Method",
|
||||||
|
"color": colors["bg"],
|
||||||
|
"border": colors["start"],
|
||||||
|
"dashed": False,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"label": "Router",
|
"label": "Router",
|
||||||
"color": colors["router"],
|
"color": colors["router"],
|
||||||
@@ -22,9 +28,10 @@ def generate_legend_items_html(legend_items):
|
|||||||
legend_items_html = ""
|
legend_items_html = ""
|
||||||
for item in legend_items:
|
for item in legend_items:
|
||||||
if "border" in item:
|
if "border" in item:
|
||||||
|
style = "dashed" if item["dashed"] else "solid"
|
||||||
legend_items_html += f"""
|
legend_items_html += f"""
|
||||||
<div class="legend-item">
|
<div class="legend-item">
|
||||||
<div class="legend-color-box" style="background-color: {item['color']}; border: 2px dashed {item['border']};"></div>
|
<div class="legend-color-box" style="background-color: {item['color']}; border: 2px {style} {item['border']}; border-radius: 5px;"></div>
|
||||||
<div>{item['label']}</div>
|
<div>{item['label']}</div>
|
||||||
</div>
|
</div>
|
||||||
"""
|
"""
|
||||||
@@ -32,14 +39,14 @@ def generate_legend_items_html(legend_items):
|
|||||||
style = "dashed" if item["dashed"] else "solid"
|
style = "dashed" if item["dashed"] else "solid"
|
||||||
legend_items_html += f"""
|
legend_items_html += f"""
|
||||||
<div class="legend-item">
|
<div class="legend-item">
|
||||||
<div class="legend-{style}" style="border-bottom: 2px {style} {item['color']};"></div>
|
<div class="legend-{style}" style="border-bottom: 2px {style} {item['color']}; border-radius: 5px;"></div>
|
||||||
<div>{item['label']}</div>
|
<div>{item['label']}</div>
|
||||||
</div>
|
</div>
|
||||||
"""
|
"""
|
||||||
else:
|
else:
|
||||||
legend_items_html += f"""
|
legend_items_html += f"""
|
||||||
<div class="legend-item">
|
<div class="legend-item">
|
||||||
<div class="legend-color-box" style="background-color: {item['color']};"></div>
|
<div class="legend-color-box" style="background-color: {item['color']}; border-radius: 5px;"></div>
|
||||||
<div>{item['label']}</div>
|
<div>{item['label']}</div>
|
||||||
</div>
|
</div>
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,3 +1,48 @@
|
|||||||
|
import ast
|
||||||
|
import inspect
|
||||||
|
import textwrap
|
||||||
|
|
||||||
|
|
||||||
|
def get_possible_return_constants(function):
|
||||||
|
try:
|
||||||
|
source = inspect.getsource(function)
|
||||||
|
except OSError:
|
||||||
|
# Can't get source code
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error retrieving source code for function {function.__name__}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Remove leading indentation
|
||||||
|
source = textwrap.dedent(source)
|
||||||
|
# Parse the source code into an AST
|
||||||
|
code_ast = ast.parse(source)
|
||||||
|
except IndentationError as e:
|
||||||
|
print(f"IndentationError while parsing source code of {function.__name__}: {e}")
|
||||||
|
print(f"Source code:\n{source}")
|
||||||
|
return None
|
||||||
|
except SyntaxError as e:
|
||||||
|
print(f"SyntaxError while parsing source code of {function.__name__}: {e}")
|
||||||
|
print(f"Source code:\n{source}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Unexpected error while parsing source code of {function.__name__}: {e}")
|
||||||
|
print(f"Source code:\n{source}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return_values = []
|
||||||
|
|
||||||
|
class ReturnVisitor(ast.NodeVisitor):
|
||||||
|
def visit_Return(self, node):
|
||||||
|
# Check if the return value is a constant (Python 3.8+)
|
||||||
|
if isinstance(node.value, ast.Constant):
|
||||||
|
return_values.append(node.value.value)
|
||||||
|
|
||||||
|
ReturnVisitor().visit(code_ast)
|
||||||
|
return return_values
|
||||||
|
|
||||||
|
|
||||||
def calculate_node_levels(flow):
|
def calculate_node_levels(flow):
|
||||||
levels = {}
|
levels = {}
|
||||||
queue = []
|
queue = []
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
import ast
|
||||||
|
import inspect
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
build_ancestor_dict,
|
build_ancestor_dict,
|
||||||
build_parent_children_dict,
|
build_parent_children_dict,
|
||||||
@@ -6,6 +9,70 @@ from .utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def method_calls_crew(method):
|
||||||
|
"""Check if the method calls `.crew()`."""
|
||||||
|
try:
|
||||||
|
source = inspect.getsource(method)
|
||||||
|
source = inspect.cleandoc(source)
|
||||||
|
tree = ast.parse(source)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Could not parse method {method.__name__}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
class CrewCallVisitor(ast.NodeVisitor):
|
||||||
|
def __init__(self):
|
||||||
|
self.found = False
|
||||||
|
|
||||||
|
def visit_Call(self, node):
|
||||||
|
if isinstance(node.func, ast.Attribute):
|
||||||
|
if node.func.attr == "crew":
|
||||||
|
self.found = True
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
visitor = CrewCallVisitor()
|
||||||
|
visitor.visit(tree)
|
||||||
|
return visitor.found
|
||||||
|
|
||||||
|
|
||||||
|
def add_nodes_to_network(net, flow, node_positions, node_styles):
|
||||||
|
def human_friendly_label(method_name):
|
||||||
|
return method_name.replace("_", " ").title()
|
||||||
|
|
||||||
|
for method_name, (x, y) in node_positions.items():
|
||||||
|
method = flow._methods.get(method_name)
|
||||||
|
if hasattr(method, "__is_start_method__"):
|
||||||
|
node_style = node_styles["start"]
|
||||||
|
elif hasattr(method, "__is_router__"):
|
||||||
|
node_style = node_styles["router"]
|
||||||
|
elif method_calls_crew(method):
|
||||||
|
node_style = node_styles["crew"]
|
||||||
|
else:
|
||||||
|
node_style = node_styles["method"]
|
||||||
|
|
||||||
|
node_style = node_style.copy()
|
||||||
|
label = human_friendly_label(method_name)
|
||||||
|
|
||||||
|
node_style.update(
|
||||||
|
{
|
||||||
|
"label": label,
|
||||||
|
"shape": "box",
|
||||||
|
"font": {
|
||||||
|
"multi": "html",
|
||||||
|
"color": node_style.get("font", {}).get("color", "#FFFFFF"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
net.add_node(
|
||||||
|
method_name,
|
||||||
|
x=x,
|
||||||
|
y=y,
|
||||||
|
fixed=True,
|
||||||
|
physics=False,
|
||||||
|
**node_style,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def compute_positions(flow, node_levels, y_spacing=150, x_spacing=150):
|
def compute_positions(flow, node_levels, y_spacing=150, x_spacing=150):
|
||||||
level_nodes = {}
|
level_nodes = {}
|
||||||
node_positions = {}
|
node_positions = {}
|
||||||
@@ -109,24 +176,3 @@ def add_edges(net, flow, node_positions, colors):
|
|||||||
"smooth": edge_smooth,
|
"smooth": edge_smooth,
|
||||||
}
|
}
|
||||||
net.add_edge(router_method_name, listener_name, **edge_style)
|
net.add_edge(router_method_name, listener_name, **edge_style)
|
||||||
|
|
||||||
|
|
||||||
def add_nodes_to_network(net, flow, node_positions, node_styles):
|
|
||||||
for method_name, (x, y) in node_positions.items():
|
|
||||||
method = flow._methods.get(method_name)
|
|
||||||
if hasattr(method, "__is_start_method__"):
|
|
||||||
node_style = node_styles["start"]
|
|
||||||
elif hasattr(method, "__is_router__"):
|
|
||||||
node_style = node_styles["router"]
|
|
||||||
else:
|
|
||||||
node_style = node_styles["method"]
|
|
||||||
|
|
||||||
net.add_node(
|
|
||||||
method_name,
|
|
||||||
label=method_name,
|
|
||||||
x=x,
|
|
||||||
y=y,
|
|
||||||
fixed=True,
|
|
||||||
physics=False,
|
|
||||||
**node_style,
|
|
||||||
)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user