diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 4a6361cce..806d9ec84 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -30,7 +30,47 @@ from crewai.telemetry import Telemetry T = TypeVar("T", bound=Union[BaseModel, Dict[str, Any]]) -def start(condition=None): +def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable: + """ + Marks a method as a flow's starting point. + + This decorator designates a method as an entry point for the flow execution. + It can optionally specify conditions that trigger the start based on other + method executions. + + Parameters + ---------- + condition : Optional[Union[str, dict, Callable]], optional + Defines when the start method should execute. Can be: + - str: Name of a method that triggers this start + - dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers) + - Callable: A method reference that triggers this start + Default is None, meaning unconditional start. + + Returns + ------- + Callable + A decorator function that marks the method as a flow start point. + + Raises + ------ + ValueError + If the condition format is invalid. + + Examples + -------- + >>> @start() # Unconditional start + >>> def begin_flow(self): + ... pass + + >>> @start("method_name") # Start after specific method + >>> def conditional_start(self): + ... pass + + >>> @start(and_("method1", "method2")) # Start after multiple methods + >>> def complex_start(self): + ... pass + """ def decorator(func): func.__is_start_method__ = True if condition is not None: @@ -56,7 +96,42 @@ def start(condition=None): return decorator -def listen(condition): +def listen(condition: Union[str, dict, Callable]) -> Callable: + """ + Creates a listener that executes when specified conditions are met. + + This decorator sets up a method to execute in response to other method + executions in the flow. It supports both simple and complex triggering + conditions. + + Parameters + ---------- + condition : Union[str, dict, Callable] + Specifies when the listener should execute. Can be: + - str: Name of a method that triggers this listener + - dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers) + - Callable: A method reference that triggers this listener + + Returns + ------- + Callable + A decorator function that sets up the method as a listener. + + Raises + ------ + ValueError + If the condition format is invalid. + + Examples + -------- + >>> @listen("process_data") # Listen to single method + >>> def handle_processed_data(self): + ... pass + + >>> @listen(or_("success", "failure")) # Listen to multiple methods + >>> def handle_completion(self): + ... pass + """ def decorator(func): if isinstance(condition, str): func.__trigger_methods__ = [condition] @@ -80,7 +155,47 @@ def listen(condition): return decorator -def router(condition): +def router(condition: Union[str, dict, Callable]) -> Callable: + """ + Creates a routing method that directs flow execution based on conditions. + + This decorator marks a method as a router, which can dynamically determine + the next steps in the flow based on its return value. Routers are triggered + by specified conditions and can return constants that determine which path + the flow should take. + + Parameters + ---------- + condition : Union[str, dict, Callable] + Specifies when the router should execute. Can be: + - str: Name of a method that triggers this router + - dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers) + - Callable: A method reference that triggers this router + + Returns + ------- + Callable + A decorator function that sets up the method as a router. + + Raises + ------ + ValueError + If the condition format is invalid. + + Examples + -------- + >>> @router("check_status") + >>> def route_based_on_status(self): + ... if self.state.status == "success": + ... return SUCCESS + ... return FAILURE + + >>> @router(and_("validate", "process")) + >>> def complex_routing(self): + ... if all([self.state.valid, self.state.processed]): + ... return CONTINUE + ... return STOP + """ def decorator(func): func.__is_router__ = True # Handle conditions like listen/start @@ -106,7 +221,39 @@ def router(condition): return decorator -def or_(*conditions): +def or_(*conditions: Union[str, dict, Callable]) -> dict: + """ + Combines multiple conditions with OR logic for flow control. + + Creates a condition that is satisfied when any of the specified conditions + are met. This is used with @start, @listen, or @router decorators to create + complex triggering conditions. + + Parameters + ---------- + *conditions : Union[str, dict, Callable] + Variable number of conditions that can be: + - str: Method names + - dict: Existing condition dictionaries + - Callable: Method references + + Returns + ------- + dict + A condition dictionary with format: + {"type": "OR", "methods": list_of_method_names} + + Raises + ------ + ValueError + If any condition is invalid. + + Examples + -------- + >>> @listen(or_("success", "timeout")) + >>> def handle_completion(self): + ... pass + """ methods = [] for condition in conditions: if isinstance(condition, dict) and "methods" in condition: @@ -120,7 +267,39 @@ def or_(*conditions): return {"type": "OR", "methods": methods} -def and_(*conditions): +def and_(*conditions: Union[str, dict, Callable]) -> dict: + """ + Combines multiple conditions with AND logic for flow control. + + Creates a condition that is satisfied only when all specified conditions + are met. This is used with @start, @listen, or @router decorators to create + complex triggering conditions. + + Parameters + ---------- + *conditions : Union[str, dict, Callable] + Variable number of conditions that can be: + - str: Method names + - dict: Existing condition dictionaries + - Callable: Method references + + Returns + ------- + dict + A condition dictionary with format: + {"type": "AND", "methods": list_of_method_names} + + Raises + ------ + ValueError + If any condition is invalid. + + Examples + -------- + >>> @listen(and_("validated", "processed")) + >>> def handle_complete_data(self): + ... pass + """ methods = [] for condition in conditions: if isinstance(condition, dict) and "methods" in condition: @@ -286,6 +465,23 @@ class Flow(Generic[T], metaclass=FlowMeta): return final_output async def _execute_start_method(self, start_method_name: str) -> None: + """ + Executes a flow's start method and its triggered listeners. + + This internal method handles the execution of methods marked with @start + decorator and manages the subsequent chain of listener executions. + + Parameters + ---------- + start_method_name : str + The name of the start method to execute. + + Notes + ----- + - Executes the start method and captures its result + - Triggers execution of any listeners waiting on this start method + - Part of the flow's initialization sequence + """ result = await self._execute_method( start_method_name, self._methods[start_method_name] ) @@ -306,6 +502,28 @@ class Flow(Generic[T], metaclass=FlowMeta): return result async def _execute_listeners(self, trigger_method: str, result: Any) -> None: + """ + Executes all listeners and routers triggered by a method completion. + + This internal method manages the execution flow by: + 1. First executing all triggered routers sequentially + 2. Then executing all triggered listeners in parallel + + Parameters + ---------- + trigger_method : str + The name of the method that triggered these listeners. + result : Any + The result from the triggering method, passed to listeners + that accept parameters. + + Notes + ----- + - Routers are executed sequentially to maintain flow control + - Each router's result becomes the new trigger_method + - Normal listeners are executed in parallel for efficiency + - Listeners can receive the trigger method's result as a parameter + """ # First, handle routers repeatedly until no router triggers anymore while True: routers_triggered = self._find_triggered_methods( @@ -335,6 +553,33 @@ class Flow(Generic[T], metaclass=FlowMeta): def _find_triggered_methods( self, trigger_method: str, router_only: bool ) -> List[str]: + """ + Finds all methods that should be triggered based on conditions. + + This internal method evaluates both OR and AND conditions to determine + which methods should be executed next in the flow. + + Parameters + ---------- + trigger_method : str + The name of the method that just completed execution. + router_only : bool + If True, only consider router methods. + If False, only consider non-router methods. + + Returns + ------- + List[str] + Names of methods that should be triggered. + + Notes + ----- + - Handles both OR and AND conditions: + * OR: Triggers if any condition is met + * AND: Triggers only when all conditions are met + - Maintains state for AND conditions using _pending_and_listeners + - Separates router and normal listener evaluation + """ triggered = [] for listener_name, (condition_type, methods) in self._listeners.items(): is_router = listener_name in self._routers @@ -363,6 +608,33 @@ class Flow(Generic[T], metaclass=FlowMeta): return triggered async def _execute_single_listener(self, listener_name: str, result: Any) -> None: + """ + Executes a single listener method with proper event handling. + + This internal method manages the execution of an individual listener, + including parameter inspection, event emission, and error handling. + + Parameters + ---------- + listener_name : str + The name of the listener method to execute. + result : Any + The result from the triggering method, which may be passed + to the listener if it accepts parameters. + + Notes + ----- + - Inspects method signature to determine if it accepts the trigger result + - Emits events for method execution start and finish + - Handles errors gracefully with detailed logging + - Recursively triggers listeners of this listener + - Supports both parameterized and parameter-less listeners + + Error Handling + ------------- + Catches and logs any exceptions during execution, preventing + individual listener failures from breaking the entire flow. + """ try: method = self._methods[listener_name] diff --git a/src/crewai/flow/flow_visualizer.py b/src/crewai/flow/flow_visualizer.py index 988f27919..ceacee91f 100644 --- a/src/crewai/flow/flow_visualizer.py +++ b/src/crewai/flow/flow_visualizer.py @@ -1,12 +1,14 @@ # flow_visualizer.py import os +from pathlib import Path from pyvis.network import Network from crewai.flow.config import COLORS, NODE_STYLES from crewai.flow.html_template_handler import HTMLTemplateHandler from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items +from crewai.flow.path_utils import safe_path_join, validate_path_exists from crewai.flow.utils import calculate_node_levels from crewai.flow.visualization_utils import ( add_edges, @@ -17,88 +19,206 @@ from crewai.flow.visualization_utils import ( class FlowPlot: def __init__(self, flow): + """ + Initialize FlowPlot with a flow object. + + Parameters + ---------- + flow : Flow + A Flow instance to visualize. + + Raises + ------ + ValueError + If flow object is invalid or missing required attributes. + """ + if not hasattr(flow, '_methods'): + raise ValueError("Invalid flow object: missing '_methods' attribute") + if not hasattr(flow, '_listeners'): + raise ValueError("Invalid flow object: missing '_listeners' attribute") + if not hasattr(flow, '_start_methods'): + raise ValueError("Invalid flow object: missing '_start_methods' attribute") + self.flow = flow self.colors = COLORS self.node_styles = NODE_STYLES def plot(self, filename): - net = Network( - directed=True, - height="750px", - width="100%", - bgcolor=self.colors["bg"], - layout=None, - ) - - # Set options to disable physics - net.set_options( - """ - var options = { - "nodes": { - "font": { - "multi": "html" - } - }, - "physics": { - "enabled": false - } - } """ - ) + Generate and save an HTML visualization of the flow. - # Calculate levels for nodes - node_levels = calculate_node_levels(self.flow) + Parameters + ---------- + filename : str + Name of the output file (without extension). - # Compute positions - node_positions = compute_positions(self.flow, node_levels) + Raises + ------ + ValueError + If filename is invalid or network generation fails. + IOError + If file operations fail or visualization cannot be generated. + RuntimeError + If network visualization generation fails. + """ + if not filename or not isinstance(filename, str): + raise ValueError("Filename must be a non-empty string") + + try: + # Initialize network + net = Network( + directed=True, + height="750px", + width="100%", + bgcolor=self.colors["bg"], + layout=None, + ) - # Add nodes to the network - add_nodes_to_network(net, self.flow, node_positions, self.node_styles) + # Set options to disable physics + net.set_options( + """ + var options = { + "nodes": { + "font": { + "multi": "html" + } + }, + "physics": { + "enabled": false + } + } + """ + ) - # Add edges to the network - add_edges(net, self.flow, node_positions, self.colors) + # Calculate levels for nodes + try: + node_levels = calculate_node_levels(self.flow) + except Exception as e: + raise ValueError(f"Failed to calculate node levels: {str(e)}") - network_html = net.generate_html() - final_html_content = self._generate_final_html(network_html) + # Compute positions + try: + node_positions = compute_positions(self.flow, node_levels) + except Exception as e: + raise ValueError(f"Failed to compute node positions: {str(e)}") - # Save the final HTML content to the file - with open(f"{filename}.html", "w", encoding="utf-8") as f: - f.write(final_html_content) - print(f"Plot saved as {filename}.html") + # Add nodes to the network + try: + add_nodes_to_network(net, self.flow, node_positions, self.node_styles) + except Exception as e: + raise RuntimeError(f"Failed to add nodes to network: {str(e)}") - self._cleanup_pyvis_lib() + # Add edges to the network + try: + add_edges(net, self.flow, node_positions, self.colors) + except Exception as e: + raise RuntimeError(f"Failed to add edges to network: {str(e)}") + + # Generate HTML + try: + network_html = net.generate_html() + final_html_content = self._generate_final_html(network_html) + except Exception as e: + raise RuntimeError(f"Failed to generate network visualization: {str(e)}") + + # Save the final HTML content to the file + try: + with open(f"{filename}.html", "w", encoding="utf-8") as f: + f.write(final_html_content) + print(f"Plot saved as {filename}.html") + except IOError as e: + raise IOError(f"Failed to save flow visualization to {filename}.html: {str(e)}") + + except (ValueError, RuntimeError, IOError) as e: + raise e + except Exception as e: + raise RuntimeError(f"Unexpected error during flow visualization: {str(e)}") + finally: + self._cleanup_pyvis_lib() def _generate_final_html(self, network_html): - # Extract just the body content from the generated HTML - current_dir = os.path.dirname(__file__) - template_path = os.path.join( - current_dir, "assets", "crewai_flow_visual_template.html" - ) - logo_path = os.path.join(current_dir, "assets", "crewai_logo.svg") + """ + Generate the final HTML content with network visualization and legend. - html_handler = HTMLTemplateHandler(template_path, logo_path) - network_body = html_handler.extract_body_content(network_html) + Parameters + ---------- + network_html : str + HTML content generated by pyvis Network. - # Generate the legend items HTML - legend_items = get_legend_items(self.colors) - legend_items_html = generate_legend_items_html(legend_items) - final_html_content = html_handler.generate_final_html( - network_body, legend_items_html - ) - return final_html_content + Returns + ------- + str + Complete HTML content with styling and legend. + + Raises + ------ + IOError + If template or logo files cannot be accessed. + ValueError + If network_html is invalid. + """ + if not network_html: + raise ValueError("Invalid network HTML content") + + try: + # Extract just the body content from the generated HTML + current_dir = os.path.dirname(__file__) + template_path = safe_path_join("assets", "crewai_flow_visual_template.html", root=current_dir) + logo_path = safe_path_join("assets", "crewai_logo.svg", root=current_dir) + + if not os.path.exists(template_path): + raise IOError(f"Template file not found: {template_path}") + if not os.path.exists(logo_path): + raise IOError(f"Logo file not found: {logo_path}") + + html_handler = HTMLTemplateHandler(template_path, logo_path) + network_body = html_handler.extract_body_content(network_html) + + # Generate the legend items HTML + legend_items = get_legend_items(self.colors) + legend_items_html = generate_legend_items_html(legend_items) + final_html_content = html_handler.generate_final_html( + network_body, legend_items_html + ) + return final_html_content + except Exception as e: + raise IOError(f"Failed to generate visualization HTML: {str(e)}") def _cleanup_pyvis_lib(self): - # Clean up the generated lib folder - lib_folder = os.path.join(os.getcwd(), "lib") + """ + Clean up the generated lib folder from pyvis. + + This method safely removes the temporary lib directory created by pyvis + during network visualization generation. + """ try: + lib_folder = safe_path_join("lib", root=os.getcwd()) if os.path.exists(lib_folder) and os.path.isdir(lib_folder): import shutil - shutil.rmtree(lib_folder) + except ValueError as e: + print(f"Error validating lib folder path: {e}") except Exception as e: - print(f"Error cleaning up {lib_folder}: {e}") + print(f"Error cleaning up lib folder: {e}") def plot_flow(flow, filename="flow_plot"): + """ + Convenience function to create and save a flow visualization. + + Parameters + ---------- + flow : Flow + Flow instance to visualize. + filename : str, optional + Output filename without extension, by default "flow_plot". + + Raises + ------ + ValueError + If flow object or filename is invalid. + IOError + If file operations fail. + """ visualizer = FlowPlot(flow) visualizer.plot(filename) diff --git a/src/crewai/flow/html_template_handler.py b/src/crewai/flow/html_template_handler.py index d521d8cf8..396af5546 100644 --- a/src/crewai/flow/html_template_handler.py +++ b/src/crewai/flow/html_template_handler.py @@ -1,11 +1,32 @@ import base64 import re +from pathlib import Path + +from crewai.flow.path_utils import safe_path_join, validate_path_exists class HTMLTemplateHandler: def __init__(self, template_path, logo_path): - self.template_path = template_path - self.logo_path = logo_path + """ + Initialize HTMLTemplateHandler with validated template and logo paths. + + Parameters + ---------- + template_path : str + Path to the HTML template file. + logo_path : str + Path to the logo image file. + + Raises + ------ + ValueError + If template or logo paths are invalid or files don't exist. + """ + try: + self.template_path = validate_path_exists(template_path, "file") + self.logo_path = validate_path_exists(logo_path, "file") + except ValueError as e: + raise ValueError(f"Invalid template or logo path: {e}") def read_template(self): with open(self.template_path, "r", encoding="utf-8") as f: diff --git a/src/crewai/flow/path_utils.py b/src/crewai/flow/path_utils.py new file mode 100644 index 000000000..09ae8cd3d --- /dev/null +++ b/src/crewai/flow/path_utils.py @@ -0,0 +1,135 @@ +""" +Path utilities for secure file operations in CrewAI flow module. + +This module provides utilities for secure path handling to prevent directory +traversal attacks and ensure paths remain within allowed boundaries. +""" + +import os +from pathlib import Path +from typing import List, Union + + +def safe_path_join(*parts: str, root: Union[str, Path, None] = None) -> str: + """ + Safely join path components and ensure the result is within allowed boundaries. + + Parameters + ---------- + *parts : str + Variable number of path components to join. + root : Union[str, Path, None], optional + Root directory to use as base. If None, uses current working directory. + + Returns + ------- + str + String representation of the resolved path. + + Raises + ------ + ValueError + If the resulting path would be outside the root directory + or if any path component is invalid. + """ + if not parts: + raise ValueError("No path components provided") + + try: + # Convert all parts to strings and clean them + clean_parts = [str(part).strip() for part in parts if part] + if not clean_parts: + raise ValueError("No valid path components provided") + + # Establish root directory + root_path = Path(root).resolve() if root else Path.cwd() + + # Join and resolve the full path + full_path = Path(root_path, *clean_parts).resolve() + + # Check if the resolved path is within root + if not str(full_path).startswith(str(root_path)): + raise ValueError( + f"Invalid path: Potential directory traversal. Path must be within {root_path}" + ) + + return str(full_path) + + except Exception as e: + if isinstance(e, ValueError): + raise + raise ValueError(f"Invalid path components: {str(e)}") + + +def validate_path_exists(path: Union[str, Path], file_type: str = "file") -> str: + """ + Validate that a path exists and is of the expected type. + + Parameters + ---------- + path : Union[str, Path] + Path to validate. + file_type : str, optional + Expected type ('file' or 'directory'), by default 'file'. + + Returns + ------- + str + Validated path as string. + + Raises + ------ + ValueError + If path doesn't exist or is not of expected type. + """ + try: + path_obj = Path(path).resolve() + + if not path_obj.exists(): + raise ValueError(f"Path does not exist: {path}") + + if file_type == "file" and not path_obj.is_file(): + raise ValueError(f"Path is not a file: {path}") + elif file_type == "directory" and not path_obj.is_dir(): + raise ValueError(f"Path is not a directory: {path}") + + return str(path_obj) + + except Exception as e: + if isinstance(e, ValueError): + raise + raise ValueError(f"Invalid path: {str(e)}") + + +def list_files(directory: Union[str, Path], pattern: str = "*") -> List[str]: + """ + Safely list files in a directory matching a pattern. + + Parameters + ---------- + directory : Union[str, Path] + Directory to search in. + pattern : str, optional + Glob pattern to match files against, by default "*". + + Returns + ------- + List[str] + List of matching file paths. + + Raises + ------ + ValueError + If directory is invalid or inaccessible. + """ + try: + dir_path = Path(directory).resolve() + if not dir_path.is_dir(): + raise ValueError(f"Not a directory: {directory}") + + return [str(p) for p in dir_path.glob(pattern) if p.is_file()] + + except Exception as e: + if isinstance(e, ValueError): + raise + raise ValueError(f"Error listing files: {str(e)}") diff --git a/src/crewai/flow/utils.py b/src/crewai/flow/utils.py index dc1f611fb..c0686222f 100644 --- a/src/crewai/flow/utils.py +++ b/src/crewai/flow/utils.py @@ -1,9 +1,25 @@ +""" +Utility functions for flow visualization and dependency analysis. + +This module provides core functionality for analyzing and manipulating flow structures, +including node level calculation, ancestor tracking, and return value analysis. +Functions in this module are primarily used by the visualization system to create +accurate and informative flow diagrams. + +Example +------- +>>> flow = Flow() +>>> node_levels = calculate_node_levels(flow) +>>> ancestors = build_ancestor_dict(flow) +""" + import ast import inspect import textwrap +from typing import Any, Dict, List, Optional, Set, Union -def get_possible_return_constants(function): +def get_possible_return_constants(function: Any) -> Optional[List[str]]: try: source = inspect.getsource(function) except OSError: @@ -77,11 +93,34 @@ def get_possible_return_constants(function): return list(return_values) if return_values else None -def calculate_node_levels(flow): - levels = {} - queue = [] - visited = set() - pending_and_listeners = {} +def calculate_node_levels(flow: Any) -> Dict[str, int]: + """ + Calculate the hierarchical level of each node in the flow. + + Performs a breadth-first traversal of the flow graph to assign levels + to nodes, starting with start methods at level 0. + + Parameters + ---------- + flow : Any + The flow instance containing methods, listeners, and router configurations. + + Returns + ------- + Dict[str, int] + Dictionary mapping method names to their hierarchical levels. + + Notes + ----- + - Start methods are assigned level 0 + - Each subsequent connected node is assigned level = parent_level + 1 + - Handles both OR and AND conditions for listeners + - Processes router paths separately + """ + levels: Dict[str, int] = {} + queue: List[str] = [] + visited: Set[str] = set() + pending_and_listeners: Dict[str, Set[str]] = {} # Make all start methods at level 0 for method_name, method in flow._methods.items(): @@ -140,7 +179,20 @@ def calculate_node_levels(flow): return levels -def count_outgoing_edges(flow): +def count_outgoing_edges(flow: Any) -> Dict[str, int]: + """ + Count the number of outgoing edges for each method in the flow. + + Parameters + ---------- + flow : Any + The flow instance to analyze. + + Returns + ------- + Dict[str, int] + Dictionary mapping method names to their outgoing edge count. + """ counts = {} for method_name in flow._methods: counts[method_name] = 0 @@ -152,16 +204,53 @@ def count_outgoing_edges(flow): return counts -def build_ancestor_dict(flow): - ancestors = {node: set() for node in flow._methods} - visited = set() +def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]: + """ + Build a dictionary mapping each node to its ancestor nodes. + + Parameters + ---------- + flow : Any + The flow instance to analyze. + + Returns + ------- + Dict[str, Set[str]] + Dictionary mapping each node to a set of its ancestor nodes. + """ + ancestors: Dict[str, Set[str]] = {node: set() for node in flow._methods} + visited: Set[str] = set() for node in flow._methods: if node not in visited: dfs_ancestors(node, ancestors, visited, flow) return ancestors -def dfs_ancestors(node, ancestors, visited, flow): +def dfs_ancestors( + node: str, + ancestors: Dict[str, Set[str]], + visited: Set[str], + flow: Any +) -> None: + """ + Perform depth-first search to build ancestor relationships. + + Parameters + ---------- + node : str + Current node being processed. + ancestors : Dict[str, Set[str]] + Dictionary tracking ancestor relationships. + visited : Set[str] + Set of already visited nodes. + flow : Any + The flow instance being analyzed. + + Notes + ----- + This function modifies the ancestors dictionary in-place to build + the complete ancestor graph. + """ if node in visited: return visited.add(node) @@ -185,12 +274,48 @@ def dfs_ancestors(node, ancestors, visited, flow): dfs_ancestors(listener_name, ancestors, visited, flow) -def is_ancestor(node, ancestor_candidate, ancestors): +def is_ancestor(node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]) -> bool: + """ + Check if one node is an ancestor of another. + + Parameters + ---------- + node : str + The node to check ancestors for. + ancestor_candidate : str + The potential ancestor node. + ancestors : Dict[str, Set[str]] + Dictionary containing ancestor relationships. + + Returns + ------- + bool + True if ancestor_candidate is an ancestor of node, False otherwise. + """ return ancestor_candidate in ancestors.get(node, set()) -def build_parent_children_dict(flow): - parent_children = {} +def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]: + """ + Build a dictionary mapping parent nodes to their children. + + Parameters + ---------- + flow : Any + The flow instance to analyze. + + Returns + ------- + Dict[str, List[str]] + Dictionary mapping parent method names to lists of their child method names. + + Notes + ----- + - Maps listeners to their trigger methods + - Maps router methods to their paths and listeners + - Children lists are sorted for consistent ordering + """ + parent_children: Dict[str, List[str]] = {} # Map listeners to their trigger methods for listener_name, (_, trigger_methods) in flow._listeners.items(): @@ -214,7 +339,24 @@ def build_parent_children_dict(flow): return parent_children -def get_child_index(parent, child, parent_children): +def get_child_index(parent: str, child: str, parent_children: Dict[str, List[str]]) -> int: + """ + Get the index of a child node in its parent's sorted children list. + + Parameters + ---------- + parent : str + The parent node name. + child : str + The child node name to find the index for. + parent_children : Dict[str, List[str]] + Dictionary mapping parents to their children lists. + + Returns + ------- + int + Zero-based index of the child in its parent's sorted children list. + """ children = parent_children.get(parent, []) children.sort() return children.index(child) diff --git a/src/crewai/flow/visualization_utils.py b/src/crewai/flow/visualization_utils.py index 321f63344..70f527f1a 100644 --- a/src/crewai/flow/visualization_utils.py +++ b/src/crewai/flow/visualization_utils.py @@ -1,5 +1,23 @@ +""" +Utilities for creating visual representations of flow structures. + +This module provides functions for generating network visualizations of flows, +including node placement, edge creation, and visual styling. It handles the +conversion of flow structures into visual network graphs with appropriate +styling and layout. + +Example +------- +>>> flow = Flow() +>>> net = Network(directed=True) +>>> node_positions = compute_positions(flow, node_levels) +>>> add_nodes_to_network(net, flow, node_positions, node_styles) +>>> add_edges(net, flow, node_positions, colors) +""" + import ast import inspect +from typing import Any, Dict, List, Optional, Tuple, Union from .utils import ( build_ancestor_dict, @@ -9,8 +27,25 @@ from .utils import ( ) -def method_calls_crew(method): - """Check if the method calls `.crew()`.""" +def method_calls_crew(method: Any) -> bool: + """ + Check if the method contains a call to `.crew()`. + + Parameters + ---------- + method : Any + The method to analyze for crew() calls. + + Returns + ------- + bool + True if the method calls .crew(), False otherwise. + + Notes + ----- + Uses AST analysis to detect method calls, specifically looking for + attribute access of 'crew'. + """ try: source = inspect.getsource(method) source = inspect.cleandoc(source) @@ -34,7 +69,34 @@ def method_calls_crew(method): return visitor.found -def add_nodes_to_network(net, flow, node_positions, node_styles): +def add_nodes_to_network( + net: Any, + flow: Any, + node_positions: Dict[str, Tuple[float, float]], + node_styles: Dict[str, Dict[str, Any]] +) -> None: + """ + Add nodes to the network visualization with appropriate styling. + + Parameters + ---------- + net : Any + The pyvis Network instance to add nodes to. + flow : Any + The flow instance containing method information. + node_positions : Dict[str, Tuple[float, float]] + Dictionary mapping node names to their (x, y) positions. + node_styles : Dict[str, Dict[str, Any]] + Dictionary containing style configurations for different node types. + + Notes + ----- + Node types include: + - Start methods + - Router methods + - Crew methods + - Regular methods + """ def human_friendly_label(method_name): return method_name.replace("_", " ").title() @@ -73,9 +135,33 @@ def add_nodes_to_network(net, flow, node_positions, node_styles): ) -def compute_positions(flow, node_levels, y_spacing=150, x_spacing=150): - level_nodes = {} - node_positions = {} +def compute_positions( + flow: Any, + node_levels: Dict[str, int], + y_spacing: float = 150, + x_spacing: float = 150 +) -> Dict[str, Tuple[float, float]]: + """ + Compute the (x, y) positions for each node in the flow graph. + + Parameters + ---------- + flow : Any + The flow instance to compute positions for. + node_levels : Dict[str, int] + Dictionary mapping node names to their hierarchical levels. + y_spacing : float, optional + Vertical spacing between levels, by default 150. + x_spacing : float, optional + Horizontal spacing between nodes, by default 150. + + Returns + ------- + Dict[str, Tuple[float, float]] + Dictionary mapping node names to their (x, y) coordinates. + """ + level_nodes: Dict[int, List[str]] = {} + node_positions: Dict[str, Tuple[float, float]] = {} for method_name, level in node_levels.items(): level_nodes.setdefault(level, []).append(method_name) @@ -90,7 +176,33 @@ def compute_positions(flow, node_levels, y_spacing=150, x_spacing=150): return node_positions -def add_edges(net, flow, node_positions, colors): +def add_edges( + net: Any, + flow: Any, + node_positions: Dict[str, Tuple[float, float]], + colors: Dict[str, str] +) -> None: + edge_smooth: Dict[str, Union[str, float]] = {"type": "continuous"} # Default value + """ + Add edges to the network visualization with appropriate styling. + + Parameters + ---------- + net : Any + The pyvis Network instance to add edges to. + flow : Any + The flow instance containing edge information. + node_positions : Dict[str, Tuple[float, float]] + Dictionary mapping node names to their positions. + colors : Dict[str, str] + Dictionary mapping edge types to their colors. + + Notes + ----- + - Handles both normal listener edges and router edges + - Applies appropriate styling (color, dashes) based on edge type + - Adds curvature to edges when needed (cycles or multiple children) + """ ancestors = build_ancestor_dict(flow) parent_children = build_parent_children_dict(flow) @@ -126,7 +238,7 @@ def add_edges(net, flow, node_positions, colors): else: edge_smooth = {"type": "cubicBezier"} else: - edge_smooth = False + edge_smooth.update({"type": "continuous"}) edge_style = { "color": edge_color, @@ -189,7 +301,7 @@ def add_edges(net, flow, node_positions, colors): else: edge_smooth = {"type": "cubicBezier"} else: - edge_smooth = False + edge_smooth.update({"type": "continuous"}) edge_style = { "color": colors["router_edge"],