diff --git a/src/crewai/flow/core_flow_utils.py b/src/crewai/flow/core_flow_utils.py new file mode 100644 index 000000000..b886f8658 --- /dev/null +++ b/src/crewai/flow/core_flow_utils.py @@ -0,0 +1,141 @@ +"""Core utility functions for Flow class operations. + +This module contains utility functions that are specifically designed to work +with the Flow class and require direct access to Flow class internals. These +utilities are separated from general-purpose utilities to maintain a clean +dependency structure and avoid circular imports. + +Functions in this module are core to Flow functionality and are not related +to visualization or other optional features. +""" + +import ast +import inspect +import textwrap +from typing import Any, Callable, Dict, List, Optional, Set, Union + +from pydantic import BaseModel + + +def get_possible_return_constants(function: callable) -> Optional[List[str]]: + """Extract possible string return values from a function by analyzing its source code. + + Analyzes the function's source code using AST to identify string constants that + could be returned, including strings stored in dictionaries and direct returns. + + Args: + function: The function to analyze for possible return values + + Returns: + list[str] | None: List of possible string return values, or None if: + - Source code cannot be retrieved + - Source code has syntax/indentation errors + - No string return values are found + + Raises: + OSError: If source code cannot be retrieved + IndentationError: If source code has invalid indentation + SyntaxError: If source code has syntax errors + + Example: + >>> def get_status(): + ... paths = {"success": "completed", "error": "failed"} + ... return paths["success"] + >>> get_possible_return_constants(get_status) + ['completed', 'failed'] + """ + try: + source = inspect.getsource(function) + except OSError: + # Can't get source code + return None + except Exception as e: + print(f"Error retrieving source code for function {function.__name__}: {e}") + return None + + try: + # Remove leading indentation + source = textwrap.dedent(source) + # Parse the source code into an AST + code_ast = ast.parse(source) + except IndentationError as e: + print(f"IndentationError while parsing source code of {function.__name__}: {e}") + print(f"Source code:\n{source}") + return None + except SyntaxError as e: + print(f"SyntaxError while parsing source code of {function.__name__}: {e}") + print(f"Source code:\n{source}") + return None + except Exception as e: + print(f"Unexpected error while parsing source code of {function.__name__}: {e}") + print(f"Source code:\n{source}") + return None + + return_values = set() + dict_definitions = {} + + class DictionaryAssignmentVisitor(ast.NodeVisitor): + def visit_Assign(self, node): + # Check if this assignment is assigning a dictionary literal to a variable + if isinstance(node.value, ast.Dict) and len(node.targets) == 1: + target = node.targets[0] + if isinstance(target, ast.Name): + var_name = target.id + dict_values = [] + # Extract string values from the dictionary + for val in node.value.values: + if isinstance(val, ast.Constant) and isinstance(val.value, str): + dict_values.append(val.value) + # If non-string, skip or just ignore + if dict_values: + dict_definitions[var_name] = dict_values + self.generic_visit(node) + + class ReturnVisitor(ast.NodeVisitor): + def visit_Return(self, node): + # Direct string return + if isinstance(node.value, ast.Constant) and isinstance( + node.value.value, str + ): + return_values.add(node.value.value) + # Dictionary-based return, like return paths[result] + elif isinstance(node.value, ast.Subscript): + # Check if we're subscripting a known dictionary variable + if isinstance(node.value.value, ast.Name): + var_name = node.value.value.id + if var_name in dict_definitions: + # Add all possible dictionary values + for v in dict_definitions[var_name]: + return_values.add(v) + self.generic_visit(node) + + # First pass: identify dictionary assignments + DictionaryAssignmentVisitor().visit(code_ast) + # Second pass: identify returns + ReturnVisitor().visit(code_ast) + + return list(return_values) if return_values else None + + +def is_ancestor(node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]) -> bool: + """Check if one node is an ancestor of another in the flow graph. + + Args: + node: Target node to check ancestors for + ancestor_candidate: Node to check if it's an ancestor + ancestors: Dictionary mapping nodes to their ancestor sets + + Returns: + bool: True if ancestor_candidate is an ancestor of node + + Raises: + TypeError: If any argument has an invalid type + """ + if not isinstance(node, str): + raise TypeError("Argument 'node' must be a string") + if not isinstance(ancestor_candidate, str): + raise TypeError("Argument 'ancestor_candidate' must be a string") + if not isinstance(ancestors, dict): + raise TypeError("Argument 'ancestors' must be a dictionary") + + return ancestor_candidate in ancestors.get(node, set()) diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 25f2c2fff..cac5962a7 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -23,15 +23,37 @@ from crewai.flow.flow_events import ( MethodExecutionFinishedEvent, MethodExecutionStartedEvent, ) -from crewai.flow.flow_visualizer import plot_flow -from crewai.flow.utils import get_possible_return_constants +from crewai.flow.core_flow_utils import get_possible_return_constants from crewai.telemetry import Telemetry T = TypeVar("T", bound=Union[BaseModel, Dict[str, Any]]) -def start(condition=None): - """Marks a method as a flow starting point, optionally triggered by other methods.""" +def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable: + """Marks a method as a flow starting point, optionally triggered by other methods. + + Args: + condition: The condition that triggers this method. Can be: + - str: Name of the triggering method + - dict: Dictionary with 'type' and 'methods' keys for complex conditions + - Callable: A function reference + - None: No trigger condition (default) + + Returns: + Callable: The decorated function that will serve as a flow starting point. + + Raises: + ValueError: If the condition format is invalid. + + Example: + >>> @start() # No condition + >>> def begin_flow(): + >>> pass + >>> + >>> @start("method_name") # Triggered by specific method + >>> def conditional_start(): + >>> pass + """ def decorator(func): func.__is_start_method__ = True if condition is not None: @@ -57,8 +79,30 @@ def start(condition=None): return decorator -def listen(condition): - """Marks a method to execute when specified conditions/methods complete.""" +def listen(condition: Union[str, dict, Callable]) -> Callable: + """Marks a method to execute when specified conditions/methods complete. + + Args: + condition: The condition that triggers this method. Can be: + - str: Name of the triggering method + - dict: Dictionary with 'type' and 'methods' keys for complex conditions + - Callable: A function reference + + Returns: + Callable: The decorated function that will execute when conditions are met. + + Raises: + ValueError: If the condition format is invalid. + + Example: + >>> @listen("start_method") # Listen to single method + >>> def on_start(): + >>> pass + >>> + >>> @listen(and_("method1", "method2")) # Listen with AND condition + >>> def on_both_complete(): + >>> pass + """ def decorator(func): if isinstance(condition, str): func.__trigger_methods__ = [condition] @@ -82,8 +126,31 @@ def listen(condition): return decorator -def router(condition): - """Marks a method as a router to direct flow based on its return value.""" +def router(condition: Union[str, dict, Callable]) -> Callable: + """Marks a method as a router to direct flow based on its return value. + + A router method can return different string values that trigger different + subsequent methods, allowing for dynamic flow control. + + Args: + condition: The condition that triggers this router. Can be: + - str: Name of the triggering method + - dict: Dictionary with 'type' and 'methods' keys for complex conditions + - Callable: A function reference + + Returns: + Callable: The decorated function that will serve as a router. + + Raises: + ValueError: If the condition format is invalid. + + Example: + >>> @router("process_data") + >>> def route_result(result): + >>> if result.success: + >>> return "handle_success" + >>> return "handle_error" + """ def decorator(func): func.__is_router__ = True if isinstance(condition, str): @@ -108,8 +175,27 @@ def router(condition): return decorator -def or_(*conditions): - """Combines multiple conditions with OR logic for flow control.""" +def or_(*conditions: Union[str, dict, Callable]) -> dict: + """Combines multiple conditions with OR logic for flow control. + + Args: + *conditions: Variable number of conditions. Each can be: + - str: Name of a method + - dict: Dictionary with 'type' and 'methods' keys + - Callable: A function reference + + Returns: + dict: A dictionary with 'type': 'OR' and 'methods' list. + + Raises: + ValueError: If any condition is invalid. + + Example: + >>> @listen(or_("method1", "method2")) + >>> def on_either(): + >>> # Executes when either method1 OR method2 completes + >>> pass + """ methods = [] for condition in conditions: if isinstance(condition, dict) and "methods" in condition: @@ -123,8 +209,27 @@ def or_(*conditions): return {"type": "OR", "methods": methods} -def and_(*conditions): - """Combines multiple conditions with AND logic for flow control.""" +def and_(*conditions: Union[str, dict, Callable]) -> dict: + """Combines multiple conditions with AND logic for flow control. + + Args: + *conditions: Variable number of conditions. Each can be: + - str: Name of a method + - dict: Dictionary with 'type' and 'methods' keys + - Callable: A function reference + + Returns: + dict: A dictionary with 'type': 'AND' and 'methods' list. + + Raises: + ValueError: If any condition is invalid. + + Example: + >>> @listen(and_("method1", "method2")) + >>> def on_both(): + >>> # Executes when BOTH method1 AND method2 complete + >>> pass + """ methods = [] for condition in conditions: if isinstance(condition, dict) and "methods" in condition: @@ -183,6 +288,22 @@ class Flow(Generic[T], metaclass=FlowMeta): event_emitter = Signal("event_emitter") def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]: + """Create a generic version of Flow with specified state type. + + Args: + cls: The Flow class + item: The type parameter for the flow's state + + Returns: + Type["Flow"]: A new Flow class with the specified state type + + Example: + >>> class MyState(BaseModel): + >>> value: int + >>> + >>> class MyFlow(Flow[MyState]): + >>> pass + """ class _FlowGeneric(cls): # type: ignore _initial_state_T = item # type: ignore @@ -190,11 +311,23 @@ class Flow(Generic[T], metaclass=FlowMeta): return _FlowGeneric def __init__(self) -> None: + """Initialize a new Flow instance. + + Sets up internal state tracking, method registration, and telemetry. + The flow's methods are automatically discovered and registered during initialization. + + Attributes initialized: + _methods: Dictionary mapping method names to their callable objects + _state: The flow's state object of type T + _method_execution_counts: Tracks how many times each method has executed + _pending_and_listeners: Tracks methods waiting for AND conditions + _method_outputs: List of all outputs from executed methods + """ self._methods: Dict[str, Callable] = {} self._state: T = self._create_initial_state() self._method_execution_counts: Dict[str, int] = {} self._pending_and_listeners: Dict[str, Set[str]] = {} - self._method_outputs: List[Any] = [] # List to store all method outputs + self._method_outputs: List[Any] = [] self._telemetry.flow_creation_span(self.__class__.__name__) @@ -205,6 +338,20 @@ class Flow(Generic[T], metaclass=FlowMeta): self._methods[method_name] = getattr(self, method_name) def _create_initial_state(self) -> T: + """Create the initial state for the flow. + + The state is created based on the following priority: + 1. If initial_state is None and _initial_state_T exists (generic type), use that + 2. If initial_state is None, return empty dict + 3. If initial_state is a type, instantiate it + 4. Otherwise, use initial_state as-is + + Returns: + T: The initial state object of type T + + Note: + The type T can be either a Pydantic BaseModel or a dictionary. + """ if self.initial_state is None and hasattr(self, "_initial_state_T"): return self._initial_state_T() # type: ignore if self.initial_state is None: @@ -216,11 +363,21 @@ class Flow(Generic[T], metaclass=FlowMeta): @property def state(self) -> T: + """Get the current state of the flow. + + Returns: + T: The current state object, either a Pydantic model or dictionary + """ return self._state @property def method_outputs(self) -> List[Any]: - """Returns the list of all outputs from executed methods.""" + """Get the list of all outputs from executed methods. + + Returns: + List[Any]: A list containing the output values from all executed flow methods, + in order of execution. + """ return self._method_outputs def _initialize_state(self, inputs: Dict[str, Any]) -> None: @@ -310,6 +467,23 @@ class Flow(Generic[T], metaclass=FlowMeta): return result async def _execute_listeners(self, trigger_method: str, result: Any) -> None: + """Execute all listener methods triggered by a completed method. + + This method handles both router and non-router listeners in a specific order: + 1. First executes all triggered router methods sequentially until no more routers + are triggered + 2. Then executes all regular listeners in parallel + + Args: + trigger_method: The name of the method that completed execution + result: The result value from the triggering method + + Note: + Router methods are executed sequentially to ensure proper flow control, + while regular listeners are executed concurrently for better performance. + This provides fine-grained control over the execution flow while + maintaining efficiency. + """ # First, handle routers repeatedly until no router triggers anymore while True: routers_triggered = self._find_triggered_methods( @@ -339,6 +513,27 @@ class Flow(Generic[T], metaclass=FlowMeta): def _find_triggered_methods( self, trigger_method: str, router_only: bool ) -> List[str]: + """Find all methods that should be triggered based on completed method and type. + + Provides precise control over method triggering by handling both OR and AND + conditions separately for router and non-router methods. + + Args: + trigger_method: The name of the method that completed execution + router_only: If True, only find router methods; if False, only regular + listeners + + Returns: + List[str]: Names of methods that should be executed next + + Note: + This method implements sophisticated flow control by: + 1. Filtering methods based on their router/non-router status + 2. Handling OR conditions for immediate triggering + 3. Managing AND conditions with state tracking for complex dependencies + + This ensures predictable and consistent execution order in complex flows. + """ triggered = [] for listener_name, (condition_type, methods) in self._listeners.items(): is_router = listener_name in self._routers @@ -367,6 +562,27 @@ class Flow(Generic[T], metaclass=FlowMeta): return triggered async def _execute_single_listener(self, listener_name: str, result: Any) -> None: + """Execute a single listener method with precise parameter handling and error tracking. + + Provides fine-grained control over method execution through: + 1. Automatic parameter inspection to determine if the method accepts results + 2. Event emission for execution tracking + 3. Comprehensive error handling + 4. Recursive listener execution + + Args: + listener_name: The name of the listener method to execute + result: The result from the triggering method, passed to the listener + if its signature accepts parameters + + Note: + This method ensures precise execution control by: + - Inspecting method signatures to handle parameters correctly + - Emitting events for execution tracking + - Providing comprehensive error handling + - Supporting both parameterized and parameter-less methods + - Maintaining execution chain through recursive listener calls + """ try: method = self._methods[listener_name] @@ -410,8 +626,32 @@ class Flow(Generic[T], metaclass=FlowMeta): traceback.print_exc() - def plot(self, filename: str = "crewai_flow") -> None: + def plot(self, *args, **kwargs): + """Generate an interactive visualization of the flow's execution graph. + + Creates a detailed HTML visualization showing the relationships between + methods, including start points, listeners, routers, and their + connections. Includes telemetry tracking for flow analysis. + + Args: + *args: Variable length argument list passed to plot_flow + **kwargs: Arbitrary keyword arguments passed to plot_flow + + Note: + The visualization provides: + - Clear representation of method relationships + - Visual distinction between different method types + - Interactive exploration capabilities + - Execution path tracing + - Telemetry tracking for flow analysis + + Example: + >>> flow = MyFlow() + >>> flow.plot("my_workflow") # Creates my_workflow.html + """ + from crewai.flow.flow_visualizer import plot_flow + self._telemetry.flow_plotting_span( self.__class__.__name__, list(self._methods.keys()) ) - plot_flow(self, filename) + return plot_flow(self, *args, **kwargs) diff --git a/src/crewai/flow/flow_visual_utils.py b/src/crewai/flow/flow_visual_utils.py new file mode 100644 index 000000000..aeb27eb31 --- /dev/null +++ b/src/crewai/flow/flow_visual_utils.py @@ -0,0 +1,240 @@ +"""Utility functions for Flow visualization. + +This module contains utility functions specifically designed for visualizing +Flow graphs and calculating layout information. These utilities are separated +from general-purpose utilities to maintain a clean dependency structure. +""" + +from typing import TYPE_CHECKING, Dict, List, Set + +if TYPE_CHECKING: + from crewai.flow.flow import Flow + + +def calculate_node_levels(flow: Flow) -> Dict[str, int]: + """Calculate the hierarchical level of each node in the flow graph. + + Uses breadth-first traversal to assign levels to nodes, starting with + start methods at level 0. Handles both OR and AND conditions for listeners, + and considers router paths when calculating levels. + + Args: + flow: Flow instance containing methods, listeners, and router configurations + + Returns: + dict[str, int]: Dictionary mapping method names to their hierarchical levels, + where level 0 contains start methods and each subsequent level contains + methods triggered by the previous level + + Example: + >>> flow = Flow() + >>> @flow.start + ... def start(): pass + >>> @flow.on("start") + ... def second(): pass + >>> calculate_node_levels(flow) + {'start': 0, 'second': 1} + """ + levels = {} + queue = [] + visited = set() + pending_and_listeners = {} + + # Make all start methods at level 0 + for method_name, method in flow._methods.items(): + if hasattr(method, "__is_start_method__"): + levels[method_name] = 0 + queue.append(method_name) + + # Breadth-first traversal to assign levels + while queue: + current = queue.pop(0) + current_level = levels[current] + visited.add(current) + + for listener_name, (condition_type, trigger_methods) in flow._listeners.items(): + if condition_type == "OR": + if current in trigger_methods: + if ( + listener_name not in levels + or levels[listener_name] > current_level + 1 + ): + levels[listener_name] = current_level + 1 + if listener_name not in visited: + queue.append(listener_name) + elif condition_type == "AND": + if listener_name not in pending_and_listeners: + pending_and_listeners[listener_name] = set() + if current in trigger_methods: + pending_and_listeners[listener_name].add(current) + if set(trigger_methods) == pending_and_listeners[listener_name]: + if ( + listener_name not in levels + or levels[listener_name] > current_level + 1 + ): + levels[listener_name] = current_level + 1 + if listener_name not in visited: + queue.append(listener_name) + + # Handle router connections + if current in flow._routers: + router_method_name = current + paths = flow._router_paths.get(router_method_name, []) + for path in paths: + for listener_name, ( + condition_type, + trigger_methods, + ) in flow._listeners.items(): + if path in trigger_methods: + if ( + listener_name not in levels + or levels[listener_name] > current_level + 1 + ): + levels[listener_name] = current_level + 1 + if listener_name not in visited: + queue.append(listener_name) + + return levels + + +def count_outgoing_edges(flow: Flow) -> Dict[str, int]: + """Count the number of outgoing edges for each node in the flow graph. + + An outgoing edge represents a connection from a method to a listener + that it triggers. This is useful for visualization and analysis of + flow structure. + + Args: + flow: Flow instance containing methods and their connections + + Returns: + dict[str, int]: Dictionary mapping method names to their number + of outgoing connections + """ + counts = {} + for method_name in flow._methods: + counts[method_name] = 0 + for method_name in flow._listeners: + _, trigger_methods = flow._listeners[method_name] + for trigger in trigger_methods: + if trigger in flow._methods: + counts[trigger] += 1 + return counts + + +def build_ancestor_dict(flow: Flow) -> Dict[str, Set[str]]: + """Build a dictionary mapping each node to its set of ancestor nodes. + + Uses depth-first search to identify all ancestors (direct and indirect + trigger methods) for each node in the flow graph. Handles both regular + listeners and router paths. + + Args: + flow: Flow instance containing methods and their relationships + + Returns: + dict[str, set[str]]: Dictionary mapping each method name to a set + of its ancestor method names + """ + ancestors = {node: set() for node in flow._methods} + visited = set() + for node in flow._methods: + if node not in visited: + dfs_ancestors(node, ancestors, visited, flow) + return ancestors + + + + +def dfs_ancestors(node: str, ancestors: Dict[str, Set[str]], + visited: Set[str], flow: Flow) -> None: + """Perform depth-first search to populate the ancestors dictionary. + + Helper function for build_ancestor_dict that recursively traverses + the flow graph to identify ancestors of each node. + + Args: + node: Current node being processed + ancestors: Dictionary mapping nodes to their ancestor sets + visited: Set of already visited nodes + flow: Flow instance containing the graph structure + """ + if node in visited: + return + visited.add(node) + + # Handle regular listeners + for listener_name, (_, trigger_methods) in flow._listeners.items(): + if node in trigger_methods: + ancestors[listener_name].add(node) + ancestors[listener_name].update(ancestors[node]) + dfs_ancestors(listener_name, ancestors, visited, flow) + + # Handle router methods separately + if node in flow._routers: + router_method_name = node + paths = flow._router_paths.get(router_method_name, []) + for path in paths: + for listener_name, (_, trigger_methods) in flow._listeners.items(): + if path in trigger_methods: + # Only propagate the ancestors of the router method, not the router method itself + ancestors[listener_name].update(ancestors[node]) + dfs_ancestors(listener_name, ancestors, visited, flow) + + +def build_parent_children_dict(flow: Flow) -> Dict[str, List[str]]: + """Build a dictionary mapping each node to its list of child nodes. + + Maps both regular trigger methods to their listeners and router + methods to their path listeners. Useful for visualization and + traversal of the flow graph structure. + + Args: + flow: Flow instance containing methods and their relationships + + Returns: + dict[str, list[str]]: Dictionary mapping each method name to a + sorted list of its child method names + """ + parent_children = {} + + # Map listeners to their trigger methods + for listener_name, (_, trigger_methods) in flow._listeners.items(): + for trigger in trigger_methods: + if trigger not in parent_children: + parent_children[trigger] = [] + if listener_name not in parent_children[trigger]: + parent_children[trigger].append(listener_name) + + # Map router methods to their paths and to listeners + for router_method_name, paths in flow._router_paths.items(): + for path in paths: + # Map router method to listeners of each path + for listener_name, (_, trigger_methods) in flow._listeners.items(): + if path in trigger_methods: + if router_method_name not in parent_children: + parent_children[router_method_name] = [] + if listener_name not in parent_children[router_method_name]: + parent_children[router_method_name].append(listener_name) + + return parent_children + + +def get_child_index(parent: str, child: str, + parent_children: Dict[str, List[str]]) -> int: + """Get the index of a child node in its parent's sorted children list. + + Args: + parent: Parent node name + child: Child node name to find index for + parent_children: Dictionary mapping parents to their children lists + + Returns: + int: Zero-based index of the child in parent's sorted children list + + Raises: + ValueError: If child is not found in parent's children list + """ + children = parent_children.get(parent, []) + children.sort() + return children.index(child) diff --git a/src/crewai/flow/flow_visualizer.py b/src/crewai/flow/flow_visualizer.py index b78c4b717..c1c8d61c8 100644 --- a/src/crewai/flow/flow_visualizer.py +++ b/src/crewai/flow/flow_visualizer.py @@ -1,13 +1,16 @@ # flow_visualizer.py import os +from pathlib import Path from pyvis.network import Network +from crewai.flow.path_utils import safe_path_join, validate_file_path + from crewai.flow.config import COLORS, NODE_STYLES from crewai.flow.html_template_handler import HTMLTemplateHandler from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items -from crewai.flow.utils import calculate_node_levels +from crewai.flow.flow_visual_utils import calculate_node_levels from crewai.flow.visualization_utils import ( add_edges, add_nodes_to_network, @@ -19,7 +22,21 @@ class FlowPlot: """Handles the creation and rendering of flow visualization diagrams.""" def __init__(self, flow): - """Initialize flow plot with flow instance and styling configuration.""" + """Initialize flow plot with flow instance and styling configuration. + + Args: + flow: A Flow instance with required attributes for visualization + + Raises: + ValueError: If flow object is invalid or missing required attributes + """ + if not hasattr(flow, '_methods'): + raise ValueError("Invalid flow object: Missing '_methods' attribute") + if not hasattr(flow, '_start_methods'): + raise ValueError("Invalid flow object: Missing '_start_methods' attribute") + if not hasattr(flow, '_listeners'): + raise ValueError("Invalid flow object: Missing '_listeners' attribute") + self.flow = flow self.colors = COLORS self.node_styles = NODE_STYLES @@ -58,9 +75,16 @@ class FlowPlot: network_html = net.generate_html() final_html_content = self._generate_final_html(network_html) - with open(f"{filename}.html", "w", encoding="utf-8") as f: - f.write(final_html_content) - print(f"Plot saved as {filename}.html") + try: + # Ensure the output path is safe + output_dir = os.getcwd() + output_path = safe_path_join(output_dir, f"{filename}.html") + + with open(output_path, "w", encoding="utf-8") as f: + f.write(final_html_content) + print(f"Plot saved as {output_path}") + except (IOError, ValueError) as e: + raise IOError(f"Failed to save flow visualization: {str(e)}") self._cleanup_pyvis_lib() diff --git a/src/crewai/flow/html_template_handler.py b/src/crewai/flow/html_template_handler.py index 2af2665c1..a0b789874 100644 --- a/src/crewai/flow/html_template_handler.py +++ b/src/crewai/flow/html_template_handler.py @@ -1,35 +1,107 @@ import base64 +import os import re +from pathlib import Path + +from crewai.flow.path_utils import safe_path_join, validate_file_path class HTMLTemplateHandler: """Handles HTML template processing and generation for flow visualization diagrams.""" def __init__(self, template_path, logo_path): - """Initialize template handler with template and logo file paths.""" - self.template_path = template_path - self.logo_path = logo_path + """Initialize template handler with template and logo file paths. + + Args: + template_path: Path to the HTML template file + logo_path: Path to the logo SVG file + + Raises: + ValueError: If template_path or logo_path is invalid or files don't exist + """ + try: + self.template_path = validate_file_path(template_path) + self.logo_path = validate_file_path(logo_path) + except (ValueError, TypeError) as e: + raise ValueError(f"Invalid file path: {str(e)}") def read_template(self): - """Read and return the HTML template file contents.""" - with open(self.template_path, "r", encoding="utf-8") as f: - return f.read() + """Read and return the HTML template file contents. + + Returns: + str: The contents of the template file + + Raises: + IOError: If template file cannot be read + """ + try: + with open(self.template_path, "r", encoding="utf-8") as f: + return f.read() + except IOError as e: + raise IOError(f"Failed to read template file {self.template_path}: {str(e)}") def encode_logo(self): - """Convert the logo SVG file to base64 encoded string.""" - with open(self.logo_path, "rb") as logo_file: - logo_svg_data = logo_file.read() - return base64.b64encode(logo_svg_data).decode("utf-8") + """Convert the logo SVG file to base64 encoded string. + + Returns: + str: Base64 encoded logo data + + Raises: + IOError: If logo file cannot be read + ValueError: If logo data cannot be encoded + """ + try: + with open(self.logo_path, "rb") as logo_file: + logo_svg_data = logo_file.read() + try: + return base64.b64encode(logo_svg_data).decode("utf-8") + except Exception as e: + raise ValueError(f"Failed to encode logo data: {str(e)}") + except IOError as e: + raise IOError(f"Failed to read logo file {self.logo_path}: {str(e)}") def extract_body_content(self, html): - """Extract and return content between body tags from HTML string.""" + """Extract and return content between body tags from HTML string. + + Args: + html: HTML string to extract body content from + + Returns: + str: Content between body tags, or empty string if not found + + Raises: + ValueError: If input HTML is invalid + """ + if not html or not isinstance(html, str): + raise ValueError("Input HTML must be a non-empty string") + match = re.search("(.*?)", html, re.DOTALL) return match.group(1) if match else "" def generate_legend_items_html(self, legend_items): - """Generate HTML markup for the legend items.""" + """Generate HTML markup for the legend items. + + Args: + legend_items: List of dictionaries containing legend item properties + + Returns: + str: Generated HTML markup for legend items + + Raises: + ValueError: If legend_items is invalid or missing required properties + """ + if not isinstance(legend_items, list): + raise ValueError("legend_items must be a list") + legend_items_html = "" for item in legend_items: + if not isinstance(item, dict): + raise ValueError("Each legend item must be a dictionary") + if "color" not in item: + raise ValueError("Each legend item must have a 'color' property") + if "label" not in item: + raise ValueError("Each legend item must have a 'label' property") + if "border" in item: legend_items_html += f"""
@@ -55,19 +127,42 @@ class HTMLTemplateHandler: return legend_items_html def generate_final_html(self, network_body, legend_items_html, title="Flow Plot"): - """Combine all components into final HTML document with network visualization.""" - html_template = self.read_template() - logo_svg_base64 = self.encode_logo() + """Combine all components into final HTML document with network visualization. + + Args: + network_body: HTML string containing network visualization + legend_items_html: HTML string containing legend items markup + title: Title for the visualization page (default: "Flow Plot") + + Returns: + str: Complete HTML document with all components integrated + + Raises: + ValueError: If any input parameters are invalid + IOError: If template or logo files cannot be read + """ + if not isinstance(network_body, str): + raise ValueError("network_body must be a string") + if not isinstance(legend_items_html, str): + raise ValueError("legend_items_html must be a string") + if not isinstance(title, str): + raise ValueError("title must be a string") + + try: + html_template = self.read_template() + logo_svg_base64 = self.encode_logo() - final_html_content = html_template.replace("{{ title }}", title) - final_html_content = final_html_content.replace( - "{{ network_content }}", network_body - ) - final_html_content = final_html_content.replace( - "{{ logo_svg_base64 }}", logo_svg_base64 - ) - final_html_content = final_html_content.replace( - "", legend_items_html - ) + final_html_content = html_template.replace("{{ title }}", title) + final_html_content = final_html_content.replace( + "{{ network_content }}", network_body + ) + final_html_content = final_html_content.replace( + "{{ logo_svg_base64 }}", logo_svg_base64 + ) + final_html_content = final_html_content.replace( + "", legend_items_html + ) - return final_html_content + return final_html_content + except Exception as e: + raise ValueError(f"Failed to generate final HTML: {str(e)}") diff --git a/src/crewai/flow/path_utils.py b/src/crewai/flow/path_utils.py new file mode 100644 index 000000000..e770477a5 --- /dev/null +++ b/src/crewai/flow/path_utils.py @@ -0,0 +1,123 @@ +"""Utilities for safe path handling in flow visualization. + +This module provides a comprehensive set of utilities for secure path handling, +including path joining, validation, and normalization. It helps prevent common +security issues like directory traversal attacks while providing a consistent +interface for path operations. +""" + +import os +from pathlib import Path +from typing import Union, List, Optional + + +def safe_path_join(base_dir: Union[str, Path], filename: str) -> str: + """Safely join base directory with filename, preventing directory traversal. + + Args: + base_dir: Base directory path + filename: Filename or path to join with base_dir + + Returns: + str: Safely joined absolute path + + Raises: + ValueError: If resulting path would escape base_dir or contains dangerous patterns + TypeError: If inputs are not strings or Path objects + OSError: If path resolution fails + """ + if not isinstance(base_dir, (str, Path)): + raise TypeError("base_dir must be a string or Path object") + if not isinstance(filename, str): + raise TypeError("filename must be a string") + + # Check for dangerous patterns + dangerous_patterns = ['..', '~', '*', '?', '|', '>', '<', '$', '&', '`'] + if any(pattern in filename for pattern in dangerous_patterns): + raise ValueError(f"Invalid filename: Contains dangerous pattern") + + try: + base_path = Path(base_dir).resolve(strict=True) + full_path = Path(base_path, filename).resolve(strict=True) + + if not str(full_path).startswith(str(base_path)): + raise ValueError( + f"Invalid path: {filename} would escape base directory {base_dir}" + ) + + return str(full_path) + except OSError as e: + raise OSError(f"Failed to resolve path: {str(e)}") + except Exception as e: + raise ValueError(f"Failed to process paths: {str(e)}") + + +def normalize_path(path: Union[str, Path]) -> str: + """Normalize a path by resolving symlinks and removing redundant separators. + + Args: + path: Path to normalize + + Returns: + str: Normalized absolute path + + Raises: + TypeError: If path is not a string or Path object + OSError: If path resolution fails + """ + if not isinstance(path, (str, Path)): + raise TypeError("path must be a string or Path object") + + try: + return str(Path(path).resolve(strict=True)) + except OSError as e: + raise OSError(f"Failed to normalize path: {str(e)}") + + +def validate_path_components(components: List[str]) -> None: + """Validate path components for potentially dangerous patterns. + + Args: + components: List of path components to validate + + Raises: + TypeError: If components is not a list or contains non-string items + ValueError: If any component contains dangerous patterns + """ + if not isinstance(components, list): + raise TypeError("components must be a list") + + dangerous_patterns = ['..', '~', '*', '?', '|', '>', '<', '$', '&', '`'] + for component in components: + if not isinstance(component, str): + raise TypeError(f"Path component '{component}' must be a string") + if any(pattern in component for pattern in dangerous_patterns): + raise ValueError(f"Invalid path component '{component}': Contains dangerous pattern") + + +def validate_file_path(path: Union[str, Path], must_exist: bool = True) -> str: + """Validate a file path for security and existence. + + Args: + path: File path to validate + must_exist: Whether the file must exist (default: True) + + Returns: + str: Validated absolute path + + Raises: + ValueError: If path is invalid or file doesn't exist when required + TypeError: If path is not a string or Path object + """ + if not isinstance(path, (str, Path)): + raise TypeError("path must be a string or Path object") + + try: + resolved_path = Path(path).resolve() + + if must_exist and not resolved_path.is_file(): + raise ValueError(f"File not found: {path}") + + return str(resolved_path) + except Exception as e: + raise ValueError(f"Invalid file path {path}: {str(e)}") diff --git a/src/crewai/flow/utils.py b/src/crewai/flow/utils.py index abbd69a33..26f266ae2 100644 --- a/src/crewai/flow/utils.py +++ b/src/crewai/flow/utils.py @@ -1,226 +1,36 @@ -"""Utility functions for flow execution and visualization. +"""General utility functions for flow execution. -Provides helper functions for analyzing flow structure, calculating -node positions, and extracting return values from methods. +This module has been deprecated. All functionality has been moved to: +- core_flow_utils.py: Core flow execution utilities +- flow_visual_utils.py: Visualization-related utilities + +This module is kept as a temporary redirect to maintain backwards compatibility. +New code should import from the appropriate new modules directly. """ -import ast -import inspect -import textwrap +from typing import Any, Dict, List, Optional, Set +from .core_flow_utils import get_possible_return_constants +from .flow_visual_utils import ( + build_ancestor_dict, + build_parent_children_dict, + calculate_node_levels, + count_outgoing_edges, + dfs_ancestors, + get_child_index, + is_ancestor, +) -def get_possible_return_constants(function): - try: - source = inspect.getsource(function) - except OSError: - # Can't get source code - return None - except Exception as e: - print(f"Error retrieving source code for function {function.__name__}: {e}") - return None +# Re-export all functions for backwards compatibility +__all__ = [ + 'get_possible_return_constants', + 'calculate_node_levels', + 'count_outgoing_edges', + 'build_ancestor_dict', + 'dfs_ancestors', + 'is_ancestor', + 'build_parent_children_dict', + 'get_child_index', +] - try: - # Remove leading indentation - source = textwrap.dedent(source) - # Parse the source code into an AST - code_ast = ast.parse(source) - except IndentationError as e: - print(f"IndentationError while parsing source code of {function.__name__}: {e}") - print(f"Source code:\n{source}") - return None - except SyntaxError as e: - print(f"SyntaxError while parsing source code of {function.__name__}: {e}") - print(f"Source code:\n{source}") - return None - except Exception as e: - print(f"Unexpected error while parsing source code of {function.__name__}: {e}") - print(f"Source code:\n{source}") - return None - - return_values = set() - dict_definitions = {} - - class DictionaryAssignmentVisitor(ast.NodeVisitor): - def visit_Assign(self, node): - # Check if this assignment is assigning a dictionary literal to a variable - if isinstance(node.value, ast.Dict) and len(node.targets) == 1: - target = node.targets[0] - if isinstance(target, ast.Name): - var_name = target.id - dict_values = [] - # Extract string values from the dictionary - for val in node.value.values: - if isinstance(val, ast.Constant) and isinstance(val.value, str): - dict_values.append(val.value) - # If non-string, skip or just ignore - if dict_values: - dict_definitions[var_name] = dict_values - self.generic_visit(node) - - class ReturnVisitor(ast.NodeVisitor): - def visit_Return(self, node): - # Direct string return - if isinstance(node.value, ast.Constant) and isinstance( - node.value.value, str - ): - return_values.add(node.value.value) - # Dictionary-based return, like return paths[result] - elif isinstance(node.value, ast.Subscript): - # Check if we're subscripting a known dictionary variable - if isinstance(node.value.value, ast.Name): - var_name = node.value.value.id - if var_name in dict_definitions: - # Add all possible dictionary values - for v in dict_definitions[var_name]: - return_values.add(v) - self.generic_visit(node) - - # First pass: identify dictionary assignments - DictionaryAssignmentVisitor().visit(code_ast) - # Second pass: identify returns - ReturnVisitor().visit(code_ast) - - return list(return_values) if return_values else None - - -def calculate_node_levels(flow): - levels = {} - queue = [] - visited = set() - pending_and_listeners = {} - - # Make all start methods at level 0 - for method_name, method in flow._methods.items(): - if hasattr(method, "__is_start_method__"): - levels[method_name] = 0 - queue.append(method_name) - - # Breadth-first traversal to assign levels - while queue: - current = queue.pop(0) - current_level = levels[current] - visited.add(current) - - for listener_name, (condition_type, trigger_methods) in flow._listeners.items(): - if condition_type == "OR": - if current in trigger_methods: - if ( - listener_name not in levels - or levels[listener_name] > current_level + 1 - ): - levels[listener_name] = current_level + 1 - if listener_name not in visited: - queue.append(listener_name) - elif condition_type == "AND": - if listener_name not in pending_and_listeners: - pending_and_listeners[listener_name] = set() - if current in trigger_methods: - pending_and_listeners[listener_name].add(current) - if set(trigger_methods) == pending_and_listeners[listener_name]: - if ( - listener_name not in levels - or levels[listener_name] > current_level + 1 - ): - levels[listener_name] = current_level + 1 - if listener_name not in visited: - queue.append(listener_name) - - # Handle router connections - if current in flow._routers: - router_method_name = current - paths = flow._router_paths.get(router_method_name, []) - for path in paths: - for listener_name, ( - condition_type, - trigger_methods, - ) in flow._listeners.items(): - if path in trigger_methods: - if ( - listener_name not in levels - or levels[listener_name] > current_level + 1 - ): - levels[listener_name] = current_level + 1 - if listener_name not in visited: - queue.append(listener_name) - - return levels - - -def count_outgoing_edges(flow): - counts = {} - for method_name in flow._methods: - counts[method_name] = 0 - for method_name in flow._listeners: - _, trigger_methods = flow._listeners[method_name] - for trigger in trigger_methods: - if trigger in flow._methods: - counts[trigger] += 1 - return counts - - -def build_ancestor_dict(flow): - ancestors = {node: set() for node in flow._methods} - visited = set() - for node in flow._methods: - if node not in visited: - dfs_ancestors(node, ancestors, visited, flow) - return ancestors - - -def dfs_ancestors(node, ancestors, visited, flow): - if node in visited: - return - visited.add(node) - - # Handle regular listeners - for listener_name, (_, trigger_methods) in flow._listeners.items(): - if node in trigger_methods: - ancestors[listener_name].add(node) - ancestors[listener_name].update(ancestors[node]) - dfs_ancestors(listener_name, ancestors, visited, flow) - - # Handle router methods separately - if node in flow._routers: - router_method_name = node - paths = flow._router_paths.get(router_method_name, []) - for path in paths: - for listener_name, (_, trigger_methods) in flow._listeners.items(): - if path in trigger_methods: - # Only propagate the ancestors of the router method, not the router method itself - ancestors[listener_name].update(ancestors[node]) - dfs_ancestors(listener_name, ancestors, visited, flow) - - -def is_ancestor(node, ancestor_candidate, ancestors): - return ancestor_candidate in ancestors.get(node, set()) - - -def build_parent_children_dict(flow): - parent_children = {} - - # Map listeners to their trigger methods - for listener_name, (_, trigger_methods) in flow._listeners.items(): - for trigger in trigger_methods: - if trigger not in parent_children: - parent_children[trigger] = [] - if listener_name not in parent_children[trigger]: - parent_children[trigger].append(listener_name) - - # Map router methods to their paths and to listeners - for router_method_name, paths in flow._router_paths.items(): - for path in paths: - # Map router method to listeners of each path - for listener_name, (_, trigger_methods) in flow._listeners.items(): - if path in trigger_methods: - if router_method_name not in parent_children: - parent_children[router_method_name] = [] - if listener_name not in parent_children[router_method_name]: - parent_children[router_method_name].append(listener_name) - - return parent_children - - -def get_child_index(parent, child, parent_children): - children = parent_children.get(parent, []) - children.sort() - return children.index(child) +# Function implementations have been moved to core_flow_utils.py and flow_visual_utils.py diff --git a/src/crewai/flow/visualization_utils.py b/src/crewai/flow/visualization_utils.py index f87a94e61..a10c3198d 100644 --- a/src/crewai/flow/visualization_utils.py +++ b/src/crewai/flow/visualization_utils.py @@ -1,26 +1,56 @@ import ast import inspect +import os +from pathlib import Path +from typing import Dict, Optional, Tuple -from .utils import ( +from .core_flow_utils import is_ancestor +from .flow_visual_utils import ( build_ancestor_dict, build_parent_children_dict, get_child_index, - is_ancestor, ) +from .path_utils import safe_path_join, validate_file_path -def method_calls_crew(method): - """Check if the method contains a .crew() call.""" +def method_calls_crew(method: callable) -> bool: + """Check if the method contains a .crew() call in its implementation. + + Analyzes the method's source code using AST to detect if it makes any + calls to the .crew() method, which indicates crew involvement in the + flow execution. + + Args: + method: The method to analyze for crew calls + + Returns: + bool: True if the method contains a .crew() call, False otherwise + + Raises: + Exception: If method source code cannot be parsed + """ + if not callable(method): + raise TypeError("Input must be a callable method") + try: source = inspect.getsource(method) source = inspect.cleandoc(source) tree = ast.parse(source) + except (TypeError, ValueError, OSError) as e: + raise ValueError(f"Could not parse method {getattr(method, '__name__', str(method))}: {e}") except Exception as e: - print(f"Could not parse method {method.__name__}: {e}") - return False + raise RuntimeError(f"Unexpected error parsing method: {e}") class CrewCallVisitor(ast.NodeVisitor): - """AST visitor to detect .crew() method calls.""" + """AST visitor to detect .crew() method calls in source code. + + A specialized AST visitor that analyzes Python source code to precisely + identify calls to the .crew() method, enabling accurate detection of + crew involvement in flow methods. + + Attributes: + found (bool): Indicates whether a .crew() call was found + """ def __init__(self): self.found = False @@ -35,9 +65,64 @@ def method_calls_crew(method): return visitor.found -def add_nodes_to_network(net, flow, node_positions, node_styles): - """Add nodes to the network visualization with appropriate styling.""" - def human_friendly_label(method_name): +def add_nodes_to_network(net: object, flow: object, + node_positions: Dict[str, Tuple[float, float]], + node_styles: Dict[str, dict], + output_dir: Optional[str] = None) -> None: + """Add nodes to the network visualization with precise styling and positioning. + + Creates and styles nodes in the visualization network based on their type + (start, router, crew, or regular method) with fine-grained control over + appearance and positioning. + + Args: + net: The network visualization object to add nodes to + flow: Flow object containing method definitions and relationships + node_positions: Dictionary mapping method names to (x,y) coordinates + node_styles: Dictionary mapping node types to their visual styles + output_dir: Optional directory path for saving visualization assets + + Returns: + None + + Raises: + ValueError: If flow object is invalid or required styles are missing + TypeError: If input arguments have incorrect types + OSError: If output directory operations fail + + Note: + Node styles are applied with precise control over shape, font, color, + and positioning to ensure accurate visual representation of the flow. + If output_dir is provided, it will be validated and created if needed. + """ + if not hasattr(flow, '_methods'): + raise ValueError("Invalid flow object: missing '_methods' attribute") + if not isinstance(node_positions, dict): + raise TypeError("node_positions must be a dictionary") + if not isinstance(node_styles, dict): + raise TypeError("node_styles must be a dictionary") + + required_styles = {'start', 'router', 'crew', 'method'} + missing_styles = required_styles - set(node_styles.keys()) + if missing_styles: + raise ValueError(f"Missing required node styles: {missing_styles}") + + # Validate and create output directory if specified + if output_dir: + try: + output_dir = validate_file_path(output_dir, must_exist=False) + os.makedirs(output_dir, exist_ok=True) + except (ValueError, OSError) as e: + raise OSError(f"Failed to create or validate output directory: {e}") + def human_friendly_label(method_name: str) -> str: + """Convert method name to human-readable format. + + Args: + method_name: Original method name with underscores + + Returns: + str: Formatted method name with spaces and title case + """ return method_name.replace("_", " ").title() for method_name, (x, y) in node_positions.items(): @@ -54,6 +139,15 @@ def add_nodes_to_network(net, flow, node_positions, node_styles): node_style = node_style.copy() label = human_friendly_label(method_name) + # Handle file-based assets if output directory is provided + if output_dir and node_style.get("image"): + try: + image_path = node_style["image"] + safe_image_path = safe_path_join(output_dir, Path(image_path).name) + node_style["image"] = str(safe_image_path) + except (ValueError, OSError) as e: + raise OSError(f"Failed to process node image path: {e}") + node_style.update( { "label": label, @@ -75,8 +169,39 @@ def add_nodes_to_network(net, flow, node_positions, node_styles): ) -def compute_positions(flow, node_levels, y_spacing=150, x_spacing=150): - """Calculate x,y coordinates for each node in the flow diagram.""" +def compute_positions(flow: object, node_levels: dict[str, int], + y_spacing: float = 150, x_spacing: float = 150) -> dict[str, tuple[float, float]]: + if not hasattr(flow, '_methods'): + raise ValueError("Invalid flow object: missing '_methods' attribute") + if not isinstance(node_levels, dict): + raise TypeError("node_levels must be a dictionary") + if not isinstance(y_spacing, (int, float)) or y_spacing <= 0: + raise ValueError("y_spacing must be a positive number") + if not isinstance(x_spacing, (int, float)) or x_spacing <= 0: + raise ValueError("x_spacing must be a positive number") + + if not node_levels: + raise ValueError("node_levels dictionary cannot be empty") + """Calculate precise x,y coordinates for each node in the flow diagram. + + Computes optimal node positions with fine-grained control over spacing + and alignment, ensuring clear visualization of flow hierarchy and + relationships. + + Args: + flow: Flow object containing method definitions + node_levels: Dictionary mapping method names to their hierarchy levels + y_spacing: Vertical spacing between hierarchy levels (default: 150) + x_spacing: Horizontal spacing between nodes at same level (default: 150) + + Returns: + dict[str, tuple[float, float]]: Dictionary mapping method names to + their calculated (x,y) coordinates in the visualization + + Note: + Positions are calculated to maintain clear hierarchical structure while + ensuring optimal spacing and readability of the flow diagram. + """ level_nodes = {} node_positions = {} @@ -93,8 +218,56 @@ def compute_positions(flow, node_levels, y_spacing=150, x_spacing=150): return node_positions -def add_edges(net, flow, node_positions, colors): - """Add edges between nodes with appropriate styling and routing.""" +def add_edges(net: object, flow: object, + node_positions: Dict[str, Tuple[float, float]], + colors: Dict[str, str], + asset_dir: Optional[str] = None) -> None: + if not hasattr(flow, '_methods'): + raise ValueError("Invalid flow object: missing '_methods' attribute") + if not hasattr(flow, '_listeners'): + raise ValueError("Invalid flow object: missing '_listeners' attribute") + if not hasattr(flow, '_router_paths'): + raise ValueError("Invalid flow object: missing '_router_paths' attribute") + + if not isinstance(node_positions, dict): + raise TypeError("node_positions must be a dictionary") + if not isinstance(colors, dict): + raise TypeError("colors must be a dictionary") + + required_colors = {'edge', 'router_edge'} + missing_colors = required_colors - set(colors.keys()) + if missing_colors: + raise ValueError(f"Missing required edge colors: {missing_colors}") + + # Validate asset directory if provided + if asset_dir: + try: + asset_dir = validate_file_path(asset_dir, must_exist=False) + os.makedirs(asset_dir, exist_ok=True) + except (ValueError, OSError) as e: + raise OSError(f"Failed to create or validate asset directory: {e}") + """Add edges between nodes with precise styling and intelligent routing. + + Creates and styles edges in the visualization with fine-grained control over + appearance, routing, and curvature. Handles both normal method connections + and router paths with specialized styling. + + Args: + net: The network visualization object to add edges to + flow: Flow object containing method relationships and router paths + node_positions: Dictionary mapping method names to (x,y) coordinates + colors: Dictionary mapping edge types to their colors + + Returns: + None + + Note: + Implements sophisticated edge routing with: + - Automatic curve direction based on node positions + - Dynamic curvature adjustment for multiple edges + - Distinct styling for router paths and AND conditions + - Cycle detection and appropriate visualization + """ ancestors = build_ancestor_dict(flow) parent_children = build_parent_children_dict(flow)