mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
feat: add docstring
This commit is contained in:
@@ -31,6 +31,7 @@ T = TypeVar("T", bound=Union[BaseModel, Dict[str, Any]])
|
||||
|
||||
|
||||
def start(condition=None):
|
||||
"""Marks a method as a flow starting point, optionally triggered by other methods."""
|
||||
def decorator(func):
|
||||
func.__is_start_method__ = True
|
||||
if condition is not None:
|
||||
@@ -57,6 +58,7 @@ def start(condition=None):
|
||||
|
||||
|
||||
def listen(condition):
|
||||
"""Marks a method to execute when specified conditions/methods complete."""
|
||||
def decorator(func):
|
||||
if isinstance(condition, str):
|
||||
func.__trigger_methods__ = [condition]
|
||||
@@ -81,9 +83,9 @@ def listen(condition):
|
||||
|
||||
|
||||
def router(condition):
|
||||
"""Marks a method as a router to direct flow based on its return value."""
|
||||
def decorator(func):
|
||||
func.__is_router__ = True
|
||||
# Handle conditions like listen/start
|
||||
if isinstance(condition, str):
|
||||
func.__trigger_methods__ = [condition]
|
||||
func.__condition_type__ = "OR"
|
||||
@@ -107,6 +109,7 @@ def router(condition):
|
||||
|
||||
|
||||
def or_(*conditions):
|
||||
"""Combines multiple conditions with OR logic for flow control."""
|
||||
methods = []
|
||||
for condition in conditions:
|
||||
if isinstance(condition, dict) and "methods" in condition:
|
||||
@@ -121,6 +124,7 @@ def or_(*conditions):
|
||||
|
||||
|
||||
def and_(*conditions):
|
||||
"""Combines multiple conditions with AND logic for flow control."""
|
||||
methods = []
|
||||
for condition in conditions:
|
||||
if isinstance(condition, dict) and "methods" in condition:
|
||||
|
||||
@@ -16,12 +16,16 @@ from crewai.flow.visualization_utils import (
|
||||
|
||||
|
||||
class FlowPlot:
|
||||
"""Handles the creation and rendering of flow visualization diagrams."""
|
||||
|
||||
def __init__(self, flow):
|
||||
"""Initialize flow plot with flow instance and styling configuration."""
|
||||
self.flow = flow
|
||||
self.colors = COLORS
|
||||
self.node_styles = NODE_STYLES
|
||||
|
||||
def plot(self, filename):
|
||||
"""Generate and save interactive flow visualization to HTML file."""
|
||||
net = Network(
|
||||
directed=True,
|
||||
height="750px",
|
||||
@@ -46,22 +50,14 @@ class FlowPlot:
|
||||
"""
|
||||
)
|
||||
|
||||
# Calculate levels for nodes
|
||||
node_levels = calculate_node_levels(self.flow)
|
||||
|
||||
# Compute positions
|
||||
node_positions = compute_positions(self.flow, node_levels)
|
||||
|
||||
# Add nodes to the network
|
||||
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
|
||||
|
||||
# Add edges to the network
|
||||
add_edges(net, self.flow, node_positions, self.colors)
|
||||
|
||||
network_html = net.generate_html()
|
||||
final_html_content = self._generate_final_html(network_html)
|
||||
|
||||
# Save the final HTML content to the file
|
||||
with open(f"{filename}.html", "w", encoding="utf-8") as f:
|
||||
f.write(final_html_content)
|
||||
print(f"Plot saved as {filename}.html")
|
||||
@@ -69,7 +65,7 @@ class FlowPlot:
|
||||
self._cleanup_pyvis_lib()
|
||||
|
||||
def _generate_final_html(self, network_html):
|
||||
# Extract just the body content from the generated HTML
|
||||
"""Generate final HTML content with network visualization and legend."""
|
||||
current_dir = os.path.dirname(__file__)
|
||||
template_path = os.path.join(
|
||||
current_dir, "assets", "crewai_flow_visual_template.html"
|
||||
@@ -79,7 +75,6 @@ class FlowPlot:
|
||||
html_handler = HTMLTemplateHandler(template_path, logo_path)
|
||||
network_body = html_handler.extract_body_content(network_html)
|
||||
|
||||
# Generate the legend items HTML
|
||||
legend_items = get_legend_items(self.colors)
|
||||
legend_items_html = generate_legend_items_html(legend_items)
|
||||
final_html_content = html_handler.generate_final_html(
|
||||
@@ -88,17 +83,17 @@ class FlowPlot:
|
||||
return final_html_content
|
||||
|
||||
def _cleanup_pyvis_lib(self):
|
||||
# Clean up the generated lib folder
|
||||
"""Clean up temporary files generated by pyvis library."""
|
||||
lib_folder = os.path.join(os.getcwd(), "lib")
|
||||
try:
|
||||
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(lib_folder)
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up {lib_folder}: {e}")
|
||||
|
||||
|
||||
def plot_flow(flow, filename="flow_plot"):
|
||||
"""Create and save a visualization of the given flow."""
|
||||
visualizer = FlowPlot(flow)
|
||||
visualizer.plot(filename)
|
||||
|
||||
@@ -3,24 +3,31 @@ import re
|
||||
|
||||
|
||||
class HTMLTemplateHandler:
|
||||
"""Handles HTML template processing and generation for flow visualization diagrams."""
|
||||
|
||||
def __init__(self, template_path, logo_path):
|
||||
"""Initialize template handler with template and logo file paths."""
|
||||
self.template_path = template_path
|
||||
self.logo_path = logo_path
|
||||
|
||||
def read_template(self):
|
||||
"""Read and return the HTML template file contents."""
|
||||
with open(self.template_path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
def encode_logo(self):
|
||||
"""Convert the logo SVG file to base64 encoded string."""
|
||||
with open(self.logo_path, "rb") as logo_file:
|
||||
logo_svg_data = logo_file.read()
|
||||
return base64.b64encode(logo_svg_data).decode("utf-8")
|
||||
|
||||
def extract_body_content(self, html):
|
||||
"""Extract and return content between body tags from HTML string."""
|
||||
match = re.search("<body.*?>(.*?)</body>", html, re.DOTALL)
|
||||
return match.group(1) if match else ""
|
||||
|
||||
def generate_legend_items_html(self, legend_items):
|
||||
"""Generate HTML markup for the legend items."""
|
||||
legend_items_html = ""
|
||||
for item in legend_items:
|
||||
if "border" in item:
|
||||
@@ -48,6 +55,7 @@ class HTMLTemplateHandler:
|
||||
return legend_items_html
|
||||
|
||||
def generate_final_html(self, network_body, legend_items_html, title="Flow Plot"):
|
||||
"""Combine all components into final HTML document with network visualization."""
|
||||
html_template = self.read_template()
|
||||
logo_svg_base64 = self.encode_logo()
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
|
||||
def get_legend_items(colors):
|
||||
return [
|
||||
{"label": "Start Method", "color": colors["start"]},
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
"""Utility functions for flow execution and visualization.
|
||||
|
||||
Provides helper functions for analyzing flow structure, calculating
|
||||
node positions, and extracting return values from methods.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
|
||||
@@ -10,7 +10,7 @@ from .utils import (
|
||||
|
||||
|
||||
def method_calls_crew(method):
|
||||
"""Check if the method calls `.crew()`."""
|
||||
"""Check if the method contains a .crew() call."""
|
||||
try:
|
||||
source = inspect.getsource(method)
|
||||
source = inspect.cleandoc(source)
|
||||
@@ -20,6 +20,7 @@ def method_calls_crew(method):
|
||||
return False
|
||||
|
||||
class CrewCallVisitor(ast.NodeVisitor):
|
||||
"""AST visitor to detect .crew() method calls."""
|
||||
def __init__(self):
|
||||
self.found = False
|
||||
|
||||
@@ -35,6 +36,7 @@ def method_calls_crew(method):
|
||||
|
||||
|
||||
def add_nodes_to_network(net, flow, node_positions, node_styles):
|
||||
"""Add nodes to the network visualization with appropriate styling."""
|
||||
def human_friendly_label(method_name):
|
||||
return method_name.replace("_", " ").title()
|
||||
|
||||
@@ -74,6 +76,7 @@ def add_nodes_to_network(net, flow, node_positions, node_styles):
|
||||
|
||||
|
||||
def compute_positions(flow, node_levels, y_spacing=150, x_spacing=150):
|
||||
"""Calculate x,y coordinates for each node in the flow diagram."""
|
||||
level_nodes = {}
|
||||
node_positions = {}
|
||||
|
||||
@@ -91,6 +94,7 @@ def compute_positions(flow, node_levels, y_spacing=150, x_spacing=150):
|
||||
|
||||
|
||||
def add_edges(net, flow, node_positions, colors):
|
||||
"""Add edges between nodes with appropriate styling and routing."""
|
||||
ancestors = build_ancestor_dict(flow)
|
||||
parent_children = build_parent_children_dict(flow)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user