diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm.py index 8bc1fe648..bbe20ce77 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm.py @@ -341,6 +341,7 @@ class AccumulatedToolArgs(BaseModel): class LLM(BaseLLM): completion_cost: float | None = None + _callback_lock: threading.RLock = threading.RLock() def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM: """Factory method that routes to native SDK or falls back to LiteLLM. @@ -1144,7 +1145,7 @@ class LLM(BaseLLM): if response_model: params["response_model"] = response_model response = litellm.completion(**params) - + if hasattr(response,"usage") and not isinstance(response.usage, type) and response.usage: usage_info = response.usage self._track_token_usage_internal(usage_info) @@ -1363,7 +1364,7 @@ class LLM(BaseLLM): """ full_response = "" chunk_count = 0 - + usage_info = None accumulated_tool_args: defaultdict[int, AccumulatedToolArgs] = defaultdict( @@ -1657,78 +1658,92 @@ class LLM(BaseLLM): raise ValueError("LLM call blocked by before_llm_call hook") # --- 5) Set up callbacks if provided + # Use a class-level lock to synchronize access to global litellm callbacks. + # This prevents race conditions when multiple LLM instances call set_callbacks + # concurrently, which could cause callbacks to be removed before they fire. with suppress_warnings(): - if callbacks and len(callbacks) > 0: - self.set_callbacks(callbacks) - try: - # --- 6) Prepare parameters for the completion call - params = self._prepare_completion_params(messages, tools) - # --- 7) Make the completion call and handle response - if self.stream: - result = self._handle_streaming_response( - params=params, - callbacks=callbacks, - available_functions=available_functions, - from_task=from_task, - from_agent=from_agent, - response_model=response_model, - ) - else: - result = self._handle_non_streaming_response( - params=params, - callbacks=callbacks, - available_functions=available_functions, - from_task=from_task, - from_agent=from_agent, - response_model=response_model, - ) - - if isinstance(result, str): - result = self._invoke_after_llm_call_hooks( - messages, result, from_agent - ) - - return result - except LLMContextLengthExceededError: - # Re-raise LLMContextLengthExceededError as it should be handled - # by the CrewAgentExecutor._invoke_loop method, which can then decide - # whether to summarize the content or abort based on the respect_context_window flag - raise - except Exception as e: - unsupported_stop = "Unsupported parameter" in str( - e - ) and "'stop'" in str(e) - - if unsupported_stop: - if ( - "additional_drop_params" in self.additional_params - and isinstance( - self.additional_params["additional_drop_params"], list + with LLM._callback_lock: + if callbacks and len(callbacks) > 0: + self.set_callbacks(callbacks) + try: + # --- 6) Prepare parameters for the completion call + params = self._prepare_completion_params(messages, tools) + # --- 7) Make the completion call and handle response + if self.stream: + result = self._handle_streaming_response( + params=params, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, ) - ): - self.additional_params["additional_drop_params"].append("stop") else: - self.additional_params = {"additional_drop_params": ["stop"]} + result = self._handle_non_streaming_response( + params=params, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ) - logging.info("Retrying LLM call without the unsupported 'stop'") + if isinstance(result, str): + result = self._invoke_after_llm_call_hooks( + messages, result, from_agent + ) - return self.call( - messages, - tools=tools, - callbacks=callbacks, - available_functions=available_functions, - from_task=from_task, - from_agent=from_agent, - response_model=response_model, + return result + except LLMContextLengthExceededError: + # Re-raise LLMContextLengthExceededError as it should be handled + # by the CrewAgentExecutor._invoke_loop method, which can then decide + # whether to summarize the content or abort based on the respect_context_window flag + raise + except Exception as e: + unsupported_stop = "Unsupported parameter" in str( + e + ) and "'stop'" in str(e) + + if unsupported_stop: + if ( + "additional_drop_params" in self.additional_params + and isinstance( + self.additional_params["additional_drop_params"], list + ) + ): + self.additional_params["additional_drop_params"].append( + "stop" + ) + else: + self.additional_params = { + "additional_drop_params": ["stop"] + } + + logging.info( + "Retrying LLM call without the unsupported 'stop'" + ) + + # Recursive call happens inside the lock since we're using + # a reentrant-safe pattern (the lock is released when we + # exit the with block, and the recursive call will acquire + # it again) + return self.call( + messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ) + + crewai_event_bus.emit( + self, + event=LLMCallFailedEvent( + error=str(e), from_task=from_task, from_agent=from_agent + ), ) - - crewai_event_bus.emit( - self, - event=LLMCallFailedEvent( - error=str(e), from_task=from_task, from_agent=from_agent - ), - ) - raise + raise async def acall( self, @@ -1790,14 +1805,27 @@ class LLM(BaseLLM): msg_role: Literal["assistant"] = "assistant" message["role"] = msg_role + # Use a class-level lock to synchronize access to global litellm callbacks. + # This prevents race conditions when multiple LLM instances call set_callbacks + # concurrently, which could cause callbacks to be removed before they fire. with suppress_warnings(): - if callbacks and len(callbacks) > 0: - self.set_callbacks(callbacks) - try: - params = self._prepare_completion_params(messages, tools) + with LLM._callback_lock: + if callbacks and len(callbacks) > 0: + self.set_callbacks(callbacks) + try: + params = self._prepare_completion_params(messages, tools) - if self.stream: - return await self._ahandle_streaming_response( + if self.stream: + return await self._ahandle_streaming_response( + params=params, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ) + + return await self._ahandle_non_streaming_response( params=params, callbacks=callbacks, available_functions=available_functions, @@ -1805,52 +1833,49 @@ class LLM(BaseLLM): from_agent=from_agent, response_model=response_model, ) + except LLMContextLengthExceededError: + raise + except Exception as e: + unsupported_stop = "Unsupported parameter" in str( + e + ) and "'stop'" in str(e) - return await self._ahandle_non_streaming_response( - params=params, - callbacks=callbacks, - available_functions=available_functions, - from_task=from_task, - from_agent=from_agent, - response_model=response_model, - ) - except LLMContextLengthExceededError: - raise - except Exception as e: - unsupported_stop = "Unsupported parameter" in str( - e - ) and "'stop'" in str(e) + if unsupported_stop: + if ( + "additional_drop_params" in self.additional_params + and isinstance( + self.additional_params["additional_drop_params"], list + ) + ): + self.additional_params["additional_drop_params"].append( + "stop" + ) + else: + self.additional_params = { + "additional_drop_params": ["stop"] + } - if unsupported_stop: - if ( - "additional_drop_params" in self.additional_params - and isinstance( - self.additional_params["additional_drop_params"], list + logging.info( + "Retrying LLM call without the unsupported 'stop'" ) - ): - self.additional_params["additional_drop_params"].append("stop") - else: - self.additional_params = {"additional_drop_params": ["stop"]} - logging.info("Retrying LLM call without the unsupported 'stop'") + return await self.acall( + messages, + tools=tools, + callbacks=callbacks, + available_functions=available_functions, + from_task=from_task, + from_agent=from_agent, + response_model=response_model, + ) - return await self.acall( - messages, - tools=tools, - callbacks=callbacks, - available_functions=available_functions, - from_task=from_task, - from_agent=from_agent, - response_model=response_model, + crewai_event_bus.emit( + self, + event=LLMCallFailedEvent( + error=str(e), from_task=from_task, from_agent=from_agent + ), ) - - crewai_event_bus.emit( - self, - event=LLMCallFailedEvent( - error=str(e), from_task=from_task, from_agent=from_agent - ), - ) - raise + raise def _handle_emit_call_events( self, diff --git a/lib/crewai/tests/test_llm.py b/lib/crewai/tests/test_llm.py index a8b6a7a3f..8419153a5 100644 --- a/lib/crewai/tests/test_llm.py +++ b/lib/crewai/tests/test_llm.py @@ -1,6 +1,6 @@ import logging import os -from time import sleep +import threading from unittest.mock import MagicMock, patch from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess @@ -18,9 +18,15 @@ from pydantic import BaseModel import pytest -# TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date @pytest.mark.vcr() def test_llm_callback_replacement(): + """Test that callbacks are properly isolated between LLM instances. + + This test verifies that the race condition fix (using _callback_lock) works + correctly. Previously, this test required a sleep(5) workaround because + callbacks were being modified globally without synchronization, causing + one LLM instance's callbacks to interfere with another's. + """ llm1 = LLM(model="gpt-4o-mini", is_litellm=True) llm2 = LLM(model="gpt-4o-mini", is_litellm=True) @@ -37,7 +43,6 @@ def test_llm_callback_replacement(): messages=[{"role": "user", "content": "Hello, world from another agent!"}], callbacks=[calc_handler_2], ) - sleep(5) usage_metrics_2 = calc_handler_2.token_cost_process.get_summary() # The first handler should not have been updated @@ -46,6 +51,66 @@ def test_llm_callback_replacement(): assert usage_metrics_1 == calc_handler_1.token_cost_process.get_summary() +def test_llm_callback_lock_prevents_race_condition(): + """Test that the _callback_lock prevents race conditions in concurrent LLM calls. + + This test verifies that multiple threads can safely call LLM.call() with + different callbacks without interfering with each other. The lock ensures + that callbacks are properly isolated between concurrent calls. + """ + num_threads = 5 + results: list[int] = [] + errors: list[Exception] = [] + lock = threading.Lock() + + def make_llm_call(thread_id: int, mock_completion: MagicMock) -> None: + try: + llm = LLM(model="gpt-4o-mini", is_litellm=True) + calc_handler = TokenCalcHandler(token_cost_process=TokenProcess()) + + mock_message = MagicMock() + mock_message.content = f"Response from thread {thread_id}" + mock_choice = MagicMock() + mock_choice.message = mock_message + mock_response = MagicMock() + mock_response.choices = [mock_choice] + mock_response.usage = { + "prompt_tokens": 10, + "completion_tokens": 10, + "total_tokens": 20, + } + mock_completion.return_value = mock_response + + llm.call( + messages=[{"role": "user", "content": f"Hello from thread {thread_id}"}], + callbacks=[calc_handler], + ) + + usage = calc_handler.token_cost_process.get_summary() + with lock: + results.append(usage.successful_requests) + except Exception as e: + with lock: + errors.append(e) + + with patch("litellm.completion") as mock_completion: + threads = [ + threading.Thread(target=make_llm_call, args=(i, mock_completion)) + for i in range(num_threads) + ] + + for t in threads: + t.start() + for t in threads: + t.join() + + assert len(errors) == 0, f"Errors occurred: {errors}" + assert len(results) == num_threads + assert all( + r == 1 for r in results + ), f"Expected all callbacks to have 1 successful request, got {results}" + + @pytest.mark.vcr() def test_llm_call_with_string_input(): llm = LLM(model="gpt-4o-mini") @@ -989,4 +1054,4 @@ async def test_usage_info_streaming_with_acall(): assert llm._token_usage["completion_tokens"] > 0 assert llm._token_usage["total_tokens"] > 0 - assert len(result) > 0 \ No newline at end of file + assert len(result) > 0