mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
Fix Flow initialization with Pydantic models having required fields
- Modified _create_initial_state() to accept kwargs parameter - Pass kwargs when instantiating BaseModel classes - Updated __init__() to pass kwargs to _create_initial_state() - Added comprehensive tests covering various scenarios Fixes #3744 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -539,8 +539,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
self._persistence: FlowPersistence | None = persistence
|
self._persistence: FlowPersistence | None = persistence
|
||||||
self._is_execution_resuming: bool = False
|
self._is_execution_resuming: bool = False
|
||||||
|
|
||||||
# Initialize state with initial values
|
# Initialize state with initial values and kwargs
|
||||||
self._state = self._create_initial_state()
|
self._state = self._create_initial_state(kwargs if kwargs else None)
|
||||||
self.tracing = tracing
|
self.tracing = tracing
|
||||||
if (
|
if (
|
||||||
is_tracing_enabled()
|
is_tracing_enabled()
|
||||||
@@ -549,9 +549,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
):
|
):
|
||||||
trace_listener = TraceCollectionListener()
|
trace_listener = TraceCollectionListener()
|
||||||
trace_listener.setup_listeners(crewai_event_bus)
|
trace_listener.setup_listeners(crewai_event_bus)
|
||||||
# Apply any additional kwargs
|
|
||||||
if kwargs:
|
|
||||||
self._initialize_state(kwargs)
|
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
@@ -577,9 +574,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
method = method.__get__(self, self.__class__)
|
method = method.__get__(self, self.__class__)
|
||||||
self._methods[method_name] = method
|
self._methods[method_name] = method
|
||||||
|
|
||||||
def _create_initial_state(self) -> T:
|
def _create_initial_state(self, kwargs: dict[str, Any] | None = None) -> T:
|
||||||
"""Create and initialize flow state with UUID and default values.
|
"""Create and initialize flow state with UUID and default values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kwargs: Optional dictionary of initial state values
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
New state instance with UUID and default values initialized
|
New state instance with UUID and default values initialized
|
||||||
|
|
||||||
@@ -587,13 +587,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
ValueError: If structured state model lacks 'id' field
|
ValueError: If structured state model lacks 'id' field
|
||||||
TypeError: If state is neither BaseModel nor dictionary
|
TypeError: If state is neither BaseModel nor dictionary
|
||||||
"""
|
"""
|
||||||
|
if kwargs is None:
|
||||||
|
kwargs = {}
|
||||||
|
|
||||||
# Handle case where initial_state is None but we have a type parameter
|
# Handle case where initial_state is None but we have a type parameter
|
||||||
if self.initial_state is None and hasattr(self, "_initial_state_t"):
|
if self.initial_state is None and hasattr(self, "_initial_state_t"):
|
||||||
state_type = self._initial_state_t
|
state_type = self._initial_state_t
|
||||||
if isinstance(state_type, type):
|
if isinstance(state_type, type):
|
||||||
if issubclass(state_type, FlowState):
|
if issubclass(state_type, FlowState):
|
||||||
# Create instance without id, then set it
|
# Create instance with kwargs
|
||||||
instance = state_type()
|
instance = state_type(**kwargs)
|
||||||
if not hasattr(instance, "id"):
|
if not hasattr(instance, "id"):
|
||||||
instance.id = str(uuid4())
|
instance.id = str(uuid4())
|
||||||
return cast(T, instance)
|
return cast(T, instance)
|
||||||
@@ -602,35 +605,42 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
class StateWithId(state_type, FlowState): # type: ignore
|
class StateWithId(state_type, FlowState): # type: ignore
|
||||||
pass
|
pass
|
||||||
|
|
||||||
instance = StateWithId()
|
instance = StateWithId(**kwargs)
|
||||||
if not hasattr(instance, "id"):
|
if not hasattr(instance, "id"):
|
||||||
instance.id = str(uuid4())
|
instance.id = str(uuid4())
|
||||||
return cast(T, instance)
|
return cast(T, instance)
|
||||||
if state_type is dict:
|
if state_type is dict:
|
||||||
return cast(T, {"id": str(uuid4())})
|
state_dict = {"id": str(uuid4())}
|
||||||
|
state_dict.update(kwargs)
|
||||||
|
return cast(T, state_dict)
|
||||||
|
|
||||||
# Handle case where no initial state is provided
|
# Handle case where no initial state is provided
|
||||||
if self.initial_state is None:
|
if self.initial_state is None:
|
||||||
return cast(T, {"id": str(uuid4())})
|
state_dict = {"id": str(uuid4())}
|
||||||
|
state_dict.update(kwargs)
|
||||||
|
return cast(T, state_dict)
|
||||||
|
|
||||||
# Handle case where initial_state is a type (class)
|
# Handle case where initial_state is a type (class)
|
||||||
if isinstance(self.initial_state, type):
|
if isinstance(self.initial_state, type):
|
||||||
if issubclass(self.initial_state, FlowState):
|
if issubclass(self.initial_state, FlowState):
|
||||||
return cast(T, self.initial_state()) # Uses model defaults
|
return cast(T, self.initial_state(**kwargs)) # Uses model defaults and kwargs
|
||||||
if issubclass(self.initial_state, BaseModel):
|
if issubclass(self.initial_state, BaseModel):
|
||||||
# Validate that the model has an id field
|
# Validate that the model has an id field
|
||||||
model_fields = getattr(self.initial_state, "model_fields", None)
|
model_fields = getattr(self.initial_state, "model_fields", None)
|
||||||
if not model_fields or "id" not in model_fields:
|
if not model_fields or "id" not in model_fields:
|
||||||
raise ValueError("Flow state model must have an 'id' field")
|
raise ValueError("Flow state model must have an 'id' field")
|
||||||
return cast(T, self.initial_state()) # Uses model defaults
|
return cast(T, self.initial_state(**kwargs)) # Uses model defaults and kwargs
|
||||||
if self.initial_state is dict:
|
if self.initial_state is dict:
|
||||||
return cast(T, {"id": str(uuid4())})
|
state_dict = {"id": str(uuid4())}
|
||||||
|
state_dict.update(kwargs)
|
||||||
|
return cast(T, state_dict)
|
||||||
|
|
||||||
# Handle dictionary instance case
|
# Handle dictionary instance case
|
||||||
if isinstance(self.initial_state, dict):
|
if isinstance(self.initial_state, dict):
|
||||||
new_state = dict(self.initial_state) # Copy to avoid mutations
|
new_state = dict(self.initial_state) # Copy to avoid mutations
|
||||||
if "id" not in new_state:
|
if "id" not in new_state:
|
||||||
new_state["id"] = str(uuid4())
|
new_state["id"] = str(uuid4())
|
||||||
|
new_state.update(kwargs) # Apply kwargs
|
||||||
return cast(T, new_state)
|
return cast(T, new_state)
|
||||||
|
|
||||||
# Handle BaseModel instance case
|
# Handle BaseModel instance case
|
||||||
@@ -652,6 +662,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
k: v for k, v in model.__dict__.items() if not k.startswith("_")
|
k: v for k, v in model.__dict__.items() if not k.startswith("_")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
state_dict.update(kwargs)
|
||||||
|
|
||||||
# Create new instance of the same class
|
# Create new instance of the same class
|
||||||
model_class = type(model)
|
model_class = type(model)
|
||||||
return cast(T, model_class(**state_dict))
|
return cast(T, model_class(**state_dict))
|
||||||
|
|||||||
145
tests/test_flow_pydantic_required_fields.py
Normal file
145
tests/test_flow_pydantic_required_fields.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
"""Tests for Flow initialization with Pydantic models having required fields.
|
||||||
|
Covers https://github.com/crewAIInc/crewAI/issues/3744
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
|
from crewai.flow.flow import Flow, FlowState, listen, start
|
||||||
|
|
||||||
|
|
||||||
|
class RequiredState(BaseModel):
|
||||||
|
"""State model with required fields."""
|
||||||
|
name: str
|
||||||
|
age: int
|
||||||
|
|
||||||
|
|
||||||
|
class RequiredStateFlow(Flow[RequiredState]):
|
||||||
|
"""Flow with required state fields."""
|
||||||
|
|
||||||
|
@start()
|
||||||
|
def begin(self):
|
||||||
|
return "started"
|
||||||
|
|
||||||
|
|
||||||
|
class MixedState(BaseModel):
|
||||||
|
"""State model with both required and optional fields."""
|
||||||
|
name: str # Required
|
||||||
|
age: int # Required
|
||||||
|
email: str = "default@example.com" # Optional with default
|
||||||
|
|
||||||
|
|
||||||
|
class MixedStateFlow(Flow[MixedState]):
|
||||||
|
"""Flow with mixed required and optional state fields."""
|
||||||
|
|
||||||
|
@start()
|
||||||
|
def begin(self):
|
||||||
|
return f"Started with {self.state.name}, {self.state.age}, {self.state.email}"
|
||||||
|
|
||||||
|
|
||||||
|
class RequiredStateWithFlowState(FlowState):
|
||||||
|
"""State model extending FlowState with required fields."""
|
||||||
|
name: str
|
||||||
|
age: int
|
||||||
|
|
||||||
|
|
||||||
|
class RequiredFlowStateFlow(Flow[RequiredStateWithFlowState]):
|
||||||
|
"""Flow with required FlowState fields."""
|
||||||
|
|
||||||
|
@start()
|
||||||
|
def begin(self):
|
||||||
|
return "started"
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_initialization_without_kwargs_raises_validation_error():
|
||||||
|
"""Test that Flow initialization without kwargs raises ValidationError for required fields."""
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
RequiredStateFlow()
|
||||||
|
|
||||||
|
error_str = str(exc_info.value)
|
||||||
|
assert "name" in error_str
|
||||||
|
assert "age" in error_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_initialization_with_kwargs_passes_and_sets_state():
|
||||||
|
"""Test that Flow initialization with kwargs properly sets state values."""
|
||||||
|
flow = RequiredStateFlow(name="John", age=30)
|
||||||
|
assert flow.state.name == "John"
|
||||||
|
assert flow.state.age == 30
|
||||||
|
assert hasattr(flow.state, "id")
|
||||||
|
assert flow.state.id is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_initialization_with_partial_kwargs_raises_validation_error():
|
||||||
|
"""Test that Flow initialization with only some required kwargs raises ValidationError."""
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
RequiredStateFlow(name="John")
|
||||||
|
|
||||||
|
error_str = str(exc_info.value)
|
||||||
|
assert "age" in error_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_initialization_with_mixed_required_and_optional_fields():
|
||||||
|
"""Test Flow initialization with both required and optional fields."""
|
||||||
|
flow1 = MixedStateFlow(name="Alice", age=25)
|
||||||
|
assert flow1.state.name == "Alice"
|
||||||
|
assert flow1.state.age == 25
|
||||||
|
assert flow1.state.email == "default@example.com"
|
||||||
|
|
||||||
|
flow2 = MixedStateFlow(name="Bob", age=35, email="bob@example.com")
|
||||||
|
assert flow2.state.name == "Bob"
|
||||||
|
assert flow2.state.age == 35
|
||||||
|
assert flow2.state.email == "bob@example.com"
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_initialization_with_flowstate_and_required_fields():
|
||||||
|
"""Test Flow initialization with FlowState subclass having required fields."""
|
||||||
|
flow = RequiredFlowStateFlow(name="Charlie", age=40)
|
||||||
|
assert flow.state.name == "Charlie"
|
||||||
|
assert flow.state.age == 40
|
||||||
|
assert hasattr(flow.state, "id")
|
||||||
|
assert flow.state.id is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_execution_with_required_state():
|
||||||
|
"""Test that Flow execution works correctly with required state fields."""
|
||||||
|
flow = RequiredStateFlow(name="David", age=45)
|
||||||
|
result = flow.kickoff()
|
||||||
|
assert result == "started"
|
||||||
|
assert flow.state.name == "David"
|
||||||
|
assert flow.state.age == 45
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_with_state_modification():
|
||||||
|
"""Test that state can be modified during flow execution."""
|
||||||
|
|
||||||
|
class ModifiableState(BaseModel):
|
||||||
|
counter: int
|
||||||
|
name: str
|
||||||
|
|
||||||
|
class ModifiableFlow(Flow[ModifiableState]):
|
||||||
|
@start()
|
||||||
|
def increment(self):
|
||||||
|
self.state.counter += 10
|
||||||
|
return "incremented"
|
||||||
|
|
||||||
|
@listen(increment)
|
||||||
|
def check_value(self):
|
||||||
|
assert self.state.counter == 15
|
||||||
|
return "checked"
|
||||||
|
|
||||||
|
flow = ModifiableFlow(counter=5, name="Test")
|
||||||
|
result = flow.kickoff()
|
||||||
|
assert result == "checked"
|
||||||
|
assert flow.state.counter == 15
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_initialization_preserves_id_field():
|
||||||
|
"""Test that the automatically generated id field is preserved."""
|
||||||
|
flow = RequiredStateFlow(name="Eve", age=28)
|
||||||
|
original_id = flow.state.id
|
||||||
|
|
||||||
|
assert isinstance(original_id, str)
|
||||||
|
assert len(original_id) == 36 # UUID format with hyphens
|
||||||
|
|
||||||
|
assert flow.state.id == original_id
|
||||||
Reference in New Issue
Block a user