fix(gemini): include thoughts_token_count in completion tokens

This commit is contained in:
Greyson LaLonde
2026-05-04 21:03:38 +08:00
committed by GitHub
parent f579aa53ae
commit 6494d68ffc
2 changed files with 32 additions and 1 deletions

View File

@@ -1328,9 +1328,11 @@ class GeminiCompletion(BaseLLM):
usage = response.usage_metadata
cached_tokens = getattr(usage, "cached_content_token_count", 0) or 0
thinking_tokens = getattr(usage, "thoughts_token_count", 0) or 0
candidates_tokens = getattr(usage, "candidates_token_count", 0) or 0
result: dict[str, Any] = {
"prompt_token_count": getattr(usage, "prompt_token_count", 0),
"candidates_token_count": getattr(usage, "candidates_token_count", 0),
"candidates_token_count": candidates_tokens,
"completion_tokens": candidates_tokens + thinking_tokens,
"total_token_count": getattr(usage, "total_token_count", 0),
"total_tokens": getattr(usage, "total_token_count", 0),
"cached_prompt_tokens": cached_tokens,

View File

@@ -596,6 +596,35 @@ def test_gemini_token_usage_tracking():
assert usage.total_tokens > 0
def test_gemini_thoughts_tokens_counted_in_completion_and_total():
"""Gemini's thoughts_token_count must be folded into completion_tokens so the
tracked total matches the API's total_token_count for thinking models."""
from crewai.llms.providers.gemini.completion import GeminiCompletion
llm = GeminiCompletion(model="gemini-2.0-flash-001")
response = MagicMock()
response.usage_metadata = MagicMock(
prompt_token_count=100,
candidates_token_count=50,
thoughts_token_count=25,
total_token_count=175,
cached_content_token_count=0,
)
usage = llm._extract_token_usage(response)
assert usage["candidates_token_count"] == 50
assert usage["completion_tokens"] == 75
assert usage["reasoning_tokens"] == 25
llm._track_token_usage_internal(usage)
summary = llm.get_token_usage_summary()
assert summary.prompt_tokens == 100
assert summary.completion_tokens == 75
assert summary.total_tokens == 175
assert summary.reasoning_tokens == 25
@pytest.mark.vcr()
def test_gemini_tool_returning_float():
"""