diff --git a/src/crewai/llm.py b/src/crewai/llm.py index 3e1c1d1b6..c413fc5d4 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -37,6 +37,7 @@ with warnings.catch_warnings(): warnings.simplefilter("ignore", UserWarning) import litellm from litellm import Choices + from litellm.exceptions import ContextWindowExceededError from litellm.litellm_core_utils.get_supported_openai_params import ( get_supported_openai_params, ) @@ -597,6 +598,11 @@ class LLM(BaseLLM): self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL) return full_response + except ContextWindowExceededError as e: + # Catch context window errors from litellm and convert them to our own exception type. + # This exception is handled by CrewAgentExecutor._invoke_loop() which can then + # decide whether to summarize the content or abort based on the respect_context_window flag. + raise LLMContextLengthExceededException(str(e)) except Exception as e: logging.error(f"Error in streaming response: {str(e)}") if full_response.strip(): @@ -711,7 +717,16 @@ class LLM(BaseLLM): str: The response text """ # --- 1) Make the completion call - response = litellm.completion(**params) + try: + # Attempt to make the completion call, but catch context window errors + # and convert them to our own exception type for consistent handling + # across the codebase. This allows CrewAgentExecutor to handle context + # length issues appropriately. + response = litellm.completion(**params) + except ContextWindowExceededError as e: + # Convert litellm's context window error to our own exception type + # for consistent handling in the rest of the codebase + raise LLMContextLengthExceededException(str(e)) # --- 2) Extract response message and content response_message = cast(Choices, cast(ModelResponse, response).choices)[ @@ -870,15 +885,17 @@ class LLM(BaseLLM): params, callbacks, available_functions ) + except LLMContextLengthExceededException: + # Re-raise LLMContextLengthExceededException as it should be handled + # by the CrewAgentExecutor._invoke_loop method, which can then decide + # whether to summarize the content or abort based on the respect_context_window flag + raise except Exception as e: crewai_event_bus.emit( self, event=LLMCallFailedEvent(error=str(e)), ) - if not LLMContextLengthExceededException( - str(e) - )._is_context_limit_error(str(e)): - logging.error(f"LiteLLM call failed: {str(e)}") + logging.error(f"LiteLLM call failed: {str(e)}") raise def _handle_emit_call_events(self, response: Any, call_type: LLMCallType): diff --git a/tests/llm_test.py b/tests/llm_test.py index 72371ea93..5db37eedb 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -373,6 +373,44 @@ def get_weather_tool_schema(): }, } +def test_context_window_exceeded_error_handling(): + """Test that litellm.ContextWindowExceededError is converted to LLMContextLengthExceededException.""" + from litellm.exceptions import ContextWindowExceededError + from crewai.utilities.exceptions.context_window_exceeding_exception import ( + LLMContextLengthExceededException, + ) + + llm = LLM(model="gpt-4") + + # Test non-streaming response + with patch("litellm.completion") as mock_completion: + mock_completion.side_effect = ContextWindowExceededError( + "This model's maximum context length is 8192 tokens. However, your messages resulted in 10000 tokens.", + model="gpt-4", + llm_provider="openai" + ) + + with pytest.raises(LLMContextLengthExceededException) as excinfo: + llm.call("This is a test message") + + assert "context length exceeded" in str(excinfo.value).lower() + assert "8192 tokens" in str(excinfo.value) + + # Test streaming response + llm = LLM(model="gpt-4", stream=True) + with patch("litellm.completion") as mock_completion: + mock_completion.side_effect = ContextWindowExceededError( + "This model's maximum context length is 8192 tokens. However, your messages resulted in 10000 tokens.", + model="gpt-4", + llm_provider="openai" + ) + + with pytest.raises(LLMContextLengthExceededException) as excinfo: + llm.call("This is a test message") + + assert "context length exceeded" in str(excinfo.value).lower() + assert "8192 tokens" in str(excinfo.value) + @pytest.mark.vcr(filter_headers=["authorization"]) @pytest.fixture