Compare commits

..

1 Commits

Author SHA1 Message Date
Devin AI
2ada7aba97 Fix race condition in LLM callback system
This commit fixes a race condition in the LLM callback system where
multiple LLM instances calling set_callbacks concurrently could cause
callbacks to be removed before they fire.

Changes:
- Add class-level RLock (_callback_lock) to LLM class to synchronize
  access to global litellm callbacks
- Wrap callback registration and LLM call execution in the lock for
  both call() and acall() methods
- Use RLock (reentrant lock) to handle recursive calls without deadlock
  (e.g., when retrying with unsupported 'stop' parameter)
- Remove sleep(5) workaround from test_llm_callback_replacement test
- Add new test_llm_callback_lock_prevents_race_condition test to verify
  concurrent callback access is properly synchronized

Fixes #4214

Co-Authored-By: João <joao@crewai.com>
2026-01-10 21:11:01 +00:00
4 changed files with 214 additions and 211 deletions

View File

@@ -1,10 +1,10 @@
from functools import lru_cache
import subprocess
class Repository:
def __init__(self, path: str = ".") -> None:
self.path = path
self._is_git_repo_cache: bool | None = None
if not self.is_git_installed():
raise ValueError("Git is not installed or not found in your PATH.")
@@ -40,26 +40,22 @@ class Repository:
encoding="utf-8",
).strip()
@lru_cache(maxsize=None) # noqa: B019
def is_git_repo(self) -> bool:
"""Check if the current directory is a git repository.
The result is cached at the instance level to avoid redundant checks
while allowing proper garbage collection of Repository instances.
Notes:
- TODO: This method is cached to avoid redundant checks, but using lru_cache on methods can lead to memory leaks
"""
if self._is_git_repo_cache is not None:
return self._is_git_repo_cache
try:
subprocess.check_output(
["git", "rev-parse", "--is-inside-work-tree"], # noqa: S607
cwd=self.path,
encoding="utf-8",
)
self._is_git_repo_cache = True
return True
except subprocess.CalledProcessError:
self._is_git_repo_cache = False
return self._is_git_repo_cache
return False
def has_uncommitted_changes(self) -> bool:
"""Check if the repository has uncommitted changes."""

View File

