Compare commits

..

1 Commits

Author SHA1 Message Date
Devin AI
2ada7aba97 Fix race condition in LLM callback system
This commit fixes a race condition in the LLM callback system where
multiple LLM instances calling set_callbacks concurrently could cause
callbacks to be removed before they fire.

Changes:
- Add class-level RLock (_callback_lock) to LLM class to synchronize
  access to global litellm callbacks
- Wrap callback registration and LLM call execution in the lock for
  both call() and acall() methods
- Use RLock (reentrant lock) to handle recursive calls without deadlock
  (e.g., when retrying with unsupported 'stop' parameter)
- Remove sleep(5) workaround from test_llm_callback_replacement test
- Add new test_llm_callback_lock_prevents_race_condition test to verify
  concurrent callback access is properly synchronized

Fixes #4214

Co-Authored-By: João <joao@crewai.com>
2026-01-10 21:11:01 +00:00
4 changed files with 226 additions and 404 deletions

View File

@@ -341,6 +341,7 @@ class AccumulatedToolArgs(BaseModel):
class LLM(BaseLLM):
completion_cost: float | None = None
_callback_lock: threading.RLock = threading.RLock()
def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
"""Factory method that routes to native SDK or falls back to LiteLLM.
@@ -1144,7 +1145,7 @@ class LLM(BaseLLM):
if response_model:
params["response_model"] = response_model
response = litellm.completion(**params)
if hasattr(response,"usage") and not isinstance(response.usage, type) and response.usage:
usage_info = response.usage
self._track_token_usage_internal(usage_info)
@@ -1363,7 +1364,7 @@ class LLM(BaseLLM):
"""
full_response = ""
chunk_count = 0
usage_info = None
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
@@ -1657,78 +1658,92 @@ class LLM(BaseLLM):
raise ValueError("LLM call blocked by before_llm_call hook")
# --- 5) Set up callbacks if provided
# Use a class-level lock to synchronize access to global litellm callbacks.
# This prevents race conditions when multiple LLM instances call set_callbacks
# concurrently, which could cause callbacks to be removed before they fire.
with suppress_warnings():
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
# --- 6) Prepare parameters for the completion call
params = self._prepare_completion_params(messages, tools)
# --- 7) Make the completion call and handle response
if self.stream:
result = self._handle_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
else:
result = self._handle_non_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
if isinstance(result, str):
result = self._invoke_after_llm_call_hooks(
messages, result, from_agent
)
return result
except LLMContextLengthExceededError:
# Re-raise LLMContextLengthExceededError as it should be handled
# by the CrewAgentExecutor._invoke_loop method, which can then decide
# whether to summarize the content or abort based on the respect_context_window flag
raise
except Exception as e:
unsupported_stop = "Unsupported parameter" in str(
e
) and "'stop'" in str(e)
if unsupported_stop:
if (
"additional_drop_params" in self.additional_params
and isinstance(
self.additional_params["additional_drop_params"], list
with LLM._callback_lock:
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
# --- 6) Prepare parameters for the completion call
params = self._prepare_completion_params(messages, tools)
# --- 7) Make the completion call and handle response
if self.stream:
result = self._handle_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
):
self.additional_params["additional_drop_params"].append("stop")
else:
self.additional_params = {"additional_drop_params": ["stop"]}
result = self._handle_non_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
logging.info("Retrying LLM call without the unsupported 'stop'")
if isinstance(result, str):
result = self._invoke_after_llm_call_hooks(
messages, result, from_agent
)
return self.call(
messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
return result
except LLMContextLengthExceededError:
# Re-raise LLMContextLengthExceededError as it should be handled
# by the CrewAgentExecutor._invoke_loop method, which can then decide
# whether to summarize the content or abort based on the respect_context_window flag
raise
except Exception as e:
unsupported_stop = "Unsupported parameter" in str(
e
) and "'stop'" in str(e)
if unsupported_stop:
if (
"additional_drop_params" in self.additional_params
and isinstance(
self.additional_params["additional_drop_params"], list
)
):
self.additional_params["additional_drop_params"].append(
"stop"
)
else:
self.additional_params = {
"additional_drop_params": ["stop"]
}
logging.info(
"Retrying LLM call without the unsupported 'stop'"
)
# Recursive call happens inside the lock since we're using
# a reentrant-safe pattern (the lock is released when we
# exit the with block, and the recursive call will acquire
# it again)
return self.call(
messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent
),
)
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent
),
)
raise
raise
async def acall(
self,
@@ -1790,14 +1805,27 @@ class LLM(BaseLLM):
msg_role: Literal["assistant"] = "assistant"
message["role"] = msg_role
# Use a class-level lock to synchronize access to global litellm callbacks.
# This prevents race conditions when multiple LLM instances call set_callbacks
# concurrently, which could cause callbacks to be removed before they fire.
with suppress_warnings():
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
params = self._prepare_completion_params(messages, tools)
with LLM._callback_lock:
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
params = self._prepare_completion_params(messages, tools)
if self.stream:
return await self._ahandle_streaming_response(
if self.stream:
return await self._ahandle_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
return await self._ahandle_non_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
@@ -1805,52 +1833,49 @@ class LLM(BaseLLM):
from_agent=from_agent,
response_model=response_model,
)
except LLMContextLengthExceededError:
raise
except Exception as e:
unsupported_stop = "Unsupported parameter" in str(
e
) and "'stop'" in str(e)
return await self._ahandle_non_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
except LLMContextLengthExceededError:
raise
except Exception as e:
unsupported_stop = "Unsupported parameter" in str(
e
) and "'stop'" in str(e)
if unsupported_stop:
if (
"additional_drop_params" in self.additional_params
and isinstance(
self.additional_params["additional_drop_params"], list
)
):
self.additional_params["additional_drop_params"].append(
"stop"
)
else:
self.additional_params = {
"additional_drop_params": ["stop"]
}
if unsupported_stop:
if (
"additional_drop_params" in self.additional_params
and isinstance(
self.additional_params["additional_drop_params"], list
logging.info(
"Retrying LLM call without the unsupported 'stop'"
)
):
self.additional_params["additional_drop_params"].append("stop")
else:
self.additional_params = {"additional_drop_params": ["stop"]}
logging.info("Retrying LLM call without the unsupported 'stop'")
return await self.acall(
messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
return await self.acall(
messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent
),
)
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent
),
)
raise
raise
def _handle_emit_call_events(
self,

View File

@@ -2,11 +2,8 @@ from datetime import datetime
import json
import os
import pickle
import tempfile
import threading
from typing import Any, TypedDict
import portalocker
from typing_extensions import Unpack
@@ -126,15 +123,10 @@ class FileHandler:
class PickleHandler:
"""Thread-safe handler for saving and loading data using pickle.
This class provides thread-safe file operations using portalocker for
cross-process file locking and atomic write operations to prevent
data corruption during concurrent access.
"""Handler for saving and loading data using pickle.
Attributes:
file_path: The path to the pickle file.
_lock: Threading lock for thread-safe operations within the same process.
"""
def __init__(self, file_name: str) -> None:
@@ -149,62 +141,34 @@ class PickleHandler:
file_name += ".pkl"
self.file_path = os.path.join(os.getcwd(), file_name)
self._lock = threading.Lock()
def initialize_file(self) -> None:
"""Initialize the file with an empty dictionary and overwrite any existing data."""
self.save({})
def save(self, data: Any) -> None:
"""Save the data to the specified file using pickle with thread-safe atomic writes.
This method uses a two-phase approach for thread safety:
1. Threading lock for same-process thread safety
2. Atomic write (write to temp file, then rename) for cross-process safety
and data integrity
"""
Save the data to the specified file using pickle.
Args:
data: The data to be saved to the file.
data: The data to be saved to the file.
"""
with self._lock:
dir_name = os.path.dirname(self.file_path) or os.getcwd()
fd, temp_path = tempfile.mkstemp(
suffix=".pkl.tmp", prefix="pickle_", dir=dir_name
)
try:
with os.fdopen(fd, "wb") as f:
pickle.dump(obj=data, file=f)
f.flush()
os.fsync(f.fileno())
os.replace(temp_path, self.file_path)
except Exception:
if os.path.exists(temp_path):
os.unlink(temp_path)
raise
with open(self.file_path, "wb") as f:
pickle.dump(obj=data, file=f)
def load(self) -> Any:
"""Load the data from the specified file using pickle with thread-safe locking.
This method uses portalocker for cross-process read locking to ensure
data consistency when multiple processes may be accessing the file.
"""Load the data from the specified file using pickle.
Returns:
The data loaded from the file, or an empty dictionary if the file
does not exist or is empty.
The data loaded from the file.
"""
with self._lock:
if (
not os.path.exists(self.file_path)
or os.path.getsize(self.file_path) == 0
):
return {}
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
return {} # Return an empty dictionary if the file does not exist or is empty
with portalocker.Lock(
self.file_path, "rb", flags=portalocker.LOCK_SH
) as file:
try:
return pickle.load(file) # noqa: S301
except EOFError:
return {}
except Exception:
raise
with open(self.file_path, "rb") as file:
try:
return pickle.load(file) # noqa: S301
except EOFError:
return {} # Return an empty dictionary if the file is empty or corrupted
except Exception:
raise # Raise any other exceptions that occur during loading

View File

@@ -1,6 +1,6 @@
import logging
import os
from time import sleep
import threading
from unittest.mock import MagicMock, patch
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
@@ -18,9 +18,15 @@ from pydantic import BaseModel
import pytest
# TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date
@pytest.mark.vcr()
def test_llm_callback_replacement():
"""Test that callbacks are properly isolated between LLM instances.
This test verifies that the race condition fix (using _callback_lock) works
correctly. Previously, this test required a sleep(5) workaround because
callbacks were being modified globally without synchronization, causing
one LLM instance's callbacks to interfere with another's.
"""
llm1 = LLM(model="gpt-4o-mini", is_litellm=True)
llm2 = LLM(model="gpt-4o-mini", is_litellm=True)
@@ -37,7 +43,6 @@ def test_llm_callback_replacement():
messages=[{"role": "user", "content": "Hello, world from another agent!"}],
callbacks=[calc_handler_2],
)
sleep(5)
usage_metrics_2 = calc_handler_2.token_cost_process.get_summary()
# The first handler should not have been updated
@@ -46,6 +51,66 @@ def test_llm_callback_replacement():
assert usage_metrics_1 == calc_handler_1.token_cost_process.get_summary()
def test_llm_callback_lock_prevents_race_condition():
"""Test that the _callback_lock prevents race conditions in concurrent LLM calls.
This test verifies that multiple threads can safely call LLM.call() with
different callbacks without interfering with each other. The lock ensures
that callbacks are properly isolated between concurrent calls.
"""
num_threads = 5
results: list[int] = []
errors: list[Exception] = []
lock = threading.Lock()
def make_llm_call(thread_id: int, mock_completion: MagicMock) -> None:
try:
llm = LLM(model="gpt-4o-mini", is_litellm=True)
calc_handler = TokenCalcHandler(token_cost_process=TokenProcess())
mock_message = MagicMock()
mock_message.content = f"Response from thread {thread_id}"
mock_choice = MagicMock()
mock_choice.message = mock_message
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_response.usage = {
"prompt_tokens": 10,
"completion_tokens": 10,
"total_tokens": 20,
}
mock_completion.return_value = mock_response
llm.call(
messages=[{"role": "user", "content": f"Hello from thread {thread_id}"}],
callbacks=[calc_handler],
)
usage = calc_handler.token_cost_process.get_summary()
with lock:
results.append(usage.successful_requests)
except Exception as e:
with lock:
errors.append(e)
with patch("litellm.completion") as mock_completion:
threads = [
threading.Thread(target=make_llm_call, args=(i, mock_completion))
for i in range(num_threads)
]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(errors) == 0, f"Errors occurred: {errors}"
assert len(results) == num_threads
assert all(
r == 1 for r in results
), f"Expected all callbacks to have 1 successful request, got {results}"
@pytest.mark.vcr()
def test_llm_call_with_string_input():
llm = LLM(model="gpt-4o-mini")
@@ -989,4 +1054,4 @@ async def test_usage_info_streaming_with_acall():
assert llm._token_usage["completion_tokens"] > 0
assert llm._token_usage["total_tokens"] > 0
assert len(result) > 0
assert len(result) > 0

View File

@@ -1,8 +1,6 @@
import os
import threading
import unittest
import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed
import pytest
from crewai.utilities.file_handler import PickleHandler
@@ -10,6 +8,7 @@ from crewai.utilities.file_handler import PickleHandler
class TestPickleHandler(unittest.TestCase):
def setUp(self):
# Use a unique file name for each test to avoid race conditions in parallel test execution
unique_id = str(uuid.uuid4())
self.file_name = f"test_data_{unique_id}.pkl"
self.file_path = os.path.join(os.getcwd(), self.file_name)
@@ -48,234 +47,3 @@ class TestPickleHandler(unittest.TestCase):
assert str(exc.value) == "pickle data was truncated"
assert "<class '_pickle.UnpicklingError'>" == str(exc.type)
class TestPickleHandlerThreadSafety(unittest.TestCase):
"""Tests for thread-safety of PickleHandler operations."""
def setUp(self):
unique_id = str(uuid.uuid4())
self.file_name = f"test_thread_safe_{unique_id}.pkl"
self.file_path = os.path.join(os.getcwd(), self.file_name)
self.handler = PickleHandler(self.file_name)
def tearDown(self):
if os.path.exists(self.file_path):
os.remove(self.file_path)
def test_concurrent_writes_same_handler(self):
"""Test that concurrent writes from multiple threads using the same handler don't corrupt data."""
num_threads = 10
num_writes_per_thread = 20
errors: list[Exception] = []
write_count = 0
count_lock = threading.Lock()
def write_data(thread_id: int) -> None:
nonlocal write_count
for i in range(num_writes_per_thread):
try:
data = {"thread": thread_id, "iteration": i, "data": f"value_{thread_id}_{i}"}
self.handler.save(data)
with count_lock:
write_count += 1
except Exception as e:
errors.append(e)
threads = []
for i in range(num_threads):
t = threading.Thread(target=write_data, args=(i,))
threads.append(t)
t.start()
for t in threads:
t.join()
assert len(errors) == 0, f"Errors occurred during concurrent writes: {errors}"
assert write_count == num_threads * num_writes_per_thread
loaded_data = self.handler.load()
assert isinstance(loaded_data, dict)
assert "thread" in loaded_data
assert "iteration" in loaded_data
def test_concurrent_reads_same_handler(self):
"""Test that concurrent reads from multiple threads don't cause issues."""
test_data = {"key": "value", "nested": {"a": 1, "b": 2}}
self.handler.save(test_data)
num_threads = 20
results: list[dict] = []
errors: list[Exception] = []
results_lock = threading.Lock()
def read_data() -> None:
try:
data = self.handler.load()
with results_lock:
results.append(data)
except Exception as e:
errors.append(e)
threads = []
for _ in range(num_threads):
t = threading.Thread(target=read_data)
threads.append(t)
t.start()
for t in threads:
t.join()
assert len(errors) == 0, f"Errors occurred during concurrent reads: {errors}"
assert len(results) == num_threads
for result in results:
assert result == test_data
def test_concurrent_read_write_same_handler(self):
"""Test that concurrent reads and writes don't corrupt data or cause errors."""
initial_data = {"counter": 0}
self.handler.save(initial_data)
num_writers = 5
num_readers = 10
writes_per_thread = 10
reads_per_thread = 20
write_errors: list[Exception] = []
read_errors: list[Exception] = []
read_results: list[dict] = []
results_lock = threading.Lock()
def writer(thread_id: int) -> None:
for i in range(writes_per_thread):
try:
data = {"writer": thread_id, "write_num": i}
self.handler.save(data)
except Exception as e:
write_errors.append(e)
def reader() -> None:
for _ in range(reads_per_thread):
try:
data = self.handler.load()
with results_lock:
read_results.append(data)
except Exception as e:
read_errors.append(e)
threads = []
for i in range(num_writers):
t = threading.Thread(target=writer, args=(i,))
threads.append(t)
for _ in range(num_readers):
t = threading.Thread(target=reader)
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()
assert len(write_errors) == 0, f"Write errors: {write_errors}"
assert len(read_errors) == 0, f"Read errors: {read_errors}"
for result in read_results:
assert isinstance(result, dict)
def test_atomic_write_no_partial_data(self):
"""Test that atomic writes prevent partial/corrupted data from being read."""
large_data = {"key": "x" * 100000, "numbers": list(range(10000))}
num_iterations = 50
errors: list[Exception] = []
corruption_detected = False
corruption_lock = threading.Lock()
def writer() -> None:
for _ in range(num_iterations):
try:
self.handler.save(large_data)
except Exception as e:
errors.append(e)
def reader() -> None:
nonlocal corruption_detected
for _ in range(num_iterations * 2):
try:
data = self.handler.load()
if data and data != {} and data != large_data:
with corruption_lock:
corruption_detected = True
except Exception as e:
errors.append(e)
writer_thread = threading.Thread(target=writer)
reader_thread = threading.Thread(target=reader)
writer_thread.start()
reader_thread.start()
writer_thread.join()
reader_thread.join()
assert len(errors) == 0, f"Errors occurred: {errors}"
assert not corruption_detected, "Partial/corrupted data was read"
def test_thread_pool_concurrent_operations(self):
"""Test thread safety using ThreadPoolExecutor for more realistic concurrent access."""
num_operations = 100
errors: list[Exception] = []
def operation(op_id: int) -> str:
try:
if op_id % 3 == 0:
self.handler.save({"op_id": op_id, "type": "write"})
return f"write_{op_id}"
else:
data = self.handler.load()
return f"read_{op_id}_{type(data).__name__}"
except Exception as e:
errors.append(e)
return f"error_{op_id}"
with ThreadPoolExecutor(max_workers=20) as executor:
futures = [executor.submit(operation, i) for i in range(num_operations)]
results = [f.result() for f in as_completed(futures)]
assert len(errors) == 0, f"Errors occurred: {errors}"
assert len(results) == num_operations
def test_multiple_handlers_same_file(self):
"""Test that multiple PickleHandler instances for the same file work correctly."""
handler1 = PickleHandler(self.file_name)
handler2 = PickleHandler(self.file_name)
num_operations = 50
errors: list[Exception] = []
def use_handler1() -> None:
for i in range(num_operations):
try:
handler1.save({"handler": 1, "iteration": i})
except Exception as e:
errors.append(e)
def use_handler2() -> None:
for i in range(num_operations):
try:
handler2.save({"handler": 2, "iteration": i})
except Exception as e:
errors.append(e)
t1 = threading.Thread(target=use_handler1)
t2 = threading.Thread(target=use_handler2)
t1.start()
t2.start()
t1.join()
t2.join()
assert len(errors) == 0, f"Errors occurred: {errors}"
final_data = self.handler.load()
assert isinstance(final_data, dict)
assert "handler" in final_data
assert "iteration" in final_data