diff --git a/src/crewai/flow/core_flow_utils.py b/src/crewai/flow/core_flow_utils.py new file mode 100644 index 000000000..b886f8658 --- /dev/null +++ b/src/crewai/flow/core_flow_utils.py @@ -0,0 +1,141 @@ +"""Core utility functions for Flow class operations. + +This module contains utility functions that are specifically designed to work +with the Flow class and require direct access to Flow class internals. These +utilities are separated from general-purpose utilities to maintain a clean +dependency structure and avoid circular imports. + +Functions in this module are core to Flow functionality and are not related +to visualization or other optional features. +""" + +import ast +import inspect +import textwrap +from typing import Any, Callable, Dict, List, Optional, Set, Union + +from pydantic import BaseModel + + +def get_possible_return_constants(function: callable) -> Optional[List[str]]: + """Extract possible string return values from a function by analyzing its source code. + + Analyzes the function's source code using AST to identify string constants that + could be returned, including strings stored in dictionaries and direct returns. + + Args: + function: The function to analyze for possible return values + + Returns: + list[str] | None: List of possible string return values, or None if: + - Source code cannot be retrieved + - Source code has syntax/indentation errors + - No string return values are found + + Raises: + OSError: If source code cannot be retrieved + IndentationError: If source code has invalid indentation + SyntaxError: If source code has syntax errors + + Example: + >>> def get_status(): + ... paths = {"success": "completed", "error": "failed"} + ... return paths["success"] + >>> get_possible_return_constants(get_status) + ['completed', 'failed'] + """ + try: + source = inspect.getsource(function) + except OSError: + # Can't get source code + return None + except Exception as e: + print(f"Error retrieving source code for function {function.__name__}: {e}") + return None + + try: + # Remove leading indentation + source = textwrap.dedent(source) + # Parse the source code into an AST + code_ast = ast.parse(source) + except IndentationError as e: + print(f"IndentationError while parsing source code of {function.__name__}: {e}") + print(f"Source code:\n{source}") + return None + except SyntaxError as e: + print(f"SyntaxError while parsing source code of {function.__name__}: {e}") + print(f"Source code:\n{source}") + return None + except Exception as e: + print(f"Unexpected error while parsing source code of {function.__name__}: {e}") + print(f"Source code:\n{source}") + return None + + return_values = set() + dict_definitions = {} + + class DictionaryAssignmentVisitor(ast.NodeVisitor): + def visit_Assign(self, node): + # Check if this assignment is assigning a dictionary literal to a variable + if isinstance(node.value, ast.Dict) and len(node.targets) == 1: + target = node.targets[0] + if isinstance(target, ast.Name): + var_name = target.id + dict_values = [] + # Extract string values from the dictionary + for val in node.value.values: + if isinstance(val, ast.Constant) and isinstance(val.value, str): + dict_values.append(val.value) + # If non-string, skip or just ignore + if dict_values: + dict_definitions[var_name] = dict_values + self.generic_visit(node) + + class ReturnVisitor(ast.NodeVisitor): + def visit_Return(self, node): + # Direct string return + if isinstance(node.value, ast.Constant) and isinstance( + node.value.value, str + ): + return_values.add(node.value.value) + # Dictionary-based return, like return paths[result] + elif isinstance(node.value, ast.Subscript): + # Check if we're subscripting a known dictionary variable + if isinstance(node.value.value, ast.Name): + var_name = node.value.value.id + if var_name in dict_definitions: + # Add all possible dictionary values + for v in dict_definitions[var_name]: + return_values.add(v) + self.generic_visit(node) + + # First pass: identify dictionary assignments + DictionaryAssignmentVisitor().visit(code_ast) + # Second pass: identify returns + ReturnVisitor().visit(code_ast) + + return list(return_values) if return_values else None + + +def is_ancestor(node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]) -> bool: + """Check if one node is an ancestor of another in the flow graph. + + Args: + node: Target node to check ancestors for + ancestor_candidate: Node to check if it's an ancestor + ancestors: Dictionary mapping nodes to their ancestor sets + + Returns: + bool: True if ancestor_candidate is an ancestor of node + + Raises: + TypeError: If any argument has an invalid type + """ + if not isinstance(node, str): + raise TypeError("Argument 'node' must be a string") + if not isinstance(ancestor_candidate, str): + raise TypeError("Argument 'ancestor_candidate' must be a string") + if not isinstance(ancestors, dict): + raise TypeError("Argument 'ancestors' must be a dictionary") + + return ancestor_candidate in ancestors.get(node, set()) diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 25f2c2fff..cac5962a7 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -23,15 +23,37 @@ from crewai.flow.flow_events import ( MethodExecutionFinishedEvent, MethodExecutionStartedEvent, ) -from crewai.flow.flow_visualizer import plot_flow -from crewai.flow.utils import get_possible_return_constants +from crewai.flow.core_flow_utils import get_possible_return_constants from crewai.telemetry import Telemetry T = TypeVar("T", bound=Union[BaseModel, Dict[str, Any]]) -def start(condition=None): - """Marks a method as a flow starting point, optionally triggered by other methods.""" +def start(condition: Optional[Union[str, dict, Callable]] = None) -> Callable: + """Marks a method as a flow starting point, optionally triggered by other methods. + + Args: + condition: The condition that triggers this method. Can be: + - str: Name of the triggering method + - dict: Dictionary with 'type' and 'methods' keys for complex conditions + - Callable: A function reference + - None: No trigger condition (default) + + Returns: + Callable: The decorated function that will serve as a flow starting point. + + Raises: + ValueError: If the condition format is invalid. + + Example: + >>> @start() # No condition + >>> def begin_flow(): + >>> pass + >>> + >>> @start("method_name") # Triggered by specific method + >>> def conditional_start(): + >>> pass + """ def decorator(func): func.__is_start_method__ = True if condition is not None: @@ -57,8 +79,30 @@ def start(condition=None): return decorator -def listen(condition): - """Marks a method to execute when specified conditions/methods complete.""" +def listen(condition: Union[str, dict, Callable]) -> Callable: + """Marks a method to execute when specified conditions/methods complete. + + Args: + condition: The condition that triggers this method. Can be: + - str: Name of the triggering method + - dict: Dictionary with 'type' and 'methods' keys for complex conditions + - Callable: A function reference + + Returns: + Callable: The decorated function that will execute when conditions are met. + + Raises: + ValueError: If the condition format is invalid. + + Example: + >>> @listen("start_method") # Listen to single method + >>> def on_start(): + >>> pass + >>> + >>> @listen(and_("method1", "method2")) # Listen with AND condition + >>> def on_both_complete(): + >>> pass + """ def decorator(func): if isinstance(condition, str): func.__trigger_methods__ = [condition] @@ -82,8 +126,31 @@ def listen(condition): return decorator -def router(condition): - """Marks a method as a router to direct flow based on its return value.""" +def router(condition: Union[str, dict, Callable]) -> Callable: + """Marks a method as a router to direct flow based on its return value. + + A router method can return different string values that trigger different + subsequent methods, allowing for dynamic flow control. + + Args: + condition: The condition that triggers this router. Can be: + - str: Name of the triggering method + - dict: Dictionary with 'type' and 'methods' keys for complex conditions + - Callable: A function reference + + Returns: + Callable: The decorated function that will serve as a router. + + Raises: + ValueError: If the condition format is invalid. + + Example: + >>> @router("process_data") + >>> def route_result(result): + >>> if result.success: + >>> return "handle_success" + >>> return "handle_error" + """ def decorator(func): func.__is_router__ = True if isinstance(condition, str): @@ -108,8 +175,27 @@ def router(condition): return decorator -def or_(*conditions): - """Combines multiple conditions with OR logic for flow control.""" +def or_(*conditions: Union[str, dict, Callable]) -> dict: + """Combines multiple conditions with OR logic for flow control. + + Args: + *conditions: Variable number of conditions. Each can be: + - str: Name of a method + - dict: Dictionary with 'type' and 'methods' keys + - Callable: A function reference + + Returns: + dict: A dictionary with 'type': 'OR' and 'methods' list. + + Raises: + ValueError: If any condition is invalid. + + Example: + >>> @listen(or_("method1", "method2")) + >>> def on_either(): + >>> # Executes when either method1 OR method2 completes + >>> pass + """ methods = [] for condition in conditions: if isinstance(condition, dict) and "methods" in condition: @@ -123,8 +209,27 @@ def or_(*conditions): return {"type": "OR", "methods": methods} -def and_(*conditions): - """Combines multiple conditions with AND logic for flow control.""" +def and_(*conditions: Union[str, dict, Callable]) -> dict: + """Combines multiple conditions with AND logic for flow control. + + Args: + *conditions: Variable number of conditions. Each can be: + - str: Name of a method + - dict: Dictionary with 'type' and 'methods' keys + - Callable: A function reference + + Returns: + dict: A dictionary with 'type': 'AND' and 'methods' list. + + Raises: + ValueError: If any condition is invalid. + + Example: + >>> @listen(and_("method1", "method2")) + >>> def on_both(): + >>> # Executes when BOTH method1 AND method2 complete + >>> pass + """ methods = [] for condition in conditions: if isinstance(condition, dict) and "methods" in condition: @@ -183,6 +288,22 @@ class Flow(Generic[T], metaclass=FlowMeta): event_emitter = Signal("event_emitter") def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]: + """Create a generic version of Flow with specified state type. + + Args: + cls: The Flow class + item: The type parameter for the flow's state + + Returns: + Type["Flow"]: A new Flow class with the specified state type + + Example: + >>> class MyState(BaseModel): + >>> value: int + >>> + >>> class MyFlow(Flow[MyState]): + >>> pass + """ class _FlowGeneric(cls): # type: ignore _initial_state_T = item # type: ignore @@ -190,11 +311,23 @@ class Flow(Generic[T], metaclass=FlowMeta): return _FlowGeneric def __init__(self) -> None: + """Initialize a new Flow instance. + + Sets up internal state tracking, method registration, and telemetry. + The flow's methods are automatically discovered and registered during initialization. + + Attributes initialized: + _methods: Dictionary mapping method names to their callable objects + _state: The flow's state object of type T + _method_execution_counts: Tracks how many times each method has executed + _pending_and_listeners: Tracks methods waiting for AND conditions + _method_outputs: List of all outputs from executed methods + """ self._methods: Dict[str, Callable] = {} self._state: T = self._create_initial_state() self._method_execution_counts: Dict[str, int] = {} self._pending_and_listeners: Dict[str, Set[str]] = {} - self._method_outputs: List[Any] = [] # List to store all method outputs + self._method_outputs: List[Any] = [] self._telemetry.flow_creation_span(self.__class__.__name__) @@ -205,6 +338,20 @@ class Flow(Generic[T], metaclass=FlowMeta): self._methods[method_name] = getattr(self, method_name) def _create_initial_state(self) -> T: + """Create the initial state for the flow. + + The state is created based on the following priority: + 1. If initial_state is None and _initial_state_T exists (generic type), use that + 2. If initial_state is None, return empty dict + 3. If initial_state is a type, instantiate it + 4. Otherwise, use initial_state as-is + + Returns: + T: The initial state object of type T + + Note: + The type T can be either a Pydantic BaseModel or a dictionary. + """ if self.initial_state is None and hasattr(self, "_initial_state_T"): return self._initial_state_T() # type: ignore if self.initial_state is None: @@ -216,11 +363,21 @@ class Flow(Generic[T], metaclass=FlowMeta): @property def state(self) -> T: + """Get the current state of the flow. + + Returns: + T: The current state object, either a Pydantic model or dictionary + """ return self._state @property def method_outputs(self) -> List[Any]: - """Returns the list of all outputs from executed methods.""" + """Get the list of all outputs from executed methods. + + Returns: + List[Any]: A list containing the output values from all executed flow methods, + in order of execution. + """ return self._method_outputs def _initialize_state(self, inputs: Dict[str, Any]) -> None: @@ -310,6 +467,23 @@ class Flow(Generic[T], metaclass=FlowMeta): return result async def _execute_listeners(self, trigger_method: str, result: Any) -> None: + """Execute all listener methods triggered by a completed method. + + This method handles both router and non-router listeners in a specific order: + 1. First executes all triggered router methods sequentially until no more routers + are triggered + 2. Then executes all regular listeners in parallel + + Args: + trigger_method: The name of the method that completed execution + result: The result value from the triggering method + + Note: + Router methods are executed sequentially to ensure proper flow control, + while regular listeners are executed concurrently for better performance. + This provides fine-grained control over the execution flow while + maintaining efficiency. + """ # First, handle routers repeatedly until no router triggers anymore while True: routers_triggered = self._find_triggered_methods( @@ -339,6 +513,27 @@ class Flow(Generic[T], metaclass=FlowMeta): def _find_triggered_methods( self, trigger_method: str, router_only: bool ) -> List[str]: + """Find all methods that should be triggered based on completed method and type. + + Provides precise control over method triggering by handling both OR and AND + conditions separately for router and non-router methods. + + Args: + trigger_method: The name of the method that completed execution + router_only: If True, only find router methods; if False, only regular + listeners + + Returns: + List[str]: Names of methods that should be executed next + + Note: + This method implements sophisticated flow control by: + 1. Filtering methods based on their router/non-router status + 2. Handling OR conditions for immediate triggering + 3. Managing AND conditions with state tracking for complex dependencies + + This ensures predictable and consistent execution order in complex flows. + """ triggered = [] for listener_name, (condition_type, methods) in self._listeners.items(): is_router = listener_name in self._routers @@ -367,6 +562,27 @@ class Flow(Generic[T], metaclass=FlowMeta): return triggered async def _execute_single_listener(self, listener_name: str, result: Any) -> None: + """Execute a single listener method with precise parameter handling and error tracking. + + Provides fine-grained control over method execution through: + 1. Automatic parameter inspection to determine if the method accepts results + 2. Event emission for execution tracking + 3. Comprehensive error handling + 4. Recursive listener execution + + Args: + listener_name: The name of the listener method to execute + result: The result from the triggering method, passed to the listener + if its signature accepts parameters + + Note: + This method ensures precise execution control by: + - Inspecting method signatures to handle parameters correctly + - Emitting events for execution tracking + - Providing comprehensive error handling + - Supporting both parameterized and parameter-less methods + - Maintaining execution chain through recursive listener calls + """ try: method = self._methods[listener_name] @@ -410,8 +626,32 @@ class Flow(Generic[T], metaclass=FlowMeta): traceback.print_exc() - def plot(self, filename: str = "crewai_flow") -> None: + def plot(self, *args, **kwargs): + """Generate an interactive visualization of the flow's execution graph. + + Creates a detailed HTML visualization showing the relationships between + methods, including start points, listeners, routers, and their + connections. Includes telemetry tracking for flow analysis. + + Args: + *args: Variable length argument list passed to plot_flow + **kwargs: Arbitrary keyword arguments passed to plot_flow + + Note: + The visualization provides: + - Clear representation of method relationships + - Visual distinction between different method types + - Interactive exploration capabilities + - Execution path tracing + - Telemetry tracking for flow analysis + + Example: + >>> flow = MyFlow() + >>> flow.plot("my_workflow") # Creates my_workflow.html + """ + from crewai.flow.flow_visualizer import plot_flow + self._telemetry.flow_plotting_span( self.__class__.__name__, list(self._methods.keys()) ) - plot_flow(self, filename) + return plot_flow(self, *args, **kwargs) diff --git a/src/crewai/flow/flow_visual_utils.py b/src/crewai/flow/flow_visual_utils.py new file mode 100644 index 000000000..aeb27eb31 --- /dev/null +++ b/src/crewai/flow/flow_visual_utils.py @@ -0,0 +1,240 @@ +"""Utility functions for Flow visualization. + +This module contains utility functions specifically designed for visualizing +Flow graphs and calculating layout information. These utilities are separated +from general-purpose utilities to maintain a clean dependency structure. +""" + +from typing import TYPE_CHECKING, Dict, List, Set + +if TYPE_CHECKING: + from crewai.flow.flow import Flow + + +def calculate_node_levels(flow: Flow) -> Dict[str, int]: + """Calculate the hierarchical level of each node in the flow graph. + + Uses breadth-first traversal to assign levels to nodes, starting with + start methods at level 0. Handles both OR and AND conditions for listeners, + and considers router paths when calculating levels. + + Args: + flow: Flow instance containing methods, listeners, and router configurations + + Returns: + dict[str, int]: Dictionary mapping method names to their hierarchical levels, + where level 0 contains start methods and each subsequent level contains + methods triggered by the previous level + + Example: + >>> flow = Flow() + >>> @flow.start + ... def start(): pass + >>> @flow.on("start") + ... def second(): pass + >>> calculate_node_levels(flow) + {'start': 0, 'second': 1} + """ + levels = {} + queue = [] + visited = set() + pending_and_listeners = {} + + # Make all start methods at level 0 + for method_name, method in flow._methods.items(): + if hasattr(method, "__is_start_method__"): + levels[method_name] = 0 + queue.append(method_name) + + # Breadth-first traversal to assign levels + while queue: + current = queue.pop(0) + current_level = levels[current] + visited.add(current) + + for listener_name, (condition_type, trigger_methods) in flow._listeners.items(): + if condition_type == "OR": + if current in trigger_methods: + if ( + listener_name not in levels + or levels[listener_name] > current_level + 1 + ): + levels[listener_name] = current_level + 1 + if listener_name not in visited: + queue.append(listener_name) + elif condition_type == "AND": + if listener_name not in pending_and_listeners: + pending_and_listeners[listener_name] = set() + if current in trigger_methods: + pending_and_listeners[listener_name].add(current) + if set(trigger_methods) == pending_and_listeners[listener_name]: + if ( + listener_name not in levels + or levels[listener_name] > current_level + 1 + ): + levels[listener_name] = current_level + 1 + if listener_name not in visited: + queue.append(listener_name) + + # Handle router connections + if current in flow._routers: + router_method_name = current + paths = flow._router_paths.get(router_method_name, []) + for path in paths: + for listener_name, ( + condition_type, + trigger_methods, + ) in flow._listeners.items(): + if path in trigger_methods: + if ( + listener_name not in levels + or levels[listener_name] > current_level + 1 + ): + levels[listener_name] = current_level + 1 + if listener_name not in visited: + queue.append(listener_name) + + return levels + + +def count_outgoing_edges(flow: Flow) -> Dict[str, int]: + """Count the number of outgoing edges for each node in the flow graph. + + An outgoing edge represents a connection from a method to a listener + that it triggers. This is useful for visualization and analysis of + flow structure. + + Args: + flow: Flow instance containing methods and their connections + + Returns: + dict[str, int]: Dictionary mapping method names to their number + of outgoing connections + """ + counts = {} + for method_name in flow._methods: + counts[method_name] = 0 + for method_name in flow._listeners: + _, trigger_methods = flow._listeners[method_name] + for trigger in trigger_methods: + if trigger in flow._methods: + counts[trigger] += 1 + return counts + + +def build_ancestor_dict(flow: Flow) -> Dict[str, Set[str]]: + """Build a dictionary mapping each node to its set of ancestor nodes. + + Uses depth-first search to identify all ancestors (direct and indirect + trigger methods) for each node in the flow graph. Handles both regular + listeners and router paths. + + Args: + flow: Flow instance containing methods and their relationships + + Returns: + dict[str, set[str]]: Dictionary mapping each method name to a set + of its ancestor method names + """ + ancestors = {node: set() for node in flow._methods} + visited = set() + for node in flow._methods: + if node not in visited: + dfs_ancestors(node, ancestors, visited, flow) + return ancestors + + + + +def dfs_ancestors(node: str, ancestors: Dict[str, Set[str]], + visited: Set[str], flow: Flow) -> None: + """Perform depth-first search to populate the ancestors dictionary. + + Helper function for build_ancestor_dict that recursively traverses + the flow graph to identify ancestors of each node. + + Args: + node: Current node being processed + ancestors: Dictionary mapping nodes to their ancestor sets + visited: Set of already visited nodes + flow: Flow instance containing the graph structure + """ + if node in visited: + return + visited.add(node) + + # Handle regular listeners + for listener_name, (_, trigger_methods) in flow._listeners.items(): + if node in trigger_methods: + ancestors[listener_name].add(node) + ancestors[listener_name].update(ancestors[node]) + dfs_ancestors(listener_name, ancestors, visited, flow) + + # Handle router methods separately + if node in flow._routers: + router_method_name = node + paths = flow._router_paths.get(router_method_name, []) + for path in paths: + for listener_name, (_, trigger_methods) in flow._listeners.items(): + if path in trigger_methods: + # Only propagate the ancestors of the router method, not the router method itself + ancestors[listener_name].update(ancestors[node]) + dfs_ancestors(listener_name, ancestors, visited, flow) + + +def build_parent_children_dict(flow: Flow) -> Dict[str, List[str]]: + """Build a dictionary mapping each node to its list of child nodes. + + Maps both regular trigger methods to their listeners and router + methods to their path listeners. Useful for visualization and + traversal of the flow graph structure. + + Args: + flow: Flow instance containing methods and their relationships + + Returns: + dict[str, list[str]]: Dictionary mapping each method name to a + sorted list of its child method names + """ + parent_children = {} + + # Map listeners to their trigger methods + for listener_name, (_, trigger_methods) in flow._listeners.items(): + for trigger in trigger_methods: + if trigger not in parent_children: + parent_children[trigger] = [] + if listener_name not in parent_children[trigger]: + parent_children[trigger].append(listener_name) + + # Map router methods to their paths and to listeners + for router_method_name, paths in flow._router_paths.items(): + for path in paths: + # Map router method to listeners of each path + for listener_name, (_, trigger_methods) in flow._listeners.items(): + if path in trigger_methods: + if router_method_name not in parent_children: + parent_children[router_method_name] = [] + if listener_name not in parent_children[router_method_name]: + parent_children[router_method_name].append(listener_name) + + return parent_children + + +def get_child_index(parent: str, child: str, + parent_children: Dict[str, List[str]]) -> int: + """Get the index of a child node in its parent's sorted children list. + + Args: + parent: Parent node name + child: Child node name to find index for + parent_children: Dictionary mapping parents to their children lists + + Returns: + int: Zero-based index of the child in parent's sorted children list + + Raises: + ValueError: If child is not found in parent's children list + """ + children = parent_children.get(parent, []) + children.sort() + return children.index(child) diff --git a/src/crewai/flow/flow_visualizer.py b/src/crewai/flow/flow_visualizer.py index b78c4b717..c1c8d61c8 100644 --- a/src/crewai/flow/flow_visualizer.py +++ b/src/crewai/flow/flow_visualizer.py @@ -1,13 +1,16 @@ # flow_visualizer.py import os +from pathlib import Path from pyvis.network import Network +from crewai.flow.path_utils import safe_path_join, validate_file_path + from crewai.flow.config import COLORS, NODE_STYLES from crewai.flow.html_template_handler import HTMLTemplateHandler from crewai.flow.legend_generator import generate_legend_items_html, get_legend_items -from crewai.flow.utils import calculate_node_levels +from crewai.flow.flow_visual_utils import calculate_node_levels from crewai.flow.visualization_utils import ( add_edges, add_nodes_to_network, @@ -19,7 +22,21 @@ class FlowPlot: """Handles the creation and rendering of flow visualization diagrams.""" def __init__(self, flow): - """Initialize flow plot with flow instance and styling configuration.""" + """Initialize flow plot with flow instance and styling configuration. + + Args: + flow: A Flow instance with required attributes for visualization + + Raises: + ValueError: If flow object is invalid or missing required attributes + """ + if not hasattr(flow, '_methods'): + raise ValueError("Invalid flow object: Missing '_methods' attribute") + if not hasattr(flow, '_start_methods'): + raise ValueError("Invalid flow object: Missing '_start_methods' attribute") + if not hasattr(flow, '_listeners'): + raise ValueError("Invalid flow object: Missing '_listeners' attribute") + self.flow = flow self.colors = COLORS self.node_styles = NODE_STYLES @@ -58,9 +75,16 @@ class FlowPlot: network_html = net.generate_html() final_html_content = self._generate_final_html(network_html) - with open(f"{filename}.html", "w", encoding="utf-8") as f: - f.write(final_html_content) - print(f"Plot saved as {filename}.html") + try: + # Ensure the output path is safe + output_dir = os.getcwd() + output_path = safe_path_join(output_dir, f"{filename}.html") + + with open(output_path, "w", encoding="utf-8") as f: + f.write(final_html_content) + print(f"Plot saved as {output_path}") + except (IOError, ValueError) as e: + raise IOError(f"Failed to save flow visualization: {str(e)}") self._cleanup_pyvis_lib() diff --git a/src/crewai/flow/html_template_handler.py b/src/crewai/flow/html_template_handler.py index 2af2665c1..a0b789874 100644 --- a/src/crewai/flow/html_template_handler.py +++ b/src/crewai/flow/html_template_handler.py @@ -1,35 +1,107 @@ import base64 +import os import re +from pathlib import Path + +from crewai.flow.path_utils import safe_path_join, validate_file_path class HTMLTemplateHandler: """Handles HTML template processing and generation for flow visualization diagrams.""" def __init__(self, template_path, logo_path): - """Initialize template handler with template and logo file paths.""" - self.template_path = template_path - self.logo_path = logo_path + """Initialize template handler with template and logo file paths. + + Args: + template_path: Path to the HTML template file + logo_path: Path to the logo SVG file + + Raises: + ValueError: If template_path or logo_path is invalid or files don't exist + """ + try: + self.template_path = validate_file_path(template_path) + self.logo_path = validate_file_path(logo_path) + except (ValueError, TypeError) as e: + raise ValueError(f"Invalid file path: {str(e)}") def read_template(self): - """Read and return the HTML template file contents.""" - with open(self.template_path, "r", encoding="utf-8") as f: - return f.read() + """Read and return the HTML template file contents. + + Returns: + str: The contents of the template file + + Raises: + IOError: If template file cannot be read + """ + try: + with open(self.template_path, "r", encoding="utf-8") as f: + return f.read() + except IOError as e: + raise IOError(f"Failed to read template file {self.template_path}: {str(e)}") def encode_logo(self): - """Convert the logo SVG file to base64 encoded string.""" - with open(self.logo_path, "rb") as logo_file: - logo_svg_data = logo_file.read() - return base64.b64encode(logo_svg_data).decode("utf-8") + """Convert the logo SVG file to base64 encoded string. + + Returns: + str: Base64 encoded logo data + + Raises: + IOError: If logo file cannot be read + ValueError: If logo data cannot be encoded + """ + try: + with open(self.logo_path, "rb") as logo_file: + logo_svg_data = logo_file.read() + try: + return base64.b64encode(logo_svg_data).decode("utf-8") + except Exception as e: + raise ValueError(f"Failed to encode logo data: {str(e)}") + except IOError as e: + raise IOError(f"Failed to read logo file {self.logo_path}: {str(e)}") def extract_body_content(self, html): - """Extract and return content between body tags from HTML string.""" + """Extract and return content between body tags from HTML string. + + Args: + html: HTML string to extract body content from + + Returns: + str: Content between body tags, or empty string if not found + + Raises: + ValueError: If input HTML is invalid + """ + if not html or not isinstance(html, str): + raise ValueError("Input HTML must be a non-empty string") + match = re.search("