feat: add aclose()/close() and async context manager to streaming outputs

This commit is contained in:
Greyson LaLonde
2026-04-08 23:32:37 +08:00
committed by GitHub
parent 98e0d1054f
commit 0e8ed75947
12 changed files with 464 additions and 121 deletions

View File

@@ -134,6 +134,7 @@ from crewai.utilities.rpm_controller import RPMController
from crewai.utilities.streaming import (
create_async_chunk_generator,
create_chunk_generator,
register_cleanup,
signal_end,
signal_error,
)
@@ -882,6 +883,7 @@ class Crew(FlowTrackable, BaseModel):
ctx.state, run_crew, ctx.output_holder
)
)
register_cleanup(streaming_output, ctx.state)
ctx.output_holder.append(streaming_output)
return streaming_output
@@ -1007,6 +1009,7 @@ class Crew(FlowTrackable, BaseModel):
ctx.state, run_crew, ctx.output_holder
)
)
register_cleanup(streaming_output, ctx.state)
ctx.output_holder.append(streaming_output)
return streaming_output
@@ -1078,6 +1081,7 @@ class Crew(FlowTrackable, BaseModel):
ctx.state, run_crew, ctx.output_holder
)
)
register_cleanup(streaming_output, ctx.state)
ctx.output_holder.append(streaming_output)
return streaming_output

View File

@@ -431,6 +431,7 @@ async def run_for_each_async(
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities.streaming import (
create_async_chunk_generator,
register_cleanup,
signal_end,
signal_error,
)
@@ -480,6 +481,7 @@ async def run_for_each_async(
streaming_output._set_results(result)
streaming_output._set_result = set_results_wrapper # type: ignore[method-assign]
register_cleanup(streaming_output, ctx.state)
ctx.output_holder.append(streaming_output)
return streaming_output

View File

@@ -132,6 +132,7 @@ from crewai.utilities.streaming import (
create_async_chunk_generator,
create_chunk_generator,
create_streaming_state,
register_cleanup,
signal_end,
signal_error,
)
@@ -1962,6 +1963,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
streaming_output = FlowStreamingOutput(
sync_iterator=create_chunk_generator(state, run_flow, output_holder)
)
register_cleanup(streaming_output, state)
output_holder.append(streaming_output)
return streaming_output
@@ -2035,6 +2037,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
state, run_flow, output_holder
)
)
register_cleanup(streaming_output, state)
output_holder.append(streaming_output)
return streaming_output

View File