@@ -341,6 +341,7 @@ class AccumulatedToolArgs(BaseModel):
class LLM(BaseLLM):
completion_cost: float | None = None
_callback_lock: threading.RLock = threading.RLock()
def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
"""Factory method that routes to native SDK or falls back to LiteLLM.
@@ -1144,7 +1145,7 @@ class LLM(BaseLLM):
if response_model:
params["response_model"] = response_model
response = litellm.completion(**params)
if hasattr(response,"usage") and not isinstance(response.usage, type) and response.usage:
usage_info = response.usage
self._track_token_usage_internal(usage_info)
@@ -1363,7 +1364,7 @@ class LLM(BaseLLM):
"""
full_response = ""
chunk_count = 0
usage_info = None
accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict(
@@ -1657,78 +1658,92 @@ class LLM(BaseLLM):
raise ValueError("LLM call blocked by before_llm_call hook")
# --- 5) Set up callbacks if provided
# Use a class-level lock to synchronize access to global litellm callbacks.
# This prevents race conditions when multiple LLM instances call set_callbacks
# concurrently, which could cause callbacks to be removed before they fire.
with suppress_warnings():
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
# --- 6) Prepare parameters for the completion call
params = self._prepare_completion_params(messages, tools)
# --- 7) Make the completion call and handle response
if self.stream:
result = self._handle_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
else:
result = self._handle_non_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
if isinstance(result, str):
result = self._invoke_after_llm_call_hooks(
messages, result, from_agent
)
return result
except LLMContextLengthExceededError:
# Re-raise LLMContextLengthExceededError as it should be handled
# by the CrewAgentExecutor._invoke_loop method, which can then decide
# whether to summarize the content or abort based on the respect_context_window flag
raise
except Exception as e:
unsupported_stop = "Unsupported parameter" in str(
e
) and "'stop'" in str(e)
if unsupported_stop:
if (
"additional_drop_params" in self.additional_params
and isinstance(
self.additional_params["additional_drop_params"], list
with LLM._callback_lock:
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
# --- 6) Prepare parameters for the completion call
params = self._prepare_completion_params(messages, tools)
# --- 7) Make the completion call and handle response
if self.stream:
result = self._handle_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
):
self.additional_params["additional_drop_params"].append("stop")
else:
self.additional_params = {"additional_drop_params": ["stop"]}
result = self._handle_non_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
logging.info("Retrying LLM call without the unsupported 'stop'")
if isinstance(result, str):
result = self._invoke_after_llm_call_hooks(
messages, result, from_agent
)
return self.call(
messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
return result
except LLMContextLengthExceededError:
# Re-raise LLMContextLengthExceededError as it should be handled
# by the CrewAgentExecutor._invoke_loop method, which can then decide
# whether to summarize the content or abort based on the respect_context_window flag
raise
except Exception as e:
unsupported_stop = "Unsupported parameter" in str(
e
) and "'stop'" in str(e)
if unsupported_stop:
if (
"additional_drop_params" in self.additional_params
and isinstance(
self.additional_params["additional_drop_params"], list
)
):
self.additional_params["additional_drop_params"].append(
"stop"
)
else:
self.additional_params = {
"additional_drop_params": ["stop"]
}
logging.info(
"Retrying LLM call without the unsupported 'stop'"
)
# Recursive call happens inside the lock since we're using
# a reentrant-safe pattern (the lock is released when we
# exit the with block, and the recursive call will acquire
# it again)
return self.call(
messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent
),
)
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent
),
)
raise
raise
async def acall(
self,
@@ -1790,14 +1805,27 @@ class LLM(BaseLLM):
msg_role: Literal["assistant"] = "assistant"
message["role"] = msg_role
# Use a class-level lock to synchronize access to global litellm callbacks.
# This prevents race conditions when multiple LLM instances call set_callbacks
# concurrently, which could cause callbacks to be removed before they fire.
with suppress_warnings():
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
params = self._prepare_completion_params(messages, tools)
with LLM._callback_lock:
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
params = self._prepare_completion_params(messages, tools)
if self.stream:
return await self._ahandle_streaming_response(
if self.stream:
return await self._ahandle_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
return await self._ahandle_non_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
@@ -1805,52 +1833,49 @@ class LLM(BaseLLM):
from_agent=from_agent,
response_model=response_model,
)
except LLMContextLengthExceededError:
raise
except Exception as e:
unsupported_stop = "Unsupported parameter" in str(
e
) and "'stop'" in str(e)
return await self._ahandle_non_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
except LLMContextLengthExceededError:
raise
except Exception as e:
unsupported_stop = "Unsupported parameter" in str(
e
) and "'stop'" in str(e)
if unsupported_stop:
if (
"additional_drop_params" in self.additional_params
and isinstance(
self.additional_params["additional_drop_params"], list
)
):
self.additional_params["additional_drop_params"].append(
"stop"
)
else:
self.additional_params = {
"additional_drop_params": ["stop"]
}
if unsupported_stop:
if (
"additional_drop_params" in self.additional_params
and isinstance(
self.additional_params["additional_drop_params"], list
logging.info(
"Retrying LLM call without the unsupported 'stop'"
)
):
self.additional_params["additional_drop_params"].append("stop")
else:
self.additional_params = {"additional_drop_params": ["stop"]}
logging.info("Retrying LLM call without the unsupported 'stop'")
return await self.acall(
messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
return await self.acall(
messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent
),
)
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(
error=str(e), from_task=from_task, from_agent=from_agent
),
)
raise
raise
def _handle_emit_call_events(
self,

View File

@@ -1,8 +1,4 @@
import gc
import weakref
import pytest
from crewai.cli.git import Repository
@@ -103,82 +99,3 @@ def test_origin_url(fp, repository):
stdout="https://github.com/user/repo.git\n",
)
assert repository.origin_url() == "https://github.com/user/repo.git"
def test_repository_garbage_collection(fp):
"""Test that Repository instances can be garbage collected.
This test verifies the fix for the memory leak issue where using
@lru_cache on the is_git_repo() method prevented garbage collection
of Repository instances.
"""
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
fp.register(["git", "rev-parse", "--is-inside-work-tree"], stdout="true\n")
fp.register(["git", "fetch"], stdout="")
repo = Repository(path=".")
weak_ref = weakref.ref(repo)
assert weak_ref() is not None
del repo
gc.collect()
assert weak_ref() is None, (
"Repository instance was not garbage collected. "
"This indicates a memory leak, likely from @lru_cache on instance methods."
)
def test_is_git_repo_caching(fp):
"""Test that is_git_repo() result is cached at the instance level.
This verifies that the instance-level caching works correctly,
only calling the subprocess once per instance.
"""
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
fp.register(["git", "rev-parse", "--is-inside-work-tree"], stdout="true\n")
fp.register(["git", "fetch"], stdout="")
repo = Repository(path=".")
assert repo._is_git_repo_cache is True
result1 = repo.is_git_repo()
result2 = repo.is_git_repo()
assert result1 is True
assert result2 is True
assert repo._is_git_repo_cache is True
def test_multiple_repository_instances_independent_caches(fp):
"""Test that multiple Repository instances have independent caches.
This verifies that the instance-level caching doesn't share state
between different Repository instances.
"""
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
fp.register(["git", "rev-parse", "--is-inside-work-tree"], stdout="true\n")
fp.register(["git", "fetch"], stdout="")
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
fp.register(["git", "rev-parse", "--is-inside-work-tree"], stdout="true\n")
fp.register(["git", "fetch"], stdout="")
repo1 = Repository(path=".")
repo2 = Repository(path=".")
assert repo1._is_git_repo_cache is True
assert repo2._is_git_repo_cache is True
assert repo1._is_git_repo_cache is not repo2._is_git_repo_cache or (
repo1._is_git_repo_cache == repo2._is_git_repo_cache
)
weak_ref1 = weakref.ref(repo1)
del repo1
gc.collect()
assert weak_ref1() is None
assert repo2._is_git_repo_cache is True

View File

@@ -1,6 +1,6 @@
import logging
import os
from time import sleep
import threading
from unittest.mock import MagicMock, patch
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
@@ -18,9 +18,15 @@ from pydantic import BaseModel
import pytest
# TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date
@pytest.mark.vcr()
def test_llm_callback_replacement():
"""Test that callbacks are properly isolated between LLM instances.
This test verifies that the race condition fix (using _callback_lock) works
correctly. Previously, this test required a sleep(5) workaround because
callbacks were being modified globally without synchronization, causing
one LLM instance's callbacks to interfere with another's.
"""
llm1 = LLM(model="gpt-4o-mini", is_litellm=True)
llm2 = LLM(model="gpt-4o-mini", is_litellm=True)
@@ -37,7 +43,6 @@ def test_llm_callback_replacement():
messages=[{"role": "user", "content": "Hello, world from another agent!"}],
callbacks=[calc_handler_2],
)
sleep(5)
usage_metrics_2 = calc_handler_2.token_cost_process.get_summary()
# The first handler should not have been updated
@@ -46,6 +51,66 @@ def test_llm_callback_replacement():
assert usage_metrics_1 == calc_handler_1.token_cost_process.get_summary()
def test_llm_callback_lock_prevents_race_condition():
"""Test that the _callback_lock prevents race conditions in concurrent LLM calls.
This test verifies that multiple threads can safely call LLM.call() with
different callbacks without interfering with each other. The lock ensures
that callbacks are properly isolated between concurrent calls.
"""
num_threads = 5
results: list[int] = []
errors: list[Exception] = []
lock = threading.Lock()
def make_llm_call(thread_id: int, mock_completion: MagicMock) -> None:
try:
llm = LLM(model="gpt-4o-mini", is_litellm=True)
calc_handler = TokenCalcHandler(token_cost_process=TokenProcess())
mock_message = MagicMock()
mock_message.content = f"Response from thread {thread_id}"
mock_choice = MagicMock()
mock_choice.message = mock_message
mock_response = MagicMock()
mock_response.choices = [mock_choice]
mock_response.usage = {
"prompt_tokens": 10,
"completion_tokens": 10,
"total_tokens": 20,
}
mock_completion.return_value = mock_response
llm.call(
messages=[{"role": "user", "content": f"Hello from thread {thread_id}"}],
callbacks=[calc_handler],
)
usage = calc_handler.token_cost_process.get_summary()
with lock:
results.append(usage.successful_requests)
except Exception as e:
with lock:
errors.append(e)
with patch("litellm.completion") as mock_completion:
threads = [
threading.Thread(target=make_llm_call, args=(i, mock_completion))
for i in range(num_threads)
]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(errors) == 0, f"Errors occurred: {errors}"
assert len(results) == num_threads
assert all(
r == 1 for r in results
), f"Expected all callbacks to have 1 successful request, got {results}"
@pytest.mark.vcr()
def test_llm_call_with_string_input():
llm = LLM(model="gpt-4o-mini")
@@ -989,4 +1054,4 @@ async def test_usage_info_streaming_with_acall():
assert llm._token_usage["completion_tokens"] > 0
assert llm._token_usage["total_tokens"] > 0
assert len(result) > 0
assert len(result) > 0