From 0976c42c6b3df77f04eacf21e5a77e896f229291 Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sat, 3 Jan 2026 19:28:08 +0000 Subject: [PATCH] fix: track token usage in litellm non-streaming and async calls MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This fixes GitHub issue #4170 where token usage metrics were not being updated when using litellm with streaming responses and async calls. Changes: - Add token usage tracking to _handle_non_streaming_response - Add token usage tracking to _ahandle_non_streaming_response - Add token usage tracking to _ahandle_streaming_response - Fix sync streaming to track usage in both code paths - Convert usage objects to dicts before passing to _track_token_usage_internal - Add comprehensive tests for token usage tracking in all scenarios Co-Authored-By: João --- lib/crewai/src/crewai/llm.py | 77 +++- .../llms/litellm/test_litellm_token_usage.py | 369 ++++++++++++++++++ 2 files changed, 441 insertions(+), 5 deletions(-) create mode 100644 lib/crewai/tests/llms/litellm/test_litellm_token_usage.py diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index 77053deeb..8b92741f7 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -928,7 +928,17 @@ class LLM(BaseLLM): if not tool_calls or not available_functions: # Track token usage and log callbacks if available in streaming mode if usage_info: - self._track_token_usage_internal(usage_info) + # Convert usage object to dict if needed + if hasattr(usage_info, "__dict__"): + usage_dict = { + "prompt_tokens": getattr(usage_info, "prompt_tokens", 0), + "completion_tokens": getattr(usage_info, "completion_tokens", 0), + "total_tokens": getattr(usage_info, "total_tokens", 0), + "cached_tokens": getattr(usage_info, "cached_tokens", 0), + } + else: + usage_dict = usage_info + self._track_token_usage_internal(usage_dict) self._handle_streaming_callbacks(callbacks, usage_info, last_chunk) if response_model and self.is_litellm: @@ -964,7 +974,17 @@ class LLM(BaseLLM): # --- 10) Track token usage and log callbacks if available in streaming mode if usage_info: - self._track_token_usage_internal(usage_info) + # Convert usage object to dict if needed + if hasattr(usage_info, "__dict__"): + usage_dict = { + "prompt_tokens": getattr(usage_info, "prompt_tokens", 0), + "completion_tokens": getattr(usage_info, "completion_tokens", 0), + "total_tokens": getattr(usage_info, "total_tokens", 0), + "cached_tokens": getattr(usage_info, "cached_tokens", 0), + } + else: + usage_dict = usage_info + self._track_token_usage_internal(usage_dict) self._handle_streaming_callbacks(callbacks, usage_info, last_chunk) # --- 11) Emit completion event and return response @@ -1173,7 +1193,23 @@ class LLM(BaseLLM): 0 ].message text_response = response_message.content or "" - # --- 3) Handle callbacks with usage info + + # --- 3a) Track token usage internally + usage_info = getattr(response, "usage", None) + if usage_info: + # Convert usage object to dict if needed + if hasattr(usage_info, "__dict__"): + usage_dict = { + "prompt_tokens": getattr(usage_info, "prompt_tokens", 0), + "completion_tokens": getattr(usage_info, "completion_tokens", 0), + "total_tokens": getattr(usage_info, "total_tokens", 0), + "cached_tokens": getattr(usage_info, "cached_tokens", 0), + } + else: + usage_dict = usage_info + self._track_token_usage_internal(usage_dict) + + # --- 3b) Handle callbacks with usage info if callbacks and len(callbacks) > 0: for callback in callbacks: if hasattr(callback, "log_success_event"): @@ -1293,10 +1329,24 @@ class LLM(BaseLLM): ].message text_response = response_message.content or "" + # Track token usage internally + usage_info = getattr(response, "usage", None) + if usage_info: + # Convert usage object to dict if needed + if hasattr(usage_info, "__dict__"): + usage_dict = { + "prompt_tokens": getattr(usage_info, "prompt_tokens", 0), + "completion_tokens": getattr(usage_info, "completion_tokens", 0), + "total_tokens": getattr(usage_info, "total_tokens", 0), + "cached_tokens": getattr(usage_info, "cached_tokens", 0), + } + else: + usage_dict = usage_info + self._track_token_usage_internal(usage_dict) + if callbacks and len(callbacks) > 0: for callback in callbacks: if hasattr(callback, "log_success_event"): - usage_info = getattr(response, "usage", None) if usage_info: callback.log_success_event( kwargs=params, @@ -1381,7 +1431,10 @@ class LLM(BaseLLM): if not isinstance(chunk.choices, type): choices = chunk.choices - if hasattr(chunk, "usage") and chunk.usage is not None: + # Try to extract usage information if available + if isinstance(chunk, dict) and "usage" in chunk: + usage_info = chunk["usage"] + elif hasattr(chunk, "usage") and chunk.usage is not None: usage_info = chunk.usage if choices and len(choices) > 0: @@ -1434,6 +1487,20 @@ class LLM(BaseLLM): ), ) + # Track token usage internally + if usage_info: + # Convert usage object to dict if needed + if hasattr(usage_info, "__dict__"): + usage_dict = { + "prompt_tokens": getattr(usage_info, "prompt_tokens", 0), + "completion_tokens": getattr(usage_info, "completion_tokens", 0), + "total_tokens": getattr(usage_info, "total_tokens", 0), + "cached_tokens": getattr(usage_info, "cached_tokens", 0), + } + else: + usage_dict = usage_info + self._track_token_usage_internal(usage_dict) + if callbacks and len(callbacks) > 0 and usage_info: for callback in callbacks: if hasattr(callback, "log_success_event"): diff --git a/lib/crewai/tests/llms/litellm/test_litellm_token_usage.py b/lib/crewai/tests/llms/litellm/test_litellm_token_usage.py new file mode 100644 index 000000000..ea0c27624 --- /dev/null +++ b/lib/crewai/tests/llms/litellm/test_litellm_token_usage.py @@ -0,0 +1,369 @@ +"""Tests for LiteLLM token usage tracking functionality. + +These tests verify that token usage metrics are properly tracked for: +- Non-streaming responses +- Async non-streaming responses +- Async streaming responses + +This addresses GitHub issue #4170 where token usage metrics were not being +updated when using litellm with streaming responses and async calls. +""" + +from collections.abc import AsyncIterator +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from crewai.llm import LLM + + +class MockUsage: + """Mock usage object that mimics litellm's usage response.""" + + def __init__( + self, + prompt_tokens: int = 10, + completion_tokens: int = 20, + total_tokens: int = 30, + ): + self.prompt_tokens = prompt_tokens + self.completion_tokens = completion_tokens + self.total_tokens = total_tokens + + +class MockMessage: + """Mock message object that mimics litellm's message response.""" + + def __init__(self, content: str = "Test response"): + self.content = content + self.tool_calls = None + + +class MockChoice: + """Mock choice object that mimics litellm's choice response.""" + + def __init__(self, content: str = "Test response"): + self.message = MockMessage(content) + + +class MockResponse: + """Mock response object that mimics litellm's completion response.""" + + def __init__( + self, + content: str = "Test response", + prompt_tokens: int = 10, + completion_tokens: int = 20, + ): + self.choices = [MockChoice(content)] + self.usage = MockUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + + +class MockStreamDelta: + """Mock delta object for streaming responses.""" + + def __init__(self, content: str | None = None): + self.content = content + self.tool_calls = None + + +class MockStreamChoice: + """Mock choice object for streaming responses.""" + + def __init__(self, content: str | None = None): + self.delta = MockStreamDelta(content) + + +class MockStreamChunk: + """Mock chunk object for streaming responses.""" + + def __init__( + self, + content: str | None = None, + usage: MockUsage | None = None, + ): + self.choices = [MockStreamChoice(content)] + self.usage = usage + + +def test_non_streaming_response_tracks_token_usage(): + """Test that non-streaming responses properly track token usage.""" + llm = LLM(model="gpt-4o-mini", is_litellm=True, stream=False) + + mock_response = MockResponse( + content="Hello, world!", + prompt_tokens=15, + completion_tokens=25, + ) + + with patch("litellm.completion", return_value=mock_response): + result = llm.call("Say hello") + + assert result == "Hello, world!" + + # Verify token usage was tracked + usage_summary = llm.get_token_usage_summary() + assert usage_summary.prompt_tokens == 15 + assert usage_summary.completion_tokens == 25 + assert usage_summary.total_tokens == 40 + assert usage_summary.successful_requests == 1 + + +def test_non_streaming_response_accumulates_token_usage(): + """Test that multiple non-streaming calls accumulate token usage.""" + llm = LLM(model="gpt-4o-mini", is_litellm=True, stream=False) + + mock_response1 = MockResponse( + content="First response", + prompt_tokens=10, + completion_tokens=20, + ) + mock_response2 = MockResponse( + content="Second response", + prompt_tokens=15, + completion_tokens=25, + ) + + with patch("litellm.completion") as mock_completion: + mock_completion.return_value = mock_response1 + llm.call("First call") + + mock_completion.return_value = mock_response2 + llm.call("Second call") + + # Verify accumulated token usage + usage_summary = llm.get_token_usage_summary() + assert usage_summary.prompt_tokens == 25 # 10 + 15 + assert usage_summary.completion_tokens == 45 # 20 + 25 + assert usage_summary.total_tokens == 70 # 30 + 40 + assert usage_summary.successful_requests == 2 + + +@pytest.mark.asyncio +async def test_async_non_streaming_response_tracks_token_usage(): + """Test that async non-streaming responses properly track token usage.""" + llm = LLM(model="gpt-4o-mini", is_litellm=True, stream=False) + + mock_response = MockResponse( + content="Async hello!", + prompt_tokens=12, + completion_tokens=18, + ) + + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.return_value = mock_response + result = await llm.acall("Say hello async") + + assert result == "Async hello!" + + # Verify token usage was tracked + usage_summary = llm.get_token_usage_summary() + assert usage_summary.prompt_tokens == 12 + assert usage_summary.completion_tokens == 18 + assert usage_summary.total_tokens == 30 + assert usage_summary.successful_requests == 1 + + +@pytest.mark.asyncio +async def test_async_non_streaming_response_accumulates_token_usage(): + """Test that multiple async non-streaming calls accumulate token usage.""" + llm = LLM(model="gpt-4o-mini", is_litellm=True, stream=False) + + mock_response1 = MockResponse( + content="First async response", + prompt_tokens=8, + completion_tokens=12, + ) + mock_response2 = MockResponse( + content="Second async response", + prompt_tokens=10, + completion_tokens=15, + ) + + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.return_value = mock_response1 + await llm.acall("First async call") + + mock_acompletion.return_value = mock_response2 + await llm.acall("Second async call") + + # Verify accumulated token usage + usage_summary = llm.get_token_usage_summary() + assert usage_summary.prompt_tokens == 18 # 8 + 10 + assert usage_summary.completion_tokens == 27 # 12 + 15 + assert usage_summary.total_tokens == 45 # 20 + 25 + assert usage_summary.successful_requests == 2 + + +@pytest.mark.asyncio +async def test_async_streaming_response_tracks_token_usage(): + """Test that async streaming responses properly track token usage.""" + llm = LLM(model="gpt-4o-mini", is_litellm=True, stream=True) + + # Create mock streaming chunks + chunks = [ + MockStreamChunk(content="Hello"), + MockStreamChunk(content=", "), + MockStreamChunk(content="world"), + MockStreamChunk(content="!"), + # Final chunk with usage info (this is how litellm typically sends usage) + MockStreamChunk( + content=None, + usage=MockUsage(prompt_tokens=20, completion_tokens=30, total_tokens=50), + ), + ] + + async def mock_async_generator() -> AsyncIterator[MockStreamChunk]: + for chunk in chunks: + yield chunk + + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.return_value = mock_async_generator() + result = await llm.acall("Say hello streaming") + + assert result == "Hello, world!" + + # Verify token usage was tracked + usage_summary = llm.get_token_usage_summary() + assert usage_summary.prompt_tokens == 20 + assert usage_summary.completion_tokens == 30 + assert usage_summary.total_tokens == 50 + assert usage_summary.successful_requests == 1 + + +@pytest.mark.asyncio +async def test_async_streaming_response_with_dict_usage(): + """Test that async streaming handles dict-based usage info.""" + llm = LLM(model="gpt-4o-mini", is_litellm=True, stream=True) + + # Create mock streaming chunks using dict format + class DictStreamChunk: + def __init__( + self, + content: str | None = None, + usage: dict | None = None, + ): + self.choices = [MockStreamChoice(content)] + # Simulate dict-based usage (some providers return this) + self._usage = usage + + @property + def usage(self) -> MockUsage | None: + if self._usage: + return MockUsage(**self._usage) + return None + + chunks = [ + DictStreamChunk(content="Test"), + DictStreamChunk(content=" response"), + DictStreamChunk( + content=None, + usage={ + "prompt_tokens": 25, + "completion_tokens": 35, + "total_tokens": 60, + }, + ), + ] + + async def mock_async_generator() -> AsyncIterator[DictStreamChunk]: + for chunk in chunks: + yield chunk + + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.return_value = mock_async_generator() + result = await llm.acall("Test streaming with dict usage") + + assert result == "Test response" + + # Verify token usage was tracked + usage_summary = llm.get_token_usage_summary() + assert usage_summary.prompt_tokens == 25 + assert usage_summary.completion_tokens == 35 + assert usage_summary.total_tokens == 60 + assert usage_summary.successful_requests == 1 + + +def test_streaming_response_tracks_token_usage(): + """Test that sync streaming responses properly track token usage.""" + llm = LLM(model="gpt-4o-mini", is_litellm=True, stream=True) + + # Create mock streaming chunks + chunks = [ + MockStreamChunk(content="Sync"), + MockStreamChunk(content=" streaming"), + MockStreamChunk(content=" test"), + # Final chunk with usage info + MockStreamChunk( + content=None, + usage=MockUsage(prompt_tokens=18, completion_tokens=22, total_tokens=40), + ), + ] + + with patch("litellm.completion", return_value=iter(chunks)): + result = llm.call("Test sync streaming") + + assert result == "Sync streaming test" + + # Verify token usage was tracked + usage_summary = llm.get_token_usage_summary() + assert usage_summary.prompt_tokens == 18 + assert usage_summary.completion_tokens == 22 + assert usage_summary.total_tokens == 40 + assert usage_summary.successful_requests == 1 + + +def test_token_usage_with_no_usage_info(): + """Test that token usage tracking handles missing usage info gracefully.""" + llm = LLM(model="gpt-4o-mini", is_litellm=True, stream=False) + + # Create mock response without usage info + mock_response = MagicMock() + mock_response.choices = [MockChoice("Response without usage")] + mock_response.usage = None + + with patch("litellm.completion", return_value=mock_response): + result = llm.call("Test without usage") + + assert result == "Response without usage" + + # Verify token usage remains at zero + usage_summary = llm.get_token_usage_summary() + assert usage_summary.prompt_tokens == 0 + assert usage_summary.completion_tokens == 0 + assert usage_summary.total_tokens == 0 + assert usage_summary.successful_requests == 0 + + +@pytest.mark.asyncio +async def test_async_streaming_with_no_usage_info(): + """Test that async streaming handles missing usage info gracefully.""" + llm = LLM(model="gpt-4o-mini", is_litellm=True, stream=True) + + # Create mock streaming chunks without usage info + chunks = [ + MockStreamChunk(content="No"), + MockStreamChunk(content=" usage"), + MockStreamChunk(content=" info"), + ] + + async def mock_async_generator() -> AsyncIterator[MockStreamChunk]: + for chunk in chunks: + yield chunk + + with patch("litellm.acompletion", new_callable=AsyncMock) as mock_acompletion: + mock_acompletion.return_value = mock_async_generator() + result = await llm.acall("Test without usage info") + + assert result == "No usage info" + + # Verify token usage remains at zero + usage_summary = llm.get_token_usage_summary() + assert usage_summary.prompt_tokens == 0 + assert usage_summary.completion_tokens == 0 + assert usage_summary.total_tokens == 0 + assert usage_summary.successful_requests == 0