mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-15 11:58:31 +00:00
Fix streaming token usage tracking in OpenAI provider
This commit fixes issue #4056 where token usage always returns 0 when using async streaming crew kickoff. Root cause: The streaming completion methods (_handle_streaming_completion and _ahandle_streaming_completion) in OpenAICompletion never called _track_token_usage_internal(), unlike the non-streaming methods. Changes: - Add stream_options={'include_usage': True} to streaming params so OpenAI API returns usage information in the final chunk - Extract and track token usage from the final chunk in sync streaming - Extract and track token usage from the final chunk in async streaming - Extract and track token usage from final_completion in response_model paths - Add _extract_chunk_token_usage method for ChatCompletionChunk objects - Add tests to verify streaming token usage tracking works correctly Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -297,6 +297,7 @@ class OpenAICompletion(BaseLLM):
|
||||
}
|
||||
if self.stream:
|
||||
params["stream"] = self.stream
|
||||
params["stream_options"] = {"include_usage": True}
|
||||
|
||||
params.update(self.additional_params)
|
||||
|
||||
@@ -545,6 +546,9 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
final_completion = stream.get_final_completion()
|
||||
if final_completion and final_completion.choices:
|
||||
usage = self._extract_openai_token_usage(final_completion)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
parsed_result = final_completion.choices[0].message.parsed
|
||||
if parsed_result:
|
||||
structured_json = parsed_result.model_dump_json()
|
||||
@@ -564,7 +568,11 @@ class OpenAICompletion(BaseLLM):
|
||||
self.client.chat.completions.create(**params)
|
||||
)
|
||||
|
||||
usage_data: dict[str, Any] | None = None
|
||||
for completion_chunk in completion_stream:
|
||||
if completion_chunk.usage is not None:
|
||||
usage_data = self._extract_chunk_token_usage(completion_chunk)
|
||||
|
||||
if not completion_chunk.choices:
|
||||
continue
|
||||
|
||||
@@ -593,6 +601,9 @@ class OpenAICompletion(BaseLLM):
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls[call_id]["arguments"] += tool_call.function.arguments
|
||||
|
||||
if usage_data:
|
||||
self._track_token_usage_internal(usage_data)
|
||||
|
||||
if tool_calls and available_functions:
|
||||
for call_data in tool_calls.values():
|
||||
function_name = call_data["name"]
|
||||
@@ -785,7 +796,11 @@ class OpenAICompletion(BaseLLM):
|
||||
] = await self.async_client.chat.completions.create(**params)
|
||||
|
||||
accumulated_content = ""
|
||||
usage_data: dict[str, Any] | None = None
|
||||
async for chunk in completion_stream:
|
||||
if chunk.usage is not None:
|
||||
usage_data = self._extract_chunk_token_usage(chunk)
|
||||
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
@@ -800,6 +815,9 @@ class OpenAICompletion(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if usage_data:
|
||||
self._track_token_usage_internal(usage_data)
|
||||
|
||||
try:
|
||||
parsed_object = response_model.model_validate_json(accumulated_content)
|
||||
structured_json = parsed_object.model_dump_json()
|
||||
@@ -828,7 +846,11 @@ class OpenAICompletion(BaseLLM):
|
||||
ChatCompletionChunk
|
||||
] = await self.async_client.chat.completions.create(**params)
|
||||
|
||||
usage_data = None
|
||||
async for chunk in stream:
|
||||
if chunk.usage is not None:
|
||||
usage_data = self._extract_chunk_token_usage(chunk)
|
||||
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
@@ -857,6 +879,9 @@ class OpenAICompletion(BaseLLM):
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls[call_id]["arguments"] += tool_call.function.arguments
|
||||
|
||||
if usage_data:
|
||||
self._track_token_usage_internal(usage_data)
|
||||
|
||||
if tool_calls and available_functions:
|
||||
for call_data in tool_calls.values():
|
||||
function_name = call_data["name"]
|
||||
@@ -955,6 +980,19 @@ class OpenAICompletion(BaseLLM):
|
||||
}
|
||||
return {"total_tokens": 0}
|
||||
|
||||
def _extract_chunk_token_usage(
|
||||
self, chunk: ChatCompletionChunk
|
||||
) -> dict[str, Any]:
|
||||
"""Extract token usage from OpenAI ChatCompletionChunk (streaming response)."""
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage = chunk.usage
|
||||
return {
|
||||
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
||||
"total_tokens": getattr(usage, "total_tokens", 0),
|
||||
}
|
||||
return {"total_tokens": 0}
|
||||
|
||||
def _format_messages(self, messages: str | list[LLMMessage]) -> list[LLMMessage]:
|
||||
"""Format messages for OpenAI API."""
|
||||
base_formatted = super()._format_messages(messages)
|
||||
|
||||
@@ -592,3 +592,133 @@ def test_openai_response_format_none():
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
def test_openai_streaming_tracks_token_usage():
|
||||
"""
|
||||
Test that streaming mode correctly tracks token usage.
|
||||
This test verifies the fix for issue #4056 where token usage was always 0
|
||||
when using streaming mode.
|
||||
"""
|
||||
llm = LLM(model="openai/gpt-4o", stream=True)
|
||||
|
||||
# Create mock chunks with usage in the final chunk
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.choices = [MagicMock()]
|
||||
mock_chunk1.choices[0].delta = MagicMock()
|
||||
mock_chunk1.choices[0].delta.content = "Hello "
|
||||
mock_chunk1.choices[0].delta.tool_calls = None
|
||||
mock_chunk1.usage = None
|
||||
|
||||
mock_chunk2 = MagicMock()
|
||||
mock_chunk2.choices = [MagicMock()]
|
||||
mock_chunk2.choices[0].delta = MagicMock()
|
||||
mock_chunk2.choices[0].delta.content = "World!"
|
||||
mock_chunk2.choices[0].delta.tool_calls = None
|
||||
mock_chunk2.usage = None
|
||||
|
||||
# Final chunk with usage information (when stream_options={"include_usage": True})
|
||||
mock_chunk3 = MagicMock()
|
||||
mock_chunk3.choices = []
|
||||
mock_chunk3.usage = MagicMock()
|
||||
mock_chunk3.usage.prompt_tokens = 10
|
||||
mock_chunk3.usage.completion_tokens = 20
|
||||
mock_chunk3.usage.total_tokens = 30
|
||||
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.__iter__ = MagicMock(return_value=iter([mock_chunk1, mock_chunk2, mock_chunk3]))
|
||||
|
||||
with patch.object(llm.client.chat.completions, "create", return_value=mock_stream):
|
||||
result = llm.call("Hello")
|
||||
|
||||
assert result == "Hello World!"
|
||||
|
||||
# Verify token usage was tracked
|
||||
usage = llm.get_token_usage_summary()
|
||||
assert usage.prompt_tokens == 10
|
||||
assert usage.completion_tokens == 20
|
||||
assert usage.total_tokens == 30
|
||||
assert usage.successful_requests == 1
|
||||
|
||||
|
||||
def test_openai_streaming_with_response_model_tracks_token_usage():
|
||||
"""
|
||||
Test that streaming with response_model correctly tracks token usage.
|
||||
This test verifies the fix for issue #4056 where token usage was always 0
|
||||
when using streaming mode with response_model.
|
||||
"""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class TestResponse(BaseModel):
|
||||
"""Test response model."""
|
||||
|
||||
answer: str
|
||||
confidence: float
|
||||
|
||||
llm = LLM(model="openai/gpt-4o", stream=True)
|
||||
|
||||
with patch.object(llm.client.beta.chat.completions, "stream") as mock_stream:
|
||||
# Create mock chunks with content.delta event structure
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.type = "content.delta"
|
||||
mock_chunk1.delta = '{"answer": "test", '
|
||||
|
||||
mock_chunk2 = MagicMock()
|
||||
mock_chunk2.type = "content.delta"
|
||||
mock_chunk2.delta = '"confidence": 0.95}'
|
||||
|
||||
# Create mock final completion with parsed result and usage
|
||||
mock_parsed = TestResponse(answer="test", confidence=0.95)
|
||||
mock_message = MagicMock()
|
||||
mock_message.parsed = mock_parsed
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message = mock_message
|
||||
mock_final_completion = MagicMock()
|
||||
mock_final_completion.choices = [mock_choice]
|
||||
mock_final_completion.usage = MagicMock()
|
||||
mock_final_completion.usage.prompt_tokens = 15
|
||||
mock_final_completion.usage.completion_tokens = 25
|
||||
mock_final_completion.usage.total_tokens = 40
|
||||
|
||||
# Create mock stream context manager
|
||||
mock_stream_obj = MagicMock()
|
||||
mock_stream_obj.__enter__ = MagicMock(return_value=mock_stream_obj)
|
||||
mock_stream_obj.__exit__ = MagicMock(return_value=None)
|
||||
mock_stream_obj.__iter__ = MagicMock(return_value=iter([mock_chunk1, mock_chunk2]))
|
||||
mock_stream_obj.get_final_completion = MagicMock(return_value=mock_final_completion)
|
||||
|
||||
mock_stream.return_value = mock_stream_obj
|
||||
|
||||
result = llm.call("Test question", response_model=TestResponse)
|
||||
|
||||
assert result is not None
|
||||
|
||||
# Verify token usage was tracked
|
||||
usage = llm.get_token_usage_summary()
|
||||
assert usage.prompt_tokens == 15
|
||||
assert usage.completion_tokens == 25
|
||||
assert usage.total_tokens == 40
|
||||
assert usage.successful_requests == 1
|
||||
|
||||
|
||||
def test_openai_streaming_params_include_usage():
|
||||
"""
|
||||
Test that streaming mode includes stream_options with include_usage=True.
|
||||
This ensures the OpenAI API will return usage information in the final chunk.
|
||||
"""
|
||||
llm = LLM(model="openai/gpt-4o", stream=True)
|
||||
|
||||
with patch.object(llm.client.chat.completions, "create") as mock_create:
|
||||
mock_stream = MagicMock()
|
||||
mock_stream.__iter__ = MagicMock(return_value=iter([]))
|
||||
mock_create.return_value = mock_stream
|
||||
|
||||
try:
|
||||
llm.call("Hello")
|
||||
except Exception:
|
||||
pass # We just want to check the call parameters
|
||||
|
||||
# Verify stream_options was included in the API call
|
||||
call_kwargs = mock_create.call_args[1]
|
||||
assert call_kwargs.get("stream") is True
|
||||
assert call_kwargs.get("stream_options") == {"include_usage": True}
|
||||
|
||||
Reference in New Issue
Block a user