From f727b3f5e2557fc24b2710588f6c38c3037e9853 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Moura?= Date: Wed, 7 Feb 2024 23:09:36 -0800 Subject: [PATCH] fixing RPM controlelr being set unencessarily --- src/crewai/crew.py | 3 ++- src/crewai/utilities/rpm_controller.py | 7 +++++-- tests/crew_test.py | 19 +++++++++++++++++++ 3 files changed, 26 insertions(+), 3 deletions(-) diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 74bde48c3..16176f7da 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -124,7 +124,8 @@ class Crew(BaseModel): if self.agents: for agent in self.agents: agent.set_cache_handler(self._cache_handler) - agent.set_rpm_controller(self._rpm_controller) + if self.max_rpm: + agent.set_rpm_controller(self._rpm_controller) return self def _setup_from_config(self): diff --git a/src/crewai/utilities/rpm_controller.py b/src/crewai/utilities/rpm_controller.py index 84d4b641f..761760bf8 100644 --- a/src/crewai/utilities/rpm_controller.py +++ b/src/crewai/utilities/rpm_controller.py @@ -14,12 +14,14 @@ class RPMController(BaseModel): _current_rpm: int = PrivateAttr(default=0) _timer: threading.Timer | None = PrivateAttr(default=None) _lock: threading.Lock = PrivateAttr(default=None) + _shutdown_flag = False @model_validator(mode="after") def reset_counter(self): if self.max_rpm: - self._lock = threading.Lock() - self._reset_request_count() + if not self._shutdown_flag: + self._lock = threading.Lock() + self._reset_request_count() return self def check_or_wait(self): @@ -51,6 +53,7 @@ class RPMController(BaseModel): with self._lock: self._current_rpm = 0 if self._timer: + self._shutdown_flag = True self._timer.cancel() self._timer = threading.Timer(60.0, self._reset_request_count) self._timer.start() diff --git a/tests/crew_test.py b/tests/crew_test.py index 5b5a680a0..57bf715fc 100644 --- a/tests/crew_test.py +++ b/tests/crew_test.py @@ -356,6 +356,25 @@ def test_api_calls_throttling(capsys): moveon.assert_called() +def test_agents_rpm_is_never_set_if_crew_max_RPM_is_not_set(): + agent = Agent( + role="test role", + goal="test goal", + backstory="test backstory", + allow_delegation=False, + verbose=True, + ) + + task = Task( + description="just say hi!", + agent=agent, + ) + + Crew(agents=[agent], tasks=[task], verbose=2) + + assert agent._rpm_controller is None + + def test_async_task_execution(): import threading from unittest.mock import patch