mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
all working. needs to be cleaned up
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
# flow.py
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, Generic, List, Set, Type, TypeVar, Union
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# flow_visualizer.py
|
||||
|
||||
import base64
|
||||
import os
|
||||
import re
|
||||
@@ -5,39 +7,53 @@ from abc import ABC, abstractmethod
|
||||
|
||||
from pyvis.network import Network
|
||||
|
||||
DARK_GRAY = "#333333"
|
||||
CREWAI_ORANGE = "#FF5A50"
|
||||
GRAY = "#666666"
|
||||
WHITE = "#FFFFFF"
|
||||
|
||||
|
||||
class FlowVisualizer(ABC):
|
||||
def __init__(self, flow):
|
||||
self.flow = flow
|
||||
self.colors = {
|
||||
"bg": "#FFFFFF",
|
||||
"start": "#FF5A50",
|
||||
"method": "#333333",
|
||||
"router": "#333333", # Dark gray for router background
|
||||
"router_border": "#FF8C00", # Orange for router border
|
||||
"edge": "#666666",
|
||||
"router_edge": "#FF8C00", # Orange for router edges
|
||||
"text": "#FFFFFF",
|
||||
"bg": WHITE,
|
||||
"start": CREWAI_ORANGE,
|
||||
"method": DARK_GRAY,
|
||||
"router": DARK_GRAY,
|
||||
"router_border": CREWAI_ORANGE,
|
||||
"edge": GRAY,
|
||||
"router_edge": CREWAI_ORANGE,
|
||||
"text": WHITE,
|
||||
}
|
||||
self.node_styles = {
|
||||
"start": {
|
||||
"color": self.colors["start"],
|
||||
"shape": "box",
|
||||
"font": {"color": self.colors["text"]},
|
||||
"margin": 15,
|
||||
},
|
||||
"method": {
|
||||
"color": self.colors["method"],
|
||||
"shape": "box",
|
||||
"font": {"color": self.colors["text"]},
|
||||
"margin": 15,
|
||||
},
|
||||
"router": {
|
||||
"color": self.colors["router"],
|
||||
"color": {
|
||||
"background": self.colors["router"],
|
||||
"border": self.colors["router_border"],
|
||||
"highlight": {
|
||||
"border": self.colors["router_border"],
|
||||
"background": self.colors["router"],
|
||||
},
|
||||
},
|
||||
"shape": "box",
|
||||
"font": {"color": self.colors["text"]},
|
||||
"borderWidth": 2,
|
||||
"borderWidth": 3,
|
||||
"borderWidthSelected": 4,
|
||||
"borderDashes": [5, 5], # Dashed border
|
||||
"borderColor": self.colors["router_border"],
|
||||
"shapeProperties": {"borderDashes": [5, 5]},
|
||||
"margin": 15,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -61,8 +77,8 @@ class PyvisFlowVisualizer(FlowVisualizer):
|
||||
print("node_levels", node_levels)
|
||||
|
||||
# Assign positions to nodes based on levels
|
||||
y_spacing = 150 # Adjust spacing between levels (positive for top-down)
|
||||
x_spacing = 150 # Adjust spacing between nodes
|
||||
y_spacing = 150
|
||||
x_spacing = 150
|
||||
level_nodes = {}
|
||||
|
||||
# Store node positions for edge calculations
|
||||
@@ -76,13 +92,13 @@ class PyvisFlowVisualizer(FlowVisualizer):
|
||||
x_offset = -(len(nodes) - 1) * x_spacing / 2 # Center nodes horizontally
|
||||
for i, method_name in enumerate(nodes):
|
||||
x = x_offset + i * x_spacing
|
||||
y = level * y_spacing # Use level directly for y position
|
||||
node_positions[method_name] = (x, y) # Store positions
|
||||
y = level * y_spacing
|
||||
node_positions[method_name] = (x, y)
|
||||
|
||||
method = self.flow._methods.get(method_name)
|
||||
if hasattr(method, "__is_start_method__"):
|
||||
node_style = self.node_styles["start"]
|
||||
elif method_name in self.flow._routers.values():
|
||||
elif hasattr(method, "__is_router__"):
|
||||
node_style = self.node_styles["router"]
|
||||
else:
|
||||
node_style = self.node_styles["method"]
|
||||
@@ -93,11 +109,10 @@ class PyvisFlowVisualizer(FlowVisualizer):
|
||||
x=x,
|
||||
y=y,
|
||||
fixed=True,
|
||||
physics=False, # Disable physics for fixed positioning
|
||||
physics=False,
|
||||
**node_style,
|
||||
)
|
||||
|
||||
# Prepare data structures for edge calculations
|
||||
ancestors = self._build_ancestor_dict()
|
||||
parent_children = self._build_parent_children_dict()
|
||||
|
||||
@@ -107,7 +122,10 @@ class PyvisFlowVisualizer(FlowVisualizer):
|
||||
is_and_condition = condition_type == "AND"
|
||||
|
||||
for trigger in trigger_methods:
|
||||
if trigger in self.flow._methods:
|
||||
if (
|
||||
trigger in self.flow._methods
|
||||
or trigger in self.flow._routers.values()
|
||||
):
|
||||
is_router_edge = any(
|
||||
trigger in paths for paths in self.flow._router_paths.values()
|
||||
)
|
||||
@@ -118,29 +136,41 @@ class PyvisFlowVisualizer(FlowVisualizer):
|
||||
)
|
||||
|
||||
# Determine if this edge forms a cycle
|
||||
is_cycle_edge = self._is_ancestor(method_name, trigger, ancestors)
|
||||
is_cycle_edge = self._is_ancestor(trigger, method_name, ancestors)
|
||||
|
||||
# Determine if parent has multiple children
|
||||
parent_has_multiple_children = (
|
||||
len(parent_children.get(trigger, [])) > 1
|
||||
)
|
||||
|
||||
# Determine if edge should be curved
|
||||
# Edge curvature logic
|
||||
needs_curvature = is_cycle_edge or parent_has_multiple_children
|
||||
|
||||
if needs_curvature:
|
||||
if is_cycle_edge:
|
||||
# For cycles, curve left and up
|
||||
edge_smooth = {"type": "curvedCCW", "roundness": 0.3}
|
||||
else:
|
||||
# For multiple children, adjust curvature to prevent overlap
|
||||
# Get node positions
|
||||
source_pos = node_positions.get(trigger)
|
||||
target_pos = node_positions.get(method_name)
|
||||
|
||||
if source_pos and target_pos:
|
||||
dx = target_pos[0] - source_pos[0]
|
||||
|
||||
if dx <= 0:
|
||||
# Child is left or directly below
|
||||
smooth_type = "curvedCCW" # Curve left and down
|
||||
else:
|
||||
# Child is to the right
|
||||
smooth_type = "curvedCW" # Curve right and down
|
||||
|
||||
index = self._get_child_index(
|
||||
trigger, method_name, parent_children
|
||||
)
|
||||
edge_smooth = {
|
||||
"type": "curvedCW" if index % 2 == 0 else "curvedCCW",
|
||||
"type": smooth_type,
|
||||
"roundness": 0.2 + (0.1 * index),
|
||||
}
|
||||
else:
|
||||
# Fallback curvature
|
||||
edge_smooth = {"type": "cubicBezier"}
|
||||
else:
|
||||
edge_smooth = False # Draw straight line
|
||||
|
||||
@@ -163,7 +193,7 @@ class PyvisFlowVisualizer(FlowVisualizer):
|
||||
) in self.flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
is_cycle_edge = self._is_ancestor(
|
||||
listener_name, router_method_name, ancestors
|
||||
trigger, method_name, ancestors
|
||||
)
|
||||
|
||||
# Determine if parent has multiple children
|
||||
@@ -171,24 +201,34 @@ class PyvisFlowVisualizer(FlowVisualizer):
|
||||
len(parent_children.get(router_method_name, [])) > 1
|
||||
)
|
||||
|
||||
# Determine if edge should be curved
|
||||
# Edge curvature logic
|
||||
needs_curvature = is_cycle_edge or parent_has_multiple_children
|
||||
|
||||
if needs_curvature:
|
||||
if is_cycle_edge:
|
||||
# Curve left and up for cycles
|
||||
edge_smooth = {"type": "curvedCCW", "roundness": 0.3}
|
||||
else:
|
||||
# For multiple children, adjust curvature
|
||||
# Get node positions
|
||||
source_pos = node_positions.get(router_method_name)
|
||||
target_pos = node_positions.get(listener_name)
|
||||
|
||||
if source_pos and target_pos:
|
||||
dx = target_pos[0] - source_pos[0]
|
||||
|
||||
if dx <= 0:
|
||||
# Child is left or directly below
|
||||
smooth_type = "curvedCCW" # Curve left and down
|
||||
else:
|
||||
# Child is to the right
|
||||
smooth_type = "curvedCW" # Curve right and down
|
||||
|
||||
index = self._get_child_index(
|
||||
router_method_name, listener_name, parent_children
|
||||
)
|
||||
edge_smooth = {
|
||||
"type": (
|
||||
"curvedCW" if index % 2 == 0 else "curvedCCW"
|
||||
),
|
||||
"type": smooth_type,
|
||||
"roundness": 0.2 + (0.1 * index),
|
||||
}
|
||||
else:
|
||||
# Fallback curvature
|
||||
edge_smooth = {"type": "cubicBezier"}
|
||||
else:
|
||||
edge_smooth = False # Straight line
|
||||
|
||||
@@ -374,54 +414,65 @@ class PyvisFlowVisualizer(FlowVisualizer):
|
||||
return counts
|
||||
|
||||
def _build_ancestor_dict(self):
|
||||
# Helper method to build a dictionary of ancestors for each node
|
||||
ancestors = {node: set() for node in self.flow._methods}
|
||||
visited = set()
|
||||
for node in self.flow._methods:
|
||||
self._dfs_ancestors(node, ancestors, set())
|
||||
if node not in visited:
|
||||
self._dfs_ancestors(node, ancestors, visited)
|
||||
print("Ancestor Relationships:")
|
||||
for node, node_ancestors in ancestors.items():
|
||||
print(f"{node}: {node_ancestors}")
|
||||
return ancestors
|
||||
|
||||
def _dfs_ancestors(self, node, ancestors, visited):
|
||||
if node in visited:
|
||||
return
|
||||
visited.add(node)
|
||||
|
||||
# Handle regular listeners
|
||||
for listener_name, (_, trigger_methods) in self.flow._listeners.items():
|
||||
if node in trigger_methods:
|
||||
ancestors[listener_name].add(node)
|
||||
ancestors[listener_name].update(ancestors[node])
|
||||
self._dfs_ancestors(listener_name, ancestors, visited)
|
||||
# Include router paths
|
||||
for router_method_name, paths in self.flow._router_paths.items():
|
||||
if node == router_method_name:
|
||||
for path in paths:
|
||||
for listener_name, (
|
||||
_,
|
||||
trigger_methods,
|
||||
) in self.flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
ancestors[listener_name].add(node)
|
||||
ancestors[listener_name].update(ancestors[node])
|
||||
self._dfs_ancestors(listener_name, ancestors, visited)
|
||||
|
||||
# Handle router methods separately
|
||||
if node in self.flow._routers.values():
|
||||
router_method_name = node
|
||||
paths = self.flow._router_paths.get(router_method_name, [])
|
||||
for path in paths:
|
||||
for listener_name, (_, trigger_methods) in self.flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
# Only propagate the ancestors of the router method, not the router method itself
|
||||
ancestors[listener_name].update(ancestors[node])
|
||||
self._dfs_ancestors(listener_name, ancestors, visited)
|
||||
|
||||
def _is_ancestor(self, node, ancestor_candidate, ancestors):
|
||||
return ancestor_candidate in ancestors.get(node, set())
|
||||
|
||||
def _build_parent_children_dict(self):
|
||||
# Helper method to build a parent to children mapping
|
||||
parent_children = {}
|
||||
# Map listeners to their trigger methods
|
||||
for listener_name, (_, trigger_methods) in self.flow._listeners.items():
|
||||
for trigger in trigger_methods:
|
||||
if trigger not in parent_children:
|
||||
parent_children[trigger] = []
|
||||
parent_children[trigger].append(listener_name)
|
||||
# Include router paths and their connected listener methods
|
||||
if listener_name not in parent_children[trigger]:
|
||||
parent_children[trigger].append(listener_name)
|
||||
# Map router methods to their paths and to listeners
|
||||
for router_method_name, paths in self.flow._router_paths.items():
|
||||
for path in paths:
|
||||
# Find listener methods triggered by each path
|
||||
# Map router method to listeners of each path
|
||||
for listener_name, (_, trigger_methods) in self.flow._listeners.items():
|
||||
if path in trigger_methods:
|
||||
parent_children.setdefault(router_method_name, []).append(
|
||||
listener_name
|
||||
)
|
||||
if router_method_name not in parent_children:
|
||||
parent_children[router_method_name] = []
|
||||
if listener_name not in parent_children[router_method_name]:
|
||||
parent_children[router_method_name].append(listener_name)
|
||||
# Debugging output
|
||||
print("Parent-Children Relationships:")
|
||||
for parent, children in parent_children.items():
|
||||
print(f"{parent}: {children}")
|
||||
return parent_children
|
||||
|
||||
def _get_child_index(self, parent, child, parent_children):
|
||||
|
||||
Reference in New Issue
Block a user