Compare commits

...

3 Commits

Author SHA1 Message Date
Devin AI
53d94df7de Fix pre-existing lint and type-checker errors in flow.py
- Import PrinterColor type from utilities.printer
- Replace unused variable 'condition_type' with '_' at line 1091
- Fix type annotation for '_log_flow_event' color parameter to use PrinterColor instead of str

These were pre-existing errors that prevented CI from passing after modifying flow.py

Co-Authored-By: João <joao@crewai.com>
2025-10-01 17:37:53 +00:00
Devin AI
1ce52c6616 Fix lint issues in test_flow.py
- Add ValidationError import from pydantic
- Use ValidationError instead of generic Exception in test
- Remove unused flow variable
- Apply ruff formatting (import ordering and getattr simplification)

Co-Authored-By: João <joao@crewai.com>
2025-10-01 17:30:23 +00:00
Devin AI
0650fbbe88 Fix Flow initialization with Pydantic models having required fields
- Modified _create_initial_state() to accept kwargs parameter
- Updated all model instantiation points to use kwargs
- Removed redundant _initialize_state() call from __init__()
- Added comprehensive tests for required fields, optional fields, and backward compatibility
- Fixes issue #3629

Co-Authored-By: João <joao@crewai.com>
2025-10-01 17:17:29 +00:00
2 changed files with 127 additions and 17 deletions

View File

