mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-05 15:09:22 +00:00
fix: scope streaming handlers to prevent cross-run chunk contamination
Concurrent streaming runs registered handlers on the singleton event bus that received all LLMStreamChunkEvent emissions, causing chunks to fan out across unrelated queues. Introduces a ContextVar-based stream scope ID so each handler only accepts events from its own execution context. Closes #5376
This commit is contained in:
@@ -7,6 +7,7 @@ import logging
|
||||
import queue
|
||||
import threading
|
||||
from typing import Any, NamedTuple
|
||||
import uuid
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
@@ -25,6 +26,10 @@ from crewai.utilities.string_utils import sanitize_tool_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_current_stream_ids: contextvars.ContextVar[tuple[str, ...]] = contextvars.ContextVar(
|
||||
"_current_stream_ids", default=()
|
||||
)
|
||||
|
||||
|
||||
class TaskInfo(TypedDict):
|
||||
"""Task context information for streaming."""
|
||||
@@ -45,6 +50,7 @@ class StreamingState(NamedTuple):
|
||||
async_queue: asyncio.Queue[StreamChunk | None | Exception] | None
|
||||
loop: asyncio.AbstractEventLoop | None
|
||||
handler: Callable[[Any, BaseEvent], None]
|
||||
stream_id: str | None = None
|
||||
|
||||
|
||||
def _extract_tool_call_info(
|
||||
@@ -106,6 +112,7 @@ def _create_stream_handler(
|
||||
sync_queue: queue.Queue[StreamChunk | None | Exception],
|
||||
async_queue: asyncio.Queue[StreamChunk | None | Exception] | None = None,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
stream_id: str | None = None,
|
||||
) -> Callable[[Any, BaseEvent], None]:
|
||||
"""Create a stream handler function.
|
||||
|
||||
@@ -114,21 +121,19 @@ def _create_stream_handler(
|
||||
sync_queue: Synchronous queue for chunks.
|
||||
async_queue: Optional async queue for chunks.
|
||||
loop: Optional event loop for async operations.
|
||||
stream_id: Stream scope ID for concurrent isolation.
|
||||
|
||||
Returns:
|
||||
Handler function that can be registered with the event bus.
|
||||
"""
|
||||
|
||||
def stream_handler(_: Any, event: BaseEvent) -> None:
|
||||
"""Handle LLM stream chunk events and enqueue them.
|
||||
|
||||
Args:
|
||||
_: Event source (unused).
|
||||
event: The event to process.
|
||||
"""
|
||||
if not isinstance(event, LLMStreamChunkEvent):
|
||||
return
|
||||
|
||||
if stream_id is not None and stream_id not in _current_stream_ids.get():
|
||||
return
|
||||
|
||||
chunk = _create_stream_chunk(event, current_task_info)
|
||||
|
||||
if async_queue is not None and loop is not None:
|
||||
@@ -203,7 +208,11 @@ def create_streaming_state(
|
||||
async_queue = asyncio.Queue()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
handler = _create_stream_handler(current_task_info, sync_queue, async_queue, loop)
|
||||
stream_id = str(uuid.uuid4())
|
||||
|
||||
handler = _create_stream_handler(
|
||||
current_task_info, sync_queue, async_queue, loop, stream_id=stream_id
|
||||
)
|
||||
crewai_event_bus.register_handler(LLMStreamChunkEvent, handler)
|
||||
|
||||
return StreamingState(
|
||||
@@ -213,6 +222,7 @@ def create_streaming_state(
|
||||
async_queue=async_queue,
|
||||
loop=loop,
|
||||
handler=handler,
|
||||
stream_id=stream_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -260,7 +270,12 @@ def create_chunk_generator(
|
||||
Yields:
|
||||
StreamChunk objects as they arrive.
|
||||
"""
|
||||
ctx = contextvars.copy_context()
|
||||
if state.stream_id is not None:
|
||||
token = _current_stream_ids.set((*_current_stream_ids.get(), state.stream_id))
|
||||
ctx = contextvars.copy_context()
|
||||
_current_stream_ids.reset(token)
|
||||
else:
|
||||
ctx = contextvars.copy_context()
|
||||
thread = threading.Thread(target=ctx.run, args=(run_func,), daemon=True)
|
||||
thread.start()
|
||||
|
||||
@@ -300,7 +315,12 @@ async def create_async_chunk_generator(
|
||||
"Async queue not initialized. Use create_streaming_state(use_async=True)."
|
||||
)
|
||||
|
||||
task = asyncio.create_task(run_coro())
|
||||
if state.stream_id is not None:
|
||||
token = _current_stream_ids.set((*_current_stream_ids.get(), state.stream_id))
|
||||
task = asyncio.create_task(run_coro())
|
||||
_current_stream_ids.reset(token)
|
||||
else:
|
||||
task = asyncio.create_task(run_coro())
|
||||
|
||||
try:
|
||||
while True:
|
||||
|
||||
@@ -879,3 +879,91 @@ class TestStreamingImports:
|
||||
assert StreamChunk is not None
|
||||
assert StreamChunkType is not None
|
||||
assert ToolCallChunk is not None
|
||||
|
||||
|
||||
class TestConcurrentStreamIsolation:
|
||||
"""Regression tests for concurrent streaming isolation (issue #5376)."""
|
||||
|
||||
def test_concurrent_streams_do_not_cross_contaminate(self) -> None:
|
||||
"""Two concurrent streaming runs must each receive only their own chunks.
|
||||
|
||||
Mirrors the real production path: create_streaming_state in the caller,
|
||||
then temporarily push the stream_id into the ContextVar, copy_context,
|
||||
and reset — exactly as create_chunk_generator does.
|
||||
"""
|
||||
import contextvars
|
||||
import threading
|
||||
|
||||
from crewai.utilities.streaming import (
|
||||
TaskInfo,
|
||||
_current_stream_ids,
|
||||
_unregister_handler,
|
||||
create_streaming_state,
|
||||
)
|
||||
|
||||
task_info_a: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "task_a",
|
||||
"id": "a",
|
||||
"agent_role": "A",
|
||||
"agent_id": "a",
|
||||
}
|
||||
task_info_b: TaskInfo = {
|
||||
"index": 1,
|
||||
"name": "task_b",
|
||||
"id": "b",
|
||||
"agent_role": "B",
|
||||
"agent_id": "b",
|
||||
}
|
||||
|
||||
state_a = create_streaming_state(task_info_a, [])
|
||||
state_b = create_streaming_state(task_info_b, [])
|
||||
|
||||
def make_emitter_ctx(state: Any) -> contextvars.Context:
|
||||
token = _current_stream_ids.set(
|
||||
(*_current_stream_ids.get(), state.stream_id)
|
||||
)
|
||||
ctx = contextvars.copy_context()
|
||||
_current_stream_ids.reset(token)
|
||||
return ctx
|
||||
|
||||
ctx_a = make_emitter_ctx(state_a)
|
||||
ctx_b = make_emitter_ctx(state_b)
|
||||
|
||||
def emit_chunks(prefix: str, call_id: str) -> None:
|
||||
for text in [f"{prefix}1", f"{prefix}2", f"{prefix}3"]:
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
event=LLMStreamChunkEvent(
|
||||
chunk=text, call_id=call_id, response_id="r"
|
||||
),
|
||||
)
|
||||
|
||||
t_a = threading.Thread(target=ctx_a.run, args=(lambda: emit_chunks("A", "ca"),))
|
||||
t_b = threading.Thread(target=ctx_b.run, args=(lambda: emit_chunks("B", "cb"),))
|
||||
t_a.start()
|
||||
t_b.start()
|
||||
t_a.join()
|
||||
t_b.join()
|
||||
|
||||
chunks_a: list[str] = []
|
||||
while not state_a.sync_queue.empty():
|
||||
item = state_a.sync_queue.get_nowait()
|
||||
if isinstance(item, StreamChunk):
|
||||
chunks_a.append(item.content)
|
||||
|
||||
chunks_b: list[str] = []
|
||||
while not state_b.sync_queue.empty():
|
||||
item = state_b.sync_queue.get_nowait()
|
||||
if isinstance(item, StreamChunk):
|
||||
chunks_b.append(item.content)
|
||||
|
||||
assert set(chunks_a) == {"A1", "A2", "A3"}, (
|
||||
f"Stream A received unexpected chunks: {chunks_a}"
|
||||
)
|
||||
assert set(chunks_b) == {"B1", "B2", "B3"}, (
|
||||
f"Stream B received unexpected chunks: {chunks_b}"
|
||||
)
|
||||
|
||||
_unregister_handler(state_a.handler)
|
||||
_unregister_handler(state_b.handler)
|
||||
|
||||
Reference in New Issue
Block a user