feat: enhance flow event state management (#3952)

* feat: enhance flow event state management

- Added `state` attribute to `FlowFinishedEvent` to capture the flow's state as a JSON-serialized dictionary.
- Updated flow event emissions to include the serialized state, improving traceability and debugging capabilities during flow execution.

* fix: improve state serialization in Flow class

- Enhanced the `_copy_and_serialize_state` method to handle exceptions during JSON serialization of Pydantic models, ensuring robustness in state management.
- Updated test assertions to access the state as a dictionary, aligning with the new state structure.

---------

Co-authored-by: Greyson LaLonde <greyson.r.lalonde@gmail.com>
This commit is contained in:
Lorenze Jay
2025-11-24 15:55:49 -08:00
committed by GitHub
parent b049b73f2e
commit 4ae8c36815
4 changed files with 170 additions and 5 deletions

View File

@@ -64,6 +64,7 @@ class FlowFinishedEvent(FlowEvent):
flow_name: str
result: Any | None = None
type: str = "flow_finished"
state: dict[str, Any] | BaseModel
class FlowPlotEvent(FlowEvent):

View File

@@ -1008,6 +1008,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
type="flow_finished",
flow_name=self.name or self.__class__.__name__,
result=final_output,
state=self._copy_and_serialize_state(),
),
)
if future:
@@ -1109,6 +1110,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (
kwargs or {}
)
future = crewai_event_bus.emit(
self,
MethodExecutionStartedEvent(
@@ -1116,7 +1118,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
method_name=method_name,
flow_name=self.name or self.__class__.__name__,
params=dumped_params,
state=self._copy_state(),
state=self._copy_and_serialize_state(),
),
)
if future:
@@ -1134,13 +1136,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
)
self._completed_methods.add(method_name)
future = crewai_event_bus.emit(
self,
MethodExecutionFinishedEvent(
type="method_execution_finished",
method_name=method_name,
flow_name=self.name or self.__class__.__name__,
state=self._copy_state(),
state=self._copy_and_serialize_state(),
result=result,
),
)
@@ -1162,6 +1165,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._event_futures.append(future)
raise e
def _copy_and_serialize_state(self) -> dict[str, Any]:
state_copy = self._copy_state()
if isinstance(state_copy, BaseModel):
try:
return state_copy.model_dump(mode="json")
except Exception:
return state_copy.model_dump()
else:
return state_copy
async def _execute_listeners(
self, trigger_method: FlowMethodName, result: Any
) -> None:

View File

@@ -723,11 +723,11 @@ def test_structured_flow_event_emission():
assert isinstance(received_events[3], MethodExecutionStartedEvent)
assert received_events[3].method_name == "send_welcome_message"
assert received_events[3].params == {}
assert received_events[3].state.sent is False
assert received_events[3].state["sent"] is False
assert isinstance(received_events[4], MethodExecutionFinishedEvent)
assert received_events[4].method_name == "send_welcome_message"
assert received_events[4].state.sent is True
assert received_events[4].state["sent"] is True
assert received_events[4].result == "Welcome, Anakin!"
assert isinstance(received_events[5], FlowFinishedEvent)

View File

