mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-15 11:58:31 +00:00
feat: enhance flow event state management (#3952)
* feat: enhance flow event state management - Added `state` attribute to `FlowFinishedEvent` to capture the flow's state as a JSON-serialized dictionary. - Updated flow event emissions to include the serialized state, improving traceability and debugging capabilities during flow execution. * fix: improve state serialization in Flow class - Enhanced the `_copy_and_serialize_state` method to handle exceptions during JSON serialization of Pydantic models, ensuring robustness in state management. - Updated test assertions to access the state as a dictionary, aligning with the new state structure. --------- Co-authored-by: Greyson LaLonde <greyson.r.lalonde@gmail.com>
This commit is contained in:
@@ -64,6 +64,7 @@ class FlowFinishedEvent(FlowEvent):
|
||||
flow_name: str
|
||||
result: Any | None = None
|
||||
type: str = "flow_finished"
|
||||
state: dict[str, Any] | BaseModel
|
||||
|
||||
|
||||
class FlowPlotEvent(FlowEvent):
|
||||
|
||||
@@ -1008,6 +1008,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
type="flow_finished",
|
||||
flow_name=self.name or self.__class__.__name__,
|
||||
result=final_output,
|
||||
state=self._copy_and_serialize_state(),
|
||||
),
|
||||
)
|
||||
if future:
|
||||
@@ -1109,6 +1110,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (
|
||||
kwargs or {}
|
||||
)
|
||||
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
MethodExecutionStartedEvent(
|
||||
@@ -1116,7 +1118,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
method_name=method_name,
|
||||
flow_name=self.name or self.__class__.__name__,
|
||||
params=dumped_params,
|
||||
state=self._copy_state(),
|
||||
state=self._copy_and_serialize_state(),
|
||||
),
|
||||
)
|
||||
if future:
|
||||
@@ -1134,13 +1136,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
|
||||
self._completed_methods.add(method_name)
|
||||
|
||||
future = crewai_event_bus.emit(
|
||||
self,
|
||||
MethodExecutionFinishedEvent(
|
||||
type="method_execution_finished",
|
||||
method_name=method_name,
|
||||
flow_name=self.name or self.__class__.__name__,
|
||||
state=self._copy_state(),
|
||||
state=self._copy_and_serialize_state(),
|
||||
result=result,
|
||||
),
|
||||
)
|
||||
@@ -1162,6 +1165,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
self._event_futures.append(future)
|
||||
raise e
|
||||
|
||||
def _copy_and_serialize_state(self) -> dict[str, Any]:
|
||||
state_copy = self._copy_state()
|
||||
if isinstance(state_copy, BaseModel):
|
||||
try:
|
||||
return state_copy.model_dump(mode="json")
|
||||
except Exception:
|
||||
return state_copy.model_dump()
|
||||
else:
|
||||
return state_copy
|
||||
|
||||
async def _execute_listeners(
|
||||
self, trigger_method: FlowMethodName, result: Any
|
||||
) -> None:
|
||||
|
||||
@@ -723,11 +723,11 @@ def test_structured_flow_event_emission():
|
||||
assert isinstance(received_events[3], MethodExecutionStartedEvent)
|
||||
assert received_events[3].method_name == "send_welcome_message"
|
||||
assert received_events[3].params == {}
|
||||
assert received_events[3].state.sent is False
|
||||
assert received_events[3].state["sent"] is False
|
||||
|
||||
assert isinstance(received_events[4], MethodExecutionFinishedEvent)
|
||||
assert received_events[4].method_name == "send_welcome_message"
|
||||
assert received_events[4].state.sent is True
|
||||
assert received_events[4].state["sent"] is True
|
||||
assert received_events[4].result == "Welcome, Anakin!"
|
||||
|
||||
assert isinstance(received_events[5], FlowFinishedEvent)
|
||||
|
||||
@@ -26,6 +26,7 @@ from crewai.events.types.flow_events import (
|
||||
FlowFinishedEvent,
|
||||
FlowStartedEvent,
|
||||
MethodExecutionFailedEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.events.types.llm_events import (
|
||||
@@ -47,7 +48,7 @@ from crewai.flow.flow import Flow, listen, start
|
||||
from crewai.llm import LLM
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
import pytest
|
||||
|
||||
from ..utils import wait_for_event_handlers
|
||||
@@ -703,6 +704,156 @@ def test_flow_emits_method_execution_failed_event():
|
||||
assert received_events[0].error == error
|
||||
|
||||
|
||||
def test_flow_method_execution_started_includes_unstructured_state():
|
||||
"""Test that MethodExecutionStartedEvent includes unstructured (dict) state."""
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
def handle_method_started(source, event):
|
||||
received_events.append(event)
|
||||
if event.method_name == "process":
|
||||
event_received.set()
|
||||
|
||||
class TestFlow(Flow[dict]):
|
||||
@start()
|
||||
def begin(self):
|
||||
self.state["counter"] = 1
|
||||
self.state["message"] = "test"
|
||||
return "started"
|
||||
|
||||
@listen("begin")
|
||||
def process(self):
|
||||
self.state["counter"] = 2
|
||||
return "processed"
|
||||
|
||||
flow = TestFlow()
|
||||
flow.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), (
|
||||
"Timeout waiting for method execution started event"
|
||||
)
|
||||
|
||||
# Find the events for each method
|
||||
begin_event = next(e for e in received_events if e.method_name == "begin")
|
||||
process_event = next(e for e in received_events if e.method_name == "process")
|
||||
|
||||
# Verify state is included and is a dict
|
||||
assert begin_event.state is not None
|
||||
assert isinstance(begin_event.state, dict)
|
||||
assert "id" in begin_event.state # Auto-generated ID
|
||||
|
||||
# Verify state from begin method is captured in process event
|
||||
assert process_event.state is not None
|
||||
assert isinstance(process_event.state, dict)
|
||||
assert process_event.state["counter"] == 1
|
||||
assert process_event.state["message"] == "test"
|
||||
|
||||
|
||||
def test_flow_method_execution_started_includes_structured_state():
|
||||
"""Test that MethodExecutionStartedEvent includes structured (BaseModel) state and serializes it properly."""
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
class FlowState(BaseModel):
|
||||
counter: int = 0
|
||||
message: str = ""
|
||||
items: list[str] = []
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionStartedEvent)
|
||||
def handle_method_started(source, event):
|
||||
received_events.append(event)
|
||||
if event.method_name == "process":
|
||||
event_received.set()
|
||||
|
||||
class TestFlow(Flow[FlowState]):
|
||||
@start()
|
||||
def begin(self):
|
||||
self.state.counter = 1
|
||||
self.state.message = "initial"
|
||||
self.state.items = ["a", "b"]
|
||||
return "started"
|
||||
|
||||
@listen("begin")
|
||||
def process(self):
|
||||
self.state.counter += 1
|
||||
return "processed"
|
||||
|
||||
flow = TestFlow()
|
||||
flow.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), (
|
||||
"Timeout waiting for method execution started event"
|
||||
)
|
||||
|
||||
begin_event = next(e for e in received_events if e.method_name == "begin")
|
||||
process_event = next(e for e in received_events if e.method_name == "process")
|
||||
|
||||
assert begin_event.state is not None
|
||||
assert isinstance(begin_event.state, dict)
|
||||
assert begin_event.state["counter"] == 0 # Initial state
|
||||
assert begin_event.state["message"] == ""
|
||||
assert begin_event.state["items"] == []
|
||||
|
||||
assert process_event.state is not None
|
||||
assert isinstance(process_event.state, dict)
|
||||
assert process_event.state["counter"] == 1
|
||||
assert process_event.state["message"] == "initial"
|
||||
assert process_event.state["items"] == ["a", "b"]
|
||||
|
||||
|
||||
def test_flow_method_execution_finished_includes_serialized_state():
|
||||
"""Test that MethodExecutionFinishedEvent includes properly serialized state."""
|
||||
received_events = []
|
||||
event_received = threading.Event()
|
||||
|
||||
class FlowState(BaseModel):
|
||||
result: str = ""
|
||||
completed: bool = False
|
||||
|
||||
@crewai_event_bus.on(MethodExecutionFinishedEvent)
|
||||
def handle_method_finished(source, event):
|
||||
received_events.append(event)
|
||||
if event.method_name == "process":
|
||||
event_received.set()
|
||||
|
||||
class TestFlow(Flow[FlowState]):
|
||||
@start()
|
||||
def begin(self):
|
||||
self.state.result = "begin done"
|
||||
return "started"
|
||||
|
||||
@listen("begin")
|
||||
def process(self):
|
||||
self.state.result = "process done"
|
||||
self.state.completed = True
|
||||
return "final_result"
|
||||
|
||||
flow = TestFlow()
|
||||
final_output = flow.kickoff()
|
||||
|
||||
assert event_received.wait(timeout=5), (
|
||||
"Timeout waiting for method execution finished event"
|
||||
)
|
||||
|
||||
begin_finished = next(e for e in received_events if e.method_name == "begin")
|
||||
process_finished = next(e for e in received_events if e.method_name == "process")
|
||||
|
||||
assert begin_finished.state is not None
|
||||
assert isinstance(begin_finished.state, dict)
|
||||
assert begin_finished.state["result"] == "begin done"
|
||||
assert begin_finished.state["completed"] is False
|
||||
assert begin_finished.result == "started"
|
||||
|
||||
# Verify process finished event has final state and result
|
||||
assert process_finished.state is not None
|
||||
assert isinstance(process_finished.state, dict)
|
||||
assert process_finished.state["result"] == "process done"
|
||||
assert process_finished.state["completed"] is True
|
||||
assert process_finished.result == "final_result"
|
||||
assert final_output == "final_result"
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_emits_call_started_event():
|
||||
received_events = []
|
||||
|
||||
Reference in New Issue
Block a user