mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-01 05:08:12 +00:00
Preserve LLM instance state for stream events
This commit is contained in:
@@ -749,7 +749,7 @@ class LLM(BaseLLM):
|
||||
"base_url": self.base_url,
|
||||
"api_version": self.api_version,
|
||||
"api_key": self.api_key,
|
||||
"stream": self.stream,
|
||||
"stream": self._effective_stream(),
|
||||
"tools": tools,
|
||||
"reasoning_effort": self.reasoning_effort,
|
||||
**self.additional_params,
|
||||
@@ -1841,7 +1841,7 @@ class LLM(BaseLLM):
|
||||
self.set_callbacks(callbacks)
|
||||
try:
|
||||
params = self._prepare_completion_params(messages, tools)
|
||||
if self.stream:
|
||||
if self._effective_stream():
|
||||
result = self._handle_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
@@ -1983,7 +1983,7 @@ class LLM(BaseLLM):
|
||||
messages, tools, skip_file_processing=True
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
if self._effective_stream():
|
||||
return await self._ahandle_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
|
||||
@@ -82,6 +82,9 @@ _current_call_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
|
||||
_call_stop_override_var: contextvars.ContextVar[dict[int, list[str]] | None] = (
|
||||
contextvars.ContextVar("_call_stop_override_var", default=None)
|
||||
)
|
||||
_call_stream_override_var: contextvars.ContextVar[dict[int, bool] | None] = (
|
||||
contextvars.ContextVar("_call_stream_override_var", default=None)
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -120,6 +123,19 @@ def call_stop_override(
|
||||
_call_stop_override_var.reset(token)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def call_stream_override(llm: BaseLLM, stream: bool) -> Generator[None, None, None]:
|
||||
"""Override streaming for ``llm`` within the current call scope."""
|
||||
current = _call_stream_override_var.get()
|
||||
new_overrides: dict[int, bool] = dict(current) if current else {}
|
||||
new_overrides[id(llm)] = stream
|
||||
token = _call_stream_override_var.set(new_overrides)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_call_stream_override_var.reset(token)
|
||||
|
||||
|
||||
def get_current_call_id() -> str:
|
||||
"""Get current call_id from context"""
|
||||
call_id = _current_call_id.get()
|
||||
@@ -218,6 +234,13 @@ class BaseLLM(BaseModel, ABC):
|
||||
return override
|
||||
return self.stop
|
||||
|
||||
def _effective_stream(self) -> bool | None:
|
||||
"""Return the call-scoped streaming mode for this instance."""
|
||||
overrides = _call_stream_override_var.get()
|
||||
if overrides is not None and id(self) in overrides:
|
||||
return overrides[id(self)]
|
||||
return self.stream
|
||||
|
||||
_token_usage: dict[str, int] = PrivateAttr(
|
||||
default_factory=lambda: {
|
||||
"total_tokens": 0,
|
||||
@@ -339,16 +362,16 @@ class BaseLLM(BaseModel, ABC):
|
||||
output_holder: list[StreamSession[Any]] = []
|
||||
|
||||
def run_llm_call() -> Any:
|
||||
streaming_llm = self.model_copy(update={"stream": True})
|
||||
return streaming_llm.call(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
with call_stream_override(self, True):
|
||||
return self.call(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
stream_session: StreamSession[Any] = StreamSession(
|
||||
sync_iterator=create_frame_generator(state, run_llm_call, output_holder)
|
||||
@@ -547,7 +570,7 @@ class BaseLLM(BaseModel, ABC):
|
||||
if max_tokens is None:
|
||||
max_tokens = self._effective_max_tokens()
|
||||
if stream is None:
|
||||
stream = self.stream
|
||||
stream = self._effective_stream()
|
||||
if seed is None:
|
||||
seed = self.seed
|
||||
if stop_sequences is None:
|
||||
|
||||
@@ -323,7 +323,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
if self.stream:
|
||||
if self._effective_stream():
|
||||
return self._handle_streaming_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
@@ -393,7 +393,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
if self.stream:
|
||||
if self._effective_stream():
|
||||
return await self._ahandle_streaming_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
@@ -441,7 +441,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stream": self.stream,
|
||||
"stream": self._effective_stream(),
|
||||
}
|
||||
|
||||
if system_message:
|
||||
|
||||
@@ -42,7 +42,7 @@ try:
|
||||
)
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM, llm_call_context
|
||||
from crewai.llms.base_llm import BaseLLM, call_stream_override, llm_call_context
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
@@ -493,15 +493,18 @@ class AzureCompletion(BaseLLM):
|
||||
Completion response or tool call result
|
||||
"""
|
||||
if self.api == "responses":
|
||||
return self._responses_delegate.call(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
with call_stream_override(
|
||||
self._responses_delegate, bool(self._effective_stream())
|
||||
):
|
||||
return self._responses_delegate.call(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
with llm_call_context():
|
||||
try:
|
||||
@@ -527,7 +530,7 @@ class AzureCompletion(BaseLLM):
|
||||
formatted_messages, tools, effective_response_model
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
if self._effective_stream():
|
||||
return self._handle_streaming_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
@@ -572,15 +575,18 @@ class AzureCompletion(BaseLLM):
|
||||
Completion response or tool call result
|
||||
"""
|
||||
if self.api == "responses":
|
||||
return await self._responses_delegate.acall(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
with call_stream_override(
|
||||
self._responses_delegate, bool(self._effective_stream())
|
||||
):
|
||||
return await self._responses_delegate.acall(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
with llm_call_context():
|
||||
try:
|
||||
@@ -601,7 +607,7 @@ class AzureCompletion(BaseLLM):
|
||||
formatted_messages, tools, effective_response_model
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
if self._effective_stream():
|
||||
return await self._ahandle_streaming_completion(
|
||||
completion_params,
|
||||
available_functions,
|
||||
@@ -639,11 +645,11 @@ class AzureCompletion(BaseLLM):
|
||||
"""
|
||||
params: AzureCompletionParams = {
|
||||
"messages": messages,
|
||||
"stream": self.stream,
|
||||
"stream": bool(self._effective_stream()),
|
||||
}
|
||||
|
||||
model_extras: dict[str, Any] = {}
|
||||
if self.stream:
|
||||
if self._effective_stream():
|
||||
model_extras["stream_options"] = {"include_usage": True}
|
||||
|
||||
if response_model and self.is_openai_model:
|
||||
|
||||
@@ -428,7 +428,7 @@ class BedrockCompletion(BaseLLM):
|
||||
self.additional_model_response_field_paths
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
if self._effective_stream():
|
||||
return self._handle_streaming_converse(
|
||||
formatted_messages,
|
||||
body,
|
||||
@@ -556,7 +556,7 @@ class BedrockCompletion(BaseLLM):
|
||||
self.additional_model_response_field_paths
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
if self._effective_stream():
|
||||
return await self._ahandle_streaming_converse(
|
||||
formatted_messages,
|
||||
body,
|
||||
|
||||
@@ -322,7 +322,7 @@ class GeminiCompletion(BaseLLM):
|
||||
system_instruction, tools, effective_response_model
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
if self._effective_stream():
|
||||
return self._handle_streaming_completion(
|
||||
formatted_content,
|
||||
config,
|
||||
@@ -401,7 +401,7 @@ class GeminiCompletion(BaseLLM):
|
||||
system_instruction, tools, effective_response_model
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
if self._effective_stream():
|
||||
return await self._ahandle_streaming_completion(
|
||||
formatted_content,
|
||||
config,
|
||||
|
||||
@@ -469,7 +469,7 @@ class OpenAICompletion(BaseLLM):
|
||||
messages=messages, tools=tools
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
if self._effective_stream():
|
||||
return self._handle_streaming_completion(
|
||||
params=completion_params,
|
||||
available_functions=available_functions,
|
||||
@@ -564,7 +564,7 @@ class OpenAICompletion(BaseLLM):
|
||||
messages=messages, tools=tools
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
if self._effective_stream():
|
||||
return await self._ahandle_streaming_completion(
|
||||
params=completion_params,
|
||||
available_functions=available_functions,
|
||||
@@ -595,7 +595,7 @@ class OpenAICompletion(BaseLLM):
|
||||
messages=messages, tools=tools, response_model=response_model
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
if self._effective_stream():
|
||||
return self._handle_streaming_responses(
|
||||
params=params,
|
||||
available_functions=available_functions,
|
||||
@@ -626,7 +626,7 @@ class OpenAICompletion(BaseLLM):
|
||||
messages=messages, tools=tools, response_model=response_model
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
if self._effective_stream():
|
||||
return await self._ahandle_streaming_responses(
|
||||
params=params,
|
||||
available_functions=available_functions,
|
||||
@@ -685,7 +685,7 @@ class OpenAICompletion(BaseLLM):
|
||||
if instructions:
|
||||
params["instructions"] = instructions
|
||||
|
||||
if self.stream:
|
||||
if self._effective_stream():
|
||||
params["stream"] = True
|
||||
|
||||
if self.store is not None:
|
||||
@@ -1540,8 +1540,8 @@ class OpenAICompletion(BaseLLM):
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
}
|
||||
if self.stream:
|
||||
params["stream"] = self.stream
|
||||
if self._effective_stream():
|
||||
params["stream"] = self._effective_stream()
|
||||
params["stream_options"] = {"include_usage": True}
|
||||
|
||||
params.update(self.additional_params)
|
||||
|
||||
@@ -13,7 +13,7 @@ from crewai.events.types.flow_events import ConversationMessageAddedEvent
|
||||
from crewai.events.types.llm_events import LLMStreamChunkEvent, LLMThinkingChunkEvent
|
||||
from crewai.events.types.tool_usage_events import ToolUsageStartedEvent
|
||||
from crewai.flow.flow import Flow, start
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.base_llm import BaseLLM, call_stop_override
|
||||
from crewai.types.streaming import StreamFrame
|
||||
|
||||
|
||||
@@ -60,11 +60,22 @@ class FrameFlow(Flow):
|
||||
|
||||
class DirectStreamingLLM(BaseLLM):
|
||||
call_stream_values: ClassVar[list[bool | None]] = []
|
||||
raw_stream_values: ClassVar[list[bool | None]] = []
|
||||
call_instance_ids: ClassVar[list[int]] = []
|
||||
call_stop_values: ClassVar[list[list[str]]] = []
|
||||
|
||||
def call(self, messages: Any, *args: Any, **kwargs: Any) -> str:
|
||||
self.call_stream_values.append(self.stream)
|
||||
self.call_stream_values.append(self._effective_stream())
|
||||
self.raw_stream_values.append(self.stream)
|
||||
self.call_instance_ids.append(id(self))
|
||||
self.call_stop_values.append(list(self.stop_sequences))
|
||||
self._track_token_usage_internal(
|
||||
{
|
||||
"prompt_tokens": 1,
|
||||
"completion_tokens": 2,
|
||||
"total_tokens": 3,
|
||||
}
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
LLMStreamChunkEvent(
|
||||
@@ -182,18 +193,28 @@ def test_flow_streaming_returns_iterable_frame_session() -> None:
|
||||
|
||||
def test_direct_llm_stream_events_scope_and_restore_stream_flag() -> None:
|
||||
DirectStreamingLLM.call_stream_values = []
|
||||
DirectStreamingLLM.raw_stream_values = []
|
||||
DirectStreamingLLM.call_instance_ids = []
|
||||
DirectStreamingLLM.call_stop_values = []
|
||||
llm = DirectStreamingLLM(model="gpt-4o-mini", stream=False)
|
||||
|
||||
with llm.stream_events("hello") as stream:
|
||||
frames = list(stream)
|
||||
with call_stop_override(llm, ["STOP"]):
|
||||
with llm.stream_events("hello") as stream:
|
||||
frames = list(stream)
|
||||
|
||||
assert [frame.content for frame in frames] == ["hel", "lo"]
|
||||
assert frames[0].event["chunk"] == "hel"
|
||||
assert stream.result == "hello"
|
||||
assert llm.stream is False
|
||||
assert DirectStreamingLLM.call_stream_values == [True]
|
||||
assert DirectStreamingLLM.call_instance_ids != [id(llm)]
|
||||
assert DirectStreamingLLM.raw_stream_values == [False]
|
||||
assert DirectStreamingLLM.call_instance_ids == [id(llm)]
|
||||
assert DirectStreamingLLM.call_stop_values == [["STOP"]]
|
||||
usage = llm.get_token_usage_summary()
|
||||
assert usage.total_tokens == 3
|
||||
assert usage.prompt_tokens == 1
|
||||
assert usage.completion_tokens == 2
|
||||
assert usage.successful_requests == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user