mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-04 14:39:23 +00:00
Build flow state from FlowDefinition
Definition-driven flows previously always started with a bare dict state.
This commit is contained in:
@@ -52,8 +52,9 @@ class FlowDefinitionDiagnostic(BaseModel):
|
||||
class FlowStateDefinition(BaseModel):
|
||||
"""Static description of a Flow state contract."""
|
||||
|
||||
type: TypingLiteral["dict", "pydantic", "unknown"] = "dict"
|
||||
type: TypingLiteral["dict", "pydantic", "json_schema", "unknown"] = "dict"
|
||||
ref: str | None = None
|
||||
json_schema: dict[str, Any] | None = None
|
||||
default: Any = None
|
||||
|
||||
|
||||
|
||||
@@ -96,6 +96,7 @@ from crewai.flow.flow_definition import (
|
||||
FlowDefinition,
|
||||
FlowDefinitionCondition,
|
||||
FlowMethodDefinition,
|
||||
FlowStateDefinition,
|
||||
)
|
||||
from crewai.flow.flow_wrappers import (
|
||||
FlowMethod,
|
||||
@@ -188,6 +189,54 @@ def _resolve_handler(ref: str) -> Callable[..., Any]:
|
||||
return cast(Callable[..., Any], target)
|
||||
|
||||
|
||||
def _build_definition_state_model(
|
||||
state_definition: FlowStateDefinition,
|
||||
) -> BaseModel | None:
|
||||
kwargs = (
|
||||
dict(state_definition.default)
|
||||
if isinstance(state_definition.default, dict)
|
||||
else {}
|
||||
)
|
||||
|
||||
model_class: type[BaseModel] | None = None
|
||||
if state_definition.ref:
|
||||
try:
|
||||
resolved = _resolve_handler(state_definition.ref)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Could not import state ref %r", state_definition.ref, exc_info=True
|
||||
)
|
||||
else:
|
||||
if isinstance(resolved, type) and issubclass(resolved, BaseModel):
|
||||
model_class = resolved
|
||||
else:
|
||||
logger.warning(
|
||||
"State ref %r is not a pydantic model", state_definition.ref
|
||||
)
|
||||
|
||||
if model_class is None and state_definition.json_schema:
|
||||
from crewai.utilities.pydantic_schema_utils import create_model_from_schema
|
||||
|
||||
try:
|
||||
model_class = create_model_from_schema(state_definition.json_schema)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Could not build a state model from the declared json_schema",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
if model_class is None:
|
||||
return None
|
||||
|
||||
if not issubclass(model_class, FlowState):
|
||||
|
||||
class StateWithId(FlowState, model_class): # type: ignore[misc, valid-type]
|
||||
pass
|
||||
|
||||
model_class = StateWithId
|
||||
return model_class(**kwargs)
|
||||
|
||||
|
||||
def _iter_condition_events(condition: FlowDefinitionCondition) -> Iterator[str]:
|
||||
if isinstance(condition, str):
|
||||
yield condition
|
||||
@@ -718,10 +767,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
"""Build a runnable Flow directly from a definition; no subclass required."""
|
||||
return cls.model_validate({}, context={"flow_definition": definition})
|
||||
|
||||
@property
|
||||
def _flow_name(self) -> str:
|
||||
return self.name or self._definition.name
|
||||
|
||||
def _start_method_names(self) -> list[FlowMethodName]:
|
||||
return [
|
||||
FlowMethodName(method_name)
|
||||
@@ -959,6 +1004,8 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
self._initialize_runtime_extension_attrs()
|
||||
|
||||
self._definition = definition or type(self).flow_definition()
|
||||
if self.name and self.name != self._definition.name:
|
||||
self._definition = self._definition.model_copy(update={"name": self.name})
|
||||
methods = (
|
||||
self._handler_bound_methods()
|
||||
if definition is not None
|
||||
@@ -979,7 +1026,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
self,
|
||||
FlowCreatedEvent(
|
||||
type="flow_created",
|
||||
flow_name=self._flow_name,
|
||||
flow_name=self._definition.name,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -989,7 +1036,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
if self.memory is None and not getattr(self, "_skip_auto_memory", False):
|
||||
from crewai.memory.utils import sanitize_scope_name
|
||||
|
||||
flow_name = sanitize_scope_name(self._flow_name)
|
||||
flow_name = sanitize_scope_name(self._definition.name)
|
||||
self.memory = Memory(root_scope=f"/flow/{flow_name}")
|
||||
|
||||
self._methods.update(methods)
|
||||
@@ -1427,7 +1474,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
self,
|
||||
FlowStartedEvent(
|
||||
type="flow_started",
|
||||
flow_name=self._flow_name,
|
||||
flow_name=self._definition.name,
|
||||
inputs=None,
|
||||
),
|
||||
)
|
||||
@@ -1503,7 +1550,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
self,
|
||||
MethodExecutionFinishedEvent(
|
||||
type="method_execution_finished",
|
||||
flow_name=self._flow_name,
|
||||
flow_name=self._definition.name,
|
||||
method_name=context.method_name,
|
||||
result=collapsed_outcome if emit else result,
|
||||
state=self._state,
|
||||
@@ -1557,7 +1604,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
self,
|
||||
FlowPausedEvent(
|
||||
type="flow_paused",
|
||||
flow_name=self._flow_name,
|
||||
flow_name=self._definition.name,
|
||||
flow_id=e.context.flow_id,
|
||||
method_name=e.context.method_name,
|
||||
state=self._copy_and_serialize_state(),
|
||||
@@ -1588,7 +1635,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
self,
|
||||
FlowFinishedEvent(
|
||||
type="flow_finished",
|
||||
flow_name=self._flow_name,
|
||||
flow_name=self._definition.name,
|
||||
result=final_result,
|
||||
state=self._copy_and_serialize_state(),
|
||||
),
|
||||
@@ -1654,7 +1701,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
return cast(T, {"id": str(uuid4())})
|
||||
|
||||
if init_state is None:
|
||||
return cast(T, {"id": str(uuid4())})
|
||||
return cast(T, self._create_definition_state())
|
||||
|
||||
if isinstance(init_state, type):
|
||||
state_class = init_state
|
||||
@@ -1696,6 +1743,34 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
|
||||
)
|
||||
|
||||
def _create_definition_state(self) -> dict[str, Any] | BaseModel:
|
||||
state_definition = self._definition.state
|
||||
if state_definition is None:
|
||||
return {"id": str(uuid4())}
|
||||
if state_definition.type in ("pydantic", "json_schema"):
|
||||
state = _build_definition_state_model(state_definition)
|
||||
if state is not None:
|
||||
return state
|
||||
logger.error(
|
||||
"Flow %r declares %s state but neither ref nor json_schema "
|
||||
"produced a model; falling back to dict state",
|
||||
self._definition.name,
|
||||
state_definition.type,
|
||||
)
|
||||
elif state_definition.type == "unknown":
|
||||
logger.warning(
|
||||
"Flow %r declares state of unknown type; falling back to dict state",
|
||||
self._definition.name,
|
||||
)
|
||||
dict_state: dict[str, Any] = (
|
||||
dict(state_definition.default)
|
||||
if isinstance(state_definition.default, dict)
|
||||
else {}
|
||||
)
|
||||
if "id" not in dict_state:
|
||||
dict_state["id"] = str(uuid4())
|
||||
return dict_state
|
||||
|
||||
def _copy_state(self) -> T:
|
||||
"""Create a copy of the current state.
|
||||
|
||||
@@ -2231,7 +2306,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
# explicit finalization call closes the batch.
|
||||
started_event = FlowStartedEvent(
|
||||
type="flow_started",
|
||||
flow_name=self._flow_name,
|
||||
flow_name=self._definition.name,
|
||||
inputs=inputs,
|
||||
)
|
||||
future = crewai_event_bus.emit(self, started_event)
|
||||
@@ -2323,7 +2398,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
self,
|
||||
FlowPausedEvent(
|
||||
type="flow_paused",
|
||||
flow_name=self._flow_name,
|
||||
flow_name=self._definition.name,
|
||||
flow_id=e.context.flow_id,
|
||||
method_name=e.context.method_name,
|
||||
state=self._copy_and_serialize_state(),
|
||||
@@ -2373,7 +2448,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
self,
|
||||
FlowFinishedEvent(
|
||||
type="flow_finished",
|
||||
flow_name=self._flow_name,
|
||||
flow_name=self._definition.name,
|
||||
result=final_output,
|
||||
state=self._copy_and_serialize_state(),
|
||||
),
|
||||
@@ -2459,7 +2534,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionFailedEvent,
|
||||
)
|
||||
flow_name = self._flow_name
|
||||
flow_name = self._definition.name
|
||||
nodes = sorted(
|
||||
(
|
||||
n
|
||||
@@ -2597,7 +2672,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
MethodExecutionStartedEvent(
|
||||
type="method_execution_started",
|
||||
method_name=method_name,
|
||||
flow_name=self._flow_name,
|
||||
flow_name=self._definition.name,
|
||||
params=dumped_params,
|
||||
state=self._copy_and_serialize_state(),
|
||||
),
|
||||
@@ -2649,7 +2724,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
finished_event = MethodExecutionFinishedEvent(
|
||||
type="method_execution_finished",
|
||||
method_name=method_name,
|
||||
flow_name=self._flow_name,
|
||||
flow_name=self._definition.name,
|
||||
state=self._copy_and_serialize_state(),
|
||||
result=result,
|
||||
)
|
||||
@@ -2678,7 +2753,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
MethodExecutionPausedEvent(
|
||||
type="method_execution_paused",
|
||||
method_name=method_name,
|
||||
flow_name=self._flow_name,
|
||||
flow_name=self._definition.name,
|
||||
state=self._copy_and_serialize_state(),
|
||||
flow_id=e.context.flow_id,
|
||||
message=e.context.message,
|
||||
@@ -2694,7 +2769,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
MethodExecutionFailedEvent(
|
||||
type="method_execution_failed",
|
||||
method_name=method_name,
|
||||
flow_name=self._flow_name,
|
||||
flow_name=self._definition.name,
|
||||
error=e,
|
||||
),
|
||||
)
|
||||
@@ -3101,7 +3176,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
self,
|
||||
FlowInputRequestedEvent(
|
||||
type="flow_input_requested",
|
||||
flow_name=self._flow_name,
|
||||
flow_name=self._definition.name,
|
||||
method_name=method_name,
|
||||
message=message,
|
||||
metadata=metadata,
|
||||
@@ -3168,7 +3243,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
self,
|
||||
FlowInputReceivedEvent(
|
||||
type="flow_input_received",
|
||||
flow_name=self._flow_name,
|
||||
flow_name=self._definition.name,
|
||||
method_name=method_name,
|
||||
message=message,
|
||||
response=response,
|
||||
@@ -3206,7 +3281,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
self,
|
||||
HumanFeedbackRequestedEvent(
|
||||
type="human_feedback_requested",
|
||||
flow_name=self._flow_name,
|
||||
flow_name=self._definition.name,
|
||||
method_name="", # Will be set by decorator if needed
|
||||
output=output,
|
||||
message=message,
|
||||
@@ -3235,7 +3310,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
self,
|
||||
HumanFeedbackReceivedEvent(
|
||||
type="human_feedback_received",
|
||||
flow_name=self._flow_name,
|
||||
flow_name=self._definition.name,
|
||||
method_name="", # Will be set by decorator if needed
|
||||
feedback=feedback,
|
||||
outcome=None, # Will be determined after collapsing
|
||||
@@ -3410,7 +3485,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
self,
|
||||
FlowPlotEvent(
|
||||
type="flow_plot",
|
||||
flow_name=self._flow_name,
|
||||
flow_name=self._definition.name,
|
||||
),
|
||||
)
|
||||
structure = build_flow_structure(cast(Any, self))
|
||||
|
||||
@@ -999,7 +999,11 @@ def _json_schema_to_pydantic_field(
|
||||
if examples:
|
||||
schema_extra["examples"] = examples
|
||||
|
||||
default = ... if is_required else None
|
||||
default = (
|
||||
json_schema["default"]
|
||||
if "default" in json_schema
|
||||
else (... if is_required else None)
|
||||
)
|
||||
|
||||
if isinstance(type_, type) and issubclass(type_, (int, float)):
|
||||
if "minimum" in json_schema:
|
||||
|
||||
@@ -1157,6 +1157,25 @@ def test_flow_name():
|
||||
assert flow.name == "MyFlow"
|
||||
|
||||
|
||||
def test_flow_custom_name_overrides_class_name_in_events():
|
||||
class InternalFlowClass(Flow):
|
||||
name = "PublicName"
|
||||
|
||||
@start()
|
||||
def begin(self):
|
||||
return "done"
|
||||
|
||||
received = []
|
||||
|
||||
@crewai_event_bus.on(FlowStartedEvent)
|
||||
def handle(source, event):
|
||||
received.append(event)
|
||||
|
||||
InternalFlowClass().kickoff()
|
||||
|
||||
assert received[0].flow_name == "PublicName"
|
||||
|
||||
|
||||
def test_nested_and_or_conditions():
|
||||
"""Test nested conditions like or_(and_(A, B), and_(C, D)).
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from crewai.events.types.flow_events import (
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.flow import Flow, and_, listen, or_, router, start
|
||||
from crewai.flow.flow import FlowState
|
||||
from crewai.flow.flow_definition import FlowDefinition
|
||||
|
||||
|
||||
@@ -163,6 +164,142 @@ methods:
|
||||
"""
|
||||
|
||||
|
||||
class CounterState(FlowState):
|
||||
count: int = 0
|
||||
label: str = "none"
|
||||
|
||||
|
||||
class PydanticStateFlow(Flow[CounterState]):
|
||||
@start()
|
||||
def begin(self):
|
||||
self.state.count += 1
|
||||
return self.state.count
|
||||
|
||||
@listen(begin)
|
||||
def finish(self):
|
||||
self.state.label = f"count={self.state.count}"
|
||||
return self.state.label
|
||||
|
||||
|
||||
PYDANTIC_STATE_YAML = f"""
|
||||
schema: crewai.flow/v1
|
||||
name: PydanticStateFlow
|
||||
state:
|
||||
type: pydantic
|
||||
ref: {__name__}:CounterState
|
||||
methods:
|
||||
begin:
|
||||
handler: {__name__}:PydanticStateFlow.begin
|
||||
start: true
|
||||
finish:
|
||||
handler: {__name__}:PydanticStateFlow.finish
|
||||
listen: begin
|
||||
"""
|
||||
|
||||
PYDANTIC_STATE_OVERLAY_YAML = f"""
|
||||
schema: crewai.flow/v1
|
||||
name: PydanticStateFlow
|
||||
state:
|
||||
type: pydantic
|
||||
ref: {__name__}:CounterState
|
||||
default:
|
||||
count: 5
|
||||
methods:
|
||||
begin:
|
||||
handler: {__name__}:PydanticStateFlow.begin
|
||||
start: true
|
||||
finish:
|
||||
handler: {__name__}:PydanticStateFlow.finish
|
||||
listen: begin
|
||||
"""
|
||||
|
||||
JSON_SCHEMA_STATE_YAML = f"""
|
||||
schema: crewai.flow/v1
|
||||
name: JsonSchemaStateFlow
|
||||
state:
|
||||
type: json_schema
|
||||
json_schema:
|
||||
title: CounterState
|
||||
type: object
|
||||
properties:
|
||||
count:
|
||||
type: integer
|
||||
default: 0
|
||||
label:
|
||||
type: string
|
||||
default: none
|
||||
methods:
|
||||
begin:
|
||||
handler: {__name__}:PydanticStateFlow.begin
|
||||
start: true
|
||||
finish:
|
||||
handler: {__name__}:PydanticStateFlow.finish
|
||||
listen: begin
|
||||
"""
|
||||
|
||||
PYDANTIC_REF_WITH_SCHEMA_FALLBACK_YAML = f"""
|
||||
schema: crewai.flow/v1
|
||||
name: SchemaFallbackFlow
|
||||
state:
|
||||
type: pydantic
|
||||
ref: definitely_not_a_module_xyz:MissingState
|
||||
json_schema:
|
||||
title: CounterState
|
||||
type: object
|
||||
properties:
|
||||
count:
|
||||
type: integer
|
||||
default: 0
|
||||
label:
|
||||
type: string
|
||||
default: none
|
||||
methods:
|
||||
begin:
|
||||
handler: {__name__}:PydanticStateFlow.begin
|
||||
start: true
|
||||
finish:
|
||||
handler: {__name__}:PydanticStateFlow.finish
|
||||
listen: begin
|
||||
"""
|
||||
|
||||
UNRESOLVABLE_STATE_YAML = f"""
|
||||
schema: crewai.flow/v1
|
||||
name: UnresolvableStateFlow
|
||||
state:
|
||||
type: pydantic
|
||||
ref: definitely_not_a_module_xyz:MissingState
|
||||
methods:
|
||||
begin:
|
||||
handler: {__name__}:ChainFlow.begin
|
||||
start: true
|
||||
"""
|
||||
|
||||
DICT_STATE_YAML = f"""
|
||||
schema: crewai.flow/v1
|
||||
name: DictStateFlow
|
||||
state:
|
||||
type: dict
|
||||
default:
|
||||
count: 5
|
||||
methods:
|
||||
begin:
|
||||
handler: {__name__}:ChainFlow.begin
|
||||
start: true
|
||||
"""
|
||||
|
||||
UNKNOWN_STATE_YAML = f"""
|
||||
schema: crewai.flow/v1
|
||||
name: UnknownStateFlow
|
||||
state:
|
||||
type: unknown
|
||||
ref: somewhere:Something
|
||||
methods:
|
||||
begin:
|
||||
handler: {__name__}:ChainFlow.begin
|
||||
start: true
|
||||
"""
|
||||
|
||||
|
||||
def _run_with_events(flow, inputs=None):
|
||||
events = []
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
@@ -183,7 +320,7 @@ def _run_with_events(flow, inputs=None):
|
||||
|
||||
|
||||
def _state_without_id(flow):
|
||||
snapshot = dict(flow.state)
|
||||
snapshot = dict(flow.state.model_dump())
|
||||
snapshot.pop("id", None)
|
||||
return snapshot
|
||||
|
||||
@@ -293,3 +430,79 @@ def test_flow_definition_stamps_handler_refs():
|
||||
|
||||
assert definition.methods["begin"].handler == f"{__name__}:ChainFlow.begin"
|
||||
assert definition.methods["shout"].handler == f"{__name__}:ChainFlow.shout"
|
||||
|
||||
|
||||
def test_pydantic_state_from_ref_parity():
|
||||
flow, result = assert_parity(PydanticStateFlow, PYDANTIC_STATE_YAML)
|
||||
assert result == "count=1"
|
||||
assert flow.state.count == 1
|
||||
assert flow.state.label == "count=1"
|
||||
assert flow.state.id
|
||||
|
||||
|
||||
def test_pydantic_state_default_overlay():
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(PYDANTIC_STATE_OVERLAY_YAML))
|
||||
result = flow.kickoff()
|
||||
assert result == "count=6"
|
||||
assert flow.state.count == 6
|
||||
|
||||
|
||||
def test_json_schema_state():
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(JSON_SCHEMA_STATE_YAML))
|
||||
result = flow.kickoff()
|
||||
assert result == "count=1"
|
||||
assert flow.state.count == 1
|
||||
assert flow.state.label == "count=1"
|
||||
assert flow.state.id
|
||||
|
||||
|
||||
def test_json_schema_state_validates_inputs():
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(JSON_SCHEMA_STATE_YAML))
|
||||
with pytest.raises(ValueError, match="Invalid inputs"):
|
||||
flow.kickoff(inputs={"count": "not-a-number"})
|
||||
|
||||
|
||||
def test_pydantic_state_falls_back_to_json_schema_when_ref_unimportable():
|
||||
flow = Flow.from_definition(
|
||||
FlowDefinition.from_yaml(PYDANTIC_REF_WITH_SCHEMA_FALLBACK_YAML)
|
||||
)
|
||||
result = flow.kickoff()
|
||||
assert result == "count=1"
|
||||
assert flow.state.count == 1
|
||||
|
||||
|
||||
def test_pydantic_state_without_ref_or_schema_falls_back_to_dict(caplog):
|
||||
with caplog.at_level("ERROR"):
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(UNRESOLVABLE_STATE_YAML))
|
||||
assert "falling back to dict state" in caplog.text
|
||||
|
||||
result = flow.kickoff()
|
||||
assert result == "hello"
|
||||
assert flow.state["begin_ran"] is True
|
||||
assert flow.state["id"]
|
||||
|
||||
|
||||
def test_dict_state_is_a_copy_of_default_plus_id():
|
||||
definition = FlowDefinition.from_yaml(DICT_STATE_YAML)
|
||||
|
||||
flow = Flow.from_definition(definition)
|
||||
assert flow.state["count"] == 5
|
||||
assert flow.state["id"]
|
||||
flow.kickoff()
|
||||
assert flow.state["begin_ran"] is True
|
||||
|
||||
second = Flow.from_definition(definition)
|
||||
assert second.state["count"] == 5
|
||||
assert "begin_ran" not in second.state.model_dump()
|
||||
assert second.state["id"] != flow.state["id"]
|
||||
assert definition.state.default == {"count": 5}
|
||||
|
||||
|
||||
def test_unknown_state_type_falls_back_to_dict(caplog):
|
||||
with caplog.at_level("WARNING"):
|
||||
flow = Flow.from_definition(FlowDefinition.from_yaml(UNKNOWN_STATE_YAML))
|
||||
assert "falling back to dict state" in caplog.text
|
||||
|
||||
result = flow.kickoff()
|
||||
assert result == "hello"
|
||||
assert flow.state["begin_ran"] is True
|
||||
|
||||
Reference in New Issue
Block a user