all working. needs to be cleaned up

This commit is contained in:
Brandon Hancock
2024-09-30 16:00:24 -04:00
parent b927989c4d
commit 1a0f96ae03
2 changed files with 112 additions and 59 deletions

View File

@@ -1,3 +1,5 @@
# flow.py
import asyncio import asyncio
import inspect import inspect
from typing import Any, Callable, Dict, Generic, List, Set, Type, TypeVar, Union from typing import Any, Callable, Dict, Generic, List, Set, Type, TypeVar, Union

View File

@@ -1,3 +1,5 @@
# flow_visualizer.py
import base64 import base64
import os import os
import re import re
@@ -5,39 +7,53 @@ from abc import ABC, abstractmethod
from pyvis.network import Network from pyvis.network import Network
DARK_GRAY = "#333333"
CREWAI_ORANGE = "#FF5A50"
GRAY = "#666666"
WHITE = "#FFFFFF"
class FlowVisualizer(ABC): class FlowVisualizer(ABC):
def __init__(self, flow): def __init__(self, flow):
self.flow = flow self.flow = flow
self.colors = { self.colors = {
"bg": "#FFFFFF", "bg": WHITE,
"start": "#FF5A50", "start": CREWAI_ORANGE,
"method": "#333333", "method": DARK_GRAY,
"router": "#333333", # Dark gray for router background "router": DARK_GRAY,
"router_border": "#FF8C00", # Orange for router border "router_border": CREWAI_ORANGE,
"edge": "#666666", "edge": GRAY,
"router_edge": "#FF8C00", # Orange for router edges "router_edge": CREWAI_ORANGE,
"text": "#FFFFFF", "text": WHITE,
} }
self.node_styles = { self.node_styles = {
"start": { "start": {
"color": self.colors["start"], "color": self.colors["start"],
"shape": "box", "shape": "box",
"font": {"color": self.colors["text"]}, "font": {"color": self.colors["text"]},
"margin": 15,
}, },
"method": { "method": {
"color": self.colors["method"], "color": self.colors["method"],
"shape": "box", "shape": "box",
"font": {"color": self.colors["text"]}, "font": {"color": self.colors["text"]},
"margin": 15,
}, },
"router": { "router": {
"color": self.colors["router"], "color": {
"background": self.colors["router"],
"border": self.colors["router_border"],
"highlight": {
"border": self.colors["router_border"],
"background": self.colors["router"],
},
},
"shape": "box", "shape": "box",
"font": {"color": self.colors["text"]}, "font": {"color": self.colors["text"]},
"borderWidth": 2, "borderWidth": 3,
"borderWidthSelected": 4, "borderWidthSelected": 4,
"borderDashes": [5, 5], # Dashed border "shapeProperties": {"borderDashes": [5, 5]},
"borderColor": self.colors["router_border"], "margin": 15,
}, },
} }
@@ -61,8 +77,8 @@ class PyvisFlowVisualizer(FlowVisualizer):
print("node_levels", node_levels) print("node_levels", node_levels)
# Assign positions to nodes based on levels # Assign positions to nodes based on levels
y_spacing = 150 # Adjust spacing between levels (positive for top-down) y_spacing = 150
x_spacing = 150 # Adjust spacing between nodes x_spacing = 150
level_nodes = {} level_nodes = {}
# Store node positions for edge calculations # Store node positions for edge calculations
@@ -76,13 +92,13 @@ class PyvisFlowVisualizer(FlowVisualizer):
x_offset = -(len(nodes) - 1) * x_spacing / 2 # Center nodes horizontally x_offset = -(len(nodes) - 1) * x_spacing / 2 # Center nodes horizontally
for i, method_name in enumerate(nodes): for i, method_name in enumerate(nodes):
x = x_offset + i * x_spacing x = x_offset + i * x_spacing
y = level * y_spacing # Use level directly for y position y = level * y_spacing
node_positions[method_name] = (x, y) # Store positions node_positions[method_name] = (x, y)
method = self.flow._methods.get(method_name) method = self.flow._methods.get(method_name)
if hasattr(method, "__is_start_method__"): if hasattr(method, "__is_start_method__"):
node_style = self.node_styles["start"] node_style = self.node_styles["start"]
elif method_name in self.flow._routers.values(): elif hasattr(method, "__is_router__"):
node_style = self.node_styles["router"] node_style = self.node_styles["router"]
else: else:
node_style = self.node_styles["method"] node_style = self.node_styles["method"]
@@ -93,11 +109,10 @@ class PyvisFlowVisualizer(FlowVisualizer):
x=x, x=x,
y=y, y=y,
fixed=True, fixed=True,
physics=False, # Disable physics for fixed positioning physics=False,
**node_style, **node_style,
) )
# Prepare data structures for edge calculations
ancestors = self._build_ancestor_dict() ancestors = self._build_ancestor_dict()
parent_children = self._build_parent_children_dict() parent_children = self._build_parent_children_dict()
@@ -107,7 +122,10 @@ class PyvisFlowVisualizer(FlowVisualizer):
is_and_condition = condition_type == "AND" is_and_condition = condition_type == "AND"
for trigger in trigger_methods: for trigger in trigger_methods:
if trigger in self.flow._methods: if (
trigger in self.flow._methods
or trigger in self.flow._routers.values()
):
is_router_edge = any( is_router_edge = any(
trigger in paths for paths in self.flow._router_paths.values() trigger in paths for paths in self.flow._router_paths.values()
) )
@@ -118,29 +136,41 @@ class PyvisFlowVisualizer(FlowVisualizer):
) )
# Determine if this edge forms a cycle # Determine if this edge forms a cycle
is_cycle_edge = self._is_ancestor(method_name, trigger, ancestors) is_cycle_edge = self._is_ancestor(trigger, method_name, ancestors)
# Determine if parent has multiple children # Determine if parent has multiple children
parent_has_multiple_children = ( parent_has_multiple_children = (
len(parent_children.get(trigger, [])) > 1 len(parent_children.get(trigger, [])) > 1
) )
# Determine if edge should be curved # Edge curvature logic
needs_curvature = is_cycle_edge or parent_has_multiple_children needs_curvature = is_cycle_edge or parent_has_multiple_children
if needs_curvature: if needs_curvature:
if is_cycle_edge: # Get node positions
# For cycles, curve left and up source_pos = node_positions.get(trigger)
edge_smooth = {"type": "curvedCCW", "roundness": 0.3} target_pos = node_positions.get(method_name)
else:
# For multiple children, adjust curvature to prevent overlap if source_pos and target_pos:
dx = target_pos[0] - source_pos[0]
if dx <= 0:
# Child is left or directly below
smooth_type = "curvedCCW" # Curve left and down
else:
# Child is to the right
smooth_type = "curvedCW" # Curve right and down
index = self._get_child_index( index = self._get_child_index(
trigger, method_name, parent_children trigger, method_name, parent_children
) )
edge_smooth = { edge_smooth = {
"type": "curvedCW" if index % 2 == 0 else "curvedCCW", "type": smooth_type,
"roundness": 0.2 + (0.1 * index), "roundness": 0.2 + (0.1 * index),
} }
else:
# Fallback curvature
edge_smooth = {"type": "cubicBezier"}
else: else:
edge_smooth = False # Draw straight line edge_smooth = False # Draw straight line
@@ -163,7 +193,7 @@ class PyvisFlowVisualizer(FlowVisualizer):
) in self.flow._listeners.items(): ) in self.flow._listeners.items():
if path in trigger_methods: if path in trigger_methods:
is_cycle_edge = self._is_ancestor( is_cycle_edge = self._is_ancestor(
listener_name, router_method_name, ancestors trigger, method_name, ancestors
) )
# Determine if parent has multiple children # Determine if parent has multiple children
@@ -171,24 +201,34 @@ class PyvisFlowVisualizer(FlowVisualizer):
len(parent_children.get(router_method_name, [])) > 1 len(parent_children.get(router_method_name, [])) > 1
) )
# Determine if edge should be curved # Edge curvature logic
needs_curvature = is_cycle_edge or parent_has_multiple_children needs_curvature = is_cycle_edge or parent_has_multiple_children
if needs_curvature: if needs_curvature:
if is_cycle_edge: # Get node positions
# Curve left and up for cycles source_pos = node_positions.get(router_method_name)
edge_smooth = {"type": "curvedCCW", "roundness": 0.3} target_pos = node_positions.get(listener_name)
else:
# For multiple children, adjust curvature if source_pos and target_pos:
dx = target_pos[0] - source_pos[0]
if dx <= 0:
# Child is left or directly below
smooth_type = "curvedCCW" # Curve left and down
else:
# Child is to the right
smooth_type = "curvedCW" # Curve right and down
index = self._get_child_index( index = self._get_child_index(
router_method_name, listener_name, parent_children router_method_name, listener_name, parent_children
) )
edge_smooth = { edge_smooth = {
"type": ( "type": smooth_type,
"curvedCW" if index % 2 == 0 else "curvedCCW"
),
"roundness": 0.2 + (0.1 * index), "roundness": 0.2 + (0.1 * index),
} }
else:
# Fallback curvature
edge_smooth = {"type": "cubicBezier"}
else: else:
edge_smooth = False # Straight line edge_smooth = False # Straight line
@@ -374,54 +414,65 @@ class PyvisFlowVisualizer(FlowVisualizer):
return counts return counts
def _build_ancestor_dict(self): def _build_ancestor_dict(self):
# Helper method to build a dictionary of ancestors for each node
ancestors = {node: set() for node in self.flow._methods} ancestors = {node: set() for node in self.flow._methods}
visited = set()
for node in self.flow._methods: for node in self.flow._methods:
self._dfs_ancestors(node, ancestors, set()) if node not in visited:
self._dfs_ancestors(node, ancestors, visited)
print("Ancestor Relationships:")
for node, node_ancestors in ancestors.items():
print(f"{node}: {node_ancestors}")
return ancestors return ancestors
def _dfs_ancestors(self, node, ancestors, visited): def _dfs_ancestors(self, node, ancestors, visited):
if node in visited: if node in visited:
return return
visited.add(node) visited.add(node)
# Handle regular listeners
for listener_name, (_, trigger_methods) in self.flow._listeners.items(): for listener_name, (_, trigger_methods) in self.flow._listeners.items():
if node in trigger_methods: if node in trigger_methods:
ancestors[listener_name].add(node) ancestors[listener_name].add(node)
ancestors[listener_name].update(ancestors[node]) ancestors[listener_name].update(ancestors[node])
self._dfs_ancestors(listener_name, ancestors, visited) self._dfs_ancestors(listener_name, ancestors, visited)
# Include router paths
for router_method_name, paths in self.flow._router_paths.items(): # Handle router methods separately
if node == router_method_name: if node in self.flow._routers.values():
for path in paths: router_method_name = node
for listener_name, ( paths = self.flow._router_paths.get(router_method_name, [])
_, for path in paths:
trigger_methods, for listener_name, (_, trigger_methods) in self.flow._listeners.items():
) in self.flow._listeners.items(): if path in trigger_methods:
if path in trigger_methods: # Only propagate the ancestors of the router method, not the router method itself
ancestors[listener_name].add(node) ancestors[listener_name].update(ancestors[node])
ancestors[listener_name].update(ancestors[node]) self._dfs_ancestors(listener_name, ancestors, visited)
self._dfs_ancestors(listener_name, ancestors, visited)
def _is_ancestor(self, node, ancestor_candidate, ancestors): def _is_ancestor(self, node, ancestor_candidate, ancestors):
return ancestor_candidate in ancestors.get(node, set()) return ancestor_candidate in ancestors.get(node, set())
def _build_parent_children_dict(self): def _build_parent_children_dict(self):
# Helper method to build a parent to children mapping
parent_children = {} parent_children = {}
# Map listeners to their trigger methods
for listener_name, (_, trigger_methods) in self.flow._listeners.items(): for listener_name, (_, trigger_methods) in self.flow._listeners.items():
for trigger in trigger_methods: for trigger in trigger_methods:
if trigger not in parent_children: if trigger not in parent_children:
parent_children[trigger] = [] parent_children[trigger] = []
parent_children[trigger].append(listener_name) if listener_name not in parent_children[trigger]:
# Include router paths and their connected listener methods parent_children[trigger].append(listener_name)
# Map router methods to their paths and to listeners
for router_method_name, paths in self.flow._router_paths.items(): for router_method_name, paths in self.flow._router_paths.items():
for path in paths: for path in paths:
# Find listener methods triggered by each path # Map router method to listeners of each path
for listener_name, (_, trigger_methods) in self.flow._listeners.items(): for listener_name, (_, trigger_methods) in self.flow._listeners.items():
if path in trigger_methods: if path in trigger_methods:
parent_children.setdefault(router_method_name, []).append( if router_method_name not in parent_children:
listener_name parent_children[router_method_name] = []
) if listener_name not in parent_children[router_method_name]:
parent_children[router_method_name].append(listener_name)
# Debugging output
print("Parent-Children Relationships:")
for parent, children in parent_children.items():
print(f"{parent}: {children}")
return parent_children return parent_children
def _get_child_index(self, parent, child, parent_children): def _get_child_index(self, parent, child, parent_children):