fix: ensure fuzzy returns are more strict, show type warning

This commit is contained in:
Greyson LaLonde
2025-11-24 17:35:12 -05:00
committed by GitHub
parent d2b9c54931
commit b049b73f2e
3 changed files with 418 additions and 25 deletions

View File

@@ -17,6 +17,7 @@ from __future__ import annotations
import ast
from collections import defaultdict, deque
from enum import Enum
import inspect
import textwrap
from typing import TYPE_CHECKING, Any
@@ -40,11 +41,123 @@ if TYPE_CHECKING:
_printer = Printer()
def _extract_string_literals_from_type_annotation(
node: ast.expr,
function_globals: dict[str, Any] | None = None,
) -> list[str]:
"""Extract string literals from a type annotation AST node.
Handles:
- Literal["a", "b", "c"]
- "a" | "b" | "c" (union of string literals)
- Just "a" (single string constant annotation)
- Enum types with string values (e.g., class MyEnum(str, Enum))
Args:
node: The AST node representing a type annotation.
function_globals: The globals dict from the function, used to resolve Enum types.
Returns:
List of string literals found in the annotation.
"""
strings: list[str] = []
if isinstance(node, ast.Constant) and isinstance(node.value, str):
strings.append(node.value)
elif isinstance(node, ast.Name) and function_globals:
enum_class = function_globals.get(node.id)
if (
enum_class is not None
and isinstance(enum_class, type)
and issubclass(enum_class, Enum)
):
strings.extend(
member.value for member in enum_class if isinstance(member.value, str)
)
elif isinstance(node, ast.Attribute) and function_globals:
try:
if isinstance(node.value, ast.Name):
module = function_globals.get(node.value.id)
if module is not None:
enum_class = getattr(module, node.attr, None)
if (
enum_class is not None
and isinstance(enum_class, type)
and issubclass(enum_class, Enum)
):
strings.extend(
member.value
for member in enum_class
if isinstance(member.value, str)
)
except (AttributeError, TypeError):
pass
elif isinstance(node, ast.Subscript):
is_literal = False
if isinstance(node.value, ast.Name) and node.value.id == "Literal":
is_literal = True
elif isinstance(node.value, ast.Attribute) and node.value.attr == "Literal":
is_literal = True
if is_literal:
if isinstance(node.slice, ast.Tuple):
strings.extend(
elt.value
for elt in node.slice.elts
if isinstance(elt, ast.Constant) and isinstance(elt.value, str)
)
elif isinstance(node.slice, ast.Constant) and isinstance(
node.slice.value, str
):
strings.append(node.slice.value)
elif isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
strings.extend(
_extract_string_literals_from_type_annotation(node.left, function_globals)
)
strings.extend(
_extract_string_literals_from_type_annotation(node.right, function_globals)
)
return strings
def _unwrap_function(function: Any) -> Any:
"""Unwrap a function to get the original function with correct globals.
Flow methods are wrapped by decorators like @router, @listen, etc.
This function unwraps them to get the original function which has
the correct __globals__ for resolving type annotations like Enums.
Args:
function: The potentially wrapped function.
Returns:
The unwrapped original function.
"""
if hasattr(function, "__func__"):
function = function.__func__
if hasattr(function, "__wrapped__"):
wrapped = function.__wrapped__
if hasattr(wrapped, "unwrap"):
return wrapped.unwrap()
return wrapped
return function
def get_possible_return_constants(function: Any) -> list[str] | None:
"""Extract possible string return values from a function using AST parsing.
This function analyzes the source code of a router method to identify
all possible string values it might return. It handles:
- Return type annotations: -> Literal["a", "b"] or -> "a" | "b" | "c"
- Enum type annotations: -> MyEnum (extracts string values from members)
- Direct string literals: return "value"
- Variable assignments: x = "value"; return x
- Dictionary lookups: d = {"k": "v"}; return d[key]
@@ -57,6 +170,8 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
Returns:
List of possible string return values, or None if analysis fails.
"""
unwrapped = _unwrap_function(function)
try:
source = inspect.getsource(function)
except OSError:
@@ -97,6 +212,17 @@ def get_possible_return_constants(function: Any) -> list[str] | None:
return None
return_values: set[str] = set()
function_globals = getattr(unwrapped, "__globals__", None)
for node in ast.walk(code_ast):
if isinstance(node, ast.FunctionDef):
if node.returns:
annotation_values = _extract_string_literals_from_type_annotation(
node.returns, function_globals
)
return_values.update(annotation_values)
break # Only process the first function definition
dict_definitions: dict[str, list[str]] = {}
variable_values: dict[str, list[str]] = {}
state_attribute_values: dict[str, list[str]] = {}

View File

