Fix race condition in LLM callback system

This commit fixes a race condition in the LLM callback system where
multiple LLM instances calling set_callbacks concurrently could cause
callbacks to be removed before they fire.

Changes:
- Add class-level RLock (_callback_lock) to LLM class to synchronize
  access to global litellm callbacks
- Wrap callback registration and LLM call execution in the lock for
  both call() and acall() methods
- Use RLock (reentrant lock) to handle recursive calls without deadlock
  (e.g., when retrying with unsupported 'stop' parameter)
- Remove sleep(5) workaround from test_llm_callback_replacement test
- Add new test_llm_callback_lock_prevents_race_condition test to verify
  concurrent callback access is properly synchronized

Fixes #4214

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2026-01-10 21:11:01 +00:00
parent d60f7b360d
commit 2ada7aba97
2 changed files with 208 additions and 118 deletions

View File

@@ -341,6 +341,7 @@ class AccumulatedToolArgs(BaseModel):
class LLM(BaseLLM):
completion_cost: float | None = None
_callback_lock: threading.RLock = threading.RLock()
def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
"""Factory method that routes to native SDK or falls back to LiteLLM.
@@ -1144,7 +1145,7 @@ class LLM(BaseLLM):
if response_model:
params["response_model"] = response_model
response = litellm.completion(**params)
if hasattr(response,"usage") and not isinstance(response.usage, type) and response.usage:
usage_info = response.usage
self._track_token_usage_internal(usage_info)
@@ -1363,7 +1364,7 @@ class LLM(BaseLLM):
"""
full_response = ""
chunk_count = 0
usage_info = None
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
@@ -1657,78 +1658,92 @@ class LLM(BaseLLM):
raise ValueError("LLM call blocked by before_llm_call hook")
# --- 5) Set up callbacks if provided
# Use a class-level lock to synchronize access to global litellm callbacks.
# This prevents race conditions when multiple LLM instances call set_callbacks
# concurrently, which could cause callbacks to be removed before they fire.
with suppress_warnings():
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
# --- 6) Prepare parameters for the completion call
params = self._prepare_completion_params(messages, tools)
# --- 7) Make the completion call and handle response
if self.stream:
result = self._handle_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
else:
result = self._handle_non_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
if isinstance(result, str):
result = self._invoke_after_llm_call_hooks(
messages, result, from_agent
)
return result
except LLMContextLengthExceededError:
# Re-raise LLMContextLengthExceededError as it should be handled
# by the CrewAgentExecutor._invoke_loop method, which can then decide
# whether to summarize the content or abort based on the respect_context_window flag
raise
except Exception as e:
unsupported_stop = "Unsupported parameter" in str(
e
) and "'stop'" in str(e)
if unsupported_stop:
if (
"additional_drop_params" in self.additional_params
and isinstance(
self.additional_params["additional_drop_params"], list
with LLM._callback_lock:
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
# --- 6) Prepare parameters for the completion call
params = self._prepare_completion_params(messages, tools)
# --- 7) Make the completion call and handle response
if self.stream:
result = self._handle_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
):
self.additional_params["additional_drop_params"].append("stop")
else:
self.additional_params = {"additional_drop_params": ["stop"]}
result = self._handle_non_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
logging.info("Retrying LLM call without the unsupported 'stop'")
if isinstance(result, str):
result = self._invoke_after_llm_call_hooks(
messages, result, from_agent
)
return self.call(
messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
return result
except LLMContextLengthExceededError:
# Re-raise LLMContextLengthExceededError as it should be handled
# by the CrewAgentExecutor._invoke_loop method, which can then decide
# whether to summarize the content or abort based on the respect_context_window flag
raise
except Exception as e:
unsupported_stop = "Unsupported parameter" in str(
e
) and "'stop'" in str(e)
if unsupported_stop:
if (
"additional_drop_params" in self.additional_params
and isinstance(
self.additional_params["additional_drop_params"], list
)
):
self.additional_params["additional_drop_params"].append(
"stop"
)
else:
self.additional_params = {
"additional_drop_params": ["stop"]
}
logging.info(
"Retrying LLM call without the unsupported 'stop'"
)
# Recursive call happens inside the lock since we're using
# a reentrant-safe pattern (the lock is released when we
# exit the with block, and the recursive call will acquire
# it again)
return self.call(
messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent
),
)
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent
),
)
raise
raise
async def acall(
self,
@@ -1790,14 +1805,27 @@ class LLM(BaseLLM):
msg_role: Literal["assistant"] = "assistant"
message["role"] = msg_role
# Use a class-level lock to synchronize access to global litellm callbacks.
# This prevents race conditions when multiple LLM instances call set_callbacks
# concurrently, which could cause callbacks to be removed before they fire.
with suppress_warnings():
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
params = self._prepare_completion_params(messages, tools)
with LLM._callback_lock:
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
params = self._prepare_completion_params(messages, tools)
if self.stream:
return await self._ahandle_streaming_response(
if self.stream:
return await self._ahandle_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
return await self._ahandle_non_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
@@ -1805,52 +1833,49 @@ class LLM(BaseLLM):
from_agent=from_agent,
response_model=response_model,
)
except LLMContextLengthExceededError:
raise
except Exception as e:
unsupported_stop = "Unsupported parameter" in str(
e
) and "'stop'" in str(e)
return await self._ahandle_non_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
except LLMContextLengthExceededError:
raise
except Exception as e:
unsupported_stop = "Unsupported parameter" in str(
e
) and "'stop'" in str(e)
if unsupported_stop:
if (
"additional_drop_params" in self.additional_params
and isinstance(
self.additional_params["additional_drop_params"], list
)
):
self.additional_params["additional_drop_params"].append(
"stop"
)
else:
self.additional_params = {
"additional_drop_params": ["stop"]
}
if unsupported_stop:
if (
"additional_drop_params" in self.additional_params
and isinstance(
self.additional_params["additional_drop_params"], list
logging.info(
"Retrying LLM call without the unsupported 'stop'"
)
):
self.additional_params["additional_drop_params"].append("stop")
else:
self.additional_params = {"additional_drop_params": ["stop"]}
logging.info("Retrying LLM call without the unsupported 'stop'")
return await self.acall(
messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
return await self.acall(
messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent
),
)
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent
),
)
raise
raise
def _handle_emit_call_events(
self,

View File

@@ -1,6 +1,6 @@
import logging
import os
from time import sleep
import threading
from unittest.mock import MagicMock, patch
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
@@ -18,9 +18,15 @@ from pydantic import BaseModel
import pytest
# TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date
@pytest.mark.vcr()
def test_llm_callback_replacement():
"""Test that callbacks are properly isolated between LLM instances.
This test verifies that the race condition fix (using _callback_lock) works
correctly. Previously, this test required a sleep(5) workaround because
callbacks were being modified globally without synchronization, causing
one LLM instance's callbacks to interfere with another's.
"""
llm1 = LLM(model="gpt-4o-mini", is_litellm=True)
llm2 = LLM(model="gpt-4o-mini", is_litellm=True)
@@ -37,7 +43,6 @@ def test_llm_callback_replacement():
messages=[{"role": "user", "content": "Hello, world from another agent!"}],
callbacks=[calc_handler_2],
)
sleep(5)
usage_metrics_2 = calc_handler_2.token_cost_process.get_summary()
# The first handler should not have been updated
@@ -46,6 +51,66 @@ def test_llm_callback_replacement():
assert usage_metrics_1 == calc_handler_1.token_cost_process.get_summary()
def test_llm_callback_lock_prevents_race_condition():
"""Test that the _callback_lock prevents race conditions in concurrent LLM calls.
This test verifies that multiple threads can safely call LLM.call() with
different callbacks without interfering with each other. The lock ensures
that callbacks are properly isolated between concurrent calls.
"""
num_threads = 5
results: list[int] = []
errors: list[Exception] = []
lock = threading.Lock()
def make_llm_call(thread_id: int, mock_completion: MagicMock) -> None:
try:
llm = LLM(model="gpt-4o-mini", is_litellm=True)
calc_handler = TokenCalcHandler(token_cost_process=TokenProcess())
mock_message = MagicMock()
mock_message.content = f"Response from thread {thread_id}"
mock_choice = MagicMock()
mock_choice.message = mock_message
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_response.usage = {
"prompt_tokens": 10,
"completion_tokens": 10,
"total_tokens": 20,
}
mock_completion.return_value = mock_response
llm.call(
messages=[{"role": "user", "content": f"Hello from thread {thread_id}"}],
callbacks=[calc_handler],
)
usage = calc_handler.token_cost_process.get_summary()
with lock:
results.append(usage.successful_requests)
except Exception as e:
with lock:
errors.append(e)
with patch("litellm.completion") as mock_completion:
threads = [
threading.Thread(target=make_llm_call, args=(i, mock_completion))
for i in range(num_threads)
]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(errors) == 0, f"Errors occurred: {errors}"
assert len(results) == num_threads
assert all(
r == 1 for r in results
), f"Expected all callbacks to have 1 successful request, got {results}"
@pytest.mark.vcr()
def test_llm_call_with_string_input():
llm = LLM(model="gpt-4o-mini")
@@ -989,4 +1054,4 @@ async def test_usage_info_streaming_with_acall():
assert llm._token_usage["completion_tokens"] > 0
assert llm._token_usage["total_tokens"] > 0
assert len(result) > 0
assert len(result) > 0