diff --git a/src/crewai/agents/agent_builder/utilities/base_token_process.py b/src/crewai/agents/agent_builder/utilities/base_token_process.py index 3ce5cfb82..9e902123a 100644 --- a/src/crewai/agents/agent_builder/utilities/base_token_process.py +++ b/src/crewai/agents/agent_builder/utilities/base_token_process.py @@ -17,8 +17,9 @@ class TokenProcess: self.completion_tokens += tokens self.total_tokens += tokens - def sum_cached_prompt_tokens(self, tokens: int) -> None: - self.cached_prompt_tokens += tokens + def sum_cached_prompt_tokens(self, tokens: int | None) -> None: + if tokens is not None: + self.cached_prompt_tokens += tokens def sum_successful_requests(self, requests: int) -> None: self.successful_requests += requests diff --git a/tests/utilities/test_token_process.py b/tests/utilities/test_token_process.py new file mode 100644 index 000000000..50160e5b1 --- /dev/null +++ b/tests/utilities/test_token_process.py @@ -0,0 +1,17 @@ +import unittest +from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess + + +class TestTokenProcess(unittest.TestCase): + def setUp(self): + self.token_process = TokenProcess() + + def test_sum_cached_prompt_tokens_with_none(self): + initial_tokens = self.token_process.cached_prompt_tokens + self.token_process.sum_cached_prompt_tokens(None) + self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens) + + def test_sum_cached_prompt_tokens_with_int(self): + initial_tokens = self.token_process.cached_prompt_tokens + self.token_process.sum_cached_prompt_tokens(5) + self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens + 5)