mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 16:48:30 +00:00
refactor: improve token counter implementation
- Fix import sorting in tests - Add docstrings and type validation - Add comprehensive test cases - Add validation for negative token counts Addresses review feedback on #2198 Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -18,7 +18,18 @@ class TokenProcess:
|
||||
self.total_tokens += tokens
|
||||
|
||||
def sum_cached_prompt_tokens(self, tokens: int | None) -> None:
|
||||
"""
|
||||
Adds the given token count to cached prompt tokens.
|
||||
|
||||
Args:
|
||||
tokens (int | None): Number of tokens to add. None values are ignored.
|
||||
|
||||
Raises:
|
||||
ValueError: If tokens is negative.
|
||||
"""
|
||||
if tokens is not None:
|
||||
if tokens < 0:
|
||||
raise ValueError("Token count cannot be negative")
|
||||
self.cached_prompt_tokens += tokens
|
||||
|
||||
def sum_successful_requests(self, requests: int) -> None:
|
||||
|
||||
@@ -1,17 +1,46 @@
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
import unittest
|
||||
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
|
||||
|
||||
class TestTokenProcess(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.token_process = TokenProcess()
|
||||
|
||||
def test_sum_cached_prompt_tokens_with_none(self):
|
||||
"""Test that passing None to sum_cached_prompt_tokens doesn't modify the counter."""
|
||||
initial_tokens = self.token_process.cached_prompt_tokens
|
||||
self.token_process.sum_cached_prompt_tokens(None)
|
||||
self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens)
|
||||
|
||||
def test_sum_cached_prompt_tokens_with_int(self):
|
||||
"""Test that passing an integer correctly increments the counter."""
|
||||
initial_tokens = self.token_process.cached_prompt_tokens
|
||||
self.token_process.sum_cached_prompt_tokens(5)
|
||||
self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens + 5)
|
||||
|
||||
def test_sum_cached_prompt_tokens_with_zero(self):
|
||||
"""Test that passing zero doesn't modify the counter."""
|
||||
initial_tokens = self.token_process.cached_prompt_tokens
|
||||
self.token_process.sum_cached_prompt_tokens(0)
|
||||
self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens)
|
||||
|
||||
def test_sum_cached_prompt_tokens_with_large_number(self):
|
||||
"""Test that the counter works with large numbers."""
|
||||
initial_tokens = self.token_process.cached_prompt_tokens
|
||||
self.token_process.sum_cached_prompt_tokens(1000000)
|
||||
self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens + 1000000)
|
||||
|
||||
def test_sum_cached_prompt_tokens_multiple_calls(self):
|
||||
"""Test that multiple calls accumulate correctly, ignoring None values."""
|
||||
initial_tokens = self.token_process.cached_prompt_tokens
|
||||
self.token_process.sum_cached_prompt_tokens(5)
|
||||
self.token_process.sum_cached_prompt_tokens(None)
|
||||
self.token_process.sum_cached_prompt_tokens(3)
|
||||
self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens + 8)
|
||||
|
||||
def test_sum_cached_prompt_tokens_with_negative(self):
|
||||
"""Test that negative values raise ValueError."""
|
||||
with self.assertRaises(ValueError) as context:
|
||||
self.token_process.sum_cached_prompt_tokens(-1)
|
||||
self.assertEqual(str(context.exception), "Token count cannot be negative")
|
||||
|
||||
Reference in New Issue
Block a user