From d029a5cd9203a6549dd2ac94c968b260c8b29928 Mon Sep 17 00:00:00 2001 From: Vinicius Brasil Date: Wed, 10 Jun 2026 15:46:24 -0700 Subject: [PATCH] Build flow state from FlowDefinition Definition-driven flows previously always started with a bare dict state. --- lib/crewai/src/crewai/flow/flow_definition.py | 3 +- lib/crewai/src/crewai/flow/runtime.py | 123 ++++++++-- .../crewai/utilities/pydantic_schema_utils.py | 6 +- lib/crewai/tests/test_flow.py | 19 ++ lib/crewai/tests/test_flow_from_definition.py | 215 +++++++++++++++++- 5 files changed, 339 insertions(+), 27 deletions(-) diff --git a/lib/crewai/src/crewai/flow/flow_definition.py b/lib/crewai/src/crewai/flow/flow_definition.py index da82abf2e..157cb56f7 100644 --- a/lib/crewai/src/crewai/flow/flow_definition.py +++ b/lib/crewai/src/crewai/flow/flow_definition.py @@ -52,8 +52,9 @@ class FlowDefinitionDiagnostic(BaseModel): class FlowStateDefinition(BaseModel): """Static description of a Flow state contract.""" - type: TypingLiteral["dict", "pydantic", "unknown"] = "dict" + type: TypingLiteral["dict", "pydantic", "json_schema", "unknown"] = "dict" ref: str | None = None + json_schema: dict[str, Any] | None = None default: Any = None diff --git a/lib/crewai/src/crewai/flow/runtime.py b/lib/crewai/src/crewai/flow/runtime.py index 1937f5363..011074992 100644 --- a/lib/crewai/src/crewai/flow/runtime.py +++ b/lib/crewai/src/crewai/flow/runtime.py @@ -96,6 +96,7 @@ from crewai.flow.flow_definition import ( FlowDefinition, FlowDefinitionCondition, FlowMethodDefinition, + FlowStateDefinition, ) from crewai.flow.flow_wrappers import ( FlowMethod, @@ -188,6 +189,54 @@ def _resolve_handler(ref: str) -> Callable[..., Any]: return cast(Callable[..., Any], target) +def _build_definition_state_model( + state_definition: FlowStateDefinition, +) -> BaseModel | None: + kwargs = ( + dict(state_definition.default) + if isinstance(state_definition.default, dict) + else {} + ) + + model_class: type[BaseModel] | None = None + if state_definition.ref: + try: + resolved = _resolve_handler(state_definition.ref) + except Exception: + logger.warning( + "Could not import state ref %r", state_definition.ref, exc_info=True + ) + else: + if isinstance(resolved, type) and issubclass(resolved, BaseModel): + model_class = resolved + else: + logger.warning( + "State ref %r is not a pydantic model", state_definition.ref + ) + + if model_class is None and state_definition.json_schema: + from crewai.utilities.pydantic_schema_utils import create_model_from_schema + + try: + model_class = create_model_from_schema(state_definition.json_schema) + except Exception: + logger.warning( + "Could not build a state model from the declared json_schema", + exc_info=True, + ) + + if model_class is None: + return None + + if not issubclass(model_class, FlowState): + + class StateWithId(FlowState, model_class): # type: ignore[misc, valid-type] + pass + + model_class = StateWithId + return model_class(**kwargs) + + def _iter_condition_events(condition: FlowDefinitionCondition) -> Iterator[str]: if isinstance(condition, str): yield condition @@ -718,10 +767,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): """Build a runnable Flow directly from a definition; no subclass required.""" return cls.model_validate({}, context={"flow_definition": definition}) - @property - def _flow_name(self) -> str: - return self.name or self._definition.name - def _start_method_names(self) -> list[FlowMethodName]: return [ FlowMethodName(method_name) @@ -959,6 +1004,8 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self._initialize_runtime_extension_attrs() self._definition = definition or type(self).flow_definition() + if self.name and self.name != self._definition.name: + self._definition = self._definition.model_copy(update={"name": self.name}) methods = ( self._handler_bound_methods() if definition is not None @@ -979,7 +1026,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, FlowCreatedEvent( type="flow_created", - flow_name=self._flow_name, + flow_name=self._definition.name, ), ) @@ -989,7 +1036,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): if self.memory is None and not getattr(self, "_skip_auto_memory", False): from crewai.memory.utils import sanitize_scope_name - flow_name = sanitize_scope_name(self._flow_name) + flow_name = sanitize_scope_name(self._definition.name) self.memory = Memory(root_scope=f"/flow/{flow_name}") self._methods.update(methods) @@ -1427,7 +1474,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, FlowStartedEvent( type="flow_started", - flow_name=self._flow_name, + flow_name=self._definition.name, inputs=None, ), ) @@ -1503,7 +1550,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, MethodExecutionFinishedEvent( type="method_execution_finished", - flow_name=self._flow_name, + flow_name=self._definition.name, method_name=context.method_name, result=collapsed_outcome if emit else result, state=self._state, @@ -1557,7 +1604,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, FlowPausedEvent( type="flow_paused", - flow_name=self._flow_name, + flow_name=self._definition.name, flow_id=e.context.flow_id, method_name=e.context.method_name, state=self._copy_and_serialize_state(), @@ -1588,7 +1635,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, FlowFinishedEvent( type="flow_finished", - flow_name=self._flow_name, + flow_name=self._definition.name, result=final_result, state=self._copy_and_serialize_state(), ), @@ -1654,7 +1701,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): return cast(T, {"id": str(uuid4())}) if init_state is None: - return cast(T, {"id": str(uuid4())}) + return cast(T, self._create_definition_state()) if isinstance(init_state, type): state_class = init_state @@ -1696,6 +1743,34 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): f"Initial state must be dict or BaseModel, got {type(self.initial_state)}" ) + def _create_definition_state(self) -> dict[str, Any] | BaseModel: + state_definition = self._definition.state + if state_definition is None: + return {"id": str(uuid4())} + if state_definition.type in ("pydantic", "json_schema"): + state = _build_definition_state_model(state_definition) + if state is not None: + return state + logger.error( + "Flow %r declares %s state but neither ref nor json_schema " + "produced a model; falling back to dict state", + self._definition.name, + state_definition.type, + ) + elif state_definition.type == "unknown": + logger.warning( + "Flow %r declares state of unknown type; falling back to dict state", + self._definition.name, + ) + dict_state: dict[str, Any] = ( + dict(state_definition.default) + if isinstance(state_definition.default, dict) + else {} + ) + if "id" not in dict_state: + dict_state["id"] = str(uuid4()) + return dict_state + def _copy_state(self) -> T: """Create a copy of the current state. @@ -2231,7 +2306,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): # explicit finalization call closes the batch. started_event = FlowStartedEvent( type="flow_started", - flow_name=self._flow_name, + flow_name=self._definition.name, inputs=inputs, ) future = crewai_event_bus.emit(self, started_event) @@ -2323,7 +2398,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, FlowPausedEvent( type="flow_paused", - flow_name=self._flow_name, + flow_name=self._definition.name, flow_id=e.context.flow_id, method_name=e.context.method_name, state=self._copy_and_serialize_state(), @@ -2373,7 +2448,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, FlowFinishedEvent( type="flow_finished", - flow_name=self._flow_name, + flow_name=self._definition.name, result=final_output, state=self._copy_and_serialize_state(), ), @@ -2459,7 +2534,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): MethodExecutionFinishedEvent, MethodExecutionFailedEvent, ) - flow_name = self._flow_name + flow_name = self._definition.name nodes = sorted( ( n @@ -2597,7 +2672,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): MethodExecutionStartedEvent( type="method_execution_started", method_name=method_name, - flow_name=self._flow_name, + flow_name=self._definition.name, params=dumped_params, state=self._copy_and_serialize_state(), ), @@ -2649,7 +2724,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): finished_event = MethodExecutionFinishedEvent( type="method_execution_finished", method_name=method_name, - flow_name=self._flow_name, + flow_name=self._definition.name, state=self._copy_and_serialize_state(), result=result, ) @@ -2678,7 +2753,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): MethodExecutionPausedEvent( type="method_execution_paused", method_name=method_name, - flow_name=self._flow_name, + flow_name=self._definition.name, state=self._copy_and_serialize_state(), flow_id=e.context.flow_id, message=e.context.message, @@ -2694,7 +2769,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): MethodExecutionFailedEvent( type="method_execution_failed", method_name=method_name, - flow_name=self._flow_name, + flow_name=self._definition.name, error=e, ), ) @@ -3101,7 +3176,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, FlowInputRequestedEvent( type="flow_input_requested", - flow_name=self._flow_name, + flow_name=self._definition.name, method_name=method_name, message=message, metadata=metadata, @@ -3168,7 +3243,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, FlowInputReceivedEvent( type="flow_input_received", - flow_name=self._flow_name, + flow_name=self._definition.name, method_name=method_name, message=message, response=response, @@ -3206,7 +3281,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, HumanFeedbackRequestedEvent( type="human_feedback_requested", - flow_name=self._flow_name, + flow_name=self._definition.name, method_name="", # Will be set by decorator if needed output=output, message=message, @@ -3235,7 +3310,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, HumanFeedbackReceivedEvent( type="human_feedback_received", - flow_name=self._flow_name, + flow_name=self._definition.name, method_name="", # Will be set by decorator if needed feedback=feedback, outcome=None, # Will be determined after collapsing @@ -3410,7 +3485,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta): self, FlowPlotEvent( type="flow_plot", - flow_name=self._flow_name, + flow_name=self._definition.name, ), ) structure = build_flow_structure(cast(Any, self)) diff --git a/lib/crewai/src/crewai/utilities/pydantic_schema_utils.py b/lib/crewai/src/crewai/utilities/pydantic_schema_utils.py index ff1d5529b..85a53d9bc 100644 --- a/lib/crewai/src/crewai/utilities/pydantic_schema_utils.py +++ b/lib/crewai/src/crewai/utilities/pydantic_schema_utils.py @@ -999,7 +999,11 @@ def _json_schema_to_pydantic_field( if examples: schema_extra["examples"] = examples - default = ... if is_required else None + default = ( + json_schema["default"] + if "default" in json_schema + else (... if is_required else None) + ) if isinstance(type_, type) and issubclass(type_, (int, float)): if "minimum" in json_schema: diff --git a/lib/crewai/tests/test_flow.py b/lib/crewai/tests/test_flow.py index ab50af05e..c38f41778 100644 --- a/lib/crewai/tests/test_flow.py +++ b/lib/crewai/tests/test_flow.py @@ -1157,6 +1157,25 @@ def test_flow_name(): assert flow.name == "MyFlow" +def test_flow_custom_name_overrides_class_name_in_events(): + class InternalFlowClass(Flow): + name = "PublicName" + + @start() + def begin(self): + return "done" + + received = [] + + @crewai_event_bus.on(FlowStartedEvent) + def handle(source, event): + received.append(event) + + InternalFlowClass().kickoff() + + assert received[0].flow_name == "PublicName" + + def test_nested_and_or_conditions(): """Test nested conditions like or_(and_(A, B), and_(C, D)). diff --git a/lib/crewai/tests/test_flow_from_definition.py b/lib/crewai/tests/test_flow_from_definition.py index 8480b43e3..f93dab69e 100644 --- a/lib/crewai/tests/test_flow_from_definition.py +++ b/lib/crewai/tests/test_flow_from_definition.py @@ -8,6 +8,7 @@ from crewai.events.types.flow_events import ( MethodExecutionStartedEvent, ) from crewai.flow import Flow, and_, listen, or_, router, start +from crewai.flow.flow import FlowState from crewai.flow.flow_definition import FlowDefinition @@ -163,6 +164,142 @@ methods: """ +class CounterState(FlowState): + count: int = 0 + label: str = "none" + + +class PydanticStateFlow(Flow[CounterState]): + @start() + def begin(self): + self.state.count += 1 + return self.state.count + + @listen(begin) + def finish(self): + self.state.label = f"count={self.state.count}" + return self.state.label + + +PYDANTIC_STATE_YAML = f""" +schema: crewai.flow/v1 +name: PydanticStateFlow +state: + type: pydantic + ref: {__name__}:CounterState +methods: + begin: + handler: {__name__}:PydanticStateFlow.begin + start: true + finish: + handler: {__name__}:PydanticStateFlow.finish + listen: begin +""" + +PYDANTIC_STATE_OVERLAY_YAML = f""" +schema: crewai.flow/v1 +name: PydanticStateFlow +state: + type: pydantic + ref: {__name__}:CounterState + default: + count: 5 +methods: + begin: + handler: {__name__}:PydanticStateFlow.begin + start: true + finish: + handler: {__name__}:PydanticStateFlow.finish + listen: begin +""" + +JSON_SCHEMA_STATE_YAML = f""" +schema: crewai.flow/v1 +name: JsonSchemaStateFlow +state: + type: json_schema + json_schema: + title: CounterState + type: object + properties: + count: + type: integer + default: 0 + label: + type: string + default: none +methods: + begin: + handler: {__name__}:PydanticStateFlow.begin + start: true + finish: + handler: {__name__}:PydanticStateFlow.finish + listen: begin +""" + +PYDANTIC_REF_WITH_SCHEMA_FALLBACK_YAML = f""" +schema: crewai.flow/v1 +name: SchemaFallbackFlow +state: + type: pydantic + ref: definitely_not_a_module_xyz:MissingState + json_schema: + title: CounterState + type: object + properties: + count: + type: integer + default: 0 + label: + type: string + default: none +methods: + begin: + handler: {__name__}:PydanticStateFlow.begin + start: true + finish: + handler: {__name__}:PydanticStateFlow.finish + listen: begin +""" + +UNRESOLVABLE_STATE_YAML = f""" +schema: crewai.flow/v1 +name: UnresolvableStateFlow +state: + type: pydantic + ref: definitely_not_a_module_xyz:MissingState +methods: + begin: + handler: {__name__}:ChainFlow.begin + start: true +""" + +DICT_STATE_YAML = f""" +schema: crewai.flow/v1 +name: DictStateFlow +state: + type: dict + default: + count: 5 +methods: + begin: + handler: {__name__}:ChainFlow.begin + start: true +""" + +UNKNOWN_STATE_YAML = f""" +schema: crewai.flow/v1 +name: UnknownStateFlow +state: + type: unknown + ref: somewhere:Something +methods: + begin: + handler: {__name__}:ChainFlow.begin + start: true +""" + + def _run_with_events(flow, inputs=None): events = [] with crewai_event_bus.scoped_handlers(): @@ -183,7 +320,7 @@ def _run_with_events(flow, inputs=None): def _state_without_id(flow): - snapshot = dict(flow.state) + snapshot = dict(flow.state.model_dump()) snapshot.pop("id", None) return snapshot @@ -293,3 +430,79 @@ def test_flow_definition_stamps_handler_refs(): assert definition.methods["begin"].handler == f"{__name__}:ChainFlow.begin" assert definition.methods["shout"].handler == f"{__name__}:ChainFlow.shout" + + +def test_pydantic_state_from_ref_parity(): + flow, result = assert_parity(PydanticStateFlow, PYDANTIC_STATE_YAML) + assert result == "count=1" + assert flow.state.count == 1 + assert flow.state.label == "count=1" + assert flow.state.id + + +def test_pydantic_state_default_overlay(): + flow = Flow.from_definition(FlowDefinition.from_yaml(PYDANTIC_STATE_OVERLAY_YAML)) + result = flow.kickoff() + assert result == "count=6" + assert flow.state.count == 6 + + +def test_json_schema_state(): + flow = Flow.from_definition(FlowDefinition.from_yaml(JSON_SCHEMA_STATE_YAML)) + result = flow.kickoff() + assert result == "count=1" + assert flow.state.count == 1 + assert flow.state.label == "count=1" + assert flow.state.id + + +def test_json_schema_state_validates_inputs(): + flow = Flow.from_definition(FlowDefinition.from_yaml(JSON_SCHEMA_STATE_YAML)) + with pytest.raises(ValueError, match="Invalid inputs"): + flow.kickoff(inputs={"count": "not-a-number"}) + + +def test_pydantic_state_falls_back_to_json_schema_when_ref_unimportable(): + flow = Flow.from_definition( + FlowDefinition.from_yaml(PYDANTIC_REF_WITH_SCHEMA_FALLBACK_YAML) + ) + result = flow.kickoff() + assert result == "count=1" + assert flow.state.count == 1 + + +def test_pydantic_state_without_ref_or_schema_falls_back_to_dict(caplog): + with caplog.at_level("ERROR"): + flow = Flow.from_definition(FlowDefinition.from_yaml(UNRESOLVABLE_STATE_YAML)) + assert "falling back to dict state" in caplog.text + + result = flow.kickoff() + assert result == "hello" + assert flow.state["begin_ran"] is True + assert flow.state["id"] + + +def test_dict_state_is_a_copy_of_default_plus_id(): + definition = FlowDefinition.from_yaml(DICT_STATE_YAML) + + flow = Flow.from_definition(definition) + assert flow.state["count"] == 5 + assert flow.state["id"] + flow.kickoff() + assert flow.state["begin_ran"] is True + + second = Flow.from_definition(definition) + assert second.state["count"] == 5 + assert "begin_ran" not in second.state.model_dump() + assert second.state["id"] != flow.state["id"] + assert definition.state.default == {"count": 5} + + +def test_unknown_state_type_falls_back_to_dict(caplog): + with caplog.at_level("WARNING"): + flow = Flow.from_definition(FlowDefinition.from_yaml(UNKNOWN_STATE_YAML)) + assert "falling back to dict state" in caplog.text + + result = flow.kickoff() + assert result == "hello" + assert flow.state["begin_ran"] is True