diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index 153bbd2d7..5db0d300b 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -749,7 +749,7 @@ class LLM(BaseLLM): "base_url": self.base_url, "api_version": self.api_version, "api_key": self.api_key, - "stream": self.stream, + "stream": self._effective_stream(), "tools": tools, "reasoning_effort": self.reasoning_effort, **self.additional_params, @@ -1841,7 +1841,7 @@ class LLM(BaseLLM): self.set_callbacks(callbacks) try: params = self._prepare_completion_params(messages, tools) - if self.stream: + if self._effective_stream(): result = self._handle_streaming_response( params=params, callbacks=callbacks, @@ -1983,7 +1983,7 @@ class LLM(BaseLLM): messages, tools, skip_file_processing=True ) - if self.stream: + if self._effective_stream(): return await self._ahandle_streaming_response( params=params, callbacks=callbacks, diff --git a/lib/crewai/src/crewai/llms/base_llm.py b/lib/crewai/src/crewai/llms/base_llm.py index d805b04aa..d679bf670 100644 --- a/lib/crewai/src/crewai/llms/base_llm.py +++ b/lib/crewai/src/crewai/llms/base_llm.py @@ -82,6 +82,9 @@ _current_call_id: contextvars.ContextVar[str | None] = contextvars.ContextVar( _call_stop_override_var: contextvars.ContextVar[dict[int, list[str]] | None] = ( contextvars.ContextVar("_call_stop_override_var", default=None) ) +_call_stream_override_var: contextvars.ContextVar[dict[int, bool] | None] = ( + contextvars.ContextVar("_call_stream_override_var", default=None) +) @contextmanager @@ -120,6 +123,19 @@ def call_stop_override( _call_stop_override_var.reset(token) +@contextmanager +def call_stream_override(llm: BaseLLM, stream: bool) -> Generator[None, None, None]: + """Override streaming for ``llm`` within the current call scope.""" + current = _call_stream_override_var.get() + new_overrides: dict[int, bool] = dict(current) if current else {} + new_overrides[id(llm)] = stream + token = _call_stream_override_var.set(new_overrides) + try: + yield + finally: + _call_stream_override_var.reset(token) + + def get_current_call_id() -> str: """Get current call_id from context""" call_id = _current_call_id.get() @@ -218,6 +234,13 @@ class BaseLLM(BaseModel, ABC): return override return self.stop + def _effective_stream(self) -> bool | None: + """Return the call-scoped streaming mode for this instance.""" + overrides = _call_stream_override_var.get() + if overrides is not None and id(self) in overrides: + return overrides[id(self)] + return self.stream + _token_usage: dict[str, int] = PrivateAttr( default_factory=lambda: { "total_tokens": 0, @@ -339,16 +362,16 @@ class BaseLLM(BaseModel, ABC): output_holder: list[StreamSession[Any]] = [] def run_llm_call() -> Any: - streaming_llm = self.model_copy(update={"stream": True}) - return streaming_llm.call( - messages=messages, - tools=tools, - callbacks=callbacks, - available_functions=available_functions, - from_task=from_task, - from_agent=from_agent, - response_model=response_model, - ) + with call_stream_override(self, True): + return self.call( + messages=messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ) stream_session: StreamSession[Any] = StreamSession( sync_iterator=create_frame_generator(state, run_llm_call, output_holder) @@ -547,7 +570,7 @@ class BaseLLM(BaseModel, ABC): if max_tokens is None: max_tokens = self._effective_max_tokens() if stream is None: - stream = self.stream + stream = self._effective_stream() if seed is None: seed = self.seed if stop_sequences is None: diff --git a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py index 599ec5a3b..b012aa61d 100644 --- a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llms/providers/anthropic/completion.py @@ -323,7 +323,7 @@ class AnthropicCompletion(BaseLLM): effective_response_model = response_model or self.response_format - if self.stream: + if self._effective_stream(): return self._handle_streaming_completion( completion_params, available_functions, @@ -393,7 +393,7 @@ class AnthropicCompletion(BaseLLM): effective_response_model = response_model or self.response_format - if self.stream: + if self._effective_stream(): return await self._ahandle_streaming_completion( completion_params, available_functions, @@ -441,7 +441,7 @@ class AnthropicCompletion(BaseLLM): "model": self.model, "messages": messages, "max_tokens": self.max_tokens, - "stream": self.stream, + "stream": self._effective_stream(), } if system_message: diff --git a/lib/crewai/src/crewai/llms/providers/azure/completion.py b/lib/crewai/src/crewai/llms/providers/azure/completion.py index 579ca5eba..4597fd623 100644 --- a/lib/crewai/src/crewai/llms/providers/azure/completion.py +++ b/lib/crewai/src/crewai/llms/providers/azure/completion.py @@ -42,7 +42,7 @@ try: ) from crewai.events.types.llm_events import LLMCallType - from crewai.llms.base_llm import BaseLLM, llm_call_context + from crewai.llms.base_llm import BaseLLM, call_stream_override, llm_call_context except ImportError: raise ImportError( @@ -493,15 +493,18 @@ class AzureCompletion(BaseLLM): Completion response or tool call result """ if self.api == "responses": - return self._responses_delegate.call( - messages=messages, - tools=tools, - callbacks=callbacks, - available_functions=available_functions, - from_task=from_task, - from_agent=from_agent, - response_model=response_model, - ) + with call_stream_override( + self._responses_delegate, bool(self._effective_stream()) + ): + return self._responses_delegate.call( + messages=messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ) with llm_call_context(): try: @@ -527,7 +530,7 @@ class AzureCompletion(BaseLLM): formatted_messages, tools, effective_response_model ) - if self.stream: + if self._effective_stream(): return self._handle_streaming_completion( completion_params, available_functions, @@ -572,15 +575,18 @@ class AzureCompletion(BaseLLM): Completion response or tool call result """ if self.api == "responses": - return await self._responses_delegate.acall( - messages=messages, - tools=tools, - callbacks=callbacks, - available_functions=available_functions, - from_task=from_task, - from_agent=from_agent, - response_model=response_model, - ) + with call_stream_override( + self._responses_delegate, bool(self._effective_stream()) + ): + return await self._responses_delegate.acall( + messages=messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ) with llm_call_context(): try: @@ -601,7 +607,7 @@ class AzureCompletion(BaseLLM): formatted_messages, tools, effective_response_model ) - if self.stream: + if self._effective_stream(): return await self._ahandle_streaming_completion( completion_params, available_functions, @@ -639,11 +645,11 @@ class AzureCompletion(BaseLLM): """ params: AzureCompletionParams = { "messages": messages, - "stream": self.stream, + "stream": bool(self._effective_stream()), } model_extras: dict[str, Any] = {} - if self.stream: + if self._effective_stream(): model_extras["stream_options"] = {"include_usage": True} if response_model and self.is_openai_model: diff --git a/lib/crewai/src/crewai/llms/providers/bedrock/completion.py b/lib/crewai/src/crewai/llms/providers/bedrock/completion.py index 0f34b6723..ef316e8f4 100644 --- a/lib/crewai/src/crewai/llms/providers/bedrock/completion.py +++ b/lib/crewai/src/crewai/llms/providers/bedrock/completion.py @@ -428,7 +428,7 @@ class BedrockCompletion(BaseLLM): self.additional_model_response_field_paths ) - if self.stream: + if self._effective_stream(): return self._handle_streaming_converse( formatted_messages, body, @@ -556,7 +556,7 @@ class BedrockCompletion(BaseLLM): self.additional_model_response_field_paths ) - if self.stream: + if self._effective_stream(): return await self._ahandle_streaming_converse( formatted_messages, body, diff --git a/lib/crewai/src/crewai/llms/providers/gemini/completion.py b/lib/crewai/src/crewai/llms/providers/gemini/completion.py index b811614a1..b099fe237 100644 --- a/lib/crewai/src/crewai/llms/providers/gemini/completion.py +++ b/lib/crewai/src/crewai/llms/providers/gemini/completion.py @@ -322,7 +322,7 @@ class GeminiCompletion(BaseLLM): system_instruction, tools, effective_response_model ) - if self.stream: + if self._effective_stream(): return self._handle_streaming_completion( formatted_content, config, @@ -401,7 +401,7 @@ class GeminiCompletion(BaseLLM): system_instruction, tools, effective_response_model ) - if self.stream: + if self._effective_stream(): return await self._ahandle_streaming_completion( formatted_content, config, diff --git a/lib/crewai/src/crewai/llms/providers/openai/completion.py b/lib/crewai/src/crewai/llms/providers/openai/completion.py index d8972e1de..77d2bbbdd 100644 --- a/lib/crewai/src/crewai/llms/providers/openai/completion.py +++ b/lib/crewai/src/crewai/llms/providers/openai/completion.py @@ -469,7 +469,7 @@ class OpenAICompletion(BaseLLM): messages=messages, tools=tools ) - if self.stream: + if self._effective_stream(): return self._handle_streaming_completion( params=completion_params, available_functions=available_functions, @@ -564,7 +564,7 @@ class OpenAICompletion(BaseLLM): messages=messages, tools=tools ) - if self.stream: + if self._effective_stream(): return await self._ahandle_streaming_completion( params=completion_params, available_functions=available_functions, @@ -595,7 +595,7 @@ class OpenAICompletion(BaseLLM): messages=messages, tools=tools, response_model=response_model ) - if self.stream: + if self._effective_stream(): return self._handle_streaming_responses( params=params, available_functions=available_functions, @@ -626,7 +626,7 @@ class OpenAICompletion(BaseLLM): messages=messages, tools=tools, response_model=response_model ) - if self.stream: + if self._effective_stream(): return await self._ahandle_streaming_responses( params=params, available_functions=available_functions, @@ -685,7 +685,7 @@ class OpenAICompletion(BaseLLM): if instructions: params["instructions"] = instructions - if self.stream: + if self._effective_stream(): params["stream"] = True if self.store is not None: @@ -1540,8 +1540,8 @@ class OpenAICompletion(BaseLLM): "model": self.model, "messages": messages, } - if self.stream: - params["stream"] = self.stream + if self._effective_stream(): + params["stream"] = self._effective_stream() params["stream_options"] = {"include_usage": True} params.update(self.additional_params) diff --git a/lib/crewai/tests/test_stream_frames.py b/lib/crewai/tests/test_stream_frames.py index 5fc9ee168..076da1a5a 100644 --- a/lib/crewai/tests/test_stream_frames.py +++ b/lib/crewai/tests/test_stream_frames.py @@ -13,7 +13,7 @@ from crewai.events.types.flow_events import ConversationMessageAddedEvent from crewai.events.types.llm_events import LLMStreamChunkEvent, LLMThinkingChunkEvent from crewai.events.types.tool_usage_events import ToolUsageStartedEvent from crewai.flow.flow import Flow, start -from crewai.llms.base_llm import BaseLLM +from crewai.llms.base_llm import BaseLLM, call_stop_override from crewai.types.streaming import StreamFrame @@ -60,11 +60,22 @@ class FrameFlow(Flow): class DirectStreamingLLM(BaseLLM): call_stream_values: ClassVar[list[bool | None]] = [] + raw_stream_values: ClassVar[list[bool | None]] = [] call_instance_ids: ClassVar[list[int]] = [] + call_stop_values: ClassVar[list[list[str]]] = [] def call(self, messages: Any, *args: Any, **kwargs: Any) -> str: - self.call_stream_values.append(self.stream) + self.call_stream_values.append(self._effective_stream()) + self.raw_stream_values.append(self.stream) self.call_instance_ids.append(id(self)) + self.call_stop_values.append(list(self.stop_sequences)) + self._track_token_usage_internal( + { + "prompt_tokens": 1, + "completion_tokens": 2, + "total_tokens": 3, + } + ) crewai_event_bus.emit( self, LLMStreamChunkEvent( @@ -182,18 +193,28 @@ def test_flow_streaming_returns_iterable_frame_session() -> None: def test_direct_llm_stream_events_scope_and_restore_stream_flag() -> None: DirectStreamingLLM.call_stream_values = [] + DirectStreamingLLM.raw_stream_values = [] DirectStreamingLLM.call_instance_ids = [] + DirectStreamingLLM.call_stop_values = [] llm = DirectStreamingLLM(model="gpt-4o-mini", stream=False) - with llm.stream_events("hello") as stream: - frames = list(stream) + with call_stop_override(llm, ["STOP"]): + with llm.stream_events("hello") as stream: + frames = list(stream) assert [frame.content for frame in frames] == ["hel", "lo"] assert frames[0].event["chunk"] == "hel" assert stream.result == "hello" assert llm.stream is False assert DirectStreamingLLM.call_stream_values == [True] - assert DirectStreamingLLM.call_instance_ids != [id(llm)] + assert DirectStreamingLLM.raw_stream_values == [False] + assert DirectStreamingLLM.call_instance_ids == [id(llm)] + assert DirectStreamingLLM.call_stop_values == [["STOP"]] + usage = llm.get_token_usage_summary() + assert usage.total_tokens == 3 + assert usage.prompt_tokens == 1 + assert usage.completion_tokens == 2 + assert usage.successful_requests == 1 @pytest.mark.asyncio