fix: change flow viz del dir; method inspection
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

* chore: update flow viz deletion dir, add typing
* tests: add flow viz tests to ensure lib dir is not deleted
This commit is contained in:
Greyson LaLonde
2025-10-22 19:32:38 -04:00
committed by GitHub
parent 4371cf5690
commit 9728388ea7
5 changed files with 99 additions and 30 deletions

View File

@@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import os import os
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any
from pyvis.network import Network # type: ignore[import-untyped] from pyvis.network import Network # type: ignore[import-untyped]
@@ -29,7 +29,7 @@ _printer = Printer()
class FlowPlot: class FlowPlot:
"""Handles the creation and rendering of flow visualization diagrams.""" """Handles the creation and rendering of flow visualization diagrams."""
def __init__(self, flow: Flow) -> None: def __init__(self, flow: Flow[Any]) -> None:
""" """
Initialize FlowPlot with a flow object. Initialize FlowPlot with a flow object.
@@ -136,7 +136,7 @@ class FlowPlot:
f"Unexpected error during flow visualization: {e!s}" f"Unexpected error during flow visualization: {e!s}"
) from e ) from e
finally: finally:
self._cleanup_pyvis_lib() self._cleanup_pyvis_lib(filename)
def _generate_final_html(self, network_html: str) -> str: def _generate_final_html(self, network_html: str) -> str:
""" """
@@ -186,26 +186,33 @@ class FlowPlot:
raise IOError(f"Failed to generate visualization HTML: {e!s}") from e raise IOError(f"Failed to generate visualization HTML: {e!s}") from e
@staticmethod @staticmethod
def _cleanup_pyvis_lib() -> None: def _cleanup_pyvis_lib(filename: str) -> None:
""" """
Clean up the generated lib folder from pyvis. Clean up the generated lib folder from pyvis.
This method safely removes the temporary lib directory created by pyvis This method safely removes the temporary lib directory created by pyvis
during network visualization generation. during network visualization generation. The lib folder is created in the
same directory as the output HTML file.
Parameters
----------
filename : str
The output filename (without .html extension) used for the visualization.
""" """
try: try:
lib_folder = safe_path_join("lib", root=os.getcwd()) import shutil
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
import shutil
shutil.rmtree(lib_folder) output_dir = os.path.dirname(os.path.abspath(filename)) or os.getcwd()
except ValueError as e: lib_folder = os.path.join(output_dir, "lib")
_printer.print(f"Error validating lib folder path: {e}", color="red") if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
vis_js = os.path.join(lib_folder, "vis-network.min.js")
if os.path.exists(vis_js):
shutil.rmtree(lib_folder)
except Exception as e: except Exception as e:
_printer.print(f"Error cleaning up lib folder: {e}", color="red") _printer.print(f"Error cleaning up lib folder: {e}", color="red")
def plot_flow(flow: Flow, filename: str = "flow_plot") -> None: def plot_flow(flow: Flow[Any], filename: str = "flow_plot") -> None:
""" """
Convenience function to create and save a flow visualization. Convenience function to create and save a flow visualization.

View File

@@ -1,5 +1,8 @@
"""HTML template processing and generation for flow visualization diagrams."""
import base64 import base64
import re import re
from typing import Any
from crewai.flow.path_utils import validate_path_exists from crewai.flow.path_utils import validate_path_exists
@@ -7,7 +10,7 @@ from crewai.flow.path_utils import validate_path_exists
class HTMLTemplateHandler: class HTMLTemplateHandler:
"""Handles HTML template processing and generation for flow visualization diagrams.""" """Handles HTML template processing and generation for flow visualization diagrams."""
def __init__(self, template_path, logo_path): def __init__(self, template_path: str, logo_path: str) -> None:
""" """
Initialize HTMLTemplateHandler with validated template and logo paths. Initialize HTMLTemplateHandler with validated template and logo paths.
@@ -29,23 +32,23 @@ class HTMLTemplateHandler:
except ValueError as e: except ValueError as e:
raise ValueError(f"Invalid template or logo path: {e}") from e raise ValueError(f"Invalid template or logo path: {e}") from e
def read_template(self): def read_template(self) -> str:
"""Read and return the HTML template file contents.""" """Read and return the HTML template file contents."""
with open(self.template_path, "r", encoding="utf-8") as f: with open(self.template_path, "r", encoding="utf-8") as f:
return f.read() return f.read()
def encode_logo(self): def encode_logo(self) -> str:
"""Convert the logo SVG file to base64 encoded string.""" """Convert the logo SVG file to base64 encoded string."""
with open(self.logo_path, "rb") as logo_file: with open(self.logo_path, "rb") as logo_file:
logo_svg_data = logo_file.read() logo_svg_data = logo_file.read()
return base64.b64encode(logo_svg_data).decode("utf-8") return base64.b64encode(logo_svg_data).decode("utf-8")
def extract_body_content(self, html): def extract_body_content(self, html: str) -> str:
"""Extract and return content between body tags from HTML string.""" """Extract and return content between body tags from HTML string."""
match = re.search("<body.*?>(.*?)</body>", html, re.DOTALL) match = re.search("<body.*?>(.*?)</body>", html, re.DOTALL)
return match.group(1) if match else "" return match.group(1) if match else ""
def generate_legend_items_html(self, legend_items): def generate_legend_items_html(self, legend_items: list[dict[str, Any]]) -> str:
"""Generate HTML markup for the legend items.""" """Generate HTML markup for the legend items."""
legend_items_html = "" legend_items_html = ""
for item in legend_items: for item in legend_items:
@@ -73,7 +76,9 @@ class HTMLTemplateHandler:
""" """
return legend_items_html return legend_items_html
def generate_final_html(self, network_body, legend_items_html, title="Flow Plot"): def generate_final_html(
self, network_body: str, legend_items_html: str, title: str = "Flow Plot"
) -> str:
"""Combine all components into final HTML document with network visualization.""" """Combine all components into final HTML document with network visualization."""
html_template = self.read_template() html_template = self.read_template()
logo_svg_base64 = self.encode_logo() logo_svg_base64 = self.encode_logo()

