mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 07:08:31 +00:00
Compare commits
3 Commits
devin/1762
...
devin/1759
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
53d94df7de | ||
|
|
1ce52c6616 | ||
|
|
0650fbbe88 |
@@ -31,7 +31,7 @@ from crewai.flow.flow_visualizer import plot_flow
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.types import FlowExecutionData
|
||||
from crewai.flow.utils import get_possible_return_constants
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.printer import Printer, PrinterColor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -465,7 +465,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._is_execution_resuming: bool = False
|
||||
|
||||
# Initialize state with initial values
|
||||
self._state = self._create_initial_state()
|
||||
self._state = self._create_initial_state(kwargs)
|
||||
self.tracing = tracing
|
||||
if (
|
||||
is_tracing_enabled()
|
||||
@@ -474,9 +474,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
):
|
||||
trace_listener = TraceCollectionListener()
|
||||
trace_listener.setup_listeners(crewai_event_bus)
|
||||
# Apply any additional kwargs
|
||||
if kwargs:
|
||||
self._initialize_state(kwargs)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -502,9 +499,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
method = method.__get__(self, self.__class__)
|
||||
self._methods[method_name] = method
|
||||
|
||||
def _create_initial_state(self) -> T:
|
||||
def _create_initial_state(self, kwargs: dict[str, Any] | None = None) -> T:
|
||||
"""Create and initialize flow state with UUID and default values.
|
||||
|
||||
Args:
|
||||
kwargs: Optional initial values for state fields
|
||||
|
||||
Returns:
|
||||
New state instance with UUID and default values initialized
|
||||
|
||||
@@ -518,7 +518,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if isinstance(state_type, type):
|
||||
if issubclass(state_type, FlowState):
|
||||
# Create instance without id, then set it
|
||||
instance = state_type()
|
||||
init_kwargs = kwargs or {}
|
||||
instance = state_type(**init_kwargs)
|
||||
if not hasattr(instance, "id"):
|
||||
instance.id = str(uuid4())
|
||||
return cast(T, instance)
|
||||
@@ -527,7 +528,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
class StateWithId(state_type, FlowState): # type: ignore
|
||||
pass
|
||||
|
||||
instance = StateWithId()
|
||||
init_kwargs = kwargs or {}
|
||||
instance = StateWithId(**init_kwargs)
|
||||
if not hasattr(instance, "id"):
|
||||
instance.id = str(uuid4())
|
||||
return cast(T, instance)
|
||||
@@ -541,13 +543,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# Handle case where initial_state is a type (class)
|
||||
if isinstance(self.initial_state, type):
|
||||
if issubclass(self.initial_state, FlowState):
|
||||
return cast(T, self.initial_state()) # Uses model defaults
|
||||
return cast(T, self.initial_state(**(kwargs or {})))
|
||||
if issubclass(self.initial_state, BaseModel):
|
||||
# Validate that the model has an id field
|
||||
model_fields = getattr(self.initial_state, "model_fields", None)
|
||||
if not model_fields or "id" not in model_fields:
|
||||
raise ValueError("Flow state model must have an 'id' field")
|
||||
return cast(T, self.initial_state()) # Uses model defaults
|
||||
return cast(T, self.initial_state(**(kwargs or {})))
|
||||
if self.initial_state is dict:
|
||||
return cast(T, {"id": str(uuid4())})
|
||||
|
||||
@@ -1086,7 +1088,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
for method_name in self._start_methods:
|
||||
# Check if this start method is triggered by the current trigger
|
||||
if method_name in self._listeners:
|
||||
condition_type, trigger_methods = self._listeners[
|
||||
_, trigger_methods = self._listeners[
|
||||
method_name
|
||||
]
|
||||
if current_trigger in trigger_methods:
|
||||
@@ -1218,7 +1220,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
raise
|
||||
|
||||
def _log_flow_event(
|
||||
self, message: str, color: str = "yellow", level: str = "info"
|
||||
self, message: str, color: PrinterColor = "yellow", level: str = "info"
|
||||
) -> None:
|
||||
"""Centralized logging method for flow events.
|
||||
|
||||
|
||||
@@ -4,17 +4,17 @@ import asyncio
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowFinishedEvent,
|
||||
FlowStartedEvent,
|
||||
FlowPlotEvent,
|
||||
FlowStartedEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
|
||||
|
||||
def test_simple_sequential_flow():
|
||||
@@ -679,11 +679,11 @@ def test_structured_flow_event_emission():
|
||||
assert isinstance(received_events[3], MethodExecutionStartedEvent)
|
||||
assert received_events[3].method_name == "send_welcome_message"
|
||||
assert received_events[3].params == {}
|
||||
assert getattr(received_events[3].state, "sent") is False
|
||||
assert received_events[3].state.sent is False
|
||||
|
||||
assert isinstance(received_events[4], MethodExecutionFinishedEvent)
|
||||
assert received_events[4].method_name == "send_welcome_message"
|
||||
assert getattr(received_events[4].state, "sent") is True
|
||||
assert received_events[4].state.sent is True
|
||||
assert received_events[4].result == "Welcome, Anakin!"
|
||||
|
||||
assert isinstance(received_events[5], FlowFinishedEvent)
|
||||
@@ -894,3 +894,111 @@ def test_flow_name():
|
||||
|
||||
flow = MyFlow()
|
||||
assert flow.name == "MyFlow"
|
||||
|
||||
|
||||
def test_flow_init_with_required_fields():
|
||||
"""Test Flow initialization with Pydantic models having required fields."""
|
||||
|
||||
class RequiredFieldsState(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
class RequiredFieldsFlow(Flow[RequiredFieldsState]):
|
||||
@start()
|
||||
def step_1(self):
|
||||
assert self.state.name == "Alice"
|
||||
assert self.state.age == 30
|
||||
|
||||
flow = RequiredFieldsFlow(name="Alice", age=30)
|
||||
flow.kickoff()
|
||||
|
||||
assert flow.state.name == "Alice"
|
||||
assert flow.state.age == 30
|
||||
assert hasattr(flow.state, "id")
|
||||
assert len(flow.state.id) == 36
|
||||
|
||||
|
||||
def test_flow_init_with_required_fields_missing_values():
|
||||
"""Test that Flow initialization fails when required fields are missing."""
|
||||
|
||||
class RequiredFieldsState(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
class RequiredFieldsFlow(Flow[RequiredFieldsState]):
|
||||
@start()
|
||||
def step_1(self):
|
||||
pass
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
RequiredFieldsFlow()
|
||||
|
||||
|
||||
def test_flow_init_with_mixed_required_optional_fields():
|
||||
"""Test Flow with both required and optional fields."""
|
||||
|
||||
class MixedFieldsState(BaseModel):
|
||||
name: str
|
||||
age: int = 25
|
||||
city: str | None = None
|
||||
|
||||
class MixedFieldsFlow(Flow[MixedFieldsState]):
|
||||
@start()
|
||||
def step_1(self):
|
||||
assert self.state.name == "Bob"
|
||||
assert self.state.age == 25
|
||||
assert self.state.city is None
|
||||
|
||||
flow = MixedFieldsFlow(name="Bob")
|
||||
flow.kickoff()
|
||||
|
||||
assert flow.state.name == "Bob"
|
||||
assert flow.state.age == 25
|
||||
assert flow.state.city is None
|
||||
|
||||
|
||||
def test_flow_init_with_required_fields_and_overrides():
|
||||
"""Test that kwargs override default values."""
|
||||
|
||||
class DefaultFieldsState(BaseModel):
|
||||
name: str
|
||||
age: int = 18
|
||||
active: bool = True
|
||||
|
||||
class DefaultFieldsFlow(Flow[DefaultFieldsState]):
|
||||
@start()
|
||||
def step_1(self):
|
||||
assert self.state.name == "Charlie"
|
||||
assert self.state.age == 35
|
||||
assert self.state.active is False
|
||||
|
||||
flow = DefaultFieldsFlow(name="Charlie", age=35, active=False)
|
||||
flow.kickoff()
|
||||
|
||||
assert flow.state.name == "Charlie"
|
||||
assert flow.state.age == 35
|
||||
assert flow.state.active is False
|
||||
|
||||
|
||||
def test_flow_init_backward_compatibility_with_flowstate():
|
||||
"""Test that existing FlowState subclasses still work."""
|
||||
from crewai.flow.flow import FlowState
|
||||
|
||||
class MyFlowState(FlowState):
|
||||
counter: int = 0
|
||||
message: str = "default"
|
||||
|
||||
class BackwardCompatFlow(Flow[MyFlowState]):
|
||||
@start()
|
||||
def step_1(self):
|
||||
self.state.counter += 1
|
||||
|
||||
flow1 = BackwardCompatFlow()
|
||||
flow1.kickoff()
|
||||
assert flow1.state.counter == 1
|
||||
assert flow1.state.message == "default"
|
||||
|
||||
flow2 = BackwardCompatFlow(counter=10, message="custom")
|
||||
flow2.kickoff()
|
||||
assert flow2.state.counter == 11
|
||||
assert flow2.state.message == "custom"
|
||||
|
||||
Reference in New Issue
Block a user