diff --git a/lib/crewai/src/crewai/types/streaming.py b/lib/crewai/src/crewai/types/streaming.py index a1f6e4ef7..5c513ab7b 100644 --- a/lib/crewai/src/crewai/types/streaming.py +++ b/lib/crewai/src/crewai/types/streaming.py @@ -2,8 +2,10 @@ from __future__ import annotations +import asyncio from collections.abc import AsyncIterator, Iterator from enum import Enum +import threading from typing import TYPE_CHECKING, Any, Generic, TypeVar from pydantic import BaseModel, Field @@ -76,14 +78,23 @@ class StreamingOutputBase(Generic[T]): Provides iteration over stream chunks and access to final result via the .result property after streaming completes. + + Supports graceful cancellation via ``aclose()`` (async) and ``cancel()`` + (sync). When cancelled, in-flight background tasks are aborted and + resources are released promptly. """ def __init__(self) -> None: """Initialize streaming output base.""" self._result: T | None = None self._completed: bool = False + self._cancelled: bool = False self._chunks: list[StreamChunk] = [] self._error: Exception | None = None + self._cancel_event: asyncio.Event | None = None + self._cancel_thread_event: threading.Event | None = None + self._background_task: asyncio.Task[Any] | None = None + self._background_thread: threading.Thread | None = None @property def result(self) -> T: @@ -112,6 +123,11 @@ class StreamingOutputBase(Generic[T]): """Check if streaming has completed.""" return self._completed + @property + def is_cancelled(self) -> bool: + """Check if streaming was cancelled.""" + return self._cancelled + @property def chunks(self) -> list[StreamChunk]: """Get all collected chunks so far.""" @@ -129,6 +145,76 @@ class StreamingOutputBase(Generic[T]): if chunk.chunk_type == StreamChunkType.TEXT ) + async def aclose(self) -> None: + """Cancel streaming and clean up resources. + + Signals cancellation to the background task, waits briefly for it + to finish, and marks the stream as completed and cancelled. + Safe to call multiple times or on an already-completed stream. + + Example: + ```python + streaming = await crew.akickoff(inputs=inputs) + try: + async for chunk in streaming: + ... + finally: + await streaming.aclose() + ``` + """ + if self._completed: + return + + self._cancelled = True + + if self._cancel_event is not None: + self._cancel_event.set() + + if self._cancel_thread_event is not None: + self._cancel_thread_event.set() + + if self._background_task is not None and not self._background_task.done(): + self._background_task.cancel() + try: + await self._background_task + except (asyncio.CancelledError, Exception): # noqa: S110 + pass + + self._completed = True + + def cancel(self) -> None: + """Synchronously cancel streaming and clean up resources. + + Signals cancellation to the background thread/task and marks the + stream as completed and cancelled. For async contexts prefer + ``aclose()`` which can ``await`` background cleanup. + + Example: + ```python + streaming = crew.kickoff(inputs=inputs) + try: + for chunk in streaming: + ... + finally: + streaming.cancel() + ``` + """ + if self._completed: + return + + self._cancelled = True + + if self._cancel_event is not None: + self._cancel_event.set() + + if self._cancel_thread_event is not None: + self._cancel_thread_event.set() + + if self._background_task is not None and not self._background_task.done(): + self._background_task.cancel() + + self._completed = True + class CrewStreamingOutput(StreamingOutputBase["CrewOutput"]): """Streaming output wrapper for crew execution. diff --git a/lib/crewai/src/crewai/utilities/streaming.py b/lib/crewai/src/crewai/utilities/streaming.py index dd0992684..e0ea5b1f0 100644 --- a/lib/crewai/src/crewai/utilities/streaming.py +++ b/lib/crewai/src/crewai/utilities/streaming.py @@ -243,20 +243,37 @@ def create_chunk_generator( Yields: StreamChunk objects as they arrive. """ + cancel_event = threading.Event() ctx = contextvars.copy_context() thread = threading.Thread(target=ctx.run, args=(run_func,), daemon=True) thread.start() + # Wire cancellation to the streaming output once the holder is populated + def _wire_cancel() -> None: + if output_holder: + output_holder[0]._cancel_thread_event = cancel_event + output_holder[0]._background_thread = thread + try: while True: - item = state.sync_queue.get() + # Poll the queue with a timeout so we can check cancellation + while True: + _wire_cancel() + if cancel_event.is_set(): + return + try: + item = state.sync_queue.get(timeout=0.1) + break + except queue.Empty: + continue if item is None: break if isinstance(item, Exception): raise item yield item finally: - thread.join() + if not cancel_event.is_set(): + thread.join() if output_holder: _finalize_streaming(state, output_holder[0]) else: @@ -283,18 +300,49 @@ async def create_async_chunk_generator( "Async queue not initialized. Use create_streaming_state(use_async=True)." ) + cancel_event = asyncio.Event() task = asyncio.create_task(run_coro()) + # Wire cancellation to the streaming output once the holder is populated + def _wire_cancel() -> None: + if output_holder: + output_holder[0]._cancel_event = cancel_event + output_holder[0]._background_task = task + try: while True: - item = await state.async_queue.get() + _wire_cancel() + # Use asyncio.wait to race between the queue and cancellation + get_task = asyncio.ensure_future(state.async_queue.get()) + cancel_wait = asyncio.ensure_future(cancel_event.wait()) + done, pending = await asyncio.wait( + {get_task, cancel_wait}, return_when=asyncio.FIRST_COMPLETED + ) + for p in pending: + p.cancel() + try: + await p + except (asyncio.CancelledError, Exception): # noqa: S110 + pass + if cancel_wait in done: + # Cancellation was requested + return + item = get_task.result() if item is None: break if isinstance(item, Exception): raise item yield item finally: - await task + if not cancel_event.is_set(): + await task + else: + if not task.done(): + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): # noqa: S110 + pass if output_holder: _finalize_streaming(state, output_holder[0]) else: diff --git a/lib/crewai/tests/test_streaming.py b/lib/crewai/tests/test_streaming.py index 8eb63694e..727947de5 100644 --- a/lib/crewai/tests/test_streaming.py +++ b/lib/crewai/tests/test_streaming.py @@ -709,6 +709,222 @@ class TestStreamingEdgeCases: assert streaming.is_completed +class TestStreamingCancellation: + """Tests for graceful cancellation of streaming via aclose() and cancel().""" + + @pytest.mark.asyncio + async def test_aclose_stops_async_iteration(self) -> None: + """Test that aclose() stops async iteration promptly.""" + chunks_yielded: list[str] = [] + cancel_event = asyncio.Event() + + async def slow_gen() -> AsyncIterator[StreamChunk]: + for i in range(100): + if cancel_event.is_set(): + return + yield StreamChunk(content=f"chunk-{i}") + await asyncio.sleep(0.05) + + streaming = CrewStreamingOutput(async_iterator=slow_gen()) + streaming._cancel_event = cancel_event + + async for chunk in streaming: + chunks_yielded.append(chunk.content) + if len(chunks_yielded) >= 3: + await streaming.aclose() + break + + assert streaming.is_cancelled + assert streaming.is_completed + assert len(chunks_yielded) >= 3 + assert len(chunks_yielded) < 100 + + @pytest.mark.asyncio + async def test_aclose_on_completed_stream_is_noop(self) -> None: + """Test that aclose() on an already-completed stream does nothing.""" + async def simple_gen() -> AsyncIterator[StreamChunk]: + yield StreamChunk(content="done") + + streaming = CrewStreamingOutput(async_iterator=simple_gen()) + + async for _ in streaming: + pass + + assert streaming.is_completed + assert not streaming.is_cancelled + + # aclose on completed stream should not change cancelled state + await streaming.aclose() + assert streaming.is_completed + assert not streaming.is_cancelled + + @pytest.mark.asyncio + async def test_aclose_cancels_background_task(self) -> None: + """Test that aclose() cancels the background asyncio task.""" + bg_task_started = asyncio.Event() + + async def long_running_task() -> None: + bg_task_started.set() + await asyncio.sleep(100) + + bg_task = asyncio.create_task(long_running_task()) + await bg_task_started.wait() + + streaming = CrewStreamingOutput() + streaming._background_task = bg_task + + assert not bg_task.done() + + await streaming.aclose() + + assert streaming.is_cancelled + assert bg_task.done() + assert bg_task.cancelled() + + def test_cancel_stops_sync_iteration(self) -> None: + """Test that cancel() marks streaming as cancelled.""" + def slow_gen() -> Generator[StreamChunk, None, None]: + for i in range(100): + yield StreamChunk(content=f"chunk-{i}") + + streaming = CrewStreamingOutput(sync_iterator=slow_gen()) + + chunks_collected: list[str] = [] + for chunk in streaming: + chunks_collected.append(chunk.content) + if len(chunks_collected) >= 3: + streaming.cancel() + break + + assert streaming.is_cancelled + assert streaming.is_completed + assert len(chunks_collected) >= 3 + + def test_cancel_on_completed_stream_is_noop(self) -> None: + """Test that cancel() on an already-completed stream does nothing.""" + def simple_gen() -> Generator[StreamChunk, None, None]: + yield StreamChunk(content="done") + + streaming = CrewStreamingOutput(sync_iterator=simple_gen()) + list(streaming) + + assert streaming.is_completed + assert not streaming.is_cancelled + + streaming.cancel() + assert streaming.is_completed + assert not streaming.is_cancelled + + @pytest.mark.asyncio + async def test_is_cancelled_property_reflects_state(self) -> None: + """Test that is_cancelled starts False and becomes True after aclose().""" + async def simple_gen() -> AsyncIterator[StreamChunk]: + yield StreamChunk(content="test") + + streaming = CrewStreamingOutput(async_iterator=simple_gen()) + assert not streaming.is_cancelled + + await streaming.aclose() + assert streaming.is_cancelled + + @pytest.mark.asyncio + async def test_aclose_with_cancel_event(self) -> None: + """Test that aclose() sets the cancel event.""" + cancel_event = asyncio.Event() + streaming = CrewStreamingOutput() + streaming._cancel_event = cancel_event + + assert not cancel_event.is_set() + await streaming.aclose() + assert cancel_event.is_set() + assert streaming.is_cancelled + + def test_cancel_with_thread_event(self) -> None: + """Test that cancel() sets the thread cancel event.""" + import threading + + cancel_event = threading.Event() + streaming = CrewStreamingOutput() + streaming._cancel_thread_event = cancel_event + + assert not cancel_event.is_set() + streaming.cancel() + assert cancel_event.is_set() + assert streaming.is_cancelled + + @pytest.mark.asyncio + async def test_flow_streaming_aclose(self) -> None: + """Test that FlowStreamingOutput also supports aclose().""" + async def simple_gen() -> AsyncIterator[StreamChunk]: + yield StreamChunk(content="flow-chunk") + await asyncio.sleep(100) # Would block forever without cancel + + streaming = FlowStreamingOutput(async_iterator=simple_gen()) + cancel_event = asyncio.Event() + streaming._cancel_event = cancel_event + + chunks: list[str] = [] + async for chunk in streaming: + chunks.append(chunk.content) + await streaming.aclose() + break + + assert streaming.is_cancelled + assert streaming.is_completed + assert len(chunks) == 1 + assert chunks[0] == "flow-chunk" + + def test_flow_streaming_cancel(self) -> None: + """Test that FlowStreamingOutput also supports cancel().""" + def simple_gen() -> Generator[StreamChunk, None, None]: + yield StreamChunk(content="flow-chunk") + + streaming = FlowStreamingOutput(sync_iterator=simple_gen()) + assert not streaming.is_cancelled + + # Consume + list(streaming) + assert streaming.is_completed + + # Cancel on completed does nothing + streaming.cancel() + assert not streaming.is_cancelled + + # Test cancelling before completion + streaming2 = FlowStreamingOutput(sync_iterator=simple_gen()) + streaming2.cancel() + assert streaming2.is_cancelled + assert streaming2.is_completed + + @pytest.mark.asyncio + async def test_multiple_aclose_calls_are_safe(self) -> None: + """Test that calling aclose() multiple times is safe.""" + async def simple_gen() -> AsyncIterator[StreamChunk]: + yield StreamChunk(content="test") + + streaming = CrewStreamingOutput(async_iterator=simple_gen()) + + await streaming.aclose() + assert streaming.is_cancelled + + # Second call should be a no-op + await streaming.aclose() + assert streaming.is_cancelled + assert streaming.is_completed + + def test_multiple_cancel_calls_are_safe(self) -> None: + """Test that calling cancel() multiple times is safe.""" + streaming = CrewStreamingOutput() + + streaming.cancel() + assert streaming.is_cancelled + + # Second call should be a no-op + streaming.cancel() + assert streaming.is_cancelled + assert streaming.is_completed + + class TestStreamingImports: """Tests for correct imports of streaming types."""