feat: add docstring

This commit is contained in:
Marco Vinciguerra
2024-12-30 16:44:11 +01:00
parent 73f328860b
commit 8c6883e5ee
6 changed files with 32 additions and 14 deletions

View File

@@ -31,6 +31,7 @@ T = TypeVar("T", bound=Union[BaseModel, Dict[str, Any]])
def start(condition=None):
"""Marks a method as a flow starting point, optionally triggered by other methods."""
def decorator(func):
func.__is_start_method__ = True
if condition is not None:
@@ -57,6 +58,7 @@ def start(condition=None):
def listen(condition):
"""Marks a method to execute when specified conditions/methods complete."""
def decorator(func):
if isinstance(condition, str):
func.__trigger_methods__ = [condition]
@@ -81,9 +83,9 @@ def listen(condition):
def router(condition):
"""Marks a method as a router to direct flow based on its return value."""
def decorator(func):
func.__is_router__ = True
# Handle conditions like listen/start
if isinstance(condition, str):
func.__trigger_methods__ = [condition]
func.__condition_type__ = "OR"
@@ -107,6 +109,7 @@ def router(condition):
def or_(*conditions):
"""Combines multiple conditions with OR logic for flow control."""
methods = []
for condition in conditions:
if isinstance(condition, dict) and "methods" in condition:
@@ -121,6 +124,7 @@ def or_(*conditions):
def and_(*conditions):
"""Combines multiple conditions with AND logic for flow control."""
methods = []
for condition in conditions:
if isinstance(condition, dict) and "methods" in condition:

View File

@@ -16,12 +16,16 @@ from crewai.flow.visualization_utils import (
class FlowPlot:
"""Handles the creation and rendering of flow visualization diagrams."""
def __init__(self, flow):
"""Initialize flow plot with flow instance and styling configuration."""
self.flow = flow
self.colors = COLORS
self.node_styles = NODE_STYLES
def plot(self, filename):
"""Generate and save interactive flow visualization to HTML file."""
net = Network(
directed=True,
height="750px",
@@ -46,22 +50,14 @@ class FlowPlot:
"""
)
# Calculate levels for nodes
node_levels = calculate_node_levels(self.flow)
# Compute positions
node_positions = compute_positions(self.flow, node_levels)
# Add nodes to the network
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
# Add edges to the network
add_edges(net, self.flow, node_positions, self.colors)
network_html = net.generate_html()
final_html_content = self._generate_final_html(network_html)
# Save the final HTML content to the file
with open(f"{filename}.html", "w", encoding="utf-8") as f:
f.write(final_html_content)
print(f"Plot saved as {filename}.html")
@@ -69,7 +65,7 @@ class FlowPlot:
self._cleanup_pyvis_lib()
def _generate_final_html(self, network_html):
# Extract just the body content from the generated HTML
"""Generate final HTML content with network visualization and legend."""
current_dir = os.path.dirname(__file__)
template_path = os.path.join(
current_dir, "assets", "crewai_flow_visual_template.html"
@@ -79,7 +75,6 @@ class FlowPlot:
html_handler = HTMLTemplateHandler(template_path, logo_path)
network_body = html_handler.extract_body_content(network_html)
# Generate the legend items HTML
legend_items = get_legend_items(self.colors)
legend_items_html = generate_legend_items_html(legend_items)
final_html_content = html_handler.generate_final_html(
@@ -88,17 +83,17 @@ class FlowPlot:
return final_html_content
def _cleanup_pyvis_lib(self):
# Clean up the generated lib folder
"""Clean up temporary files generated by pyvis library."""
lib_folder = os.path.join(os.getcwd(), "lib")
try:
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
import shutil
shutil.rmtree(lib_folder)
except Exception as e:
print(f"Error cleaning up {lib_folder}: {e}")
def plot_flow(flow, filename="flow_plot"):
"""Create and save a visualization of the given flow."""
visualizer = FlowPlot(flow)
visualizer.plot(filename)

View File

@@ -3,24 +3,31 @@ import re
class HTMLTemplateHandler:
"""Handles HTML template processing and generation for flow visualization diagrams."""
def __init__(self, template_path, logo_path):
"""Initialize template handler with template and logo file paths."""
self.template_path = template_path
self.logo_path = logo_path
def read_template(self):
"""Read and return the HTML template file contents."""
with open(self.template_path, "r", encoding="utf-8") as f:
return f.read()
def encode_logo(self):
"""Convert the logo SVG file to base64 encoded string."""
with open(self.logo_path, "rb") as logo_file:
logo_svg_data = logo_file.read()
return base64.b64encode(logo_svg_data).decode("utf-8")
def extract_body_content(self, html):
"""Extract and return content between body tags from HTML string."""
match = re.search("<body.*?>(.*?)</body>", html, re.DOTALL)
return match.group(1) if match else ""
def generate_legend_items_html(self, legend_items):
"""Generate HTML markup for the legend items."""
legend_items_html = ""
for item in legend_items:
if "border" in item:
@@ -48,6 +55,7 @@ class HTMLTemplateHandler:
return legend_items_html
def generate_final_html(self, network_body, legend_items_html, title="Flow Plot"):
"""Combine all components into final HTML document with network visualization."""
html_template = self.read_template()
logo_svg_base64 = self.encode_logo()

View File

@@ -1,3 +1,4 @@
def get_legend_items(colors):
return [
{"label": "Start Method", "color": colors["start"]},

View File

@@ -1,3 +1,9 @@
"""Utility functions for flow execution and visualization.
Provides helper functions for analyzing flow structure, calculating
node positions, and extracting return values from methods.
"""
import ast
import inspect
import textwrap

View File

@@ -10,7 +10,7 @@ from .utils import (
def method_calls_crew(method):
"""Check if the method calls `.crew()`."""
"""Check if the method contains a .crew() call."""
try:
source = inspect.getsource(method)
source = inspect.cleandoc(source)
@@ -20,6 +20,7 @@ def method_calls_crew(method):
return False
class CrewCallVisitor(ast.NodeVisitor):
"""AST visitor to detect .crew() method calls."""
def __init__(self):
self.found = False
@@ -35,6 +36,7 @@ def method_calls_crew(method):
def add_nodes_to_network(net, flow, node_positions, node_styles):
"""Add nodes to the network visualization with appropriate styling."""
def human_friendly_label(method_name):
return method_name.replace("_", " ").title()
@@ -74,6 +76,7 @@ def add_nodes_to_network(net, flow, node_positions, node_styles):
def compute_positions(flow, node_levels, y_spacing=150, x_spacing=150):
"""Calculate x,y coordinates for each node in the flow diagram."""
level_nodes = {}
node_positions = {}
@@ -91,6 +94,7 @@ def compute_positions(flow, node_levels, y_spacing=150, x_spacing=150):
def add_edges(net, flow, node_positions, colors):
"""Add edges between nodes with appropriate styling and routing."""
ancestors = build_ancestor_dict(flow)
parent_children = build_parent_children_dict(flow)