@@ -2,11 +2,12 @@
from __future__ import annotations
from collections.abc import AsyncIterator, Iterator
from collections.abc import AsyncIterator, Callable, Iterator
from enum import Enum
from typing import TYPE_CHECKING, Any, Generic, TypeVar
from pydantic import BaseModel, Field
from typing_extensions import Self
if TYPE_CHECKING:
@@ -78,12 +79,21 @@ class StreamingOutputBase(Generic[T]):
via the .result property after streaming completes.
"""
def __init__(self) -> None:
def __init__(
self,
sync_iterator: Iterator[StreamChunk] | None = None,
async_iterator: AsyncIterator[StreamChunk] | None = None,
) -> None:
"""Initialize streaming output base."""
self._result: T | None = None
self._completed: bool = False
self._chunks: list[StreamChunk] = []
self._error: Exception | None = None
self._cancelled: bool = False
self._exhausted: bool = False
self._on_cleanup: Callable[[], None] | None = None
self._sync_iterator = sync_iterator
self._async_iterator = async_iterator
@property
def result(self) -> T:
@@ -112,6 +122,11 @@ class StreamingOutputBase(Generic[T]):
"""Check if streaming has completed."""
return self._completed
@property
def is_cancelled(self) -> bool:
"""Check if streaming was cancelled."""
return self._cancelled
@property
def chunks(self) -> list[StreamChunk]:
"""Get all collected chunks so far."""
@@ -129,6 +144,98 @@ class StreamingOutputBase(Generic[T]):
if chunk.chunk_type == StreamChunkType.TEXT
)
async def __aenter__(self) -> Self:
"""Enter async context manager."""
return self
async def __aexit__(self, *exc_info: Any) -> None:
"""Exit async context manager, cancelling if still running."""
await self.aclose()
async def aclose(self) -> None:
"""Cancel streaming and clean up resources.
Cancels any in-flight tasks and closes the underlying async iterator.
Safe to call multiple times. No-op if already cancelled or fully consumed.
"""
if self._cancelled or self._exhausted or self._error is not None:
return
self._cancelled = True
self._completed = True
if self._async_iterator is not None and hasattr(self._async_iterator, "aclose"):
await self._async_iterator.aclose()
if self._on_cleanup is not None:
self._on_cleanup()
self._on_cleanup = None
def close(self) -> None:
"""Cancel streaming and clean up resources (sync).
Closes the underlying sync iterator. Safe to call multiple times.
No-op if already cancelled, fully consumed, or errored.
"""
if self._cancelled or self._exhausted or self._error is not None:
return
self._cancelled = True
self._completed = True
if self._sync_iterator is not None and hasattr(self._sync_iterator, "close"):
self._sync_iterator.close()
if self._on_cleanup is not None:
self._on_cleanup()
self._on_cleanup = None
def __iter__(self) -> Iterator[StreamChunk]:
"""Iterate over stream chunks synchronously.
Yields:
StreamChunk objects as they arrive.
Raises:
RuntimeError: If sync iterator not available.
"""
if self._sync_iterator is None:
raise RuntimeError("Sync iterator not available")
try:
for chunk in self._sync_iterator:
self._chunks.append(chunk)
yield chunk
self._exhausted = True
except Exception as e:
self._error = e
raise
finally:
self._completed = True
def __aiter__(self) -> AsyncIterator[StreamChunk]:
"""Return async iterator for stream chunks.
Returns:
Async iterator for StreamChunk objects.
"""
return self._async_iterate()
async def _async_iterate(self) -> AsyncIterator[StreamChunk]:
"""Iterate over stream chunks asynchronously.
Yields:
StreamChunk objects as they arrive.
Raises:
RuntimeError: If async iterator not available.
"""
if self._async_iterator is None:
raise RuntimeError("Async iterator not available")
try:
async for chunk in self._async_iterator:
self._chunks.append(chunk)
yield chunk
self._exhausted = True
except Exception as e:
self._error = e
raise
finally:
self._completed = True
class CrewStreamingOutput(StreamingOutputBase["CrewOutput"]):
"""Streaming output wrapper for crew execution.
@@ -167,9 +274,7 @@ class CrewStreamingOutput(StreamingOutputBase["CrewOutput"]):
sync_iterator: Synchronous iterator for chunks.
async_iterator: Asynchronous iterator for chunks.
"""
super().__init__()
self._sync_iterator = sync_iterator
self._async_iterator = async_iterator
super().__init__(sync_iterator=sync_iterator, async_iterator=async_iterator)
self._results: list[CrewOutput] | None = None
@property
@@ -204,56 +309,6 @@ class CrewStreamingOutput(StreamingOutputBase["CrewOutput"]):
self._results = results
self._completed = True
def __iter__(self) -> Iterator[StreamChunk]:
"""Iterate over stream chunks synchronously.
Yields:
StreamChunk objects as they arrive.
Raises:
RuntimeError: If sync iterator not available.
"""
if self._sync_iterator is None:
raise RuntimeError("Sync iterator not available")
try:
for chunk in self._sync_iterator:
self._chunks.append(chunk)
yield chunk
except Exception as e:
self._error = e
raise
finally:
self._completed = True
def __aiter__(self) -> AsyncIterator[StreamChunk]:
"""Return async iterator for stream chunks.
Returns:
Async iterator for StreamChunk objects.
"""
return self._async_iterate()
async def _async_iterate(self) -> AsyncIterator[StreamChunk]:
"""Iterate over stream chunks asynchronously.
Yields:
StreamChunk objects as they arrive.
Raises:
RuntimeError: If async iterator not available.
"""
if self._async_iterator is None:
raise RuntimeError("Async iterator not available")
try:
async for chunk in self._async_iterator:
self._chunks.append(chunk)
yield chunk
except Exception as e:
self._error = e
raise
finally:
self._completed = True
def _set_result(self, result: CrewOutput) -> None:
"""Set the final result after streaming completes.
@@ -286,71 +341,6 @@ class FlowStreamingOutput(StreamingOutputBase[Any]):
```
"""
def __init__(
self,
sync_iterator: Iterator[StreamChunk] | None = None,
async_iterator: AsyncIterator[StreamChunk] | None = None,
) -> None:
"""Initialize flow streaming output.
Args:
sync_iterator: Synchronous iterator for chunks.
async_iterator: Asynchronous iterator for chunks.
"""
super().__init__()
self._sync_iterator = sync_iterator
self._async_iterator = async_iterator
def __iter__(self) -> Iterator[StreamChunk]:
"""Iterate over stream chunks synchronously.
Yields:
StreamChunk objects as they arrive.
Raises:
RuntimeError: If sync iterator not available.
"""
if self._sync_iterator is None:
raise RuntimeError("Sync iterator not available")
try:
for chunk in self._sync_iterator:
self._chunks.append(chunk)
yield chunk
except Exception as e:
self._error = e
raise
finally:
self._completed = True
def __aiter__(self) -> AsyncIterator[StreamChunk]:
"""Return async iterator for stream chunks.
Returns:
Async iterator for StreamChunk objects.
"""
return self._async_iterate()
async def _async_iterate(self) -> AsyncIterator[StreamChunk]:
"""Iterate over stream chunks asynchronously.
Yields:
StreamChunk objects as they arrive.
Raises:
RuntimeError: If async iterator not available.
"""
if self._async_iterator is None:
raise RuntimeError("Async iterator not available")
try:
async for chunk in self._async_iterator:
self._chunks.append(chunk)
yield chunk
except Exception as e:
self._error = e
raise
finally:
self._completed = True
def _set_result(self, result: Any) -> None:
"""Set the final result after streaming completes.

View File

@@ -3,6 +3,7 @@
import asyncio
from collections.abc import AsyncIterator, Callable, Iterator
import contextvars
import logging
import queue
import threading
from typing import Any, NamedTuple
@@ -22,6 +23,9 @@ from crewai.types.streaming import (
from crewai.utilities.string_utils import sanitize_tool_name
logger = logging.getLogger(__name__)
class TaskInfo(TypedDict):
"""Task context information for streaming."""
@@ -159,10 +163,23 @@ def _finalize_streaming(
streaming_output: The streaming output to set the result on.
"""
_unregister_handler(state.handler)
streaming_output._on_cleanup = None
if state.result_holder:
streaming_output._set_result(state.result_holder[0])
def register_cleanup(
streaming_output: CrewStreamingOutput | FlowStreamingOutput,
state: StreamingState,
) -> None:
"""Register a cleanup callback on the streaming output.
Ensures the event handler is unregistered even if aclose()/close()
is called before iteration starts.
"""
streaming_output._on_cleanup = lambda: _unregister_handler(state.handler)
def create_streaming_state(
current_task_info: TaskInfo,
result_holder: list[Any],
@@ -294,7 +311,14 @@ async def create_async_chunk_generator(
raise item
yield item
finally:
await task
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
except Exception:
logger.debug("Background streaming task failed", exc_info=True)
if output_holder:
_finalize_streaming(state, output_holder[0])
else:

View File

@@ -709,6 +709,158 @@ class TestStreamingEdgeCases:
assert streaming.is_completed
class TestStreamingCancellation:
"""Tests for streaming cancellation and resource cleanup."""
@pytest.mark.asyncio
async def test_aclose_cancels_async_streaming(self) -> None:
"""Test that aclose() stops iteration and marks as cancelled."""
chunks_yielded: list[str] = []
async def slow_gen() -> AsyncIterator[StreamChunk]:
for i in range(100):
await asyncio.sleep(0.01)
chunks_yielded.append(f"chunk-{i}")
yield StreamChunk(content=f"chunk-{i}")
streaming = CrewStreamingOutput(async_iterator=slow_gen())
collected: list[StreamChunk] = []
async for chunk in streaming:
collected.append(chunk)
if len(collected) >= 3:
break
await streaming.aclose()
assert streaming.is_cancelled
assert streaming.is_completed
assert len(collected) == 3
@pytest.mark.asyncio
async def test_aclose_idempotent(self) -> None:
"""Test that calling aclose() multiple times is safe."""
async def gen() -> AsyncIterator[StreamChunk]:
yield StreamChunk(content="test")
streaming = CrewStreamingOutput(async_iterator=gen())
async for _ in streaming:
pass
await streaming.aclose()
await streaming.aclose()
assert not streaming.is_cancelled
assert streaming.is_completed
@pytest.mark.asyncio
async def test_async_context_manager(self) -> None:
"""Test using streaming output as async context manager."""
async def gen() -> AsyncIterator[StreamChunk]:
yield StreamChunk(content="hello")
yield StreamChunk(content="world")
streaming = CrewStreamingOutput(async_iterator=gen())
collected: list[StreamChunk] = []
async with streaming:
async for chunk in streaming:
collected.append(chunk)
assert not streaming.is_cancelled
assert streaming.is_completed
assert len(collected) == 2
@pytest.mark.asyncio
async def test_async_context_manager_early_exit(self) -> None:
"""Test context manager cleans up on early exit."""
async def gen() -> AsyncIterator[StreamChunk]:
for i in range(100):
await asyncio.sleep(0.01)
yield StreamChunk(content=f"chunk-{i}")
streaming = CrewStreamingOutput(async_iterator=gen())
async with streaming:
async for chunk in streaming:
if chunk.content == "chunk-2":
break
assert streaming.is_cancelled
assert streaming.is_completed
def test_close_cancels_sync_streaming(self) -> None:
"""Test that close() stops sync streaming and marks as cancelled."""
def gen() -> Generator[StreamChunk, None, None]:
for i in range(100):
yield StreamChunk(content=f"chunk-{i}")
streaming = CrewStreamingOutput(sync_iterator=gen())
collected: list[StreamChunk] = []
for chunk in streaming:
collected.append(chunk)
if len(collected) >= 3:
break
streaming.close()
assert streaming.is_cancelled
assert streaming.is_completed
def test_close_idempotent(self) -> None:
"""Test that calling close() multiple times is safe."""
def gen() -> Generator[StreamChunk, None, None]:
yield StreamChunk(content="test")
streaming = CrewStreamingOutput(sync_iterator=gen())
list(streaming)
streaming.close()
streaming.close()
assert not streaming.is_cancelled
assert streaming.is_completed
@pytest.mark.asyncio
async def test_flow_aclose(self) -> None:
"""Test that FlowStreamingOutput aclose() is no-op after normal completion."""
async def gen() -> AsyncIterator[StreamChunk]:
yield StreamChunk(content="flow-chunk")
streaming = FlowStreamingOutput(async_iterator=gen())
async for _ in streaming:
pass
await streaming.aclose()
assert not streaming.is_cancelled
assert streaming.is_completed
@pytest.mark.asyncio
async def test_flow_async_context_manager(self) -> None:
"""Test FlowStreamingOutput as async context manager with full consumption."""
async def gen() -> AsyncIterator[StreamChunk]:
yield StreamChunk(content="flow-chunk")
streaming = FlowStreamingOutput(async_iterator=gen())
async with streaming:
async for _ in streaming:
pass
assert not streaming.is_cancelled
assert streaming.is_completed
def test_flow_close(self) -> None:
"""Test that FlowStreamingOutput close() is no-op after normal completion."""
def gen() -> Generator[StreamChunk, None, None]:
yield StreamChunk(content="flow-chunk")
streaming = FlowStreamingOutput(sync_iterator=gen())
list(streaming)
streaming.close()
assert not streaming.is_cancelled
class TestStreamingImports:
"""Tests for correct imports of streaming types."""