feat: enrich LLM token tracking with reasoning tokens, cache creation tokens (#5389)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Vulnerability Scan / pip-audit (push) Has been cancelled
Nightly Canary Release / Check for new commits (push) Has been cancelled
Nightly Canary Release / Build nightly packages (push) Has been cancelled
Nightly Canary Release / Publish nightly to PyPI (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

This commit is contained in:
Lucas Gomide
2026-04-10 01:22:27 -03:00
committed by GitHub
parent 84b1b0a0b0
commit fc6792d067
14 changed files with 405 additions and 12 deletions

View File

@@ -172,6 +172,8 @@ class BaseLLM(BaseModel, ABC):
"completion_tokens": 0,
"successful_requests": 0,
"cached_prompt_tokens": 0,
"reasoning_tokens": 0,
"cache_creation_tokens": 0,
}
)
@@ -808,14 +810,24 @@ class BaseLLM(BaseModel, ABC):
cached_tokens = (
usage_data.get("cached_tokens")
or usage_data.get("cached_prompt_tokens")
or usage_data.get("cache_read_input_tokens")
or 0
)
if not cached_tokens:
prompt_details = usage_data.get("prompt_tokens_details")
if isinstance(prompt_details, dict):
cached_tokens = prompt_details.get("cached_tokens", 0) or 0
reasoning_tokens = usage_data.get("reasoning_tokens", 0) or 0
cache_creation_tokens = usage_data.get("cache_creation_tokens", 0) or 0
self._token_usage["prompt_tokens"] += prompt_tokens
self._token_usage["completion_tokens"] += completion_tokens
self._token_usage["total_tokens"] += prompt_tokens + completion_tokens
self._token_usage["successful_requests"] += 1
self._token_usage["cached_prompt_tokens"] += cached_tokens
self._token_usage["reasoning_tokens"] += reasoning_tokens
self._token_usage["cache_creation_tokens"] += cache_creation_tokens
def get_token_usage_summary(self) -> UsageMetrics:
"""Get summary of token usage for this LLM instance.

View File

@@ -1704,18 +1704,23 @@ class AnthropicCompletion(BaseLLM):
def _extract_anthropic_token_usage(
response: Message | BetaMessage,
) -> dict[str, Any]:
"""Extract token usage from Anthropic response."""
"""Extract token usage and response metadata from Anthropic response."""
if hasattr(response, "usage") and response.usage:
usage = response.usage
input_tokens = getattr(usage, "input_tokens", 0)
output_tokens = getattr(usage, "output_tokens", 0)
cache_read_tokens = getattr(usage, "cache_read_input_tokens", 0) or 0
return {
cache_creation_tokens = (
getattr(usage, "cache_creation_input_tokens", 0) or 0
)
result: dict[str, Any] = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
"cached_prompt_tokens": cache_read_tokens,
"cache_creation_tokens": cache_creation_tokens,
}
return result
return {"total_tokens": 0}
def supports_multimodal(self) -> bool:

View File

@@ -1076,19 +1076,27 @@ class AzureCompletion(BaseLLM):
@staticmethod
def _extract_azure_token_usage(response: ChatCompletions) -> dict[str, Any]:
"""Extract token usage from Azure response."""
"""Extract token usage and response metadata from Azure response."""
if hasattr(response, "usage") and response.usage:
usage = response.usage
cached_tokens = 0
prompt_details = getattr(usage, "prompt_tokens_details", None)
if prompt_details:
cached_tokens = getattr(prompt_details, "cached_tokens", 0) or 0
return {
reasoning_tokens = 0
completion_details = getattr(usage, "completion_tokens_details", None)
if completion_details:
reasoning_tokens = (
getattr(completion_details, "reasoning_tokens", 0) or 0
)
result: dict[str, Any] = {
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
"completion_tokens": getattr(usage, "completion_tokens", 0),
"total_tokens": getattr(usage, "total_tokens", 0),
"cached_prompt_tokens": cached_tokens,
"reasoning_tokens": reasoning_tokens,
}
return result
return {"total_tokens": 0}
async def aclose(self) -> None:

View File

@@ -2025,11 +2025,18 @@ class BedrockCompletion(BaseLLM):
input_tokens = usage.get("inputTokens", 0)
output_tokens = usage.get("outputTokens", 0)
total_tokens = usage.get("totalTokens", input_tokens + output_tokens)
raw_cached = (
usage.get("cacheReadInputTokenCount")
or usage.get("cacheReadInputTokens")
or 0
)
cached_tokens = raw_cached if isinstance(raw_cached, int) else 0
self._token_usage["prompt_tokens"] += input_tokens
self._token_usage["completion_tokens"] += output_tokens
self._token_usage["total_tokens"] += total_tokens
self._token_usage["successful_requests"] += 1
self._token_usage["cached_prompt_tokens"] += cached_tokens
def supports_function_calling(self) -> bool:
"""Check if the model supports function calling."""

View File

@@ -1306,17 +1306,20 @@ class GeminiCompletion(BaseLLM):
@staticmethod
def _extract_token_usage(response: GenerateContentResponse) -> dict[str, Any]:
"""Extract token usage from Gemini response."""
"""Extract token usage and response metadata from Gemini response."""
if response.usage_metadata:
usage = response.usage_metadata
cached_tokens = getattr(usage, "cached_content_token_count", 0) or 0
return {
thinking_tokens = getattr(usage, "thoughts_token_count", 0) or 0
result: dict[str, Any] = {
"prompt_token_count": getattr(usage, "prompt_token_count", 0),
"candidates_token_count": getattr(usage, "candidates_token_count", 0),
"total_token_count": getattr(usage, "total_token_count", 0),
"total_tokens": getattr(usage, "total_token_count", 0),
"cached_prompt_tokens": cached_tokens,
"reasoning_tokens": thinking_tokens,
}
return result
return {"total_tokens": 0}
@staticmethod

View File

@@ -1324,19 +1324,23 @@ class OpenAICompletion(BaseLLM):
]
def _extract_responses_token_usage(self, response: Response) -> dict[str, Any]:
"""Extract token usage from Responses API response."""
"""Extract token usage and response metadata from Responses API response."""
if response.usage:
result = {
result: dict[str, Any] = {
"prompt_tokens": response.usage.input_tokens,
"completion_tokens": response.usage.output_tokens,
"total_tokens": response.usage.total_tokens,
}
# Extract cached prompt tokens from input_tokens_details
input_details = getattr(response.usage, "input_tokens_details", None)
if input_details:
result["cached_prompt_tokens"] = (
getattr(input_details, "cached_tokens", 0) or 0
)
output_details = getattr(response.usage, "output_tokens_details", None)
if output_details:
result["reasoning_tokens"] = (
getattr(output_details, "reasoning_tokens", 0) or 0
)
return result
return {"total_tokens": 0}
@@ -2307,20 +2311,24 @@ class OpenAICompletion(BaseLLM):
def _extract_openai_token_usage(
self, response: ChatCompletion | ChatCompletionChunk
) -> dict[str, Any]:
"""Extract token usage from OpenAI ChatCompletion or ChatCompletionChunk response."""
"""Extract token usage and response metadata from OpenAI ChatCompletion."""
if hasattr(response, "usage") and response.usage:
usage = response.usage
result = {
result: dict[str, Any] = {
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
"completion_tokens": getattr(usage, "completion_tokens", 0),
"total_tokens": getattr(usage, "total_tokens", 0),
}
# Extract cached prompt tokens from prompt_tokens_details
prompt_details = getattr(usage, "prompt_tokens_details", None)
if prompt_details:
result["cached_prompt_tokens"] = (
getattr(prompt_details, "cached_tokens", 0) or 0
)
completion_details = getattr(usage, "completion_tokens_details", None)
if completion_details:
result["reasoning_tokens"] = (
getattr(completion_details, "reasoning_tokens", 0) or 0
)
return result
return {"total_tokens": 0}

View File

@@ -29,6 +29,14 @@ class UsageMetrics(BaseModel):
completion_tokens: int = Field(
default=0, description="Number of tokens used in completions."
)
reasoning_tokens: int = Field(
default=0,
description="Number of reasoning/thinking tokens (e.g. OpenAI o-series, Gemini thinking).",
)
cache_creation_tokens: int = Field(
default=0,
description="Number of cache creation tokens (e.g. Anthropic cache writes).",
)
successful_requests: int = Field(
default=0, description="Number of successful requests made."
)
@@ -43,4 +51,6 @@ class UsageMetrics(BaseModel):
self.prompt_tokens += usage_metrics.prompt_tokens
self.cached_prompt_tokens += usage_metrics.cached_prompt_tokens
self.completion_tokens += usage_metrics.completion_tokens
self.reasoning_tokens += usage_metrics.reasoning_tokens
self.cache_creation_tokens += usage_metrics.cache_creation_tokens
self.successful_requests += usage_metrics.successful_requests

View File

@@ -174,3 +174,51 @@ class TestEmitCallCompletedEventPassesUsage:
event = mock_emit.call_args[1]["event"]
assert isinstance(event, LLMCallCompletedEvent)
assert event.usage is None
class TestUsageMetricsNewFields:
def test_add_usage_metrics_aggregates_reasoning_and_cache_creation(self):
from crewai.types.usage_metrics import UsageMetrics
metrics1 = UsageMetrics(
total_tokens=100,
prompt_tokens=60,
completion_tokens=40,
cached_prompt_tokens=10,
reasoning_tokens=15,
cache_creation_tokens=5,
successful_requests=1,
)
metrics2 = UsageMetrics(
total_tokens=200,
prompt_tokens=120,
completion_tokens=80,
cached_prompt_tokens=20,
reasoning_tokens=25,
cache_creation_tokens=10,
successful_requests=1,
)
metrics1.add_usage_metrics(metrics2)
assert metrics1.total_tokens == 300
assert metrics1.prompt_tokens == 180
assert metrics1.completion_tokens == 120
assert metrics1.cached_prompt_tokens == 30
assert metrics1.reasoning_tokens == 40
assert metrics1.cache_creation_tokens == 15
assert metrics1.successful_requests == 2
def test_new_fields_default_to_zero(self):
from crewai.types.usage_metrics import UsageMetrics
metrics = UsageMetrics()
assert metrics.reasoning_tokens == 0
assert metrics.cache_creation_tokens == 0
def test_model_dump_includes_new_fields(self):
from crewai.types.usage_metrics import UsageMetrics
metrics = UsageMetrics(reasoning_tokens=10, cache_creation_tokens=5)
dumped = metrics.model_dump()
assert dumped["reasoning_tokens"] == 10
assert dumped["cache_creation_tokens"] == 5

View File

@@ -1463,3 +1463,45 @@ def test_tool_search_saves_input_tokens():
f"Expected tool_search ({usage_search.prompt_tokens}) to use fewer input tokens "
f"than no search ({usage_no_search.prompt_tokens})"
)
def test_anthropic_cache_creation_tokens_extraction():
"""Test that cache_creation_input_tokens are extracted from Anthropic responses."""
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
mock_response = MagicMock()
mock_response.content = [MagicMock(text="test response")]
mock_response.usage = MagicMock(
input_tokens=100,
output_tokens=50,
cache_read_input_tokens=30,
cache_creation_input_tokens=20,
)
mock_response.stop_reason = None
mock_response.model = None
usage = llm._extract_anthropic_token_usage(mock_response)
assert usage["input_tokens"] == 100
assert usage["output_tokens"] == 50
assert usage["total_tokens"] == 150
assert usage["cached_prompt_tokens"] == 30
assert usage["cache_creation_tokens"] == 20
def test_anthropic_missing_cache_fields_default_to_zero():
"""Test that missing cache fields default to zero."""
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
mock_response = MagicMock()
mock_response.content = [MagicMock(text="test response")]
mock_response.usage = MagicMock(
input_tokens=40,
output_tokens=20,
spec=["input_tokens", "output_tokens"],
)
mock_response.usage.cache_read_input_tokens = None
mock_response.usage.cache_creation_input_tokens = None
usage = llm._extract_anthropic_token_usage(mock_response)
assert usage["cached_prompt_tokens"] == 0
assert usage["cache_creation_tokens"] == 0

View File

@@ -1403,3 +1403,44 @@ def test_azure_stop_words_still_applied_to_regular_responses():
assert "Observation:" not in result
assert "Found results" not in result
assert "I need to search for more information" in result
def test_azure_reasoning_tokens_and_cached_tokens():
"""Test that reasoning_tokens and cached_tokens are extracted from Azure responses."""
llm = LLM(model="azure/gpt-4")
mock_response = MagicMock()
mock_response.usage = MagicMock(
prompt_tokens=100,
completion_tokens=200,
total_tokens=300,
)
mock_response.usage.prompt_tokens_details = MagicMock(cached_tokens=40)
mock_response.usage.completion_tokens_details = MagicMock(reasoning_tokens=60)
usage = llm._extract_azure_token_usage(mock_response)
assert usage["prompt_tokens"] == 100
assert usage["completion_tokens"] == 200
assert usage["total_tokens"] == 300
assert usage["cached_prompt_tokens"] == 40
assert usage["reasoning_tokens"] == 60
def test_azure_no_detail_fields():
"""Test Azure extraction without detail fields."""
llm = LLM(model="azure/gpt-4")
mock_response = MagicMock()
mock_response.usage = MagicMock(
prompt_tokens=50,
completion_tokens=30,
total_tokens=80,
)
mock_response.usage.prompt_tokens_details = None
mock_response.usage.completion_tokens_details = None
usage = llm._extract_azure_token_usage(mock_response)
assert usage["prompt_tokens"] == 50
assert usage["completion_tokens"] == 30
assert usage["cached_prompt_tokens"] == 0
assert usage["reasoning_tokens"] == 0

View File

@@ -1175,3 +1175,81 @@ def test_bedrock_tool_results_not_merged_across_assistant_messages():
)
assert tool_result_messages[0]["content"][0]["toolResult"]["toolUseId"] == "call_a"
assert tool_result_messages[1]["content"][0]["toolResult"]["toolUseId"] == "call_b"
def test_bedrock_cached_token_tracking():
"""Test that cached tokens (cacheReadInputTokenCount) are tracked for Bedrock."""
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
with patch.object(llm._client, 'converse') as mock_converse:
mock_response = {
'output': {
'message': {
'role': 'assistant',
'content': [{'text': 'test response'}]
}
},
'usage': {
'inputTokens': 100,
'outputTokens': 50,
'totalTokens': 150,
'cacheReadInputTokenCount': 30,
}
}
mock_converse.return_value = mock_response
result = llm.call("Hello")
assert result == "test response"
assert llm._token_usage['prompt_tokens'] == 100
assert llm._token_usage['completion_tokens'] == 50
assert llm._token_usage['total_tokens'] == 150
assert llm._token_usage['cached_prompt_tokens'] == 30
def test_bedrock_cached_token_alternate_key():
"""Test that the alternate key cacheReadInputTokens also works."""
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
with patch.object(llm._client, 'converse') as mock_converse:
mock_response = {
'output': {
'message': {
'role': 'assistant',
'content': [{'text': 'test response'}]
}
},
'usage': {
'inputTokens': 80,
'outputTokens': 40,
'totalTokens': 120,
'cacheReadInputTokens': 25,
}
}
mock_converse.return_value = mock_response
llm.call("Hello")
assert llm._token_usage['cached_prompt_tokens'] == 25
def test_bedrock_no_cache_tokens_defaults_to_zero():
"""Test that missing cache token keys default to zero."""
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
with patch.object(llm._client, 'converse') as mock_converse:
mock_response = {
'output': {
'message': {
'role': 'assistant',
'content': [{'text': 'test response'}]
}
},
'usage': {
'inputTokens': 60,
'outputTokens': 30,
'totalTokens': 90,
}
}
mock_converse.return_value = mock_response
llm.call("Hello")
assert llm._token_usage['cached_prompt_tokens'] == 0

View File

@@ -1190,3 +1190,42 @@ def test_gemini_cached_prompt_tokens_with_tools():
# cached_prompt_tokens should be populated (may be 0 if Gemini
# doesn't cache for this particular request, but the field should exist)
assert usage.cached_prompt_tokens >= 0
def test_gemini_reasoning_tokens_extraction():
"""Test that thoughts_token_count is extracted as reasoning_tokens from Gemini."""
llm = LLM(model="google/gemini-2.0-flash-001")
mock_response = MagicMock()
mock_response.usage_metadata = MagicMock(
prompt_token_count=100,
candidates_token_count=50,
total_token_count=150,
cached_content_token_count=10,
thoughts_token_count=30,
)
usage = llm._extract_token_usage(mock_response)
assert usage["prompt_token_count"] == 100
assert usage["candidates_token_count"] == 50
assert usage["total_tokens"] == 150
assert usage["cached_prompt_tokens"] == 10
assert usage["reasoning_tokens"] == 30
def test_gemini_no_thinking_tokens_defaults_to_zero():
"""Test that missing thoughts_token_count defaults to zero."""
llm = LLM(model="google/gemini-2.0-flash-001")
mock_response = MagicMock()
mock_response.usage_metadata = MagicMock(
prompt_token_count=80,
candidates_token_count=40,
total_token_count=120,
cached_content_token_count=0,
thoughts_token_count=None,
)
mock_response.candidates = []
usage = llm._extract_token_usage(mock_response)
assert usage["reasoning_tokens"] == 0
assert usage["cached_prompt_tokens"] == 0

View File

@@ -1929,6 +1929,47 @@ def test_openai_streaming_returns_tool_calls_without_available_functions():
assert result[0]["type"] == "function"
def test_openai_responses_api_reasoning_tokens_extraction():
"""Test that reasoning_tokens are extracted from Responses API responses."""
llm = LLM(model="openai/gpt-4o")
mock_response = MagicMock()
mock_response.usage = MagicMock(
input_tokens=100,
output_tokens=200,
total_tokens=300,
)
mock_response.usage.input_tokens_details = MagicMock(cached_tokens=25)
mock_response.usage.output_tokens_details = MagicMock(reasoning_tokens=80)
usage = llm._extract_responses_token_usage(mock_response)
assert usage["prompt_tokens"] == 100
assert usage["completion_tokens"] == 200
assert usage["total_tokens"] == 300
assert usage["cached_prompt_tokens"] == 25
assert usage["reasoning_tokens"] == 80
def test_openai_responses_api_no_detail_fields_omitted():
"""Test that reasoning/cached fields are omitted when Responses API details are absent."""
llm = LLM(model="openai/gpt-4o")
mock_response = MagicMock()
mock_response.usage = MagicMock(
input_tokens=50,
output_tokens=30,
total_tokens=80,
)
mock_response.usage.input_tokens_details = None
mock_response.usage.output_tokens_details = None
usage = llm._extract_responses_token_usage(mock_response)
assert usage["prompt_tokens"] == 50
assert usage["completion_tokens"] == 30
assert "cached_prompt_tokens" not in usage
assert "reasoning_tokens" not in usage
@pytest.mark.asyncio
async def test_openai_async_streaming_returns_tool_calls_without_available_functions():
"""Test that async streaming returns tool calls list when available_functions is None.
@@ -2018,3 +2059,44 @@ async def test_openai_async_streaming_returns_tool_calls_without_available_funct
assert result[0]["function"]["arguments"] == '{"expression": "1+1"}'
assert result[0]["id"] == "call_abc123"
assert result[0]["type"] == "function"
def test_openai_reasoning_tokens_extraction():
"""Test that reasoning_tokens are extracted from OpenAI o-series responses."""
llm = LLM(model="openai/gpt-4o")
mock_response = MagicMock()
mock_response.usage = MagicMock(
prompt_tokens=100,
completion_tokens=200,
total_tokens=300,
)
mock_response.usage.prompt_tokens_details = MagicMock(cached_tokens=25)
mock_response.usage.completion_tokens_details = MagicMock(reasoning_tokens=80)
usage = llm._extract_openai_token_usage(mock_response)
assert usage["prompt_tokens"] == 100
assert usage["completion_tokens"] == 200
assert usage["total_tokens"] == 300
assert usage["cached_prompt_tokens"] == 25
assert usage["reasoning_tokens"] == 80
def test_openai_no_detail_fields_omitted():
"""Test that reasoning/cached fields are omitted when details are absent."""
llm = LLM(model="openai/gpt-4o")
mock_response = MagicMock()
mock_response.usage = MagicMock(
prompt_tokens=50,
completion_tokens=30,
total_tokens=80,
)
mock_response.usage.prompt_tokens_details = None
mock_response.usage.completion_tokens_details = None
usage = llm._extract_openai_token_usage(mock_response)
assert usage["prompt_tokens"] == 50
assert usage["completion_tokens"] == 30
assert "cached_prompt_tokens" not in usage
assert "reasoning_tokens" not in usage

