Compare commits

...

4 Commits

Author SHA1 Message Date
Devin AI
0a22cbc349 Enhance set_callbacks with improved type hints, error handling, and expanded test coverage
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-04-03 09:44:06 +00:00
Devin AI
4f5d18a2c9 Fix import sorting with ruff --fix
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-04-03 09:41:32 +00:00
Devin AI
f6571f114d Fix import sorting in test file
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-04-03 09:40:38 +00:00
Devin AI
c06eb56cf3 Fix litellm callback removal error (issue #2513)
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-04-03 09:39:10 +00:00
2 changed files with 137 additions and 12 deletions

View File

@@ -956,22 +956,42 @@ class LLM(BaseLLM):
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
return self.context_window_size
def set_callbacks(self, callbacks: List[Any]):
def set_callbacks(self, callbacks: List[Any]) -> None:
"""
Attempt to keep a single set of callbacks in litellm by removing old
duplicates and adding new ones.
This method safely updates the litellm callback lists by:
1. Identifying the types of new callbacks
2. Filtering out existing callbacks of the same types
3. Setting the new callbacks
Args:
callbacks: List of callback objects to set in litellm
Returns:
None
Note:
Uses list comprehension to avoid "list.remove(x): x not in list" errors
that can occur with direct removal during iteration.
"""
with suppress_warnings():
callback_types = [type(callback) for callback in callbacks]
for callback in litellm.success_callback[:]:
if type(callback) in callback_types:
litellm.success_callback.remove(callback)
for callback in litellm._async_success_callback[:]:
if type(callback) in callback_types:
litellm._async_success_callback.remove(callback)
litellm.callbacks = callbacks
try:
with suppress_warnings():
callback_types = [type(callback) for callback in callbacks]
litellm.success_callback = [
cb for cb in litellm.success_callback if type(cb) not in callback_types
]
litellm._async_success_callback = [
cb for cb in litellm._async_success_callback if type(cb) not in callback_types
]
litellm.callbacks = callbacks
except Exception as e:
logging.error(f"Error setting callbacks: {str(e)}")
raise
def set_env_callbacks(self):
"""

View File

@@ -0,0 +1,105 @@
from typing import Any, List
import litellm
import pytest
from crewai.llm import LLM
class CustomCallback:
"""A simple callback class for testing."""
pass
class DifferentCallback:
"""A different callback class for testing type differentiation."""
pass
@pytest.fixture
def reset_litellm_callbacks():
"""Fixture to reset litellm callbacks after each test."""
original_success_callback = litellm.success_callback
original_async_success_callback = litellm._async_success_callback
yield
litellm.success_callback = original_success_callback
litellm._async_success_callback = original_async_success_callback
def test_set_callbacks_handles_removed_callbacks(reset_litellm_callbacks):
"""Test that set_callbacks handles the case where callbacks are removed during iteration."""
litellm.success_callback = []
litellm._async_success_callback = []
llm = LLM(model="test-model")
callback1 = CustomCallback()
callback2 = CustomCallback()
litellm.success_callback.append(callback1)
litellm.success_callback.append(callback2)
new_callback = CustomCallback()
litellm.success_callback.remove(callback1)
llm.set_callbacks([new_callback])
assert litellm.callbacks == [new_callback]
assert len([cb for cb in litellm.success_callback if isinstance(cb, CustomCallback)]) == 0
@pytest.mark.parametrize("callback_count", [1, 3, 5])
def test_set_callbacks_with_different_sizes(callback_count, reset_litellm_callbacks):
"""Test with various numbers of callbacks."""
litellm.success_callback = []
litellm._async_success_callback = []
llm = LLM(model="test-model")
callbacks = [CustomCallback() for _ in range(callback_count)]
for callback in callbacks:
litellm.success_callback.append(callback)
new_callback = CustomCallback()
llm.set_callbacks([new_callback])
assert litellm.callbacks == [new_callback]
assert len([cb for cb in litellm.success_callback if isinstance(cb, CustomCallback)]) == 0
def test_set_callbacks_with_different_types(reset_litellm_callbacks):
"""Test that callbacks of different types are handled correctly."""
litellm.success_callback = []
litellm._async_success_callback = []
llm = LLM(model="test-model")
custom_callback = CustomCallback()
different_callback = DifferentCallback()
litellm.success_callback.append(custom_callback)
litellm.success_callback.append(different_callback)
llm.set_callbacks([CustomCallback()])
assert any(isinstance(cb, DifferentCallback) for cb in litellm.success_callback)
assert not any(isinstance(cb, CustomCallback) for cb in litellm.success_callback)
def test_set_callbacks_with_empty_list(reset_litellm_callbacks):
"""Test setting callbacks with an empty list."""
litellm.success_callback = []
litellm._async_success_callback = []
llm = LLM(model="test-model")
custom_callback = CustomCallback()
litellm.success_callback.append(custom_callback)
llm.set_callbacks([])
assert litellm.callbacks == []
assert custom_callback in litellm.success_callback