Compare commits

...

3 Commits

Author SHA1 Message Date
lorenzejay
4fd8327f25 ruff following for getattr alt 2026-06-26 14:18:28 -07:00
lorenzejay
6827131cd7 Merge branch 'main' of github.com:crewAIInc/crewAI into lorenze/imp/streaming 2026-06-26 14:12:11 -07:00
lorenzejay
1cb1fa8264 feat: enhance streaming support in conversational flow
- Introduced  function to determine if a result is a streaming output.
- Added  method to handle streaming results before accessing them.
- Updated  method to utilize the new streaming result handling.
- Implemented context management for LLM streaming in the conversational mixin.
- Added tests to verify streaming behavior and ensure proper handling of user messages during streaming.
2026-06-26 14:11:59 -07:00
5 changed files with 195 additions and 17 deletions

View File

@@ -4,6 +4,7 @@ Two-column layout: left sidebar (tasks/agents/tokens) + main content
(task header, plan checklist, activity timeline, streaming output).
"""
from collections.abc import Iterable
import json as _json
import re
import threading
@@ -46,6 +47,19 @@ def _is_save_to_memory_tool(tool_name: str | None) -> bool:
return (tool_name or "").replace(" ", "_").lower() == "save_to_memory"
def _is_streaming_output(value: Any) -> bool:
if not isinstance(value, Iterable):
return False
value_type = type(value)
try:
value_type.get_full_text # noqa: B018
value_type.result # noqa: B018
except AttributeError:
return False
return True
def _truncate_log_text(value: Any, limit: int) -> str | None:
if value is None:
return None
@@ -836,14 +850,18 @@ FooterKey .footer-key--key {
set_suppress_tracing_messages(True)
try:
result = self._flow.handle_turn(message)
if hasattr(result, "get_full_text") and hasattr(result, "result"):
for _chunk in result:
pass
result = result.result
result = self._consume_conversation_streaming_result(result)
self.call_from_thread(self._on_conversation_turn_done, result)
except Exception as e:
self.call_from_thread(self._on_conversation_turn_failed, str(e))
def _consume_conversation_streaming_result(self, result: Any) -> Any:
if not _is_streaming_output(result):
return result
for _chunk in result:
pass
return result.result
def _on_conversation_turn_done(self, result: Any) -> None:
with self._lock:
output = self._stringify_output(result)

View File

@@ -26,6 +26,7 @@ from crewai.events.types.tool_usage_events import (
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
)
from crewai.types.streaming import FlowStreamingOutput, StreamChunk
from crewai_cli.command import AuthenticationRequiredError
from crewai_cli import run_crew
from crewai_cli.crew_run_tui import (
@@ -177,6 +178,29 @@ def test_conversation_turn_done_records_assistant_message() -> None:
assert isinstance(app._crew_result, RawResult)
def test_conversation_streaming_result_is_consumed_before_result_access() -> None:
streaming = FlowStreamingOutput()
result_accessed_before_completion = False
def chunks():
yield StreamChunk(content="hello ")
yield StreamChunk(content="world")
streaming._set_result("hello world")
streaming._sync_iterator = chunks()
try:
streaming.result
except RuntimeError:
result_accessed_before_completion = True
app = CrewRunApp(conversational=True)
assert result_accessed_before_completion is True
assert app._consume_conversation_streaming_result(streaming) == "hello world"
assert streaming.get_full_text() == "hello world"
@pytest.mark.asyncio
async def test_conversation_input_submits_turn() -> None:
class FakeFlow:

View File

@@ -19,6 +19,7 @@ Import surface:
from __future__ import annotations
from collections.abc import Callable, Mapping, Sequence
from contextlib import contextmanager
from enum import Enum
import json
import logging
@@ -62,6 +63,21 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
@contextmanager
def _streaming_conversation_llm(llm: BaseLLM, *, enabled: bool) -> Any:
"""Temporarily enable LLM streaming for Flow streaming turns."""
if not enabled:
yield
return
previous_stream = llm.stream
llm.stream = True
try:
yield
finally:
llm.stream = previous_stream
def _iter_condition_labels(condition: Any) -> set[str]:
if isinstance(condition, str):
return {condition}
@@ -146,6 +162,9 @@ class _ConversationalMixin:
def _copy_and_serialize_state(self) -> dict[str, Any]:
pass
def _should_stream_llm_calls(self) -> bool:
pass
def kickoff(self, *args: Any, **kwargs: Any) -> Any:
pass
@@ -221,7 +240,12 @@ class _ConversationalMixin:
messages.append({"role": "system", "content": system_prompt})
messages.extend(self.conversation_messages)
response = self._coerce_llm(llm).call(messages=messages)
llm_instance = self._coerce_llm(llm)
with _streaming_conversation_llm(
llm_instance,
enabled=self._should_stream_llm_calls(),
):
response = llm_instance.call(messages=messages)
content = self._stringify_result(response)
self.append_assistant_message(content)
return content
@@ -703,6 +727,27 @@ class _ConversationalMixin:
def _apply_pending_kickoff_context(self) -> None:
self._apply_pending_conversational_turn()
def _capture_pending_kickoff_context(self) -> dict[str, Any] | None:
if not self._should_apply_pending_kickoff_context():
return None
return {
"user_message": self._pending_user_message,
"intents": self._pending_intents,
"intent_llm": self._pending_intent_llm,
}
def _restore_pending_kickoff_context(self, context: Any) -> None:
if not isinstance(context, dict):
return
self._pending_user_message = context["user_message"]
self._pending_intents = context["intents"]
self._pending_intent_llm = context["intent_llm"]
def _clear_pending_kickoff_context(self) -> None:
self._pending_user_message = None
self._pending_intents = None
self._pending_intent_llm = None
def _order_start_methods_for_kickoff(
self,
start_methods: list[Any],

View File

@@ -15,6 +15,7 @@ from collections.abc import (
Sequence,
)
from concurrent.futures import Future, ThreadPoolExecutor
from contextlib import contextmanager
import contextvars
import copy
from datetime import datetime
@@ -460,6 +461,16 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
def _apply_pending_kickoff_context(self) -> None:
"""Apply optional runtime-extension kickoff context."""
def _capture_pending_kickoff_context(self) -> Any | None:
"""Capture optional pending kickoff context for deferred execution."""
return None
def _restore_pending_kickoff_context(self, context: Any) -> None:
"""Restore optional pending kickoff context in deferred execution."""
def _clear_pending_kickoff_context(self) -> None:
"""Clear optional pending kickoff context after deferred execution."""
def _order_start_methods_for_kickoff(
self,
start_methods: list[FlowMethodName],
@@ -471,6 +482,19 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
"""Whether this kickoff should defer final flow trace finalization."""
return bool(getattr(self, "defer_trace_finalization", False))
def _should_stream_llm_calls(self) -> bool:
"""Whether LLM calls inside the current flow run should stream chunks."""
return self.stream or self._streaming_run_active
@contextmanager
def _streaming_run(self) -> Iterator[None]:
previous_streaming_run = self._streaming_run_active
self._streaming_run_active = True
try:
yield
finally:
self._streaming_run_active = previous_streaming_run
@classmethod
def flow_definition(cls) -> FlowDefinition:
"""Return the static Flow Definition built from this Flow class."""
@@ -735,6 +759,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
_usage_aggregation_handler: Callable[..., Any] | None = PrivateAttr(default=None)
_persist_backends: dict[int, FlowPersistence] = PrivateAttr(default_factory=dict)
_instance_persistence: bool = PrivateAttr(default=False)
_streaming_run_active: bool = PrivateAttr(default=False)
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]: # type: ignore[override]
class _FlowGeneric(cls): # type: ignore[valid-type,misc]
@@ -1872,6 +1897,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
return restored.kickoff(inputs=inputs, input_files=input_files)
if self.stream:
result_holder: list[Any] = []
pending_kickoff_context = self._capture_pending_kickoff_context()
current_task_info: TaskInfo = {
"index": 0,
"name": "",
@@ -1887,12 +1913,15 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
def run_flow() -> None:
try:
self.stream = False
result = self.kickoff(
inputs=inputs,
input_files=input_files,
restore_from_state_id=restore_from_state_id,
)
if pending_kickoff_context is not None:
self._restore_pending_kickoff_context(pending_kickoff_context)
with self._streaming_run():
self.stream = False
result = self.kickoff(
inputs=inputs,
input_files=input_files,
restore_from_state_id=restore_from_state_id,
)
result_holder.append(result)
except Exception as e:
# HumanFeedbackPending is expected control flow, not an error
@@ -1901,6 +1930,8 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
else:
signal_error(state, e)
finally:
if pending_kickoff_context is not None:
self._clear_pending_kickoff_context()
self.stream = True
signal_end(state)
@@ -1972,6 +2003,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
return await restored.kickoff_async(inputs=inputs, input_files=input_files)
if self.stream:
result_holder: list[Any] = []
pending_kickoff_context = self._capture_pending_kickoff_context()
current_task_info: TaskInfo = {
"index": 0,
"name": "",
@@ -1987,12 +2019,15 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
async def run_flow() -> None:
try:
self.stream = False
result = await self.kickoff_async(
inputs=inputs,
input_files=input_files,
restore_from_state_id=restore_from_state_id,
)
if pending_kickoff_context is not None:
self._restore_pending_kickoff_context(pending_kickoff_context)
with self._streaming_run():
self.stream = False
result = await self.kickoff_async(
inputs=inputs,
input_files=input_files,
restore_from_state_id=restore_from_state_id,
)
result_holder.append(result)
except Exception as e:
# HumanFeedbackPending is expected control flow, not an error
@@ -2001,6 +2036,8 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
else:
signal_error(state, e, is_async=True)
finally:
if pending_kickoff_context is not None:
self._clear_pending_kickoff_context()
self.stream = True
signal_end(state, is_async=True)

View File

@@ -204,6 +204,60 @@ class TestConversationalFlow:
assert flow.state.events[0].agent_name == "researcher"
assert flow.state.events[0].visibility == "public"
def test_builtin_converse_enables_llm_streaming_for_streaming_flow(self) -> None:
llm = MagicMock()
llm.stream = False
stream_values_seen: list[bool | None] = []
def call(*args: Any, **kwargs: Any) -> str:
stream_values_seen.append(llm.stream)
return "streamed reply"
llm.call.side_effect = call
@ConversationConfig(llm=llm)
class StreamingFlow(ConversationalFlow):
pass
flow = StreamingFlow()
flow.stream = False
with flow._streaming_run():
result = flow.converse_turn()
assert result == "streamed reply"
assert stream_values_seen == [True]
assert llm.stream is False
assert flow._should_stream_llm_calls() is False
assert flow.state.messages[-1].content == "streamed reply"
def test_streaming_handle_turn_preserves_pending_user_message(self) -> None:
@ConversationConfig(llm="unused")
class StreamingEchoFlow(ConversationalFlow):
stream = True
def route_turn(self, context: dict[str, Any]) -> str:
return "echo"
@listen("echo")
def handle_echo(self) -> str:
reply = f"heard: {self.state.current_user_message}"
self.append_assistant_message(reply)
return reply
flow = StreamingEchoFlow()
result = flow.handle_turn("hello streaming")
for _chunk in result:
pass
assert result.result == "heard: hello streaming"
assert [message.role for message in flow.state.messages] == [
"user",
"assistant",
]
assert flow.state.messages[0].content == "hello streaming"
assert flow.state.messages[1].content == "heard: hello streaming"
@conversational_graph_broken
def test_private_agent_results_stay_out_of_shared_history(self) -> None:
class PrivateFlow(ConversationalFlow):