diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 9185d143d..a91e74da6 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -1,6 +1,7 @@ import asyncio import json import re +import threading import uuid import warnings from concurrent.futures import Future @@ -60,6 +61,7 @@ from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator from crewai.utilities.evaluators.task_evaluator import TaskEvaluator from crewai.events.types.crew_events import ( + CrewKickoffCancelledEvent, CrewKickoffCompletedEvent, CrewKickoffFailedEvent, CrewKickoffStartedEvent, @@ -137,6 +139,7 @@ class Crew(FlowTrackable, BaseModel): _task_output_handler: TaskOutputStorageHandler = PrivateAttr( default_factory=TaskOutputStorageHandler ) + _cancellation_event: threading.Event = PrivateAttr(default_factory=threading.Event) name: Optional[str] = Field(default="crew") cache: bool = Field(default=True) @@ -613,6 +616,8 @@ class Crew(FlowTrackable, BaseModel): self, inputs: Optional[Dict[str, Any]] = None, ) -> CrewOutput: + self._reset_cancellation() + ctx = baggage.set_baggage( "crew_context", CrewContext(id=str(self.id), key=self.key) ) @@ -826,6 +831,18 @@ class Crew(FlowTrackable, BaseModel): last_sync_output: Optional[TaskOutput] = None for task_index, task in enumerate(tasks): + if self.is_cancelled(): + self._logger.log("info", f"Crew execution cancelled after {task_index} tasks", color="yellow") + crewai_event_bus.emit( + self, + CrewKickoffCancelledEvent( + crew_name=self.name, + completed_tasks=task_index, + total_tasks=len(tasks), + ), + ) + return self._create_crew_output(task_outputs) + if start_index is not None and task_index < start_index: if task.output: if task.async_execution: @@ -1093,6 +1110,10 @@ class Crew(FlowTrackable, BaseModel): ) -> List[TaskOutput]: task_outputs: List[TaskOutput] = [] for future_task, future, task_index in futures: + if self.is_cancelled(): + future.cancel() + continue + task_output = future.result() task_outputs.append(task_output) self._process_task_result(future_task, task_output) @@ -1525,3 +1546,16 @@ class Crew(FlowTrackable, BaseModel): and able_to_inject ): self.tasks[0].allow_crewai_trigger_context = True + + def cancel(self) -> None: + """Cancel the crew execution. This will stop the crew after the current task completes.""" + self._cancellation_event.set() + self._logger.log("info", "Crew cancellation requested", color="yellow") + + def is_cancelled(self) -> bool: + """Check if the crew execution has been cancelled.""" + return self._cancellation_event.is_set() + + def _reset_cancellation(self) -> None: + """Reset the cancellation state for reuse of the crew instance.""" + self._cancellation_event.clear() diff --git a/src/crewai/events/types/crew_events.py b/src/crewai/events/types/crew_events.py index 02d0f983d..ddd1a3056 100644 --- a/src/crewai/events/types/crew_events.py +++ b/src/crewai/events/types/crew_events.py @@ -110,3 +110,12 @@ class CrewTestResultEvent(CrewBaseEvent): execution_duration: float model: str type: str = "crew_test_result" + + +class CrewKickoffCancelledEvent(CrewBaseEvent): + """Event emitted when a crew execution is cancelled""" + + reason: str = "External cancellation requested" + completed_tasks: int = 0 + total_tasks: int = 0 + type: str = "crew_kickoff_cancelled" diff --git a/tests/test_crew_cancellation.py b/tests/test_crew_cancellation.py new file mode 100644 index 000000000..c612edb29 --- /dev/null +++ b/tests/test_crew_cancellation.py @@ -0,0 +1,224 @@ +import threading +import time +from unittest.mock import Mock, patch +import pytest +from crewai import Agent, Crew, Task +from crewai.process import Process +from crewai.events.types.crew_events import CrewKickoffCancelledEvent +from crewai.tasks.task_output import TaskOutput + + +@pytest.fixture +def mock_agent(): + return Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory", + verbose=False, + ) + + +@pytest.fixture +def slow_task(): + """A task that takes some time to complete for testing cancellation""" + def slow_execution(*args, **kwargs): + time.sleep(0.5) + return TaskOutput( + description="Task completed", + raw="Task completed", + agent="Test Agent" + ) + + task = Task( + description="A slow task for testing", + expected_output="Task output", + ) + task.execute_sync = Mock(side_effect=slow_execution) + return task + + +def test_crew_cancellation_basic(mock_agent, slow_task): + """Test basic cancellation functionality""" + crew = Crew(agents=[mock_agent], tasks=[slow_task], verbose=False) + + assert not crew.is_cancelled() + + crew.cancel() + assert crew.is_cancelled() + + +def test_crew_cancellation_during_execution(mock_agent): + """Test cancellation during crew execution""" + tasks = [] + for i in range(3): + task = Task( + description=f"Task {i}", + expected_output="Output", + ) + task.execute_sync = Mock(return_value=TaskOutput( + description=f"Task {i} completed", + raw=f"Output {i}", + agent="Test Agent" + )) + tasks.append(task) + + crew = Crew(agents=[mock_agent], tasks=tasks, verbose=False) + + result = None + exception = None + + def run_crew(): + nonlocal result, exception + try: + result = crew.kickoff() + except Exception as e: + exception = e + + thread = threading.Thread(target=run_crew) + thread.start() + + time.sleep(0.1) + crew.cancel() + + thread.join(timeout=2) + + assert crew.is_cancelled() + assert result is not None + assert exception is None + + +def test_crew_cancellation_events(mock_agent, slow_task): + """Test that cancellation events are emitted properly""" + crew = Crew(agents=[mock_agent], tasks=[slow_task], verbose=False) + + with patch('crewai.events.event_bus.crewai_event_bus.emit') as mock_emit: + crew.cancel() + result = crew.kickoff() + + cancellation_events = [ + call for call in mock_emit.call_args_list + if len(call[0]) > 1 and isinstance(call[0][1], CrewKickoffCancelledEvent) + ] + assert len(cancellation_events) > 0 + + +def test_crew_reuse_after_cancellation(mock_agent): + """Test that crew can be reused after cancellation""" + task = Task( + description="Test task", + expected_output="Test output", + ) + task.execute_sync = Mock(return_value=TaskOutput( + description="Task completed", + raw="Task completed", + agent="Test Agent" + )) + + crew = Crew(agents=[mock_agent], tasks=[task], verbose=False) + + crew.cancel() + result1 = crew.kickoff() + + result2 = crew.kickoff() + assert not crew.is_cancelled() + + +def test_crew_cancellation_hierarchical_process(mock_agent): + """Test cancellation works with hierarchical process""" + task = Task( + description="Test task", + expected_output="Test output", + ) + task.execute_sync = Mock(return_value=TaskOutput( + description="Task completed", + raw="Task completed", + agent="Test Agent" + )) + + crew = Crew( + agents=[mock_agent], + tasks=[task], + process=Process.hierarchical, + manager_llm="gpt-3.5-turbo", + verbose=False + ) + + crew.cancel() + result = crew.kickoff() + assert crew.is_cancelled() + + +def test_crew_cancellation_thread_safety(): + """Test thread safety of cancellation mechanism""" + agent = Agent(role="Test", goal="Test", backstory="Test", verbose=False) + task = Task(description="Test", expected_output="Test") + task.execute_sync = Mock(return_value=TaskOutput( + description="Task completed", + raw="Task completed", + agent="Test Agent" + )) + crew = Crew(agents=[agent], tasks=[task], verbose=False) + + def toggle_cancellation(): + for _ in range(100): + crew.cancel() + crew._reset_cancellation() + + threads = [threading.Thread(target=toggle_cancellation) for _ in range(5)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert isinstance(crew.is_cancelled(), bool) + + +def test_crew_cancellation_with_async_tasks(mock_agent): + """Test cancellation with async tasks""" + task = Task( + description="Async test task", + expected_output="Test output", + async_execution=True + ) + + def mock_execute_async(*args, **kwargs): + from concurrent.futures import Future + future = Future() + future.set_result(TaskOutput( + description="Async task completed", + raw="Async task completed", + agent="Test Agent" + )) + return future + + task.execute_async = Mock(side_effect=mock_execute_async) + + crew = Crew(agents=[mock_agent], tasks=[task], verbose=False) + + crew.cancel() + result = crew.kickoff() + assert crew.is_cancelled() + + +def test_crew_cancellation_partial_results(mock_agent): + """Test that partial results are returned when cancelled""" + tasks = [] + for i in range(3): + task = Task( + description=f"Task {i}", + expected_output="Output", + ) + task.execute_sync = Mock(return_value=TaskOutput( + description=f"Task {i} completed", + raw=f"Output {i}", + agent="Test Agent" + )) + tasks.append(task) + + crew = Crew(agents=[mock_agent], tasks=tasks, verbose=False) + + crew.cancel() + result = crew.kickoff() + + assert result is not None + assert hasattr(result, 'tasks_output')