mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-19 21:08:13 +00:00
Compare commits
3 Commits
lorenze/fi
...
devin/1759
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
53d94df7de | ||
|
|
1ce52c6616 | ||
|
|
0650fbbe88 |
@@ -31,7 +31,7 @@ from crewai.flow.flow_visualizer import plot_flow
|
|||||||
from crewai.flow.persistence.base import FlowPersistence
|
from crewai.flow.persistence.base import FlowPersistence
|
||||||
from crewai.flow.types import FlowExecutionData
|
from crewai.flow.types import FlowExecutionData
|
||||||
from crewai.flow.utils import get_possible_return_constants
|
from crewai.flow.utils import get_possible_return_constants
|
||||||
from crewai.utilities.printer import Printer
|
from crewai.utilities.printer import Printer, PrinterColor
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -465,7 +465,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
self._is_execution_resuming: bool = False
|
self._is_execution_resuming: bool = False
|
||||||
|
|
||||||
# Initialize state with initial values
|
# Initialize state with initial values
|
||||||
self._state = self._create_initial_state()
|
self._state = self._create_initial_state(kwargs)
|
||||||
self.tracing = tracing
|
self.tracing = tracing
|
||||||
if (
|
if (
|
||||||
is_tracing_enabled()
|
is_tracing_enabled()
|
||||||
@@ -474,9 +474,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
):
|
):
|
||||||
trace_listener = TraceCollectionListener()
|
trace_listener = TraceCollectionListener()
|
||||||
trace_listener.setup_listeners(crewai_event_bus)
|
trace_listener.setup_listeners(crewai_event_bus)
|
||||||
# Apply any additional kwargs
|
|
||||||
if kwargs:
|
|
||||||
self._initialize_state(kwargs)
|
|
||||||
|
|
||||||
crewai_event_bus.emit(
|
crewai_event_bus.emit(
|
||||||
self,
|
self,
|
||||||
@@ -502,9 +499,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
method = method.__get__(self, self.__class__)
|
method = method.__get__(self, self.__class__)
|
||||||
self._methods[method_name] = method
|
self._methods[method_name] = method
|
||||||
|
|
||||||
def _create_initial_state(self) -> T:
|
def _create_initial_state(self, kwargs: dict[str, Any] | None = None) -> T:
|
||||||
"""Create and initialize flow state with UUID and default values.
|
"""Create and initialize flow state with UUID and default values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kwargs: Optional initial values for state fields
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
New state instance with UUID and default values initialized
|
New state instance with UUID and default values initialized
|
||||||
|
|
||||||
@@ -518,7 +518,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
if isinstance(state_type, type):
|
if isinstance(state_type, type):
|
||||||
if issubclass(state_type, FlowState):
|
if issubclass(state_type, FlowState):
|
||||||
# Create instance without id, then set it
|
# Create instance without id, then set it
|
||||||
instance = state_type()
|
init_kwargs = kwargs or {}
|
||||||
|
instance = state_type(**init_kwargs)
|
||||||
if not hasattr(instance, "id"):
|
if not hasattr(instance, "id"):
|
||||||
instance.id = str(uuid4())
|
instance.id = str(uuid4())
|
||||||
return cast(T, instance)
|
return cast(T, instance)
|
||||||
@@ -527,7 +528,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
class StateWithId(state_type, FlowState): # type: ignore
|
class StateWithId(state_type, FlowState): # type: ignore
|
||||||
pass
|
pass
|
||||||
|
|
||||||
instance = StateWithId()
|
init_kwargs = kwargs or {}
|
||||||
|
instance = StateWithId(**init_kwargs)
|
||||||
if not hasattr(instance, "id"):
|
if not hasattr(instance, "id"):
|
||||||
instance.id = str(uuid4())
|
instance.id = str(uuid4())
|
||||||
return cast(T, instance)
|
return cast(T, instance)
|
||||||
@@ -541,13 +543,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
# Handle case where initial_state is a type (class)
|
# Handle case where initial_state is a type (class)
|
||||||
if isinstance(self.initial_state, type):
|
if isinstance(self.initial_state, type):
|
||||||
if issubclass(self.initial_state, FlowState):
|
if issubclass(self.initial_state, FlowState):
|
||||||
return cast(T, self.initial_state()) # Uses model defaults
|
return cast(T, self.initial_state(**(kwargs or {})))
|
||||||
if issubclass(self.initial_state, BaseModel):
|
if issubclass(self.initial_state, BaseModel):
|
||||||
# Validate that the model has an id field
|
# Validate that the model has an id field
|
||||||
model_fields = getattr(self.initial_state, "model_fields", None)
|
model_fields = getattr(self.initial_state, "model_fields", None)
|
||||||
if not model_fields or "id" not in model_fields:
|
if not model_fields or "id" not in model_fields:
|
||||||
raise ValueError("Flow state model must have an 'id' field")
|
raise ValueError("Flow state model must have an 'id' field")
|
||||||
return cast(T, self.initial_state()) # Uses model defaults
|
return cast(T, self.initial_state(**(kwargs or {})))
|
||||||
if self.initial_state is dict:
|
if self.initial_state is dict:
|
||||||
return cast(T, {"id": str(uuid4())})
|
return cast(T, {"id": str(uuid4())})
|
||||||
|
|
||||||
@@ -1086,7 +1088,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
for method_name in self._start_methods:
|
for method_name in self._start_methods:
|
||||||
# Check if this start method is triggered by the current trigger
|
# Check if this start method is triggered by the current trigger
|
||||||
if method_name in self._listeners:
|
if method_name in self._listeners:
|
||||||
condition_type, trigger_methods = self._listeners[
|
_, trigger_methods = self._listeners[
|
||||||
method_name
|
method_name
|
||||||
]
|
]
|
||||||
if current_trigger in trigger_methods:
|
if current_trigger in trigger_methods:
|
||||||
@@ -1218,7 +1220,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def _log_flow_event(
|
def _log_flow_event(
|
||||||
self, message: str, color: str = "yellow", level: str = "info"
|
self, message: str, color: PrinterColor = "yellow", level: str = "info"
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Centralized logging method for flow events.
|
"""Centralized logging method for flow events.
|
||||||
|
|
||||||
|
|||||||
@@ -4,17 +4,17 @@ import asyncio
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
from crewai.events.types.flow_events import (
|
from crewai.events.types.flow_events import (
|
||||||
FlowFinishedEvent,
|
FlowFinishedEvent,
|
||||||
FlowStartedEvent,
|
|
||||||
FlowPlotEvent,
|
FlowPlotEvent,
|
||||||
|
FlowStartedEvent,
|
||||||
MethodExecutionFinishedEvent,
|
MethodExecutionFinishedEvent,
|
||||||
MethodExecutionStartedEvent,
|
MethodExecutionStartedEvent,
|
||||||
)
|
)
|
||||||
|
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||||
|
|
||||||
|
|
||||||
def test_simple_sequential_flow():
|
def test_simple_sequential_flow():
|
||||||
@@ -679,11 +679,11 @@ def test_structured_flow_event_emission():
|
|||||||
assert isinstance(received_events[3], MethodExecutionStartedEvent)
|
assert isinstance(received_events[3], MethodExecutionStartedEvent)
|
||||||
assert received_events[3].method_name == "send_welcome_message"
|
assert received_events[3].method_name == "send_welcome_message"
|
||||||
assert received_events[3].params == {}
|
assert received_events[3].params == {}
|
||||||
assert getattr(received_events[3].state, "sent") is False
|
assert received_events[3].state.sent is False
|
||||||
|
|
||||||
assert isinstance(received_events[4], MethodExecutionFinishedEvent)
|
assert isinstance(received_events[4], MethodExecutionFinishedEvent)
|
||||||
assert received_events[4].method_name == "send_welcome_message"
|
assert received_events[4].method_name == "send_welcome_message"
|
||||||
assert getattr(received_events[4].state, "sent") is True
|
assert received_events[4].state.sent is True
|
||||||
assert received_events[4].result == "Welcome, Anakin!"
|
assert received_events[4].result == "Welcome, Anakin!"
|
||||||
|
|
||||||
assert isinstance(received_events[5], FlowFinishedEvent)
|
assert isinstance(received_events[5], FlowFinishedEvent)
|
||||||
@@ -894,3 +894,111 @@ def test_flow_name():
|
|||||||
|
|
||||||
flow = MyFlow()
|
flow = MyFlow()
|
||||||
assert flow.name == "MyFlow"
|
assert flow.name == "MyFlow"
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_init_with_required_fields():
|
||||||
|
"""Test Flow initialization with Pydantic models having required fields."""
|
||||||
|
|
||||||
|
class RequiredFieldsState(BaseModel):
|
||||||
|
name: str
|
||||||
|
age: int
|
||||||
|
|
||||||
|
class RequiredFieldsFlow(Flow[RequiredFieldsState]):
|
||||||
|
@start()
|
||||||
|
def step_1(self):
|
||||||
|
assert self.state.name == "Alice"
|
||||||
|
assert self.state.age == 30
|
||||||
|
|
||||||
|
flow = RequiredFieldsFlow(name="Alice", age=30)
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert flow.state.name == "Alice"
|
||||||
|
assert flow.state.age == 30
|
||||||
|
assert hasattr(flow.state, "id")
|
||||||
|
assert len(flow.state.id) == 36
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_init_with_required_fields_missing_values():
|
||||||
|
"""Test that Flow initialization fails when required fields are missing."""
|
||||||
|
|
||||||
|
class RequiredFieldsState(BaseModel):
|
||||||
|
name: str
|
||||||
|
age: int
|
||||||
|
|
||||||
|
class RequiredFieldsFlow(Flow[RequiredFieldsState]):
|
||||||
|
@start()
|
||||||
|
def step_1(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
RequiredFieldsFlow()
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_init_with_mixed_required_optional_fields():
|
||||||
|
"""Test Flow with both required and optional fields."""
|
||||||
|
|
||||||
|
class MixedFieldsState(BaseModel):
|
||||||
|
name: str
|
||||||
|
age: int = 25
|
||||||
|
city: str | None = None
|
||||||
|
|
||||||
|
class MixedFieldsFlow(Flow[MixedFieldsState]):
|
||||||
|
@start()
|
||||||
|
def step_1(self):
|
||||||
|
assert self.state.name == "Bob"
|
||||||
|
assert self.state.age == 25
|
||||||
|
assert self.state.city is None
|
||||||
|
|
||||||
|
flow = MixedFieldsFlow(name="Bob")
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert flow.state.name == "Bob"
|
||||||
|
assert flow.state.age == 25
|
||||||
|
assert flow.state.city is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_init_with_required_fields_and_overrides():
|
||||||
|
"""Test that kwargs override default values."""
|
||||||
|
|
||||||
|
class DefaultFieldsState(BaseModel):
|
||||||
|
name: str
|
||||||
|
age: int = 18
|
||||||
|
active: bool = True
|
||||||
|
|
||||||
|
class DefaultFieldsFlow(Flow[DefaultFieldsState]):
|
||||||
|
@start()
|
||||||
|
def step_1(self):
|
||||||
|
assert self.state.name == "Charlie"
|
||||||
|
assert self.state.age == 35
|
||||||
|
assert self.state.active is False
|
||||||
|
|
||||||
|
flow = DefaultFieldsFlow(name="Charlie", age=35, active=False)
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert flow.state.name == "Charlie"
|
||||||
|
assert flow.state.age == 35
|
||||||
|
assert flow.state.active is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_init_backward_compatibility_with_flowstate():
|
||||||
|
"""Test that existing FlowState subclasses still work."""
|
||||||
|
from crewai.flow.flow import FlowState
|
||||||
|
|
||||||
|
class MyFlowState(FlowState):
|
||||||
|
counter: int = 0
|
||||||
|
message: str = "default"
|
||||||
|
|
||||||
|
class BackwardCompatFlow(Flow[MyFlowState]):
|
||||||
|
@start()
|
||||||
|
def step_1(self):
|
||||||
|
self.state.counter += 1
|
||||||
|
|
||||||
|
flow1 = BackwardCompatFlow()
|
||||||
|
flow1.kickoff()
|
||||||
|
assert flow1.state.counter == 1
|
||||||
|
assert flow1.state.message == "default"
|
||||||
|
|
||||||
|
flow2 = BackwardCompatFlow(counter=10, message="custom")
|
||||||
|
flow2.kickoff()
|
||||||
|
assert flow2.state.counter == 11
|
||||||
|
assert flow2.state.message == "custom"
|
||||||
|
|||||||
Reference in New Issue
Block a user