diff --git a/lib/crewai/src/crewai/types/streaming.py b/lib/crewai/src/crewai/types/streaming.py index aca458565..548a7e6ac 100644 --- a/lib/crewai/src/crewai/types/streaming.py +++ b/lib/crewai/src/crewai/types/streaming.py @@ -160,9 +160,14 @@ class StreamSession(StreamSessionBase[T]): self, channels: Sequence[StreamChannel] | None = None ) -> Iterator[StreamFrame]: """Iterate over frames, optionally filtered by channel.""" + selected = set(channels) if channels is not None else None + if self._exhausted: + for frame in self._frames: + if selected is None or frame.channel in selected: + yield frame + return if self._sync_iterator is None: raise RuntimeError("Sync iterator not available") - selected = set(channels) if channels is not None else None try: for frame in self._sync_iterator: self._frames.append(frame) @@ -237,9 +242,14 @@ class AsyncStreamSession(StreamSessionBase[T]): self, channels: Sequence[StreamChannel] | None = None ) -> AsyncIterator[StreamFrame]: """Iterate over frames, optionally filtered by channel.""" + selected = set(channels) if channels is not None else None + if self._exhausted: + for frame in self._frames: + if selected is None or frame.channel in selected: + yield frame + return if self._async_iterator is None: raise RuntimeError("Async iterator not available") - selected = set(channels) if channels is not None else None try: async for frame in self._async_iterator: self._frames.append(frame) diff --git a/lib/crewai/tests/test_stream_frames.py b/lib/crewai/tests/test_stream_frames.py index d844ae78f..1515b9453 100644 --- a/lib/crewai/tests/test_stream_frames.py +++ b/lib/crewai/tests/test_stream_frames.py @@ -118,6 +118,34 @@ def test_stream_subscribe_filters_channels_without_losing_order() -> None: assert stream.result == "done" +def test_stream_projections_replay_cached_frames_after_exhaustion() -> None: + with FrameFlow().stream_events() as stream: + all_frames = list(stream.events) + + assert [frame.content for frame in stream.llm if frame.content] == [ + "hello", + "thinking", + ] + assert [frame.type for frame in stream.tools] == ["tool_usage_started"] + assert list(stream.events) == all_frames + + +def test_stream_channel_projection_can_be_followed_by_cached_projection() -> None: + with FrameFlow().stream_events() as stream: + llm_frames = list(stream.llm) + + assert [frame.content for frame in llm_frames if frame.content] == [ + "hello", + "thinking", + ] + assert [frame.type for frame in stream.flow] == [ + "flow_started", + "method_execution_started", + "method_execution_finished", + "flow_finished", + ] + + def test_stream_errors_surface_after_failed_frame() -> None: class ErrorFlow(Flow): @start() @@ -197,6 +225,42 @@ async def test_astream_scopes_concurrent_executions() -> None: assert second == ("second", ["second"]) +@pytest.mark.asyncio +async def test_async_stream_projections_replay_cached_frames_after_exhaustion() -> None: + async with FrameFlow().astream() as stream: + all_frames = [frame async for frame in stream.events] + + llm_frames = [frame async for frame in stream.llm] + tool_frames = [frame async for frame in stream.tools] + replayed_frames = [frame async for frame in stream.events] + + assert [frame.content for frame in llm_frames if frame.content] == [ + "hello", + "thinking", + ] + assert [frame.type for frame in tool_frames] == ["tool_usage_started"] + assert replayed_frames == all_frames + + +@pytest.mark.asyncio +async def test_async_stream_channel_projection_can_be_followed_by_cached_projection() -> None: + async with FrameFlow().astream() as stream: + llm_frames = [frame async for frame in stream.llm] + + flow_frames = [frame async for frame in stream.flow] + + assert [frame.content for frame in llm_frames if frame.content] == [ + "hello", + "thinking", + ] + assert [frame.type for frame in flow_frames] == [ + "flow_started", + "method_execution_started", + "method_execution_finished", + "flow_finished", + ] + + @pytest.mark.asyncio async def test_astream_cancellation_cleans_up_task() -> None: class SlowFlow(Flow):