properly identifying router and router children nodes. Need to fix color

This commit is contained in:
Brandon Hancock
2024-09-30 11:05:46 -04:00
parent 5d645cd89f
commit 66e7fc5ce3
3 changed files with 130 additions and 42 deletions

View File

@@ -58,10 +58,12 @@ def listen(condition):
return decorator
def router(method):
def router(method, paths=None):
def decorator(func):
func.__is_router__ = True
func.__router_for__ = method.__name__
if paths:
func.__router_paths__ = paths
return func
return decorator
@@ -102,6 +104,7 @@ class FlowMeta(type):
start_methods = []
listeners = {}
routers = {}
router_paths = {}
for attr_name, attr_value in dct.items():
if hasattr(attr_value, "__is_start_method__"):
@@ -116,10 +119,24 @@ class FlowMeta(type):
listeners[attr_name] = (condition_type, methods)
elif hasattr(attr_value, "__is_router__"):
routers[attr_value.__router_for__] = attr_name
if hasattr(attr_value, "__router_paths__"):
router_paths[attr_name] = attr_value.__router_paths__
# **Register router as a listener to its triggering method**
trigger_method_name = attr_value.__router_for__
methods = [trigger_method_name]
condition_type = "OR"
listeners[attr_name] = (condition_type, methods)
setattr(cls, "_start_methods", start_methods)
setattr(cls, "_listeners", listeners)
setattr(cls, "_routers", routers)
setattr(cls, "_router_paths", router_paths)
print("Start methods:", start_methods)
print("Listeners:", listeners)
print("Routers:", routers)
print("Router paths:", router_paths)
return cls
@@ -128,6 +145,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
_start_methods: List[str] = []
_listeners: Dict[str, tuple[str, List[str]]] = {}
_routers: Dict[str, str] = {}
_router_paths: Dict[str, List[str]] = {}
initial_state: Union[Type[T], T, None] = None
def __class_getitem__(cls, item):

View File

@@ -13,8 +13,10 @@ class FlowVisualizer(ABC):
"bg": "#FFFFFF",
"start": "#FF5A50",
"method": "#333333",
"router": "#FF8C00",
"router": "#333333", # Dark gray for router background
"router_border": "#FF8C00", # Orange for router border
"edge": "#666666",
"router_edge": "#FF8C00", # Orange for router edges
"text": "#FFFFFF",
}
self.node_styles = {
@@ -32,6 +34,10 @@ class FlowVisualizer(ABC):
"color": self.colors["router"],
"shape": "box",
"font": {"color": self.colors["text"]},
"borderWidth": 2,
"borderWidthSelected": 4,
"borderDashes": [5, 5], # Dashed border
"borderColor": self.colors["router_border"],
},
}
@@ -52,6 +58,7 @@ class PyvisFlowVisualizer(FlowVisualizer):
# Calculate levels for nodes
node_levels = self._calculate_node_levels()
print("node_levels", node_levels)
# Assign positions to nodes based on levels
y_spacing = 150 # Adjust spacing between levels (positive for top-down)
@@ -61,8 +68,11 @@ class PyvisFlowVisualizer(FlowVisualizer):
for method_name, level in node_levels.items():
level_nodes.setdefault(level, []).append(method_name)
print("level_nodes", level_nodes)
# Compute positions
for level, nodes in level_nodes.items():
print("level", level, "nodes", nodes)
x_offset = -(len(nodes) - 1) * x_spacing / 2 # Center nodes horizontally
for i, method_name in enumerate(nodes):
x = x_offset + i * x_spacing
@@ -85,21 +95,47 @@ class PyvisFlowVisualizer(FlowVisualizer):
**node_style,
)
# Add edges with curved lines
# Add edges
for method_name in self.flow._listeners:
condition_type, trigger_methods = self.flow._listeners[method_name]
is_and_condition = condition_type == "AND"
for trigger in trigger_methods:
if trigger in self.flow._methods:
net.add_edge(
trigger,
method_name,
color=self.colors.get("edge", "#666666"),
width=2,
arrows="to",
dashes=is_and_condition,
smooth={"type": "cubicBezier"},
is_router_edge = (
trigger in self.flow._routers.values()
or method_name in self.flow._routers.values()
)
edge_color = (
self.colors["router_edge"]
if is_router_edge
else self.colors["edge"]
)
edge_style = {
"color": edge_color,
"width": 2,
"arrows": "to",
"dashes": True if is_router_edge or is_and_condition else False,
"smooth": {"type": "cubicBezier"},
}
net.add_edge(trigger, method_name, **edge_style)
# Add edges from router methods to their possible paths
for router_method_name, paths in self.flow._router_paths.items():
for path in paths:
for listener_name, (
condition_type,
trigger_methods,
) in self.flow._listeners.items():
if path in trigger_methods:
edge_style = {
"color": self.colors["router_edge"],
"width": 2,
"arrows": "to",
"dashes": True,
"smooth": {"type": "cubicBezier"},
}
net.add_edge(router_method_name, listener_name, **edge_style)
# Set options for curved edges and disable physics
net.set_options(
@@ -138,38 +174,40 @@ class PyvisFlowVisualizer(FlowVisualizer):
# Generate the legend items HTML
legend_items = [
{"label": "Start Method", "color": self.colors.get("start", "#FF5A50")},
{"label": "Method", "color": self.colors.get("method", "#333333")},
# {"label": "Router", "color": self.colors.get("router", "#FF8C00")},
{"label": "Start Method", "color": self.colors["start"]},
{"label": "Method", "color": self.colors["method"]},
{
"label": "Trigger",
"color": self.colors.get("edge", "#666666"),
"dashed": False,
"label": "Router",
"color": self.colors["router"],
"border": self.colors["router_border"],
"dashed": True,
},
{"label": "Trigger", "color": self.colors["edge"], "dashed": False},
{"label": "AND Trigger", "color": self.colors["edge"], "dashed": True},
{
"label": "AND Trigger",
"color": self.colors.get("edge", "#666666"),
"label": "Router Trigger",
"color": self.colors["router_edge"],
"dashed": True,
},
]
legend_items_html = ""
for item in legend_items:
if item.get("dashed") is not None:
if item.get("dashed"):
legend_items_html += f"""
<div class="legend-item">
<div class="legend-dashed"></div>
<div>{item['label']}</div>
</div>
"""
else:
legend_items_html += f"""
<div class="legend-item">
<div class="legend-solid" style="border-bottom: 2px solid {item['color']};"></div>
<div>{item['label']}</div>
</div>
"""
if "border" in item:
legend_items_html += f"""
<div class="legend-item">
<div class="legend-color-box" style="background-color: {item['color']}; border: 2px dashed {item['border']};"></div>
<div>{item['label']}</div>
</div>
"""
elif item.get("dashed") is not None:
style = "dashed" if item["dashed"] else "solid"
legend_items_html += f"""
<div class="legend-item">
<div class="legend-{style}" style="border-bottom: 2px {style} {item['color']};"></div>
<div>{item['label']}</div>
</div>
"""
else:
legend_items_html += f"""
<div class="legend-item">
@@ -205,6 +243,7 @@ class PyvisFlowVisualizer(FlowVisualizer):
levels = {}
queue = []
visited = set()
pending_and_listeners = {}
# Initialize start methods at level 0
for method_name, method in self.flow._methods.items():
@@ -223,15 +262,46 @@ class PyvisFlowVisualizer(FlowVisualizer):
condition_type,
trigger_methods,
) in self.flow._listeners.items():
if current in trigger_methods:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
if condition_type == "OR":
if current in trigger_methods:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
elif condition_type == "AND":
if listener_name not in pending_and_listeners:
pending_and_listeners[listener_name] = set()
if current in trigger_methods:
pending_and_listeners[listener_name].add(current)
if set(trigger_methods) == pending_and_listeners[listener_name]:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
# Handle router connections (same as before)
if current in self.flow._routers.values():
router_method_name = current
paths = self.flow._router_paths.get(router_method_name, [])
for path in paths:
for listener_name, (
condition_type,
trigger_methods,
) in self.flow._listeners.items():
if path in trigger_methods:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
return levels