Build flow state from FlowDefinition

Definition-driven flows previously always started with a bare dict
state.
This commit is contained in:
Vinicius Brasil
2026-06-10 15:46:24 -07:00
parent 2652caea2d
commit d029a5cd92
5 changed files with 339 additions and 27 deletions

View File

@@ -52,8 +52,9 @@ class FlowDefinitionDiagnostic(BaseModel):
class FlowStateDefinition(BaseModel):
"""Static description of a Flow state contract."""
type: TypingLiteral["dict", "pydantic", "unknown"] = "dict"
type: TypingLiteral["dict", "pydantic", "json_schema", "unknown"] = "dict"
ref: str | None = None
json_schema: dict[str, Any] | None = None
default: Any = None

View File

@@ -96,6 +96,7 @@ from crewai.flow.flow_definition import (
FlowDefinition,
FlowDefinitionCondition,
FlowMethodDefinition,
FlowStateDefinition,
)
from crewai.flow.flow_wrappers import (
FlowMethod,
@@ -188,6 +189,54 @@ def _resolve_handler(ref: str) -> Callable[..., Any]:
return cast(Callable[..., Any], target)
def _build_definition_state_model(
state_definition: FlowStateDefinition,
) -> BaseModel | None:
kwargs = (
dict(state_definition.default)
if isinstance(state_definition.default, dict)
else {}
)
model_class: type[BaseModel] | None = None
if state_definition.ref:
try:
resolved = _resolve_handler(state_definition.ref)
except Exception:
logger.warning(
"Could not import state ref %r", state_definition.ref, exc_info=True
)
else:
if isinstance(resolved, type) and issubclass(resolved, BaseModel):
model_class = resolved
else:
logger.warning(
"State ref %r is not a pydantic model", state_definition.ref
)
if model_class is None and state_definition.json_schema:
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
try:
model_class = create_model_from_schema(state_definition.json_schema)
except Exception:
logger.warning(
"Could not build a state model from the declared json_schema",
exc_info=True,
)
if model_class is None:
return None
if not issubclass(model_class, FlowState):
class StateWithId(FlowState, model_class): # type: ignore[misc, valid-type]
pass
model_class = StateWithId
return model_class(**kwargs)
def _iter_condition_events(condition: FlowDefinitionCondition) -> Iterator[str]:
if isinstance(condition, str):
yield condition
@@ -718,10 +767,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
"""Build a runnable Flow directly from a definition; no subclass required."""
return cls.model_validate({}, context={"flow_definition": definition})
@property
def _flow_name(self) -> str:
return self.name or self._definition.name
def _start_method_names(self) -> list[FlowMethodName]:
return [
FlowMethodName(method_name)
@@ -959,6 +1004,8 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
self._initialize_runtime_extension_attrs()
self._definition = definition or type(self).flow_definition()
if self.name and self.name != self._definition.name:
self._definition = self._definition.model_copy(update={"name": self.name})
methods = (
self._handler_bound_methods()
if definition is not None
@@ -979,7 +1026,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
self,
FlowCreatedEvent(
type="flow_created",
flow_name=self._flow_name,
flow_name=self._definition.name,
),
)
@@ -989,7 +1036,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
if self.memory is None and not getattr(self, "_skip_auto_memory", False):
from crewai.memory.utils import sanitize_scope_name
flow_name = sanitize_scope_name(self._flow_name)
flow_name = sanitize_scope_name(self._definition.name)
self.memory = Memory(root_scope=f"/flow/{flow_name}")
self._methods.update(methods)
@@ -1427,7 +1474,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
self,
FlowStartedEvent(
type="flow_started",
flow_name=self._flow_name,
flow_name=self._definition.name,
inputs=None,
),
)
@@ -1503,7 +1550,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
self,
MethodExecutionFinishedEvent(
type="method_execution_finished",
flow_name=self._flow_name,
flow_name=self._definition.name,
method_name=context.method_name,
result=collapsed_outcome if emit else result,
state=self._state,
@@ -1557,7 +1604,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
self,
FlowPausedEvent(
type="flow_paused",
flow_name=self._flow_name,
flow_name=self._definition.name,
flow_id=e.context.flow_id,
method_name=e.context.method_name,
state=self._copy_and_serialize_state(),
@@ -1588,7 +1635,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
self,
FlowFinishedEvent(
type="flow_finished",
flow_name=self._flow_name,
flow_name=self._definition.name,
result=final_result,
state=self._copy_and_serialize_state(),
),
@@ -1654,7 +1701,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
return cast(T, {"id": str(uuid4())})
if init_state is None:
return cast(T, {"id": str(uuid4())})
return cast(T, self._create_definition_state())
if isinstance(init_state, type):
state_class = init_state
@@ -1696,6 +1743,34 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
)
def _create_definition_state(self) -> dict[str, Any] | BaseModel:
state_definition = self._definition.state
if state_definition is None:
return {"id": str(uuid4())}
if state_definition.type in ("pydantic", "json_schema"):
state = _build_definition_state_model(state_definition)
if state is not None:
return state
logger.error(
"Flow %r declares %s state but neither ref nor json_schema "
"produced a model; falling back to dict state",
self._definition.name,
state_definition.type,
)
elif state_definition.type == "unknown":
logger.warning(
"Flow %r declares state of unknown type; falling back to dict state",
self._definition.name,
)
dict_state: dict[str, Any] = (
dict(state_definition.default)
if isinstance(state_definition.default, dict)
else {}
)
if "id" not in dict_state:
dict_state["id"] = str(uuid4())
return dict_state
def _copy_state(self) -> T:
"""Create a copy of the current state.
@@ -2231,7 +2306,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
# explicit finalization call closes the batch.
started_event = FlowStartedEvent(
type="flow_started",
flow_name=self._flow_name,
flow_name=self._definition.name,
inputs=inputs,
)
future = crewai_event_bus.emit(self, started_event)
@@ -2323,7 +2398,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
self,
FlowPausedEvent(
type="flow_paused",
flow_name=self._flow_name,
flow_name=self._definition.name,
flow_id=e.context.flow_id,
method_name=e.context.method_name,
state=self._copy_and_serialize_state(),
@@ -2373,7 +2448,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
self,
FlowFinishedEvent(
type="flow_finished",
flow_name=self._flow_name,
flow_name=self._definition.name,
result=final_output,
state=self._copy_and_serialize_state(),
),
@@ -2459,7 +2534,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
MethodExecutionFinishedEvent,
MethodExecutionFailedEvent,
)
flow_name = self._flow_name
flow_name = self._definition.name
nodes = sorted(
(
n
@@ -2597,7 +2672,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
MethodExecutionStartedEvent(
type="method_execution_started",
method_name=method_name,
flow_name=self._flow_name,
flow_name=self._definition.name,
params=dumped_params,
state=self._copy_and_serialize_state(),
),
@@ -2649,7 +2724,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
finished_event = MethodExecutionFinishedEvent(
type="method_execution_finished",
method_name=method_name,
flow_name=self._flow_name,
flow_name=self._definition.name,
state=self._copy_and_serialize_state(),
result=result,
)
@@ -2678,7 +2753,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
MethodExecutionPausedEvent(
type="method_execution_paused",
method_name=method_name,
flow_name=self._flow_name,
flow_name=self._definition.name,
state=self._copy_and_serialize_state(),
flow_id=e.context.flow_id,
message=e.context.message,
@@ -2694,7 +2769,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
MethodExecutionFailedEvent(
type="method_execution_failed",
method_name=method_name,
flow_name=self._flow_name,
flow_name=self._definition.name,
error=e,
),
)
@@ -3101,7 +3176,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
self,
FlowInputRequestedEvent(
type="flow_input_requested",
flow_name=self._flow_name,
flow_name=self._definition.name,
method_name=method_name,
message=message,
metadata=metadata,
@@ -3168,7 +3243,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
self,
FlowInputReceivedEvent(
type="flow_input_received",
flow_name=self._flow_name,
flow_name=self._definition.name,
method_name=method_name,
message=message,
response=response,
@@ -3206,7 +3281,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
self,
HumanFeedbackRequestedEvent(
type="human_feedback_requested",
flow_name=self._flow_name,
flow_name=self._definition.name,
method_name="", # Will be set by decorator if needed
output=output,
message=message,
@@ -3235,7 +3310,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
self,
HumanFeedbackReceivedEvent(
type="human_feedback_received",
flow_name=self._flow_name,
flow_name=self._definition.name,
method_name="", # Will be set by decorator if needed
feedback=feedback,
outcome=None, # Will be determined after collapsing
@@ -3410,7 +3485,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
self,
FlowPlotEvent(
type="flow_plot",
flow_name=self._flow_name,
flow_name=self._definition.name,
),
)
structure = build_flow_structure(cast(Any, self))

View File

@@ -999,7 +999,11 @@ def _json_schema_to_pydantic_field(
if examples:
schema_extra["examples"] = examples
default = ... if is_required else None
default = (
json_schema["default"]
if "default" in json_schema
else (... if is_required else None)
)
if isinstance(type_, type) and issubclass(type_, (int, float)):
if "minimum" in json_schema:

View File

@@ -1157,6 +1157,25 @@ def test_flow_name():
assert flow.name == "MyFlow"
def test_flow_custom_name_overrides_class_name_in_events():
class InternalFlowClass(Flow):
name = "PublicName"
@start()
def begin(self):
return "done"
received = []
@crewai_event_bus.on(FlowStartedEvent)
def handle(source, event):
received.append(event)
InternalFlowClass().kickoff()
assert received[0].flow_name == "PublicName"
def test_nested_and_or_conditions():
"""Test nested conditions like or_(and_(A, B), and_(C, D)).

View File

@@ -8,6 +8,7 @@ from crewai.events.types.flow_events import (
MethodExecutionStartedEvent,
)
from crewai.flow import Flow, and_, listen, or_, router, start
from crewai.flow.flow import FlowState
from crewai.flow.flow_definition import FlowDefinition
@@ -163,6 +164,142 @@ methods:
"""
class CounterState(FlowState):
count: int = 0
label: str = "none"
class PydanticStateFlow(Flow[CounterState]):
@start()
def begin(self):
self.state.count += 1
return self.state.count
@listen(begin)
def finish(self):
self.state.label = f"count={self.state.count}"
return self.state.label
PYDANTIC_STATE_YAML = f"""
schema: crewai.flow/v1
name: PydanticStateFlow
state:
type: pydantic
ref: {__name__}:CounterState
methods:
begin:
handler: {__name__}:PydanticStateFlow.begin
start: true
finish:
handler: {__name__}:PydanticStateFlow.finish
listen: begin
"""
PYDANTIC_STATE_OVERLAY_YAML = f"""
schema: crewai.flow/v1
name: PydanticStateFlow
state:
type: pydantic
ref: {__name__}:CounterState
default:
count: 5
methods:
begin:
handler: {__name__}:PydanticStateFlow.begin
start: true
finish:
handler: {__name__}:PydanticStateFlow.finish
listen: begin
"""
JSON_SCHEMA_STATE_YAML = f"""
schema: crewai.flow/v1
name: JsonSchemaStateFlow
state:
type: json_schema
json_schema:
title: CounterState
type: object
properties:
count:
type: integer
default: 0
label:
type: string
default: none
methods:
begin:
handler: {__name__}:PydanticStateFlow.begin
start: true
finish:
handler: {__name__}:PydanticStateFlow.finish
listen: begin
"""
PYDANTIC_REF_WITH_SCHEMA_FALLBACK_YAML = f"""
schema: crewai.flow/v1
name: SchemaFallbackFlow
state:
type: pydantic
ref: definitely_not_a_module_xyz:MissingState
json_schema:
title: CounterState
type: object
properties:
count:
type: integer
default: 0
label:
type: string
default: none
methods:
begin:
handler: {__name__}:PydanticStateFlow.begin
start: true
finish:
handler: {__name__}:PydanticStateFlow.finish
listen: begin
"""
UNRESOLVABLE_STATE_YAML = f"""
schema: crewai.flow/v1
name: UnresolvableStateFlow
state:
type: pydantic
ref: definitely_not_a_module_xyz:MissingState
methods:
begin:
handler: {__name__}:ChainFlow.begin
start: true
"""
DICT_STATE_YAML = f"""
schema: crewai.flow/v1
name: DictStateFlow
state:
type: dict
default:
count: 5
methods:
begin:
handler: {__name__}:ChainFlow.begin
start: true
"""
UNKNOWN_STATE_YAML = f"""
schema: crewai.flow/v1
name: UnknownStateFlow
state:
type: unknown
ref: somewhere:Something
methods:
begin:
handler: {__name__}:ChainFlow.begin
start: true
"""
def _run_with_events(flow, inputs=None):
events = []
with crewai_event_bus.scoped_handlers():
@@ -183,7 +320,7 @@ def _run_with_events(flow, inputs=None):
def _state_without_id(flow):
snapshot = dict(flow.state)
snapshot = dict(flow.state.model_dump())
snapshot.pop("id", None)
return snapshot
@@ -293,3 +430,79 @@ def test_flow_definition_stamps_handler_refs():
assert definition.methods["begin"].handler == f"{__name__}:ChainFlow.begin"
assert definition.methods["shout"].handler == f"{__name__}:ChainFlow.shout"
def test_pydantic_state_from_ref_parity():
flow, result = assert_parity(PydanticStateFlow, PYDANTIC_STATE_YAML)
assert result == "count=1"
assert flow.state.count == 1
assert flow.state.label == "count=1"
assert flow.state.id
def test_pydantic_state_default_overlay():
flow = Flow.from_definition(FlowDefinition.from_yaml(PYDANTIC_STATE_OVERLAY_YAML))
result = flow.kickoff()
assert result == "count=6"
assert flow.state.count == 6
def test_json_schema_state():
flow = Flow.from_definition(FlowDefinition.from_yaml(JSON_SCHEMA_STATE_YAML))
result = flow.kickoff()
assert result == "count=1"
assert flow.state.count == 1
assert flow.state.label == "count=1"
assert flow.state.id
def test_json_schema_state_validates_inputs():
flow = Flow.from_definition(FlowDefinition.from_yaml(JSON_SCHEMA_STATE_YAML))
with pytest.raises(ValueError, match="Invalid inputs"):
flow.kickoff(inputs={"count": "not-a-number"})
def test_pydantic_state_falls_back_to_json_schema_when_ref_unimportable():
flow = Flow.from_definition(
FlowDefinition.from_yaml(PYDANTIC_REF_WITH_SCHEMA_FALLBACK_YAML)
)
result = flow.kickoff()
assert result == "count=1"
assert flow.state.count == 1
def test_pydantic_state_without_ref_or_schema_falls_back_to_dict(caplog):
with caplog.at_level("ERROR"):
flow = Flow.from_definition(FlowDefinition.from_yaml(UNRESOLVABLE_STATE_YAML))
assert "falling back to dict state" in caplog.text
result = flow.kickoff()
assert result == "hello"
assert flow.state["begin_ran"] is True
assert flow.state["id"]
def test_dict_state_is_a_copy_of_default_plus_id():
definition = FlowDefinition.from_yaml(DICT_STATE_YAML)
flow = Flow.from_definition(definition)
assert flow.state["count"] == 5
assert flow.state["id"]
flow.kickoff()
assert flow.state["begin_ran"] is True
second = Flow.from_definition(definition)
assert second.state["count"] == 5
assert "begin_ran" not in second.state.model_dump()
assert second.state["id"] != flow.state["id"]
assert definition.state.default == {"count": 5}
def test_unknown_state_type_falls_back_to_dict(caplog):
with caplog.at_level("WARNING"):
flow = Flow.from_definition(FlowDefinition.from_yaml(UNKNOWN_STATE_YAML))
assert "falling back to dict state" in caplog.text
result = flow.kickoff()
assert result == "hello"
assert flow.state["begin_ran"] is True