@@ -3,13 +3,13 @@
from __future__ import annotations
from collections import defaultdict
from collections.abc import Iterable
import inspect
import logging
from typing import TYPE_CHECKING, Any
from crewai.flow.constants import AND_CONDITION, OR_CONDITION
from crewai.flow.flow_wrappers import FlowCondition
from crewai.flow.types import FlowMethodName, FlowRouteName
from crewai.flow.types import FlowMethodName
from crewai.flow.utils import (
is_flow_condition_dict,
is_simple_flow_condition,
@@ -18,6 +18,9 @@ from crewai.flow.visualization.schema import extract_method_signature
from crewai.flow.visualization.types import FlowStructure, NodeMetadata, StructureEdge
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from crewai.flow.flow import Flow
@@ -346,34 +349,43 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure:
if trigger_method in nodes
)
all_string_triggers: set[str] = set()
for condition_data in flow._listeners.values():
if is_simple_flow_condition(condition_data):
_, methods = condition_data
for m in methods:
if str(m) not in nodes: # It's a string trigger, not a method name
all_string_triggers.add(str(m))
elif is_flow_condition_dict(condition_data):
for trigger in _extract_direct_or_triggers(condition_data):
if trigger not in nodes:
all_string_triggers.add(trigger)
all_router_outputs: set[str] = set()
for router_method_name in router_methods:
if router_method_name not in flow._router_paths:
flow._router_paths[FlowMethodName(router_method_name)] = []
inferred_paths: Iterable[FlowMethodName | FlowRouteName] = set(
flow._router_paths.get(FlowMethodName(router_method_name), [])
)
current_paths = flow._router_paths.get(FlowMethodName(router_method_name), [])
if current_paths and router_method_name in nodes:
nodes[router_method_name]["router_paths"] = [str(p) for p in current_paths]
all_router_outputs.update(str(p) for p in current_paths)
for condition_data in flow._listeners.values():
trigger_strings: list[str] = []
if is_simple_flow_condition(condition_data):
_, methods = condition_data
trigger_strings = [str(m) for m in methods]
elif is_flow_condition_dict(condition_data):
trigger_strings = _extract_direct_or_triggers(condition_data)
for trigger_str in trigger_strings:
if trigger_str not in nodes:
# This is likely a router path output
inferred_paths.add(trigger_str) # type: ignore[attr-defined]
if inferred_paths:
flow._router_paths[FlowMethodName(router_method_name)] = list(
inferred_paths # type: ignore[arg-type]
if not current_paths:
logger.warning(
f"Could not determine return paths for router '{router_method_name}'. "
f"Add a return type annotation like "
f"'-> Literal[\"path1\", \"path2\"]' or '-> YourEnum' "
f"to enable proper flow visualization."
)
if router_method_name in nodes:
nodes[router_method_name]["router_paths"] = list(inferred_paths)
orphaned_triggers = all_string_triggers - all_router_outputs
if orphaned_triggers:
logger.error(
f"Found listeners waiting for triggers {orphaned_triggers} "
f"but no router outputs these values explicitly. "
f"If your router returns a non-static value, check that your router has proper return type annotations."
)
for router_method_name in router_methods:
if router_method_name not in flow._router_paths:
@@ -383,6 +395,9 @@ def build_flow_structure(flow: Flow[Any]) -> FlowStructure:
for path in router_paths:
for listener_name, condition_data in flow._listeners.items():
if listener_name == router_method_name:
continue
trigger_strings_from_cond: list[str] = []
if is_simple_flow_condition(condition_data):

View File

@@ -415,4 +415,256 @@ def test_router_paths_not_in_and_conditions():
assert "step_1" in targets
assert "step_3_or" in targets
assert "step_2_and" not in targets
assert "step_2_and" not in targets
def test_chained_routers_no_self_loops():
"""Test that chained routers don't create self-referencing edges.
This tests the bug where routers with string triggers (like 'auth', 'exp')
would incorrectly create edges to themselves when another router outputs
those strings.
"""
class ChainedRouterFlow(Flow):
"""Flow with multiple chained routers using string outputs."""
@start()
def entrance(self):
return "started"
@router(entrance)
def session_in_cache(self):
return "exp"
@router("exp")
def check_exp(self):
return "auth"
@router("auth")
def call_ai_auth(self):
return "action"
@listen("action")
def forward_to_action(self):
return "done"
@listen("authenticate")
def forward_to_authenticate(self):
return "need_auth"
flow = ChainedRouterFlow()
structure = build_flow_structure(flow)
# Check that no self-loops exist
for edge in structure["edges"]:
assert edge["source"] != edge["target"], (
f"Self-loop detected: {edge['source']} -> {edge['target']}"
)
# Verify correct connections
router_edges = [edge for edge in structure["edges"] if edge["is_router_path"]]
# session_in_cache -> check_exp (via 'exp')
exp_edges = [
edge
for edge in router_edges
if edge["router_path_label"] == "exp" and edge["source"] == "session_in_cache"
]
assert len(exp_edges) == 1
assert exp_edges[0]["target"] == "check_exp"
# check_exp -> call_ai_auth (via 'auth')
auth_edges = [
edge
for edge in router_edges
if edge["router_path_label"] == "auth" and edge["source"] == "check_exp"
]
assert len(auth_edges) == 1
assert auth_edges[0]["target"] == "call_ai_auth"
# call_ai_auth -> forward_to_action (via 'action')
action_edges = [
edge
for edge in router_edges
if edge["router_path_label"] == "action" and edge["source"] == "call_ai_auth"
]
assert len(action_edges) == 1
assert action_edges[0]["target"] == "forward_to_action"
def test_routers_with_shared_output_strings():
"""Test that routers with shared output strings don't create incorrect edges.
This tests a scenario where multiple routers can output the same string,
ensuring the visualization only creates edges for the router that actually
outputs the string, not all routers.
"""
class SharedOutputRouterFlow(Flow):
"""Flow where multiple routers can output 'auth'."""
@start()
def start(self):
return "started"
@router(start)
def router_a(self):
# This router can output 'auth' or 'skip'
return "auth"
@router("auth")
def router_b(self):
# This router listens to 'auth' but outputs 'done'
return "done"
@listen("done")
def finalize(self):
return "complete"
@listen("skip")
def handle_skip(self):
return "skipped"
flow = SharedOutputRouterFlow()
structure = build_flow_structure(flow)
# Check no self-loops
for edge in structure["edges"]:
assert edge["source"] != edge["target"], (
f"Self-loop detected: {edge['source']} -> {edge['target']}"
)
# router_a should connect to router_b via 'auth'
router_edges = [edge for edge in structure["edges"] if edge["is_router_path"]]
auth_from_a = [
edge
for edge in router_edges
if edge["source"] == "router_a" and edge["router_path_label"] == "auth"
]
assert len(auth_from_a) == 1
assert auth_from_a[0]["target"] == "router_b"
# router_b should connect to finalize via 'done'
done_from_b = [
edge
for edge in router_edges
if edge["source"] == "router_b" and edge["router_path_label"] == "done"
]
assert len(done_from_b) == 1
assert done_from_b[0]["target"] == "finalize"
def test_warning_for_router_without_paths(caplog):
"""Test that a warning is logged when a router has no determinable paths."""
import logging
class RouterWithoutPathsFlow(Flow):
"""Flow with a router that returns a dynamic value."""
@start()
def begin(self):
return "started"
@router(begin)
def dynamic_router(self):
# Returns a variable that can't be statically analyzed
import random
return random.choice(["path_a", "path_b"])
@listen("path_a")
def handle_a(self):
return "a"
@listen("path_b")
def handle_b(self):
return "b"
flow = RouterWithoutPathsFlow()
with caplog.at_level(logging.WARNING):
build_flow_structure(flow)
# Check that warning was logged for the router
assert any(
"Could not determine return paths for router 'dynamic_router'" in record.message
for record in caplog.records
)
# Check that error was logged for orphaned triggers
assert any(
"Found listeners waiting for triggers" in record.message
for record in caplog.records
)
def test_warning_for_orphaned_listeners(caplog):
"""Test that an error is logged when listeners wait for triggers no router outputs."""
import logging
from typing import Literal
class OrphanedListenerFlow(Flow):
"""Flow where a listener waits for a trigger that no router outputs."""
@start()
def begin(self):
return "started"
@router(begin)
def my_router(self) -> Literal["option_a", "option_b"]:
return "option_a"
@listen("option_a")
def handle_a(self):
return "a"
@listen("option_c") # This trigger is never output by any router
def handle_orphan(self):
return "orphan"
flow = OrphanedListenerFlow()
with caplog.at_level(logging.ERROR):
build_flow_structure(flow)
# Check that error was logged for orphaned trigger
assert any(
"Found listeners waiting for triggers" in record.message
and "option_c" in record.message
for record in caplog.records
)
def test_no_warning_for_properly_typed_router(caplog):
"""Test that no warning is logged when router has proper type annotations."""
import logging
from typing import Literal
class ProperlyTypedRouterFlow(Flow):
"""Flow with properly typed router."""
@start()
def begin(self):
return "started"
@router(begin)
def typed_router(self) -> Literal["path_a", "path_b"]:
return "path_a"
@listen("path_a")
def handle_a(self):
return "a"
@listen("path_b")
def handle_b(self):
return "b"
flow = ProperlyTypedRouterFlow()
with caplog.at_level(logging.WARNING):
build_flow_structure(flow)
# No warnings should be logged
warning_messages = [r.message for r in caplog.records if r.levelno >= logging.WARNING]
assert not any("Could not determine return paths" in msg for msg in warning_messages)
assert not any("Found listeners waiting for triggers" in msg for msg in warning_messages)