Fix streaming token usage tracking in OpenAI provider

This commit fixes issue #4056 where token usage always returns 0 when using
async streaming crew kickoff.

Root cause: The streaming completion methods (_handle_streaming_completion and
_ahandle_streaming_completion) in OpenAICompletion never called
_track_token_usage_internal(), unlike the non-streaming methods.

Changes:
- Add stream_options={'include_usage': True} to streaming params so OpenAI API
  returns usage information in the final chunk
- Extract and track token usage from the final chunk in sync streaming
- Extract and track token usage from the final chunk in async streaming
- Extract and track token usage from final_completion in response_model paths
- Add _extract_chunk_token_usage method for ChatCompletionChunk objects
- Add tests to verify streaming token usage tracking works correctly

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-12-10 08:28:31 +00:00
parent 34b909367b
commit 558fc6eda4
2 changed files with 168 additions and 0 deletions

View File

@@ -297,6 +297,7 @@ class OpenAICompletion(BaseLLM):
}
if self.stream:
params["stream"] = self.stream
params["stream_options"] = {"include_usage": True}
params.update(self.additional_params)
@@ -545,6 +546,9 @@ class OpenAICompletion(BaseLLM):
final_completion = stream.get_final_completion()
if final_completion and final_completion.choices:
usage = self._extract_openai_token_usage(final_completion)
self._track_token_usage_internal(usage)
parsed_result = final_completion.choices[0].message.parsed
if parsed_result:
structured_json = parsed_result.model_dump_json()
@@ -564,7 +568,11 @@ class OpenAICompletion(BaseLLM):
self.client.chat.completions.create(**params)
)
usage_data: dict[str, Any] | None = None
for completion_chunk in completion_stream:
if completion_chunk.usage is not None:
usage_data = self._extract_chunk_token_usage(completion_chunk)
if not completion_chunk.choices:
continue
@@ -593,6 +601,9 @@ class OpenAICompletion(BaseLLM):
if tool_call.function and tool_call.function.arguments:
tool_calls[call_id]["arguments"] += tool_call.function.arguments
if usage_data:
self._track_token_usage_internal(usage_data)
if tool_calls and available_functions:
for call_data in tool_calls.values():
function_name = call_data["name"]
@@ -785,7 +796,11 @@ class OpenAICompletion(BaseLLM):
] = await self.async_client.chat.completions.create(**params)
accumulated_content = ""
usage_data: dict[str, Any] | None = None
async for chunk in completion_stream:
if chunk.usage is not None:
usage_data = self._extract_chunk_token_usage(chunk)
if not chunk.choices:
continue
@@ -800,6 +815,9 @@ class OpenAICompletion(BaseLLM):
from_agent=from_agent,
)
if usage_data:
self._track_token_usage_internal(usage_data)
try:
parsed_object = response_model.model_validate_json(accumulated_content)
structured_json = parsed_object.model_dump_json()
@@ -828,7 +846,11 @@ class OpenAICompletion(BaseLLM):
ChatCompletionChunk
] = await self.async_client.chat.completions.create(**params)
usage_data = None
async for chunk in stream:
if chunk.usage is not None:
usage_data = self._extract_chunk_token_usage(chunk)
if not chunk.choices:
continue
@@ -857,6 +879,9 @@ class OpenAICompletion(BaseLLM):
if tool_call.function and tool_call.function.arguments:
tool_calls[call_id]["arguments"] += tool_call.function.arguments
if usage_data:
self._track_token_usage_internal(usage_data)
if tool_calls and available_functions:
for call_data in tool_calls.values():
function_name = call_data["name"]
@@ -955,6 +980,19 @@ class OpenAICompletion(BaseLLM):
}
return {"total_tokens": 0}
def _extract_chunk_token_usage(
self, chunk: ChatCompletionChunk
) -> dict[str, Any]:
"""Extract token usage from OpenAI ChatCompletionChunk (streaming response)."""
if hasattr(chunk, "usage") and chunk.usage:
usage = chunk.usage
return {
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
"completion_tokens": getattr(usage, "completion_tokens", 0),
"total_tokens": getattr(usage, "total_tokens", 0),
}
return {"total_tokens": 0}
def _format_messages(self, messages: str | list[LLMMessage]) -> list[LLMMessage]:
"""Format messages for OpenAI API."""
base_formatted = super()._format_messages(messages)

View File

