mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-12 14:02:47 +00:00
Compare commits
2 Commits
chore/clea
...
devin/1775
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
da65140cf8 | ||
|
|
3702a47bfe |
@@ -2,8 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from enum import Enum
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -76,14 +78,23 @@ class StreamingOutputBase(Generic[T]):
|
||||
|
||||
Provides iteration over stream chunks and access to final result
|
||||
via the .result property after streaming completes.
|
||||
|
||||
Supports graceful cancellation via ``aclose()`` (async) and ``cancel()``
|
||||
(sync). When cancelled, in-flight background tasks are aborted and
|
||||
resources are released promptly.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize streaming output base."""
|
||||
self._result: T | None = None
|
||||
self._completed: bool = False
|
||||
self._cancelled: bool = False
|
||||
self._chunks: list[StreamChunk] = []
|
||||
self._error: Exception | None = None
|
||||
self._cancel_event: asyncio.Event | None = None
|
||||
self._cancel_thread_event: threading.Event | None = None
|
||||
self._background_task: asyncio.Task[Any] | None = None
|
||||
self._background_thread: threading.Thread | None = None
|
||||
|
||||
@property
|
||||
def result(self) -> T:
|
||||
@@ -112,6 +123,11 @@ class StreamingOutputBase(Generic[T]):
|
||||
"""Check if streaming has completed."""
|
||||
return self._completed
|
||||
|
||||
@property
|
||||
def is_cancelled(self) -> bool:
|
||||
"""Check if streaming was cancelled."""
|
||||
return self._cancelled
|
||||
|
||||
@property
|
||||
def chunks(self) -> list[StreamChunk]:
|
||||
"""Get all collected chunks so far."""
|
||||
@@ -129,6 +145,76 @@ class StreamingOutputBase(Generic[T]):
|
||||
if chunk.chunk_type == StreamChunkType.TEXT
|
||||
)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Cancel streaming and clean up resources.
|
||||
|
||||
Signals cancellation to the background task, waits briefly for it
|
||||
to finish, and marks the stream as completed and cancelled.
|
||||
Safe to call multiple times or on an already-completed stream.
|
||||
|
||||
Example:
|
||||
```python
|
||||
streaming = await crew.akickoff(inputs=inputs)
|
||||
try:
|
||||
async for chunk in streaming:
|
||||
...
|
||||
finally:
|
||||
await streaming.aclose()
|
||||
```
|
||||
"""
|
||||
if self._completed:
|
||||
return
|
||||
|
||||
self._cancelled = True
|
||||
|
||||
if self._cancel_event is not None:
|
||||
self._cancel_event.set()
|
||||
|
||||
if self._cancel_thread_event is not None:
|
||||
self._cancel_thread_event.set()
|
||||
|
||||
if self._background_task is not None and not self._background_task.done():
|
||||
self._background_task.cancel()
|
||||
try:
|
||||
await self._background_task
|
||||
except (asyncio.CancelledError, Exception): # noqa: S110
|
||||
pass
|
||||
|
||||
self._completed = True
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Synchronously cancel streaming and clean up resources.
|
||||
|
||||
Signals cancellation to the background thread/task and marks the
|
||||
stream as completed and cancelled. For async contexts prefer
|
||||
``aclose()`` which can ``await`` background cleanup.
|
||||
|
||||
Example:
|
||||
```python
|
||||
streaming = crew.kickoff(inputs=inputs)
|
||||
try:
|
||||
for chunk in streaming:
|
||||
...
|
||||
finally:
|
||||
streaming.cancel()
|
||||
```
|
||||
"""
|
||||
if self._completed:
|
||||
return
|
||||
|
||||
self._cancelled = True
|
||||
|
||||
if self._cancel_event is not None:
|
||||
self._cancel_event.set()
|
||||
|
||||
if self._cancel_thread_event is not None:
|
||||
self._cancel_thread_event.set()
|
||||
|
||||
if self._background_task is not None and not self._background_task.done():
|
||||
self._background_task.cancel()
|
||||
|
||||
self._completed = True
|
||||
|
||||
|
||||
class CrewStreamingOutput(StreamingOutputBase["CrewOutput"]):
|
||||
"""Streaming output wrapper for crew execution.
|
||||
|
||||
@@ -243,20 +243,37 @@ def create_chunk_generator(
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
"""
|
||||
cancel_event = threading.Event()
|
||||
ctx = contextvars.copy_context()
|
||||
thread = threading.Thread(target=ctx.run, args=(run_func,), daemon=True)
|
||||
thread.start()
|
||||
|
||||
# Wire cancellation to the streaming output once the holder is populated
|
||||
def _wire_cancel() -> None:
|
||||
if output_holder:
|
||||
output_holder[0]._cancel_thread_event = cancel_event
|
||||
output_holder[0]._background_thread = thread
|
||||
|
||||
try:
|
||||
while True:
|
||||
item = state.sync_queue.get()
|
||||
# Poll the queue with a timeout so we can check cancellation
|
||||
while True:
|
||||
_wire_cancel()
|
||||
if cancel_event.is_set():
|
||||
return
|
||||
try:
|
||||
item = state.sync_queue.get(timeout=0.1)
|
||||
break
|
||||
except queue.Empty:
|
||||
continue
|
||||
if item is None:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
finally:
|
||||
thread.join()
|
||||
if not cancel_event.is_set():
|
||||
thread.join()
|
||||
if output_holder:
|
||||
_finalize_streaming(state, output_holder[0])
|
||||
else:
|
||||
@@ -283,18 +300,52 @@ async def create_async_chunk_generator(
|
||||
"Async queue not initialized. Use create_streaming_state(use_async=True)."
|
||||
)
|
||||
|
||||
cancel_event = asyncio.Event()
|
||||
task = asyncio.create_task(run_coro())
|
||||
|
||||
# Wire cancellation to the streaming output once the holder is populated
|
||||
def _wire_cancel() -> None:
|
||||
if output_holder:
|
||||
output_holder[0]._cancel_event = cancel_event
|
||||
output_holder[0]._background_task = task
|
||||
|
||||
try:
|
||||
while True:
|
||||
item = await state.async_queue.get()
|
||||
_wire_cancel()
|
||||
# Use asyncio.wait to race between the queue and cancellation
|
||||
get_task = asyncio.ensure_future(state.async_queue.get())
|
||||
cancel_wait = asyncio.ensure_future(cancel_event.wait())
|
||||
done, pending = await asyncio.wait(
|
||||
{get_task, cancel_wait}, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
for p in pending:
|
||||
p.cancel()
|
||||
try:
|
||||
await p
|
||||
except (asyncio.CancelledError, Exception): # noqa: S110
|
||||
pass
|
||||
if cancel_wait in done:
|
||||
# Cancellation was requested
|
||||
return
|
||||
item = get_task.result()
|
||||
if item is None:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
finally:
|
||||
await task
|
||||
if not cancel_event.is_set():
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception): # noqa: S110
|
||||
pass
|
||||
else:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except (asyncio.CancelledError, Exception): # noqa: S110
|
||||
pass
|
||||
if output_holder:
|
||||
_finalize_streaming(state, output_holder[0])
|
||||
else:
|
||||
|
||||
@@ -709,6 +709,222 @@ class TestStreamingEdgeCases:
|
||||
assert streaming.is_completed
|
||||
|
||||
|
||||
class TestStreamingCancellation:
|
||||
"""Tests for graceful cancellation of streaming via aclose() and cancel()."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aclose_stops_async_iteration(self) -> None:
|
||||
"""Test that aclose() stops async iteration promptly."""
|
||||
chunks_yielded: list[str] = []
|
||||
cancel_event = asyncio.Event()
|
||||
|
||||
async def slow_gen() -> AsyncIterator[StreamChunk]:
|
||||
for i in range(100):
|
||||
if cancel_event.is_set():
|
||||
return
|
||||
yield StreamChunk(content=f"chunk-{i}")
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
streaming = CrewStreamingOutput(async_iterator=slow_gen())
|
||||
streaming._cancel_event = cancel_event
|
||||
|
||||
async for chunk in streaming:
|
||||
chunks_yielded.append(chunk.content)
|
||||
if len(chunks_yielded) >= 3:
|
||||
await streaming.aclose()
|
||||
break
|
||||
|
||||
assert streaming.is_cancelled
|
||||
assert streaming.is_completed
|
||||
assert len(chunks_yielded) >= 3
|
||||
assert len(chunks_yielded) < 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aclose_on_completed_stream_is_noop(self) -> None:
|
||||
"""Test that aclose() on an already-completed stream does nothing."""
|
||||
async def simple_gen() -> AsyncIterator[StreamChunk]:
|
||||
yield StreamChunk(content="done")
|
||||
|
||||
streaming = CrewStreamingOutput(async_iterator=simple_gen())
|
||||
|
||||
async for _ in streaming:
|
||||
pass
|
||||
|
||||
assert streaming.is_completed
|
||||
assert not streaming.is_cancelled
|
||||
|
||||
# aclose on completed stream should not change cancelled state
|
||||
await streaming.aclose()
|
||||
assert streaming.is_completed
|
||||
assert not streaming.is_cancelled
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aclose_cancels_background_task(self) -> None:
|
||||
"""Test that aclose() cancels the background asyncio task."""
|
||||
bg_task_started = asyncio.Event()
|
||||
|
||||
async def long_running_task() -> None:
|
||||
bg_task_started.set()
|
||||
await asyncio.sleep(100)
|
||||
|
||||
bg_task = asyncio.create_task(long_running_task())
|
||||
await bg_task_started.wait()
|
||||
|
||||
streaming = CrewStreamingOutput()
|
||||
streaming._background_task = bg_task
|
||||
|
||||
assert not bg_task.done()
|
||||
|
||||
await streaming.aclose()
|
||||
|
||||
assert streaming.is_cancelled
|
||||
assert bg_task.done()
|
||||
assert bg_task.cancelled()
|
||||
|
||||
def test_cancel_stops_sync_iteration(self) -> None:
|
||||
"""Test that cancel() marks streaming as cancelled."""
|
||||
def slow_gen() -> Generator[StreamChunk, None, None]:
|
||||
for i in range(100):
|
||||
yield StreamChunk(content=f"chunk-{i}")
|
||||
|
||||
streaming = CrewStreamingOutput(sync_iterator=slow_gen())
|
||||
|
||||
chunks_collected: list[str] = []
|
||||
for chunk in streaming:
|
||||
chunks_collected.append(chunk.content)
|
||||
if len(chunks_collected) >= 3:
|
||||
streaming.cancel()
|
||||
break
|
||||
|
||||
assert streaming.is_cancelled
|
||||
assert streaming.is_completed
|
||||
assert len(chunks_collected) >= 3
|
||||
|
||||
def test_cancel_on_completed_stream_is_noop(self) -> None:
|
||||
"""Test that cancel() on an already-completed stream does nothing."""
|
||||
def simple_gen() -> Generator[StreamChunk, None, None]:
|
||||
yield StreamChunk(content="done")
|
||||
|
||||
streaming = CrewStreamingOutput(sync_iterator=simple_gen())
|
||||
list(streaming)
|
||||
|
||||
assert streaming.is_completed
|
||||
assert not streaming.is_cancelled
|
||||
|
||||
streaming.cancel()
|
||||
assert streaming.is_completed
|
||||
assert not streaming.is_cancelled
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_cancelled_property_reflects_state(self) -> None:
|
||||
"""Test that is_cancelled starts False and becomes True after aclose()."""
|
||||
async def simple_gen() -> AsyncIterator[StreamChunk]:
|
||||
yield StreamChunk(content="test")
|
||||
|
||||
streaming = CrewStreamingOutput(async_iterator=simple_gen())
|
||||
assert not streaming.is_cancelled
|
||||
|
||||
await streaming.aclose()
|
||||
assert streaming.is_cancelled
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aclose_with_cancel_event(self) -> None:
|
||||
"""Test that aclose() sets the cancel event."""
|
||||
cancel_event = asyncio.Event()
|
||||
streaming = CrewStreamingOutput()
|
||||
streaming._cancel_event = cancel_event
|
||||
|
||||
assert not cancel_event.is_set()
|
||||
await streaming.aclose()
|
||||
assert cancel_event.is_set()
|
||||
assert streaming.is_cancelled
|
||||
|
||||
def test_cancel_with_thread_event(self) -> None:
|
||||
"""Test that cancel() sets the thread cancel event."""
|
||||
import threading
|
||||
|
||||
cancel_event = threading.Event()
|
||||
streaming = CrewStreamingOutput()
|
||||
streaming._cancel_thread_event = cancel_event
|
||||
|
||||
assert not cancel_event.is_set()
|
||||
streaming.cancel()
|
||||
assert cancel_event.is_set()
|
||||
assert streaming.is_cancelled
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flow_streaming_aclose(self) -> None:
|
||||
"""Test that FlowStreamingOutput also supports aclose()."""
|
||||
async def simple_gen() -> AsyncIterator[StreamChunk]:
|
||||
yield StreamChunk(content="flow-chunk")
|
||||
await asyncio.sleep(100) # Would block forever without cancel
|
||||
|
||||
streaming = FlowStreamingOutput(async_iterator=simple_gen())
|
||||
cancel_event = asyncio.Event()
|
||||
streaming._cancel_event = cancel_event
|
||||
|
||||
chunks: list[str] = []
|
||||
async for chunk in streaming:
|
||||
chunks.append(chunk.content)
|
||||
await streaming.aclose()
|
||||
break
|
||||
|
||||
assert streaming.is_cancelled
|
||||
assert streaming.is_completed
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0] == "flow-chunk"
|
||||
|
||||
def test_flow_streaming_cancel(self) -> None:
|
||||
"""Test that FlowStreamingOutput also supports cancel()."""
|
||||
def simple_gen() -> Generator[StreamChunk, None, None]:
|
||||
yield StreamChunk(content="flow-chunk")
|
||||
|
||||
streaming = FlowStreamingOutput(sync_iterator=simple_gen())
|
||||
assert not streaming.is_cancelled
|
||||
|
||||
# Consume
|
||||
list(streaming)
|
||||
assert streaming.is_completed
|
||||
|
||||
# Cancel on completed does nothing
|
||||
streaming.cancel()
|
||||
assert not streaming.is_cancelled
|
||||
|
||||
# Test cancelling before completion
|
||||
streaming2 = FlowStreamingOutput(sync_iterator=simple_gen())
|
||||
streaming2.cancel()
|
||||
assert streaming2.is_cancelled
|
||||
assert streaming2.is_completed
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_aclose_calls_are_safe(self) -> None:
|
||||
"""Test that calling aclose() multiple times is safe."""
|
||||
async def simple_gen() -> AsyncIterator[StreamChunk]:
|
||||
yield StreamChunk(content="test")
|
||||
|
||||
streaming = CrewStreamingOutput(async_iterator=simple_gen())
|
||||
|
||||
await streaming.aclose()
|
||||
assert streaming.is_cancelled
|
||||
|
||||
# Second call should be a no-op
|
||||
await streaming.aclose()
|
||||
assert streaming.is_cancelled
|
||||
assert streaming.is_completed
|
||||
|
||||
def test_multiple_cancel_calls_are_safe(self) -> None:
|
||||
"""Test that calling cancel() multiple times is safe."""
|
||||
streaming = CrewStreamingOutput()
|
||||
|
||||
streaming.cancel()
|
||||
assert streaming.is_cancelled
|
||||
|
||||
# Second call should be a no-op
|
||||
streaming.cancel()
|
||||
assert streaming.is_cancelled
|
||||
assert streaming.is_completed
|
||||
|
||||
|
||||
class TestStreamingImports:
|
||||
"""Tests for correct imports of streaming types."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user