mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
Fix Flow initial_state BaseModel dict coercion issue #3147
- Fix Flow constructor to accept initial_state parameter - Replace dict conversion with model.model_copy() in _create_initial_state - Replace dict conversion with model.copy(update=...) in _initialize_state - Add comprehensive tests covering dict method name collisions - Preserve BaseModel structure to prevent attribute collision Fixes issue where Pydantic BaseModel instances were coerced into dicts, causing field names like 'items', 'keys', 'values' to be overridden by built-in dict methods. Now BaseModel structure is preserved using Pydantic's built-in copying methods. Co-Authored-By: Jo\u00E3o <joao@crewai.com>
This commit is contained in:
@@ -446,15 +446,20 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
initial_state: Union[Type[T], T, None] = None,
|
||||||
persistence: Optional[FlowPersistence] = None,
|
persistence: Optional[FlowPersistence] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a new Flow instance.
|
"""Initialize a new Flow instance.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
initial_state: Initial state for the flow (BaseModel instance or dict)
|
||||||
persistence: Optional persistence backend for storing flow states
|
persistence: Optional persistence backend for storing flow states
|
||||||
**kwargs: Additional state values to initialize or override
|
**kwargs: Additional state values to initialize or override
|
||||||
"""
|
"""
|
||||||
|
# Set the initial_state for this instance
|
||||||
|
if initial_state is not None:
|
||||||
|
self.initial_state = initial_state
|
||||||
# Initialize basic instance attributes
|
# Initialize basic instance attributes
|
||||||
self._methods: Dict[str, Callable] = {}
|
self._methods: Dict[str, Callable] = {}
|
||||||
self._method_execution_counts: Dict[str, int] = {}
|
self._method_execution_counts: Dict[str, int] = {}
|
||||||
@@ -552,23 +557,19 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
# Handle BaseModel instance case
|
# Handle BaseModel instance case
|
||||||
if isinstance(self.initial_state, BaseModel):
|
if isinstance(self.initial_state, BaseModel):
|
||||||
model = cast(BaseModel, self.initial_state)
|
model = cast(BaseModel, self.initial_state)
|
||||||
if not hasattr(model, "id"):
|
|
||||||
raise ValueError("Flow state model must have an 'id' field")
|
|
||||||
|
|
||||||
# Create new instance with same values to avoid mutations
|
# Create copy of the BaseModel to avoid mutations
|
||||||
if hasattr(model, "model_dump"):
|
if hasattr(model, "model_copy"):
|
||||||
# Pydantic v2
|
# Pydantic v2
|
||||||
state_dict = model.model_dump()
|
return cast(T, model.model_copy())
|
||||||
elif hasattr(model, "dict"):
|
elif hasattr(model, "copy"):
|
||||||
# Pydantic v1
|
# Pydantic v1
|
||||||
state_dict = model.dict()
|
return cast(T, model.copy())
|
||||||
else:
|
else:
|
||||||
# Fallback for other BaseModel implementations
|
# Fallback for other BaseModel implementations - preserve original logic
|
||||||
state_dict = {
|
state_dict = {
|
||||||
k: v for k, v in model.__dict__.items() if not k.startswith("_")
|
k: v for k, v in model.__dict__.items() if not k.startswith("_")
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create new instance of the same class
|
|
||||||
model_class = type(model)
|
model_class = type(model)
|
||||||
return cast(T, model_class(**state_dict))
|
return cast(T, model_class(**state_dict))
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
@@ -645,29 +646,25 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
# For BaseModel states, preserve existing fields unless overridden
|
# For BaseModel states, preserve existing fields unless overridden
|
||||||
try:
|
try:
|
||||||
model = cast(BaseModel, self._state)
|
model = cast(BaseModel, self._state)
|
||||||
# Get current state as dict
|
|
||||||
if hasattr(model, "model_dump"):
|
if hasattr(model, "model_copy"):
|
||||||
current_state = model.model_dump()
|
# Pydantic v2
|
||||||
elif hasattr(model, "dict"):
|
self._state = cast(T, model.model_copy(update=inputs))
|
||||||
current_state = model.dict()
|
elif hasattr(model, "copy"):
|
||||||
|
# Pydantic v1
|
||||||
|
self._state = cast(T, model.copy(update=inputs))
|
||||||
else:
|
else:
|
||||||
|
# Fallback for other BaseModel implementations - preserve original logic
|
||||||
current_state = {
|
current_state = {
|
||||||
k: v for k, v in model.__dict__.items() if not k.startswith("_")
|
k: v for k, v in model.__dict__.items() if not k.startswith("_")
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create new state with preserved fields and updates
|
|
||||||
new_state = {**current_state, **inputs}
|
new_state = {**current_state, **inputs}
|
||||||
|
|
||||||
# Create new instance with merged state
|
|
||||||
model_class = type(model)
|
model_class = type(model)
|
||||||
if hasattr(model_class, "model_validate"):
|
if hasattr(model_class, "model_validate"):
|
||||||
# Pydantic v2
|
|
||||||
self._state = cast(T, model_class.model_validate(new_state))
|
self._state = cast(T, model_class.model_validate(new_state))
|
||||||
elif hasattr(model_class, "parse_obj"):
|
elif hasattr(model_class, "parse_obj"):
|
||||||
# Pydantic v1
|
|
||||||
self._state = cast(T, model_class.parse_obj(new_state))
|
self._state = cast(T, model_class.parse_obj(new_state))
|
||||||
else:
|
else:
|
||||||
# Fallback for other BaseModel implementations
|
|
||||||
self._state = cast(T, model_class(**new_state))
|
self._state = cast(T, model_class(**new_state))
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
raise ValueError(f"Invalid inputs for structured state: {e}") from e
|
raise ValueError(f"Invalid inputs for structured state: {e}") from e
|
||||||
|
|||||||
182
tests/test_flow_initial_state_fix.py
Normal file
182
tests/test_flow_initial_state_fix.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
"""Test Flow initial_state BaseModel dict coercion fix for issue #3147"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from crewai.flow.flow import Flow
|
||||||
|
|
||||||
|
|
||||||
|
class StateWithItems(BaseModel):
|
||||||
|
items: list = [1, 2, 3]
|
||||||
|
metadata: dict = {"x": 1}
|
||||||
|
|
||||||
|
|
||||||
|
class StateWithKeys(BaseModel):
|
||||||
|
keys: list = ["a", "b", "c"]
|
||||||
|
data: str = "test"
|
||||||
|
|
||||||
|
|
||||||
|
class StateWithValues(BaseModel):
|
||||||
|
values: list = [10, 20, 30]
|
||||||
|
name: str = "example"
|
||||||
|
|
||||||
|
|
||||||
|
class StateWithGet(BaseModel):
|
||||||
|
get: str = "method_name"
|
||||||
|
config: dict = {"enabled": True}
|
||||||
|
|
||||||
|
|
||||||
|
class StateWithPop(BaseModel):
|
||||||
|
pop: int = 42
|
||||||
|
settings: list = ["option1", "option2"]
|
||||||
|
|
||||||
|
|
||||||
|
class StateWithUpdate(BaseModel):
|
||||||
|
update: bool = True
|
||||||
|
version: str = "1.0.0"
|
||||||
|
|
||||||
|
|
||||||
|
class StateWithClear(BaseModel):
|
||||||
|
clear: str = "action"
|
||||||
|
status: str = "active"
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_initial_state_items_field():
|
||||||
|
"""Test that BaseModel with 'items' field preserves structure and doesn't get dict coercion."""
|
||||||
|
flow = Flow(initial_state=StateWithItems())
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert isinstance(flow.state, StateWithItems)
|
||||||
|
assert not isinstance(flow.state, dict)
|
||||||
|
|
||||||
|
assert isinstance(flow.state.items, list)
|
||||||
|
assert flow.state.items == [1, 2, 3]
|
||||||
|
assert len(flow.state.items) == 3
|
||||||
|
|
||||||
|
assert flow.state.metadata == {"x": 1}
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_initial_state_keys_field():
|
||||||
|
"""Test that BaseModel with 'keys' field preserves structure."""
|
||||||
|
flow = Flow(initial_state=StateWithKeys())
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert isinstance(flow.state, StateWithKeys)
|
||||||
|
assert isinstance(flow.state.keys, list)
|
||||||
|
assert flow.state.keys == ["a", "b", "c"]
|
||||||
|
assert len(flow.state.keys) == 3
|
||||||
|
assert flow.state.data == "test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_initial_state_values_field():
|
||||||
|
"""Test that BaseModel with 'values' field preserves structure."""
|
||||||
|
flow = Flow(initial_state=StateWithValues())
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert isinstance(flow.state, StateWithValues)
|
||||||
|
assert isinstance(flow.state.values, list)
|
||||||
|
assert flow.state.values == [10, 20, 30]
|
||||||
|
assert len(flow.state.values) == 3
|
||||||
|
assert flow.state.name == "example"
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_initial_state_get_field():
|
||||||
|
"""Test that BaseModel with 'get' field preserves structure."""
|
||||||
|
flow = Flow(initial_state=StateWithGet())
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert isinstance(flow.state, StateWithGet)
|
||||||
|
assert isinstance(flow.state.get, str)
|
||||||
|
assert flow.state.get == "method_name"
|
||||||
|
assert flow.state.config == {"enabled": True}
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_initial_state_pop_field():
|
||||||
|
"""Test that BaseModel with 'pop' field preserves structure."""
|
||||||
|
flow = Flow(initial_state=StateWithPop())
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert isinstance(flow.state, StateWithPop)
|
||||||
|
assert isinstance(flow.state.pop, int)
|
||||||
|
assert flow.state.pop == 42
|
||||||
|
assert flow.state.settings == ["option1", "option2"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_initial_state_update_field():
|
||||||
|
"""Test that BaseModel with 'update' field preserves structure."""
|
||||||
|
flow = Flow(initial_state=StateWithUpdate())
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert isinstance(flow.state, StateWithUpdate)
|
||||||
|
assert isinstance(flow.state.update, bool)
|
||||||
|
assert flow.state.update is True
|
||||||
|
assert flow.state.version == "1.0.0"
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_initial_state_clear_field():
|
||||||
|
"""Test that BaseModel with 'clear' field preserves structure."""
|
||||||
|
flow = Flow(initial_state=StateWithClear())
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert isinstance(flow.state, StateWithClear)
|
||||||
|
assert isinstance(flow.state.clear, str)
|
||||||
|
assert flow.state.clear == "action"
|
||||||
|
assert flow.state.status == "active"
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_state_modification_preserves_basemodel():
|
||||||
|
"""Test that modifying flow state preserves BaseModel structure."""
|
||||||
|
|
||||||
|
class ModifiableState(BaseModel):
|
||||||
|
items: list = [1, 2, 3]
|
||||||
|
counter: int = 0
|
||||||
|
|
||||||
|
class TestFlow(Flow[ModifiableState]):
|
||||||
|
@Flow.start()
|
||||||
|
def modify_state(self):
|
||||||
|
self.state.counter += 1
|
||||||
|
self.state.items.append(4)
|
||||||
|
|
||||||
|
flow = TestFlow(initial_state=ModifiableState())
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert isinstance(flow.state, ModifiableState)
|
||||||
|
assert not isinstance(flow.state, dict)
|
||||||
|
|
||||||
|
assert flow.state.counter == 1
|
||||||
|
assert flow.state.items == [1, 2, 3, 4]
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_with_inputs_preserves_basemodel():
|
||||||
|
"""Test that providing inputs to flow preserves BaseModel structure."""
|
||||||
|
|
||||||
|
class InputState(BaseModel):
|
||||||
|
items: list = []
|
||||||
|
name: str = ""
|
||||||
|
|
||||||
|
flow = Flow(initial_state=InputState())
|
||||||
|
flow.kickoff(inputs={"name": "test_flow", "items": [5, 6, 7]})
|
||||||
|
|
||||||
|
assert isinstance(flow.state, InputState)
|
||||||
|
assert not isinstance(flow.state, dict)
|
||||||
|
|
||||||
|
assert flow.state.name == "test_flow"
|
||||||
|
assert flow.state.items == [5, 6, 7]
|
||||||
|
|
||||||
|
|
||||||
|
def test_reproduction_case_from_issue_3147():
|
||||||
|
"""Test the exact reproduction case from GitHub issue #3147."""
|
||||||
|
|
||||||
|
class MyState(BaseModel):
|
||||||
|
items: list = [1, 2, 3]
|
||||||
|
metadata: dict = {"x": 1}
|
||||||
|
|
||||||
|
flow = Flow(initial_state=MyState())
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert isinstance(flow.state.items, list)
|
||||||
|
assert len(flow.state.items) == 3
|
||||||
|
assert flow.state.items == [1, 2, 3]
|
||||||
|
|
||||||
|
assert not callable(flow.state.items)
|
||||||
|
assert str(type(flow.state.items)) != "<class 'builtin_function_or_method'>"
|
||||||
Reference in New Issue
Block a user