Compare commits

..

6 Commits

Author SHA1 Message Date
Devin AI
e78efb047f style: fix import block formatting with ruff
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-22 01:38:03 +00:00
Devin AI
d58dc08511 style: fix import sorting in base_token_process.py
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-22 01:36:52 +00:00
Devin AI
9ed21c4b0e feat: add logging for None values and improve documentation
- Add logging for None token values
- Improve test documentation and structure
- Fix import sorting in tests

Part of #2198

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-22 01:36:00 +00:00
Devin AI
4e84b98ac2 refactor: improve token counter implementation
- Fix import sorting in tests
- Add docstrings and type validation
- Add comprehensive test cases
- Add validation for negative token counts

Addresses review feedback on #2198

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-22 01:32:43 +00:00
Devin AI
9f7f1cdb54 style: fix import sorting in test_token_process.py
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-22 01:30:52 +00:00
Devin AI
3f02d10626 fix: handle None values in token counter
- Update sum_cached_prompt_tokens to handle None values gracefully
- Add unit tests for token counting with None values
- Fixes #2197

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-22 01:29:47 +00:00
4 changed files with 104 additions and 50 deletions

View File

@@ -1,5 +1,9 @@
import logging
from crewai.types.usage_metrics import UsageMetrics
logger = logging.getLogger(__name__)
class TokenProcess:
def __init__(self) -> None:
@@ -17,7 +21,21 @@ class TokenProcess:
self.completion_tokens += tokens
self.total_tokens += tokens
def sum_cached_prompt_tokens(self, tokens: int) -> None:
def sum_cached_prompt_tokens(self, tokens: int | None) -> None:
"""
Adds the given token count to cached prompt tokens.
Args:
tokens (int | None): Number of tokens to add. None values are ignored.
Raises:
ValueError: If tokens is negative.
"""
if tokens is None:
logger.debug("Received None value for token count")
return
if tokens < 0:
raise ValueError("Token count cannot be negative")
self.cached_prompt_tokens += tokens
def sum_successful_requests(self, requests: int) -> None:

View File

@@ -1278,11 +1278,11 @@ class Crew(BaseModel):
def _reset_all_memories(self) -> None:
"""Reset all available memory systems."""
memory_systems = [
("short term", getattr(self, "_short_term_memory", None)),
("entity", getattr(self, "_entity_memory", None)),
("long term", getattr(self, "_long_term_memory", None)),
("task output", getattr(self, "_task_output_handler", None)),
("knowledge", getattr(self, "knowledge", None)),
("short term", self._short_term_memory),
("entity", self._entity_memory),
("long term", self._long_term_memory),
("task output", self._task_output_handler),
("knowledge", self.knowledge),
]
for name, system in memory_systems:

View File

