Replay cached stream frame projections

This commit is contained in:
lorenzejay
2026-06-29 15:54:37 -07:00
parent 7de7e32bb2
commit 926057635d
2 changed files with 76 additions and 2 deletions

View File

@@ -160,9 +160,14 @@ class StreamSession(StreamSessionBase[T]):
self, channels: Sequence[StreamChannel] | None = None
) -> Iterator[StreamFrame]:
"""Iterate over frames, optionally filtered by channel."""
selected = set(channels) if channels is not None else None
if self._exhausted:
for frame in self._frames:
if selected is None or frame.channel in selected:
yield frame
return
if self._sync_iterator is None:
raise RuntimeError("Sync iterator not available")
selected = set(channels) if channels is not None else None
try:
for frame in self._sync_iterator:
self._frames.append(frame)
@@ -237,9 +242,14 @@ class AsyncStreamSession(StreamSessionBase[T]):
self, channels: Sequence[StreamChannel] | None = None
) -> AsyncIterator[StreamFrame]:
"""Iterate over frames, optionally filtered by channel."""
selected = set(channels) if channels is not None else None
if self._exhausted:
for frame in self._frames:
if selected is None or frame.channel in selected:
yield frame
return
if self._async_iterator is None:
raise RuntimeError("Async iterator not available")
selected = set(channels) if channels is not None else None
try:
async for frame in self._async_iterator:
self._frames.append(frame)

View File

@@ -118,6 +118,34 @@ def test_stream_subscribe_filters_channels_without_losing_order() -> None:
assert stream.result == "done"
def test_stream_projections_replay_cached_frames_after_exhaustion() -> None:
with FrameFlow().stream_events() as stream:
all_frames = list(stream.events)
assert [frame.content for frame in stream.llm if frame.content] == [
"hello",
"thinking",
]
assert [frame.type for frame in stream.tools] == ["tool_usage_started"]
assert list(stream.events) == all_frames
def test_stream_channel_projection_can_be_followed_by_cached_projection() -> None:
with FrameFlow().stream_events() as stream:
llm_frames = list(stream.llm)
assert [frame.content for frame in llm_frames if frame.content] == [
"hello",
"thinking",
]
assert [frame.type for frame in stream.flow] == [
"flow_started",
"method_execution_started",
"method_execution_finished",
"flow_finished",
]
def test_stream_errors_surface_after_failed_frame() -> None:
class ErrorFlow(Flow):
@start()
@@ -197,6 +225,42 @@ async def test_astream_scopes_concurrent_executions() -> None:
assert second == ("second", ["second"])
@pytest.mark.asyncio
async def test_async_stream_projections_replay_cached_frames_after_exhaustion() -> None:
async with FrameFlow().astream() as stream:
all_frames = [frame async for frame in stream.events]
llm_frames = [frame async for frame in stream.llm]
tool_frames = [frame async for frame in stream.tools]
replayed_frames = [frame async for frame in stream.events]
assert [frame.content for frame in llm_frames if frame.content] == [
"hello",
"thinking",
]
assert [frame.type for frame in tool_frames] == ["tool_usage_started"]
assert replayed_frames == all_frames
@pytest.mark.asyncio
async def test_async_stream_channel_projection_can_be_followed_by_cached_projection() -> None:
async with FrameFlow().astream() as stream:
llm_frames = [frame async for frame in stream.llm]
flow_frames = [frame async for frame in stream.flow]
assert [frame.content for frame in llm_frames if frame.content] == [
"hello",
"thinking",
]
assert [frame.type for frame in flow_frames] == [
"flow_started",
"method_execution_started",
"method_execution_finished",
"flow_finished",
]
@pytest.mark.asyncio
async def test_astream_cancellation_cleans_up_task() -> None:
class SlowFlow(Flow):