mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
fix: type checker errors in flow modules
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -17,7 +17,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def get_possible_return_constants(function: callable) -> Optional[List[str]]:
|
||||
def get_possible_return_constants(function: Callable[..., Any]) -> Optional[List[str]]:
|
||||
"""Extract possible string return values from a function by analyzing its source code.
|
||||
|
||||
Analyzes the function's source code using AST to identify string constants that
|
||||
|
||||
@@ -5,13 +5,13 @@ Flow graphs and calculating layout information. These utilities are separated
|
||||
from general-purpose utilities to maintain a clean dependency structure.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, List, Set
|
||||
from typing import TYPE_CHECKING, Dict, List, Set, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
|
||||
def calculate_node_levels(flow: Flow) -> Dict[str, int]:
|
||||
def calculate_node_levels(flow: Flow[Any]) -> Dict[str, int]:
|
||||
"""Calculate the hierarchical level of each node in the flow graph.
|
||||
|
||||
Uses breadth-first traversal to assign levels to nodes, starting with
|
||||
@@ -35,10 +35,10 @@ def calculate_node_levels(flow: Flow) -> Dict[str, int]:
|
||||
>>> calculate_node_levels(flow)
|
||||
{'start': 0, 'second': 1}
|
||||
"""
|
||||
levels = {}
|
||||
queue = []
|
||||
visited = set()
|
||||
pending_and_listeners = {}
|
||||
levels: Dict[str, int] = {}
|
||||
queue: List[str] = []
|
||||
visited: Set[str] = set()
|
||||
pending_and_listeners: Dict[str, Set[str]] = {}
|
||||
|
||||
# Make all start methods at level 0
|
||||
for method_name, method in flow._methods.items():
|
||||
@@ -97,7 +97,7 @@ def calculate_node_levels(flow: Flow) -> Dict[str, int]:
|
||||
return levels
|
||||
|
||||
|
||||
def count_outgoing_edges(flow: Flow) -> Dict[str, int]:
|
||||
def count_outgoing_edges(flow: Flow[Any]) -> Dict[str, int]:
|
||||
"""Count the number of outgoing edges for each node in the flow graph.
|
||||
|
||||
An outgoing edge represents a connection from a method to a listener
|
||||
@@ -111,7 +111,7 @@ def count_outgoing_edges(flow: Flow) -> Dict[str, int]:
|
||||
dict[str, int]: Dictionary mapping method names to their number
|
||||
of outgoing connections
|
||||
"""
|
||||
counts = {}
|
||||
counts: Dict[str, int] = {}
|
||||
for method_name in flow._methods:
|
||||
counts[method_name] = 0
|
||||
for method_name in flow._listeners:
|
||||
@@ -122,7 +122,7 @@ def count_outgoing_edges(flow: Flow) -> Dict[str, int]:
|
||||
return counts
|
||||
|
||||
|
||||
def build_ancestor_dict(flow: Flow) -> Dict[str, Set[str]]:
|
||||
def build_ancestor_dict(flow: Flow[Any]) -> Dict[str, Set[str]]:
|
||||
"""Build a dictionary mapping each node to its set of ancestor nodes.
|
||||
|
||||
Uses depth-first search to identify all ancestors (direct and indirect
|
||||
@@ -147,7 +147,7 @@ def build_ancestor_dict(flow: Flow) -> Dict[str, Set[str]]:
|
||||
|
||||
|
||||
def dfs_ancestors(node: str, ancestors: Dict[str, Set[str]],
|
||||
visited: Set[str], flow: Flow) -> None:
|
||||
visited: Set[str], flow: Flow[Any]) -> None:
|
||||
"""Perform depth-first search to populate the ancestors dictionary.
|
||||
|
||||
Helper function for build_ancestor_dict that recursively traverses
|
||||
@@ -182,7 +182,7 @@ def dfs_ancestors(node: str, ancestors: Dict[str, Set[str]],
|
||||
dfs_ancestors(listener_name, ancestors, visited, flow)
|
||||
|
||||
|
||||
def build_parent_children_dict(flow: Flow) -> Dict[str, List[str]]:
|
||||
def build_parent_children_dict(flow: Flow[Any]) -> Dict[str, List[str]]:
|
||||
"""Build a dictionary mapping each node to its list of child nodes.
|
||||
|
||||
Maps both regular trigger methods to their listeners and router
|
||||
@@ -196,7 +196,7 @@ def build_parent_children_dict(flow: Flow) -> Dict[str, List[str]]:
|
||||
dict[str, list[str]]: Dictionary mapping each method name to a
|
||||
sorted list of its child method names
|
||||
"""
|
||||
parent_children = {}
|
||||
parent_children: Dict[str, List[str]] = {}
|
||||
|
||||
# Map listeners to their trigger methods
|
||||
for listener_name, (_, trigger_methods) in flow._listeners.items():
|
||||
|
||||
@@ -2,8 +2,11 @@ import ast
|
||||
import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, cast
|
||||
|
||||
from pyvis.network import Network
|
||||
|
||||
from crewai.flow.flow import Flow
|
||||
from .core_flow_utils import is_ancestor
|
||||
from .flow_visual_utils import (
|
||||
build_ancestor_dict,
|
||||
@@ -13,7 +16,7 @@ from .flow_visual_utils import (
|
||||
from .path_utils import safe_path_join, validate_file_path
|
||||
|
||||
|
||||
def method_calls_crew(method: callable) -> bool:
|
||||
def method_calls_crew(method: Callable[..., Any]) -> bool:
|
||||
"""Check if the method contains a .crew() call in its implementation.
|
||||
|
||||
Analyzes the method's source code using AST to detect if it makes any
|
||||
@@ -65,7 +68,7 @@ def method_calls_crew(method: callable) -> bool:
|
||||
return visitor.found
|
||||
|
||||
|
||||
def add_nodes_to_network(net: object, flow: object,
|
||||
def add_nodes_to_network(net: Network, flow: Flow[Any],
|
||||
node_positions: Dict[str, Tuple[float, float]],
|
||||
node_styles: Dict[str, dict],
|
||||
output_dir: Optional[str] = None) -> None:
|
||||
@@ -169,8 +172,8 @@ def add_nodes_to_network(net: object, flow: object,
|
||||
)
|
||||
|
||||
|
||||
def compute_positions(flow: object, node_levels: dict[str, int],
|
||||
y_spacing: float = 150, x_spacing: float = 150) -> dict[str, tuple[float, float]]:
|
||||
def compute_positions(flow: Flow[Any], node_levels: Dict[str, int],
|
||||
y_spacing: float = 150, x_spacing: float = 150) -> Dict[str, Tuple[float, float]]:
|
||||
if not hasattr(flow, '_methods'):
|
||||
raise ValueError("Invalid flow object: missing '_methods' attribute")
|
||||
if not isinstance(node_levels, dict):
|
||||
@@ -218,7 +221,7 @@ def compute_positions(flow: object, node_levels: dict[str, int],
|
||||
return node_positions
|
||||
|
||||
|
||||
def add_edges(net: object, flow: object,
|
||||
def add_edges(net: Network, flow: Flow[Any],
|
||||
node_positions: Dict[str, Tuple[float, float]],
|
||||
colors: Dict[str, str],
|
||||
asset_dir: Optional[str] = None) -> None:
|
||||
|
||||
Reference in New Issue
Block a user