properly identifying router and router children nodes. Need to fix color

This commit is contained in:
Brandon Hancock
2024-09-30 11:05:46 -04:00
parent 5d645cd89f
commit 66e7fc5ce3
3 changed files with 130 additions and 42 deletions

View File

@@ -58,10 +58,12 @@ def listen(condition):
return decorator return decorator
def router(method): def router(method, paths=None):
def decorator(func): def decorator(func):
func.__is_router__ = True func.__is_router__ = True
func.__router_for__ = method.__name__ func.__router_for__ = method.__name__
if paths:
func.__router_paths__ = paths
return func return func
return decorator return decorator
@@ -102,6 +104,7 @@ class FlowMeta(type):
start_methods = [] start_methods = []
listeners = {} listeners = {}
routers = {} routers = {}
router_paths = {}
for attr_name, attr_value in dct.items(): for attr_name, attr_value in dct.items():
if hasattr(attr_value, "__is_start_method__"): if hasattr(attr_value, "__is_start_method__"):
@@ -116,10 +119,24 @@ class FlowMeta(type):
listeners[attr_name] = (condition_type, methods) listeners[attr_name] = (condition_type, methods)
elif hasattr(attr_value, "__is_router__"): elif hasattr(attr_value, "__is_router__"):
routers[attr_value.__router_for__] = attr_name routers[attr_value.__router_for__] = attr_name
if hasattr(attr_value, "__router_paths__"):
router_paths[attr_name] = attr_value.__router_paths__
# **Register router as a listener to its triggering method**
trigger_method_name = attr_value.__router_for__
methods = [trigger_method_name]
condition_type = "OR"
listeners[attr_name] = (condition_type, methods)
setattr(cls, "_start_methods", start_methods) setattr(cls, "_start_methods", start_methods)
setattr(cls, "_listeners", listeners) setattr(cls, "_listeners", listeners)
setattr(cls, "_routers", routers) setattr(cls, "_routers", routers)
setattr(cls, "_router_paths", router_paths)
print("Start methods:", start_methods)
print("Listeners:", listeners)
print("Routers:", routers)
print("Router paths:", router_paths)
return cls return cls
@@ -128,6 +145,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
_start_methods: List[str] = [] _start_methods: List[str] = []
_listeners: Dict[str, tuple[str, List[str]]] = {} _listeners: Dict[str, tuple[str, List[str]]] = {}
_routers: Dict[str, str] = {} _routers: Dict[str, str] = {}
_router_paths: Dict[str, List[str]] = {}
initial_state: Union[Type[T], T, None] = None initial_state: Union[Type[T], T, None] = None
def __class_getitem__(cls, item): def __class_getitem__(cls, item):

View File

@@ -13,8 +13,10 @@ class FlowVisualizer(ABC):
"bg": "#FFFFFF", "bg": "#FFFFFF",
"start": "#FF5A50", "start": "#FF5A50",
"method": "#333333", "method": "#333333",
"router": "#FF8C00", "router": "#333333", # Dark gray for router background
"router_border": "#FF8C00", # Orange for router border
"edge": "#666666", "edge": "#666666",
"router_edge": "#FF8C00", # Orange for router edges
"text": "#FFFFFF", "text": "#FFFFFF",
} }
self.node_styles = { self.node_styles = {
@@ -32,6 +34,10 @@ class FlowVisualizer(ABC):
"color": self.colors["router"], "color": self.colors["router"],
"shape": "box", "shape": "box",
"font": {"color": self.colors["text"]}, "font": {"color": self.colors["text"]},
"borderWidth": 2,
"borderWidthSelected": 4,
"borderDashes": [5, 5], # Dashed border
"borderColor": self.colors["router_border"],
}, },
} }
@@ -52,6 +58,7 @@ class PyvisFlowVisualizer(FlowVisualizer):
# Calculate levels for nodes # Calculate levels for nodes
node_levels = self._calculate_node_levels() node_levels = self._calculate_node_levels()
print("node_levels", node_levels)
# Assign positions to nodes based on levels # Assign positions to nodes based on levels
y_spacing = 150 # Adjust spacing between levels (positive for top-down) y_spacing = 150 # Adjust spacing between levels (positive for top-down)
@@ -61,8 +68,11 @@ class PyvisFlowVisualizer(FlowVisualizer):
for method_name, level in node_levels.items(): for method_name, level in node_levels.items():
level_nodes.setdefault(level, []).append(method_name) level_nodes.setdefault(level, []).append(method_name)
print("level_nodes", level_nodes)
# Compute positions # Compute positions
for level, nodes in level_nodes.items(): for level, nodes in level_nodes.items():
print("level", level, "nodes", nodes)
x_offset = -(len(nodes) - 1) * x_spacing / 2 # Center nodes horizontally x_offset = -(len(nodes) - 1) * x_spacing / 2 # Center nodes horizontally
for i, method_name in enumerate(nodes): for i, method_name in enumerate(nodes):
x = x_offset + i * x_spacing x = x_offset + i * x_spacing
@@ -85,21 +95,47 @@ class PyvisFlowVisualizer(FlowVisualizer):
**node_style, **node_style,
) )
# Add edges with curved lines # Add edges
for method_name in self.flow._listeners: for method_name in self.flow._listeners:
condition_type, trigger_methods = self.flow._listeners[method_name] condition_type, trigger_methods = self.flow._listeners[method_name]
is_and_condition = condition_type == "AND" is_and_condition = condition_type == "AND"
for trigger in trigger_methods: for trigger in trigger_methods:
if trigger in self.flow._methods: if trigger in self.flow._methods:
net.add_edge( is_router_edge = (
trigger, trigger in self.flow._routers.values()
method_name, or method_name in self.flow._routers.values()
color=self.colors.get("edge", "#666666"),
width=2,
arrows="to",
dashes=is_and_condition,
smooth={"type": "cubicBezier"},
) )
edge_color = (
self.colors["router_edge"]
if is_router_edge
else self.colors["edge"]
)
edge_style = {
"color": edge_color,
"width": 2,
"arrows": "to",
"dashes": True if is_router_edge or is_and_condition else False,
"smooth": {"type": "cubicBezier"},
}
net.add_edge(trigger, method_name, **edge_style)
# Add edges from router methods to their possible paths
for router_method_name, paths in self.flow._router_paths.items():
for path in paths:
for listener_name, (
condition_type,
trigger_methods,
) in self.flow._listeners.items():
if path in trigger_methods:
edge_style = {
"color": self.colors["router_edge"],
"width": 2,
"arrows": "to",
"dashes": True,
"smooth": {"type": "cubicBezier"},
}
net.add_edge(router_method_name, listener_name, **edge_style)
# Set options for curved edges and disable physics # Set options for curved edges and disable physics
net.set_options( net.set_options(
@@ -138,38 +174,40 @@ class PyvisFlowVisualizer(FlowVisualizer):
# Generate the legend items HTML # Generate the legend items HTML
legend_items = [ legend_items = [
{"label": "Start Method", "color": self.colors.get("start", "#FF5A50")}, {"label": "Start Method", "color": self.colors["start"]},
{"label": "Method", "color": self.colors.get("method", "#333333")}, {"label": "Method", "color": self.colors["method"]},
# {"label": "Router", "color": self.colors.get("router", "#FF8C00")},
{ {
"label": "Trigger", "label": "Router",
"color": self.colors.get("edge", "#666666"), "color": self.colors["router"],
"dashed": False, "border": self.colors["router_border"],
"dashed": True,
}, },
{"label": "Trigger", "color": self.colors["edge"], "dashed": False},
{"label": "AND Trigger", "color": self.colors["edge"], "dashed": True},
{ {
"label": "AND Trigger", "label": "Router Trigger",
"color": self.colors.get("edge", "#666666"), "color": self.colors["router_edge"],
"dashed": True, "dashed": True,
}, },
] ]
legend_items_html = "" legend_items_html = ""
for item in legend_items: for item in legend_items:
if item.get("dashed") is not None: if "border" in item:
if item.get("dashed"): legend_items_html += f"""
legend_items_html += f""" <div class="legend-item">
<div class="legend-item"> <div class="legend-color-box" style="background-color: {item['color']}; border: 2px dashed {item['border']};"></div>
<div class="legend-dashed"></div> <div>{item['label']}</div>
<div>{item['label']}</div> </div>
</div> """
""" elif item.get("dashed") is not None:
else: style = "dashed" if item["dashed"] else "solid"
legend_items_html += f""" legend_items_html += f"""
<div class="legend-item"> <div class="legend-item">
<div class="legend-solid" style="border-bottom: 2px solid {item['color']};"></div> <div class="legend-{style}" style="border-bottom: 2px {style} {item['color']};"></div>
<div>{item['label']}</div> <div>{item['label']}</div>
</div> </div>
""" """
else: else:
legend_items_html += f""" legend_items_html += f"""
<div class="legend-item"> <div class="legend-item">
@@ -205,6 +243,7 @@ class PyvisFlowVisualizer(FlowVisualizer):
levels = {} levels = {}
queue = [] queue = []
visited = set() visited = set()
pending_and_listeners = {}
# Initialize start methods at level 0 # Initialize start methods at level 0
for method_name, method in self.flow._methods.items(): for method_name, method in self.flow._methods.items():
@@ -223,15 +262,46 @@ class PyvisFlowVisualizer(FlowVisualizer):
condition_type, condition_type,
trigger_methods, trigger_methods,
) in self.flow._listeners.items(): ) in self.flow._listeners.items():
if current in trigger_methods: if condition_type == "OR":
if ( if current in trigger_methods:
listener_name not in levels if (
or levels[listener_name] > current_level + 1 listener_name not in levels
): or levels[listener_name] > current_level + 1
levels[listener_name] = current_level + 1 ):
if listener_name not in visited: levels[listener_name] = current_level + 1
queue.append(listener_name) if listener_name not in visited:
queue.append(listener_name)
elif condition_type == "AND":
if listener_name not in pending_and_listeners:
pending_and_listeners[listener_name] = set()
if current in trigger_methods:
pending_and_listeners[listener_name].add(current)
if set(trigger_methods) == pending_and_listeners[listener_name]:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
# Handle router connections (same as before)
if current in self.flow._routers.values():
router_method_name = current
paths = self.flow._router_paths.get(router_method_name, [])
for path in paths:
for listener_name, (
condition_type,
trigger_methods,
) in self.flow._listeners.items():
if path in trigger_methods:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
return levels return levels