@@ -713,35 +713,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}")
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
"""
Start the flow execution in a synchronous context.
This method wraps kickoff_async so that all state initialization and event
emission is handled in the asynchronous method.
"""
async def run_flow():
return await self.kickoff_async(inputs)
return asyncio.run(run_flow())
@init_flow_main_trace
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
"""
Start the flow execution asynchronously.
This method performs state restoration (if an 'id' is provided and persistence is available)
and updates the flow state with any additional inputs. It then emits the FlowStartedEvent,
logs the flow startup, and executes all start methods. Once completed, it emits the
FlowFinishedEvent and returns the final output.
"""Start the flow execution.
Args:
inputs: Optional dictionary containing input values and/or a state ID for restoration.
Returns:
The final output from the flow, which is the result of the last executed method.
inputs: Optional dictionary containing input values and potentially a state ID to restore
"""
if inputs:
# Handle state restoration if ID is provided in inputs
if inputs and "id" in inputs and self._persistence is not None:
restore_uuid = inputs["id"]
stored_state = self._persistence.load_state(restore_uuid)
# Override the id in the state if it exists in inputs
if "id" in inputs:
if isinstance(self._state, dict):
@@ -749,27 +730,24 @@ class Flow(Generic[T], metaclass=FlowMeta):
elif isinstance(self._state, BaseModel):
setattr(self._state, "id", inputs["id"])
# If persistence is enabled, attempt to restore the stored state using the provided id.
if "id" in inputs and self._persistence is not None:
restore_uuid = inputs["id"]
stored_state = self._persistence.load_state(restore_uuid)
if stored_state:
self._log_flow_event(
f"Loading flow state from memory for UUID: {restore_uuid}",
color="yellow",
)
self._restore_state(stored_state)
else:
self._log_flow_event(
f"No flow state found for UUID: {restore_uuid}", color="red"
)
if stored_state:
self._log_flow_event(
f"Loading flow state from memory for UUID: {restore_uuid}",
color="yellow",
)
# Restore the state
self._restore_state(stored_state)
else:
self._log_flow_event(
f"No flow state found for UUID: {restore_uuid}", color="red"
)
# Update state with any additional inputs (ignoring the 'id' key)
# Apply any additional inputs after restoration
filtered_inputs = {k: v for k, v in inputs.items() if k != "id"}
if filtered_inputs:
self._initialize_state(filtered_inputs)
# Emit FlowStartedEvent and log the start of the flow.
# Start flow execution
crewai_event_bus.emit(
self,
FlowStartedEvent(
@@ -782,18 +760,27 @@ class Flow(Generic[T], metaclass=FlowMeta):
f"Flow started with ID: {self.flow_id}", color="bold_magenta"
)
if inputs is not None and "id" not in inputs:
self._initialize_state(inputs)
async def run_flow():
return await self.kickoff_async()
return asyncio.run(run_flow())
@init_flow_main_trace
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
if not self._start_methods:
raise ValueError("No start method defined")
# Execute all start methods concurrently.
tasks = [
self._execute_start_method(start_method)
for start_method in self._start_methods
]
await asyncio.gather(*tasks)
final_output = self._method_outputs[-1] if self._method_outputs else None
# Emit FlowFinishedEvent after all processing is complete.
crewai_event_bus.emit(
self,
FlowFinishedEvent(

View File

@@ -0,0 +1,49 @@
import unittest
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
class TestTokenProcess(unittest.TestCase):
"""Test suite for TokenProcess class token counting functionality."""
def setUp(self):
"""Initialize a fresh TokenProcess instance before each test."""
self.token_process = TokenProcess()
def test_sum_cached_prompt_tokens_with_none(self):
"""Test that passing None to sum_cached_prompt_tokens doesn't modify the counter."""
initial_tokens = self.token_process.cached_prompt_tokens
self.token_process.sum_cached_prompt_tokens(None)
self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens)
def test_sum_cached_prompt_tokens_with_int(self):
"""Test that passing an integer correctly increments the counter."""
initial_tokens = self.token_process.cached_prompt_tokens
self.token_process.sum_cached_prompt_tokens(5)
self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens + 5)
def test_sum_cached_prompt_tokens_with_zero(self):
"""Test that passing zero doesn't modify the counter."""
initial_tokens = self.token_process.cached_prompt_tokens
self.token_process.sum_cached_prompt_tokens(0)
self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens)
def test_sum_cached_prompt_tokens_with_large_number(self):
"""Test that the counter works with large numbers."""
initial_tokens = self.token_process.cached_prompt_tokens
self.token_process.sum_cached_prompt_tokens(1000000)
self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens + 1000000)
def test_sum_cached_prompt_tokens_multiple_calls(self):
"""Test that multiple calls accumulate correctly, ignoring None values."""
initial_tokens = self.token_process.cached_prompt_tokens
self.token_process.sum_cached_prompt_tokens(5)
self.token_process.sum_cached_prompt_tokens(None)
self.token_process.sum_cached_prompt_tokens(3)
self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens + 8)
def test_sum_cached_prompt_tokens_with_negative(self):
"""Test that negative values raise ValueError."""
with self.assertRaises(ValueError) as context:
self.token_process.sum_cached_prompt_tokens(-1)
self.assertEqual(str(context.exception), "Token count cannot be negative")