diff --git a/src/crewai/llm.py b/src/crewai/llm.py index 8227a28e7..b93b2b4b5 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -956,23 +956,42 @@ class LLM(BaseLLM): self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO) return self.context_window_size - def set_callbacks(self, callbacks: List[Any]): + def set_callbacks(self, callbacks: List[Any]) -> None: """ Attempt to keep a single set of callbacks in litellm by removing old duplicates and adding new ones. + + This method safely updates the litellm callback lists by: + 1. Identifying the types of new callbacks + 2. Filtering out existing callbacks of the same types + 3. Setting the new callbacks + + Args: + callbacks: List of callback objects to set in litellm + + Returns: + None + + Note: + Uses list comprehension to avoid "list.remove(x): x not in list" errors + that can occur with direct removal during iteration. """ - with suppress_warnings(): - callback_types = [type(callback) for callback in callbacks] - - litellm.success_callback = [ - cb for cb in litellm.success_callback if type(cb) not in callback_types - ] - - litellm._async_success_callback = [ - cb for cb in litellm._async_success_callback if type(cb) not in callback_types - ] - - litellm.callbacks = callbacks + try: + with suppress_warnings(): + callback_types = [type(callback) for callback in callbacks] + + litellm.success_callback = [ + cb for cb in litellm.success_callback if type(cb) not in callback_types + ] + + litellm._async_success_callback = [ + cb for cb in litellm._async_success_callback if type(cb) not in callback_types + ] + + litellm.callbacks = callbacks + except Exception as e: + logging.error(f"Error setting callbacks: {str(e)}") + raise def set_env_callbacks(self): """ diff --git a/tests/test_set_callbacks_handles_removed_callbacks.py b/tests/test_set_callbacks_handles_removed_callbacks.py index 25d221af9..fb67476b3 100644 --- a/tests/test_set_callbacks_handles_removed_callbacks.py +++ b/tests/test_set_callbacks_handles_removed_callbacks.py @@ -1,4 +1,4 @@ -from typing import Any +from typing import Any, List import litellm import pytest @@ -6,35 +6,100 @@ import pytest from crewai.llm import LLM -def test_set_callbacks_handles_removed_callbacks(): - """Test that set_callbacks handles the case where callbacks are removed during iteration.""" - class CustomCallback: - pass +class CustomCallback: + """A simple callback class for testing.""" + pass + +class DifferentCallback: + """A different callback class for testing type differentiation.""" + pass + + +@pytest.fixture +def reset_litellm_callbacks(): + """Fixture to reset litellm callbacks after each test.""" original_success_callback = litellm.success_callback original_async_success_callback = litellm._async_success_callback - try: - litellm.success_callback = [] - litellm._async_success_callback = [] - - llm = LLM(model="test-model") - - callback1 = CustomCallback() - callback2 = CustomCallback() - litellm.success_callback.append(callback1) - litellm.success_callback.append(callback2) - - new_callback = CustomCallback() - - litellm.success_callback.remove(callback1) - - llm.set_callbacks([new_callback]) - - assert litellm.callbacks == [new_callback] - - assert len([cb for cb in litellm.success_callback if isinstance(cb, CustomCallback)]) == 0 + yield - finally: - litellm.success_callback = original_success_callback - litellm._async_success_callback = original_async_success_callback + litellm.success_callback = original_success_callback + litellm._async_success_callback = original_async_success_callback + + +def test_set_callbacks_handles_removed_callbacks(reset_litellm_callbacks): + """Test that set_callbacks handles the case where callbacks are removed during iteration.""" + litellm.success_callback = [] + litellm._async_success_callback = [] + + llm = LLM(model="test-model") + + callback1 = CustomCallback() + callback2 = CustomCallback() + litellm.success_callback.append(callback1) + litellm.success_callback.append(callback2) + + new_callback = CustomCallback() + + litellm.success_callback.remove(callback1) + + llm.set_callbacks([new_callback]) + + assert litellm.callbacks == [new_callback] + assert len([cb for cb in litellm.success_callback if isinstance(cb, CustomCallback)]) == 0 + + +@pytest.mark.parametrize("callback_count", [1, 3, 5]) +def test_set_callbacks_with_different_sizes(callback_count, reset_litellm_callbacks): + """Test with various numbers of callbacks.""" + litellm.success_callback = [] + litellm._async_success_callback = [] + + llm = LLM(model="test-model") + + callbacks = [CustomCallback() for _ in range(callback_count)] + for callback in callbacks: + litellm.success_callback.append(callback) + + new_callback = CustomCallback() + + llm.set_callbacks([new_callback]) + + assert litellm.callbacks == [new_callback] + assert len([cb for cb in litellm.success_callback if isinstance(cb, CustomCallback)]) == 0 + + +def test_set_callbacks_with_different_types(reset_litellm_callbacks): + """Test that callbacks of different types are handled correctly.""" + litellm.success_callback = [] + litellm._async_success_callback = [] + + llm = LLM(model="test-model") + + custom_callback = CustomCallback() + different_callback = DifferentCallback() + + litellm.success_callback.append(custom_callback) + litellm.success_callback.append(different_callback) + + llm.set_callbacks([CustomCallback()]) + + assert any(isinstance(cb, DifferentCallback) for cb in litellm.success_callback) + assert not any(isinstance(cb, CustomCallback) for cb in litellm.success_callback) + + +def test_set_callbacks_with_empty_list(reset_litellm_callbacks): + """Test setting callbacks with an empty list.""" + litellm.success_callback = [] + litellm._async_success_callback = [] + + llm = LLM(model="test-model") + + custom_callback = CustomCallback() + litellm.success_callback.append(custom_callback) + + llm.set_callbacks([]) + + assert litellm.callbacks == [] + assert custom_callback in litellm.success_callback