From 3f813832855775b6ac99e878613809f0e2040dd8 Mon Sep 17 00:00:00 2001 From: "Brandon Hancock (bhancock_ai)" <109994880+bhancockio@users.noreply.github.com> Date: Fri, 4 Oct 2024 12:22:46 -0400 Subject: [PATCH] Brandon/cre 291 flow improvements (#1390) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Implement joao feedback * update colors for crew nodes * clean up * more linting clean up * round legend corners --------- Co-authored-by: João Moura --- src/crewai/flow/config.py | 31 ++++++--- src/crewai/flow/flow.py | 12 ++-- src/crewai/flow/flow_visualizer.py | 31 +++++---- src/crewai/flow/html_template_handler.py | 3 +- src/crewai/flow/legend_generator.py | 13 +++- src/crewai/flow/utils.py | 45 ++++++++++++ src/crewai/flow/visualization_utils.py | 88 ++++++++++++++++++------ 7 files changed, 169 insertions(+), 54 deletions(-) diff --git a/src/crewai/flow/config.py b/src/crewai/flow/config.py index ddaddc7a8..b04d5d0c2 100644 --- a/src/crewai/flow/config.py +++ b/src/crewai/flow/config.py @@ -2,6 +2,7 @@ DARK_GRAY = "#333333" CREWAI_ORANGE = "#FF5A50" GRAY = "#666666" WHITE = "#FFFFFF" +BLACK = "#000000" COLORS = { "bg": WHITE, @@ -16,31 +17,43 @@ COLORS = { NODE_STYLES = { "start": { - "color": COLORS["start"], + "color": CREWAI_ORANGE, "shape": "box", - "font": {"color": COLORS["text"]}, + "font": {"color": WHITE}, "margin": {"top": 10, "bottom": 8, "left": 10, "right": 10}, }, "method": { - "color": COLORS["method"], + "color": DARK_GRAY, "shape": "box", - "font": {"color": COLORS["text"]}, + "font": {"color": WHITE}, "margin": {"top": 10, "bottom": 8, "left": 10, "right": 10}, }, "router": { "color": { - "background": COLORS["router"], - "border": COLORS["router_border"], + "background": DARK_GRAY, + "border": CREWAI_ORANGE, "highlight": { - "border": COLORS["router_border"], - "background": COLORS["router"], + "border": CREWAI_ORANGE, + "background": DARK_GRAY, }, }, "shape": "box", - "font": {"color": COLORS["text"]}, + "font": {"color": WHITE}, "borderWidth": 3, "borderWidthSelected": 4, "shapeProperties": {"borderDashes": [5, 5]}, "margin": {"top": 10, "bottom": 8, "left": 10, "right": 10}, }, + "crew": { + "color": { + "background": WHITE, + "border": CREWAI_ORANGE, + }, + "shape": "box", + "font": {"color": BLACK}, + "borderWidth": 3, + "borderWidthSelected": 4, + "shapeProperties": {"borderDashes": False}, + "margin": {"top": 10, "bottom": 8, "left": 10, "right": 10}, + }, } diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 6ae3941a4..5f4953de7 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -7,6 +7,7 @@ from typing import Any, Callable, Dict, Generic, List, Set, Type, TypeVar, Union from pydantic import BaseModel from crewai.flow.flow_visualizer import plot_flow +from crewai.flow.utils import get_possible_return_constants from crewai.telemetry import Telemetry T = TypeVar("T", bound=Union[BaseModel, Dict[str, Any]]) @@ -62,12 +63,10 @@ def listen(condition): return decorator -def router(method, paths=None): +def router(method): def decorator(func): func.__is_router__ = True func.__router_for__ = method.__name__ - if paths: - func.__router_paths__ = paths return func return decorator @@ -123,10 +122,11 @@ class FlowMeta(type): listeners[attr_name] = (condition_type, methods) elif hasattr(attr_value, "__is_router__"): routers[attr_value.__router_for__] = attr_name - if hasattr(attr_value, "__router_paths__"): - router_paths[attr_name] = attr_value.__router_paths__ + possible_returns = get_possible_return_constants(attr_value) + if possible_returns: + router_paths[attr_name] = possible_returns - # **Register router as a listener to its triggering method** + # Register router as a listener to its triggering method trigger_method_name = attr_value.__router_for__ methods = [trigger_method_name] condition_type = "OR" diff --git a/src/crewai/flow/flow_visualizer.py b/src/crewai/flow/flow_visualizer.py index 822f192b0..988f27919 100644 --- a/src/crewai/flow/flow_visualizer.py +++ b/src/crewai/flow/flow_visualizer.py @@ -30,6 +30,22 @@ class FlowPlot: layout=None, ) + # Set options to disable physics + net.set_options( + """ + var options = { + "nodes": { + "font": { + "multi": "html" + } + }, + "physics": { + "enabled": false + } + } + """ + ) + # Calculate levels for nodes node_levels = calculate_node_levels(self.flow) @@ -42,24 +58,13 @@ class FlowPlot: # Add edges to the network add_edges(net, self.flow, node_positions, self.colors) - # Set options to disable physics - net.set_options( - """ - var options = { - "physics": { - "enabled": false - } - } - """ - ) - network_html = net.generate_html() final_html_content = self._generate_final_html(network_html) # Save the final HTML content to the file with open(f"{filename}.html", "w", encoding="utf-8") as f: f.write(final_html_content) - print(f"Graph saved as {filename}.html") + print(f"Plot saved as {filename}.html") self._cleanup_pyvis_lib() @@ -94,6 +99,6 @@ class FlowPlot: print(f"Error cleaning up {lib_folder}: {e}") -def plot_flow(flow, filename="flow_graph"): +def plot_flow(flow, filename="flow_plot"): visualizer = FlowPlot(flow) visualizer.plot(filename) diff --git a/src/crewai/flow/html_template_handler.py b/src/crewai/flow/html_template_handler.py index 8a88da42a..d521d8cf8 100644 --- a/src/crewai/flow/html_template_handler.py +++ b/src/crewai/flow/html_template_handler.py @@ -1,5 +1,4 @@ import base64 -import os import re @@ -48,7 +47,7 @@ class HTMLTemplateHandler: """ return legend_items_html - def generate_final_html(self, network_body, legend_items_html, title="Flow Graph"): + def generate_final_html(self, network_body, legend_items_html, title="Flow Plot"): html_template = self.read_template() logo_svg_base64 = self.encode_logo() diff --git a/src/crewai/flow/legend_generator.py b/src/crewai/flow/legend_generator.py index 83d9b97a2..fb3d5cfd6 100644 --- a/src/crewai/flow/legend_generator.py +++ b/src/crewai/flow/legend_generator.py @@ -2,6 +2,12 @@ def get_legend_items(colors): return [ {"label": "Start Method", "color": colors["start"]}, {"label": "Method", "color": colors["method"]}, + { + "label": "Crew Method", + "color": colors["bg"], + "border": colors["start"], + "dashed": False, + }, { "label": "Router", "color": colors["router"], @@ -22,9 +28,10 @@ def generate_legend_items_html(legend_items): legend_items_html = "" for item in legend_items: if "border" in item: + style = "dashed" if item["dashed"] else "solid" legend_items_html += f"""
-
+
{item['label']}
""" @@ -32,14 +39,14 @@ def generate_legend_items_html(legend_items): style = "dashed" if item["dashed"] else "solid" legend_items_html += f"""
-
+
{item['label']}
""" else: legend_items_html += f"""
-
+
{item['label']}
""" diff --git a/src/crewai/flow/utils.py b/src/crewai/flow/utils.py index f2dbfb7fd..98d03f24f 100644 --- a/src/crewai/flow/utils.py +++ b/src/crewai/flow/utils.py @@ -1,3 +1,48 @@ +import ast +import inspect +import textwrap + + +def get_possible_return_constants(function): + try: + source = inspect.getsource(function) + except OSError: + # Can't get source code + return None + except Exception as e: + print(f"Error retrieving source code for function {function.__name__}: {e}") + return None + + try: + # Remove leading indentation + source = textwrap.dedent(source) + # Parse the source code into an AST + code_ast = ast.parse(source) + except IndentationError as e: + print(f"IndentationError while parsing source code of {function.__name__}: {e}") + print(f"Source code:\n{source}") + return None + except SyntaxError as e: + print(f"SyntaxError while parsing source code of {function.__name__}: {e}") + print(f"Source code:\n{source}") + return None + except Exception as e: + print(f"Unexpected error while parsing source code of {function.__name__}: {e}") + print(f"Source code:\n{source}") + return None + + return_values = [] + + class ReturnVisitor(ast.NodeVisitor): + def visit_Return(self, node): + # Check if the return value is a constant (Python 3.8+) + if isinstance(node.value, ast.Constant): + return_values.append(node.value.value) + + ReturnVisitor().visit(code_ast) + return return_values + + def calculate_node_levels(flow): levels = {} queue = [] diff --git a/src/crewai/flow/visualization_utils.py b/src/crewai/flow/visualization_utils.py index ba2ba5f18..5b95a1369 100644 --- a/src/crewai/flow/visualization_utils.py +++ b/src/crewai/flow/visualization_utils.py @@ -1,3 +1,6 @@ +import ast +import inspect + from .utils import ( build_ancestor_dict, build_parent_children_dict, @@ -6,6 +9,70 @@ from .utils import ( ) +def method_calls_crew(method): + """Check if the method calls `.crew()`.""" + try: + source = inspect.getsource(method) + source = inspect.cleandoc(source) + tree = ast.parse(source) + except Exception as e: + print(f"Could not parse method {method.__name__}: {e}") + return False + + class CrewCallVisitor(ast.NodeVisitor): + def __init__(self): + self.found = False + + def visit_Call(self, node): + if isinstance(node.func, ast.Attribute): + if node.func.attr == "crew": + self.found = True + self.generic_visit(node) + + visitor = CrewCallVisitor() + visitor.visit(tree) + return visitor.found + + +def add_nodes_to_network(net, flow, node_positions, node_styles): + def human_friendly_label(method_name): + return method_name.replace("_", " ").title() + + for method_name, (x, y) in node_positions.items(): + method = flow._methods.get(method_name) + if hasattr(method, "__is_start_method__"): + node_style = node_styles["start"] + elif hasattr(method, "__is_router__"): + node_style = node_styles["router"] + elif method_calls_crew(method): + node_style = node_styles["crew"] + else: + node_style = node_styles["method"] + + node_style = node_style.copy() + label = human_friendly_label(method_name) + + node_style.update( + { + "label": label, + "shape": "box", + "font": { + "multi": "html", + "color": node_style.get("font", {}).get("color", "#FFFFFF"), + }, + } + ) + + net.add_node( + method_name, + x=x, + y=y, + fixed=True, + physics=False, + **node_style, + ) + + def compute_positions(flow, node_levels, y_spacing=150, x_spacing=150): level_nodes = {} node_positions = {} @@ -109,24 +176,3 @@ def add_edges(net, flow, node_positions, colors): "smooth": edge_smooth, } net.add_edge(router_method_name, listener_name, **edge_style) - - -def add_nodes_to_network(net, flow, node_positions, node_styles): - for method_name, (x, y) in node_positions.items(): - method = flow._methods.get(method_name) - if hasattr(method, "__is_start_method__"): - node_style = node_styles["start"] - elif hasattr(method, "__is_router__"): - node_style = node_styles["router"] - else: - node_style = node_styles["method"] - - net.add_node( - method_name, - label=method_name, - x=x, - y=y, - fixed=True, - physics=False, - **node_style, - )