Fix Flow initialization with Pydantic models having required fields

- Modified _create_initial_state() to accept kwargs parameter
- Pass kwargs when instantiating BaseModel classes
- Updated __init__() to pass kwargs to _create_initial_state()
- Added comprehensive tests covering various scenarios

Fixes #3744

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-10-20 19:50:16 +00:00
parent 42f2b4d551
commit 29617cd228
2 changed files with 171 additions and 14 deletions

View File

@@ -539,8 +539,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._persistence: FlowPersistence | None = persistence
self._is_execution_resuming: bool = False
# Initialize state with initial values
self._state = self._create_initial_state()
# Initialize state with initial values and kwargs
self._state = self._create_initial_state(kwargs if kwargs else None)
self.tracing = tracing
if (
is_tracing_enabled()
@@ -549,9 +549,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
):
trace_listener = TraceCollectionListener()
trace_listener.setup_listeners(crewai_event_bus)
# Apply any additional kwargs
if kwargs:
self._initialize_state(kwargs)
crewai_event_bus.emit(
self,
@@ -577,9 +574,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
method = method.__get__(self, self.__class__)
self._methods[method_name] = method
def _create_initial_state(self) -> T:
def _create_initial_state(self, kwargs: dict[str, Any] | None = None) -> T:
"""Create and initialize flow state with UUID and default values.
Args:
kwargs: Optional dictionary of initial state values
Returns:
New state instance with UUID and default values initialized
@@ -587,13 +587,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
ValueError: If structured state model lacks 'id' field
TypeError: If state is neither BaseModel nor dictionary
"""
if kwargs is None:
kwargs = {}
# Handle case where initial_state is None but we have a type parameter
if self.initial_state is None and hasattr(self, "_initial_state_t"):
state_type = self._initial_state_t
if isinstance(state_type, type):
if issubclass(state_type, FlowState):
# Create instance without id, then set it
instance = state_type()
# Create instance with kwargs
instance = state_type(**kwargs)
if not hasattr(instance, "id"):
instance.id = str(uuid4())
return cast(T, instance)
@@ -602,35 +605,42 @@ class Flow(Generic[T], metaclass=FlowMeta):
class StateWithId(state_type, FlowState): # type: ignore
pass
instance = StateWithId()
instance = StateWithId(**kwargs)
if not hasattr(instance, "id"):
instance.id = str(uuid4())
return cast(T, instance)
if state_type is dict:
return cast(T, {"id": str(uuid4())})
state_dict = {"id": str(uuid4())}
state_dict.update(kwargs)
return cast(T, state_dict)
# Handle case where no initial state is provided
if self.initial_state is None:
return cast(T, {"id": str(uuid4())})
state_dict = {"id": str(uuid4())}
state_dict.update(kwargs)
return cast(T, state_dict)
# Handle case where initial_state is a type (class)
if isinstance(self.initial_state, type):
if issubclass(self.initial_state, FlowState):
return cast(T, self.initial_state()) # Uses model defaults
return cast(T, self.initial_state(**kwargs)) # Uses model defaults and kwargs
if issubclass(self.initial_state, BaseModel):
# Validate that the model has an id field
model_fields = getattr(self.initial_state, "model_fields", None)
if not model_fields or "id" not in model_fields:
raise ValueError("Flow state model must have an 'id' field")
return cast(T, self.initial_state()) # Uses model defaults
return cast(T, self.initial_state(**kwargs)) # Uses model defaults and kwargs
if self.initial_state is dict:
return cast(T, {"id": str(uuid4())})
state_dict = {"id": str(uuid4())}
state_dict.update(kwargs)
return cast(T, state_dict)
# Handle dictionary instance case
if isinstance(self.initial_state, dict):
new_state = dict(self.initial_state) # Copy to avoid mutations
if "id" not in new_state:
new_state["id"] = str(uuid4())
new_state.update(kwargs) # Apply kwargs
return cast(T, new_state)
# Handle BaseModel instance case
@@ -652,6 +662,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
k: v for k, v in model.__dict__.items() if not k.startswith("_")
}
state_dict.update(kwargs)
# Create new instance of the same class
model_class = type(model)
return cast(T, model_class(**state_dict))