mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
In the middle of improving plotting
This commit is contained in:
@@ -31,16 +31,50 @@ def get_possible_return_constants(function):
|
|||||||
print(f"Source code:\n{source}")
|
print(f"Source code:\n{source}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return_values = []
|
return_values = set()
|
||||||
|
dict_definitions = {}
|
||||||
|
|
||||||
|
class DictionaryAssignmentVisitor(ast.NodeVisitor):
|
||||||
|
def visit_Assign(self, node):
|
||||||
|
# Check if this assignment is assigning a dictionary literal to a variable
|
||||||
|
if isinstance(node.value, ast.Dict) and len(node.targets) == 1:
|
||||||
|
target = node.targets[0]
|
||||||
|
if isinstance(target, ast.Name):
|
||||||
|
var_name = target.id
|
||||||
|
dict_values = []
|
||||||
|
# Extract string values from the dictionary
|
||||||
|
for val in node.value.values:
|
||||||
|
if isinstance(val, ast.Constant) and isinstance(val.value, str):
|
||||||
|
dict_values.append(val.value)
|
||||||
|
# If non-string, skip or just ignore
|
||||||
|
if dict_values:
|
||||||
|
dict_definitions[var_name] = dict_values
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
class ReturnVisitor(ast.NodeVisitor):
|
class ReturnVisitor(ast.NodeVisitor):
|
||||||
def visit_Return(self, node):
|
def visit_Return(self, node):
|
||||||
# Check if the return value is a constant (Python 3.8+)
|
# Direct string return
|
||||||
if isinstance(node.value, ast.Constant):
|
if isinstance(node.value, ast.Constant) and isinstance(
|
||||||
return_values.append(node.value.value)
|
node.value.value, str
|
||||||
|
):
|
||||||
|
return_values.add(node.value.value)
|
||||||
|
# Dictionary-based return, like return paths[result]
|
||||||
|
elif isinstance(node.value, ast.Subscript):
|
||||||
|
# Check if we're subscripting a known dictionary variable
|
||||||
|
if isinstance(node.value.value, ast.Name):
|
||||||
|
var_name = node.value.value.id
|
||||||
|
if var_name in dict_definitions:
|
||||||
|
# Add all possible dictionary values
|
||||||
|
for v in dict_definitions[var_name]:
|
||||||
|
return_values.add(v)
|
||||||
|
self.generic_visit(node)
|
||||||
|
|
||||||
|
# First pass: identify dictionary assignments
|
||||||
|
DictionaryAssignmentVisitor().visit(code_ast)
|
||||||
|
# Second pass: identify returns
|
||||||
ReturnVisitor().visit(code_ast)
|
ReturnVisitor().visit(code_ast)
|
||||||
return return_values
|
|
||||||
|
return list(return_values) if return_values else None
|
||||||
|
|
||||||
|
|
||||||
def calculate_node_levels(flow):
|
def calculate_node_levels(flow):
|
||||||
@@ -61,10 +95,7 @@ def calculate_node_levels(flow):
|
|||||||
current_level = levels[current]
|
current_level = levels[current]
|
||||||
visited.add(current)
|
visited.add(current)
|
||||||
|
|
||||||
for listener_name, (
|
for listener_name, (condition_type, trigger_methods) in flow._listeners.items():
|
||||||
condition_type,
|
|
||||||
trigger_methods,
|
|
||||||
) in flow._listeners.items():
|
|
||||||
if condition_type == "OR":
|
if condition_type == "OR":
|
||||||
if current in trigger_methods:
|
if current in trigger_methods:
|
||||||
if (
|
if (
|
||||||
@@ -89,7 +120,7 @@ def calculate_node_levels(flow):
|
|||||||
queue.append(listener_name)
|
queue.append(listener_name)
|
||||||
|
|
||||||
# Handle router connections
|
# Handle router connections
|
||||||
if current in flow._routers.values():
|
if current in flow._routers:
|
||||||
router_method_name = current
|
router_method_name = current
|
||||||
paths = flow._router_paths.get(router_method_name, [])
|
paths = flow._router_paths.get(router_method_name, [])
|
||||||
for path in paths:
|
for path in paths:
|
||||||
@@ -105,6 +136,7 @@ def calculate_node_levels(flow):
|
|||||||
levels[listener_name] = current_level + 1
|
levels[listener_name] = current_level + 1
|
||||||
if listener_name not in visited:
|
if listener_name not in visited:
|
||||||
queue.append(listener_name)
|
queue.append(listener_name)
|
||||||
|
|
||||||
return levels
|
return levels
|
||||||
|
|
||||||
|
|
||||||
@@ -142,7 +174,7 @@ def dfs_ancestors(node, ancestors, visited, flow):
|
|||||||
dfs_ancestors(listener_name, ancestors, visited, flow)
|
dfs_ancestors(listener_name, ancestors, visited, flow)
|
||||||
|
|
||||||
# Handle router methods separately
|
# Handle router methods separately
|
||||||
if node in flow._routers.values():
|
if node in flow._routers:
|
||||||
router_method_name = node
|
router_method_name = node
|
||||||
paths = flow._router_paths.get(router_method_name, [])
|
paths = flow._router_paths.get(router_method_name, [])
|
||||||
for path in paths:
|
for path in paths:
|
||||||
|
|||||||
@@ -99,7 +99,8 @@ def add_edges(net, flow, node_positions, colors):
|
|||||||
is_and_condition = condition_type == "AND"
|
is_and_condition = condition_type == "AND"
|
||||||
|
|
||||||
for trigger in trigger_methods:
|
for trigger in trigger_methods:
|
||||||
if trigger in flow._methods or trigger in flow._routers.values():
|
# Check if nodes exist before adding edges
|
||||||
|
if trigger in node_positions and method_name in node_positions:
|
||||||
is_router_edge = any(
|
is_router_edge = any(
|
||||||
trigger in paths for paths in flow._router_paths.values()
|
trigger in paths for paths in flow._router_paths.values()
|
||||||
)
|
)
|
||||||
@@ -135,6 +136,11 @@ def add_edges(net, flow, node_positions, colors):
|
|||||||
}
|
}
|
||||||
|
|
||||||
net.add_edge(trigger, method_name, **edge_style)
|
net.add_edge(trigger, method_name, **edge_style)
|
||||||
|
else:
|
||||||
|
# Print a warning if a node does not exist.
|
||||||
|
print(
|
||||||
|
f"Warning: No node found for '{trigger}' or '{method_name}'. Skipping edge."
|
||||||
|
)
|
||||||
|
|
||||||
for router_method_name, paths in flow._router_paths.items():
|
for router_method_name, paths in flow._router_paths.items():
|
||||||
for path in paths:
|
for path in paths:
|
||||||
@@ -143,36 +149,47 @@ def add_edges(net, flow, node_positions, colors):
|
|||||||
trigger_methods,
|
trigger_methods,
|
||||||
) in flow._listeners.items():
|
) in flow._listeners.items():
|
||||||
if path in trigger_methods:
|
if path in trigger_methods:
|
||||||
is_cycle_edge = is_ancestor(trigger, method_name, ancestors)
|
if (
|
||||||
parent_has_multiple_children = (
|
router_method_name in node_positions
|
||||||
len(parent_children.get(router_method_name, [])) > 1
|
and listener_name in node_positions
|
||||||
)
|
):
|
||||||
needs_curvature = is_cycle_edge or parent_has_multiple_children
|
is_cycle_edge = is_ancestor(
|
||||||
|
router_method_name, listener_name, ancestors
|
||||||
|
)
|
||||||
|
parent_has_multiple_children = (
|
||||||
|
len(parent_children.get(router_method_name, [])) > 1
|
||||||
|
)
|
||||||
|
needs_curvature = is_cycle_edge or parent_has_multiple_children
|
||||||
|
|
||||||
if needs_curvature:
|
if needs_curvature:
|
||||||
source_pos = node_positions.get(router_method_name)
|
source_pos = node_positions.get(router_method_name)
|
||||||
target_pos = node_positions.get(listener_name)
|
target_pos = node_positions.get(listener_name)
|
||||||
|
|
||||||
if source_pos and target_pos:
|
if source_pos and target_pos:
|
||||||
dx = target_pos[0] - source_pos[0]
|
dx = target_pos[0] - source_pos[0]
|
||||||
smooth_type = "curvedCCW" if dx <= 0 else "curvedCW"
|
smooth_type = "curvedCCW" if dx <= 0 else "curvedCW"
|
||||||
index = get_child_index(
|
index = get_child_index(
|
||||||
router_method_name, listener_name, parent_children
|
router_method_name, listener_name, parent_children
|
||||||
)
|
)
|
||||||
edge_smooth = {
|
edge_smooth = {
|
||||||
"type": smooth_type,
|
"type": smooth_type,
|
||||||
"roundness": 0.2 + (0.1 * index),
|
"roundness": 0.2 + (0.1 * index),
|
||||||
}
|
}
|
||||||
|
else:
|
||||||
|
edge_smooth = {"type": "cubicBezier"}
|
||||||
else:
|
else:
|
||||||
edge_smooth = {"type": "cubicBezier"}
|
edge_smooth = False
|
||||||
else:
|
|
||||||
edge_smooth = False
|
|
||||||
|
|
||||||
edge_style = {
|
edge_style = {
|
||||||
"color": colors["router_edge"],
|
"color": colors["router_edge"],
|
||||||
"width": 2,
|
"width": 2,
|
||||||
"arrows": "to",
|
"arrows": "to",
|
||||||
"dashes": True,
|
"dashes": True,
|
||||||
"smooth": edge_smooth,
|
"smooth": edge_smooth,
|
||||||
}
|
}
|
||||||
net.add_edge(router_method_name, listener_name, **edge_style)
|
net.add_edge(router_method_name, listener_name, **edge_style)
|
||||||
|
else:
|
||||||
|
# Print a warning if a node does not exist.
|
||||||
|
print(
|
||||||
|
f"Warning: No node found for '{router_method_name}' or '{listener_name}'. Skipping edge."
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user