From cb522cf5005f856b21c6976e8d94709ae4f9c3f3 Mon Sep 17 00:00:00 2001 From: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com> Date: Tue, 29 Jul 2025 15:41:30 -0700 Subject: [PATCH] Enhance Flow class to support custom flow names (#3234) - Added an optional `name` attribute to the Flow class for better identification. - Updated event emissions to utilize the new `name` attribute, ensuring accurate flow naming in events. - Added tests to verify the correct flow name is set and emitted during flow execution. --- src/crewai/flow/flow.py | 15 ++++++++------- tests/flow_test.py | 12 ++++++++++++ tests/utilities/test_events.py | 35 +++++++++++++++++++++++++++++++--- 3 files changed, 52 insertions(+), 10 deletions(-) diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 99ae82c96..9bd9e3b6a 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -436,6 +436,7 @@ class Flow(Generic[T], metaclass=FlowMeta): _routers: Set[str] = set() _router_paths: Dict[str, List[str]] = {} initial_state: Union[Type[T], T, None] = None + name: Optional[str] = None def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]: class _FlowGeneric(cls): # type: ignore @@ -473,7 +474,7 @@ class Flow(Generic[T], metaclass=FlowMeta): self, FlowCreatedEvent( type="flow_created", - flow_name=self.__class__.__name__, + flow_name=self.name or self.__class__.__name__, ), ) @@ -769,7 +770,7 @@ class Flow(Generic[T], metaclass=FlowMeta): self, FlowStartedEvent( type="flow_started", - flow_name=self.__class__.__name__, + flow_name=self.name or self.__class__.__name__, inputs=inputs, ), ) @@ -792,7 +793,7 @@ class Flow(Generic[T], metaclass=FlowMeta): self, FlowFinishedEvent( type="flow_finished", - flow_name=self.__class__.__name__, + flow_name=self.name or self.__class__.__name__, result=final_output, ), ) @@ -834,7 +835,7 @@ class Flow(Generic[T], metaclass=FlowMeta): MethodExecutionStartedEvent( type="method_execution_started", method_name=method_name, - flow_name=self.__class__.__name__, + flow_name=self.name or self.__class__.__name__, params=dumped_params, state=self._copy_state(), ), @@ -856,7 +857,7 @@ class Flow(Generic[T], metaclass=FlowMeta): MethodExecutionFinishedEvent( type="method_execution_finished", method_name=method_name, - flow_name=self.__class__.__name__, + flow_name=self.name or self.__class__.__name__, state=self._copy_state(), result=result, ), @@ -869,7 +870,7 @@ class Flow(Generic[T], metaclass=FlowMeta): MethodExecutionFailedEvent( type="method_execution_failed", method_name=method_name, - flow_name=self.__class__.__name__, + flow_name=self.name or self.__class__.__name__, error=e, ), ) @@ -1076,7 +1077,7 @@ class Flow(Generic[T], metaclass=FlowMeta): self, FlowPlotEvent( type="flow_plot", - flow_name=self.__class__.__name__, + flow_name=self.name or self.__class__.__name__, ), ) plot_flow(self, filename) diff --git a/tests/flow_test.py b/tests/flow_test.py index c2640fffb..32b93cd05 100644 --- a/tests/flow_test.py +++ b/tests/flow_test.py @@ -755,3 +755,15 @@ def test_multiple_routers_from_same_trigger(): assert execution_order.index("anemia_analysis") > execution_order.index( "anemia_router" ) + + +def test_flow_name(): + class MyFlow(Flow): + name = "MyFlow" + + @start() + def start(self): + return "Hello, world!" + + flow = MyFlow() + assert flow.name == "MyFlow" diff --git a/tests/utilities/test_events.py b/tests/utilities/test_events.py index 6962291c8..35d1fd887 100644 --- a/tests/utilities/test_events.py +++ b/tests/utilities/test_events.py @@ -64,7 +64,8 @@ def base_agent(): llm="gpt-4o-mini", goal="Just say hi", backstory="You are a helpful assistant that just says hi", -) + ) + @pytest.fixture(scope="module") def base_task(base_agent): @@ -74,6 +75,7 @@ def base_task(base_agent): agent=base_agent, ) + event_listener = EventListener() @@ -448,6 +450,27 @@ def test_flow_emits_start_event(): assert received_events[0].type == "flow_started" +def test_flow_name_emitted_to_event_bus(): + received_events = [] + + class MyFlowClass(Flow): + name = "PRODUCTION_FLOW" + + @start() + def start(self): + return "Hello, world!" + + @crewai_event_bus.on(FlowStartedEvent) + def handle_flow_start(source, event): + received_events.append(event) + + flow = MyFlowClass() + flow.kickoff() + + assert len(received_events) == 1 + assert received_events[0].flow_name == "PRODUCTION_FLOW" + + def test_flow_emits_finish_event(): received_events = [] @@ -756,6 +779,7 @@ def test_streaming_empty_response_handling(): received_chunks = [] with crewai_event_bus.scoped_handlers(): + @crewai_event_bus.on(LLMStreamChunkEvent) def handle_stream_chunk(source, event): received_chunks.append(event.chunk) @@ -793,6 +817,7 @@ def test_streaming_empty_response_handling(): # Restore the original method llm.call = original_call + @pytest.mark.vcr(filter_headers=["authorization"]) def test_stream_llm_emits_event_with_task_and_agent_info(): completed_event = [] @@ -801,6 +826,7 @@ def test_stream_llm_emits_event_with_task_and_agent_info(): stream_event = [] with crewai_event_bus.scoped_handlers(): + @crewai_event_bus.on(LLMCallFailedEvent) def handle_llm_failed(source, event): failed_event.append(event) @@ -827,7 +853,7 @@ def test_stream_llm_emits_event_with_task_and_agent_info(): description="Just say hi", expected_output="hi", llm=LLM(model="gpt-4o-mini", stream=True), - agent=agent + agent=agent, ) crew = Crew(agents=[agent], tasks=[task]) @@ -855,6 +881,7 @@ def test_stream_llm_emits_event_with_task_and_agent_info(): assert set(all_task_id) == {task.id} assert set(all_task_name) == {task.name} + @pytest.mark.vcr(filter_headers=["authorization"]) def test_llm_emits_event_with_task_and_agent_info(base_agent, base_task): completed_event = [] @@ -863,6 +890,7 @@ def test_llm_emits_event_with_task_and_agent_info(base_agent, base_task): stream_event = [] with crewai_event_bus.scoped_handlers(): + @crewai_event_bus.on(LLMCallFailedEvent) def handle_llm_failed(source, event): failed_event.append(event) @@ -904,6 +932,7 @@ def test_llm_emits_event_with_task_and_agent_info(base_agent, base_task): assert set(all_task_id) == {base_task.id} assert set(all_task_name) == {base_task.name} + @pytest.mark.vcr(filter_headers=["authorization"]) def test_llm_emits_event_with_lite_agent(): completed_event = [] @@ -912,6 +941,7 @@ def test_llm_emits_event_with_lite_agent(): stream_event = [] with crewai_event_bus.scoped_handlers(): + @crewai_event_bus.on(LLMCallFailedEvent) def handle_llm_failed(source, event): failed_event.append(event) @@ -936,7 +966,6 @@ def test_llm_emits_event_with_lite_agent(): ) agent.kickoff(messages=[{"role": "user", "content": "say hi!"}]) - assert len(completed_event) == 2 assert len(failed_event) == 0 assert len(started_event) == 2