mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 17:18:29 +00:00
Fix Flow initialization with Pydantic models having required fields
- Modified _create_initial_state() to accept kwargs parameter - Updated all model instantiation points to use kwargs - Removed redundant _initialize_state() call from __init__() - Added comprehensive tests for required fields, optional fields, and backward compatibility - Fixes issue #3629 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -465,7 +465,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._is_execution_resuming: bool = False
|
||||
|
||||
# Initialize state with initial values
|
||||
self._state = self._create_initial_state()
|
||||
self._state = self._create_initial_state(kwargs)
|
||||
self.tracing = tracing
|
||||
if (
|
||||
is_tracing_enabled()
|
||||
@@ -474,9 +474,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
):
|
||||
trace_listener = TraceCollectionListener()
|
||||
trace_listener.setup_listeners(crewai_event_bus)
|
||||
# Apply any additional kwargs
|
||||
if kwargs:
|
||||
self._initialize_state(kwargs)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -502,9 +499,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
method = method.__get__(self, self.__class__)
|
||||
self._methods[method_name] = method
|
||||
|
||||
def _create_initial_state(self) -> T:
|
||||
def _create_initial_state(self, kwargs: dict[str, Any] | None = None) -> T:
|
||||
"""Create and initialize flow state with UUID and default values.
|
||||
|
||||
Args:
|
||||
kwargs: Optional initial values for state fields
|
||||
|
||||
Returns:
|
||||
New state instance with UUID and default values initialized
|
||||
|
||||
@@ -518,7 +518,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if isinstance(state_type, type):
|
||||
if issubclass(state_type, FlowState):
|
||||
# Create instance without id, then set it
|
||||
instance = state_type()
|
||||
init_kwargs = kwargs or {}
|
||||
instance = state_type(**init_kwargs)
|
||||
if not hasattr(instance, "id"):
|
||||
instance.id = str(uuid4())
|
||||
return cast(T, instance)
|
||||
@@ -527,7 +528,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
class StateWithId(state_type, FlowState): # type: ignore
|
||||
pass
|
||||
|
||||
instance = StateWithId()
|
||||
init_kwargs = kwargs or {}
|
||||
instance = StateWithId(**init_kwargs)
|
||||
if not hasattr(instance, "id"):
|
||||
instance.id = str(uuid4())
|
||||
return cast(T, instance)
|
||||
@@ -541,13 +543,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# Handle case where initial_state is a type (class)
|
||||
if isinstance(self.initial_state, type):
|
||||
if issubclass(self.initial_state, FlowState):
|
||||
return cast(T, self.initial_state()) # Uses model defaults
|
||||
return cast(T, self.initial_state(**(kwargs or {})))
|
||||
if issubclass(self.initial_state, BaseModel):
|
||||
# Validate that the model has an id field
|
||||
model_fields = getattr(self.initial_state, "model_fields", None)
|
||||
if not model_fields or "id" not in model_fields:
|
||||
raise ValueError("Flow state model must have an 'id' field")
|
||||
return cast(T, self.initial_state()) # Uses model defaults
|
||||
return cast(T, self.initial_state(**(kwargs or {})))
|
||||
if self.initial_state is dict:
|
||||
return cast(T, {"id": str(uuid4())})
|
||||
|
||||
|
||||
@@ -894,3 +894,111 @@ def test_flow_name():
|
||||
|
||||
flow = MyFlow()
|
||||
assert flow.name == "MyFlow"
|
||||
|
||||
|
||||
def test_flow_init_with_required_fields():
|
||||
"""Test Flow initialization with Pydantic models having required fields."""
|
||||
|
||||
class RequiredFieldsState(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
class RequiredFieldsFlow(Flow[RequiredFieldsState]):
|
||||
@start()
|
||||
def step_1(self):
|
||||
assert self.state.name == "Alice"
|
||||
assert self.state.age == 30
|
||||
|
||||
flow = RequiredFieldsFlow(name="Alice", age=30)
|
||||
flow.kickoff()
|
||||
|
||||
assert flow.state.name == "Alice"
|
||||
assert flow.state.age == 30
|
||||
assert hasattr(flow.state, "id")
|
||||
assert len(flow.state.id) == 36
|
||||
|
||||
|
||||
def test_flow_init_with_required_fields_missing_values():
|
||||
"""Test that Flow initialization fails when required fields are missing."""
|
||||
|
||||
class RequiredFieldsState(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
class RequiredFieldsFlow(Flow[RequiredFieldsState]):
|
||||
@start()
|
||||
def step_1(self):
|
||||
pass
|
||||
|
||||
with pytest.raises(Exception):
|
||||
flow = RequiredFieldsFlow()
|
||||
|
||||
|
||||
def test_flow_init_with_mixed_required_optional_fields():
|
||||
"""Test Flow with both required and optional fields."""
|
||||
|
||||
class MixedFieldsState(BaseModel):
|
||||
name: str
|
||||
age: int = 25
|
||||
city: str | None = None
|
||||
|
||||
class MixedFieldsFlow(Flow[MixedFieldsState]):
|
||||
@start()
|
||||
def step_1(self):
|
||||
assert self.state.name == "Bob"
|
||||
assert self.state.age == 25
|
||||
assert self.state.city is None
|
||||
|
||||
flow = MixedFieldsFlow(name="Bob")
|
||||
flow.kickoff()
|
||||
|
||||
assert flow.state.name == "Bob"
|
||||
assert flow.state.age == 25
|
||||
assert flow.state.city is None
|
||||
|
||||
|
||||
def test_flow_init_with_required_fields_and_overrides():
|
||||
"""Test that kwargs override default values."""
|
||||
|
||||
class DefaultFieldsState(BaseModel):
|
||||
name: str
|
||||
age: int = 18
|
||||
active: bool = True
|
||||
|
||||
class DefaultFieldsFlow(Flow[DefaultFieldsState]):
|
||||
@start()
|
||||
def step_1(self):
|
||||
assert self.state.name == "Charlie"
|
||||
assert self.state.age == 35
|
||||
assert self.state.active is False
|
||||
|
||||
flow = DefaultFieldsFlow(name="Charlie", age=35, active=False)
|
||||
flow.kickoff()
|
||||
|
||||
assert flow.state.name == "Charlie"
|
||||
assert flow.state.age == 35
|
||||
assert flow.state.active is False
|
||||
|
||||
|
||||
def test_flow_init_backward_compatibility_with_flowstate():
|
||||
"""Test that existing FlowState subclasses still work."""
|
||||
from crewai.flow.flow import FlowState
|
||||
|
||||
class MyFlowState(FlowState):
|
||||
counter: int = 0
|
||||
message: str = "default"
|
||||
|
||||
class BackwardCompatFlow(Flow[MyFlowState]):
|
||||
@start()
|
||||
def step_1(self):
|
||||
self.state.counter += 1
|
||||
|
||||
flow1 = BackwardCompatFlow()
|
||||
flow1.kickoff()
|
||||
assert flow1.state.counter == 1
|
||||
assert flow1.state.message == "default"
|
||||
|
||||
flow2 = BackwardCompatFlow(counter=10, message="custom")
|
||||
flow2.kickoff()
|
||||
assert flow2.state.counter == 11
|
||||
assert flow2.state.message == "custom"
|
||||
|
||||
Reference in New Issue
Block a user