diff --git a/src/crewai/flow/utils.py b/src/crewai/flow/utils.py index c0686222f..81f3c1041 100644 --- a/src/crewai/flow/utils.py +++ b/src/crewai/flow/utils.py @@ -16,7 +16,8 @@ Example import ast import inspect import textwrap -from typing import Any, Dict, List, Optional, Set, Union +from collections import defaultdict, deque +from typing import Any, Deque, Dict, List, Optional, Set, Union def get_possible_return_constants(function: Any) -> Optional[List[str]]: @@ -118,7 +119,7 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]: - Processes router paths separately """ levels: Dict[str, int] = {} - queue: List[str] = [] + queue: Deque[str] = deque() visited: Set[str] = set() pending_and_listeners: Dict[str, Set[str]] = {} @@ -128,28 +129,35 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]: levels[method_name] = 0 queue.append(method_name) + # Precompute listener dependencies + or_listeners = defaultdict(list) + and_listeners = defaultdict(set) + for listener_name, (condition_type, trigger_methods) in flow._listeners.items(): + if condition_type == "OR": + for method in trigger_methods: + or_listeners[method].append(listener_name) + elif condition_type == "AND": + and_listeners[listener_name] = set(trigger_methods) + # Breadth-first traversal to assign levels while queue: - current = queue.pop(0) + current = queue.popleft() current_level = levels[current] visited.add(current) - for listener_name, (condition_type, trigger_methods) in flow._listeners.items(): - if condition_type == "OR": - if current in trigger_methods: - if ( - listener_name not in levels - or levels[listener_name] > current_level + 1 - ): - levels[listener_name] = current_level + 1 - if listener_name not in visited: - queue.append(listener_name) - elif condition_type == "AND": + for listener_name in or_listeners[current]: + if listener_name not in levels or levels[listener_name] > current_level + 1: + levels[listener_name] = current_level + 1 + if listener_name not in visited: + queue.append(listener_name) + + for listener_name, required_methods in and_listeners.items(): + if current in required_methods: if listener_name not in pending_and_listeners: pending_and_listeners[listener_name] = set() - if current in trigger_methods: - pending_and_listeners[listener_name].add(current) - if set(trigger_methods) == pending_and_listeners[listener_name]: + pending_and_listeners[listener_name].add(current) + + if required_methods == pending_and_listeners[listener_name]: if ( listener_name not in levels or levels[listener_name] > current_level + 1 @@ -159,22 +167,7 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]: queue.append(listener_name) # Handle router connections - if current in flow._routers: - router_method_name = current - paths = flow._router_paths.get(router_method_name, []) - for path in paths: - for listener_name, ( - condition_type, - trigger_methods, - ) in flow._listeners.items(): - if path in trigger_methods: - if ( - listener_name not in levels - or levels[listener_name] > current_level + 1 - ): - levels[listener_name] = current_level + 1 - if listener_name not in visited: - queue.append(listener_name) + process_router_paths(flow, current, current_level, levels, queue) return levels @@ -227,10 +220,7 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]: def dfs_ancestors( - node: str, - ancestors: Dict[str, Set[str]], - visited: Set[str], - flow: Any + node: str, ancestors: Dict[str, Set[str]], visited: Set[str], flow: Any ) -> None: """ Perform depth-first search to build ancestor relationships. @@ -274,7 +264,9 @@ def dfs_ancestors( dfs_ancestors(listener_name, ancestors, visited, flow) -def is_ancestor(node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]) -> bool: +def is_ancestor( + node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]] +) -> bool: """ Check if one node is an ancestor of another. @@ -339,7 +331,9 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]: return parent_children -def get_child_index(parent: str, child: str, parent_children: Dict[str, List[str]]) -> int: +def get_child_index( + parent: str, child: str, parent_children: Dict[str, List[str]] +) -> int: """ Get the index of a child node in its parent's sorted children list. @@ -360,3 +354,23 @@ def get_child_index(parent: str, child: str, parent_children: Dict[str, List[str children = parent_children.get(parent, []) children.sort() return children.index(child) + + +def process_router_paths(flow, current, current_level, levels, queue): + """ + Handle the router connections for the current node. + """ + if current in flow._routers: + paths = flow._router_paths.get(current, []) + for path in paths: + for listener_name, ( + condition_type, + trigger_methods, + ) in flow._listeners.items(): + if path in trigger_methods: + if ( + listener_name not in levels + or levels[listener_name] > current_level + 1 + ): + levels[listener_name] = current_level + 1 + queue.append(listener_name)