View File

@@ -1,4 +1,23 @@
def get_legend_items(colors): """Legend generation for flow visualization diagrams."""
from typing import Any
from crewai.flow.config import FlowColors
def get_legend_items(colors: FlowColors) -> list[dict[str, Any]]:
"""Generate legend items based on flow colors.
Parameters
----------
colors : FlowColors
Dictionary containing color definitions for flow elements.
Returns
-------
list[dict[str, Any]]
List of legend item dictionaries with labels and styling.
"""
return [ return [
{"label": "Start Method", "color": colors["start"]}, {"label": "Start Method", "color": colors["start"]},
{"label": "Method", "color": colors["method"]}, {"label": "Method", "color": colors["method"]},
@@ -24,7 +43,19 @@ def get_legend_items(colors):
] ]
def generate_legend_items_html(legend_items): def generate_legend_items_html(legend_items: list[dict[str, Any]]) -> str:
"""Generate HTML markup for legend items.
Parameters
----------
legend_items : list[dict[str, Any]]
List of legend item dictionaries containing labels and styling.
Returns
-------
str
HTML string containing formatted legend items.
"""
legend_items_html = "" legend_items_html = ""
for item in legend_items: for item in legend_items:
if "border" in item: if "border" in item:

View File

@@ -36,28 +36,29 @@ from crewai.flow.utils import (
from crewai.utilities.printer import Printer from crewai.utilities.printer import Printer
_printer = Printer() _printer = Printer()
def method_calls_crew(method: Any) -> bool: def method_calls_crew(method: Any) -> bool:
""" """
Check if the method contains a call to `.crew()`. Check if the method contains a call to `.crew()`, `.kickoff()`, or `.kickoff_async()`.
Parameters Parameters
---------- ----------
method : Any method : Any
The method to analyze for crew() calls. The method to analyze for crew or agent execution calls.
Returns Returns
------- -------
bool bool
True if the method calls .crew(), False otherwise. True if the method calls .crew(), .kickoff(), or .kickoff_async(), False otherwise.
Notes Notes
----- -----
Uses AST analysis to detect method calls, specifically looking for Uses AST analysis to detect method calls, specifically looking for
attribute access of 'crew'. attribute access of 'crew', 'kickoff', or 'kickoff_async'.
This includes both traditional Crew execution (.crew()) and Agent/LiteAgent
execution (.kickoff() or .kickoff_async()).
""" """
try: try:
source = inspect.getsource(method) source = inspect.getsource(method)
@@ -68,14 +69,14 @@ def method_calls_crew(method: Any) -> bool:
return False return False
class CrewCallVisitor(ast.NodeVisitor): class CrewCallVisitor(ast.NodeVisitor):
"""AST visitor to detect .crew() method calls.""" """AST visitor to detect .crew(), .kickoff(), or .kickoff_async() method calls."""
def __init__(self): def __init__(self) -> None:
self.found = False self.found = False
def visit_Call(self, node): def visit_Call(self, node: ast.Call) -> None:
if isinstance(node.func, ast.Attribute): if isinstance(node.func, ast.Attribute):
if node.func.attr == "crew": if node.func.attr in ("crew", "kickoff", "kickoff_async"):
self.found = True self.found = True
self.generic_visit(node) self.generic_visit(node)
@@ -113,7 +114,7 @@ def add_nodes_to_network(
- Regular methods - Regular methods
""" """
def human_friendly_label(method_name): def human_friendly_label(method_name: str) -> str:
return method_name.replace("_", " ").title() return method_name.replace("_", " ").title()
node_style: ( node_style: (

View File

@@ -850,6 +850,31 @@ def test_flow_plotting():
assert isinstance(received_events[0].timestamp, datetime) assert isinstance(received_events[0].timestamp, datetime)
def test_method_calls_crew_detection():
"""Test that method_calls_crew() detects .crew(), .kickoff(), and .kickoff_async() calls."""
from crewai.flow.visualization_utils import method_calls_crew
from crewai import Agent
# Test with a real Flow that uses agent.kickoff()
class FlowWithAgentKickoff(Flow):
@start()
def run_agent(self):
agent = Agent(role="test", goal="test", backstory="test")
return agent.kickoff("query")
flow = FlowWithAgentKickoff()
assert method_calls_crew(flow.run_agent) is True
# Test with a Flow that has no crew/agent calls
class FlowWithoutCrewCalls(Flow):
@start()
def simple_method(self):
return "Just a regular method"
flow2 = FlowWithoutCrewCalls()
assert method_calls_crew(flow2.simple_method) is False
def test_multiple_routers_from_same_trigger(): def test_multiple_routers_from_same_trigger():
"""Test that multiple routers triggered by the same method all activate their listeners.""" """Test that multiple routers triggered by the same method all activate their listeners."""
execution_order = [] execution_order = []