@@ -31,7 +31,7 @@ from crewai.flow.flow_visualizer import plot_flow
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.types import FlowExecutionData
from crewai.flow.utils import get_possible_return_constants
from crewai.utilities.printer import Printer
from crewai.utilities.printer import Printer, PrinterColor
logger = logging.getLogger(__name__)
@@ -465,7 +465,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._is_execution_resuming: bool = False
# Initialize state with initial values
self._state = self._create_initial_state()
self._state = self._create_initial_state(kwargs)
self.tracing = tracing
if (
is_tracing_enabled()
@@ -474,9 +474,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
):
trace_listener = TraceCollectionListener()
trace_listener.setup_listeners(crewai_event_bus)
# Apply any additional kwargs
if kwargs:
self._initialize_state(kwargs)
crewai_event_bus.emit(
self,
@@ -502,9 +499,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
method = method.__get__(self, self.__class__)
self._methods[method_name] = method
def _create_initial_state(self) -> T:
def _create_initial_state(self, kwargs: dict[str, Any] | None = None) -> T:
"""Create and initialize flow state with UUID and default values.
Args:
kwargs: Optional initial values for state fields
Returns:
New state instance with UUID and default values initialized
@@ -518,7 +518,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
if isinstance(state_type, type):
if issubclass(state_type, FlowState):
# Create instance without id, then set it
instance = state_type()
init_kwargs = kwargs or {}
instance = state_type(**init_kwargs)
if not hasattr(instance, "id"):
instance.id = str(uuid4())
return cast(T, instance)
@@ -527,7 +528,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
class StateWithId(state_type, FlowState): # type: ignore
pass
instance = StateWithId()
init_kwargs = kwargs or {}
instance = StateWithId(**init_kwargs)
if not hasattr(instance, "id"):
instance.id = str(uuid4())
return cast(T, instance)
@@ -541,13 +543,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
# Handle case where initial_state is a type (class)
if isinstance(self.initial_state, type):
if issubclass(self.initial_state, FlowState):
return cast(T, self.initial_state()) # Uses model defaults
return cast(T, self.initial_state(**(kwargs or {})))
if issubclass(self.initial_state, BaseModel):
# Validate that the model has an id field
model_fields = getattr(self.initial_state, "model_fields", None)
if not model_fields or "id" not in model_fields:
raise ValueError("Flow state model must have an 'id' field")
return cast(T, self.initial_state()) # Uses model defaults
return cast(T, self.initial_state(**(kwargs or {})))
if self.initial_state is dict:
return cast(T, {"id": str(uuid4())})
@@ -1086,7 +1088,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
for method_name in self._start_methods:
# Check if this start method is triggered by the current trigger
if method_name in self._listeners:
condition_type, trigger_methods = self._listeners[
_, trigger_methods = self._listeners[
method_name
]
if current_trigger in trigger_methods:
@@ -1218,7 +1220,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
raise
def _log_flow_event(
self, message: str, color: str = "yellow", level: str = "info"
self, message: str, color: PrinterColor = "yellow", level: str = "info"
) -> None:
"""Centralized logging method for flow events.

View File

@@ -4,17 +4,17 @@ import asyncio
from datetime import datetime
import pytest
from pydantic import BaseModel
from pydantic import BaseModel, ValidationError
from crewai.flow.flow import Flow, and_, listen, or_, router, start
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.flow_events import (
FlowFinishedEvent,
FlowStartedEvent,
FlowPlotEvent,
FlowStartedEvent,
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from crewai.flow.flow import Flow, and_, listen, or_, router, start
def test_simple_sequential_flow():
@@ -679,11 +679,11 @@ def test_structured_flow_event_emission():
assert isinstance(received_events[3], MethodExecutionStartedEvent)
assert received_events[3].method_name == "send_welcome_message"
assert received_events[3].params == {}
assert getattr(received_events[3].state, "sent") is False
assert received_events[3].state.sent is False
assert isinstance(received_events[4], MethodExecutionFinishedEvent)
assert received_events[4].method_name == "send_welcome_message"
assert getattr(received_events[4].state, "sent") is True
assert received_events[4].state.sent is True
assert received_events[4].result == "Welcome, Anakin!"
assert isinstance(received_events[5], FlowFinishedEvent)
@@ -894,3 +894,111 @@ def test_flow_name():
flow = MyFlow()
assert flow.name == "MyFlow"
def test_flow_init_with_required_fields():
"""Test Flow initialization with Pydantic models having required fields."""
class RequiredFieldsState(BaseModel):
name: str
age: int
class RequiredFieldsFlow(Flow[RequiredFieldsState]):
@start()
def step_1(self):
assert self.state.name == "Alice"
assert self.state.age == 30
flow = RequiredFieldsFlow(name="Alice", age=30)
flow.kickoff()
assert flow.state.name == "Alice"
assert flow.state.age == 30
assert hasattr(flow.state, "id")
assert len(flow.state.id) == 36
def test_flow_init_with_required_fields_missing_values():
"""Test that Flow initialization fails when required fields are missing."""
class RequiredFieldsState(BaseModel):
name: str
age: int
class RequiredFieldsFlow(Flow[RequiredFieldsState]):
@start()
def step_1(self):
pass
with pytest.raises(ValidationError):
RequiredFieldsFlow()
def test_flow_init_with_mixed_required_optional_fields():
"""Test Flow with both required and optional fields."""
class MixedFieldsState(BaseModel):
name: str
age: int = 25
city: str | None = None
class MixedFieldsFlow(Flow[MixedFieldsState]):
@start()
def step_1(self):
assert self.state.name == "Bob"
assert self.state.age == 25
assert self.state.city is None
flow = MixedFieldsFlow(name="Bob")
flow.kickoff()
assert flow.state.name == "Bob"
assert flow.state.age == 25
assert flow.state.city is None
def test_flow_init_with_required_fields_and_overrides():
"""Test that kwargs override default values."""
class DefaultFieldsState(BaseModel):
name: str
age: int = 18
active: bool = True
class DefaultFieldsFlow(Flow[DefaultFieldsState]):
@start()
def step_1(self):
assert self.state.name == "Charlie"
assert self.state.age == 35
assert self.state.active is False
flow = DefaultFieldsFlow(name="Charlie", age=35, active=False)
flow.kickoff()
assert flow.state.name == "Charlie"
assert flow.state.age == 35
assert flow.state.active is False
def test_flow_init_backward_compatibility_with_flowstate():
"""Test that existing FlowState subclasses still work."""
from crewai.flow.flow import FlowState
class MyFlowState(FlowState):
counter: int = 0
message: str = "default"
class BackwardCompatFlow(Flow[MyFlowState]):
@start()
def step_1(self):
self.state.counter += 1
flow1 = BackwardCompatFlow()
flow1.kickoff()
assert flow1.state.counter == 1
assert flow1.state.message == "default"
flow2 = BackwardCompatFlow(counter=10, message="custom")
flow2.kickoff()
assert flow2.state.counter == 11
assert flow2.state.message == "custom"