fix: type checker errors in flow modules

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2024-12-30 20:53:59 +00:00
parent 1dc8ce2674
commit 613dd175ee
3 changed files with 22 additions and 19 deletions

View File

@@ -17,7 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union
from pydantic import BaseModel
def get_possible_return_constants(function: callable) -> Optional[List[str]]:
def get_possible_return_constants(function: Callable[..., Any]) -> Optional[List[str]]:
"""Extract possible string return values from a function by analyzing its source code.
Analyzes the function's source code using AST to identify string constants that

View File

@@ -5,13 +5,13 @@ Flow graphs and calculating layout information. These utilities are separated
from general-purpose utilities to maintain a clean dependency structure.
"""
from typing import TYPE_CHECKING, Dict, List, Set
from typing import TYPE_CHECKING, Dict, List, Set, Any
if TYPE_CHECKING:
from crewai.flow.flow import Flow
def calculate_node_levels(flow: Flow) -> Dict[str, int]:
def calculate_node_levels(flow: Flow[Any]) -> Dict[str, int]:
"""Calculate the hierarchical level of each node in the flow graph.
Uses breadth-first traversal to assign levels to nodes, starting with
@@ -35,10 +35,10 @@ def calculate_node_levels(flow: Flow) -> Dict[str, int]:
>>> calculate_node_levels(flow)
{'start': 0, 'second': 1}
"""
levels = {}
queue = []
visited = set()
pending_and_listeners = {}
levels: Dict[str, int] = {}
queue: List[str] = []
visited: Set[str] = set()
pending_and_listeners: Dict[str, Set[str]] = {}
# Make all start methods at level 0
for method_name, method in flow._methods.items():
@@ -97,7 +97,7 @@ def calculate_node_levels(flow: Flow) -> Dict[str, int]:
return levels
def count_outgoing_edges(flow: Flow) -> Dict[str, int]:
def count_outgoing_edges(flow: Flow[Any]) -> Dict[str, int]:
"""Count the number of outgoing edges for each node in the flow graph.
An outgoing edge represents a connection from a method to a listener
@@ -111,7 +111,7 @@ def count_outgoing_edges(flow: Flow) -> Dict[str, int]:
dict[str, int]: Dictionary mapping method names to their number
of outgoing connections
"""
counts = {}
counts: Dict[str, int] = {}
for method_name in flow._methods:
counts[method_name] = 0
for method_name in flow._listeners:
@@ -122,7 +122,7 @@ def count_outgoing_edges(flow: Flow) -> Dict[str, int]:
return counts
def build_ancestor_dict(flow: Flow) -> Dict[str, Set[str]]:
def build_ancestor_dict(flow: Flow[Any]) -> Dict[str, Set[str]]:
"""Build a dictionary mapping each node to its set of ancestor nodes.
Uses depth-first search to identify all ancestors (direct and indirect
@@ -147,7 +147,7 @@ def build_ancestor_dict(flow: Flow) -> Dict[str, Set[str]]:
def dfs_ancestors(node: str, ancestors: Dict[str, Set[str]],
visited: Set[str], flow: Flow) -> None:
visited: Set[str], flow: Flow[Any]) -> None:
"""Perform depth-first search to populate the ancestors dictionary.
Helper function for build_ancestor_dict that recursively traverses
@@ -182,7 +182,7 @@ def dfs_ancestors(node: str, ancestors: Dict[str, Set[str]],
dfs_ancestors(listener_name, ancestors, visited, flow)
def build_parent_children_dict(flow: Flow) -> Dict[str, List[str]]:
def build_parent_children_dict(flow: Flow[Any]) -> Dict[str, List[str]]:
"""Build a dictionary mapping each node to its list of child nodes.
Maps both regular trigger methods to their listeners and router
@@ -196,7 +196,7 @@ def build_parent_children_dict(flow: Flow) -> Dict[str, List[str]]:
dict[str, list[str]]: Dictionary mapping each method name to a
sorted list of its child method names
"""
parent_children = {}
parent_children: Dict[str, List[str]] = {}
# Map listeners to their trigger methods
for listener_name, (_, trigger_methods) in flow._listeners.items():

View File

@@ -2,8 +2,11 @@ import ast
import inspect
import os
from pathlib import Path
from typing import Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple, cast
from pyvis.network import Network
from crewai.flow.flow import Flow
from .core_flow_utils import is_ancestor
from .flow_visual_utils import (
build_ancestor_dict,
@@ -13,7 +16,7 @@ from .flow_visual_utils import (
from .path_utils import safe_path_join, validate_file_path
def method_calls_crew(method: callable) -> bool:
def method_calls_crew(method: Callable[..., Any]) -> bool:
"""Check if the method contains a .crew() call in its implementation.
Analyzes the method's source code using AST to detect if it makes any
@@ -65,7 +68,7 @@ def method_calls_crew(method: callable) -> bool:
return visitor.found
def add_nodes_to_network(net: object, flow: object,
def add_nodes_to_network(net: Network, flow: Flow[Any],
node_positions: Dict[str, Tuple[float, float]],
node_styles: Dict[str, dict],
output_dir: Optional[str] = None) -> None:
@@ -169,8 +172,8 @@ def add_nodes_to_network(net: object, flow: object,
)
def compute_positions(flow: object, node_levels: dict[str, int],
y_spacing: float = 150, x_spacing: float = 150) -> dict[str, tuple[float, float]]:
def compute_positions(flow: Flow[Any], node_levels: Dict[str, int],
y_spacing: float = 150, x_spacing: float = 150) -> Dict[str, Tuple[float, float]]:
if not hasattr(flow, '_methods'):
raise ValueError("Invalid flow object: missing '_methods' attribute")
if not isinstance(node_levels, dict):
@@ -218,7 +221,7 @@ def compute_positions(flow: object, node_levels: dict[str, int],
return node_positions
def add_edges(net: object, flow: object,
def add_edges(net: Network, flow: Flow[Any],
node_positions: Dict[str, Tuple[float, float]],
colors: Dict[str, str],
asset_dir: Optional[str] = None) -> None: