mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-01 13:18:10 +00:00
* Define stream frame protocol for flows * Add direct LLM streaming helpers * Unify flow streaming frame items * Update flow streaming integration properties * Drop stream frame debug runner example * Address streaming contract review feedback * Replay cached stream frame projections * Remove stream frame version field * Fix streaming contract docs link * Preserve LLM instance state for stream events * Address streaming review cleanup
301 lines
9.7 KiB
Python
301 lines
9.7 KiB
Python
"""Tests for the public stream frame contract."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections.abc import AsyncIterator
|
|
from typing import Any, ClassVar
|
|
|
|
import pytest
|
|
|
|
from crewai.events.event_bus import crewai_event_bus
|
|
from crewai.events.types.flow_events import ConversationMessageAddedEvent
|
|
from crewai.events.types.llm_events import LLMStreamChunkEvent, LLMThinkingChunkEvent
|
|
from crewai.events.types.tool_usage_events import ToolUsageStartedEvent
|
|
from crewai.flow.flow import Flow, start
|
|
from crewai.llms.base_llm import BaseLLM, call_stop_override
|
|
from crewai.types.streaming import StreamFrame
|
|
|
|
|
|
class FrameFlow(Flow):
|
|
@start()
|
|
def run(self) -> str:
|
|
crewai_event_bus.emit(
|
|
self,
|
|
LLMStreamChunkEvent(
|
|
type="llm_stream_chunk",
|
|
chunk="hello",
|
|
call_id="call-1",
|
|
),
|
|
)
|
|
crewai_event_bus.emit(
|
|
self,
|
|
LLMThinkingChunkEvent(
|
|
type="llm_thinking_chunk",
|
|
chunk="thinking",
|
|
call_id="call-1",
|
|
),
|
|
)
|
|
crewai_event_bus.emit(
|
|
self,
|
|
ConversationMessageAddedEvent(
|
|
type="conversation_message_added",
|
|
flow_name=self._definition.name,
|
|
session_id="session-1",
|
|
role="assistant",
|
|
content="hello",
|
|
message_index=0,
|
|
),
|
|
)
|
|
crewai_event_bus.emit(
|
|
self,
|
|
ToolUsageStartedEvent(
|
|
type="tool_usage_started",
|
|
tool_name="search",
|
|
tool_args={"query": "crew"},
|
|
),
|
|
)
|
|
return "done"
|
|
|
|
|
|
class DirectStreamingLLM(BaseLLM):
|
|
call_stream_values: ClassVar[list[bool | None]] = []
|
|
raw_stream_values: ClassVar[list[bool | None]] = []
|
|
call_instance_ids: ClassVar[list[int]] = []
|
|
call_stop_values: ClassVar[list[list[str]]] = []
|
|
|
|
def call(self, messages: Any, *args: Any, **kwargs: Any) -> str:
|
|
self.call_stream_values.append(self._effective_stream())
|
|
self.raw_stream_values.append(self.stream)
|
|
self.call_instance_ids.append(id(self))
|
|
self.call_stop_values.append(list(self.stop_sequences))
|
|
self._track_token_usage_internal(
|
|
{
|
|
"prompt_tokens": 1,
|
|
"completion_tokens": 2,
|
|
"total_tokens": 3,
|
|
}
|
|
)
|
|
crewai_event_bus.emit(
|
|
self,
|
|
LLMStreamChunkEvent(
|
|
type="llm_stream_chunk",
|
|
chunk="hel",
|
|
call_id="call-1",
|
|
),
|
|
)
|
|
crewai_event_bus.emit(
|
|
self,
|
|
LLMStreamChunkEvent(
|
|
type="llm_stream_chunk",
|
|
chunk="lo",
|
|
call_id="call-1",
|
|
),
|
|
)
|
|
return "hello"
|
|
|
|
|
|
def test_stream_frame_contract_and_ordering() -> None:
|
|
stream = FrameFlow().stream_events()
|
|
|
|
with pytest.raises(RuntimeError, match="Streaming has not completed yet"):
|
|
_ = stream.result
|
|
|
|
with stream:
|
|
frames = list(stream.events)
|
|
|
|
assert stream.result == "done"
|
|
assert all(isinstance(frame, StreamFrame) for frame in frames)
|
|
assert [frame.seq for frame in frames] == sorted(frame.seq for frame in frames)
|
|
|
|
by_type = {frame.type: frame for frame in frames}
|
|
assert by_type["flow_started"].channel == "flow"
|
|
assert by_type["method_execution_started"].parent_id == by_type["flow_started"].id
|
|
assert by_type["llm_stream_chunk"].channel == "llm"
|
|
assert by_type["llm_thinking_chunk"].channel == "llm"
|
|
assert by_type["conversation_message_added"].channel == "messages"
|
|
assert by_type["tool_usage_started"].channel == "tools"
|
|
assert "FrameFlow" in by_type["method_execution_started"].namespace
|
|
assert "run" in by_type["method_execution_started"].namespace
|
|
|
|
|
|
def test_stream_subscribe_filters_channels_without_losing_order() -> None:
|
|
with FrameFlow().stream_events() as stream:
|
|
frames = list(stream.interleave(["messages", "tools"]))
|
|
|
|
assert [frame.channel for frame in frames] == ["messages", "tools"]
|
|
assert [frame.seq for frame in frames] == sorted(frame.seq for frame in frames)
|
|
assert stream.result == "done"
|
|
|
|
|
|
def test_stream_projections_replay_cached_frames_after_exhaustion() -> None:
|
|
with FrameFlow().stream_events() as stream:
|
|
all_frames = list(stream.events)
|
|
|
|
assert [frame.content for frame in stream.llm if frame.content] == [
|
|
"hello",
|
|
"thinking",
|
|
]
|
|
assert [frame.type for frame in stream.tools] == ["tool_usage_started"]
|
|
assert list(stream.events) == all_frames
|
|
|
|
|
|
def test_stream_channel_projection_can_be_followed_by_cached_projection() -> None:
|
|
with FrameFlow().stream_events() as stream:
|
|
llm_frames = list(stream.llm)
|
|
|
|
assert [frame.content for frame in llm_frames if frame.content] == [
|
|
"hello",
|
|
"thinking",
|
|
]
|
|
assert [frame.type for frame in stream.flow] == [
|
|
"flow_started",
|
|
"method_execution_started",
|
|
"method_execution_finished",
|
|
"flow_finished",
|
|
]
|
|
|
|
|
|
def test_stream_errors_surface_after_failed_frame() -> None:
|
|
class ErrorFlow(Flow):
|
|
@start()
|
|
def run(self) -> str:
|
|
raise ValueError("boom")
|
|
|
|
stream = ErrorFlow().stream_events()
|
|
|
|
with pytest.raises(ValueError, match="boom"):
|
|
list(stream.events)
|
|
|
|
assert any(frame.type == "method_execution_failed" for frame in stream.frames)
|
|
with pytest.raises(ValueError, match="boom"):
|
|
_ = stream.result
|
|
|
|
|
|
def test_flow_streaming_returns_iterable_frame_session() -> None:
|
|
flow = FrameFlow()
|
|
flow.stream = True
|
|
|
|
stream = flow.kickoff()
|
|
|
|
with stream:
|
|
frames = list(stream)
|
|
|
|
assert all(isinstance(frame, StreamFrame) for frame in frames)
|
|
assert [frame.content for frame in frames if frame.content] == [
|
|
"hello",
|
|
"thinking",
|
|
]
|
|
first_content_frame = next(frame for frame in frames if frame.content)
|
|
assert first_content_frame.event["chunk"] == "hello"
|
|
assert stream.result == "done"
|
|
|
|
|
|
def test_direct_llm_stream_events_scope_and_restore_stream_flag() -> None:
|
|
DirectStreamingLLM.call_stream_values = []
|
|
DirectStreamingLLM.raw_stream_values = []
|
|
DirectStreamingLLM.call_instance_ids = []
|
|
DirectStreamingLLM.call_stop_values = []
|
|
llm = DirectStreamingLLM(model="gpt-4o-mini", stream=False)
|
|
|
|
with call_stop_override(llm, ["STOP"]):
|
|
with llm.stream_events("hello") as stream:
|
|
frames = list(stream)
|
|
|
|
assert [frame.content for frame in frames] == ["hel", "lo"]
|
|
assert frames[0].event["chunk"] == "hel"
|
|
assert stream.result == "hello"
|
|
assert llm.stream is False
|
|
assert DirectStreamingLLM.call_stream_values == [True]
|
|
assert DirectStreamingLLM.raw_stream_values == [False]
|
|
assert DirectStreamingLLM.call_instance_ids == [id(llm)]
|
|
assert DirectStreamingLLM.call_stop_values == [["STOP"]]
|
|
usage = llm.get_token_usage_summary()
|
|
assert usage.total_tokens == 3
|
|
assert usage.prompt_tokens == 1
|
|
assert usage.completion_tokens == 2
|
|
assert usage.successful_requests == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_astream_scopes_concurrent_executions() -> None:
|
|
class ConcurrentFlow(Flow):
|
|
@start()
|
|
async def run(self) -> str:
|
|
label = str(self.state["label"])
|
|
await asyncio.sleep(0)
|
|
crewai_event_bus.emit(
|
|
self,
|
|
LLMStreamChunkEvent(
|
|
type="llm_stream_chunk",
|
|
chunk=label,
|
|
call_id=label,
|
|
),
|
|
)
|
|
return label
|
|
|
|
async def collect(label: str) -> tuple[str, list[str]]:
|
|
async with ConcurrentFlow().astream(inputs={"label": label}) as stream:
|
|
frames = [frame async for frame in stream.llm]
|
|
return stream.result, [frame.data["chunk"] for frame in frames]
|
|
|
|
first, second = await asyncio.gather(collect("first"), collect("second"))
|
|
|
|
assert first == ("first", ["first"])
|
|
assert second == ("second", ["second"])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_stream_projections_replay_cached_frames_after_exhaustion() -> None:
|
|
async with FrameFlow().astream() as stream:
|
|
all_frames = [frame async for frame in stream.events]
|
|
|
|
llm_frames = [frame async for frame in stream.llm]
|
|
tool_frames = [frame async for frame in stream.tools]
|
|
replayed_frames = [frame async for frame in stream.events]
|
|
|
|
assert [frame.content for frame in llm_frames if frame.content] == [
|
|
"hello",
|
|
"thinking",
|
|
]
|
|
assert [frame.type for frame in tool_frames] == ["tool_usage_started"]
|
|
assert replayed_frames == all_frames
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_stream_channel_projection_can_be_followed_by_cached_projection() -> None:
|
|
async with FrameFlow().astream() as stream:
|
|
llm_frames = [frame async for frame in stream.llm]
|
|
|
|
flow_frames = [frame async for frame in stream.flow]
|
|
|
|
assert [frame.content for frame in llm_frames if frame.content] == [
|
|
"hello",
|
|
"thinking",
|
|
]
|
|
assert [frame.type for frame in flow_frames] == [
|
|
"flow_started",
|
|
"method_execution_started",
|
|
"method_execution_finished",
|
|
"flow_finished",
|
|
]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_astream_cancellation_cleans_up_task() -> None:
|
|
class SlowFlow(Flow):
|
|
@start()
|
|
async def run(self) -> str:
|
|
await asyncio.sleep(10)
|
|
return "too late"
|
|
|
|
stream = SlowFlow().astream()
|
|
events: AsyncIterator[StreamFrame] = stream.events
|
|
first_frame = await anext(events)
|
|
|
|
assert first_frame.type == "flow_started"
|
|
await stream.aclose()
|
|
|
|
assert stream.is_cancelled is True
|
|
assert stream.is_completed is True
|