fix: change flow viz del dir; method inspection
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

* chore: update flow viz deletion dir, add typing
* tests: add flow viz tests to ensure lib dir is not deleted
This commit is contained in:
Greyson LaLonde
2025-10-22 19:32:38 -04:00
committed by GitHub
parent 4371cf5690
commit 9728388ea7
5 changed files with 99 additions and 30 deletions

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
import os
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any
from pyvis.network import Network # type: ignore[import-untyped]
@@ -29,7 +29,7 @@ _printer = Printer()
class FlowPlot:
"""Handles the creation and rendering of flow visualization diagrams."""
def __init__(self, flow: Flow) -> None:
def __init__(self, flow: Flow[Any]) -> None:
"""
Initialize FlowPlot with a flow object.
@@ -136,7 +136,7 @@ class FlowPlot:
f"Unexpected error during flow visualization: {e!s}"
) from e
finally:
self._cleanup_pyvis_lib()
self._cleanup_pyvis_lib(filename)
def _generate_final_html(self, network_html: str) -> str:
"""
@@ -186,26 +186,33 @@ class FlowPlot:
raise IOError(f"Failed to generate visualization HTML: {e!s}") from e
@staticmethod
def _cleanup_pyvis_lib() -> None:
def _cleanup_pyvis_lib(filename: str) -> None:
"""
Clean up the generated lib folder from pyvis.
This method safely removes the temporary lib directory created by pyvis
during network visualization generation.
during network visualization generation. The lib folder is created in the
same directory as the output HTML file.
Parameters
----------
filename : str
The output filename (without .html extension) used for the visualization.
"""
try:
lib_folder = safe_path_join("lib", root=os.getcwd())
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
import shutil
import shutil
shutil.rmtree(lib_folder)
except ValueError as e:
_printer.print(f"Error validating lib folder path: {e}", color="red")
output_dir = os.path.dirname(os.path.abspath(filename)) or os.getcwd()
lib_folder = os.path.join(output_dir, "lib")
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
vis_js = os.path.join(lib_folder, "vis-network.min.js")
if os.path.exists(vis_js):
shutil.rmtree(lib_folder)
except Exception as e:
_printer.print(f"Error cleaning up lib folder: {e}", color="red")
def plot_flow(flow: Flow, filename: str = "flow_plot") -> None:
def plot_flow(flow: Flow[Any], filename: str = "flow_plot") -> None:
"""
Convenience function to create and save a flow visualization.

View File

@@ -1,5 +1,8 @@
"""HTML template processing and generation for flow visualization diagrams."""
import base64
import re
from typing import Any
from crewai.flow.path_utils import validate_path_exists
@@ -7,7 +10,7 @@ from crewai.flow.path_utils import validate_path_exists
class HTMLTemplateHandler:
"""Handles HTML template processing and generation for flow visualization diagrams."""
def __init__(self, template_path, logo_path):
def __init__(self, template_path: str, logo_path: str) -> None:
"""
Initialize HTMLTemplateHandler with validated template and logo paths.
@@ -29,23 +32,23 @@ class HTMLTemplateHandler:
except ValueError as e:
raise ValueError(f"Invalid template or logo path: {e}") from e
def read_template(self):
def read_template(self) -> str:
"""Read and return the HTML template file contents."""
with open(self.template_path, "r", encoding="utf-8") as f:
return f.read()
def encode_logo(self):
def encode_logo(self) -> str:
"""Convert the logo SVG file to base64 encoded string."""
with open(self.logo_path, "rb") as logo_file:
logo_svg_data = logo_file.read()
return base64.b64encode(logo_svg_data).decode("utf-8")
def extract_body_content(self, html):
def extract_body_content(self, html: str) -> str:
"""Extract and return content between body tags from HTML string."""
match = re.search("<body.*?>(.*?)</body>", html, re.DOTALL)
return match.group(1) if match else ""
def generate_legend_items_html(self, legend_items):
def generate_legend_items_html(self, legend_items: list[dict[str, Any]]) -> str:
"""Generate HTML markup for the legend items."""
legend_items_html = ""
for item in legend_items:
@@ -73,7 +76,9 @@ class HTMLTemplateHandler:
"""
return legend_items_html
def generate_final_html(self, network_body, legend_items_html, title="Flow Plot"):
def generate_final_html(
self, network_body: str, legend_items_html: str, title: str = "Flow Plot"
) -> str:
"""Combine all components into final HTML document with network visualization."""
html_template = self.read_template()
logo_svg_base64 = self.encode_logo()

View File

@@ -1,4 +1,23 @@
def get_legend_items(colors):
"""Legend generation for flow visualization diagrams."""
from typing import Any
from crewai.flow.config import FlowColors
def get_legend_items(colors: FlowColors) -> list[dict[str, Any]]:
"""Generate legend items based on flow colors.
Parameters
----------
colors : FlowColors
Dictionary containing color definitions for flow elements.
Returns
-------
list[dict[str, Any]]
List of legend item dictionaries with labels and styling.
"""
return [
{"label": "Start Method", "color": colors["start"]},
{"label": "Method", "color": colors["method"]},
@@ -24,7 +43,19 @@ def get_legend_items(colors):
]
def generate_legend_items_html(legend_items):
def generate_legend_items_html(legend_items: list[dict[str, Any]]) -> str:
"""Generate HTML markup for legend items.
Parameters
----------
legend_items : list[dict[str, Any]]
List of legend item dictionaries containing labels and styling.
Returns
-------
str
HTML string containing formatted legend items.
"""
legend_items_html = ""
for item in legend_items:
if "border" in item:

View File

@@ -36,28 +36,29 @@ from crewai.flow.utils import (
from crewai.utilities.printer import Printer
_printer = Printer()
def method_calls_crew(method: Any) -> bool:
"""
Check if the method contains a call to `.crew()`.
Check if the method contains a call to `.crew()`, `.kickoff()`, or `.kickoff_async()`.
Parameters
----------
method : Any
The method to analyze for crew() calls.
The method to analyze for crew or agent execution calls.
Returns
-------
bool
True if the method calls .crew(), False otherwise.
True if the method calls .crew(), .kickoff(), or .kickoff_async(), False otherwise.
Notes
-----
Uses AST analysis to detect method calls, specifically looking for
attribute access of 'crew'.
attribute access of 'crew', 'kickoff', or 'kickoff_async'.
This includes both traditional Crew execution (.crew()) and Agent/LiteAgent
execution (.kickoff() or .kickoff_async()).
"""
try:
source = inspect.getsource(method)
@@ -68,14 +69,14 @@ def method_calls_crew(method: Any) -> bool:
return False
class CrewCallVisitor(ast.NodeVisitor):
"""AST visitor to detect .crew() method calls."""
"""AST visitor to detect .crew(), .kickoff(), or .kickoff_async() method calls."""
def __init__(self):
def __init__(self) -> None:
self.found = False
def visit_Call(self, node):
def visit_Call(self, node: ast.Call) -> None:
if isinstance(node.func, ast.Attribute):
if node.func.attr == "crew":
if node.func.attr in ("crew", "kickoff", "kickoff_async"):
self.found = True
self.generic_visit(node)
@@ -113,7 +114,7 @@ def add_nodes_to_network(
- Regular methods
"""
def human_friendly_label(method_name):
def human_friendly_label(method_name: str) -> str:
return method_name.replace("_", " ").title()
node_style: (

View File

@@ -850,6 +850,31 @@ def test_flow_plotting():
assert isinstance(received_events[0].timestamp, datetime)
def test_method_calls_crew_detection():
"""Test that method_calls_crew() detects .crew(), .kickoff(), and .kickoff_async() calls."""
from crewai.flow.visualization_utils import method_calls_crew
from crewai import Agent
# Test with a real Flow that uses agent.kickoff()
class FlowWithAgentKickoff(Flow):
@start()
def run_agent(self):
agent = Agent(role="test", goal="test", backstory="test")
return agent.kickoff("query")
flow = FlowWithAgentKickoff()
assert method_calls_crew(flow.run_agent) is True
# Test with a Flow that has no crew/agent calls
class FlowWithoutCrewCalls(Flow):
@start()
def simple_method(self):
return "Just a regular method"
flow2 = FlowWithoutCrewCalls()
assert method_calls_crew(flow2.simple_method) is False
def test_multiple_routers_from_same_trigger():
"""Test that multiple routers triggered by the same method all activate their listeners."""
execution_order = []