diff --git a/src/crewai/flow/utils.py b/src/crewai/flow/utils.py index 98d03f24f..dc1f611fb 100644 --- a/src/crewai/flow/utils.py +++ b/src/crewai/flow/utils.py @@ -31,16 +31,50 @@ def get_possible_return_constants(function): print(f"Source code:\n{source}") return None - return_values = [] + return_values = set() + dict_definitions = {} + + class DictionaryAssignmentVisitor(ast.NodeVisitor): + def visit_Assign(self, node): + # Check if this assignment is assigning a dictionary literal to a variable + if isinstance(node.value, ast.Dict) and len(node.targets) == 1: + target = node.targets[0] + if isinstance(target, ast.Name): + var_name = target.id + dict_values = [] + # Extract string values from the dictionary + for val in node.value.values: + if isinstance(val, ast.Constant) and isinstance(val.value, str): + dict_values.append(val.value) + # If non-string, skip or just ignore + if dict_values: + dict_definitions[var_name] = dict_values + self.generic_visit(node) class ReturnVisitor(ast.NodeVisitor): def visit_Return(self, node): - # Check if the return value is a constant (Python 3.8+) - if isinstance(node.value, ast.Constant): - return_values.append(node.value.value) + # Direct string return + if isinstance(node.value, ast.Constant) and isinstance( + node.value.value, str + ): + return_values.add(node.value.value) + # Dictionary-based return, like return paths[result] + elif isinstance(node.value, ast.Subscript): + # Check if we're subscripting a known dictionary variable + if isinstance(node.value.value, ast.Name): + var_name = node.value.value.id + if var_name in dict_definitions: + # Add all possible dictionary values + for v in dict_definitions[var_name]: + return_values.add(v) + self.generic_visit(node) + # First pass: identify dictionary assignments + DictionaryAssignmentVisitor().visit(code_ast) + # Second pass: identify returns ReturnVisitor().visit(code_ast) - return return_values + + return list(return_values) if return_values else None def calculate_node_levels(flow): @@ -61,10 +95,7 @@ def calculate_node_levels(flow): current_level = levels[current] visited.add(current) - for listener_name, ( - condition_type, - trigger_methods, - ) in flow._listeners.items(): + for listener_name, (condition_type, trigger_methods) in flow._listeners.items(): if condition_type == "OR": if current in trigger_methods: if ( @@ -89,7 +120,7 @@ def calculate_node_levels(flow): queue.append(listener_name) # Handle router connections - if current in flow._routers.values(): + if current in flow._routers: router_method_name = current paths = flow._router_paths.get(router_method_name, []) for path in paths: @@ -105,6 +136,7 @@ def calculate_node_levels(flow): levels[listener_name] = current_level + 1 if listener_name not in visited: queue.append(listener_name) + return levels @@ -142,7 +174,7 @@ def dfs_ancestors(node, ancestors, visited, flow): dfs_ancestors(listener_name, ancestors, visited, flow) # Handle router methods separately - if node in flow._routers.values(): + if node in flow._routers: router_method_name = node paths = flow._router_paths.get(router_method_name, []) for path in paths: diff --git a/src/crewai/flow/visualization_utils.py b/src/crewai/flow/visualization_utils.py index 5b95a1369..a367ef1db 100644 --- a/src/crewai/flow/visualization_utils.py +++ b/src/crewai/flow/visualization_utils.py @@ -99,7 +99,8 @@ def add_edges(net, flow, node_positions, colors): is_and_condition = condition_type == "AND" for trigger in trigger_methods: - if trigger in flow._methods or trigger in flow._routers.values(): + # Check if nodes exist before adding edges + if trigger in node_positions and method_name in node_positions: is_router_edge = any( trigger in paths for paths in flow._router_paths.values() ) @@ -135,6 +136,11 @@ def add_edges(net, flow, node_positions, colors): } net.add_edge(trigger, method_name, **edge_style) + else: + # Print a warning if a node does not exist. + print( + f"Warning: No node found for '{trigger}' or '{method_name}'. Skipping edge." + ) for router_method_name, paths in flow._router_paths.items(): for path in paths: @@ -143,36 +149,47 @@ def add_edges(net, flow, node_positions, colors): trigger_methods, ) in flow._listeners.items(): if path in trigger_methods: - is_cycle_edge = is_ancestor(trigger, method_name, ancestors) - parent_has_multiple_children = ( - len(parent_children.get(router_method_name, [])) > 1 - ) - needs_curvature = is_cycle_edge or parent_has_multiple_children + if ( + router_method_name in node_positions + and listener_name in node_positions + ): + is_cycle_edge = is_ancestor( + router_method_name, listener_name, ancestors + ) + parent_has_multiple_children = ( + len(parent_children.get(router_method_name, [])) > 1 + ) + needs_curvature = is_cycle_edge or parent_has_multiple_children - if needs_curvature: - source_pos = node_positions.get(router_method_name) - target_pos = node_positions.get(listener_name) + if needs_curvature: + source_pos = node_positions.get(router_method_name) + target_pos = node_positions.get(listener_name) - if source_pos and target_pos: - dx = target_pos[0] - source_pos[0] - smooth_type = "curvedCCW" if dx <= 0 else "curvedCW" - index = get_child_index( - router_method_name, listener_name, parent_children - ) - edge_smooth = { - "type": smooth_type, - "roundness": 0.2 + (0.1 * index), - } + if source_pos and target_pos: + dx = target_pos[0] - source_pos[0] + smooth_type = "curvedCCW" if dx <= 0 else "curvedCW" + index = get_child_index( + router_method_name, listener_name, parent_children + ) + edge_smooth = { + "type": smooth_type, + "roundness": 0.2 + (0.1 * index), + } + else: + edge_smooth = {"type": "cubicBezier"} else: - edge_smooth = {"type": "cubicBezier"} - else: - edge_smooth = False + edge_smooth = False - edge_style = { - "color": colors["router_edge"], - "width": 2, - "arrows": "to", - "dashes": True, - "smooth": edge_smooth, - } - net.add_edge(router_method_name, listener_name, **edge_style) + edge_style = { + "color": colors["router_edge"], + "width": 2, + "arrows": "to", + "dashes": True, + "smooth": edge_smooth, + } + net.add_edge(router_method_name, listener_name, **edge_style) + else: + # Print a warning if a node does not exist. + print( + f"Warning: No node found for '{router_method_name}' or '{listener_name}'. Skipping edge." + )