mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-02 05:38:12 +00:00
Allow @router() as start method of a flow (#6288)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Vulnerability Scan / pip-audit (push) Has been cancelled
Nightly Canary Release / Check for new commits (push) Has been cancelled
Nightly Canary Release / Build nightly packages (push) Has been cancelled
Nightly Canary Release / Publish nightly to PyPI (push) Has been cancelled
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Vulnerability Scan / pip-audit (push) Has been cancelled
Nightly Canary Release / Check for new commits (push) Has been cancelled
Nightly Canary Release / Build nightly packages (push) Has been cancelled
Nightly Canary Release / Publish nightly to PyPI (push) Has been cancelled
This commit fixes a bug where a router method could not be the start method of a flow. This is useful when you want to route against the initial state, or even stack two routers.
This commit is contained in:
@@ -8,8 +8,8 @@ from crewai.flow.dsl._types import FlowMethodDecorator, FlowTrigger
|
||||
from crewai.flow.dsl._utils import (
|
||||
P,
|
||||
R,
|
||||
_merge_flow_method_definition,
|
||||
_method_action,
|
||||
_set_flow_method_definition,
|
||||
)
|
||||
from crewai.flow.flow_definition import FlowMethodDefinition
|
||||
from crewai.flow.flow_wrappers import ListenMethod
|
||||
@@ -45,7 +45,7 @@ def listen(condition: FlowTrigger) -> FlowMethodDecorator:
|
||||
def decorator(func: Callable[P, R]) -> ListenMethod[P, R]:
|
||||
wrapper = ListenMethod(func)
|
||||
|
||||
_set_flow_method_definition(
|
||||
_merge_flow_method_definition(
|
||||
wrapper,
|
||||
FlowMethodDefinition(
|
||||
do=_method_action(func),
|
||||
|
||||
@@ -19,8 +19,8 @@ from crewai.flow.dsl._types import FlowMethodDecorator, FlowTrigger
|
||||
from crewai.flow.dsl._utils import (
|
||||
P,
|
||||
R,
|
||||
_merge_flow_method_definition,
|
||||
_method_action,
|
||||
_set_flow_method_definition,
|
||||
)
|
||||
from crewai.flow.flow_definition import FlowMethodDefinition
|
||||
from crewai.flow.flow_wrappers import RouterMethod
|
||||
@@ -95,7 +95,7 @@ def _normalize_router_emit(value: Sequence[Any] | str) -> list[str]:
|
||||
|
||||
|
||||
def router(
|
||||
condition: FlowTrigger,
|
||||
condition: FlowTrigger | None = None,
|
||||
*,
|
||||
emit: Sequence[str] | str | None = None,
|
||||
) -> FlowMethodDecorator:
|
||||
@@ -107,6 +107,7 @@ def router(
|
||||
|
||||
Args:
|
||||
condition: Specifies when the router should execute. Can be:
|
||||
- None: no listen trigger, used when stacking with @start() or @listen()
|
||||
- str: Route label or method name that triggers this router
|
||||
- FlowCondition: Result from or_() or and_(), including nested conditions
|
||||
- Flow method reference: A method whose completion triggers this router
|
||||
@@ -146,14 +147,17 @@ def router(
|
||||
else:
|
||||
router_events = _get_router_return_events(func) or []
|
||||
|
||||
_set_flow_method_definition(
|
||||
method_definition_kwargs: dict[str, Any] = {
|
||||
"do": _method_action(func),
|
||||
"router": True,
|
||||
"emit": router_events or None,
|
||||
}
|
||||
if condition is not None:
|
||||
method_definition_kwargs["listen"] = _to_definition_condition(condition)
|
||||
|
||||
_merge_flow_method_definition(
|
||||
wrapper,
|
||||
FlowMethodDefinition(
|
||||
do=_method_action(func),
|
||||
listen=_to_definition_condition(condition),
|
||||
router=True,
|
||||
emit=router_events or None,
|
||||
),
|
||||
FlowMethodDefinition(**method_definition_kwargs),
|
||||
)
|
||||
return wrapper
|
||||
|
||||
|
||||
@@ -8,8 +8,8 @@ from crewai.flow.dsl._types import FlowMethodDecorator, FlowTrigger
|
||||
from crewai.flow.dsl._utils import (
|
||||
P,
|
||||
R,
|
||||
_merge_flow_method_definition,
|
||||
_method_action,
|
||||
_set_flow_method_definition,
|
||||
)
|
||||
from crewai.flow.flow_definition import FlowMethodDefinition
|
||||
from crewai.flow.flow_wrappers import StartMethod
|
||||
@@ -54,7 +54,7 @@ def start(
|
||||
def decorator(func: Callable[P, R]) -> StartMethod[P, R]:
|
||||
wrapper = StartMethod(func)
|
||||
|
||||
_set_flow_method_definition(
|
||||
_merge_flow_method_definition(
|
||||
wrapper,
|
||||
FlowMethodDefinition(
|
||||
do=_method_action(func),
|
||||
|
||||
@@ -106,6 +106,25 @@ def _get_flow_method_definition(method: Any) -> FlowMethodDefinition | None:
|
||||
return None
|
||||
|
||||
|
||||
def _merge_flow_method_definition(
|
||||
wrapper: FlowMethod[P, R],
|
||||
definition: FlowMethodDefinition,
|
||||
) -> None:
|
||||
existing = _get_flow_method_definition(wrapper)
|
||||
if existing is None:
|
||||
_set_flow_method_definition(wrapper, definition)
|
||||
return
|
||||
|
||||
updates = {
|
||||
field_name: getattr(definition, field_name)
|
||||
for field_name in definition.model_fields_set
|
||||
}
|
||||
_set_flow_method_definition(
|
||||
wrapper,
|
||||
existing.model_copy(deep=True, update=updates),
|
||||
)
|
||||
|
||||
|
||||
def _is_json_serializable(value: Any) -> bool:
|
||||
try:
|
||||
json.dumps(value)
|
||||
|
||||
@@ -870,14 +870,6 @@ def _validate_action_cel(
|
||||
def log_flow_definition_issues(definition: FlowDefinition) -> None:
|
||||
for method_name, method in definition.methods.items():
|
||||
path = f"methods.{method_name}"
|
||||
if method.router and not method.is_start and method.listen is None:
|
||||
_log_flow_definition_issue(
|
||||
definition.name,
|
||||
code="router_without_trigger",
|
||||
severity="error",
|
||||
path=path,
|
||||
message="router: true requires either start or listen",
|
||||
)
|
||||
if method.emit and not method.router:
|
||||
_log_flow_definition_issue(
|
||||
definition.name,
|
||||
|
||||
@@ -3007,6 +3007,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
"""
|
||||
# First, handle routers repeatedly until no router triggers anymore
|
||||
router_results = []
|
||||
router_result_payloads: dict[str, Any] = {}
|
||||
router_result_to_feedback: dict[
|
||||
str, Any
|
||||
] = {} # Map outcome -> HumanFeedbackResult
|
||||
@@ -3044,6 +3045,11 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
router_result_str = str(router_result)
|
||||
router_result_event = FlowMethodName(router_result_str)
|
||||
router_results.append(router_result_event)
|
||||
router_result_payloads[router_result_str] = (
|
||||
self.last_human_feedback
|
||||
if self.last_human_feedback is not None
|
||||
else router_result
|
||||
)
|
||||
|
||||
if self.last_human_feedback is not None:
|
||||
router_result_to_feedback[router_result_str] = (
|
||||
@@ -3064,7 +3070,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
current_trigger, router_only=False
|
||||
)
|
||||
if listeners_triggered:
|
||||
listener_result = router_result_to_feedback.get(
|
||||
listener_result = router_result_payloads.get(
|
||||
str(current_trigger), result
|
||||
)
|
||||
racing_group = self._get_racing_group_for_listeners(
|
||||
|
||||
@@ -386,6 +386,54 @@ def test_router_runtime_uses_flow_definition_without_legacy_router_metadata():
|
||||
assert execution_order == ["begin", "decide", "handle_left"]
|
||||
|
||||
|
||||
def test_start_router_runtime_routes_public_dsl_return_value():
|
||||
execution_order = []
|
||||
|
||||
class StartRouterFlow(Flow):
|
||||
@start()
|
||||
@router(emit=["continue"])
|
||||
def decide(self):
|
||||
execution_order.append("decide")
|
||||
return "continue"
|
||||
|
||||
@listen("continue")
|
||||
def handle_continue(self, result):
|
||||
execution_order.append(f"handle_continue:{result}")
|
||||
return "done"
|
||||
|
||||
assert StartRouterFlow().kickoff() == "done"
|
||||
assert execution_order == ["decide", "handle_continue:continue"]
|
||||
|
||||
|
||||
def test_start_router_runtime_chains_to_stacked_listener_router():
|
||||
execution_order = []
|
||||
|
||||
class ChainedStartRouterFlow(Flow):
|
||||
@start()
|
||||
@router(emit=["approved", "not_approved"])
|
||||
def first_router(self):
|
||||
execution_order.append("first_router")
|
||||
return "approved"
|
||||
|
||||
@listen("approved")
|
||||
@router(emit=["second_approval", "not_approved"])
|
||||
def second_router(self):
|
||||
execution_order.append("second_router")
|
||||
return "second_approval"
|
||||
|
||||
@listen("second_approval")
|
||||
def handle_second_approval(self, result):
|
||||
execution_order.append(f"handle_second_approval:{result}")
|
||||
return "done"
|
||||
|
||||
assert ChainedStartRouterFlow().kickoff() == "done"
|
||||
assert execution_order == [
|
||||
"first_router",
|
||||
"second_router",
|
||||
"handle_second_approval:second_approval",
|
||||
]
|
||||
|
||||
|
||||
def test_router_falsy_result_emits_runtime_event():
|
||||
execution_order = []
|
||||
|
||||
|
||||
@@ -565,6 +565,54 @@ def test_flow_definition_classifies_start_router_from_human_feedback_emit():
|
||||
assert entry_point.emit is None
|
||||
|
||||
|
||||
def test_flow_definition_classifies_public_dsl_start_router():
|
||||
class StartRouterFlow(Flow):
|
||||
@start()
|
||||
@router(emit=["continue", "stop"])
|
||||
def entry_point(self):
|
||||
return "continue"
|
||||
|
||||
@router(emit=["resume"])
|
||||
@start()
|
||||
def alternate_entry_point(self):
|
||||
return "resume"
|
||||
|
||||
entry_point = StartRouterFlow.flow_definition().methods["entry_point"]
|
||||
alternate_entry_point = StartRouterFlow.flow_definition().methods[
|
||||
"alternate_entry_point"
|
||||
]
|
||||
|
||||
assert entry_point.is_start is True
|
||||
assert entry_point.router is True
|
||||
assert entry_point.listen is None
|
||||
assert entry_point.emit == ["continue", "stop"]
|
||||
assert alternate_entry_point.is_start is True
|
||||
assert alternate_entry_point.router is True
|
||||
assert alternate_entry_point.listen is None
|
||||
assert alternate_entry_point.emit == ["resume"]
|
||||
|
||||
|
||||
def test_flow_definition_merges_stacked_listen_router():
|
||||
class ChainedRouterFlow(Flow):
|
||||
@start()
|
||||
@router(emit=["approved", "not_approved"])
|
||||
def first_router(self):
|
||||
return "approved"
|
||||
|
||||
@listen("approved")
|
||||
@router(emit=["second_approval", "not_approved"])
|
||||
def second_router(self):
|
||||
return "second_approval"
|
||||
|
||||
methods = ChainedRouterFlow.flow_definition().methods
|
||||
|
||||
assert methods["first_router"].is_start is True
|
||||
assert methods["first_router"].listen is None
|
||||
assert methods["second_router"].router is True
|
||||
assert methods["second_router"].listen == "approved"
|
||||
assert methods["second_router"].emit == ["second_approval", "not_approved"]
|
||||
|
||||
|
||||
def test_flow_definition_round_trips_json_and_yaml():
|
||||
class RoundTripFlow(Flow):
|
||||
@start()
|
||||
@@ -883,7 +931,7 @@ def test_flow_definition_ignores_legacy_diagnostics_loaded_from_contract():
|
||||
assert "diagnostics" not in definition.to_dict()
|
||||
|
||||
|
||||
def test_router_start_false_without_listen_logs_missing_trigger(caplog):
|
||||
def test_router_start_false_without_listen_is_allowed(caplog):
|
||||
caplog.set_level(logging.ERROR, logger="crewai.flow.flow_definition")
|
||||
|
||||
flow_definition.FlowDefinition.from_dict(
|
||||
@@ -901,12 +949,7 @@ def test_router_start_false_without_listen_logs_missing_trigger(caplog):
|
||||
}
|
||||
)
|
||||
|
||||
assert any(
|
||||
record.levelno == logging.ERROR
|
||||
and "router_without_trigger" in record.message
|
||||
and "methods.decision" in record.message
|
||||
for record in caplog.records
|
||||
)
|
||||
assert not caplog.records
|
||||
|
||||
|
||||
def test_router_human_feedback_preserves_existing_router_metadata():
|
||||
@@ -1048,7 +1091,7 @@ def test_flow_definition_cache_is_not_reused_by_subclasses():
|
||||
assert set(child_definition.methods) == {"child_step"}
|
||||
|
||||
|
||||
def test_flow_definition_logs_validation_issues_when_loaded_from_contract(caplog):
|
||||
def test_flow_definition_allows_router_without_trigger(caplog):
|
||||
caplog.set_level(logging.WARNING, logger="crewai.flow.flow_definition")
|
||||
|
||||
flow_definition.FlowDefinition.from_dict(
|
||||
@@ -1065,9 +1108,11 @@ def test_flow_definition_logs_validation_issues_when_loaded_from_contract(caplog
|
||||
}
|
||||
)
|
||||
|
||||
assert any(
|
||||
record.levelno == logging.ERROR
|
||||
and "LoadedFlow" in record.message
|
||||
and "router_without_trigger" in record.message
|
||||
for record in caplog.records
|
||||
)
|
||||
class StandaloneRouterFlow(Flow):
|
||||
@router(emit=["continue"])
|
||||
def decision(self):
|
||||
return "continue"
|
||||
|
||||
StandaloneRouterFlow.flow_definition()
|
||||
|
||||
assert not caplog.records
|
||||
|
||||
Reference in New Issue
Block a user