mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
properly identifying router and router children nodes. Need to fix color
This commit is contained in:
Binary file not shown.
@@ -58,10 +58,12 @@ def listen(condition):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def router(method):
|
def router(method, paths=None):
|
||||||
def decorator(func):
|
def decorator(func):
|
||||||
func.__is_router__ = True
|
func.__is_router__ = True
|
||||||
func.__router_for__ = method.__name__
|
func.__router_for__ = method.__name__
|
||||||
|
if paths:
|
||||||
|
func.__router_paths__ = paths
|
||||||
return func
|
return func
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
@@ -102,6 +104,7 @@ class FlowMeta(type):
|
|||||||
start_methods = []
|
start_methods = []
|
||||||
listeners = {}
|
listeners = {}
|
||||||
routers = {}
|
routers = {}
|
||||||
|
router_paths = {}
|
||||||
|
|
||||||
for attr_name, attr_value in dct.items():
|
for attr_name, attr_value in dct.items():
|
||||||
if hasattr(attr_value, "__is_start_method__"):
|
if hasattr(attr_value, "__is_start_method__"):
|
||||||
@@ -116,10 +119,24 @@ class FlowMeta(type):
|
|||||||
listeners[attr_name] = (condition_type, methods)
|
listeners[attr_name] = (condition_type, methods)
|
||||||
elif hasattr(attr_value, "__is_router__"):
|
elif hasattr(attr_value, "__is_router__"):
|
||||||
routers[attr_value.__router_for__] = attr_name
|
routers[attr_value.__router_for__] = attr_name
|
||||||
|
if hasattr(attr_value, "__router_paths__"):
|
||||||
|
router_paths[attr_name] = attr_value.__router_paths__
|
||||||
|
|
||||||
|
# **Register router as a listener to its triggering method**
|
||||||
|
trigger_method_name = attr_value.__router_for__
|
||||||
|
methods = [trigger_method_name]
|
||||||
|
condition_type = "OR"
|
||||||
|
listeners[attr_name] = (condition_type, methods)
|
||||||
|
|
||||||
setattr(cls, "_start_methods", start_methods)
|
setattr(cls, "_start_methods", start_methods)
|
||||||
setattr(cls, "_listeners", listeners)
|
setattr(cls, "_listeners", listeners)
|
||||||
setattr(cls, "_routers", routers)
|
setattr(cls, "_routers", routers)
|
||||||
|
setattr(cls, "_router_paths", router_paths)
|
||||||
|
|
||||||
|
print("Start methods:", start_methods)
|
||||||
|
print("Listeners:", listeners)
|
||||||
|
print("Routers:", routers)
|
||||||
|
print("Router paths:", router_paths)
|
||||||
|
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
@@ -128,6 +145,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
_start_methods: List[str] = []
|
_start_methods: List[str] = []
|
||||||
_listeners: Dict[str, tuple[str, List[str]]] = {}
|
_listeners: Dict[str, tuple[str, List[str]]] = {}
|
||||||
_routers: Dict[str, str] = {}
|
_routers: Dict[str, str] = {}
|
||||||
|
_router_paths: Dict[str, List[str]] = {}
|
||||||
initial_state: Union[Type[T], T, None] = None
|
initial_state: Union[Type[T], T, None] = None
|
||||||
|
|
||||||
def __class_getitem__(cls, item):
|
def __class_getitem__(cls, item):
|
||||||
|
|||||||
@@ -13,8 +13,10 @@ class FlowVisualizer(ABC):
|
|||||||
"bg": "#FFFFFF",
|
"bg": "#FFFFFF",
|
||||||
"start": "#FF5A50",
|
"start": "#FF5A50",
|
||||||
"method": "#333333",
|
"method": "#333333",
|
||||||
"router": "#FF8C00",
|
"router": "#333333", # Dark gray for router background
|
||||||
|
"router_border": "#FF8C00", # Orange for router border
|
||||||
"edge": "#666666",
|
"edge": "#666666",
|
||||||
|
"router_edge": "#FF8C00", # Orange for router edges
|
||||||
"text": "#FFFFFF",
|
"text": "#FFFFFF",
|
||||||
}
|
}
|
||||||
self.node_styles = {
|
self.node_styles = {
|
||||||
@@ -32,6 +34,10 @@ class FlowVisualizer(ABC):
|
|||||||
"color": self.colors["router"],
|
"color": self.colors["router"],
|
||||||
"shape": "box",
|
"shape": "box",
|
||||||
"font": {"color": self.colors["text"]},
|
"font": {"color": self.colors["text"]},
|
||||||
|
"borderWidth": 2,
|
||||||
|
"borderWidthSelected": 4,
|
||||||
|
"borderDashes": [5, 5], # Dashed border
|
||||||
|
"borderColor": self.colors["router_border"],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,6 +58,7 @@ class PyvisFlowVisualizer(FlowVisualizer):
|
|||||||
|
|
||||||
# Calculate levels for nodes
|
# Calculate levels for nodes
|
||||||
node_levels = self._calculate_node_levels()
|
node_levels = self._calculate_node_levels()
|
||||||
|
print("node_levels", node_levels)
|
||||||
|
|
||||||
# Assign positions to nodes based on levels
|
# Assign positions to nodes based on levels
|
||||||
y_spacing = 150 # Adjust spacing between levels (positive for top-down)
|
y_spacing = 150 # Adjust spacing between levels (positive for top-down)
|
||||||
@@ -61,8 +68,11 @@ class PyvisFlowVisualizer(FlowVisualizer):
|
|||||||
for method_name, level in node_levels.items():
|
for method_name, level in node_levels.items():
|
||||||
level_nodes.setdefault(level, []).append(method_name)
|
level_nodes.setdefault(level, []).append(method_name)
|
||||||
|
|
||||||
|
print("level_nodes", level_nodes)
|
||||||
|
|
||||||
# Compute positions
|
# Compute positions
|
||||||
for level, nodes in level_nodes.items():
|
for level, nodes in level_nodes.items():
|
||||||
|
print("level", level, "nodes", nodes)
|
||||||
x_offset = -(len(nodes) - 1) * x_spacing / 2 # Center nodes horizontally
|
x_offset = -(len(nodes) - 1) * x_spacing / 2 # Center nodes horizontally
|
||||||
for i, method_name in enumerate(nodes):
|
for i, method_name in enumerate(nodes):
|
||||||
x = x_offset + i * x_spacing
|
x = x_offset + i * x_spacing
|
||||||
@@ -85,21 +95,47 @@ class PyvisFlowVisualizer(FlowVisualizer):
|
|||||||
**node_style,
|
**node_style,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add edges with curved lines
|
# Add edges
|
||||||
for method_name in self.flow._listeners:
|
for method_name in self.flow._listeners:
|
||||||
condition_type, trigger_methods = self.flow._listeners[method_name]
|
condition_type, trigger_methods = self.flow._listeners[method_name]
|
||||||
is_and_condition = condition_type == "AND"
|
is_and_condition = condition_type == "AND"
|
||||||
|
|
||||||
for trigger in trigger_methods:
|
for trigger in trigger_methods:
|
||||||
if trigger in self.flow._methods:
|
if trigger in self.flow._methods:
|
||||||
net.add_edge(
|
is_router_edge = (
|
||||||
trigger,
|
trigger in self.flow._routers.values()
|
||||||
method_name,
|
or method_name in self.flow._routers.values()
|
||||||
color=self.colors.get("edge", "#666666"),
|
|
||||||
width=2,
|
|
||||||
arrows="to",
|
|
||||||
dashes=is_and_condition,
|
|
||||||
smooth={"type": "cubicBezier"},
|
|
||||||
)
|
)
|
||||||
|
edge_color = (
|
||||||
|
self.colors["router_edge"]
|
||||||
|
if is_router_edge
|
||||||
|
else self.colors["edge"]
|
||||||
|
)
|
||||||
|
edge_style = {
|
||||||
|
"color": edge_color,
|
||||||
|
"width": 2,
|
||||||
|
"arrows": "to",
|
||||||
|
"dashes": True if is_router_edge or is_and_condition else False,
|
||||||
|
"smooth": {"type": "cubicBezier"},
|
||||||
|
}
|
||||||
|
net.add_edge(trigger, method_name, **edge_style)
|
||||||
|
|
||||||
|
# Add edges from router methods to their possible paths
|
||||||
|
for router_method_name, paths in self.flow._router_paths.items():
|
||||||
|
for path in paths:
|
||||||
|
for listener_name, (
|
||||||
|
condition_type,
|
||||||
|
trigger_methods,
|
||||||
|
) in self.flow._listeners.items():
|
||||||
|
if path in trigger_methods:
|
||||||
|
edge_style = {
|
||||||
|
"color": self.colors["router_edge"],
|
||||||
|
"width": 2,
|
||||||
|
"arrows": "to",
|
||||||
|
"dashes": True,
|
||||||
|
"smooth": {"type": "cubicBezier"},
|
||||||
|
}
|
||||||
|
net.add_edge(router_method_name, listener_name, **edge_style)
|
||||||
|
|
||||||
# Set options for curved edges and disable physics
|
# Set options for curved edges and disable physics
|
||||||
net.set_options(
|
net.set_options(
|
||||||
@@ -138,38 +174,40 @@ class PyvisFlowVisualizer(FlowVisualizer):
|
|||||||
|
|
||||||
# Generate the legend items HTML
|
# Generate the legend items HTML
|
||||||
legend_items = [
|
legend_items = [
|
||||||
{"label": "Start Method", "color": self.colors.get("start", "#FF5A50")},
|
{"label": "Start Method", "color": self.colors["start"]},
|
||||||
{"label": "Method", "color": self.colors.get("method", "#333333")},
|
{"label": "Method", "color": self.colors["method"]},
|
||||||
# {"label": "Router", "color": self.colors.get("router", "#FF8C00")},
|
|
||||||
{
|
{
|
||||||
"label": "Trigger",
|
"label": "Router",
|
||||||
"color": self.colors.get("edge", "#666666"),
|
"color": self.colors["router"],
|
||||||
"dashed": False,
|
"border": self.colors["router_border"],
|
||||||
|
"dashed": True,
|
||||||
},
|
},
|
||||||
|
{"label": "Trigger", "color": self.colors["edge"], "dashed": False},
|
||||||
|
{"label": "AND Trigger", "color": self.colors["edge"], "dashed": True},
|
||||||
{
|
{
|
||||||
"label": "AND Trigger",
|
"label": "Router Trigger",
|
||||||
"color": self.colors.get("edge", "#666666"),
|
"color": self.colors["router_edge"],
|
||||||
"dashed": True,
|
"dashed": True,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
legend_items_html = ""
|
legend_items_html = ""
|
||||||
for item in legend_items:
|
for item in legend_items:
|
||||||
if item.get("dashed") is not None:
|
if "border" in item:
|
||||||
if item.get("dashed"):
|
legend_items_html += f"""
|
||||||
legend_items_html += f"""
|
<div class="legend-item">
|
||||||
<div class="legend-item">
|
<div class="legend-color-box" style="background-color: {item['color']}; border: 2px dashed {item['border']};"></div>
|
||||||
<div class="legend-dashed"></div>
|
<div>{item['label']}</div>
|
||||||
<div>{item['label']}</div>
|
</div>
|
||||||
</div>
|
"""
|
||||||
"""
|
elif item.get("dashed") is not None:
|
||||||
else:
|
style = "dashed" if item["dashed"] else "solid"
|
||||||
legend_items_html += f"""
|
legend_items_html += f"""
|
||||||
<div class="legend-item">
|
<div class="legend-item">
|
||||||
<div class="legend-solid" style="border-bottom: 2px solid {item['color']};"></div>
|
<div class="legend-{style}" style="border-bottom: 2px {style} {item['color']};"></div>
|
||||||
<div>{item['label']}</div>
|
<div>{item['label']}</div>
|
||||||
</div>
|
</div>
|
||||||
"""
|
"""
|
||||||
else:
|
else:
|
||||||
legend_items_html += f"""
|
legend_items_html += f"""
|
||||||
<div class="legend-item">
|
<div class="legend-item">
|
||||||
@@ -205,6 +243,7 @@ class PyvisFlowVisualizer(FlowVisualizer):
|
|||||||
levels = {}
|
levels = {}
|
||||||
queue = []
|
queue = []
|
||||||
visited = set()
|
visited = set()
|
||||||
|
pending_and_listeners = {}
|
||||||
|
|
||||||
# Initialize start methods at level 0
|
# Initialize start methods at level 0
|
||||||
for method_name, method in self.flow._methods.items():
|
for method_name, method in self.flow._methods.items():
|
||||||
@@ -223,15 +262,46 @@ class PyvisFlowVisualizer(FlowVisualizer):
|
|||||||
condition_type,
|
condition_type,
|
||||||
trigger_methods,
|
trigger_methods,
|
||||||
) in self.flow._listeners.items():
|
) in self.flow._listeners.items():
|
||||||
if current in trigger_methods:
|
if condition_type == "OR":
|
||||||
if (
|
if current in trigger_methods:
|
||||||
listener_name not in levels
|
if (
|
||||||
or levels[listener_name] > current_level + 1
|
listener_name not in levels
|
||||||
):
|
or levels[listener_name] > current_level + 1
|
||||||
levels[listener_name] = current_level + 1
|
):
|
||||||
if listener_name not in visited:
|
levels[listener_name] = current_level + 1
|
||||||
queue.append(listener_name)
|
if listener_name not in visited:
|
||||||
|
queue.append(listener_name)
|
||||||
|
elif condition_type == "AND":
|
||||||
|
if listener_name not in pending_and_listeners:
|
||||||
|
pending_and_listeners[listener_name] = set()
|
||||||
|
if current in trigger_methods:
|
||||||
|
pending_and_listeners[listener_name].add(current)
|
||||||
|
if set(trigger_methods) == pending_and_listeners[listener_name]:
|
||||||
|
if (
|
||||||
|
listener_name not in levels
|
||||||
|
or levels[listener_name] > current_level + 1
|
||||||
|
):
|
||||||
|
levels[listener_name] = current_level + 1
|
||||||
|
if listener_name not in visited:
|
||||||
|
queue.append(listener_name)
|
||||||
|
|
||||||
|
# Handle router connections (same as before)
|
||||||
|
if current in self.flow._routers.values():
|
||||||
|
router_method_name = current
|
||||||
|
paths = self.flow._router_paths.get(router_method_name, [])
|
||||||
|
for path in paths:
|
||||||
|
for listener_name, (
|
||||||
|
condition_type,
|
||||||
|
trigger_methods,
|
||||||
|
) in self.flow._listeners.items():
|
||||||
|
if path in trigger_methods:
|
||||||
|
if (
|
||||||
|
listener_name not in levels
|
||||||
|
or levels[listener_name] > current_level + 1
|
||||||
|
):
|
||||||
|
levels[listener_name] = current_level + 1
|
||||||
|
if listener_name not in visited:
|
||||||
|
queue.append(listener_name)
|
||||||
return levels
|
return levels
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user