Allow @router() as start method of a flow (#6288)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Vulnerability Scan / pip-audit (push) Has been cancelled
Nightly Canary Release / Check for new commits (push) Has been cancelled
Nightly Canary Release / Build nightly packages (push) Has been cancelled
Nightly Canary Release / Publish nightly to PyPI (push) Has been cancelled

This commit fixes a bug where a router method could not be the start
method of a flow.

This is useful when you want to route against the initial state, or even
stack two routers.
This commit is contained in:
Vinicius Brasil
2026-06-22 14:04:45 -07:00
committed by GitHub
parent 4cbfbdb232
commit 0391febc6c
8 changed files with 150 additions and 36 deletions

View File

@@ -8,8 +8,8 @@ from crewai.flow.dsl._types import FlowMethodDecorator, FlowTrigger
from crewai.flow.dsl._utils import (
P,
R,
_merge_flow_method_definition,
_method_action,
_set_flow_method_definition,
)
from crewai.flow.flow_definition import FlowMethodDefinition
from crewai.flow.flow_wrappers import ListenMethod
@@ -45,7 +45,7 @@ def listen(condition: FlowTrigger) -> FlowMethodDecorator:
def decorator(func: Callable[P, R]) -> ListenMethod[P, R]:
wrapper = ListenMethod(func)
_set_flow_method_definition(
_merge_flow_method_definition(
wrapper,
FlowMethodDefinition(
do=_method_action(func),

View File

@@ -19,8 +19,8 @@ from crewai.flow.dsl._types import FlowMethodDecorator, FlowTrigger
from crewai.flow.dsl._utils import (
P,
R,
_merge_flow_method_definition,
_method_action,
_set_flow_method_definition,
)
from crewai.flow.flow_definition import FlowMethodDefinition
from crewai.flow.flow_wrappers import RouterMethod
@@ -95,7 +95,7 @@ def _normalize_router_emit(value: Sequence[Any] | str) -> list[str]:
def router(
condition: FlowTrigger,
condition: FlowTrigger | None = None,
*,
emit: Sequence[str] | str | None = None,
) -> FlowMethodDecorator:
@@ -107,6 +107,7 @@ def router(
Args:
condition: Specifies when the router should execute. Can be:
- None: no listen trigger, used when stacking with @start() or @listen()
- str: Route label or method name that triggers this router
- FlowCondition: Result from or_() or and_(), including nested conditions
- Flow method reference: A method whose completion triggers this router
@@ -146,14 +147,17 @@ def router(
else:
router_events = _get_router_return_events(func) or []
_set_flow_method_definition(
method_definition_kwargs: dict[str, Any] = {
"do": _method_action(func),
"router": True,
"emit": router_events or None,
}
if condition is not None:
method_definition_kwargs["listen"] = _to_definition_condition(condition)
_merge_flow_method_definition(
wrapper,
FlowMethodDefinition(
do=_method_action(func),
listen=_to_definition_condition(condition),
router=True,
emit=router_events or None,
),
FlowMethodDefinition(**method_definition_kwargs),
)
return wrapper

View File

@@ -8,8 +8,8 @@ from crewai.flow.dsl._types import FlowMethodDecorator, FlowTrigger
from crewai.flow.dsl._utils import (
P,
R,
_merge_flow_method_definition,
_method_action,
_set_flow_method_definition,
)
from crewai.flow.flow_definition import FlowMethodDefinition
from crewai.flow.flow_wrappers import StartMethod
@@ -54,7 +54,7 @@ def start(
def decorator(func: Callable[P, R]) -> StartMethod[P, R]:
wrapper = StartMethod(func)
_set_flow_method_definition(
_merge_flow_method_definition(
wrapper,
FlowMethodDefinition(
do=_method_action(func),

View File

@@ -106,6 +106,25 @@ def _get_flow_method_definition(method: Any) -> FlowMethodDefinition | None:
return None
def _merge_flow_method_definition(
wrapper: FlowMethod[P, R],
definition: FlowMethodDefinition,
) -> None:
existing = _get_flow_method_definition(wrapper)
if existing is None:
_set_flow_method_definition(wrapper, definition)
return
updates = {
field_name: getattr(definition, field_name)
for field_name in definition.model_fields_set
}
_set_flow_method_definition(
wrapper,
existing.model_copy(deep=True, update=updates),
)
def _is_json_serializable(value: Any) -> bool:
try:
json.dumps(value)

View File

@@ -870,14 +870,6 @@ def _validate_action_cel(
def log_flow_definition_issues(definition: FlowDefinition) -> None:
for method_name, method in definition.methods.items():
path = f"methods.{method_name}"
if method.router and not method.is_start and method.listen is None:
_log_flow_definition_issue(
definition.name,
code="router_without_trigger",
severity="error",
path=path,
message="router: true requires either start or listen",
)
if method.emit and not method.router:
_log_flow_definition_issue(
definition.name,

View File

@@ -3007,6 +3007,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
"""
# First, handle routers repeatedly until no router triggers anymore
router_results = []
router_result_payloads: dict[str, Any] = {}
router_result_to_feedback: dict[
str, Any
] = {} # Map outcome -> HumanFeedbackResult
@@ -3044,6 +3045,11 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
router_result_str = str(router_result)
router_result_event = FlowMethodName(router_result_str)
router_results.append(router_result_event)
router_result_payloads[router_result_str] = (
self.last_human_feedback
if self.last_human_feedback is not None
else router_result
)
if self.last_human_feedback is not None:
router_result_to_feedback[router_result_str] = (
@@ -3064,7 +3070,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
current_trigger, router_only=False
)
if listeners_triggered:
listener_result = router_result_to_feedback.get(
listener_result = router_result_payloads.get(
str(current_trigger), result
)
racing_group = self._get_racing_group_for_listeners(

View File

@@ -386,6 +386,54 @@ def test_router_runtime_uses_flow_definition_without_legacy_router_metadata():
assert execution_order == ["begin", "decide", "handle_left"]
def test_start_router_runtime_routes_public_dsl_return_value():
execution_order = []
class StartRouterFlow(Flow):
@start()
@router(emit=["continue"])
def decide(self):
execution_order.append("decide")
return "continue"
@listen("continue")
def handle_continue(self, result):
execution_order.append(f"handle_continue:{result}")
return "done"
assert StartRouterFlow().kickoff() == "done"
assert execution_order == ["decide", "handle_continue:continue"]
def test_start_router_runtime_chains_to_stacked_listener_router():
execution_order = []
class ChainedStartRouterFlow(Flow):
@start()
@router(emit=["approved", "not_approved"])
def first_router(self):
execution_order.append("first_router")
return "approved"
@listen("approved")
@router(emit=["second_approval", "not_approved"])
def second_router(self):
execution_order.append("second_router")
return "second_approval"
@listen("second_approval")
def handle_second_approval(self, result):
execution_order.append(f"handle_second_approval:{result}")
return "done"
assert ChainedStartRouterFlow().kickoff() == "done"
assert execution_order == [
"first_router",
"second_router",
"handle_second_approval:second_approval",
]
def test_router_falsy_result_emits_runtime_event():
execution_order = []

View File

@@ -565,6 +565,54 @@ def test_flow_definition_classifies_start_router_from_human_feedback_emit():
assert entry_point.emit is None
def test_flow_definition_classifies_public_dsl_start_router():
class StartRouterFlow(Flow):
@start()
@router(emit=["continue", "stop"])
def entry_point(self):
return "continue"
@router(emit=["resume"])
@start()
def alternate_entry_point(self):
return "resume"
entry_point = StartRouterFlow.flow_definition().methods["entry_point"]
alternate_entry_point = StartRouterFlow.flow_definition().methods[
"alternate_entry_point"
]
assert entry_point.is_start is True
assert entry_point.router is True
assert entry_point.listen is None
assert entry_point.emit == ["continue", "stop"]
assert alternate_entry_point.is_start is True
assert alternate_entry_point.router is True
assert alternate_entry_point.listen is None
assert alternate_entry_point.emit == ["resume"]
def test_flow_definition_merges_stacked_listen_router():
class ChainedRouterFlow(Flow):
@start()
@router(emit=["approved", "not_approved"])
def first_router(self):
return "approved"
@listen("approved")
@router(emit=["second_approval", "not_approved"])
def second_router(self):
return "second_approval"
methods = ChainedRouterFlow.flow_definition().methods
assert methods["first_router"].is_start is True
assert methods["first_router"].listen is None
assert methods["second_router"].router is True
assert methods["second_router"].listen == "approved"
assert methods["second_router"].emit == ["second_approval", "not_approved"]
def test_flow_definition_round_trips_json_and_yaml():
class RoundTripFlow(Flow):
@start()
@@ -883,7 +931,7 @@ def test_flow_definition_ignores_legacy_diagnostics_loaded_from_contract():
assert "diagnostics" not in definition.to_dict()
def test_router_start_false_without_listen_logs_missing_trigger(caplog):
def test_router_start_false_without_listen_is_allowed(caplog):
caplog.set_level(logging.ERROR, logger="crewai.flow.flow_definition")
flow_definition.FlowDefinition.from_dict(
@@ -901,12 +949,7 @@ def test_router_start_false_without_listen_logs_missing_trigger(caplog):
}
)
assert any(
record.levelno == logging.ERROR
and "router_without_trigger" in record.message
and "methods.decision" in record.message
for record in caplog.records
)
assert not caplog.records
def test_router_human_feedback_preserves_existing_router_metadata():
@@ -1048,7 +1091,7 @@ def test_flow_definition_cache_is_not_reused_by_subclasses():
assert set(child_definition.methods) == {"child_step"}
def test_flow_definition_logs_validation_issues_when_loaded_from_contract(caplog):
def test_flow_definition_allows_router_without_trigger(caplog):
caplog.set_level(logging.WARNING, logger="crewai.flow.flow_definition")
flow_definition.FlowDefinition.from_dict(
@@ -1065,9 +1108,11 @@ def test_flow_definition_logs_validation_issues_when_loaded_from_contract(caplog
}
)
assert any(
record.levelno == logging.ERROR
and "LoadedFlow" in record.message
and "router_without_trigger" in record.message
for record in caplog.records
)
class StandaloneRouterFlow(Flow):
@router(emit=["continue"])
def decision(self):
return "continue"
StandaloneRouterFlow.flow_definition()
assert not caplog.records