diff --git a/src/crewai/flow/assets/arial_bold.ttf b/src/crewai/flow/assets/arial_bold.ttf deleted file mode 100644 index 940e255d0..000000000 Binary files a/src/crewai/flow/assets/arial_bold.ttf and /dev/null differ diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 687816b74..4b784309a 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -58,10 +58,12 @@ def listen(condition): return decorator -def router(method): +def router(method, paths=None): def decorator(func): func.__is_router__ = True func.__router_for__ = method.__name__ + if paths: + func.__router_paths__ = paths return func return decorator @@ -102,6 +104,7 @@ class FlowMeta(type): start_methods = [] listeners = {} routers = {} + router_paths = {} for attr_name, attr_value in dct.items(): if hasattr(attr_value, "__is_start_method__"): @@ -116,10 +119,24 @@ class FlowMeta(type): listeners[attr_name] = (condition_type, methods) elif hasattr(attr_value, "__is_router__"): routers[attr_value.__router_for__] = attr_name + if hasattr(attr_value, "__router_paths__"): + router_paths[attr_name] = attr_value.__router_paths__ + + # **Register router as a listener to its triggering method** + trigger_method_name = attr_value.__router_for__ + methods = [trigger_method_name] + condition_type = "OR" + listeners[attr_name] = (condition_type, methods) setattr(cls, "_start_methods", start_methods) setattr(cls, "_listeners", listeners) setattr(cls, "_routers", routers) + setattr(cls, "_router_paths", router_paths) + + print("Start methods:", start_methods) + print("Listeners:", listeners) + print("Routers:", routers) + print("Router paths:", router_paths) return cls @@ -128,6 +145,7 @@ class Flow(Generic[T], metaclass=FlowMeta): _start_methods: List[str] = [] _listeners: Dict[str, tuple[str, List[str]]] = {} _routers: Dict[str, str] = {} + _router_paths: Dict[str, List[str]] = {} initial_state: Union[Type[T], T, None] = None def __class_getitem__(cls, item): diff --git a/src/crewai/flow/flow_visualizer.py b/src/crewai/flow/flow_visualizer.py index ef599c436..46d3df602 100644 --- a/src/crewai/flow/flow_visualizer.py +++ b/src/crewai/flow/flow_visualizer.py @@ -13,8 +13,10 @@ class FlowVisualizer(ABC): "bg": "#FFFFFF", "start": "#FF5A50", "method": "#333333", - "router": "#FF8C00", + "router": "#333333", # Dark gray for router background + "router_border": "#FF8C00", # Orange for router border "edge": "#666666", + "router_edge": "#FF8C00", # Orange for router edges "text": "#FFFFFF", } self.node_styles = { @@ -32,6 +34,10 @@ class FlowVisualizer(ABC): "color": self.colors["router"], "shape": "box", "font": {"color": self.colors["text"]}, + "borderWidth": 2, + "borderWidthSelected": 4, + "borderDashes": [5, 5], # Dashed border + "borderColor": self.colors["router_border"], }, } @@ -52,6 +58,7 @@ class PyvisFlowVisualizer(FlowVisualizer): # Calculate levels for nodes node_levels = self._calculate_node_levels() + print("node_levels", node_levels) # Assign positions to nodes based on levels y_spacing = 150 # Adjust spacing between levels (positive for top-down) @@ -61,8 +68,11 @@ class PyvisFlowVisualizer(FlowVisualizer): for method_name, level in node_levels.items(): level_nodes.setdefault(level, []).append(method_name) + print("level_nodes", level_nodes) + # Compute positions for level, nodes in level_nodes.items(): + print("level", level, "nodes", nodes) x_offset = -(len(nodes) - 1) * x_spacing / 2 # Center nodes horizontally for i, method_name in enumerate(nodes): x = x_offset + i * x_spacing @@ -85,21 +95,47 @@ class PyvisFlowVisualizer(FlowVisualizer): **node_style, ) - # Add edges with curved lines + # Add edges for method_name in self.flow._listeners: condition_type, trigger_methods = self.flow._listeners[method_name] is_and_condition = condition_type == "AND" + for trigger in trigger_methods: if trigger in self.flow._methods: - net.add_edge( - trigger, - method_name, - color=self.colors.get("edge", "#666666"), - width=2, - arrows="to", - dashes=is_and_condition, - smooth={"type": "cubicBezier"}, + is_router_edge = ( + trigger in self.flow._routers.values() + or method_name in self.flow._routers.values() ) + edge_color = ( + self.colors["router_edge"] + if is_router_edge + else self.colors["edge"] + ) + edge_style = { + "color": edge_color, + "width": 2, + "arrows": "to", + "dashes": True if is_router_edge or is_and_condition else False, + "smooth": {"type": "cubicBezier"}, + } + net.add_edge(trigger, method_name, **edge_style) + + # Add edges from router methods to their possible paths + for router_method_name, paths in self.flow._router_paths.items(): + for path in paths: + for listener_name, ( + condition_type, + trigger_methods, + ) in self.flow._listeners.items(): + if path in trigger_methods: + edge_style = { + "color": self.colors["router_edge"], + "width": 2, + "arrows": "to", + "dashes": True, + "smooth": {"type": "cubicBezier"}, + } + net.add_edge(router_method_name, listener_name, **edge_style) # Set options for curved edges and disable physics net.set_options( @@ -138,38 +174,40 @@ class PyvisFlowVisualizer(FlowVisualizer): # Generate the legend items HTML legend_items = [ - {"label": "Start Method", "color": self.colors.get("start", "#FF5A50")}, - {"label": "Method", "color": self.colors.get("method", "#333333")}, - # {"label": "Router", "color": self.colors.get("router", "#FF8C00")}, + {"label": "Start Method", "color": self.colors["start"]}, + {"label": "Method", "color": self.colors["method"]}, { - "label": "Trigger", - "color": self.colors.get("edge", "#666666"), - "dashed": False, + "label": "Router", + "color": self.colors["router"], + "border": self.colors["router_border"], + "dashed": True, }, + {"label": "Trigger", "color": self.colors["edge"], "dashed": False}, + {"label": "AND Trigger", "color": self.colors["edge"], "dashed": True}, { - "label": "AND Trigger", - "color": self.colors.get("edge", "#666666"), + "label": "Router Trigger", + "color": self.colors["router_edge"], "dashed": True, }, ] legend_items_html = "" for item in legend_items: - if item.get("dashed") is not None: - if item.get("dashed"): - legend_items_html += f""" -
-
-
{item['label']}
-
- """ - else: - legend_items_html += f""" -
-
-
{item['label']}
-
- """ + if "border" in item: + legend_items_html += f""" +
+
+
{item['label']}
+
+ """ + elif item.get("dashed") is not None: + style = "dashed" if item["dashed"] else "solid" + legend_items_html += f""" +
+
+
{item['label']}
+
+ """ else: legend_items_html += f"""
@@ -205,6 +243,7 @@ class PyvisFlowVisualizer(FlowVisualizer): levels = {} queue = [] visited = set() + pending_and_listeners = {} # Initialize start methods at level 0 for method_name, method in self.flow._methods.items(): @@ -223,15 +262,46 @@ class PyvisFlowVisualizer(FlowVisualizer): condition_type, trigger_methods, ) in self.flow._listeners.items(): - if current in trigger_methods: - if ( - listener_name not in levels - or levels[listener_name] > current_level + 1 - ): - levels[listener_name] = current_level + 1 - if listener_name not in visited: - queue.append(listener_name) + if condition_type == "OR": + if current in trigger_methods: + if ( + listener_name not in levels + or levels[listener_name] > current_level + 1 + ): + levels[listener_name] = current_level + 1 + if listener_name not in visited: + queue.append(listener_name) + elif condition_type == "AND": + if listener_name not in pending_and_listeners: + pending_and_listeners[listener_name] = set() + if current in trigger_methods: + pending_and_listeners[listener_name].add(current) + if set(trigger_methods) == pending_and_listeners[listener_name]: + if ( + listener_name not in levels + or levels[listener_name] > current_level + 1 + ): + levels[listener_name] = current_level + 1 + if listener_name not in visited: + queue.append(listener_name) + # Handle router connections (same as before) + if current in self.flow._routers.values(): + router_method_name = current + paths = self.flow._router_paths.get(router_method_name, []) + for path in paths: + for listener_name, ( + condition_type, + trigger_methods, + ) in self.flow._listeners.items(): + if path in trigger_methods: + if ( + listener_name not in levels + or levels[listener_name] > current_level + 1 + ): + levels[listener_name] = current_level + 1 + if listener_name not in visited: + queue.append(listener_name) return levels