mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
feat: add aclose()/close() and async context manager to streaming outputs
This commit is contained in:
@@ -134,6 +134,7 @@ from crewai.utilities.rpm_controller import RPMController
|
||||
from crewai.utilities.streaming import (
|
||||
create_async_chunk_generator,
|
||||
create_chunk_generator,
|
||||
register_cleanup,
|
||||
signal_end,
|
||||
signal_error,
|
||||
)
|
||||
@@ -882,6 +883,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
ctx.state, run_crew, ctx.output_holder
|
||||
)
|
||||
)
|
||||
register_cleanup(streaming_output, ctx.state)
|
||||
ctx.output_holder.append(streaming_output)
|
||||
return streaming_output
|
||||
|
||||
@@ -1007,6 +1009,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
ctx.state, run_crew, ctx.output_holder
|
||||
)
|
||||
)
|
||||
register_cleanup(streaming_output, ctx.state)
|
||||
ctx.output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
@@ -1078,6 +1081,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
ctx.state, run_crew, ctx.output_holder
|
||||
)
|
||||
)
|
||||
register_cleanup(streaming_output, ctx.state)
|
||||
ctx.output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
@@ -431,6 +431,7 @@ async def run_for_each_async(
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities.streaming import (
|
||||
create_async_chunk_generator,
|
||||
register_cleanup,
|
||||
signal_end,
|
||||
signal_error,
|
||||
)
|
||||
@@ -480,6 +481,7 @@ async def run_for_each_async(
|
||||
streaming_output._set_results(result)
|
||||
|
||||
streaming_output._set_result = set_results_wrapper # type: ignore[method-assign]
|
||||
register_cleanup(streaming_output, ctx.state)
|
||||
ctx.output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
@@ -132,6 +132,7 @@ from crewai.utilities.streaming import (
|
||||
create_async_chunk_generator,
|
||||
create_chunk_generator,
|
||||
create_streaming_state,
|
||||
register_cleanup,
|
||||
signal_end,
|
||||
signal_error,
|
||||
)
|
||||
@@ -1962,6 +1963,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
streaming_output = FlowStreamingOutput(
|
||||
sync_iterator=create_chunk_generator(state, run_flow, output_holder)
|
||||
)
|
||||
register_cleanup(streaming_output, state)
|
||||
output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
@@ -2035,6 +2037,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
state, run_flow, output_holder
|
||||
)
|
||||
)
|
||||
register_cleanup(streaming_output, state)
|
||||
output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
@@ -2,11 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from collections.abc import AsyncIterator, Callable, Iterator
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -78,12 +79,21 @@ class StreamingOutputBase(Generic[T]):
|
||||
via the .result property after streaming completes.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
sync_iterator: Iterator[StreamChunk] | None = None,
|
||||
async_iterator: AsyncIterator[StreamChunk] | None = None,
|
||||
) -> None:
|
||||
"""Initialize streaming output base."""
|
||||
self._result: T | None = None
|
||||
self._completed: bool = False
|
||||
self._chunks: list[StreamChunk] = []
|
||||
self._error: Exception | None = None
|
||||
self._cancelled: bool = False
|
||||
self._exhausted: bool = False
|
||||
self._on_cleanup: Callable[[], None] | None = None
|
||||
self._sync_iterator = sync_iterator
|
||||
self._async_iterator = async_iterator
|
||||
|
||||
@property
|
||||
def result(self) -> T:
|
||||
@@ -112,6 +122,11 @@ class StreamingOutputBase(Generic[T]):
|
||||
"""Check if streaming has completed."""
|
||||
return self._completed
|
||||
|
||||
@property
|
||||
def is_cancelled(self) -> bool:
|
||||
"""Check if streaming was cancelled."""
|
||||
return self._cancelled
|
||||
|
||||
@property
|
||||
def chunks(self) -> list[StreamChunk]:
|
||||
"""Get all collected chunks so far."""
|
||||
@@ -129,6 +144,98 @@ class StreamingOutputBase(Generic[T]):
|
||||
if chunk.chunk_type == StreamChunkType.TEXT
|
||||
)
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Enter async context manager."""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *exc_info: Any) -> None:
|
||||
"""Exit async context manager, cancelling if still running."""
|
||||
await self.aclose()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Cancel streaming and clean up resources.
|
||||
|
||||
Cancels any in-flight tasks and closes the underlying async iterator.
|
||||
Safe to call multiple times. No-op if already cancelled or fully consumed.
|
||||
"""
|
||||
if self._cancelled or self._exhausted or self._error is not None:
|
||||
return
|
||||
self._cancelled = True
|
||||
self._completed = True
|
||||
if self._async_iterator is not None and hasattr(self._async_iterator, "aclose"):
|
||||
await self._async_iterator.aclose()
|
||||
if self._on_cleanup is not None:
|
||||
self._on_cleanup()
|
||||
self._on_cleanup = None
|
||||
|
||||
def close(self) -> None:
|
||||
"""Cancel streaming and clean up resources (sync).
|
||||
|
||||
Closes the underlying sync iterator. Safe to call multiple times.
|
||||
No-op if already cancelled, fully consumed, or errored.
|
||||
"""
|
||||
if self._cancelled or self._exhausted or self._error is not None:
|
||||
return
|
||||
self._cancelled = True
|
||||
self._completed = True
|
||||
if self._sync_iterator is not None and hasattr(self._sync_iterator, "close"):
|
||||
self._sync_iterator.close()
|
||||
if self._on_cleanup is not None:
|
||||
self._on_cleanup()
|
||||
self._on_cleanup = None
|
||||
|
||||
def __iter__(self) -> Iterator[StreamChunk]:
|
||||
"""Iterate over stream chunks synchronously.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sync iterator not available.
|
||||
"""
|
||||
if self._sync_iterator is None:
|
||||
raise RuntimeError("Sync iterator not available")
|
||||
try:
|
||||
for chunk in self._sync_iterator:
|
||||
self._chunks.append(chunk)
|
||||
yield chunk
|
||||
self._exhausted = True
|
||||
except Exception as e:
|
||||
self._error = e
|
||||
raise
|
||||
finally:
|
||||
self._completed = True
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[StreamChunk]:
|
||||
"""Return async iterator for stream chunks.
|
||||
|
||||
Returns:
|
||||
Async iterator for StreamChunk objects.
|
||||
"""
|
||||
return self._async_iterate()
|
||||
|
||||
async def _async_iterate(self) -> AsyncIterator[StreamChunk]:
|
||||
"""Iterate over stream chunks asynchronously.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If async iterator not available.
|
||||
"""
|
||||
if self._async_iterator is None:
|
||||
raise RuntimeError("Async iterator not available")
|
||||
try:
|
||||
async for chunk in self._async_iterator:
|
||||
self._chunks.append(chunk)
|
||||
yield chunk
|
||||
self._exhausted = True
|
||||
except Exception as e:
|
||||
self._error = e
|
||||
raise
|
||||
finally:
|
||||
self._completed = True
|
||||
|
||||
|
||||
class CrewStreamingOutput(StreamingOutputBase["CrewOutput"]):
|
||||
"""Streaming output wrapper for crew execution.
|
||||
@@ -167,9 +274,7 @@ class CrewStreamingOutput(StreamingOutputBase["CrewOutput"]):
|
||||
sync_iterator: Synchronous iterator for chunks.
|
||||
async_iterator: Asynchronous iterator for chunks.
|
||||
"""
|
||||
super().__init__()
|
||||
self._sync_iterator = sync_iterator
|
||||
self._async_iterator = async_iterator
|
||||
super().__init__(sync_iterator=sync_iterator, async_iterator=async_iterator)
|
||||
self._results: list[CrewOutput] | None = None
|
||||
|
||||
@property
|
||||
@@ -204,56 +309,6 @@ class CrewStreamingOutput(StreamingOutputBase["CrewOutput"]):
|
||||
self._results = results
|
||||
self._completed = True
|
||||
|
||||
def __iter__(self) -> Iterator[StreamChunk]:
|
||||
"""Iterate over stream chunks synchronously.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sync iterator not available.
|
||||
"""
|
||||
if self._sync_iterator is None:
|
||||
raise RuntimeError("Sync iterator not available")
|
||||
try:
|
||||
for chunk in self._sync_iterator:
|
||||
self._chunks.append(chunk)
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
self._error = e
|
||||
raise
|
||||
finally:
|
||||
self._completed = True
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[StreamChunk]:
|
||||
"""Return async iterator for stream chunks.
|
||||
|
||||
Returns:
|
||||
Async iterator for StreamChunk objects.
|
||||
"""
|
||||
return self._async_iterate()
|
||||
|
||||
async def _async_iterate(self) -> AsyncIterator[StreamChunk]:
|
||||
"""Iterate over stream chunks asynchronously.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If async iterator not available.
|
||||
"""
|
||||
if self._async_iterator is None:
|
||||
raise RuntimeError("Async iterator not available")
|
||||
try:
|
||||
async for chunk in self._async_iterator:
|
||||
self._chunks.append(chunk)
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
self._error = e
|
||||
raise
|
||||
finally:
|
||||
self._completed = True
|
||||
|
||||
def _set_result(self, result: CrewOutput) -> None:
|
||||
"""Set the final result after streaming completes.
|
||||
|
||||
@@ -286,71 +341,6 @@ class FlowStreamingOutput(StreamingOutputBase[Any]):
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sync_iterator: Iterator[StreamChunk] | None = None,
|
||||
async_iterator: AsyncIterator[StreamChunk] | None = None,
|
||||
) -> None:
|
||||
"""Initialize flow streaming output.
|
||||
|
||||
Args:
|
||||
sync_iterator: Synchronous iterator for chunks.
|
||||
async_iterator: Asynchronous iterator for chunks.
|
||||
"""
|
||||
super().__init__()
|
||||
self._sync_iterator = sync_iterator
|
||||
self._async_iterator = async_iterator
|
||||
|
||||
def __iter__(self) -> Iterator[StreamChunk]:
|
||||
"""Iterate over stream chunks synchronously.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If sync iterator not available.
|
||||
"""
|
||||
if self._sync_iterator is None:
|
||||
raise RuntimeError("Sync iterator not available")
|
||||
try:
|
||||
for chunk in self._sync_iterator:
|
||||
self._chunks.append(chunk)
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
self._error = e
|
||||
raise
|
||||
finally:
|
||||
self._completed = True
|
||||
|
||||
def __aiter__(self) -> AsyncIterator[StreamChunk]:
|
||||
"""Return async iterator for stream chunks.
|
||||
|
||||
Returns:
|
||||
Async iterator for StreamChunk objects.
|
||||
"""
|
||||
return self._async_iterate()
|
||||
|
||||
async def _async_iterate(self) -> AsyncIterator[StreamChunk]:
|
||||
"""Iterate over stream chunks asynchronously.
|
||||
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If async iterator not available.
|
||||
"""
|
||||
if self._async_iterator is None:
|
||||
raise RuntimeError("Async iterator not available")
|
||||
try:
|
||||
async for chunk in self._async_iterator:
|
||||
self._chunks.append(chunk)
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
self._error = e
|
||||
raise
|
||||
finally:
|
||||
self._completed = True
|
||||
|
||||
def _set_result(self, result: Any) -> None:
|
||||
"""Set the final result after streaming completes.
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator, Callable, Iterator
|
||||
import contextvars
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from typing import Any, NamedTuple
|
||||
@@ -22,6 +23,9 @@ from crewai.types.streaming import (
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskInfo(TypedDict):
|
||||
"""Task context information for streaming."""
|
||||
|
||||
@@ -159,10 +163,23 @@ def _finalize_streaming(
|
||||
streaming_output: The streaming output to set the result on.
|
||||
"""
|
||||
_unregister_handler(state.handler)
|
||||
streaming_output._on_cleanup = None
|
||||
if state.result_holder:
|
||||
streaming_output._set_result(state.result_holder[0])
|
||||
|
||||
|
||||
def register_cleanup(
|
||||
streaming_output: CrewStreamingOutput | FlowStreamingOutput,
|
||||
state: StreamingState,
|
||||
) -> None:
|
||||
"""Register a cleanup callback on the streaming output.
|
||||
|
||||
Ensures the event handler is unregistered even if aclose()/close()
|
||||
is called before iteration starts.
|
||||
"""
|
||||
streaming_output._on_cleanup = lambda: _unregister_handler(state.handler)
|
||||
|
||||
|
||||
def create_streaming_state(
|
||||
current_task_info: TaskInfo,
|
||||
result_holder: list[Any],
|
||||
@@ -294,7 +311,14 @@ async def create_async_chunk_generator(
|
||||
raise item
|
||||
yield item
|
||||
finally:
|
||||
await task
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
logger.debug("Background streaming task failed", exc_info=True)
|
||||
if output_holder:
|
||||
_finalize_streaming(state, output_holder[0])
|
||||
else:
|
||||
|
||||
@@ -709,6 +709,158 @@ class TestStreamingEdgeCases:
|
||||
assert streaming.is_completed
|
||||
|
||||
|
||||
class TestStreamingCancellation:
|
||||
"""Tests for streaming cancellation and resource cleanup."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aclose_cancels_async_streaming(self) -> None:
|
||||
"""Test that aclose() stops iteration and marks as cancelled."""
|
||||
chunks_yielded: list[str] = []
|
||||
|
||||
async def slow_gen() -> AsyncIterator[StreamChunk]:
|
||||
for i in range(100):
|
||||
await asyncio.sleep(0.01)
|
||||
chunks_yielded.append(f"chunk-{i}")
|
||||
yield StreamChunk(content=f"chunk-{i}")
|
||||
|
||||
streaming = CrewStreamingOutput(async_iterator=slow_gen())
|
||||
collected: list[StreamChunk] = []
|
||||
|
||||
async for chunk in streaming:
|
||||
collected.append(chunk)
|
||||
if len(collected) >= 3:
|
||||
break
|
||||
|
||||
await streaming.aclose()
|
||||
|
||||
assert streaming.is_cancelled
|
||||
assert streaming.is_completed
|
||||
assert len(collected) == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aclose_idempotent(self) -> None:
|
||||
"""Test that calling aclose() multiple times is safe."""
|
||||
async def gen() -> AsyncIterator[StreamChunk]:
|
||||
yield StreamChunk(content="test")
|
||||
|
||||
streaming = CrewStreamingOutput(async_iterator=gen())
|
||||
async for _ in streaming:
|
||||
pass
|
||||
|
||||
await streaming.aclose()
|
||||
await streaming.aclose()
|
||||
assert not streaming.is_cancelled
|
||||
assert streaming.is_completed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_context_manager(self) -> None:
|
||||
"""Test using streaming output as async context manager."""
|
||||
async def gen() -> AsyncIterator[StreamChunk]:
|
||||
yield StreamChunk(content="hello")
|
||||
yield StreamChunk(content="world")
|
||||
|
||||
streaming = CrewStreamingOutput(async_iterator=gen())
|
||||
collected: list[StreamChunk] = []
|
||||
|
||||
async with streaming:
|
||||
async for chunk in streaming:
|
||||
collected.append(chunk)
|
||||
|
||||
assert not streaming.is_cancelled
|
||||
assert streaming.is_completed
|
||||
assert len(collected) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_context_manager_early_exit(self) -> None:
|
||||
"""Test context manager cleans up on early exit."""
|
||||
async def gen() -> AsyncIterator[StreamChunk]:
|
||||
for i in range(100):
|
||||
await asyncio.sleep(0.01)
|
||||
yield StreamChunk(content=f"chunk-{i}")
|
||||
|
||||
streaming = CrewStreamingOutput(async_iterator=gen())
|
||||
|
||||
async with streaming:
|
||||
async for chunk in streaming:
|
||||
if chunk.content == "chunk-2":
|
||||
break
|
||||
|
||||
assert streaming.is_cancelled
|
||||
assert streaming.is_completed
|
||||
|
||||
def test_close_cancels_sync_streaming(self) -> None:
|
||||
"""Test that close() stops sync streaming and marks as cancelled."""
|
||||
def gen() -> Generator[StreamChunk, None, None]:
|
||||
for i in range(100):
|
||||
yield StreamChunk(content=f"chunk-{i}")
|
||||
|
||||
streaming = CrewStreamingOutput(sync_iterator=gen())
|
||||
collected: list[StreamChunk] = []
|
||||
|
||||
for chunk in streaming:
|
||||
collected.append(chunk)
|
||||
if len(collected) >= 3:
|
||||
break
|
||||
|
||||
streaming.close()
|
||||
|
||||
assert streaming.is_cancelled
|
||||
assert streaming.is_completed
|
||||
|
||||
def test_close_idempotent(self) -> None:
|
||||
"""Test that calling close() multiple times is safe."""
|
||||
def gen() -> Generator[StreamChunk, None, None]:
|
||||
yield StreamChunk(content="test")
|
||||
|
||||
streaming = CrewStreamingOutput(sync_iterator=gen())
|
||||
list(streaming)
|
||||
|
||||
streaming.close()
|
||||
streaming.close()
|
||||
assert not streaming.is_cancelled
|
||||
assert streaming.is_completed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flow_aclose(self) -> None:
|
||||
"""Test that FlowStreamingOutput aclose() is no-op after normal completion."""
|
||||
async def gen() -> AsyncIterator[StreamChunk]:
|
||||
yield StreamChunk(content="flow-chunk")
|
||||
|
||||
streaming = FlowStreamingOutput(async_iterator=gen())
|
||||
async for _ in streaming:
|
||||
pass
|
||||
|
||||
await streaming.aclose()
|
||||
assert not streaming.is_cancelled
|
||||
assert streaming.is_completed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flow_async_context_manager(self) -> None:
|
||||
"""Test FlowStreamingOutput as async context manager with full consumption."""
|
||||
async def gen() -> AsyncIterator[StreamChunk]:
|
||||
yield StreamChunk(content="flow-chunk")
|
||||
|
||||
streaming = FlowStreamingOutput(async_iterator=gen())
|
||||
|
||||
async with streaming:
|
||||
async for _ in streaming:
|
||||
pass
|
||||
|
||||
assert not streaming.is_cancelled
|
||||
assert streaming.is_completed
|
||||
|
||||
def test_flow_close(self) -> None:
|
||||
"""Test that FlowStreamingOutput close() is no-op after normal completion."""
|
||||
def gen() -> Generator[StreamChunk, None, None]:
|
||||
yield StreamChunk(content="flow-chunk")
|
||||
|
||||
streaming = FlowStreamingOutput(sync_iterator=gen())
|
||||
list(streaming)
|
||||
|
||||
streaming.close()
|
||||
assert not streaming.is_cancelled
|
||||
|
||||
|
||||
class TestStreamingImports:
|
||||
"""Tests for correct imports of streaming types."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user