From d6800d8957dc10cb600325a8e7b80cf40d4d9bd3 Mon Sep 17 00:00:00 2001 From: Vini Brasil Date: Wed, 12 Feb 2025 14:19:41 -0600 Subject: [PATCH] Ensure `@start` methods emit `MethodExecutionStartedEvent` (#2114) Previously, `@start` methods triggered a `FlowStartedEvent` but did not emit a `MethodExecutionStartedEvent`. This was fine for a single entry point but caused ambiguity when multiple `@start` methods existed. This commit (1) emits events for starting points, (2) adds tests ensuring ordering, (3) adds more fields to events. --- src/crewai/flow/flow.py | 49 +++++--- src/crewai/flow/flow_events.py | 10 +- tests/flow_test.py | 222 +++++++++++++++++++++++++++++++++ 3 files changed, 259 insertions(+), 22 deletions(-) diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 382a792e5..f1242a2bf 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -1,4 +1,5 @@ import asyncio +import copy import inspect import logging from typing import ( @@ -394,7 +395,6 @@ class FlowMeta(type): or hasattr(attr_value, "__trigger_methods__") or hasattr(attr_value, "__is_router__") ): - # Register start methods if hasattr(attr_value, "__is_start_method__"): start_methods.append(attr_name) @@ -569,6 +569,9 @@ class Flow(Generic[T], metaclass=FlowMeta): f"Initial state must be dict or BaseModel, got {type(self.initial_state)}" ) + def _copy_state(self) -> T: + return copy.deepcopy(self._state) + @property def state(self) -> T: return self._state @@ -740,6 +743,7 @@ class Flow(Generic[T], metaclass=FlowMeta): event=FlowStartedEvent( type="flow_started", flow_name=self.__class__.__name__, + inputs=inputs, ), ) self._log_flow_event( @@ -803,6 +807,18 @@ class Flow(Generic[T], metaclass=FlowMeta): async def _execute_method( self, method_name: str, method: Callable, *args: Any, **kwargs: Any ) -> Any: + dumped_params = {f"_{i}": arg for i, arg in enumerate(args)} | (kwargs or {}) + self.event_emitter.send( + self, + event=MethodExecutionStartedEvent( + type="method_execution_started", + method_name=method_name, + flow_name=self.__class__.__name__, + params=dumped_params, + state=self._copy_state(), + ), + ) + result = ( await method(*args, **kwargs) if asyncio.iscoroutinefunction(method) @@ -812,6 +828,18 @@ class Flow(Generic[T], metaclass=FlowMeta): self._method_execution_counts[method_name] = ( self._method_execution_counts.get(method_name, 0) + 1 ) + + self.event_emitter.send( + self, + event=MethodExecutionFinishedEvent( + type="method_execution_finished", + method_name=method_name, + flow_name=self.__class__.__name__, + state=self._copy_state(), + result=result, + ), + ) + return result async def _execute_listeners(self, trigger_method: str, result: Any) -> None: @@ -950,16 +978,6 @@ class Flow(Generic[T], metaclass=FlowMeta): """ try: method = self._methods[listener_name] - - self.event_emitter.send( - self, - event=MethodExecutionStartedEvent( - type="method_execution_started", - method_name=listener_name, - flow_name=self.__class__.__name__, - ), - ) - sig = inspect.signature(method) params = list(sig.parameters.values()) method_params = [p for p in params if p.name != "self"] @@ -971,15 +989,6 @@ class Flow(Generic[T], metaclass=FlowMeta): else: listener_result = await self._execute_method(listener_name, method) - self.event_emitter.send( - self, - event=MethodExecutionFinishedEvent( - type="method_execution_finished", - method_name=listener_name, - flow_name=self.__class__.__name__, - ), - ) - # Execute listeners (and possibly routers) of this listener await self._execute_listeners(listener_name, listener_result) diff --git a/src/crewai/flow/flow_events.py b/src/crewai/flow/flow_events.py index 068005ebe..c8f9e9694 100644 --- a/src/crewai/flow/flow_events.py +++ b/src/crewai/flow/flow_events.py @@ -1,6 +1,8 @@ from dataclasses import dataclass, field from datetime import datetime -from typing import Any, Optional +from typing import Any, Dict, Optional, Union + +from pydantic import BaseModel @dataclass @@ -15,17 +17,21 @@ class Event: @dataclass class FlowStartedEvent(Event): - pass + inputs: Optional[Dict[str, Any]] = None @dataclass class MethodExecutionStartedEvent(Event): method_name: str + state: Union[Dict[str, Any], BaseModel] + params: Optional[Dict[str, Any]] = None @dataclass class MethodExecutionFinishedEvent(Event): method_name: str + state: Union[Dict[str, Any], BaseModel] + result: Any = None @dataclass diff --git a/tests/flow_test.py b/tests/flow_test.py index 44ea1d15d..c416d4a7d 100644 --- a/tests/flow_test.py +++ b/tests/flow_test.py @@ -1,11 +1,18 @@ """Test Flow creation and execution basic functionality.""" import asyncio +from datetime import datetime import pytest from pydantic import BaseModel from crewai.flow.flow import Flow, and_, listen, or_, router, start +from crewai.flow.flow_events import ( + FlowFinishedEvent, + FlowStartedEvent, + MethodExecutionFinishedEvent, + MethodExecutionStartedEvent, +) def test_simple_sequential_flow(): @@ -398,3 +405,218 @@ def test_router_with_multiple_conditions(): # final_step should run after router_and assert execution_order.index("log_final_step") > execution_order.index("router_and") + + +def test_unstructured_flow_event_emission(): + """Test that the correct events are emitted during unstructured flow + execution with all fields validated.""" + + class PoemFlow(Flow): + @start() + def prepare_flower(self): + self.state["flower"] = "roses" + return "foo" + + @start() + def prepare_color(self): + self.state["color"] = "red" + return "bar" + + @listen(prepare_color) + def write_first_sentence(self): + return f"{self.state["flower"]} are {self.state["color"]}" + + @listen(write_first_sentence) + def finish_poem(self, first_sentence): + separator = self.state.get("separator", "\n") + return separator.join([first_sentence, "violets are blue"]) + + @listen(finish_poem) + def save_poem_to_database(self): + # A method without args/kwargs to ensure events are sent correctly + pass + + event_log = [] + + def handle_event(_, event): + event_log.append(event) + + flow = PoemFlow() + flow.event_emitter.connect(handle_event) + flow.kickoff(inputs={"separator": ", "}) + + assert isinstance(event_log[0], FlowStartedEvent) + assert event_log[0].flow_name == "PoemFlow" + assert event_log[0].inputs == {"separator": ", "} + assert isinstance(event_log[0].timestamp, datetime) + + # Asserting for concurrent start method executions in a for loop as you + # can't guarantee ordering in asynchronous executions + for i in range(1, 5): + event = event_log[i] + assert isinstance(event.state, dict) + assert isinstance(event.state["id"], str) + + if event.method_name == "prepare_flower": + if isinstance(event, MethodExecutionStartedEvent): + assert event.params == {} + assert event.state["separator"] == ", " + elif isinstance(event, MethodExecutionFinishedEvent): + assert event.result == "foo" + assert event.state["flower"] == "roses" + assert event.state["separator"] == ", " + else: + assert False, "Unexpected event type for prepare_flower" + elif event.method_name == "prepare_color": + if isinstance(event, MethodExecutionStartedEvent): + assert event.params == {} + assert event.state["separator"] == ", " + elif isinstance(event, MethodExecutionFinishedEvent): + assert event.result == "bar" + assert event.state["color"] == "red" + assert event.state["separator"] == ", " + else: + assert False, "Unexpected event type for prepare_color" + else: + assert False, f"Unexpected method {event.method_name} in prepare events" + + assert isinstance(event_log[5], MethodExecutionStartedEvent) + assert event_log[5].method_name == "write_first_sentence" + assert event_log[5].params == {} + assert isinstance(event_log[5].state, dict) + assert event_log[5].state["flower"] == "roses" + assert event_log[5].state["color"] == "red" + assert event_log[5].state["separator"] == ", " + + assert isinstance(event_log[6], MethodExecutionFinishedEvent) + assert event_log[6].method_name == "write_first_sentence" + assert event_log[6].result == "roses are red" + + assert isinstance(event_log[7], MethodExecutionStartedEvent) + assert event_log[7].method_name == "finish_poem" + assert event_log[7].params == {"_0": "roses are red"} + assert isinstance(event_log[7].state, dict) + assert event_log[7].state["flower"] == "roses" + assert event_log[7].state["color"] == "red" + + assert isinstance(event_log[8], MethodExecutionFinishedEvent) + assert event_log[8].method_name == "finish_poem" + assert event_log[8].result == "roses are red, violets are blue" + + assert isinstance(event_log[9], MethodExecutionStartedEvent) + assert event_log[9].method_name == "save_poem_to_database" + assert event_log[9].params == {} + assert isinstance(event_log[9].state, dict) + assert event_log[9].state["flower"] == "roses" + assert event_log[9].state["color"] == "red" + + assert isinstance(event_log[10], MethodExecutionFinishedEvent) + assert event_log[10].method_name == "save_poem_to_database" + assert event_log[10].result is None + + assert isinstance(event_log[11], FlowFinishedEvent) + assert event_log[11].flow_name == "PoemFlow" + assert event_log[11].result is None + assert isinstance(event_log[11].timestamp, datetime) + + +def test_structured_flow_event_emission(): + """Test that the correct events are emitted during structured flow + execution with all fields validated.""" + + class OnboardingState(BaseModel): + name: str = "" + sent: bool = False + + class OnboardingFlow(Flow[OnboardingState]): + @start() + def user_signs_up(self): + self.state.sent = False + + @listen(user_signs_up) + def send_welcome_message(self): + self.state.sent = True + return f"Welcome, {self.state.name}!" + + event_log = [] + + def handle_event(_, event): + event_log.append(event) + + flow = OnboardingFlow() + flow.event_emitter.connect(handle_event) + flow.kickoff(inputs={"name": "Anakin"}) + + assert isinstance(event_log[0], FlowStartedEvent) + assert event_log[0].flow_name == "OnboardingFlow" + assert event_log[0].inputs == {"name": "Anakin"} + assert isinstance(event_log[0].timestamp, datetime) + + assert isinstance(event_log[1], MethodExecutionStartedEvent) + assert event_log[1].method_name == "user_signs_up" + + assert isinstance(event_log[2], MethodExecutionFinishedEvent) + assert event_log[2].method_name == "user_signs_up" + + assert isinstance(event_log[3], MethodExecutionStartedEvent) + assert event_log[3].method_name == "send_welcome_message" + assert event_log[3].params == {} + assert getattr(event_log[3].state, "sent") == False + + assert isinstance(event_log[4], MethodExecutionFinishedEvent) + assert event_log[4].method_name == "send_welcome_message" + assert getattr(event_log[4].state, "sent") == True + assert event_log[4].result == "Welcome, Anakin!" + + assert isinstance(event_log[5], FlowFinishedEvent) + assert event_log[5].flow_name == "OnboardingFlow" + assert event_log[5].result == "Welcome, Anakin!" + assert isinstance(event_log[5].timestamp, datetime) + + +def test_stateless_flow_event_emission(): + """Test that the correct events are emitted stateless during flow execution + with all fields validated.""" + + class StatelessFlow(Flow): + @start() + def init(self): + pass + + @listen(init) + def process(self): + return "Deeds will not be less valiant because they are unpraised." + + event_log = [] + + def handle_event(_, event): + event_log.append(event) + + flow = StatelessFlow() + flow.event_emitter.connect(handle_event) + flow.kickoff() + + assert isinstance(event_log[0], FlowStartedEvent) + assert event_log[0].flow_name == "StatelessFlow" + assert event_log[0].inputs is None + assert isinstance(event_log[0].timestamp, datetime) + + assert isinstance(event_log[1], MethodExecutionStartedEvent) + assert event_log[1].method_name == "init" + + assert isinstance(event_log[2], MethodExecutionFinishedEvent) + assert event_log[2].method_name == "init" + + assert isinstance(event_log[3], MethodExecutionStartedEvent) + assert event_log[3].method_name == "process" + + assert isinstance(event_log[4], MethodExecutionFinishedEvent) + assert event_log[4].method_name == "process" + + assert isinstance(event_log[5], FlowFinishedEvent) + assert event_log[5].flow_name == "StatelessFlow" + assert ( + event_log[5].result + == "Deeds will not be less valiant because they are unpraised." + ) + assert isinstance(event_log[5].timestamp, datetime)