From 7bb9bc7e1a6ec7485f274fad14c615273d9a5a26 Mon Sep 17 00:00:00 2001 From: Vinicius Brasil Date: Tue, 16 Jun 2026 21:31:07 -0700 Subject: [PATCH] Discriminate FlowDefinition state types (#6196) Replace the single FlowStateDefinition model with a `type`-discriminated union of FlowDictStateDefinition, FlowPydanticStateDefinition, FlowJsonSchemaStateDefinition, and FlowUnknownStateDefinition. Each branch only carries the fields it actually uses and forbids extras, so an invalid combination like a `dict` state with a `ref` now fails validation instead of being silently accepted. The runtime reads `ref` and `json_schema` defensively since they no longer exist on every branch. ```yaml state: type: json_schema json_schema: type: object properties: topic: type: string ``` Co-authored-by: Claude Opus 4.8 (1M context) --- lib/crewai/src/crewai/flow/dsl/_utils.py | 12 +- lib/crewai/src/crewai/flow/flow_definition.py | 119 ++++++++++++++++-- .../src/crewai/flow/runtime/__init__.py | 18 ++- lib/crewai/tests/test_flow_definition.py | 34 ++++- 4 files changed, 160 insertions(+), 23 deletions(-) diff --git a/lib/crewai/src/crewai/flow/dsl/_utils.py b/lib/crewai/src/crewai/flow/dsl/_utils.py index b203bcd62..35b8006da 100644 --- a/lib/crewai/src/crewai/flow/dsl/_utils.py +++ b/lib/crewai/src/crewai/flow/dsl/_utils.py @@ -15,10 +15,13 @@ from crewai.flow.flow_definition import ( FlowConversationalRouterDefinition, FlowDefinition, FlowDefinitionDiagnostic, + FlowDictStateDefinition, FlowHumanFeedbackDefinition, FlowMethodDefinition, FlowPersistenceDefinition, + FlowPydanticStateDefinition, FlowStateDefinition, + FlowUnknownStateDefinition, _object_ref, ) from crewai.flow.flow_wrappers import ( @@ -185,12 +188,11 @@ def _build_state_definition( default = None if isinstance(state_value, dict): default = _serialize_static_value(state_value, diagnostics, "state.default") - return FlowStateDefinition(type="dict", default=default) + return FlowDictStateDefinition(default=default) if isinstance(state_value, type) and issubclass(state_value, PydanticBaseModel): - return FlowStateDefinition(type="pydantic", ref=_state_ref(state_value)) + return FlowPydanticStateDefinition(ref=_state_ref(state_value)) if isinstance(state_value, PydanticBaseModel): - return FlowStateDefinition( - type="pydantic", + return FlowPydanticStateDefinition( ref=_state_ref(state_value), default=_serialize_static_value(state_value, diagnostics, "state.default"), ) @@ -201,7 +203,7 @@ def _build_state_definition( message=f"could not serialize state type {_object_ref(state_value)}", ) ) - return FlowStateDefinition(type="unknown", ref=_state_ref(state_value)) + return FlowUnknownStateDefinition(ref=_state_ref(state_value)) def _build_config_definition( diff --git a/lib/crewai/src/crewai/flow/flow_definition.py b/lib/crewai/src/crewai/flow/flow_definition.py index b8a32d68e..29c561486 100644 --- a/lib/crewai/src/crewai/flow/flow_definition.py +++ b/lib/crewai/src/crewai/flow/flow_definition.py @@ -12,7 +12,7 @@ from __future__ import annotations import json import logging import re -from typing import Any, Literal as TypingLiteral +from typing import Annotated, Any, Literal as TypingLiteral, TypeAlias from pydantic import ( BaseModel, @@ -46,14 +46,18 @@ __all__ = [ "FlowDefinition", "FlowDefinitionCondition", "FlowDefinitionDiagnostic", + "FlowDictStateDefinition", "FlowEachActionDefinition", "FlowEachInnerActionDefinition", "FlowExpressionActionDefinition", "FlowHumanFeedbackDefinition", + "FlowJsonSchemaStateDefinition", "FlowMethodDefinition", "FlowPersistenceDefinition", + "FlowPydanticStateDefinition", "FlowStateDefinition", "FlowToolActionDefinition", + "FlowUnknownStateDefinition", ] @@ -74,13 +78,114 @@ class FlowDefinitionDiagnostic(BaseModel): path: str | None = None -class FlowStateDefinition(BaseModel): - """Static description of a Flow state contract.""" +class FlowDictStateDefinition(BaseModel): + """Static description of a plain dictionary Flow state contract.""" - type: TypingLiteral["dict", "pydantic", "json_schema", "unknown"] = "dict" - ref: str | None = None - json_schema: dict[str, Any] | None = None - default: dict[str, Any] | None = None + model_config = ConfigDict(extra="forbid") + + type: TypingLiteral["dict"] = Field( + default="dict", + description="Plain dictionary state with optional default values.", + examples=["dict"], + ) + default: dict[str, Any] | None = Field( + default=None, + description="Default state values applied before kickoff inputs.", + examples=[{"topic": "AI agents", "limit": 3}], + ) + + +class FlowPydanticStateDefinition(BaseModel): + """Static description of an importable Pydantic Flow state contract.""" + + model_config = ConfigDict(extra="forbid") + + type: TypingLiteral["pydantic"] = Field( + default="pydantic", + description="Importable Pydantic model used as the Flow state type.", + examples=["pydantic"], + ) + ref: str | None = Field( + default=None, + description="Import reference for the state model, formatted as module:qualname.", + examples=["my_project.flows:ResearchState"], + ) + json_schema: dict[str, Any] | None = Field( + default=None, + description=( + "Fallback JSON Schema used when the Pydantic state ref is unavailable." + ), + examples=[ + { + "type": "object", + "properties": {"topic": {"type": "string"}}, + "required": ["topic"], + } + ], + ) + default: dict[str, Any] | None = Field( + default=None, + description="Default state values applied before kickoff inputs.", + examples=[{"topic": "AI agents", "limit": 3}], + ) + + +class FlowJsonSchemaStateDefinition(BaseModel): + """Static description of an inline JSON Schema Flow state contract.""" + + model_config = ConfigDict(extra="forbid") + + type: TypingLiteral["json_schema"] = Field( + default="json_schema", + description="Inline JSON Schema used as the Flow state contract.", + examples=["json_schema"], + ) + json_schema: dict[str, Any] = Field( + description="JSON Schema used to validate and document flow state.", + examples=[ + { + "type": "object", + "properties": {"topic": {"type": "string"}}, + "required": ["topic"], + } + ], + ) + default: dict[str, Any] | None = Field( + default=None, + description="Default state values applied before kickoff inputs.", + examples=[{"topic": "AI agents", "limit": 3}], + ) + + +class FlowUnknownStateDefinition(BaseModel): + """Static description of a state contract that could not be serialized.""" + + model_config = ConfigDict(extra="forbid") + + type: TypingLiteral["unknown"] = Field( + default="unknown", + description="Unknown state representation; runtime falls back to dictionary state.", + examples=["unknown"], + ) + ref: str | None = Field( + default=None, + description="Best-effort import reference for the unknown state type.", + examples=["my_project.flows:CustomState"], + ) + default: dict[str, Any] | None = Field( + default=None, + description="Default state values applied before kickoff inputs.", + examples=[{"topic": "AI agents", "limit": 3}], + ) + + +FlowStateDefinition: TypeAlias = Annotated[ + FlowDictStateDefinition + | FlowPydanticStateDefinition + | FlowJsonSchemaStateDefinition + | FlowUnknownStateDefinition, + Field(discriminator="type"), +] class FlowConfigDefinition(BaseModel): diff --git a/lib/crewai/src/crewai/flow/runtime/__init__.py b/lib/crewai/src/crewai/flow/runtime/__init__.py index e91de05d3..9451ecf49 100644 --- a/lib/crewai/src/crewai/flow/runtime/__init__.py +++ b/lib/crewai/src/crewai/flow/runtime/__init__.py @@ -193,26 +193,24 @@ def _build_definition_state_model( kwargs = dict(state_definition.default or {}) model_class: type[BaseModel] | None = None - if state_definition.ref: + state_ref = getattr(state_definition, "ref", None) + if state_ref: try: - resolved: Any = resolve_ref(state_definition.ref, field="state") + resolved: Any = resolve_ref(state_ref, field="state") except Exception: - logger.warning( - "Could not import state ref %r", state_definition.ref, exc_info=True - ) + logger.warning("Could not import state ref %r", state_ref, exc_info=True) else: if isinstance(resolved, type) and issubclass(resolved, BaseModel): model_class = resolved else: - logger.warning( - "State ref %r is not a pydantic model", state_definition.ref - ) + logger.warning("State ref %r is not a pydantic model", state_ref) - if model_class is None and state_definition.json_schema: + json_schema = getattr(state_definition, "json_schema", None) + if model_class is None and json_schema: from crewai.utilities.pydantic_schema_utils import create_model_from_schema try: - model_class = create_model_from_schema(state_definition.json_schema) + model_class = create_model_from_schema(json_schema) except Exception: logger.warning( "Could not build a state model from the declared json_schema", diff --git a/lib/crewai/tests/test_flow_definition.py b/lib/crewai/tests/test_flow_definition.py index 946ebd336..dc2bd1c37 100644 --- a/lib/crewai/tests/test_flow_definition.py +++ b/lib/crewai/tests/test_flow_definition.py @@ -8,7 +8,7 @@ from pathlib import Path from typing import Annotated, Literal import pytest -from pydantic import BaseModel +from pydantic import BaseModel, ValidationError import crewai.flow.dsl as flow_dsl import crewai.flow.flow_definition as flow_definition @@ -45,19 +45,51 @@ def test_flow_public_exports_are_explicit(): "FlowDefinition", "FlowDefinitionCondition", "FlowDefinitionDiagnostic", + "FlowDictStateDefinition", "FlowEachActionDefinition", "FlowEachInnerActionDefinition", "FlowExpressionActionDefinition", "FlowHumanFeedbackDefinition", + "FlowJsonSchemaStateDefinition", "FlowMethodDefinition", "FlowPersistenceDefinition", + "FlowPydanticStateDefinition", "FlowStateDefinition", "FlowToolActionDefinition", + "FlowUnknownStateDefinition", } assert "build_flow_structure" in flow_visualization.__all__ assert "calculate_node_levels" not in flow_visualization.__all__ +def test_flow_state_definition_uses_discriminated_branches(): + definition = flow_definition.FlowDefinition.model_validate( + { + "name": "TypedStateFlow", + "state": { + "type": "json_schema", + "json_schema": {"type": "object"}, + }, + } + ) + + assert isinstance( + definition.state, + flow_definition.FlowJsonSchemaStateDefinition, + ) + + with pytest.raises(ValidationError, match="extra_forbidden"): + flow_definition.FlowDefinition.model_validate( + { + "name": "InvalidStateFlow", + "state": { + "type": "dict", + "ref": "my_project.flows:ResearchState", + }, + } + ) + + def test_condition_combinators_return_nested_runtime_tree(): condition = and_("event_a", "event_b", or_("event_c"))