mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-15 11:58:31 +00:00
fix: ensure fuzzy returns are more strict, show type warning
This commit is contained in:
@@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from collections import defaultdict, deque
|
||||
from enum import Enum
|
||||
import inspect
|
||||
import textwrap
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@@ -40,11 +41,123 @@ if TYPE_CHECKING:
|
||||
_printer = Printer()
|
||||
|
||||
|
||||
def _extract_string_literals_from_type_annotation(
|
||||
node: ast.expr,
|
||||
function_globals: dict[str, Any] | None = None,
|
||||
) -> list[str]:
|
||||
"""Extract string literals from a type annotation AST node.
|
||||
|
||||
Handles:
|
||||
- Literal["a", "b", "c"]
|
||||
- "a" | "b" | "c" (union of string literals)
|
||||
- Just "a" (single string constant annotation)
|
||||
- Enum types with string values (e.g., class MyEnum(str, Enum))
|
||||
|
||||
Args:
|
||||
node: The AST node representing a type annotation.
|
||||
function_globals: The globals dict from the function, used to resolve Enum types.
|
||||
|
||||
Returns:
|
||||
List of string literals found in the annotation.
|
||||
"""
|
||||
|
||||
strings: list[str] = []
|
||||
|
||||
if isinstance(node, ast.Constant) and isinstance(node.value, str):
|
||||
strings.append(node.value)
|
||||
|
||||
elif isinstance(node, ast.Name) and function_globals:
|
||||
enum_class = function_globals.get(node.id)
|
||||
if (
|
||||
enum_class is not None
|
||||
and isinstance(enum_class, type)
|
||||
and issubclass(enum_class, Enum)
|
||||
):
|
||||
strings.extend(
|
||||
member.value for member in enum_class if isinstance(member.value, str)
|
||||
)
|
||||
|
||||
elif isinstance(node, ast.Attribute) and function_globals:
|
||||
try:
|
||||
if isinstance(node.value, ast.Name):
|
||||
module = function_globals.get(node.value.id)
|
||||
if module is not None:
|
||||
enum_class = getattr(module, node.attr, None)
|
||||
if (
|
||||
enum_class is not None
|
||||
and isinstance(enum_class, type)
|
||||
and issubclass(enum_class, Enum)
|
||||
):
|
||||
strings.extend(
|
||||
member.value
|
||||
for member in enum_class
|
||||
if isinstance(member.value, str)
|
||||
)
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
|
||||
elif isinstance(node, ast.Subscript):
|
||||
is_literal = False
|
||||
if isinstance(node.value, ast.Name) and node.value.id == "Literal":
|
||||
is_literal = True
|
||||
elif isinstance(node.value, ast.Attribute) and node.value.attr == "Literal":
|
||||
is_literal = True
|
||||
|
||||
if is_literal:
|
||||
if isinstance(node.slice, ast.Tuple):
|
||||
strings.extend(
|
||||
elt.value
|
||||
for elt in node.slice.elts
|
||||
if isinstance(elt, ast.Constant) and isinstance(elt.value, str)
|
||||
)
|
||||
elif isinstance(node.slice, ast.Constant) and isinstance(
|
||||
node.slice.value, str
|
||||
):
|
||||
strings.append(node.slice.value)
|
||||
|
||||
elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
|
||||
strings.extend(
|
||||
_extract_string_literals_from_type_annotation(node.left, function_globals)
|
||||
)
|
||||
strings.extend(
|
||||
_extract_string_literals_from_type_annotation(node.right, function_globals)
|
||||
)
|
||||
|
||||
return strings
|
||||
|
||||
|
||||
def _unwrap_function(function: Any) -> Any:
|
||||
"""Unwrap a function to get the original function with correct globals.
|
||||
|
||||
Flow methods are wrapped by decorators like @router, @listen, etc.
|
||||
This function unwraps them to get the original function which has
|
||||
the correct __globals__ for resolving type annotations like Enums.
|
||||
|
||||
Args:
|
||||
function: The potentially wrapped function.
|
||||
|
||||
Returns:
|
||||
The unwrapped original function.
|
||||
"""
|
||||
if hasattr(function, "__func__"):
|
||||
function = function.__func__
|
||||
|
||||
if hasattr(function, "__wrapped__"):
|
||||
wrapped = function.__wrapped__
|
||||
if hasattr(wrapped, "unwrap"):
|
||||
return wrapped.unwrap()
|
||||
return wrapped
|
||||
|
||||
return function
|
||||
|
||||
|
||||
def get_possible_return_constants(function: Any) -> list[str] | None:
|
||||
"""Extract possible string return values from a function using AST parsing.
|
||||
|
||||
This function analyzes the source code of a router method to identify
|
||||
all possible string values it might return. It handles:
|
||||
- Return type annotations: -> Literal["a", "b"] or -> "a" | "b" | "c"
|
||||
- Enum type annotations: -> MyEnum (extracts string values from members)
|
||||
- Direct string literals: return "value"
|
||||
- Variable assignments: x = "value"; return x
|
||||
- Dictionary lookups: d = {"k": "v"}; return d[key]
|
||||
@@ -57,6 +170,8 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
|
||||
Returns:
|
||||
List of possible string return values, or None if analysis fails.
|
||||
"""
|
||||
unwrapped = _unwrap_function(function)
|
||||
|
||||
try:
|
||||
source = inspect.getsource(function)
|
||||
except OSError:
|
||||
@@ -97,6 +212,17 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
|
||||
return None
|
||||
|
||||
return_values: set[str] = set()
|
||||
|
||||
function_globals = getattr(unwrapped, "__globals__", None)
|
||||
|
||||
for node in ast.walk(code_ast):
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
if node.returns:
|
||||
annotation_values = _extract_string_literals_from_type_annotation(
|
||||
node.returns, function_globals
|
||||
)
|
||||
return_values.update(annotation_values)
|
||||
break # Only process the first function definition
|
||||
dict_definitions: dict[str, list[str]] = {}
|
||||
variable_values: dict[str, list[str]] = {}
|
||||
state_attribute_values: dict[str, list[str]] = {}
|
||||
|
||||
@@ -3,13 +3,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.flow.constants import AND_CONDITION, OR_CONDITION
|
||||
from crewai.flow.flow_wrappers import FlowCondition
|
||||
from crewai.flow.types import FlowMethodName, FlowRouteName
|
||||
from crewai.flow.types import FlowMethodName
|
||||
from crewai.flow.utils import (
|
||||
is_flow_condition_dict,
|
||||
is_simple_flow_condition,
|
||||
@@ -18,6 +18,9 @@ from crewai.flow.visualization.schema import extract_method_signature
|
||||
from crewai.flow.visualization.types import FlowStructure, NodeMetadata, StructureEdge
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.flow.flow import Flow
|
||||
|
||||
@@ -346,34 +349,43 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure:
|
||||
if trigger_method in nodes
|
||||
)
|
||||
|
||||
all_string_triggers: set[str] = set()
|
||||
for condition_data in flow._listeners.values():
|
||||
if is_simple_flow_condition(condition_data):
|
||||
_, methods = condition_data
|
||||
for m in methods:
|
||||
if str(m) not in nodes: # It's a string trigger, not a method name
|
||||
all_string_triggers.add(str(m))
|
||||
elif is_flow_condition_dict(condition_data):
|
||||
for trigger in _extract_direct_or_triggers(condition_data):
|
||||
if trigger not in nodes:
|
||||
all_string_triggers.add(trigger)
|
||||
|
||||
all_router_outputs: set[str] = set()
|
||||
for router_method_name in router_methods:
|
||||
if router_method_name not in flow._router_paths:
|
||||
flow._router_paths[FlowMethodName(router_method_name)] = []
|
||||
|
||||
inferred_paths: Iterable[FlowMethodName | FlowRouteName] = set(
|
||||
flow._router_paths.get(FlowMethodName(router_method_name), [])
|
||||
)
|
||||
current_paths = flow._router_paths.get(FlowMethodName(router_method_name), [])
|
||||
if current_paths and router_method_name in nodes:
|
||||
nodes[router_method_name]["router_paths"] = [str(p) for p in current_paths]
|
||||
all_router_outputs.update(str(p) for p in current_paths)
|
||||
|
||||
for condition_data in flow._listeners.values():
|
||||
trigger_strings: list[str] = []
|
||||
|
||||
if is_simple_flow_condition(condition_data):
|
||||
_, methods = condition_data
|
||||
trigger_strings = [str(m) for m in methods]
|
||||
elif is_flow_condition_dict(condition_data):
|
||||
trigger_strings = _extract_direct_or_triggers(condition_data)
|
||||
|
||||
for trigger_str in trigger_strings:
|
||||
if trigger_str not in nodes:
|
||||
# This is likely a router path output
|
||||
inferred_paths.add(trigger_str) # type: ignore[attr-defined]
|
||||
|
||||
if inferred_paths:
|
||||
flow._router_paths[FlowMethodName(router_method_name)] = list(
|
||||
inferred_paths # type: ignore[arg-type]
|
||||
if not current_paths:
|
||||
logger.warning(
|
||||
f"Could not determine return paths for router '{router_method_name}'. "
|
||||
f"Add a return type annotation like "
|
||||
f"'-> Literal[\"path1\", \"path2\"]' or '-> YourEnum' "
|
||||
f"to enable proper flow visualization."
|
||||
)
|
||||
if router_method_name in nodes:
|
||||
nodes[router_method_name]["router_paths"] = list(inferred_paths)
|
||||
|
||||
orphaned_triggers = all_string_triggers - all_router_outputs
|
||||
if orphaned_triggers:
|
||||
logger.error(
|
||||
f"Found listeners waiting for triggers {orphaned_triggers} "
|
||||
f"but no router outputs these values explicitly. "
|
||||
f"If your router returns a non-static value, check that your router has proper return type annotations."
|
||||
)
|
||||
|
||||
for router_method_name in router_methods:
|
||||
if router_method_name not in flow._router_paths:
|
||||
@@ -383,6 +395,9 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure:
|
||||
|
||||
for path in router_paths:
|
||||
for listener_name, condition_data in flow._listeners.items():
|
||||
if listener_name == router_method_name:
|
||||
continue
|
||||
|
||||
trigger_strings_from_cond: list[str] = []
|
||||
|
||||
if is_simple_flow_condition(condition_data):
|
||||
|
||||
@@ -415,4 +415,256 @@ def test_router_paths_not_in_and_conditions():
|
||||
|
||||
assert "step_1" in targets
|
||||
assert "step_3_or" in targets
|
||||
assert "step_2_and" not in targets
|
||||
assert "step_2_and" not in targets
|
||||
|
||||
|
||||
def test_chained_routers_no_self_loops():
|
||||
"""Test that chained routers don't create self-referencing edges.
|
||||
|
||||
This tests the bug where routers with string triggers (like 'auth', 'exp')
|
||||
would incorrectly create edges to themselves when another router outputs
|
||||
those strings.
|
||||
"""
|
||||
|
||||
class ChainedRouterFlow(Flow):
|
||||
"""Flow with multiple chained routers using string outputs."""
|
||||
|
||||
@start()
|
||||
def entrance(self):
|
||||
return "started"
|
||||
|
||||
@router(entrance)
|
||||
def session_in_cache(self):
|
||||
return "exp"
|
||||
|
||||
@router("exp")
|
||||
def check_exp(self):
|
||||
return "auth"
|
||||
|
||||
@router("auth")
|
||||
def call_ai_auth(self):
|
||||
return "action"
|
||||
|
||||
@listen("action")
|
||||
def forward_to_action(self):
|
||||
return "done"
|
||||
|
||||
@listen("authenticate")
|
||||
def forward_to_authenticate(self):
|
||||
return "need_auth"
|
||||
|
||||
flow = ChainedRouterFlow()
|
||||
structure = build_flow_structure(flow)
|
||||
|
||||
# Check that no self-loops exist
|
||||
for edge in structure["edges"]:
|
||||
assert edge["source"] != edge["target"], (
|
||||
f"Self-loop detected: {edge['source']} -> {edge['target']}"
|
||||
)
|
||||
|
||||
# Verify correct connections
|
||||
router_edges = [edge for edge in structure["edges"] if edge["is_router_path"]]
|
||||
|
||||
# session_in_cache -> check_exp (via 'exp')
|
||||
exp_edges = [
|
||||
edge
|
||||
for edge in router_edges
|
||||
if edge["router_path_label"] == "exp" and edge["source"] == "session_in_cache"
|
||||
]
|
||||
assert len(exp_edges) == 1
|
||||
assert exp_edges[0]["target"] == "check_exp"
|
||||
|
||||
# check_exp -> call_ai_auth (via 'auth')
|
||||
auth_edges = [
|
||||
edge
|
||||
for edge in router_edges
|
||||
if edge["router_path_label"] == "auth" and edge["source"] == "check_exp"
|
||||
]
|
||||
assert len(auth_edges) == 1
|
||||
assert auth_edges[0]["target"] == "call_ai_auth"
|
||||
|
||||
# call_ai_auth -> forward_to_action (via 'action')
|
||||
action_edges = [
|
||||
edge
|
||||
for edge in router_edges
|
||||
if edge["router_path_label"] == "action" and edge["source"] == "call_ai_auth"
|
||||
]
|
||||
assert len(action_edges) == 1
|
||||
assert action_edges[0]["target"] == "forward_to_action"
|
||||
|
||||
|
||||
def test_routers_with_shared_output_strings():
|
||||
"""Test that routers with shared output strings don't create incorrect edges.
|
||||
|
||||
This tests a scenario where multiple routers can output the same string,
|
||||
ensuring the visualization only creates edges for the router that actually
|
||||
outputs the string, not all routers.
|
||||
"""
|
||||
|
||||
class SharedOutputRouterFlow(Flow):
|
||||
"""Flow where multiple routers can output 'auth'."""
|
||||
|
||||
@start()
|
||||
def start(self):
|
||||
return "started"
|
||||
|
||||
@router(start)
|
||||
def router_a(self):
|
||||
# This router can output 'auth' or 'skip'
|
||||
return "auth"
|
||||
|
||||
@router("auth")
|
||||
def router_b(self):
|
||||
# This router listens to 'auth' but outputs 'done'
|
||||
return "done"
|
||||
|
||||
@listen("done")
|
||||
def finalize(self):
|
||||
return "complete"
|
||||
|
||||
@listen("skip")
|
||||
def handle_skip(self):
|
||||
return "skipped"
|
||||
|
||||
flow = SharedOutputRouterFlow()
|
||||
structure = build_flow_structure(flow)
|
||||
|
||||
# Check no self-loops
|
||||
for edge in structure["edges"]:
|
||||
assert edge["source"] != edge["target"], (
|
||||
f"Self-loop detected: {edge['source']} -> {edge['target']}"
|
||||
)
|
||||
|
||||
# router_a should connect to router_b via 'auth'
|
||||
router_edges = [edge for edge in structure["edges"] if edge["is_router_path"]]
|
||||
auth_from_a = [
|
||||
edge
|
||||
for edge in router_edges
|
||||
if edge["source"] == "router_a" and edge["router_path_label"] == "auth"
|
||||
]
|
||||
assert len(auth_from_a) == 1
|
||||
assert auth_from_a[0]["target"] == "router_b"
|
||||
|
||||
# router_b should connect to finalize via 'done'
|
||||
done_from_b = [
|
||||
edge
|
||||
for edge in router_edges
|
||||
if edge["source"] == "router_b" and edge["router_path_label"] == "done"
|
||||
]
|
||||
assert len(done_from_b) == 1
|
||||
assert done_from_b[0]["target"] == "finalize"
|
||||
|
||||
|
||||
def test_warning_for_router_without_paths(caplog):
|
||||
"""Test that a warning is logged when a router has no determinable paths."""
|
||||
import logging
|
||||
|
||||
class RouterWithoutPathsFlow(Flow):
|
||||
"""Flow with a router that returns a dynamic value."""
|
||||
|
||||
@start()
|
||||
def begin(self):
|
||||
return "started"
|
||||
|
||||
@router(begin)
|
||||
def dynamic_router(self):
|
||||
# Returns a variable that can't be statically analyzed
|
||||
import random
|
||||
return random.choice(["path_a", "path_b"])
|
||||
|
||||
@listen("path_a")
|
||||
def handle_a(self):
|
||||
return "a"
|
||||
|
||||
@listen("path_b")
|
||||
def handle_b(self):
|
||||
return "b"
|
||||
|
||||
flow = RouterWithoutPathsFlow()
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
build_flow_structure(flow)
|
||||
|
||||
# Check that warning was logged for the router
|
||||
assert any(
|
||||
"Could not determine return paths for router 'dynamic_router'" in record.message
|
||||
for record in caplog.records
|
||||
)
|
||||
|
||||
# Check that error was logged for orphaned triggers
|
||||
assert any(
|
||||
"Found listeners waiting for triggers" in record.message
|
||||
for record in caplog.records
|
||||
)
|
||||
|
||||
|
||||
def test_warning_for_orphaned_listeners(caplog):
|
||||
"""Test that an error is logged when listeners wait for triggers no router outputs."""
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
class OrphanedListenerFlow(Flow):
|
||||
"""Flow where a listener waits for a trigger that no router outputs."""
|
||||
|
||||
@start()
|
||||
def begin(self):
|
||||
return "started"
|
||||
|
||||
@router(begin)
|
||||
def my_router(self) -> Literal["option_a", "option_b"]:
|
||||
return "option_a"
|
||||
|
||||
@listen("option_a")
|
||||
def handle_a(self):
|
||||
return "a"
|
||||
|
||||
@listen("option_c") # This trigger is never output by any router
|
||||
def handle_orphan(self):
|
||||
return "orphan"
|
||||
|
||||
flow = OrphanedListenerFlow()
|
||||
|
||||
with caplog.at_level(logging.ERROR):
|
||||
build_flow_structure(flow)
|
||||
|
||||
# Check that error was logged for orphaned trigger
|
||||
assert any(
|
||||
"Found listeners waiting for triggers" in record.message
|
||||
and "option_c" in record.message
|
||||
for record in caplog.records
|
||||
)
|
||||
|
||||
|
||||
def test_no_warning_for_properly_typed_router(caplog):
|
||||
"""Test that no warning is logged when router has proper type annotations."""
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
class ProperlyTypedRouterFlow(Flow):
|
||||
"""Flow with properly typed router."""
|
||||
|
||||
@start()
|
||||
def begin(self):
|
||||
return "started"
|
||||
|
||||
@router(begin)
|
||||
def typed_router(self) -> Literal["path_a", "path_b"]:
|
||||
return "path_a"
|
||||
|
||||
@listen("path_a")
|
||||
def handle_a(self):
|
||||
return "a"
|
||||
|
||||
@listen("path_b")
|
||||
def handle_b(self):
|
||||
return "b"
|
||||
|
||||
flow = ProperlyTypedRouterFlow()
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
build_flow_structure(flow)
|
||||
|
||||
# No warnings should be logged
|
||||
warning_messages = [r.message for r in caplog.records if r.levelno >= logging.WARNING]
|
||||
assert not any("Could not determine return paths" in msg for msg in warning_messages)
|
||||
assert not any("Found listeners waiting for triggers" in msg for msg in warning_messages)
|
||||
Reference in New Issue
Block a user