In the middle of improving plotting

This commit is contained in:
Brandon Hancock
2024-12-17 09:26:56 -05:00
parent 8eea6bc090
commit 10706ce2ee
2 changed files with 90 additions and 41 deletions

View File

@@ -31,16 +31,50 @@ def get_possible_return_constants(function):
print(f"Source code:\n{source}")
return None
return_values = []
return_values = set()
dict_definitions = {}
class DictionaryAssignmentVisitor(ast.NodeVisitor):
def visit_Assign(self, node):
# Check if this assignment is assigning a dictionary literal to a variable
if isinstance(node.value, ast.Dict) and len(node.targets) == 1:
target = node.targets[0]
if isinstance(target, ast.Name):
var_name = target.id
dict_values = []
# Extract string values from the dictionary
for val in node.value.values:
if isinstance(val, ast.Constant) and isinstance(val.value, str):
dict_values.append(val.value)
# If non-string, skip or just ignore
if dict_values:
dict_definitions[var_name] = dict_values
self.generic_visit(node)
class ReturnVisitor(ast.NodeVisitor):
def visit_Return(self, node):
# Check if the return value is a constant (Python 3.8+)
if isinstance(node.value, ast.Constant):
return_values.append(node.value.value)
# Direct string return
if isinstance(node.value, ast.Constant) and isinstance(
node.value.value, str
):
return_values.add(node.value.value)
# Dictionary-based return, like return paths[result]
elif isinstance(node.value, ast.Subscript):
# Check if we're subscripting a known dictionary variable
if isinstance(node.value.value, ast.Name):
var_name = node.value.value.id
if var_name in dict_definitions:
# Add all possible dictionary values
for v in dict_definitions[var_name]:
return_values.add(v)
self.generic_visit(node)
# First pass: identify dictionary assignments
DictionaryAssignmentVisitor().visit(code_ast)
# Second pass: identify returns
ReturnVisitor().visit(code_ast)
return return_values
return list(return_values) if return_values else None
def calculate_node_levels(flow):
@@ -61,10 +95,7 @@ def calculate_node_levels(flow):
current_level = levels[current]
visited.add(current)
for listener_name, (
condition_type,
trigger_methods,
) in flow._listeners.items():
for listener_name, (condition_type, trigger_methods) in flow._listeners.items():
if condition_type == "OR":
if current in trigger_methods:
if (
@@ -89,7 +120,7 @@ def calculate_node_levels(flow):
queue.append(listener_name)
# Handle router connections
if current in flow._routers.values():
if current in flow._routers:
router_method_name = current
paths = flow._router_paths.get(router_method_name, [])
for path in paths:
@@ -105,6 +136,7 @@ def calculate_node_levels(flow):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
return levels
@@ -142,7 +174,7 @@ def dfs_ancestors(node, ancestors, visited, flow):
dfs_ancestors(listener_name, ancestors, visited, flow)
# Handle router methods separately
if node in flow._routers.values():
if node in flow._routers:
router_method_name = node
paths = flow._router_paths.get(router_method_name, [])
for path in paths:

View File

@@ -99,7 +99,8 @@ def add_edges(net, flow, node_positions, colors):
is_and_condition = condition_type == "AND"
for trigger in trigger_methods:
if trigger in flow._methods or trigger in flow._routers.values():
# Check if nodes exist before adding edges
if trigger in node_positions and method_name in node_positions:
is_router_edge = any(
trigger in paths for paths in flow._router_paths.values()
)
@@ -135,6 +136,11 @@ def add_edges(net, flow, node_positions, colors):
}
net.add_edge(trigger, method_name, **edge_style)
else:
# Print a warning if a node does not exist.
print(
f"Warning: No node found for '{trigger}' or '{method_name}'. Skipping edge."
)
for router_method_name, paths in flow._router_paths.items():
for path in paths:
@@ -143,36 +149,47 @@ def add_edges(net, flow, node_positions, colors):
trigger_methods,
) in flow._listeners.items():
if path in trigger_methods:
is_cycle_edge = is_ancestor(trigger, method_name, ancestors)
parent_has_multiple_children = (
len(parent_children.get(router_method_name, [])) > 1
)
needs_curvature = is_cycle_edge or parent_has_multiple_children
if (
router_method_name in node_positions
and listener_name in node_positions
):
is_cycle_edge = is_ancestor(
router_method_name, listener_name, ancestors
)
parent_has_multiple_children = (
len(parent_children.get(router_method_name, [])) > 1
)
needs_curvature = is_cycle_edge or parent_has_multiple_children
if needs_curvature:
source_pos = node_positions.get(router_method_name)
target_pos = node_positions.get(listener_name)
if needs_curvature:
source_pos = node_positions.get(router_method_name)
target_pos = node_positions.get(listener_name)
if source_pos and target_pos:
dx = target_pos[0] - source_pos[0]
smooth_type = "curvedCCW" if dx <= 0 else "curvedCW"
index = get_child_index(
router_method_name, listener_name, parent_children
)
edge_smooth = {
"type": smooth_type,
"roundness": 0.2 + (0.1 * index),
}
if source_pos and target_pos:
dx = target_pos[0] - source_pos[0]
smooth_type = "curvedCCW" if dx <= 0 else "curvedCW"
index = get_child_index(
router_method_name, listener_name, parent_children
)
edge_smooth = {
"type": smooth_type,
"roundness": 0.2 + (0.1 * index),
}
else:
edge_smooth = {"type": "cubicBezier"}
else:
edge_smooth = {"type": "cubicBezier"}
else:
edge_smooth = False
edge_smooth = False
edge_style = {
"color": colors["router_edge"],
"width": 2,
"arrows": "to",
"dashes": True,
"smooth": edge_smooth,
}
net.add_edge(router_method_name, listener_name, **edge_style)
edge_style = {
"color": colors["router_edge"],
"width": 2,
"arrows": "to",
"dashes": True,
"smooth": edge_smooth,
}
net.add_edge(router_method_name, listener_name, **edge_style)
else:
# Print a warning if a node does not exist.
print(
f"Warning: No node found for '{router_method_name}' or '{listener_name}'. Skipping edge."
)