fix: type checker errors in flow modules

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2024-12-30 20:53:59 +00:00
parent 1dc8ce2674
commit 613dd175ee
3 changed files with 22 additions and 19 deletions

View File

@@ -17,7 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union
from pydantic import BaseModel from pydantic import BaseModel
def get_possible_return_constants(function: callable) -> Optional[List[str]]: def get_possible_return_constants(function: Callable[..., Any]) -> Optional[List[str]]:
"""Extract possible string return values from a function by analyzing its source code. """Extract possible string return values from a function by analyzing its source code.
Analyzes the function's source code using AST to identify string constants that Analyzes the function's source code using AST to identify string constants that

View File

@@ -5,13 +5,13 @@ Flow graphs and calculating layout information. These utilities are separated
from general-purpose utilities to maintain a clean dependency structure. from general-purpose utilities to maintain a clean dependency structure.
""" """
from typing import TYPE_CHECKING, Dict, List, Set from typing import TYPE_CHECKING, Dict, List, Set, Any
if TYPE_CHECKING: if TYPE_CHECKING:
from crewai.flow.flow import Flow from crewai.flow.flow import Flow
def calculate_node_levels(flow: Flow) -> Dict[str, int]: def calculate_node_levels(flow: Flow[Any]) -> Dict[str, int]:
"""Calculate the hierarchical level of each node in the flow graph. """Calculate the hierarchical level of each node in the flow graph.
Uses breadth-first traversal to assign levels to nodes, starting with Uses breadth-first traversal to assign levels to nodes, starting with
@@ -35,10 +35,10 @@ def calculate_node_levels(flow: Flow) -> Dict[str, int]:
>>> calculate_node_levels(flow) >>> calculate_node_levels(flow)
{'start': 0, 'second': 1} {'start': 0, 'second': 1}
""" """
levels = {} levels: Dict[str, int] = {}
queue = [] queue: List[str] = []
visited = set() visited: Set[str] = set()
pending_and_listeners = {} pending_and_listeners: Dict[str, Set[str]] = {}
# Make all start methods at level 0 # Make all start methods at level 0
for method_name, method in flow._methods.items(): for method_name, method in flow._methods.items():
@@ -97,7 +97,7 @@ def calculate_node_levels(flow: Flow) -> Dict[str, int]:
return levels return levels
def count_outgoing_edges(flow: Flow) -> Dict[str, int]: def count_outgoing_edges(flow: Flow[Any]) -> Dict[str, int]:
"""Count the number of outgoing edges for each node in the flow graph. """Count the number of outgoing edges for each node in the flow graph.
An outgoing edge represents a connection from a method to a listener An outgoing edge represents a connection from a method to a listener
@@ -111,7 +111,7 @@ def count_outgoing_edges(flow: Flow) -> Dict[str, int]:
dict[str, int]: Dictionary mapping method names to their number dict[str, int]: Dictionary mapping method names to their number
of outgoing connections of outgoing connections
""" """
counts = {} counts: Dict[str, int] = {}
for method_name in flow._methods: for method_name in flow._methods:
counts[method_name] = 0 counts[method_name] = 0
for method_name in flow._listeners: for method_name in flow._listeners:
@@ -122,7 +122,7 @@ def count_outgoing_edges(flow: Flow) -> Dict[str, int]:
return counts return counts
def build_ancestor_dict(flow: Flow) -> Dict[str, Set[str]]: def build_ancestor_dict(flow: Flow[Any]) -> Dict[str, Set[str]]:
"""Build a dictionary mapping each node to its set of ancestor nodes. """Build a dictionary mapping each node to its set of ancestor nodes.
Uses depth-first search to identify all ancestors (direct and indirect Uses depth-first search to identify all ancestors (direct and indirect
@@ -147,7 +147,7 @@ def build_ancestor_dict(flow: Flow) -> Dict[str, Set[str]]:
def dfs_ancestors(node: str, ancestors: Dict[str, Set[str]], def dfs_ancestors(node: str, ancestors: Dict[str, Set[str]],
visited: Set[str], flow: Flow) -> None: visited: Set[str], flow: Flow[Any]) -> None:
"""Perform depth-first search to populate the ancestors dictionary. """Perform depth-first search to populate the ancestors dictionary.
Helper function for build_ancestor_dict that recursively traverses Helper function for build_ancestor_dict that recursively traverses
@@ -182,7 +182,7 @@ def dfs_ancestors(node: str, ancestors: Dict[str, Set[str]],
dfs_ancestors(listener_name, ancestors, visited, flow) dfs_ancestors(listener_name, ancestors, visited, flow)
def build_parent_children_dict(flow: Flow) -> Dict[str, List[str]]: def build_parent_children_dict(flow: Flow[Any]) -> Dict[str, List[str]]:
"""Build a dictionary mapping each node to its list of child nodes. """Build a dictionary mapping each node to its list of child nodes.
Maps both regular trigger methods to their listeners and router Maps both regular trigger methods to their listeners and router
@@ -196,7 +196,7 @@ def build_parent_children_dict(flow: Flow) -> Dict[str, List[str]]:
dict[str, list[str]]: Dictionary mapping each method name to a dict[str, list[str]]: Dictionary mapping each method name to a
sorted list of its child method names sorted list of its child method names
""" """
parent_children = {} parent_children: Dict[str, List[str]] = {}
# Map listeners to their trigger methods # Map listeners to their trigger methods
for listener_name, (_, trigger_methods) in flow._listeners.items(): for listener_name, (_, trigger_methods) in flow._listeners.items():

View File

@@ -2,8 +2,11 @@ import ast
import inspect import inspect
import os import os
from pathlib import Path from pathlib import Path
from typing import Dict, Optional, Tuple from typing import Any, Callable, Dict, Optional, Tuple, cast
from pyvis.network import Network
from crewai.flow.flow import Flow
from .core_flow_utils import is_ancestor from .core_flow_utils import is_ancestor
from .flow_visual_utils import ( from .flow_visual_utils import (
build_ancestor_dict, build_ancestor_dict,
@@ -13,7 +16,7 @@ from .flow_visual_utils import (
from .path_utils import safe_path_join, validate_file_path from .path_utils import safe_path_join, validate_file_path
def method_calls_crew(method: callable) -> bool: def method_calls_crew(method: Callable[..., Any]) -> bool:
"""Check if the method contains a .crew() call in its implementation. """Check if the method contains a .crew() call in its implementation.
Analyzes the method's source code using AST to detect if it makes any Analyzes the method's source code using AST to detect if it makes any
@@ -65,7 +68,7 @@ def method_calls_crew(method: callable) -> bool:
return visitor.found return visitor.found
def add_nodes_to_network(net: object, flow: object, def add_nodes_to_network(net: Network, flow: Flow[Any],
node_positions: Dict[str, Tuple[float, float]], node_positions: Dict[str, Tuple[float, float]],
node_styles: Dict[str, dict], node_styles: Dict[str, dict],
output_dir: Optional[str] = None) -> None: output_dir: Optional[str] = None) -> None:
@@ -169,8 +172,8 @@ def add_nodes_to_network(net: object, flow: object,
) )
def compute_positions(flow: object, node_levels: dict[str, int], def compute_positions(flow: Flow[Any], node_levels: Dict[str, int],
y_spacing: float = 150, x_spacing: float = 150) -> dict[str, tuple[float, float]]: y_spacing: float = 150, x_spacing: float = 150) -> Dict[str, Tuple[float, float]]:
if not hasattr(flow, '_methods'): if not hasattr(flow, '_methods'):
raise ValueError("Invalid flow object: missing '_methods' attribute") raise ValueError("Invalid flow object: missing '_methods' attribute")
if not isinstance(node_levels, dict): if not isinstance(node_levels, dict):
@@ -218,7 +221,7 @@ def compute_positions(flow: object, node_levels: dict[str, int],
return node_positions return node_positions
def add_edges(net: object, flow: object, def add_edges(net: Network, flow: Flow[Any],
node_positions: Dict[str, Tuple[float, float]], node_positions: Dict[str, Tuple[float, float]],
colors: Dict[str, str], colors: Dict[str, str],
asset_dir: Optional[str] = None) -> None: asset_dir: Optional[str] = None) -> None: