mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-20 21:38:14 +00:00
feat: add docstring
This commit is contained in:
committed by
Devin AI
parent
ba0965ef87
commit
a375ad2a2f
@@ -31,6 +31,7 @@ T = TypeVar("T", bound=Union[BaseModel, Dict[str, Any]])
|
|||||||
|
|
||||||
|
|
||||||
def start(condition=None):
|
def start(condition=None):
|
||||||
|
"""Marks a method as a flow starting point, optionally triggered by other methods."""
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
func.__is_start_method__ = True
|
func.__is_start_method__ = True
|
||||||
if condition is not None:
|
if condition is not None:
|
||||||
@@ -57,6 +58,7 @@ def start(condition=None):
|
|||||||
|
|
||||||
|
|
||||||
def listen(condition):
|
def listen(condition):
|
||||||
|
"""Marks a method to execute when specified conditions/methods complete."""
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
if isinstance(condition, str):
|
if isinstance(condition, str):
|
||||||
func.__trigger_methods__ = [condition]
|
func.__trigger_methods__ = [condition]
|
||||||
@@ -81,9 +83,9 @@ def listen(condition):
|
|||||||
|
|
||||||
|
|
||||||
def router(condition):
|
def router(condition):
|
||||||
|
"""Marks a method as a router to direct flow based on its return value."""
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
func.__is_router__ = True
|
func.__is_router__ = True
|
||||||
# Handle conditions like listen/start
|
|
||||||
if isinstance(condition, str):
|
if isinstance(condition, str):
|
||||||
func.__trigger_methods__ = [condition]
|
func.__trigger_methods__ = [condition]
|
||||||
func.__condition_type__ = "OR"
|
func.__condition_type__ = "OR"
|
||||||
@@ -107,6 +109,7 @@ def router(condition):
|
|||||||
|
|
||||||
|
|
||||||
def or_(*conditions):
|
def or_(*conditions):
|
||||||
|
"""Combines multiple conditions with OR logic for flow control."""
|
||||||
methods = []
|
methods = []
|
||||||
for condition in conditions:
|
for condition in conditions:
|
||||||
if isinstance(condition, dict) and "methods" in condition:
|
if isinstance(condition, dict) and "methods" in condition:
|
||||||
@@ -121,6 +124,7 @@ def or_(*conditions):
|
|||||||
|
|
||||||
|
|
||||||
def and_(*conditions):
|
def and_(*conditions):
|
||||||
|
"""Combines multiple conditions with AND logic for flow control."""
|
||||||
methods = []
|
methods = []
|
||||||
for condition in conditions:
|
for condition in conditions:
|
||||||
if isinstance(condition, dict) and "methods" in condition:
|
if isinstance(condition, dict) and "methods" in condition:
|
||||||
|
|||||||
@@ -16,12 +16,16 @@ from crewai.flow.visualization_utils import (
|
|||||||
|
|
||||||
|
|
||||||
class FlowPlot:
|
class FlowPlot:
|
||||||
|
"""Handles the creation and rendering of flow visualization diagrams."""
|
||||||
|
|
||||||
def __init__(self, flow):
|
def __init__(self, flow):
|
||||||
|
"""Initialize flow plot with flow instance and styling configuration."""
|
||||||
self.flow = flow
|
self.flow = flow
|
||||||
self.colors = COLORS
|
self.colors = COLORS
|
||||||
self.node_styles = NODE_STYLES
|
self.node_styles = NODE_STYLES
|
||||||
|
|
||||||
def plot(self, filename):
|
def plot(self, filename):
|
||||||
|
"""Generate and save interactive flow visualization to HTML file."""
|
||||||
net = Network(
|
net = Network(
|
||||||
directed=True,
|
directed=True,
|
||||||
height="750px",
|
height="750px",
|
||||||
@@ -46,22 +50,14 @@ class FlowPlot:
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
# Calculate levels for nodes
|
|
||||||
node_levels = calculate_node_levels(self.flow)
|
node_levels = calculate_node_levels(self.flow)
|
||||||
|
|
||||||
# Compute positions
|
|
||||||
node_positions = compute_positions(self.flow, node_levels)
|
node_positions = compute_positions(self.flow, node_levels)
|
||||||
|
|
||||||
# Add nodes to the network
|
|
||||||
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
|
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
|
||||||
|
|
||||||
# Add edges to the network
|
|
||||||
add_edges(net, self.flow, node_positions, self.colors)
|
add_edges(net, self.flow, node_positions, self.colors)
|
||||||
|
|
||||||
network_html = net.generate_html()
|
network_html = net.generate_html()
|
||||||
final_html_content = self._generate_final_html(network_html)
|
final_html_content = self._generate_final_html(network_html)
|
||||||
|
|
||||||
# Save the final HTML content to the file
|
|
||||||
with open(f"{filename}.html", "w", encoding="utf-8") as f:
|
with open(f"{filename}.html", "w", encoding="utf-8") as f:
|
||||||
f.write(final_html_content)
|
f.write(final_html_content)
|
||||||
print(f"Plot saved as {filename}.html")
|
print(f"Plot saved as {filename}.html")
|
||||||
@@ -69,7 +65,7 @@ class FlowPlot:
|
|||||||
self._cleanup_pyvis_lib()
|
self._cleanup_pyvis_lib()
|
||||||
|
|
||||||
def _generate_final_html(self, network_html):
|
def _generate_final_html(self, network_html):
|
||||||
# Extract just the body content from the generated HTML
|
"""Generate final HTML content with network visualization and legend."""
|
||||||
current_dir = os.path.dirname(__file__)
|
current_dir = os.path.dirname(__file__)
|
||||||
template_path = os.path.join(
|
template_path = os.path.join(
|
||||||
current_dir, "assets", "crewai_flow_visual_template.html"
|
current_dir, "assets", "crewai_flow_visual_template.html"
|
||||||
@@ -79,7 +75,6 @@ class FlowPlot:
|
|||||||
html_handler = HTMLTemplateHandler(template_path, logo_path)
|
html_handler = HTMLTemplateHandler(template_path, logo_path)
|
||||||
network_body = html_handler.extract_body_content(network_html)
|
network_body = html_handler.extract_body_content(network_html)
|
||||||
|
|
||||||
# Generate the legend items HTML
|
|
||||||
legend_items = get_legend_items(self.colors)
|
legend_items = get_legend_items(self.colors)
|
||||||
legend_items_html = generate_legend_items_html(legend_items)
|
legend_items_html = generate_legend_items_html(legend_items)
|
||||||
final_html_content = html_handler.generate_final_html(
|
final_html_content = html_handler.generate_final_html(
|
||||||
@@ -88,17 +83,17 @@ class FlowPlot:
|
|||||||
return final_html_content
|
return final_html_content
|
||||||
|
|
||||||
def _cleanup_pyvis_lib(self):
|
def _cleanup_pyvis_lib(self):
|
||||||
# Clean up the generated lib folder
|
"""Clean up temporary files generated by pyvis library."""
|
||||||
lib_folder = os.path.join(os.getcwd(), "lib")
|
lib_folder = os.path.join(os.getcwd(), "lib")
|
||||||
try:
|
try:
|
||||||
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
|
if os.path.exists(lib_folder) and os.path.isdir(lib_folder):
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
shutil.rmtree(lib_folder)
|
shutil.rmtree(lib_folder)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error cleaning up {lib_folder}: {e}")
|
print(f"Error cleaning up {lib_folder}: {e}")
|
||||||
|
|
||||||
|
|
||||||
def plot_flow(flow, filename="flow_plot"):
|
def plot_flow(flow, filename="flow_plot"):
|
||||||
|
"""Create and save a visualization of the given flow."""
|
||||||
visualizer = FlowPlot(flow)
|
visualizer = FlowPlot(flow)
|
||||||
visualizer.plot(filename)
|
visualizer.plot(filename)
|
||||||
|
|||||||
@@ -3,24 +3,31 @@ import re
|
|||||||
|
|
||||||
|
|
||||||
class HTMLTemplateHandler:
|
class HTMLTemplateHandler:
|
||||||
|
"""Handles HTML template processing and generation for flow visualization diagrams."""
|
||||||
|
|
||||||
def __init__(self, template_path, logo_path):
|
def __init__(self, template_path, logo_path):
|
||||||
|
"""Initialize template handler with template and logo file paths."""
|
||||||
self.template_path = template_path
|
self.template_path = template_path
|
||||||
self.logo_path = logo_path
|
self.logo_path = logo_path
|
||||||
|
|
||||||
def read_template(self):
|
def read_template(self):
|
||||||
|
"""Read and return the HTML template file contents."""
|
||||||
with open(self.template_path, "r", encoding="utf-8") as f:
|
with open(self.template_path, "r", encoding="utf-8") as f:
|
||||||
return f.read()
|
return f.read()
|
||||||
|
|
||||||
def encode_logo(self):
|
def encode_logo(self):
|
||||||
|
"""Convert the logo SVG file to base64 encoded string."""
|
||||||
with open(self.logo_path, "rb") as logo_file:
|
with open(self.logo_path, "rb") as logo_file:
|
||||||
logo_svg_data = logo_file.read()
|
logo_svg_data = logo_file.read()
|
||||||
return base64.b64encode(logo_svg_data).decode("utf-8")
|
return base64.b64encode(logo_svg_data).decode("utf-8")
|
||||||
|
|
||||||
def extract_body_content(self, html):
|
def extract_body_content(self, html):
|
||||||
|
"""Extract and return content between body tags from HTML string."""
|
||||||
match = re.search("<body.*?>(.*?)</body>", html, re.DOTALL)
|
match = re.search("<body.*?>(.*?)</body>", html, re.DOTALL)
|
||||||
return match.group(1) if match else ""
|
return match.group(1) if match else ""
|
||||||
|
|
||||||
def generate_legend_items_html(self, legend_items):
|
def generate_legend_items_html(self, legend_items):
|
||||||
|
"""Generate HTML markup for the legend items."""
|
||||||
legend_items_html = ""
|
legend_items_html = ""
|
||||||
for item in legend_items:
|
for item in legend_items:
|
||||||
if "border" in item:
|
if "border" in item:
|
||||||
@@ -48,6 +55,7 @@ class HTMLTemplateHandler:
|
|||||||
return legend_items_html
|
return legend_items_html
|
||||||
|
|
||||||
def generate_final_html(self, network_body, legend_items_html, title="Flow Plot"):
|
def generate_final_html(self, network_body, legend_items_html, title="Flow Plot"):
|
||||||
|
"""Combine all components into final HTML document with network visualization."""
|
||||||
html_template = self.read_template()
|
html_template = self.read_template()
|
||||||
logo_svg_base64 = self.encode_logo()
|
logo_svg_base64 = self.encode_logo()
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
|
||||||
def get_legend_items(colors):
|
def get_legend_items(colors):
|
||||||
return [
|
return [
|
||||||
{"label": "Start Method", "color": colors["start"]},
|
{"label": "Start Method", "color": colors["start"]},
|
||||||
|
|||||||
@@ -1,3 +1,9 @@
|
|||||||
|
"""Utility functions for flow execution and visualization.
|
||||||
|
|
||||||
|
Provides helper functions for analyzing flow structure, calculating
|
||||||
|
node positions, and extracting return values from methods.
|
||||||
|
"""
|
||||||
|
|
||||||
import ast
|
import ast
|
||||||
import inspect
|
import inspect
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from .utils import (
|
|||||||
|
|
||||||
|
|
||||||
def method_calls_crew(method):
|
def method_calls_crew(method):
|
||||||
"""Check if the method calls `.crew()`."""
|
"""Check if the method contains a .crew() call."""
|
||||||
try:
|
try:
|
||||||
source = inspect.getsource(method)
|
source = inspect.getsource(method)
|
||||||
source = inspect.cleandoc(source)
|
source = inspect.cleandoc(source)
|
||||||
@@ -20,6 +20,7 @@ def method_calls_crew(method):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
class CrewCallVisitor(ast.NodeVisitor):
|
class CrewCallVisitor(ast.NodeVisitor):
|
||||||
|
"""AST visitor to detect .crew() method calls."""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.found = False
|
self.found = False
|
||||||
|
|
||||||
@@ -35,6 +36,7 @@ def method_calls_crew(method):
|
|||||||
|
|
||||||
|
|
||||||
def add_nodes_to_network(net, flow, node_positions, node_styles):
|
def add_nodes_to_network(net, flow, node_positions, node_styles):
|
||||||
|
"""Add nodes to the network visualization with appropriate styling."""
|
||||||
def human_friendly_label(method_name):
|
def human_friendly_label(method_name):
|
||||||
return method_name.replace("_", " ").title()
|
return method_name.replace("_", " ").title()
|
||||||
|
|
||||||
@@ -74,6 +76,7 @@ def add_nodes_to_network(net, flow, node_positions, node_styles):
|
|||||||
|
|
||||||
|
|
||||||
def compute_positions(flow, node_levels, y_spacing=150, x_spacing=150):
|
def compute_positions(flow, node_levels, y_spacing=150, x_spacing=150):
|
||||||
|
"""Calculate x,y coordinates for each node in the flow diagram."""
|
||||||
level_nodes = {}
|
level_nodes = {}
|
||||||
node_positions = {}
|
node_positions = {}
|
||||||
|
|
||||||
@@ -91,6 +94,7 @@ def compute_positions(flow, node_levels, y_spacing=150, x_spacing=150):
|
|||||||
|
|
||||||
|
|
||||||
def add_edges(net, flow, node_positions, colors):
|
def add_edges(net, flow, node_positions, colors):
|
||||||
|
"""Add edges between nodes with appropriate styling and routing."""
|
||||||
ancestors = build_ancestor_dict(flow)
|
ancestors = build_ancestor_dict(flow)
|
||||||
parent_children = build_parent_children_dict(flow)
|
parent_children = build_parent_children_dict(flow)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user