mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 00:02:36 +00:00
Fix race condition in LLM callback system
This commit fixes a race condition in the LLM callback system where multiple LLM instances calling set_callbacks concurrently could cause callbacks to be removed before they fire. Changes: - Add class-level RLock (_callback_lock) to LLM class to synchronize access to global litellm callbacks - Wrap callback registration and LLM call execution in the lock for both call() and acall() methods - Use RLock (reentrant lock) to handle recursive calls without deadlock (e.g., when retrying with unsupported 'stop' parameter) - Remove sleep(5) workaround from test_llm_callback_replacement test - Add new test_llm_callback_lock_prevents_race_condition test to verify concurrent callback access is properly synchronized Fixes #4214 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
from time import sleep
|
||||
import threading
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
@@ -18,9 +18,15 @@ from pydantic import BaseModel
|
||||
import pytest
|
||||
|
||||
|
||||
# TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date
|
||||
@pytest.mark.vcr()
|
||||
def test_llm_callback_replacement():
|
||||
"""Test that callbacks are properly isolated between LLM instances.
|
||||
|
||||
This test verifies that the race condition fix (using _callback_lock) works
|
||||
correctly. Previously, this test required a sleep(5) workaround because
|
||||
callbacks were being modified globally without synchronization, causing
|
||||
one LLM instance's callbacks to interfere with another's.
|
||||
"""
|
||||
llm1 = LLM(model="gpt-4o-mini", is_litellm=True)
|
||||
llm2 = LLM(model="gpt-4o-mini", is_litellm=True)
|
||||
|
||||
@@ -37,7 +43,6 @@ def test_llm_callback_replacement():
|
||||
messages=[{"role": "user", "content": "Hello, world from another agent!"}],
|
||||
callbacks=[calc_handler_2],
|
||||
)
|
||||
sleep(5)
|
||||
usage_metrics_2 = calc_handler_2.token_cost_process.get_summary()
|
||||
|
||||
# The first handler should not have been updated
|
||||
@@ -46,6 +51,66 @@ def test_llm_callback_replacement():
|
||||
assert usage_metrics_1 == calc_handler_1.token_cost_process.get_summary()
|
||||
|
||||
|
||||
def test_llm_callback_lock_prevents_race_condition():
|
||||
"""Test that the _callback_lock prevents race conditions in concurrent LLM calls.
|
||||
|
||||
This test verifies that multiple threads can safely call LLM.call() with
|
||||
different callbacks without interfering with each other. The lock ensures
|
||||
that callbacks are properly isolated between concurrent calls.
|
||||
"""
|
||||
num_threads = 5
|
||||
results: list[int] = []
|
||||
errors: list[Exception] = []
|
||||
lock = threading.Lock()
|
||||
|
||||
def make_llm_call(thread_id: int, mock_completion: MagicMock) -> None:
|
||||
try:
|
||||
llm = LLM(model="gpt-4o-mini", is_litellm=True)
|
||||
calc_handler = TokenCalcHandler(token_cost_process=TokenProcess())
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = f"Response from thread {thread_id}"
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message = mock_message
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
mock_response.usage = {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 10,
|
||||
"total_tokens": 20,
|
||||
}
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
llm.call(
|
||||
messages=[{"role": "user", "content": f"Hello from thread {thread_id}"}],
|
||||
callbacks=[calc_handler],
|
||||
)
|
||||
|
||||
usage = calc_handler.token_cost_process.get_summary()
|
||||
with lock:
|
||||
results.append(usage.successful_requests)
|
||||
except Exception as e:
|
||||
with lock:
|
||||
errors.append(e)
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
threads = [
|
||||
threading.Thread(target=make_llm_call, args=(i, mock_completion))
|
||||
for i in range(num_threads)
|
||||
]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(errors) == 0, f"Errors occurred: {errors}"
|
||||
assert len(results) == num_threads
|
||||
assert all(
|
||||
r == 1 for r in results
|
||||
), f"Expected all callbacks to have 1 successful request, got {results}"
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_llm_call_with_string_input():
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
@@ -989,4 +1054,4 @@ async def test_usage_info_streaming_with_acall():
|
||||
assert llm._token_usage["completion_tokens"] > 0
|
||||
assert llm._token_usage["total_tokens"] > 0
|
||||
|
||||
assert len(result) > 0
|
||||
assert len(result) > 0
|
||||
|
||||
Reference in New Issue
Block a user