diff --git a/src/crewai/llm.py b/src/crewai/llm.py index 741544662..8227a28e7 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -963,14 +963,15 @@ class LLM(BaseLLM): """ with suppress_warnings(): callback_types = [type(callback) for callback in callbacks] - for callback in litellm.success_callback[:]: - if type(callback) in callback_types: - litellm.success_callback.remove(callback) - - for callback in litellm._async_success_callback[:]: - if type(callback) in callback_types: - litellm._async_success_callback.remove(callback) - + + litellm.success_callback = [ + cb for cb in litellm.success_callback if type(cb) not in callback_types + ] + + litellm._async_success_callback = [ + cb for cb in litellm._async_success_callback if type(cb) not in callback_types + ] + litellm.callbacks = callbacks def set_env_callbacks(self): diff --git a/tests/test_set_callbacks_handles_removed_callbacks.py b/tests/test_set_callbacks_handles_removed_callbacks.py new file mode 100644 index 000000000..ef665dd57 --- /dev/null +++ b/tests/test_set_callbacks_handles_removed_callbacks.py @@ -0,0 +1,39 @@ +import pytest +import litellm +from typing import Any + +from crewai.llm import LLM + + +def test_set_callbacks_handles_removed_callbacks(): + """Test that set_callbacks handles the case where callbacks are removed during iteration.""" + class CustomCallback: + pass + + original_success_callback = litellm.success_callback + original_async_success_callback = litellm._async_success_callback + + try: + litellm.success_callback = [] + litellm._async_success_callback = [] + + llm = LLM(model="test-model") + + callback1 = CustomCallback() + callback2 = CustomCallback() + litellm.success_callback.append(callback1) + litellm.success_callback.append(callback2) + + new_callback = CustomCallback() + + litellm.success_callback.remove(callback1) + + llm.set_callbacks([new_callback]) + + assert litellm.callbacks == [new_callback] + + assert len([cb for cb in litellm.success_callback if isinstance(cb, CustomCallback)]) == 0 + + finally: + litellm.success_callback = original_success_callback + litellm._async_success_callback = original_async_success_callback