Compare commits

..

1 Commits

Author SHA1 Message Date
Devin AI
3bfa1c6559 Fix issue #3454: Add proactive context length checking to prevent empty LLM responses
- Add _check_context_length_before_call() method to CrewAgentExecutor
- Proactively check estimated token count before LLM calls in _invoke_loop
- Use character-based estimation (chars / 4) to approximate token count
- Call existing _handle_context_length() when context window would be exceeded
- Add comprehensive tests covering proactive handling and token estimation
- Prevents empty responses from providers like DeepInfra that don't throw exceptions

Co-Authored-By: João <joao@crewai.com>
2025-09-05 16:05:35 +00:00
2 changed files with 91 additions and 123 deletions

View File

@@ -3,7 +3,6 @@ import re
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from litellm import AuthenticationError as LiteLLMAuthenticationError
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.parser import (
@@ -113,6 +112,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
try:
while not isinstance(formatted_answer, AgentFinish):
if not self.request_within_rpm_limit or self.request_within_rpm_limit():
self._check_context_length_before_call()
answer = self.llm.call(
self.messages,
callbacks=self.callbacks,
@@ -198,19 +199,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return self._invoke_loop(formatted_answer)
except Exception as e:
if isinstance(e, LiteLLMAuthenticationError):
self._logger.log(
level="error",
message="Authentication error with litellm occurred. Please check your API key and configuration.",
color="red",
)
self._logger.log(
level="error",
message=f"Error details: {str(e)}",
color="red",
)
raise e
elif LLMContextLengthExceededException(str(e))._is_context_limit_error(
if LLMContextLengthExceededException(str(e))._is_context_limit_error(
str(e)
):
self._handle_context_length()
@@ -340,6 +329,19 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
)
]
def _check_context_length_before_call(self) -> None:
total_chars = sum(len(msg.get("content", "")) for msg in self.messages)
estimated_tokens = total_chars // 4
context_window_size = self.llm.get_context_window_size()
if estimated_tokens > context_window_size:
self._printer.print(
content=f"Estimated token count ({estimated_tokens}) exceeds context window ({context_window_size}). Handling proactively.",
color="yellow",
)
self._handle_context_length()
def _handle_context_length(self) -> None:
if self.respect_context_window:
self._printer.print(

View File

@@ -1308,115 +1308,6 @@ def test_llm_call_with_error():
llm.call(messages)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_litellm_auth_error_handling():
"""Test that LiteLLM authentication errors are handled correctly and not retried."""
from litellm import AuthenticationError as LiteLLMAuthenticationError
# Create an agent with a mocked LLM and max_retry_limit=0
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
llm=LLM(model="gpt-4"),
max_retry_limit=0, # Disable retries for authentication errors
max_iter=1, # Limit to one iteration to prevent multiple calls
)
# Create a task
task = Task(
description="Test task",
expected_output="Test output",
agent=agent,
)
# Mock the LLM call to raise LiteLLMAuthenticationError
with (
patch.object(LLM, "call") as mock_llm_call,
pytest.raises(LiteLLMAuthenticationError, match="Invalid API key"),
):
mock_llm_call.side_effect = LiteLLMAuthenticationError(
message="Invalid API key",
llm_provider="openai",
model="gpt-4"
)
agent.execute_task(task)
# Verify the call was only made once (no retries)
mock_llm_call.assert_called_once()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_agent_executor_litellm_auth_error():
"""Test that CrewAgentExecutor properly identifies and handles LiteLLM authentication errors."""
from litellm import AuthenticationError as LiteLLMAuthenticationError
from crewai.utilities import Logger
from crewai.agents.tools_handler import ToolsHandler
# Create an agent and executor with max_retry_limit=0
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
llm=LLM(model="gpt-4"),
max_retry_limit=0, # Disable retries for authentication errors
)
task = Task(
description="Test task",
expected_output="Test output",
agent=agent,
)
# Create executor with all required parameters
executor = CrewAgentExecutor(
agent=agent,
task=task,
llm=agent.llm,
crew=None,
prompt={
"system": "You are a test agent",
"user": "Execute the task: {input}"
},
max_iter=5,
tools=[],
tools_names="",
stop_words=[],
tools_description="",
tools_handler=ToolsHandler(),
)
# Mock the LLM call to raise LiteLLMAuthenticationError
with (
patch.object(LLM, "call") as mock_llm_call,
patch.object(Logger, "log") as mock_logger,
pytest.raises(LiteLLMAuthenticationError, match="Invalid API key"),
):
mock_llm_call.side_effect = LiteLLMAuthenticationError(
message="Invalid API key",
llm_provider="openai",
model="gpt-4"
)
executor.invoke({
"input": "test input",
"tool_names": "", # Required template variable
"tools": "", # Required template variable
})
# Verify error handling
mock_logger.assert_any_call(
level="error",
message="Authentication error with litellm occurred. Please check your API key and configuration.",
color="red",
)
mock_logger.assert_any_call(
level="error",
message="Error details: litellm.AuthenticationError: Invalid API key",
color="red",
)
# Verify the call was only made once (no retries)
mock_llm_call.assert_called_once()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_handle_context_length_exceeds_limit():
agent = Agent(
@@ -1734,3 +1625,78 @@ def test_agent_with_knowledge_sources():
# Assert that the agent provides the correct information
assert "red" in result.raw.lower()
def test_proactive_context_length_handling_prevents_empty_response():
"""Test that proactive context length checking prevents empty LLM responses."""
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
sliding_context_window=True,
)
long_input = "This is a very long input that should exceed the context window. " * 1000
with patch.object(agent.llm, 'get_context_window_size', return_value=100):
with patch.object(agent.agent_executor, '_handle_context_length') as mock_handle:
with patch.object(agent.llm, 'call', return_value="Proper response after summarization"):
agent.agent_executor.messages = [
{"role": "user", "content": long_input}
]
task = Task(
description="Process this long input",
expected_output="A response",
agent=agent,
)
result = agent.execute_task(task)
mock_handle.assert_called()
assert result and result.strip() != ""
def test_proactive_context_length_handling_with_no_summarization():
"""Test proactive context length checking when summarization is disabled."""
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
sliding_context_window=False,
)
long_input = "This is a very long input. " * 1000
with patch.object(agent.llm, 'get_context_window_size', return_value=100):
agent.agent_executor.messages = [
{"role": "user", "content": long_input}
]
with pytest.raises(SystemExit):
agent.agent_executor._check_context_length_before_call()
def test_context_length_estimation():
"""Test the token estimation logic."""
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
)
agent.agent_executor.messages = [
{"role": "user", "content": "Short message"},
{"role": "assistant", "content": "Another short message"},
]
with patch.object(agent.llm, 'get_context_window_size', return_value=10):
with patch.object(agent.agent_executor, '_handle_context_length') as mock_handle:
agent.agent_executor._check_context_length_before_call()
mock_handle.assert_not_called()
with patch.object(agent.llm, 'get_context_window_size', return_value=5):
with patch.object(agent.agent_executor, '_handle_context_length') as mock_handle:
agent.agent_executor._check_context_length_before_call()
mock_handle.assert_called()