diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 84783b081..45b093a1a 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -539,8 +539,8 @@ class Flow(Generic[T], metaclass=FlowMeta): self._persistence: FlowPersistence | None = persistence self._is_execution_resuming: bool = False - # Initialize state with initial values - self._state = self._create_initial_state() + # Initialize state with initial values and kwargs + self._state = self._create_initial_state(kwargs if kwargs else None) self.tracing = tracing if ( is_tracing_enabled() @@ -549,9 +549,6 @@ class Flow(Generic[T], metaclass=FlowMeta): ): trace_listener = TraceCollectionListener() trace_listener.setup_listeners(crewai_event_bus) - # Apply any additional kwargs - if kwargs: - self._initialize_state(kwargs) crewai_event_bus.emit( self, @@ -577,9 +574,12 @@ class Flow(Generic[T], metaclass=FlowMeta): method = method.__get__(self, self.__class__) self._methods[method_name] = method - def _create_initial_state(self) -> T: + def _create_initial_state(self, kwargs: dict[str, Any] | None = None) -> T: """Create and initialize flow state with UUID and default values. + Args: + kwargs: Optional dictionary of initial state values + Returns: New state instance with UUID and default values initialized @@ -587,13 +587,16 @@ class Flow(Generic[T], metaclass=FlowMeta): ValueError: If structured state model lacks 'id' field TypeError: If state is neither BaseModel nor dictionary """ + if kwargs is None: + kwargs = {} + # Handle case where initial_state is None but we have a type parameter if self.initial_state is None and hasattr(self, "_initial_state_t"): state_type = self._initial_state_t if isinstance(state_type, type): if issubclass(state_type, FlowState): - # Create instance without id, then set it - instance = state_type() + # Create instance with kwargs + instance = state_type(**kwargs) if not hasattr(instance, "id"): instance.id = str(uuid4()) return cast(T, instance) @@ -602,35 +605,42 @@ class Flow(Generic[T], metaclass=FlowMeta): class StateWithId(state_type, FlowState): # type: ignore pass - instance = StateWithId() + instance = StateWithId(**kwargs) if not hasattr(instance, "id"): instance.id = str(uuid4()) return cast(T, instance) if state_type is dict: - return cast(T, {"id": str(uuid4())}) + state_dict = {"id": str(uuid4())} + state_dict.update(kwargs) + return cast(T, state_dict) # Handle case where no initial state is provided if self.initial_state is None: - return cast(T, {"id": str(uuid4())}) + state_dict = {"id": str(uuid4())} + state_dict.update(kwargs) + return cast(T, state_dict) # Handle case where initial_state is a type (class) if isinstance(self.initial_state, type): if issubclass(self.initial_state, FlowState): - return cast(T, self.initial_state()) # Uses model defaults + return cast(T, self.initial_state(**kwargs)) # Uses model defaults and kwargs if issubclass(self.initial_state, BaseModel): # Validate that the model has an id field model_fields = getattr(self.initial_state, "model_fields", None) if not model_fields or "id" not in model_fields: raise ValueError("Flow state model must have an 'id' field") - return cast(T, self.initial_state()) # Uses model defaults + return cast(T, self.initial_state(**kwargs)) # Uses model defaults and kwargs if self.initial_state is dict: - return cast(T, {"id": str(uuid4())}) + state_dict = {"id": str(uuid4())} + state_dict.update(kwargs) + return cast(T, state_dict) # Handle dictionary instance case if isinstance(self.initial_state, dict): new_state = dict(self.initial_state) # Copy to avoid mutations if "id" not in new_state: new_state["id"] = str(uuid4()) + new_state.update(kwargs) # Apply kwargs return cast(T, new_state) # Handle BaseModel instance case @@ -652,6 +662,8 @@ class Flow(Generic[T], metaclass=FlowMeta): k: v for k, v in model.__dict__.items() if not k.startswith("_") } + state_dict.update(kwargs) + # Create new instance of the same class model_class = type(model) return cast(T, model_class(**state_dict)) diff --git a/tests/test_flow_pydantic_required_fields.py b/tests/test_flow_pydantic_required_fields.py new file mode 100644 index 000000000..e7ff3871a --- /dev/null +++ b/tests/test_flow_pydantic_required_fields.py @@ -0,0 +1,145 @@ +"""Tests for Flow initialization with Pydantic models having required fields. +Covers https://github.com/crewAIInc/crewAI/issues/3744 +""" + +import pytest +from pydantic import BaseModel, ValidationError + +from crewai.flow.flow import Flow, FlowState, listen, start + + +class RequiredState(BaseModel): + """State model with required fields.""" + name: str + age: int + + +class RequiredStateFlow(Flow[RequiredState]): + """Flow with required state fields.""" + + @start() + def begin(self): + return "started" + + +class MixedState(BaseModel): + """State model with both required and optional fields.""" + name: str # Required + age: int # Required + email: str = "default@example.com" # Optional with default + + +class MixedStateFlow(Flow[MixedState]): + """Flow with mixed required and optional state fields.""" + + @start() + def begin(self): + return f"Started with {self.state.name}, {self.state.age}, {self.state.email}" + + +class RequiredStateWithFlowState(FlowState): + """State model extending FlowState with required fields.""" + name: str + age: int + + +class RequiredFlowStateFlow(Flow[RequiredStateWithFlowState]): + """Flow with required FlowState fields.""" + + @start() + def begin(self): + return "started" + + +def test_flow_initialization_without_kwargs_raises_validation_error(): + """Test that Flow initialization without kwargs raises ValidationError for required fields.""" + with pytest.raises(ValidationError) as exc_info: + RequiredStateFlow() + + error_str = str(exc_info.value) + assert "name" in error_str + assert "age" in error_str + + +def test_flow_initialization_with_kwargs_passes_and_sets_state(): + """Test that Flow initialization with kwargs properly sets state values.""" + flow = RequiredStateFlow(name="John", age=30) + assert flow.state.name == "John" + assert flow.state.age == 30 + assert hasattr(flow.state, "id") + assert flow.state.id is not None + + +def test_flow_initialization_with_partial_kwargs_raises_validation_error(): + """Test that Flow initialization with only some required kwargs raises ValidationError.""" + with pytest.raises(ValidationError) as exc_info: + RequiredStateFlow(name="John") + + error_str = str(exc_info.value) + assert "age" in error_str + + +def test_flow_initialization_with_mixed_required_and_optional_fields(): + """Test Flow initialization with both required and optional fields.""" + flow1 = MixedStateFlow(name="Alice", age=25) + assert flow1.state.name == "Alice" + assert flow1.state.age == 25 + assert flow1.state.email == "default@example.com" + + flow2 = MixedStateFlow(name="Bob", age=35, email="bob@example.com") + assert flow2.state.name == "Bob" + assert flow2.state.age == 35 + assert flow2.state.email == "bob@example.com" + + +def test_flow_initialization_with_flowstate_and_required_fields(): + """Test Flow initialization with FlowState subclass having required fields.""" + flow = RequiredFlowStateFlow(name="Charlie", age=40) + assert flow.state.name == "Charlie" + assert flow.state.age == 40 + assert hasattr(flow.state, "id") + assert flow.state.id is not None + + +def test_flow_execution_with_required_state(): + """Test that Flow execution works correctly with required state fields.""" + flow = RequiredStateFlow(name="David", age=45) + result = flow.kickoff() + assert result == "started" + assert flow.state.name == "David" + assert flow.state.age == 45 + + +def test_flow_with_state_modification(): + """Test that state can be modified during flow execution.""" + + class ModifiableState(BaseModel): + counter: int + name: str + + class ModifiableFlow(Flow[ModifiableState]): + @start() + def increment(self): + self.state.counter += 10 + return "incremented" + + @listen(increment) + def check_value(self): + assert self.state.counter == 15 + return "checked" + + flow = ModifiableFlow(counter=5, name="Test") + result = flow.kickoff() + assert result == "checked" + assert flow.state.counter == 15 + + +def test_flow_initialization_preserves_id_field(): + """Test that the automatically generated id field is preserved.""" + flow = RequiredStateFlow(name="Eve", age=28) + original_id = flow.state.id + + assert isinstance(original_id, str) + assert len(original_id) == 36 # UUID format with hyphens + + assert flow.state.id == original_id