Compare commits

...

1 Commits

Author SHA1 Message Date
Vinicius Brasil
7bb9bc7e1a Discriminate FlowDefinition state types (#6196)
Some checks are pending
CodeQL Advanced / Analyze (actions) (push) Waiting to run
CodeQL Advanced / Analyze (python) (push) Waiting to run
Check Documentation Broken Links / Check broken links (push) Waiting to run
Vulnerability Scan / pip-audit (push) Waiting to run
Replace the single FlowStateDefinition model with a `type`-discriminated
union of FlowDictStateDefinition, FlowPydanticStateDefinition,
FlowJsonSchemaStateDefinition, and FlowUnknownStateDefinition.

Each branch only carries the fields it actually uses and forbids extras,
so an invalid combination like a `dict` state with a `ref` now fails
validation instead of being silently accepted. The runtime reads `ref`
and `json_schema` defensively since they no longer exist on every branch.

```yaml
state:
  type: json_schema
  json_schema:
    type: object
    properties:
      topic:
        type: string
```

Co-authored-by: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-16 21:31:07 -07:00
4 changed files with 160 additions and 23 deletions

View File

@@ -15,10 +15,13 @@ from crewai.flow.flow_definition import (
FlowConversationalRouterDefinition,
FlowDefinition,
FlowDefinitionDiagnostic,
FlowDictStateDefinition,
FlowHumanFeedbackDefinition,
FlowMethodDefinition,
FlowPersistenceDefinition,
FlowPydanticStateDefinition,
FlowStateDefinition,
FlowUnknownStateDefinition,
_object_ref,
)
from crewai.flow.flow_wrappers import (
@@ -185,12 +188,11 @@ def _build_state_definition(
default = None
if isinstance(state_value, dict):
default = _serialize_static_value(state_value, diagnostics, "state.default")
return FlowStateDefinition(type="dict", default=default)
return FlowDictStateDefinition(default=default)
if isinstance(state_value, type) and issubclass(state_value, PydanticBaseModel):
return FlowStateDefinition(type="pydantic", ref=_state_ref(state_value))
return FlowPydanticStateDefinition(ref=_state_ref(state_value))
if isinstance(state_value, PydanticBaseModel):
return FlowStateDefinition(
type="pydantic",
return FlowPydanticStateDefinition(
ref=_state_ref(state_value),
default=_serialize_static_value(state_value, diagnostics, "state.default"),
)
@@ -201,7 +203,7 @@ def _build_state_definition(
message=f"could not serialize state type {_object_ref(state_value)}",
)
)
return FlowStateDefinition(type="unknown", ref=_state_ref(state_value))
return FlowUnknownStateDefinition(ref=_state_ref(state_value))
def _build_config_definition(

View File

@@ -12,7 +12,7 @@ from __future__ import annotations
import json
import logging
import re
from typing import Any, Literal as TypingLiteral
from typing import Annotated, Any, Literal as TypingLiteral, TypeAlias
from pydantic import (
BaseModel,
@@ -46,14 +46,18 @@ __all__ = [
"FlowDefinition",
"FlowDefinitionCondition",
"FlowDefinitionDiagnostic",
"FlowDictStateDefinition",
"FlowEachActionDefinition",
"FlowEachInnerActionDefinition",
"FlowExpressionActionDefinition",
"FlowHumanFeedbackDefinition",
"FlowJsonSchemaStateDefinition",
"FlowMethodDefinition",
"FlowPersistenceDefinition",
"FlowPydanticStateDefinition",
"FlowStateDefinition",
"FlowToolActionDefinition",
"FlowUnknownStateDefinition",
]
@@ -74,13 +78,114 @@ class FlowDefinitionDiagnostic(BaseModel):
path: str | None = None
class FlowStateDefinition(BaseModel):
"""Static description of a Flow state contract."""
class FlowDictStateDefinition(BaseModel):
"""Static description of a plain dictionary Flow state contract."""
type: TypingLiteral["dict", "pydantic", "json_schema", "unknown"] = "dict"
ref: str | None = None
json_schema: dict[str, Any] | None = None
default: dict[str, Any] | None = None
model_config = ConfigDict(extra="forbid")
type: TypingLiteral["dict"] = Field(
default="dict",
description="Plain dictionary state with optional default values.",
examples=["dict"],
)
default: dict[str, Any] | None = Field(
default=None,
description="Default state values applied before kickoff inputs.",
examples=[{"topic": "AI agents", "limit": 3}],
)
class FlowPydanticStateDefinition(BaseModel):
"""Static description of an importable Pydantic Flow state contract."""
model_config = ConfigDict(extra="forbid")
type: TypingLiteral["pydantic"] = Field(
default="pydantic",
description="Importable Pydantic model used as the Flow state type.",
examples=["pydantic"],
)
ref: str | None = Field(
default=None,
description="Import reference for the state model, formatted as module:qualname.",
examples=["my_project.flows:ResearchState"],
)
json_schema: dict[str, Any] | None = Field(
default=None,
description=(
"Fallback JSON Schema used when the Pydantic state ref is unavailable."
),
examples=[
{
"type": "object",
"properties": {"topic": {"type": "string"}},
"required": ["topic"],
}
],
)
default: dict[str, Any] | None = Field(
default=None,
description="Default state values applied before kickoff inputs.",
examples=[{"topic": "AI agents", "limit": 3}],
)
class FlowJsonSchemaStateDefinition(BaseModel):
"""Static description of an inline JSON Schema Flow state contract."""
model_config = ConfigDict(extra="forbid")
type: TypingLiteral["json_schema"] = Field(
default="json_schema",
description="Inline JSON Schema used as the Flow state contract.",
examples=["json_schema"],
)
json_schema: dict[str, Any] = Field(
description="JSON Schema used to validate and document flow state.",
examples=[
{
"type": "object",
"properties": {"topic": {"type": "string"}},
"required": ["topic"],
}
],
)
default: dict[str, Any] | None = Field(
default=None,
description="Default state values applied before kickoff inputs.",
examples=[{"topic": "AI agents", "limit": 3}],
)
class FlowUnknownStateDefinition(BaseModel):
"""Static description of a state contract that could not be serialized."""
model_config = ConfigDict(extra="forbid")
type: TypingLiteral["unknown"] = Field(
default="unknown",
description="Unknown state representation; runtime falls back to dictionary state.",
examples=["unknown"],
)
ref: str | None = Field(
default=None,
description="Best-effort import reference for the unknown state type.",
examples=["my_project.flows:CustomState"],
)
default: dict[str, Any] | None = Field(
default=None,
description="Default state values applied before kickoff inputs.",
examples=[{"topic": "AI agents", "limit": 3}],
)
FlowStateDefinition: TypeAlias = Annotated[
FlowDictStateDefinition
| FlowPydanticStateDefinition
| FlowJsonSchemaStateDefinition
| FlowUnknownStateDefinition,
Field(discriminator="type"),
]
class FlowConfigDefinition(BaseModel):

View File

@@ -193,26 +193,24 @@ def _build_definition_state_model(
kwargs = dict(state_definition.default or {})
model_class: type[BaseModel] | None = None
if state_definition.ref:
state_ref = getattr(state_definition, "ref", None)
if state_ref:
try:
resolved: Any = resolve_ref(state_definition.ref, field="state")
resolved: Any = resolve_ref(state_ref, field="state")
except Exception:
logger.warning(
"Could not import state ref %r", state_definition.ref, exc_info=True
)
logger.warning("Could not import state ref %r", state_ref, exc_info=True)
else:
if isinstance(resolved, type) and issubclass(resolved, BaseModel):
model_class = resolved
else:
logger.warning(
"State ref %r is not a pydantic model", state_definition.ref
)
logger.warning("State ref %r is not a pydantic model", state_ref)
if model_class is None and state_definition.json_schema:
json_schema = getattr(state_definition, "json_schema", None)
if model_class is None and json_schema:
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
try:
model_class = create_model_from_schema(state_definition.json_schema)
model_class = create_model_from_schema(json_schema)
except Exception:
logger.warning(
"Could not build a state model from the declared json_schema",

View File

@@ -8,7 +8,7 @@ from pathlib import Path
from typing import Annotated, Literal
import pytest
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError
import crewai.flow.dsl as flow_dsl
import crewai.flow.flow_definition as flow_definition
@@ -45,19 +45,51 @@ def test_flow_public_exports_are_explicit():
"FlowDefinition",
"FlowDefinitionCondition",
"FlowDefinitionDiagnostic",
"FlowDictStateDefinition",
"FlowEachActionDefinition",
"FlowEachInnerActionDefinition",
"FlowExpressionActionDefinition",
"FlowHumanFeedbackDefinition",
"FlowJsonSchemaStateDefinition",
"FlowMethodDefinition",
"FlowPersistenceDefinition",
"FlowPydanticStateDefinition",
"FlowStateDefinition",
"FlowToolActionDefinition",
"FlowUnknownStateDefinition",
}
assert "build_flow_structure" in flow_visualization.__all__
assert "calculate_node_levels" not in flow_visualization.__all__
def test_flow_state_definition_uses_discriminated_branches():
definition = flow_definition.FlowDefinition.model_validate(
{
"name": "TypedStateFlow",
"state": {
"type": "json_schema",
"json_schema": {"type": "object"},
},
}
)
assert isinstance(
definition.state,
flow_definition.FlowJsonSchemaStateDefinition,
)
with pytest.raises(ValidationError, match="extra_forbidden"):
flow_definition.FlowDefinition.model_validate(
{
"name": "InvalidStateFlow",
"state": {
"type": "dict",
"ref": "my_project.flows:ResearchState",
},
}
)
def test_condition_combinators_return_nested_runtime_tree():
condition = and_("event_a", "event_b", or_("event_c"))