@@ -592,3 +592,133 @@ def test_openai_response_format_none():
assert isinstance(result, str)
assert len(result) > 0
def test_openai_streaming_tracks_token_usage():
"""
Test that streaming mode correctly tracks token usage.
This test verifies the fix for issue #4056 where token usage was always 0
when using streaming mode.
"""
llm = LLM(model="openai/gpt-4o", stream=True)
# Create mock chunks with usage in the final chunk
mock_chunk1 = MagicMock()
mock_chunk1.choices = [MagicMock()]
mock_chunk1.choices[0].delta = MagicMock()
mock_chunk1.choices[0].delta.content = "Hello "
mock_chunk1.choices[0].delta.tool_calls = None
mock_chunk1.usage = None
mock_chunk2 = MagicMock()
mock_chunk2.choices = [MagicMock()]
mock_chunk2.choices[0].delta = MagicMock()
mock_chunk2.choices[0].delta.content = "World!"
mock_chunk2.choices[0].delta.tool_calls = None
mock_chunk2.usage = None
# Final chunk with usage information (when stream_options={"include_usage": True})
mock_chunk3 = MagicMock()
mock_chunk3.choices = []
mock_chunk3.usage = MagicMock()
mock_chunk3.usage.prompt_tokens = 10
mock_chunk3.usage.completion_tokens = 20
mock_chunk3.usage.total_tokens = 30
mock_stream = MagicMock()
mock_stream.__iter__ = MagicMock(return_value=iter([mock_chunk1, mock_chunk2, mock_chunk3]))
with patch.object(llm.client.chat.completions, "create", return_value=mock_stream):
result = llm.call("Hello")
assert result == "Hello World!"
# Verify token usage was tracked
usage = llm.get_token_usage_summary()
assert usage.prompt_tokens == 10
assert usage.completion_tokens == 20
assert usage.total_tokens == 30
assert usage.successful_requests == 1
def test_openai_streaming_with_response_model_tracks_token_usage():
"""
Test that streaming with response_model correctly tracks token usage.
This test verifies the fix for issue #4056 where token usage was always 0
when using streaming mode with response_model.
"""
from pydantic import BaseModel
class TestResponse(BaseModel):
"""Test response model."""
answer: str
confidence: float
llm = LLM(model="openai/gpt-4o", stream=True)
with patch.object(llm.client.beta.chat.completions, "stream") as mock_stream:
# Create mock chunks with content.delta event structure
mock_chunk1 = MagicMock()
mock_chunk1.type = "content.delta"
mock_chunk1.delta = '{"answer": "test", '
mock_chunk2 = MagicMock()
mock_chunk2.type = "content.delta"
mock_chunk2.delta = '"confidence": 0.95}'
# Create mock final completion with parsed result and usage
mock_parsed = TestResponse(answer="test", confidence=0.95)
mock_message = MagicMock()
mock_message.parsed = mock_parsed
mock_choice = MagicMock()
mock_choice.message = mock_message
mock_final_completion = MagicMock()
mock_final_completion.choices = [mock_choice]
mock_final_completion.usage = MagicMock()
mock_final_completion.usage.prompt_tokens = 15
mock_final_completion.usage.completion_tokens = 25
mock_final_completion.usage.total_tokens = 40
# Create mock stream context manager
mock_stream_obj = MagicMock()
mock_stream_obj.__enter__ = MagicMock(return_value=mock_stream_obj)
mock_stream_obj.__exit__ = MagicMock(return_value=None)
mock_stream_obj.__iter__ = MagicMock(return_value=iter([mock_chunk1, mock_chunk2]))
mock_stream_obj.get_final_completion = MagicMock(return_value=mock_final_completion)
mock_stream.return_value = mock_stream_obj
result = llm.call("Test question", response_model=TestResponse)
assert result is not None
# Verify token usage was tracked
usage = llm.get_token_usage_summary()
assert usage.prompt_tokens == 15
assert usage.completion_tokens == 25
assert usage.total_tokens == 40
assert usage.successful_requests == 1
def test_openai_streaming_params_include_usage():
"""
Test that streaming mode includes stream_options with include_usage=True.
This ensures the OpenAI API will return usage information in the final chunk.
"""
llm = LLM(model="openai/gpt-4o", stream=True)
with patch.object(llm.client.chat.completions, "create") as mock_create:
mock_stream = MagicMock()
mock_stream.__iter__ = MagicMock(return_value=iter([]))
mock_create.return_value = mock_stream
try:
llm.call("Hello")
except Exception:
pass # We just want to check the call parameters
# Verify stream_options was included in the API call
call_kwargs = mock_create.call_args[1]
assert call_kwargs.get("stream") is True
assert call_kwargs.get("stream_options") == {"include_usage": True}