View File

@@ -1001,6 +1001,8 @@ def test_usage_info_non_streaming_with_call():
"completion_tokens": 0,
"successful_requests": 0,
"cached_prompt_tokens": 0,
"reasoning_tokens": 0,
"cache_creation_tokens": 0,
}
assert llm.stream is False
@@ -1025,6 +1027,8 @@ def test_usage_info_streaming_with_call():
"completion_tokens": 0,
"successful_requests": 0,
"cached_prompt_tokens": 0,
"reasoning_tokens": 0,
"cache_creation_tokens": 0,
}
assert llm.stream is True
@@ -1056,6 +1060,8 @@ async def test_usage_info_non_streaming_with_acall():
"completion_tokens": 0,
"successful_requests": 0,
"cached_prompt_tokens": 0,
"reasoning_tokens": 0,
"cache_creation_tokens": 0,
}
with patch.object(
@@ -1089,6 +1095,8 @@ async def test_usage_info_non_streaming_with_acall_and_stop():
"completion_tokens": 0,
"successful_requests": 0,
"cached_prompt_tokens": 0,
"reasoning_tokens": 0,
"cache_creation_tokens": 0,
}
with patch.object(
@@ -1121,6 +1129,8 @@ async def test_usage_info_streaming_with_acall():
"completion_tokens": 0,
"successful_requests": 0,
"cached_prompt_tokens": 0,
"reasoning_tokens": 0,
"cache_creation_tokens": 0,
}
with patch.object(