mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-25 20:32:36 +00:00
Compare commits
6 Commits
bugfix/mem
...
devin/1740
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e78efb047f | ||
|
|
d58dc08511 | ||
|
|
9ed21c4b0e | ||
|
|
4e84b98ac2 | ||
|
|
9f7f1cdb54 | ||
|
|
3f02d10626 |
@@ -1,5 +1,9 @@
|
||||
import logging
|
||||
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TokenProcess:
|
||||
def __init__(self) -> None:
|
||||
@@ -17,7 +21,21 @@ class TokenProcess:
|
||||
self.completion_tokens += tokens
|
||||
self.total_tokens += tokens
|
||||
|
||||
def sum_cached_prompt_tokens(self, tokens: int) -> None:
|
||||
def sum_cached_prompt_tokens(self, tokens: int | None) -> None:
|
||||
"""
|
||||
Adds the given token count to cached prompt tokens.
|
||||
|
||||
Args:
|
||||
tokens (int | None): Number of tokens to add. None values are ignored.
|
||||
|
||||
Raises:
|
||||
ValueError: If tokens is negative.
|
||||
"""
|
||||
if tokens is None:
|
||||
logger.debug("Received None value for token count")
|
||||
return
|
||||
if tokens < 0:
|
||||
raise ValueError("Token count cannot be negative")
|
||||
self.cached_prompt_tokens += tokens
|
||||
|
||||
def sum_successful_requests(self, requests: int) -> None:
|
||||
|
||||
@@ -1278,11 +1278,11 @@ class Crew(BaseModel):
|
||||
def _reset_all_memories(self) -> None:
|
||||
"""Reset all available memory systems."""
|
||||
memory_systems = [
|
||||
("short term", getattr(self, "_short_term_memory", None)),
|
||||
("entity", getattr(self, "_entity_memory", None)),
|
||||
("long term", getattr(self, "_long_term_memory", None)),
|
||||
("task output", getattr(self, "_task_output_handler", None)),
|
||||
("knowledge", getattr(self, "knowledge", None)),
|
||||
("short term", self._short_term_memory),
|
||||
("entity", self._entity_memory),
|
||||
("long term", self._long_term_memory),
|
||||
("task output", self._task_output_handler),
|
||||
("knowledge", self.knowledge),
|
||||
]
|
||||
|
||||
for name, system in memory_systems:
|
||||
|
||||
@@ -713,35 +713,16 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}")
|
||||
|
||||
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""
|
||||
Start the flow execution in a synchronous context.
|
||||
|
||||
This method wraps kickoff_async so that all state initialization and event
|
||||
emission is handled in the asynchronous method.
|
||||
"""
|
||||
|
||||
async def run_flow():
|
||||
return await self.kickoff_async(inputs)
|
||||
|
||||
return asyncio.run(run_flow())
|
||||
|
||||
@init_flow_main_trace
|
||||
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""
|
||||
Start the flow execution asynchronously.
|
||||
|
||||
This method performs state restoration (if an 'id' is provided and persistence is available)
|
||||
and updates the flow state with any additional inputs. It then emits the FlowStartedEvent,
|
||||
logs the flow startup, and executes all start methods. Once completed, it emits the
|
||||
FlowFinishedEvent and returns the final output.
|
||||
"""Start the flow execution.
|
||||
|
||||
Args:
|
||||
inputs: Optional dictionary containing input values and/or a state ID for restoration.
|
||||
|
||||
Returns:
|
||||
The final output from the flow, which is the result of the last executed method.
|
||||
inputs: Optional dictionary containing input values and potentially a state ID to restore
|
||||
"""
|
||||
if inputs:
|
||||
# Handle state restoration if ID is provided in inputs
|
||||
if inputs and "id" in inputs and self._persistence is not None:
|
||||
restore_uuid = inputs["id"]
|
||||
stored_state = self._persistence.load_state(restore_uuid)
|
||||
|
||||
# Override the id in the state if it exists in inputs
|
||||
if "id" in inputs:
|
||||
if isinstance(self._state, dict):
|
||||
@@ -749,27 +730,24 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
elif isinstance(self._state, BaseModel):
|
||||
setattr(self._state, "id", inputs["id"])
|
||||
|
||||
# If persistence is enabled, attempt to restore the stored state using the provided id.
|
||||
if "id" in inputs and self._persistence is not None:
|
||||
restore_uuid = inputs["id"]
|
||||
stored_state = self._persistence.load_state(restore_uuid)
|
||||
if stored_state:
|
||||
self._log_flow_event(
|
||||
f"Loading flow state from memory for UUID: {restore_uuid}",
|
||||
color="yellow",
|
||||
)
|
||||
self._restore_state(stored_state)
|
||||
else:
|
||||
self._log_flow_event(
|
||||
f"No flow state found for UUID: {restore_uuid}", color="red"
|
||||
)
|
||||
if stored_state:
|
||||
self._log_flow_event(
|
||||
f"Loading flow state from memory for UUID: {restore_uuid}",
|
||||
color="yellow",
|
||||
)
|
||||
# Restore the state
|
||||
self._restore_state(stored_state)
|
||||
else:
|
||||
self._log_flow_event(
|
||||
f"No flow state found for UUID: {restore_uuid}", color="red"
|
||||
)
|
||||
|
||||
# Update state with any additional inputs (ignoring the 'id' key)
|
||||
# Apply any additional inputs after restoration
|
||||
filtered_inputs = {k: v for k, v in inputs.items() if k != "id"}
|
||||
if filtered_inputs:
|
||||
self._initialize_state(filtered_inputs)
|
||||
|
||||
# Emit FlowStartedEvent and log the start of the flow.
|
||||
# Start flow execution
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
FlowStartedEvent(
|
||||
@@ -782,18 +760,27 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
f"Flow started with ID: {self.flow_id}", color="bold_magenta"
|
||||
)
|
||||
|
||||
if inputs is not None and "id" not in inputs:
|
||||
self._initialize_state(inputs)
|
||||
|
||||
async def run_flow():
|
||||
return await self.kickoff_async()
|
||||
|
||||
return asyncio.run(run_flow())
|
||||
|
||||
@init_flow_main_trace
|
||||
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
|
||||
if not self._start_methods:
|
||||
raise ValueError("No start method defined")
|
||||
|
||||
# Execute all start methods concurrently.
|
||||
tasks = [
|
||||
self._execute_start_method(start_method)
|
||||
for start_method in self._start_methods
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
final_output = self._method_outputs[-1] if self._method_outputs else None
|
||||
|
||||
# Emit FlowFinishedEvent after all processing is complete.
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
FlowFinishedEvent(
|
||||
|
||||
49
tests/utilities/test_token_process.py
Normal file
49
tests/utilities/test_token_process.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import unittest
|
||||
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
|
||||
|
||||
class TestTokenProcess(unittest.TestCase):
|
||||
"""Test suite for TokenProcess class token counting functionality."""
|
||||
|
||||
def setUp(self):
|
||||
"""Initialize a fresh TokenProcess instance before each test."""
|
||||
self.token_process = TokenProcess()
|
||||
|
||||
def test_sum_cached_prompt_tokens_with_none(self):
|
||||
"""Test that passing None to sum_cached_prompt_tokens doesn't modify the counter."""
|
||||
initial_tokens = self.token_process.cached_prompt_tokens
|
||||
self.token_process.sum_cached_prompt_tokens(None)
|
||||
self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens)
|
||||
|
||||
def test_sum_cached_prompt_tokens_with_int(self):
|
||||
"""Test that passing an integer correctly increments the counter."""
|
||||
initial_tokens = self.token_process.cached_prompt_tokens
|
||||
self.token_process.sum_cached_prompt_tokens(5)
|
||||
self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens + 5)
|
||||
|
||||
def test_sum_cached_prompt_tokens_with_zero(self):
|
||||
"""Test that passing zero doesn't modify the counter."""
|
||||
initial_tokens = self.token_process.cached_prompt_tokens
|
||||
self.token_process.sum_cached_prompt_tokens(0)
|
||||
self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens)
|
||||
|
||||
def test_sum_cached_prompt_tokens_with_large_number(self):
|
||||
"""Test that the counter works with large numbers."""
|
||||
initial_tokens = self.token_process.cached_prompt_tokens
|
||||
self.token_process.sum_cached_prompt_tokens(1000000)
|
||||
self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens + 1000000)
|
||||
|
||||
def test_sum_cached_prompt_tokens_multiple_calls(self):
|
||||
"""Test that multiple calls accumulate correctly, ignoring None values."""
|
||||
initial_tokens = self.token_process.cached_prompt_tokens
|
||||
self.token_process.sum_cached_prompt_tokens(5)
|
||||
self.token_process.sum_cached_prompt_tokens(None)
|
||||
self.token_process.sum_cached_prompt_tokens(3)
|
||||
self.assertEqual(self.token_process.cached_prompt_tokens, initial_tokens + 8)
|
||||
|
||||
def test_sum_cached_prompt_tokens_with_negative(self):
|
||||
"""Test that negative values raise ValueError."""
|
||||
with self.assertRaises(ValueError) as context:
|
||||
self.token_process.sum_cached_prompt_tokens(-1)
|
||||
self.assertEqual(str(context.exception), "Token count cannot be negative")
|
||||
Reference in New Issue
Block a user