From 3f02d10626ec6e890ad08aa58d93081bdce3c89b Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sat, 22 Feb 2025 01:29:47 +0000 Subject: [PATCH] fix: handle None values in token counter - Update sum_cached_prompt_tokens to handle None values gracefully - Add unit tests for token counting with None values - Fixes #2197 Co-Authored-By: Joe Moura --- .../utilities/base_token_process.py | 5 +++-- tests/utilities/test_token_process.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) create mode 100644 tests/utilities/test_token_process.py diff --git a/src/crewai/agents/agent_builder/utilities/base_token_process.py b/src/crewai/agents/agent_builder/utilities/base_token_process.py index 3ce5cfb82..9e902123a 100644 --- a/src/crewai/agents/agent_builder/utilities/base_token_process.py +++ b/src/crewai/agents/agent_builder/utilities/base_token_process.py @@ -17,8 +17,9 @@ class TokenProcess: self.completion_tokens += tokens self.total_tokens += tokens - def sum_cached_prompt_tokens(self, tokens: int) -> None: - self.cached_prompt_tokens += tokens + def sum_cached_prompt_tokens(self, tokens: int | None) -> None: + if tokens is not None: + self.cached_prompt_tokens += tokens def sum_successful_requests(self, requests: int) -> None: self.successful_requests += requests diff --git a/tests/utilities/test_token_process.py b/tests/utilities/test_token_process.py new file mode 100644 index 000000000..50160e5b1 --- /dev/null +++ b/tests/utilities/test_token_process.py @@ -0,0 +1,17 @@ +import unittest +from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess + + +class TestTokenProcess(unittest.TestCase): + def setUp(self): + self.token_process = TokenProcess() + + def test_sum_cached_prompt_tokens_with_none(self): + initial_tokens = self.token_process.cached_prompt_tokens + self.token_process.sum_cached_prompt_tokens(None) + self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens) + + def test_sum_cached_prompt_tokens_with_int(self): + initial_tokens = self.token_process.cached_prompt_tokens + self.token_process.sum_cached_prompt_tokens(5) + self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens + 5)