From 1a0f96ae03ce9daef759dd772764631e64add518 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Mon, 30 Sep 2024 16:00:24 -0400 Subject: [PATCH] all working. needs to be cleaned up --- src/crewai/flow/flow.py | 2 + src/crewai/flow/flow_visualizer.py | 169 +++++++++++++++++++---------- 2 files changed, 112 insertions(+), 59 deletions(-) diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 4b784309a..1fcb4b53f 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -1,3 +1,5 @@ +# flow.py + import asyncio import inspect from typing import Any, Callable, Dict, Generic, List, Set, Type, TypeVar, Union diff --git a/src/crewai/flow/flow_visualizer.py b/src/crewai/flow/flow_visualizer.py index 2cf43fa02..839694a0b 100644 --- a/src/crewai/flow/flow_visualizer.py +++ b/src/crewai/flow/flow_visualizer.py @@ -1,3 +1,5 @@ +# flow_visualizer.py + import base64 import os import re @@ -5,39 +7,53 @@ from abc import ABC, abstractmethod from pyvis.network import Network +DARK_GRAY = "#333333" +CREWAI_ORANGE = "#FF5A50" +GRAY = "#666666" +WHITE = "#FFFFFF" + class FlowVisualizer(ABC): def __init__(self, flow): self.flow = flow self.colors = { - "bg": "#FFFFFF", - "start": "#FF5A50", - "method": "#333333", - "router": "#333333", # Dark gray for router background - "router_border": "#FF8C00", # Orange for router border - "edge": "#666666", - "router_edge": "#FF8C00", # Orange for router edges - "text": "#FFFFFF", + "bg": WHITE, + "start": CREWAI_ORANGE, + "method": DARK_GRAY, + "router": DARK_GRAY, + "router_border": CREWAI_ORANGE, + "edge": GRAY, + "router_edge": CREWAI_ORANGE, + "text": WHITE, } self.node_styles = { "start": { "color": self.colors["start"], "shape": "box", "font": {"color": self.colors["text"]}, + "margin": 15, }, "method": { "color": self.colors["method"], "shape": "box", "font": {"color": self.colors["text"]}, + "margin": 15, }, "router": { - "color": self.colors["router"], + "color": { + "background": self.colors["router"], + "border": self.colors["router_border"], + "highlight": { + "border": self.colors["router_border"], + "background": self.colors["router"], + }, + }, "shape": "box", "font": {"color": self.colors["text"]}, - "borderWidth": 2, + "borderWidth": 3, "borderWidthSelected": 4, - "borderDashes": [5, 5], # Dashed border - "borderColor": self.colors["router_border"], + "shapeProperties": {"borderDashes": [5, 5]}, + "margin": 15, }, } @@ -61,8 +77,8 @@ class PyvisFlowVisualizer(FlowVisualizer): print("node_levels", node_levels) # Assign positions to nodes based on levels - y_spacing = 150 # Adjust spacing between levels (positive for top-down) - x_spacing = 150 # Adjust spacing between nodes + y_spacing = 150 + x_spacing = 150 level_nodes = {} # Store node positions for edge calculations @@ -76,13 +92,13 @@ class PyvisFlowVisualizer(FlowVisualizer): x_offset = -(len(nodes) - 1) * x_spacing / 2 # Center nodes horizontally for i, method_name in enumerate(nodes): x = x_offset + i * x_spacing - y = level * y_spacing # Use level directly for y position - node_positions[method_name] = (x, y) # Store positions + y = level * y_spacing + node_positions[method_name] = (x, y) method = self.flow._methods.get(method_name) if hasattr(method, "__is_start_method__"): node_style = self.node_styles["start"] - elif method_name in self.flow._routers.values(): + elif hasattr(method, "__is_router__"): node_style = self.node_styles["router"] else: node_style = self.node_styles["method"] @@ -93,11 +109,10 @@ class PyvisFlowVisualizer(FlowVisualizer): x=x, y=y, fixed=True, - physics=False, # Disable physics for fixed positioning + physics=False, **node_style, ) - # Prepare data structures for edge calculations ancestors = self._build_ancestor_dict() parent_children = self._build_parent_children_dict() @@ -107,7 +122,10 @@ class PyvisFlowVisualizer(FlowVisualizer): is_and_condition = condition_type == "AND" for trigger in trigger_methods: - if trigger in self.flow._methods: + if ( + trigger in self.flow._methods + or trigger in self.flow._routers.values() + ): is_router_edge = any( trigger in paths for paths in self.flow._router_paths.values() ) @@ -118,29 +136,41 @@ class PyvisFlowVisualizer(FlowVisualizer): ) # Determine if this edge forms a cycle - is_cycle_edge = self._is_ancestor(method_name, trigger, ancestors) + is_cycle_edge = self._is_ancestor(trigger, method_name, ancestors) # Determine if parent has multiple children parent_has_multiple_children = ( len(parent_children.get(trigger, [])) > 1 ) - # Determine if edge should be curved + # Edge curvature logic needs_curvature = is_cycle_edge or parent_has_multiple_children if needs_curvature: - if is_cycle_edge: - # For cycles, curve left and up - edge_smooth = {"type": "curvedCCW", "roundness": 0.3} - else: - # For multiple children, adjust curvature to prevent overlap + # Get node positions + source_pos = node_positions.get(trigger) + target_pos = node_positions.get(method_name) + + if source_pos and target_pos: + dx = target_pos[0] - source_pos[0] + + if dx <= 0: + # Child is left or directly below + smooth_type = "curvedCCW" # Curve left and down + else: + # Child is to the right + smooth_type = "curvedCW" # Curve right and down + index = self._get_child_index( trigger, method_name, parent_children ) edge_smooth = { - "type": "curvedCW" if index % 2 == 0 else "curvedCCW", + "type": smooth_type, "roundness": 0.2 + (0.1 * index), } + else: + # Fallback curvature + edge_smooth = {"type": "cubicBezier"} else: edge_smooth = False # Draw straight line @@ -163,7 +193,7 @@ class PyvisFlowVisualizer(FlowVisualizer): ) in self.flow._listeners.items(): if path in trigger_methods: is_cycle_edge = self._is_ancestor( - listener_name, router_method_name, ancestors + trigger, method_name, ancestors ) # Determine if parent has multiple children @@ -171,24 +201,34 @@ class PyvisFlowVisualizer(FlowVisualizer): len(parent_children.get(router_method_name, [])) > 1 ) - # Determine if edge should be curved + # Edge curvature logic needs_curvature = is_cycle_edge or parent_has_multiple_children if needs_curvature: - if is_cycle_edge: - # Curve left and up for cycles - edge_smooth = {"type": "curvedCCW", "roundness": 0.3} - else: - # For multiple children, adjust curvature + # Get node positions + source_pos = node_positions.get(router_method_name) + target_pos = node_positions.get(listener_name) + + if source_pos and target_pos: + dx = target_pos[0] - source_pos[0] + + if dx <= 0: + # Child is left or directly below + smooth_type = "curvedCCW" # Curve left and down + else: + # Child is to the right + smooth_type = "curvedCW" # Curve right and down + index = self._get_child_index( router_method_name, listener_name, parent_children ) edge_smooth = { - "type": ( - "curvedCW" if index % 2 == 0 else "curvedCCW" - ), + "type": smooth_type, "roundness": 0.2 + (0.1 * index), } + else: + # Fallback curvature + edge_smooth = {"type": "cubicBezier"} else: edge_smooth = False # Straight line @@ -374,54 +414,65 @@ class PyvisFlowVisualizer(FlowVisualizer): return counts def _build_ancestor_dict(self): - # Helper method to build a dictionary of ancestors for each node ancestors = {node: set() for node in self.flow._methods} + visited = set() for node in self.flow._methods: - self._dfs_ancestors(node, ancestors, set()) + if node not in visited: + self._dfs_ancestors(node, ancestors, visited) + print("Ancestor Relationships:") + for node, node_ancestors in ancestors.items(): + print(f"{node}: {node_ancestors}") return ancestors def _dfs_ancestors(self, node, ancestors, visited): if node in visited: return visited.add(node) + + # Handle regular listeners for listener_name, (_, trigger_methods) in self.flow._listeners.items(): if node in trigger_methods: ancestors[listener_name].add(node) ancestors[listener_name].update(ancestors[node]) self._dfs_ancestors(listener_name, ancestors, visited) - # Include router paths - for router_method_name, paths in self.flow._router_paths.items(): - if node == router_method_name: - for path in paths: - for listener_name, ( - _, - trigger_methods, - ) in self.flow._listeners.items(): - if path in trigger_methods: - ancestors[listener_name].add(node) - ancestors[listener_name].update(ancestors[node]) - self._dfs_ancestors(listener_name, ancestors, visited) + + # Handle router methods separately + if node in self.flow._routers.values(): + router_method_name = node + paths = self.flow._router_paths.get(router_method_name, []) + for path in paths: + for listener_name, (_, trigger_methods) in self.flow._listeners.items(): + if path in trigger_methods: + # Only propagate the ancestors of the router method, not the router method itself + ancestors[listener_name].update(ancestors[node]) + self._dfs_ancestors(listener_name, ancestors, visited) def _is_ancestor(self, node, ancestor_candidate, ancestors): return ancestor_candidate in ancestors.get(node, set()) def _build_parent_children_dict(self): - # Helper method to build a parent to children mapping parent_children = {} + # Map listeners to their trigger methods for listener_name, (_, trigger_methods) in self.flow._listeners.items(): for trigger in trigger_methods: if trigger not in parent_children: parent_children[trigger] = [] - parent_children[trigger].append(listener_name) - # Include router paths and their connected listener methods + if listener_name not in parent_children[trigger]: + parent_children[trigger].append(listener_name) + # Map router methods to their paths and to listeners for router_method_name, paths in self.flow._router_paths.items(): for path in paths: - # Find listener methods triggered by each path + # Map router method to listeners of each path for listener_name, (_, trigger_methods) in self.flow._listeners.items(): if path in trigger_methods: - parent_children.setdefault(router_method_name, []).append( - listener_name - ) + if router_method_name not in parent_children: + parent_children[router_method_name] = [] + if listener_name not in parent_children[router_method_name]: + parent_children[router_method_name].append(listener_name) + # Debugging output + print("Parent-Children Relationships:") + for parent, children in parent_children.items(): + print(f"{parent}: {children}") return parent_children def _get_child_index(self, parent, child, parent_children):