@@ -26,6 +26,7 @@ from crewai.events.types.flow_events import (
FlowFinishedEvent,
FlowStartedEvent,
MethodExecutionFailedEvent,
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from crewai.events.types.llm_events import (
@@ -47,7 +48,7 @@ from crewai.flow.flow import Flow, listen, start
from crewai.llm import LLM
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
from pydantic import Field
from pydantic import BaseModel, Field
import pytest
from ..utils import wait_for_event_handlers
@@ -703,6 +704,156 @@ def test_flow_emits_method_execution_failed_event():
assert received_events[0].error == error
def test_flow_method_execution_started_includes_unstructured_state():
"""Test that MethodExecutionStartedEvent includes unstructured (dict) state."""
received_events = []
event_received = threading.Event()
@crewai_event_bus.on(MethodExecutionStartedEvent)
def handle_method_started(source, event):
received_events.append(event)
if event.method_name == "process":
event_received.set()
class TestFlow(Flow[dict]):
@start()
def begin(self):
self.state["counter"] = 1
self.state["message"] = "test"
return "started"
@listen("begin")
def process(self):
self.state["counter"] = 2
return "processed"
flow = TestFlow()
flow.kickoff()
assert event_received.wait(timeout=5), (
"Timeout waiting for method execution started event"
)
# Find the events for each method
begin_event = next(e for e in received_events if e.method_name == "begin")
process_event = next(e for e in received_events if e.method_name == "process")
# Verify state is included and is a dict
assert begin_event.state is not None
assert isinstance(begin_event.state, dict)
assert "id" in begin_event.state # Auto-generated ID
# Verify state from begin method is captured in process event
assert process_event.state is not None
assert isinstance(process_event.state, dict)
assert process_event.state["counter"] == 1
assert process_event.state["message"] == "test"
def test_flow_method_execution_started_includes_structured_state():
"""Test that MethodExecutionStartedEvent includes structured (BaseModel) state and serializes it properly."""
received_events = []
event_received = threading.Event()
class FlowState(BaseModel):
counter: int = 0
message: str = ""
items: list[str] = []
@crewai_event_bus.on(MethodExecutionStartedEvent)
def handle_method_started(source, event):
received_events.append(event)
if event.method_name == "process":
event_received.set()
class TestFlow(Flow[FlowState]):
@start()
def begin(self):
self.state.counter = 1
self.state.message = "initial"
self.state.items = ["a", "b"]
return "started"
@listen("begin")
def process(self):
self.state.counter += 1
return "processed"
flow = TestFlow()
flow.kickoff()
assert event_received.wait(timeout=5), (
"Timeout waiting for method execution started event"
)
begin_event = next(e for e in received_events if e.method_name == "begin")
process_event = next(e for e in received_events if e.method_name == "process")
assert begin_event.state is not None
assert isinstance(begin_event.state, dict)
assert begin_event.state["counter"] == 0 # Initial state
assert begin_event.state["message"] == ""
assert begin_event.state["items"] == []
assert process_event.state is not None
assert isinstance(process_event.state, dict)
assert process_event.state["counter"] == 1
assert process_event.state["message"] == "initial"
assert process_event.state["items"] == ["a", "b"]
def test_flow_method_execution_finished_includes_serialized_state():
"""Test that MethodExecutionFinishedEvent includes properly serialized state."""
received_events = []
event_received = threading.Event()
class FlowState(BaseModel):
result: str = ""
completed: bool = False
@crewai_event_bus.on(MethodExecutionFinishedEvent)
def handle_method_finished(source, event):
received_events.append(event)
if event.method_name == "process":
event_received.set()
class TestFlow(Flow[FlowState]):
@start()
def begin(self):
self.state.result = "begin done"
return "started"
@listen("begin")
def process(self):
self.state.result = "process done"
self.state.completed = True
return "final_result"
flow = TestFlow()
final_output = flow.kickoff()
assert event_received.wait(timeout=5), (
"Timeout waiting for method execution finished event"
)
begin_finished = next(e for e in received_events if e.method_name == "begin")
process_finished = next(e for e in received_events if e.method_name == "process")
assert begin_finished.state is not None
assert isinstance(begin_finished.state, dict)
assert begin_finished.state["result"] == "begin done"
assert begin_finished.state["completed"] is False
assert begin_finished.result == "started"
# Verify process finished event has final state and result
assert process_finished.state is not None
assert isinstance(process_finished.state, dict)
assert process_finished.state["result"] == "process done"
assert process_finished.state["completed"] is True
assert process_finished.result == "final_result"
assert final_output == "final_result"
@pytest.mark.vcr(filter_headers=["authorization"])
def test_llm_emits_call_started_event():
received_events = []