diff --git a/lib/crewai/src/crewai/flow/dsl/_utils.py b/lib/crewai/src/crewai/flow/dsl/_utils.py index 119173500..d25cb3b54 100644 --- a/lib/crewai/src/crewai/flow/dsl/_utils.py +++ b/lib/crewai/src/crewai/flow/dsl/_utils.py @@ -377,6 +377,11 @@ def _build_method_definition( else: method_definition = fragment.model_copy(deep=True) + # Skip / qualnames: they can never be re-imported, so a + # missing handler is more honest than a dead reference. + if "<" not in method.__qualname__: + method_definition.handler = f"{method.__module__}:{method.__qualname__}" + human_feedback = _build_human_feedback_definition( method, diagnostics, f"{path}.human_feedback" ) diff --git a/lib/crewai/src/crewai/flow/flow_definition.py b/lib/crewai/src/crewai/flow/flow_definition.py index 0830f7a65..da82abf2e 100644 --- a/lib/crewai/src/crewai/flow/flow_definition.py +++ b/lib/crewai/src/crewai/flow/flow_definition.py @@ -93,6 +93,7 @@ class FlowHumanFeedbackDefinition(BaseModel): class FlowMethodDefinition(BaseModel): """Static definition of one Flow method and its execution roles.""" + handler: str | None = None start: bool | FlowDefinitionCondition | None = None listen: FlowDefinitionCondition | None = None router: bool = False diff --git a/lib/crewai/src/crewai/flow/runtime.py b/lib/crewai/src/crewai/flow/runtime.py index 0ceb0815d..1937f5363 100644 --- a/lib/crewai/src/crewai/flow/runtime.py +++ b/lib/crewai/src/crewai/flow/runtime.py @@ -22,6 +22,7 @@ from concurrent.futures import Future, ThreadPoolExecutor import contextvars import copy import enum +import importlib import inspect import logging import threading @@ -169,6 +170,24 @@ def _condition_satisfied(condition: FlowDefinitionCondition, events: set[str]) - return combine(_condition_satisfied(branch, events) for branch in branches) +def _resolve_handler(ref: str) -> Callable[..., Any]: + module_name, separator, qualname = ref.partition(":") + if not separator or not module_name or not qualname: + raise ValueError( + f"invalid handler reference {ref!r}; expected 'module:qualname'" + ) + module = importlib.import_module(module_name) + target: Any = module + for part in qualname.split("."): + target = getattr(target, part) + if not callable(target): + raise TypeError( + f"handler reference {ref!r} resolved to a non-callable " + f"{type(target).__name__}" + ) + return cast(Callable[..., Any], target) + + def _iter_condition_events(condition: FlowDefinitionCondition) -> Iterator[str]: if isinstance(condition, str): yield condition @@ -694,6 +713,15 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): cls._flow_definition = flow_definition return flow_definition + @classmethod + def from_definition(cls, definition: FlowDefinition) -> Flow[Any]: + """Build a runnable Flow directly from a definition; no subclass required.""" + return cls.model_validate({}, context={"flow_definition": definition}) + + @property + def _flow_name(self) -> str: + return self.name or self._definition.name + def _start_method_names(self) -> list[FlowMethodName]: return [ FlowMethodName(method_name) @@ -874,7 +902,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): restore_event_scope(()) reset_last_event_id() - _methods: dict[FlowMethodName, FlowMethod[Any, Any]] = PrivateAttr( + _methods: dict[FlowMethodName, Callable[..., Any]] = PrivateAttr( default_factory=dict ) _method_execution_counts: dict[FlowMethodName, int] = PrivateAttr( @@ -918,16 +946,24 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): object.__setattr__(self, name, value) def model_post_init(self, __context: Any) -> None: - self._flow_post_init() + definition = ( + __context.get("flow_definition") if isinstance(__context, dict) else None + ) + self._flow_post_init(definition) - def _flow_post_init(self) -> None: + def _flow_post_init(self, definition: FlowDefinition | None = None) -> None: """Heavy initialization: state creation, events, memory, method registration.""" if getattr(self, "_flow_post_init_done", False): return object.__setattr__(self, "_flow_post_init_done", True) self._initialize_runtime_extension_attrs() - self._definition = type(self).flow_definition() + self._definition = definition or type(self).flow_definition() + methods = ( + self._handler_bound_methods() + if definition is not None + else self._class_bound_methods() + ) if self._state is None: self._state = self._create_initial_state() @@ -943,7 +979,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, FlowCreatedEvent( type="flow_created", - flow_name=self.name or self._definition.name, + flow_name=self._flow_name, ), ) @@ -953,17 +989,44 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): if self.memory is None and not getattr(self, "_skip_auto_memory", False): from crewai.memory.utils import sanitize_scope_name - flow_name = sanitize_scope_name(self.name or self._definition.name) + flow_name = sanitize_scope_name(self._flow_name) self.memory = Memory(root_scope=f"/flow/{flow_name}") - # Build the runtime method lookup from the static FlowDefinition. + self._methods.update(methods) + + def _handler_bound_methods(self) -> dict[FlowMethodName, Callable[..., Any]]: + methods: dict[FlowMethodName, Callable[..., Any]] = {} + unresolved: list[str] = [] + for method_name, method_definition in self._definition.methods.items(): + if method_definition.handler is None: + unresolved.append(f"{method_name}: no handler") + continue + try: + handler = _resolve_handler(method_definition.handler) + except Exception as e: + unresolved.append(f"{method_name}: {e}") + continue + if getattr(handler, "__self__", None) is None: + handler = handler.__get__(self, type(self)) + methods[FlowMethodName(method_name)] = handler + if unresolved: + raise ValueError( + f"Cannot build flow {self._definition.name!r} from its definition; " + "methods with missing or unresolvable handlers: " + + "; ".join(unresolved) + ) + return methods + + def _class_bound_methods(self) -> dict[FlowMethodName, Callable[..., Any]]: + methods: dict[FlowMethodName, Callable[..., Any]] = {} for method_name in self._definition.methods: method = getattr(self, method_name, None) if method is None: continue if not hasattr(method, "__self__"): - method = method.__get__(self, self.__class__) - self._methods[FlowMethodName(method_name)] = method + method = method.__get__(self, type(self)) + methods[FlowMethodName(method_name)] = method + return methods def recall(self, query: str, **kwargs: Any) -> Any: """Recall relevant memories. Delegates to this flow's memory. @@ -1364,7 +1427,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, FlowStartedEvent( type="flow_started", - flow_name=self.name or self._definition.name, + flow_name=self._flow_name, inputs=None, ), ) @@ -1440,7 +1503,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, MethodExecutionFinishedEvent( type="method_execution_finished", - flow_name=self.name or self._definition.name, + flow_name=self._flow_name, method_name=context.method_name, result=collapsed_outcome if emit else result, state=self._state, @@ -1494,7 +1557,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, FlowPausedEvent( type="flow_paused", - flow_name=self.name or self._definition.name, + flow_name=self._flow_name, flow_id=e.context.flow_id, method_name=e.context.method_name, state=self._copy_and_serialize_state(), @@ -1525,7 +1588,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, FlowFinishedEvent( type="flow_finished", - flow_name=self.name or self._definition.name, + flow_name=self._flow_name, result=final_result, state=self._copy_and_serialize_state(), ), @@ -2168,7 +2231,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): # explicit finalization call closes the batch. started_event = FlowStartedEvent( type="flow_started", - flow_name=self.name or self._definition.name, + flow_name=self._flow_name, inputs=inputs, ) future = crewai_event_bus.emit(self, started_event) @@ -2260,7 +2323,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, FlowPausedEvent( type="flow_paused", - flow_name=self.name or self._definition.name, + flow_name=self._flow_name, flow_id=e.context.flow_id, method_name=e.context.method_name, state=self._copy_and_serialize_state(), @@ -2310,7 +2373,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, FlowFinishedEvent( type="flow_finished", - flow_name=self.name or self._definition.name, + flow_name=self._flow_name, result=final_output, state=self._copy_and_serialize_state(), ), @@ -2396,7 +2459,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): MethodExecutionFinishedEvent, MethodExecutionFailedEvent, ) - flow_name = self.name or self._definition.name + flow_name = self._flow_name nodes = sorted( ( n @@ -2475,15 +2538,16 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): def _inject_trigger_payload_for_start_method( self, original_method: Callable[..., Any] ) -> Callable[..., Any]: + accepts_trigger_payload = ( + "crewai_trigger_payload" in inspect.signature(original_method).parameters + ) + def prepare_kwargs( *args: Any, **kwargs: Any ) -> tuple[tuple[Any, ...], dict[str, Any]]: inputs = cast(dict[str, Any], baggage.get_baggage("flow_inputs") or {}) trigger_payload = inputs.get("crewai_trigger_payload") - sig = inspect.signature(original_method) - accepts_trigger_payload = "crewai_trigger_payload" in sig.parameters - if trigger_payload is not None and accepts_trigger_payload: kwargs["crewai_trigger_payload"] = trigger_payload elif trigger_payload is not None: @@ -2533,7 +2597,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): MethodExecutionStartedEvent( type="method_execution_started", method_name=method_name, - flow_name=self.name or self._definition.name, + flow_name=self._flow_name, params=dumped_params, state=self._copy_and_serialize_state(), ), @@ -2585,7 +2649,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): finished_event = MethodExecutionFinishedEvent( type="method_execution_finished", method_name=method_name, - flow_name=self.name or self._definition.name, + flow_name=self._flow_name, state=self._copy_and_serialize_state(), result=result, ) @@ -2614,7 +2678,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): MethodExecutionPausedEvent( type="method_execution_paused", method_name=method_name, - flow_name=self.name or self._definition.name, + flow_name=self._flow_name, state=self._copy_and_serialize_state(), flow_id=e.context.flow_id, message=e.context.message, @@ -2630,7 +2694,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): MethodExecutionFailedEvent( type="method_execution_failed", method_name=method_name, - flow_name=self.name or self._definition.name, + flow_name=self._flow_name, error=e, ), ) @@ -2881,8 +2945,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): method = self._methods[listener_name] sig = inspect.signature(method) - params = list(sig.parameters.values()) - method_params = [p for p in params if p.name != "self"] + method_params = [p for p in sig.parameters.values() if p.name != "self"] if triggering_event_id: with triggered_by_scope(triggering_event_id): @@ -3038,7 +3101,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, FlowInputRequestedEvent( type="flow_input_requested", - flow_name=self.name or self._definition.name, + flow_name=self._flow_name, method_name=method_name, message=message, metadata=metadata, @@ -3105,7 +3168,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, FlowInputReceivedEvent( type="flow_input_received", - flow_name=self.name or self._definition.name, + flow_name=self._flow_name, method_name=method_name, message=message, response=response, @@ -3143,7 +3206,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, HumanFeedbackRequestedEvent( type="human_feedback_requested", - flow_name=self.name or self._definition.name, + flow_name=self._flow_name, method_name="", # Will be set by decorator if needed output=output, message=message, @@ -3172,7 +3235,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, HumanFeedbackReceivedEvent( type="human_feedback_received", - flow_name=self.name or self._definition.name, + flow_name=self._flow_name, method_name="", # Will be set by decorator if needed feedback=feedback, outcome=None, # Will be determined after collapsing @@ -3347,7 +3410,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, FlowPlotEvent( type="flow_plot", - flow_name=self.name or self._definition.name, + flow_name=self._flow_name, ), ) structure = build_flow_structure(cast(Any, self)) diff --git a/lib/crewai/tests/test_flow_from_definition.py b/lib/crewai/tests/test_flow_from_definition.py new file mode 100644 index 000000000..8480b43e3 --- /dev/null +++ b/lib/crewai/tests/test_flow_from_definition.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +import pytest + +from crewai.events.event_bus import crewai_event_bus +from crewai.events.types.flow_events import ( + MethodExecutionFinishedEvent, + MethodExecutionStartedEvent, +) +from crewai.flow import Flow, and_, listen, or_, router, start +from crewai.flow.flow_definition import FlowDefinition + + +class ChainFlow(Flow): + @start() + def begin(self): + self.state["begin_ran"] = True + return "hello" + + @listen(begin) + def shout(self, result): + return result.upper() + + @listen(shout) + def confirm(self): + self.state["confirmed"] = True + return f"confirmed:{self.state['confirmed']}" + + +CHAIN_YAML = f""" +schema: crewai.flow/v1 +name: ChainFlow +methods: + begin: + handler: {__name__}:ChainFlow.begin + start: true + shout: + handler: {__name__}:ChainFlow.shout + listen: begin + confirm: + handler: {__name__}:ChainFlow.confirm + listen: shout +""" + + +class MergeFlow(Flow): + @start() + def begin(self): + return "go" + + @listen(begin) + def left(self): + return "left" + + @listen(begin) + def right(self): + return "right" + + @listen(or_(left, right)) + def either(self): + self.state["either_ran"] = True + return "either" + + @listen(and_(left, right, either)) + def join(self): + self.state["joined"] = True + return "joined" + + +MERGE_YAML = f""" +schema: crewai.flow/v1 +name: MergeFlow +methods: + begin: + handler: {__name__}:MergeFlow.begin + start: true + left: + handler: {__name__}:MergeFlow.left + listen: begin + right: + handler: {__name__}:MergeFlow.right + listen: begin + either: + handler: {__name__}:MergeFlow.either + listen: + or: [left, right] + join: + handler: {__name__}:MergeFlow.join + listen: + and: [left, right, either] +""" + + +class RouteFlow(Flow): + @start() + def begin(self): + return "go" + + @router(begin) + def decide(self): + return "left" if self.state.get("direction") == "left" else "right" + + @listen("left") + def take_left(self): + return "took-left" + + @listen("right") + def take_right(self): + return "took-right" + + +ROUTE_YAML = f""" +schema: crewai.flow/v1 +name: RouteFlow +methods: + begin: + handler: {__name__}:RouteFlow.begin + start: true + decide: + handler: {__name__}:RouteFlow.decide + listen: begin + router: true + take_left: + handler: {__name__}:RouteFlow.take_left + listen: left + take_right: + handler: {__name__}:RouteFlow.take_right + listen: right +""" + + +class LoopFlow(Flow): + @start("retry") + def step(self): + self.state["count"] = self.state.get("count", 0) + 1 + return self.state["count"] + + @router(step) + def decide(self): + if self.state["count"] < 3: + return "retry" + return "done" + + @listen("done") + def finish(self): + return f"finished:{self.state['count']}" + + +LOOP_YAML = f""" +schema: crewai.flow/v1 +name: LoopFlow +methods: + step: + handler: {__name__}:LoopFlow.step + start: retry + decide: + handler: {__name__}:LoopFlow.decide + listen: step + router: true + finish: + handler: {__name__}:LoopFlow.finish + listen: done +""" + + +def _run_with_events(flow, inputs=None): + events = [] + with crewai_event_bus.scoped_handlers(): + + @crewai_event_bus.on(MethodExecutionStartedEvent) + def on_started(source, event): + events.append(event) + + @crewai_event_bus.on(MethodExecutionFinishedEvent) + def on_finished(source, event): + events.append(event) + + result = flow.kickoff(inputs=inputs) + events.sort(key=lambda e: e.timestamp) + return result, [ + (type(e).__name__, str(e.method_name), e.flow_name) for e in events + ] + + +def _state_without_id(flow): + snapshot = dict(flow.state) + snapshot.pop("id", None) + return snapshot + + +def assert_parity(flow_cls, yaml_str, inputs=None, ordered=True): + class_flow = flow_cls() + class_result, class_events = _run_with_events(class_flow, inputs) + + definition = FlowDefinition.from_yaml(yaml_str) + definition_flow = Flow.from_definition(definition) + definition_result, definition_events = _run_with_events(definition_flow, inputs) + + assert definition_result == class_result + assert _state_without_id(definition_flow) == _state_without_id(class_flow) + if ordered: + assert definition_flow.method_outputs == class_flow.method_outputs + assert definition_events == class_events + else: + assert sorted(map(repr, definition_flow.method_outputs)) == sorted( + map(repr, class_flow.method_outputs) + ) + assert sorted(definition_events) == sorted(class_events) + return definition_flow, definition_result + + +def test_simple_chain_parity(): + flow, result = assert_parity(ChainFlow, CHAIN_YAML) + assert result == "confirmed:True" + assert flow.method_outputs == ["hello", "HELLO", "confirmed:True"] + + +def test_and_or_merge_parity(): + flow, _ = assert_parity(MergeFlow, MERGE_YAML, ordered=False) + assert flow.state["joined"] is True + assert flow.state["either_ran"] is True + + +def test_router_label_parity_for_each_branch(): + left_flow, _ = assert_parity(RouteFlow, ROUTE_YAML, inputs={"direction": "left"}) + assert "took-left" in left_flow.method_outputs + assert "took-right" not in left_flow.method_outputs + + right_flow, _ = assert_parity(RouteFlow, ROUTE_YAML, inputs={"direction": "right"}) + assert "took-right" in right_flow.method_outputs + + +def test_cyclic_flow_parity(): + flow, result = assert_parity(LoopFlow, LOOP_YAML) + assert result == "finished:3" + assert flow.state["count"] == 3 + + +def test_definition_flow_events_use_definition_name(): + definition = FlowDefinition.from_yaml(CHAIN_YAML) + flow = Flow.from_definition(definition) + _, events = _run_with_events(flow) + assert events + assert all(flow_name == "ChainFlow" for _, _, flow_name in events) + + +def test_from_definition_missing_handler_raises(): + definition = FlowDefinition.from_dict( + { + "schema": "crewai.flow/v1", + "name": "NoHandlers", + "methods": {"begin": {"start": True}}, + } + ) + + with pytest.raises(ValueError, match="begin: no handler"): + Flow.from_definition(definition) + + +def test_from_definition_unresolvable_handler_raises(): + definition = FlowDefinition.from_dict( + { + "schema": "crewai.flow/v1", + "name": "BadHandlers", + "methods": { + "begin": { + "start": True, + "handler": "definitely_not_a_module_xyz:nope", + } + }, + } + ) + + with pytest.raises(ValueError, match="missing or unresolvable handlers.*begin"): + Flow.from_definition(definition) + + +def test_from_definition_malformed_handler_raises(): + definition = FlowDefinition.from_dict( + { + "schema": "crewai.flow/v1", + "name": "MalformedHandlers", + "methods": {"begin": {"start": True, "handler": "no-colon-here"}}, + } + ) + + with pytest.raises(ValueError, match="expected 'module:qualname'"): + Flow.from_definition(definition) + + +def test_flow_definition_stamps_handler_refs(): + definition = ChainFlow.flow_definition() + + assert definition.methods["begin"].handler == f"{__name__}:ChainFlow.begin" + assert definition.methods["shout"].handler == f"{__name__}:ChainFlow.shout"