mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
fix: change flow viz del dir; method inspection
Some checks failed
Some checks failed
* chore: update flow viz deletion dir, add typing * tests: add flow viz tests to ensure lib dir is not deleted
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pyvis.network import Network # type: ignore[import-untyped]
|
||||
|
||||
@@ -29,7 +29,7 @@ _printer = Printer()
|
||||
class FlowPlot:
|
||||
"""Handles the creation and rendering of flow visualization diagrams."""
|
||||
|
||||
def __init__(self, flow: Flow) -> None:
|
||||
def __init__(self, flow: Flow[Any]) -> None:
|
||||
"""
|
||||
Initialize FlowPlot with a flow object.
|
||||
|
||||
@@ -136,7 +136,7 @@ class FlowPlot:
|
||||
f"Unexpected error during flow visualization: {e!s}"
|
||||
) from e
|
||||
finally:
|
||||
self._cleanup_pyvis_lib()
|
||||
self._cleanup_pyvis_lib(filename)
|
||||
|
||||
def _generate_final_html(self, network_html: str) -> str:
|
||||
"""
|
||||
@@ -186,26 +186,33 @@ class FlowPlot:
|
||||
raise IOError(f"Failed to generate visualization HTML: {e!s}") from e
|
||||
|
||||
@staticmethod
|
||||
def _cleanup_pyvis_lib() -> None:
|
||||
def _cleanup_pyvis_lib(filename: str) -> None:
|
||||
"""
|
||||
Clean up the generated lib folder from pyvis.
|
||||
|
||||
This method safely removes the temporary lib directory created by pyvis
|
||||
during network visualization generation.
|
||||
during network visualization generation. The lib folder is created in the
|
||||
same directory as the output HTML file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filename : str
|
||||
The output filename (without .html extension) used for the visualization.
|
||||
"""
|
||||
try:
|
||||
lib_folder = safe_path_join("lib", root=os.getcwd())
|
||||
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
|
||||
import shutil
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(lib_folder)
|
||||
except ValueError as e:
|
||||
_printer.print(f"Error validating lib folder path: {e}", color="red")
|
||||
output_dir = os.path.dirname(os.path.abspath(filename)) or os.getcwd()
|
||||
lib_folder = os.path.join(output_dir, "lib")
|
||||
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
|
||||
vis_js = os.path.join(lib_folder, "vis-network.min.js")
|
||||
if os.path.exists(vis_js):
|
||||
shutil.rmtree(lib_folder)
|
||||
except Exception as e:
|
||||
_printer.print(f"Error cleaning up lib folder: {e}", color="red")
|
||||
|
||||
|
||||
def plot_flow(flow: Flow, filename: str = "flow_plot") -> None:
|
||||
def plot_flow(flow: Flow[Any], filename: str = "flow_plot") -> None:
|
||||
"""
|
||||
Convenience function to create and save a flow visualization.
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
"""HTML template processing and generation for flow visualization diagrams."""
|
||||
|
||||
import base64
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from crewai.flow.path_utils import validate_path_exists
|
||||
|
||||
@@ -7,7 +10,7 @@ from crewai.flow.path_utils import validate_path_exists
|
||||
class HTMLTemplateHandler:
|
||||
"""Handles HTML template processing and generation for flow visualization diagrams."""
|
||||
|
||||
def __init__(self, template_path, logo_path):
|
||||
def __init__(self, template_path: str, logo_path: str) -> None:
|
||||
"""
|
||||
Initialize HTMLTemplateHandler with validated template and logo paths.
|
||||
|
||||
@@ -29,23 +32,23 @@ class HTMLTemplateHandler:
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid template or logo path: {e}") from e
|
||||
|
||||
def read_template(self):
|
||||
def read_template(self) -> str:
|
||||
"""Read and return the HTML template file contents."""
|
||||
with open(self.template_path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
def encode_logo(self):
|
||||
def encode_logo(self) -> str:
|
||||
"""Convert the logo SVG file to base64 encoded string."""
|
||||
with open(self.logo_path, "rb") as logo_file:
|
||||
logo_svg_data = logo_file.read()
|
||||
return base64.b64encode(logo_svg_data).decode("utf-8")
|
||||
|
||||
def extract_body_content(self, html):
|
||||
def extract_body_content(self, html: str) -> str:
|
||||
"""Extract and return content between body tags from HTML string."""
|
||||
match = re.search("<body.*?>(.*?)</body>", html, re.DOTALL)
|
||||
return match.group(1) if match else ""
|
||||
|
||||
def generate_legend_items_html(self, legend_items):
|
||||
def generate_legend_items_html(self, legend_items: list[dict[str, Any]]) -> str:
|
||||
"""Generate HTML markup for the legend items."""
|
||||
legend_items_html = ""
|
||||
for item in legend_items:
|
||||
@@ -73,7 +76,9 @@ class HTMLTemplateHandler:
|
||||
"""
|
||||
return legend_items_html
|
||||
|
||||
def generate_final_html(self, network_body, legend_items_html, title="Flow Plot"):
|
||||
def generate_final_html(
|
||||
self, network_body: str, legend_items_html: str, title: str = "Flow Plot"
|
||||
) -> str:
|
||||
"""Combine all components into final HTML document with network visualization."""
|
||||
html_template = self.read_template()
|
||||
logo_svg_base64 = self.encode_logo()
|
||||
|
||||
@@ -1,4 +1,23 @@
|
||||
def get_legend_items(colors):
|
||||
"""Legend generation for flow visualization diagrams."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from crewai.flow.config import FlowColors
|
||||
|
||||
|
||||
def get_legend_items(colors: FlowColors) -> list[dict[str, Any]]:
|
||||
"""Generate legend items based on flow colors.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
colors : FlowColors
|
||||
Dictionary containing color definitions for flow elements.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[dict[str, Any]]
|
||||
List of legend item dictionaries with labels and styling.
|
||||
"""
|
||||
return [
|
||||
{"label": "Start Method", "color": colors["start"]},
|
||||
{"label": "Method", "color": colors["method"]},
|
||||
@@ -24,7 +43,19 @@ def get_legend_items(colors):
|
||||
]
|
||||
|
||||
|
||||
def generate_legend_items_html(legend_items):
|
||||
def generate_legend_items_html(legend_items: list[dict[str, Any]]) -> str:
|
||||
"""Generate HTML markup for legend items.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
legend_items : list[dict[str, Any]]
|
||||
List of legend item dictionaries containing labels and styling.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
HTML string containing formatted legend items.
|
||||
"""
|
||||
legend_items_html = ""
|
||||
for item in legend_items:
|
||||
if "border" in item:
|
||||
|
||||
@@ -36,28 +36,29 @@ from crewai.flow.utils import (
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
|
||||
|
||||
_printer = Printer()
|
||||
|
||||
|
||||
def method_calls_crew(method: Any) -> bool:
|
||||
"""
|
||||
Check if the method contains a call to `.crew()`.
|
||||
Check if the method contains a call to `.crew()`, `.kickoff()`, or `.kickoff_async()`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
method : Any
|
||||
The method to analyze for crew() calls.
|
||||
The method to analyze for crew or agent execution calls.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the method calls .crew(), False otherwise.
|
||||
True if the method calls .crew(), .kickoff(), or .kickoff_async(), False otherwise.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Uses AST analysis to detect method calls, specifically looking for
|
||||
attribute access of 'crew'.
|
||||
attribute access of 'crew', 'kickoff', or 'kickoff_async'.
|
||||
This includes both traditional Crew execution (.crew()) and Agent/LiteAgent
|
||||
execution (.kickoff() or .kickoff_async()).
|
||||
"""
|
||||
try:
|
||||
source = inspect.getsource(method)
|
||||
@@ -68,14 +69,14 @@ def method_calls_crew(method: Any) -> bool:
|
||||
return False
|
||||
|
||||
class CrewCallVisitor(ast.NodeVisitor):
|
||||
"""AST visitor to detect .crew() method calls."""
|
||||
"""AST visitor to detect .crew(), .kickoff(), or .kickoff_async() method calls."""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.found = False
|
||||
|
||||
def visit_Call(self, node):
|
||||
def visit_Call(self, node: ast.Call) -> None:
|
||||
if isinstance(node.func, ast.Attribute):
|
||||
if node.func.attr == "crew":
|
||||
if node.func.attr in ("crew", "kickoff", "kickoff_async"):
|
||||
self.found = True
|
||||
self.generic_visit(node)
|
||||
|
||||
@@ -113,7 +114,7 @@ def add_nodes_to_network(
|
||||
- Regular methods
|
||||
"""
|
||||
|
||||
def human_friendly_label(method_name):
|
||||
def human_friendly_label(method_name: str) -> str:
|
||||
return method_name.replace("_", " ").title()
|
||||
|
||||
node_style: (
|
||||
|
||||
@@ -850,6 +850,31 @@ def test_flow_plotting():
|
||||
assert isinstance(received_events[0].timestamp, datetime)
|
||||
|
||||
|
||||
def test_method_calls_crew_detection():
|
||||
"""Test that method_calls_crew() detects .crew(), .kickoff(), and .kickoff_async() calls."""
|
||||
from crewai.flow.visualization_utils import method_calls_crew
|
||||
from crewai import Agent
|
||||
|
||||
# Test with a real Flow that uses agent.kickoff()
|
||||
class FlowWithAgentKickoff(Flow):
|
||||
@start()
|
||||
def run_agent(self):
|
||||
agent = Agent(role="test", goal="test", backstory="test")
|
||||
return agent.kickoff("query")
|
||||
|
||||
flow = FlowWithAgentKickoff()
|
||||
assert method_calls_crew(flow.run_agent) is True
|
||||
|
||||
# Test with a Flow that has no crew/agent calls
|
||||
class FlowWithoutCrewCalls(Flow):
|
||||
@start()
|
||||
def simple_method(self):
|
||||
return "Just a regular method"
|
||||
|
||||
flow2 = FlowWithoutCrewCalls()
|
||||
assert method_calls_crew(flow2.simple_method) is False
|
||||
|
||||
|
||||
def test_multiple_routers_from_same_trigger():
|
||||
"""Test that multiple routers triggered by the same method all activate their listeners."""
|
||||
execution_order = []
|
||||
|
||||
Reference in New Issue
Block a user