fix: scope streaming handlers to prevent cross-run chunk contamination

Concurrent streaming runs registered handlers on the singleton event bus
that received all LLMStreamChunkEvent emissions, causing chunks to fan
out across unrelated queues. Introduces a ContextVar-based stream scope
ID so each handler only accepts events from its own execution context.

Closes #5376
This commit is contained in:
Greyson LaLonde
2026-04-17 03:02:03 +08:00
committed by GitHub
parent fbe2a04064
commit 6136228a66
2 changed files with 117 additions and 9 deletions

View File

@@ -7,6 +7,7 @@ import logging
import queue
import threading
from typing import Any, NamedTuple
import uuid
from typing_extensions import TypedDict
@@ -25,6 +26,10 @@ from crewai.utilities.string_utils import sanitize_tool_name
logger = logging.getLogger(__name__)
_current_stream_ids: contextvars.ContextVar[tuple[str, ...]] = contextvars.ContextVar(
"_current_stream_ids", default=()
)
class TaskInfo(TypedDict):
"""Task context information for streaming."""
@@ -45,6 +50,7 @@ class StreamingState(NamedTuple):
async_queue: asyncio.Queue[StreamChunk | None | Exception] | None
loop: asyncio.AbstractEventLoop | None
handler: Callable[[Any, BaseEvent], None]
stream_id: str | None = None
def _extract_tool_call_info(
@@ -106,6 +112,7 @@ def _create_stream_handler(
sync_queue: queue.Queue[StreamChunk | None | Exception],
async_queue: asyncio.Queue[StreamChunk | None | Exception] | None = None,
loop: asyncio.AbstractEventLoop | None = None,
stream_id: str | None = None,
) -> Callable[[Any, BaseEvent], None]:
"""Create a stream handler function.
@@ -114,21 +121,19 @@ def _create_stream_handler(
sync_queue: Synchronous queue for chunks.
async_queue: Optional async queue for chunks.
loop: Optional event loop for async operations.
stream_id: Stream scope ID for concurrent isolation.
Returns:
Handler function that can be registered with the event bus.
"""
def stream_handler(_: Any, event: BaseEvent) -> None:
"""Handle LLM stream chunk events and enqueue them.
Args:
_: Event source (unused).
event: The event to process.
"""
if not isinstance(event, LLMStreamChunkEvent):
return
if stream_id is not None and stream_id not in _current_stream_ids.get():
return
chunk = _create_stream_chunk(event, current_task_info)
if async_queue is not None and loop is not None:
@@ -203,7 +208,11 @@ def create_streaming_state(
async_queue = asyncio.Queue()
loop = asyncio.get_event_loop()
handler = _create_stream_handler(current_task_info, sync_queue, async_queue, loop)
stream_id = str(uuid.uuid4())
handler = _create_stream_handler(
current_task_info, sync_queue, async_queue, loop, stream_id=stream_id
)
crewai_event_bus.register_handler(LLMStreamChunkEvent, handler)
return StreamingState(
@@ -213,6 +222,7 @@ def create_streaming_state(
async_queue=async_queue,
loop=loop,
handler=handler,
stream_id=stream_id,
)
@@ -260,7 +270,12 @@ def create_chunk_generator(
Yields:
StreamChunk objects as they arrive.
"""
ctx = contextvars.copy_context()
if state.stream_id is not None:
token = _current_stream_ids.set((*_current_stream_ids.get(), state.stream_id))
ctx = contextvars.copy_context()
_current_stream_ids.reset(token)
else:
ctx = contextvars.copy_context()
thread = threading.Thread(target=ctx.run, args=(run_func,), daemon=True)
thread.start()
@@ -300,7 +315,12 @@ async def create_async_chunk_generator(
"Async queue not initialized. Use create_streaming_state(use_async=True)."
)
task = asyncio.create_task(run_coro())
if state.stream_id is not None:
token = _current_stream_ids.set((*_current_stream_ids.get(), state.stream_id))
task = asyncio.create_task(run_coro())
_current_stream_ids.reset(token)
else:
task = asyncio.create_task(run_coro())
try:
while True:

View File

@@ -879,3 +879,91 @@ class TestStreamingImports:
assert StreamChunk is not None
assert StreamChunkType is not None
assert ToolCallChunk is not None
class TestConcurrentStreamIsolation:
"""Regression tests for concurrent streaming isolation (issue #5376)."""
def test_concurrent_streams_do_not_cross_contaminate(self) -> None:
"""Two concurrent streaming runs must each receive only their own chunks.
Mirrors the real production path: create_streaming_state in the caller,
then temporarily push the stream_id into the ContextVar, copy_context,
and reset — exactly as create_chunk_generator does.
"""
import contextvars
import threading
from crewai.utilities.streaming import (
TaskInfo,
_current_stream_ids,
_unregister_handler,
create_streaming_state,
)
task_info_a: TaskInfo = {
"index": 0,
"name": "task_a",
"id": "a",
"agent_role": "A",
"agent_id": "a",
}
task_info_b: TaskInfo = {
"index": 1,
"name": "task_b",
"id": "b",
"agent_role": "B",
"agent_id": "b",
}
state_a = create_streaming_state(task_info_a, [])
state_b = create_streaming_state(task_info_b, [])
def make_emitter_ctx(state: Any) -> contextvars.Context:
token = _current_stream_ids.set(
(*_current_stream_ids.get(), state.stream_id)
)
ctx = contextvars.copy_context()
_current_stream_ids.reset(token)
return ctx
ctx_a = make_emitter_ctx(state_a)
ctx_b = make_emitter_ctx(state_b)
def emit_chunks(prefix: str, call_id: str) -> None:
for text in [f"{prefix}1", f"{prefix}2", f"{prefix}3"]:
crewai_event_bus.emit(
None,
event=LLMStreamChunkEvent(
chunk=text, call_id=call_id, response_id="r"
),
)
t_a = threading.Thread(target=ctx_a.run, args=(lambda: emit_chunks("A", "ca"),))
t_b = threading.Thread(target=ctx_b.run, args=(lambda: emit_chunks("B", "cb"),))
t_a.start()
t_b.start()
t_a.join()
t_b.join()
chunks_a: list[str] = []
while not state_a.sync_queue.empty():
item = state_a.sync_queue.get_nowait()
if isinstance(item, StreamChunk):
chunks_a.append(item.content)
chunks_b: list[str] = []
while not state_b.sync_queue.empty():
item = state_b.sync_queue.get_nowait()
if isinstance(item, StreamChunk):
chunks_b.append(item.content)
assert set(chunks_a) == {"A1", "A2", "A3"}, (
f"Stream A received unexpected chunks: {chunks_a}"
)
assert set(chunks_b) == {"B1", "B2", "B3"}, (
f"Stream B received unexpected chunks: {chunks_b}"
)
_unregister_handler(state_a.handler)
_unregister_handler(state_b.handler)