From 4e84b98ac2de24df5931fac700ba3d7f7972303a Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sat, 22 Feb 2025 01:32:43 +0000 Subject: [PATCH] refactor: improve token counter implementation - Fix import sorting in tests - Add docstrings and type validation - Add comprehensive test cases - Add validation for negative token counts Addresses review feedback on #2198 Co-Authored-By: Joe Moura --- .../utilities/base_token_process.py | 11 +++++++ tests/utilities/test_token_process.py | 31 ++++++++++++++++++- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/crewai/agents/agent_builder/utilities/base_token_process.py b/src/crewai/agents/agent_builder/utilities/base_token_process.py index 9e902123a..50a02023f 100644 --- a/src/crewai/agents/agent_builder/utilities/base_token_process.py +++ b/src/crewai/agents/agent_builder/utilities/base_token_process.py @@ -18,7 +18,18 @@ class TokenProcess: self.total_tokens += tokens def sum_cached_prompt_tokens(self, tokens: int | None) -> None: + """ + Adds the given token count to cached prompt tokens. + + Args: + tokens (int | None): Number of tokens to add. None values are ignored. + + Raises: + ValueError: If tokens is negative. + """ if tokens is not None: + if tokens < 0: + raise ValueError("Token count cannot be negative") self.cached_prompt_tokens += tokens def sum_successful_requests(self, requests: int) -> None: diff --git a/tests/utilities/test_token_process.py b/tests/utilities/test_token_process.py index b9e6cee2d..7fb89dab7 100644 --- a/tests/utilities/test_token_process.py +++ b/tests/utilities/test_token_process.py @@ -1,17 +1,46 @@ -from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess import unittest +from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess + class TestTokenProcess(unittest.TestCase): def setUp(self): self.token_process = TokenProcess() def test_sum_cached_prompt_tokens_with_none(self): + """Test that passing None to sum_cached_prompt_tokens doesn't modify the counter.""" initial_tokens = self.token_process.cached_prompt_tokens self.token_process.sum_cached_prompt_tokens(None) self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens) def test_sum_cached_prompt_tokens_with_int(self): + """Test that passing an integer correctly increments the counter.""" initial_tokens = self.token_process.cached_prompt_tokens self.token_process.sum_cached_prompt_tokens(5) self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens + 5) + + def test_sum_cached_prompt_tokens_with_zero(self): + """Test that passing zero doesn't modify the counter.""" + initial_tokens = self.token_process.cached_prompt_tokens + self.token_process.sum_cached_prompt_tokens(0) + self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens) + + def test_sum_cached_prompt_tokens_with_large_number(self): + """Test that the counter works with large numbers.""" + initial_tokens = self.token_process.cached_prompt_tokens + self.token_process.sum_cached_prompt_tokens(1000000) + self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens + 1000000) + + def test_sum_cached_prompt_tokens_multiple_calls(self): + """Test that multiple calls accumulate correctly, ignoring None values.""" + initial_tokens = self.token_process.cached_prompt_tokens + self.token_process.sum_cached_prompt_tokens(5) + self.token_process.sum_cached_prompt_tokens(None) + self.token_process.sum_cached_prompt_tokens(3) + self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens + 8) + + def test_sum_cached_prompt_tokens_with_negative(self): + """Test that negative values raise ValueError.""" + with self.assertRaises(ValueError) as context: + self.token_process.sum_cached_prompt_tokens(-1) + self.assertEqual(str(context.exception), "Token count cannot be negative")