mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
fix: change flow viz del dir; method inspection
Some checks failed
Some checks failed
* chore: update flow viz deletion dir, add typing * tests: add flow viz tests to ensure lib dir is not deleted
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from pyvis.network import Network # type: ignore[import-untyped]
|
from pyvis.network import Network # type: ignore[import-untyped]
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ _printer = Printer()
|
|||||||
class FlowPlot:
|
class FlowPlot:
|
||||||
"""Handles the creation and rendering of flow visualization diagrams."""
|
"""Handles the creation and rendering of flow visualization diagrams."""
|
||||||
|
|
||||||
def __init__(self, flow: Flow) -> None:
|
def __init__(self, flow: Flow[Any]) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize FlowPlot with a flow object.
|
Initialize FlowPlot with a flow object.
|
||||||
|
|
||||||
@@ -136,7 +136,7 @@ class FlowPlot:
|
|||||||
f"Unexpected error during flow visualization: {e!s}"
|
f"Unexpected error during flow visualization: {e!s}"
|
||||||
) from e
|
) from e
|
||||||
finally:
|
finally:
|
||||||
self._cleanup_pyvis_lib()
|
self._cleanup_pyvis_lib(filename)
|
||||||
|
|
||||||
def _generate_final_html(self, network_html: str) -> str:
|
def _generate_final_html(self, network_html: str) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -186,26 +186,33 @@ class FlowPlot:
|
|||||||
raise IOError(f"Failed to generate visualization HTML: {e!s}") from e
|
raise IOError(f"Failed to generate visualization HTML: {e!s}") from e
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _cleanup_pyvis_lib() -> None:
|
def _cleanup_pyvis_lib(filename: str) -> None:
|
||||||
"""
|
"""
|
||||||
Clean up the generated lib folder from pyvis.
|
Clean up the generated lib folder from pyvis.
|
||||||
|
|
||||||
This method safely removes the temporary lib directory created by pyvis
|
This method safely removes the temporary lib directory created by pyvis
|
||||||
during network visualization generation.
|
during network visualization generation. The lib folder is created in the
|
||||||
|
same directory as the output HTML file.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filename : str
|
||||||
|
The output filename (without .html extension) used for the visualization.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
lib_folder = safe_path_join("lib", root=os.getcwd())
|
import shutil
|
||||||
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
|
|
||||||
import shutil
|
|
||||||
|
|
||||||
shutil.rmtree(lib_folder)
|
output_dir = os.path.dirname(os.path.abspath(filename)) or os.getcwd()
|
||||||
except ValueError as e:
|
lib_folder = os.path.join(output_dir, "lib")
|
||||||
_printer.print(f"Error validating lib folder path: {e}", color="red")
|
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
|
||||||
|
vis_js = os.path.join(lib_folder, "vis-network.min.js")
|
||||||
|
if os.path.exists(vis_js):
|
||||||
|
shutil.rmtree(lib_folder)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
_printer.print(f"Error cleaning up lib folder: {e}", color="red")
|
_printer.print(f"Error cleaning up lib folder: {e}", color="red")
|
||||||
|
|
||||||
|
|
||||||
def plot_flow(flow: Flow, filename: str = "flow_plot") -> None:
|
def plot_flow(flow: Flow[Any], filename: str = "flow_plot") -> None:
|
||||||
"""
|
"""
|
||||||
Convenience function to create and save a flow visualization.
|
Convenience function to create and save a flow visualization.
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
|
"""HTML template processing and generation for flow visualization diagrams."""
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import re
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from crewai.flow.path_utils import validate_path_exists
|
from crewai.flow.path_utils import validate_path_exists
|
||||||
|
|
||||||
@@ -7,7 +10,7 @@ from crewai.flow.path_utils import validate_path_exists
|
|||||||
class HTMLTemplateHandler:
|
class HTMLTemplateHandler:
|
||||||
"""Handles HTML template processing and generation for flow visualization diagrams."""
|
"""Handles HTML template processing and generation for flow visualization diagrams."""
|
||||||
|
|
||||||
def __init__(self, template_path, logo_path):
|
def __init__(self, template_path: str, logo_path: str) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize HTMLTemplateHandler with validated template and logo paths.
|
Initialize HTMLTemplateHandler with validated template and logo paths.
|
||||||
|
|
||||||
@@ -29,23 +32,23 @@ class HTMLTemplateHandler:
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise ValueError(f"Invalid template or logo path: {e}") from e
|
raise ValueError(f"Invalid template or logo path: {e}") from e
|
||||||
|
|
||||||
def read_template(self):
|
def read_template(self) -> str:
|
||||||
"""Read and return the HTML template file contents."""
|
"""Read and return the HTML template file contents."""
|
||||||
with open(self.template_path, "r", encoding="utf-8") as f:
|
with open(self.template_path, "r", encoding="utf-8") as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
|
|
||||||
def encode_logo(self):
|
def encode_logo(self) -> str:
|
||||||
"""Convert the logo SVG file to base64 encoded string."""
|
"""Convert the logo SVG file to base64 encoded string."""
|
||||||
with open(self.logo_path, "rb") as logo_file:
|
with open(self.logo_path, "rb") as logo_file:
|
||||||
logo_svg_data = logo_file.read()
|
logo_svg_data = logo_file.read()
|
||||||
return base64.b64encode(logo_svg_data).decode("utf-8")
|
return base64.b64encode(logo_svg_data).decode("utf-8")
|
||||||
|
|
||||||
def extract_body_content(self, html):
|
def extract_body_content(self, html: str) -> str:
|
||||||
"""Extract and return content between body tags from HTML string."""
|
"""Extract and return content between body tags from HTML string."""
|
||||||
match = re.search("<body.*?>(.*?)</body>", html, re.DOTALL)
|
match = re.search("<body.*?>(.*?)</body>", html, re.DOTALL)
|
||||||
return match.group(1) if match else ""
|
return match.group(1) if match else ""
|
||||||
|
|
||||||
def generate_legend_items_html(self, legend_items):
|
def generate_legend_items_html(self, legend_items: list[dict[str, Any]]) -> str:
|
||||||
"""Generate HTML markup for the legend items."""
|
"""Generate HTML markup for the legend items."""
|
||||||
legend_items_html = ""
|
legend_items_html = ""
|
||||||
for item in legend_items:
|
for item in legend_items:
|
||||||
@@ -73,7 +76,9 @@ class HTMLTemplateHandler:
|
|||||||
"""
|
"""
|
||||||
return legend_items_html
|
return legend_items_html
|
||||||
|
|
||||||
def generate_final_html(self, network_body, legend_items_html, title="Flow Plot"):
|
def generate_final_html(
|
||||||
|
self, network_body: str, legend_items_html: str, title: str = "Flow Plot"
|
||||||
|
) -> str:
|
||||||
"""Combine all components into final HTML document with network visualization."""
|
"""Combine all components into final HTML document with network visualization."""
|
||||||
html_template = self.read_template()
|
html_template = self.read_template()
|
||||||
logo_svg_base64 = self.encode_logo()
|
logo_svg_base64 = self.encode_logo()
|
||||||
|
|||||||
@@ -1,4 +1,23 @@
|
|||||||
def get_legend_items(colors):
|
"""Legend generation for flow visualization diagrams."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from crewai.flow.config import FlowColors
|
||||||
|
|
||||||
|
|
||||||
|
def get_legend_items(colors: FlowColors) -> list[dict[str, Any]]:
|
||||||
|
"""Generate legend items based on flow colors.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
colors : FlowColors
|
||||||
|
Dictionary containing color definitions for flow elements.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
list[dict[str, Any]]
|
||||||
|
List of legend item dictionaries with labels and styling.
|
||||||
|
"""
|
||||||
return [
|
return [
|
||||||
{"label": "Start Method", "color": colors["start"]},
|
{"label": "Start Method", "color": colors["start"]},
|
||||||
{"label": "Method", "color": colors["method"]},
|
{"label": "Method", "color": colors["method"]},
|
||||||
@@ -24,7 +43,19 @@ def get_legend_items(colors):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def generate_legend_items_html(legend_items):
|
def generate_legend_items_html(legend_items: list[dict[str, Any]]) -> str:
|
||||||
|
"""Generate HTML markup for legend items.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
legend_items : list[dict[str, Any]]
|
||||||
|
List of legend item dictionaries containing labels and styling.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
HTML string containing formatted legend items.
|
||||||
|
"""
|
||||||
legend_items_html = ""
|
legend_items_html = ""
|
||||||
for item in legend_items:
|
for item in legend_items:
|
||||||
if "border" in item:
|
if "border" in item:
|
||||||
|
|||||||
@@ -36,28 +36,29 @@ from crewai.flow.utils import (
|
|||||||
from crewai.utilities.printer import Printer
|
from crewai.utilities.printer import Printer
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
_printer = Printer()
|
_printer = Printer()
|
||||||
|
|
||||||
|
|
||||||
def method_calls_crew(method: Any) -> bool:
|
def method_calls_crew(method: Any) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if the method contains a call to `.crew()`.
|
Check if the method contains a call to `.crew()`, `.kickoff()`, or `.kickoff_async()`.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
method : Any
|
method : Any
|
||||||
The method to analyze for crew() calls.
|
The method to analyze for crew or agent execution calls.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
bool
|
bool
|
||||||
True if the method calls .crew(), False otherwise.
|
True if the method calls .crew(), .kickoff(), or .kickoff_async(), False otherwise.
|
||||||
|
|
||||||
Notes
|
Notes
|
||||||
-----
|
-----
|
||||||
Uses AST analysis to detect method calls, specifically looking for
|
Uses AST analysis to detect method calls, specifically looking for
|
||||||
attribute access of 'crew'.
|
attribute access of 'crew', 'kickoff', or 'kickoff_async'.
|
||||||
|
This includes both traditional Crew execution (.crew()) and Agent/LiteAgent
|
||||||
|
execution (.kickoff() or .kickoff_async()).
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
source = inspect.getsource(method)
|
source = inspect.getsource(method)
|
||||||
@@ -68,14 +69,14 @@ def method_calls_crew(method: Any) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
class CrewCallVisitor(ast.NodeVisitor):
|
class CrewCallVisitor(ast.NodeVisitor):
|
||||||
"""AST visitor to detect .crew() method calls."""
|
"""AST visitor to detect .crew(), .kickoff(), or .kickoff_async() method calls."""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self.found = False
|
self.found = False
|
||||||
|
|
||||||
def visit_Call(self, node):
|
def visit_Call(self, node: ast.Call) -> None:
|
||||||
if isinstance(node.func, ast.Attribute):
|
if isinstance(node.func, ast.Attribute):
|
||||||
if node.func.attr == "crew":
|
if node.func.attr in ("crew", "kickoff", "kickoff_async"):
|
||||||
self.found = True
|
self.found = True
|
||||||
self.generic_visit(node)
|
self.generic_visit(node)
|
||||||
|
|
||||||
@@ -113,7 +114,7 @@ def add_nodes_to_network(
|
|||||||
- Regular methods
|
- Regular methods
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def human_friendly_label(method_name):
|
def human_friendly_label(method_name: str) -> str:
|
||||||
return method_name.replace("_", " ").title()
|
return method_name.replace("_", " ").title()
|
||||||
|
|
||||||
node_style: (
|
node_style: (
|
||||||
|
|||||||
@@ -850,6 +850,31 @@ def test_flow_plotting():
|
|||||||
assert isinstance(received_events[0].timestamp, datetime)
|
assert isinstance(received_events[0].timestamp, datetime)
|
||||||
|
|
||||||
|
|
||||||
|
def test_method_calls_crew_detection():
|
||||||
|
"""Test that method_calls_crew() detects .crew(), .kickoff(), and .kickoff_async() calls."""
|
||||||
|
from crewai.flow.visualization_utils import method_calls_crew
|
||||||
|
from crewai import Agent
|
||||||
|
|
||||||
|
# Test with a real Flow that uses agent.kickoff()
|
||||||
|
class FlowWithAgentKickoff(Flow):
|
||||||
|
@start()
|
||||||
|
def run_agent(self):
|
||||||
|
agent = Agent(role="test", goal="test", backstory="test")
|
||||||
|
return agent.kickoff("query")
|
||||||
|
|
||||||
|
flow = FlowWithAgentKickoff()
|
||||||
|
assert method_calls_crew(flow.run_agent) is True
|
||||||
|
|
||||||
|
# Test with a Flow that has no crew/agent calls
|
||||||
|
class FlowWithoutCrewCalls(Flow):
|
||||||
|
@start()
|
||||||
|
def simple_method(self):
|
||||||
|
return "Just a regular method"
|
||||||
|
|
||||||
|
flow2 = FlowWithoutCrewCalls()
|
||||||
|
assert method_calls_crew(flow2.simple_method) is False
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_routers_from_same_trigger():
|
def test_multiple_routers_from_same_trigger():
|
||||||
"""Test that multiple routers triggered by the same method all activate their listeners."""
|
"""Test that multiple routers triggered by the same method all activate their listeners."""
|
||||||
execution_order = []
|
execution_order = []
|
||||||
|
|||||||
Reference in New Issue
Block a user