Preserve LLM instance state for stream events

This commit is contained in:
lorenzejay
2026-06-29 16:50:46 -07:00
parent 5e6fdc8374
commit bb290fa967
8 changed files with 106 additions and 56 deletions

View File

@@ -749,7 +749,7 @@ class LLM(BaseLLM):
"base_url": self.base_url,
"api_version": self.api_version,
"api_key": self.api_key,
"stream": self.stream,
"stream": self._effective_stream(),
"tools": tools,
"reasoning_effort": self.reasoning_effort,
**self.additional_params,
@@ -1841,7 +1841,7 @@ class LLM(BaseLLM):
self.set_callbacks(callbacks)
try:
params = self._prepare_completion_params(messages, tools)
if self.stream:
if self._effective_stream():
result = self._handle_streaming_response(
params=params,
callbacks=callbacks,
@@ -1983,7 +1983,7 @@ class LLM(BaseLLM):
messages, tools, skip_file_processing=True
)
if self.stream:
if self._effective_stream():
return await self._ahandle_streaming_response(
params=params,
callbacks=callbacks,

View File

@@ -82,6 +82,9 @@ _current_call_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
_call_stop_override_var: contextvars.ContextVar[dict[int, list[str]] | None] = (
contextvars.ContextVar("_call_stop_override_var", default=None)
)
_call_stream_override_var: contextvars.ContextVar[dict[int, bool] | None] = (
contextvars.ContextVar("_call_stream_override_var", default=None)
)
@contextmanager
@@ -120,6 +123,19 @@ def call_stop_override(
_call_stop_override_var.reset(token)
@contextmanager
def call_stream_override(llm: BaseLLM, stream: bool) -> Generator[None, None, None]:
"""Override streaming for ``llm`` within the current call scope."""
current = _call_stream_override_var.get()
new_overrides: dict[int, bool] = dict(current) if current else {}
new_overrides[id(llm)] = stream
token = _call_stream_override_var.set(new_overrides)
try:
yield
finally:
_call_stream_override_var.reset(token)
def get_current_call_id() -> str:
"""Get current call_id from context"""
call_id = _current_call_id.get()
@@ -218,6 +234,13 @@ class BaseLLM(BaseModel, ABC):
return override
return self.stop
def _effective_stream(self) -> bool | None:
"""Return the call-scoped streaming mode for this instance."""
overrides = _call_stream_override_var.get()
if overrides is not None and id(self) in overrides:
return overrides[id(self)]
return self.stream
_token_usage: dict[str, int] = PrivateAttr(
default_factory=lambda: {
"total_tokens": 0,
@@ -339,16 +362,16 @@ class BaseLLM(BaseModel, ABC):
output_holder: list[StreamSession[Any]] = []
def run_llm_call() -> Any:
streaming_llm = self.model_copy(update={"stream": True})
return streaming_llm.call(
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
with call_stream_override(self, True):
return self.call(
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
stream_session: StreamSession[Any] = StreamSession(
sync_iterator=create_frame_generator(state, run_llm_call, output_holder)
@@ -547,7 +570,7 @@ class BaseLLM(BaseModel, ABC):
if max_tokens is None:
max_tokens = self._effective_max_tokens()
if stream is None:
stream = self.stream
stream = self._effective_stream()
if seed is None:
seed = self.seed
if stop_sequences is None:

View File

@@ -323,7 +323,7 @@ class AnthropicCompletion(BaseLLM):
effective_response_model = response_model or self.response_format
if self.stream:
if self._effective_stream():
return self._handle_streaming_completion(
completion_params,
available_functions,
@@ -393,7 +393,7 @@ class AnthropicCompletion(BaseLLM):
effective_response_model = response_model or self.response_format
if self.stream:
if self._effective_stream():
return await self._ahandle_streaming_completion(
completion_params,
available_functions,
@@ -441,7 +441,7 @@ class AnthropicCompletion(BaseLLM):
"model": self.model,
"messages": messages,
"max_tokens": self.max_tokens,
"stream": self.stream,
"stream": self._effective_stream(),
}
if system_message:

View File

@@ -42,7 +42,7 @@ try:
)
from crewai.events.types.llm_events import LLMCallType
from crewai.llms.base_llm import BaseLLM, llm_call_context
from crewai.llms.base_llm import BaseLLM, call_stream_override, llm_call_context
except ImportError:
raise ImportError(
@@ -493,15 +493,18 @@ class AzureCompletion(BaseLLM):
Completion response or tool call result
"""
if self.api == "responses":
return self._responses_delegate.call(
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
with call_stream_override(
self._responses_delegate, bool(self._effective_stream())
):
return self._responses_delegate.call(
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
with llm_call_context():
try:
@@ -527,7 +530,7 @@ class AzureCompletion(BaseLLM):
formatted_messages, tools, effective_response_model
)
if self.stream:
if self._effective_stream():
return self._handle_streaming_completion(
completion_params,
available_functions,
@@ -572,15 +575,18 @@ class AzureCompletion(BaseLLM):
Completion response or tool call result
"""
if self.api == "responses":
return await self._responses_delegate.acall(
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
with call_stream_override(
self._responses_delegate, bool(self._effective_stream())
):
return await self._responses_delegate.acall(
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
with llm_call_context():
try:
@@ -601,7 +607,7 @@ class AzureCompletion(BaseLLM):
formatted_messages, tools, effective_response_model
)
if self.stream:
if self._effective_stream():
return await self._ahandle_streaming_completion(
completion_params,
available_functions,
@@ -639,11 +645,11 @@ class AzureCompletion(BaseLLM):
"""
params: AzureCompletionParams = {
"messages": messages,
"stream": self.stream,
"stream": bool(self._effective_stream()),
}
model_extras: dict[str, Any] = {}
if self.stream:
if self._effective_stream():
model_extras["stream_options"] = {"include_usage": True}
if response_model and self.is_openai_model:

View File

@@ -428,7 +428,7 @@ class BedrockCompletion(BaseLLM):
self.additional_model_response_field_paths
)
if self.stream:
if self._effective_stream():
return self._handle_streaming_converse(
formatted_messages,
body,
@@ -556,7 +556,7 @@ class BedrockCompletion(BaseLLM):
self.additional_model_response_field_paths
)
if self.stream:
if self._effective_stream():
return await self._ahandle_streaming_converse(
formatted_messages,
body,

View File

@@ -322,7 +322,7 @@ class GeminiCompletion(BaseLLM):
system_instruction, tools, effective_response_model
)
if self.stream:
if self._effective_stream():
return self._handle_streaming_completion(
formatted_content,
config,
@@ -401,7 +401,7 @@ class GeminiCompletion(BaseLLM):
system_instruction, tools, effective_response_model
)
if self.stream:
if self._effective_stream():
return await self._ahandle_streaming_completion(
formatted_content,
config,

View File

@@ -469,7 +469,7 @@ class OpenAICompletion(BaseLLM):
messages=messages, tools=tools
)
if self.stream:
if self._effective_stream():
return self._handle_streaming_completion(
params=completion_params,
available_functions=available_functions,
@@ -564,7 +564,7 @@ class OpenAICompletion(BaseLLM):
messages=messages, tools=tools
)
if self.stream:
if self._effective_stream():
return await self._ahandle_streaming_completion(
params=completion_params,
available_functions=available_functions,
@@ -595,7 +595,7 @@ class OpenAICompletion(BaseLLM):
messages=messages, tools=tools, response_model=response_model
)
if self.stream:
if self._effective_stream():
return self._handle_streaming_responses(
params=params,
available_functions=available_functions,
@@ -626,7 +626,7 @@ class OpenAICompletion(BaseLLM):
messages=messages, tools=tools, response_model=response_model
)
if self.stream:
if self._effective_stream():
return await self._ahandle_streaming_responses(
params=params,
available_functions=available_functions,
@@ -685,7 +685,7 @@ class OpenAICompletion(BaseLLM):
if instructions:
params["instructions"] = instructions
if self.stream:
if self._effective_stream():
params["stream"] = True
if self.store is not None:
@@ -1540,8 +1540,8 @@ class OpenAICompletion(BaseLLM):
"model": self.model,
"messages": messages,
}
if self.stream:
params["stream"] = self.stream
if self._effective_stream():
params["stream"] = self._effective_stream()
params["stream_options"] = {"include_usage": True}
params.update(self.additional_params)

View File

@@ -13,7 +13,7 @@ from crewai.events.types.flow_events import ConversationMessageAddedEvent
from crewai.events.types.llm_events import LLMStreamChunkEvent, LLMThinkingChunkEvent
from crewai.events.types.tool_usage_events import ToolUsageStartedEvent
from crewai.flow.flow import Flow, start
from crewai.llms.base_llm import BaseLLM
from crewai.llms.base_llm import BaseLLM, call_stop_override
from crewai.types.streaming import StreamFrame
@@ -60,11 +60,22 @@ class FrameFlow(Flow):
class DirectStreamingLLM(BaseLLM):
call_stream_values: ClassVar[list[bool | None]] = []
raw_stream_values: ClassVar[list[bool | None]] = []
call_instance_ids: ClassVar[list[int]] = []
call_stop_values: ClassVar[list[list[str]]] = []
def call(self, messages: Any, *args: Any, **kwargs: Any) -> str:
self.call_stream_values.append(self.stream)
self.call_stream_values.append(self._effective_stream())
self.raw_stream_values.append(self.stream)
self.call_instance_ids.append(id(self))
self.call_stop_values.append(list(self.stop_sequences))
self._track_token_usage_internal(
{
"prompt_tokens": 1,
"completion_tokens": 2,
"total_tokens": 3,
}
)
crewai_event_bus.emit(
self,
LLMStreamChunkEvent(
@@ -182,18 +193,28 @@ def test_flow_streaming_returns_iterable_frame_session() -> None:
def test_direct_llm_stream_events_scope_and_restore_stream_flag() -> None:
DirectStreamingLLM.call_stream_values = []
DirectStreamingLLM.raw_stream_values = []
DirectStreamingLLM.call_instance_ids = []
DirectStreamingLLM.call_stop_values = []
llm = DirectStreamingLLM(model="gpt-4o-mini", stream=False)
with llm.stream_events("hello") as stream:
frames = list(stream)
with call_stop_override(llm, ["STOP"]):
with llm.stream_events("hello") as stream:
frames = list(stream)
assert [frame.content for frame in frames] == ["hel", "lo"]
assert frames[0].event["chunk"] == "hel"
assert stream.result == "hello"
assert llm.stream is False
assert DirectStreamingLLM.call_stream_values == [True]
assert DirectStreamingLLM.call_instance_ids != [id(llm)]
assert DirectStreamingLLM.raw_stream_values == [False]
assert DirectStreamingLLM.call_instance_ids == [id(llm)]
assert DirectStreamingLLM.call_stop_values == [["STOP"]]
usage = llm.get_token_usage_summary()
assert usage.total_tokens == 3
assert usage.prompt_tokens == 1
assert usage.completion_tokens == 2
assert usage.successful_requests == 1
@pytest.mark.asyncio