mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-02 07:42:40 +00:00
Enhance Flow class to support custom flow names (#3234)
- Added an optional `name` attribute to the Flow class for better identification. - Updated event emissions to utilize the new `name` attribute, ensuring accurate flow naming in events. - Added tests to verify the correct flow name is set and emitted during flow execution.
This commit is contained in:
@@ -436,6 +436,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
_routers: Set[str] = set()
|
_routers: Set[str] = set()
|
||||||
_router_paths: Dict[str, List[str]] = {}
|
_router_paths: Dict[str, List[str]] = {}
|
||||||
initial_state: Union[Type[T], T, None] = None
|
initial_state: Union[Type[T], T, None] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
|
||||||
def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]:
|
def __class_getitem__(cls: Type["Flow"], item: Type[T]) -> Type["Flow"]:
|
||||||
class _FlowGeneric(cls): # type: ignore
|
class _FlowGeneric(cls): # type: ignore
|
||||||
@@ -473,7 +474,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
self,
|
self,
|
||||||
FlowCreatedEvent(
|
FlowCreatedEvent(
|
||||||
type="flow_created",
|
type="flow_created",
|
||||||
flow_name=self.__class__.__name__,
|
flow_name=self.name or self.__class__.__name__,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -769,7 +770,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
self,
|
self,
|
||||||
FlowStartedEvent(
|
FlowStartedEvent(
|
||||||
type="flow_started",
|
type="flow_started",
|
||||||
flow_name=self.__class__.__name__,
|
flow_name=self.name or self.__class__.__name__,
|
||||||
inputs=inputs,
|
inputs=inputs,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -792,7 +793,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
self,
|
self,
|
||||||
FlowFinishedEvent(
|
FlowFinishedEvent(
|
||||||
type="flow_finished",
|
type="flow_finished",
|
||||||
flow_name=self.__class__.__name__,
|
flow_name=self.name or self.__class__.__name__,
|
||||||
result=final_output,
|
result=final_output,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -834,7 +835,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
MethodExecutionStartedEvent(
|
MethodExecutionStartedEvent(
|
||||||
type="method_execution_started",
|
type="method_execution_started",
|
||||||
method_name=method_name,
|
method_name=method_name,
|
||||||
flow_name=self.__class__.__name__,
|
flow_name=self.name or self.__class__.__name__,
|
||||||
params=dumped_params,
|
params=dumped_params,
|
||||||
state=self._copy_state(),
|
state=self._copy_state(),
|
||||||
),
|
),
|
||||||
@@ -856,7 +857,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
MethodExecutionFinishedEvent(
|
MethodExecutionFinishedEvent(
|
||||||
type="method_execution_finished",
|
type="method_execution_finished",
|
||||||
method_name=method_name,
|
method_name=method_name,
|
||||||
flow_name=self.__class__.__name__,
|
flow_name=self.name or self.__class__.__name__,
|
||||||
state=self._copy_state(),
|
state=self._copy_state(),
|
||||||
result=result,
|
result=result,
|
||||||
),
|
),
|
||||||
@@ -869,7 +870,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
MethodExecutionFailedEvent(
|
MethodExecutionFailedEvent(
|
||||||
type="method_execution_failed",
|
type="method_execution_failed",
|
||||||
method_name=method_name,
|
method_name=method_name,
|
||||||
flow_name=self.__class__.__name__,
|
flow_name=self.name or self.__class__.__name__,
|
||||||
error=e,
|
error=e,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -1076,7 +1077,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
self,
|
self,
|
||||||
FlowPlotEvent(
|
FlowPlotEvent(
|
||||||
type="flow_plot",
|
type="flow_plot",
|
||||||
flow_name=self.__class__.__name__,
|
flow_name=self.name or self.__class__.__name__,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
plot_flow(self, filename)
|
plot_flow(self, filename)
|
||||||
|
|||||||
@@ -755,3 +755,15 @@ def test_multiple_routers_from_same_trigger():
|
|||||||
assert execution_order.index("anemia_analysis") > execution_order.index(
|
assert execution_order.index("anemia_analysis") > execution_order.index(
|
||||||
"anemia_router"
|
"anemia_router"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_name():
|
||||||
|
class MyFlow(Flow):
|
||||||
|
name = "MyFlow"
|
||||||
|
|
||||||
|
@start()
|
||||||
|
def start(self):
|
||||||
|
return "Hello, world!"
|
||||||
|
|
||||||
|
flow = MyFlow()
|
||||||
|
assert flow.name == "MyFlow"
|
||||||
|
|||||||
@@ -64,7 +64,8 @@ def base_agent():
|
|||||||
llm="gpt-4o-mini",
|
llm="gpt-4o-mini",
|
||||||
goal="Just say hi",
|
goal="Just say hi",
|
||||||
backstory="You are a helpful assistant that just says hi",
|
backstory="You are a helpful assistant that just says hi",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def base_task(base_agent):
|
def base_task(base_agent):
|
||||||
@@ -74,6 +75,7 @@ def base_task(base_agent):
|
|||||||
agent=base_agent,
|
agent=base_agent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
event_listener = EventListener()
|
event_listener = EventListener()
|
||||||
|
|
||||||
|
|
||||||
@@ -448,6 +450,27 @@ def test_flow_emits_start_event():
|
|||||||
assert received_events[0].type == "flow_started"
|
assert received_events[0].type == "flow_started"
|
||||||
|
|
||||||
|
|
||||||
|
def test_flow_name_emitted_to_event_bus():
|
||||||
|
received_events = []
|
||||||
|
|
||||||
|
class MyFlowClass(Flow):
|
||||||
|
name = "PRODUCTION_FLOW"
|
||||||
|
|
||||||
|
@start()
|
||||||
|
def start(self):
|
||||||
|
return "Hello, world!"
|
||||||
|
|
||||||
|
@crewai_event_bus.on(FlowStartedEvent)
|
||||||
|
def handle_flow_start(source, event):
|
||||||
|
received_events.append(event)
|
||||||
|
|
||||||
|
flow = MyFlowClass()
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
assert len(received_events) == 1
|
||||||
|
assert received_events[0].flow_name == "PRODUCTION_FLOW"
|
||||||
|
|
||||||
|
|
||||||
def test_flow_emits_finish_event():
|
def test_flow_emits_finish_event():
|
||||||
received_events = []
|
received_events = []
|
||||||
|
|
||||||
@@ -756,6 +779,7 @@ def test_streaming_empty_response_handling():
|
|||||||
received_chunks = []
|
received_chunks = []
|
||||||
|
|
||||||
with crewai_event_bus.scoped_handlers():
|
with crewai_event_bus.scoped_handlers():
|
||||||
|
|
||||||
@crewai_event_bus.on(LLMStreamChunkEvent)
|
@crewai_event_bus.on(LLMStreamChunkEvent)
|
||||||
def handle_stream_chunk(source, event):
|
def handle_stream_chunk(source, event):
|
||||||
received_chunks.append(event.chunk)
|
received_chunks.append(event.chunk)
|
||||||
@@ -793,6 +817,7 @@ def test_streaming_empty_response_handling():
|
|||||||
# Restore the original method
|
# Restore the original method
|
||||||
llm.call = original_call
|
llm.call = original_call
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_stream_llm_emits_event_with_task_and_agent_info():
|
def test_stream_llm_emits_event_with_task_and_agent_info():
|
||||||
completed_event = []
|
completed_event = []
|
||||||
@@ -801,6 +826,7 @@ def test_stream_llm_emits_event_with_task_and_agent_info():
|
|||||||
stream_event = []
|
stream_event = []
|
||||||
|
|
||||||
with crewai_event_bus.scoped_handlers():
|
with crewai_event_bus.scoped_handlers():
|
||||||
|
|
||||||
@crewai_event_bus.on(LLMCallFailedEvent)
|
@crewai_event_bus.on(LLMCallFailedEvent)
|
||||||
def handle_llm_failed(source, event):
|
def handle_llm_failed(source, event):
|
||||||
failed_event.append(event)
|
failed_event.append(event)
|
||||||
@@ -827,7 +853,7 @@ def test_stream_llm_emits_event_with_task_and_agent_info():
|
|||||||
description="Just say hi",
|
description="Just say hi",
|
||||||
expected_output="hi",
|
expected_output="hi",
|
||||||
llm=LLM(model="gpt-4o-mini", stream=True),
|
llm=LLM(model="gpt-4o-mini", stream=True),
|
||||||
agent=agent
|
agent=agent,
|
||||||
)
|
)
|
||||||
|
|
||||||
crew = Crew(agents=[agent], tasks=[task])
|
crew = Crew(agents=[agent], tasks=[task])
|
||||||
@@ -855,6 +881,7 @@ def test_stream_llm_emits_event_with_task_and_agent_info():
|
|||||||
assert set(all_task_id) == {task.id}
|
assert set(all_task_id) == {task.id}
|
||||||
assert set(all_task_name) == {task.name}
|
assert set(all_task_name) == {task.name}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_llm_emits_event_with_task_and_agent_info(base_agent, base_task):
|
def test_llm_emits_event_with_task_and_agent_info(base_agent, base_task):
|
||||||
completed_event = []
|
completed_event = []
|
||||||
@@ -863,6 +890,7 @@ def test_llm_emits_event_with_task_and_agent_info(base_agent, base_task):
|
|||||||
stream_event = []
|
stream_event = []
|
||||||
|
|
||||||
with crewai_event_bus.scoped_handlers():
|
with crewai_event_bus.scoped_handlers():
|
||||||
|
|
||||||
@crewai_event_bus.on(LLMCallFailedEvent)
|
@crewai_event_bus.on(LLMCallFailedEvent)
|
||||||
def handle_llm_failed(source, event):
|
def handle_llm_failed(source, event):
|
||||||
failed_event.append(event)
|
failed_event.append(event)
|
||||||
@@ -904,6 +932,7 @@ def test_llm_emits_event_with_task_and_agent_info(base_agent, base_task):
|
|||||||
assert set(all_task_id) == {base_task.id}
|
assert set(all_task_id) == {base_task.id}
|
||||||
assert set(all_task_name) == {base_task.name}
|
assert set(all_task_name) == {base_task.name}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_llm_emits_event_with_lite_agent():
|
def test_llm_emits_event_with_lite_agent():
|
||||||
completed_event = []
|
completed_event = []
|
||||||
@@ -912,6 +941,7 @@ def test_llm_emits_event_with_lite_agent():
|
|||||||
stream_event = []
|
stream_event = []
|
||||||
|
|
||||||
with crewai_event_bus.scoped_handlers():
|
with crewai_event_bus.scoped_handlers():
|
||||||
|
|
||||||
@crewai_event_bus.on(LLMCallFailedEvent)
|
@crewai_event_bus.on(LLMCallFailedEvent)
|
||||||
def handle_llm_failed(source, event):
|
def handle_llm_failed(source, event):
|
||||||
failed_event.append(event)
|
failed_event.append(event)
|
||||||
@@ -936,7 +966,6 @@ def test_llm_emits_event_with_lite_agent():
|
|||||||
)
|
)
|
||||||
agent.kickoff(messages=[{"role": "user", "content": "say hi!"}])
|
agent.kickoff(messages=[{"role": "user", "content": "say hi!"}])
|
||||||
|
|
||||||
|
|
||||||
assert len(completed_event) == 2
|
assert len(completed_event) == 2
|
||||||
assert len(failed_event) == 0
|
assert len(failed_event) == 0
|
||||||
assert len(started_event) == 2
|
assert len(started_event) == 2
|
||||||
|
|||||||
Reference in New Issue
Block a user