Add tests for LiteLLM auth error handling in agent and crew_agent_executor

Co-Authored-By: brandon@crewai.com <brandon@crewai.com>
This commit is contained in:
Devin AI
2025-01-21 20:30:44 +00:00
parent 409892d65f
commit de3e64f51f
2 changed files with 122 additions and 1 deletions

View File

@@ -3,6 +3,7 @@ import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
from litellm import AuthenticationError as LiteLLMAuthenticationError
from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.parser import ( from crewai.agents.parser import (
@@ -197,7 +198,19 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return self._invoke_loop(formatted_answer) return self._invoke_loop(formatted_answer)
except Exception as e: except Exception as e:
if LLMContextLengthExceededException(str(e))._is_context_limit_error( if isinstance(e, LiteLLMAuthenticationError):
self._logger.log(
level="error",
message="Authentication error with litellm occurred. Please check your API key and configuration.",
color="red",
)
self._logger.log(
level="error",
message=f"Error details: {str(e)}",
color="red",
)
raise e
elif LLMContextLengthExceededException(str(e))._is_context_limit_error(
str(e) str(e)
): ):
self._handle_context_length() self._handle_context_length()

View File

@@ -1308,6 +1308,114 @@ def test_llm_call_with_error():
llm.call(messages) llm.call(messages)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_litellm_auth_error_handling():
"""Test that LiteLLM authentication errors are handled correctly and not retried."""
from litellm import AuthenticationError as LiteLLMAuthenticationError
# Create an agent with a mocked LLM and max_retry_limit=0
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
llm=LLM(model="gpt-4"),
max_retry_limit=0, # Disable retries for authentication errors
)
# Create a task
task = Task(
description="Test task",
expected_output="Test output",
agent=agent,
)
# Mock the LLM call to raise LiteLLMAuthenticationError
with (
patch.object(LLM, "call") as mock_llm_call,
pytest.raises(LiteLLMAuthenticationError, match="Invalid API key"),
):
mock_llm_call.side_effect = LiteLLMAuthenticationError(
message="Invalid API key",
llm_provider="openai",
model="gpt-4"
)
agent.execute_task(task)
# Verify the call was only made once (no retries)
mock_llm_call.assert_called_once()
@pytest.mark.vcr(filter_headers=["authorization"])
def test_crew_agent_executor_litellm_auth_error():
"""Test that CrewAgentExecutor properly identifies and handles LiteLLM authentication errors."""
from litellm import AuthenticationError as LiteLLMAuthenticationError
from crewai.utilities import Logger
from crewai.agents.tools_handler import ToolsHandler
# Create an agent and executor with max_retry_limit=0
agent = Agent(
role="test role",
goal="test goal",
backstory="test backstory",
llm=LLM(model="gpt-4"),
max_retry_limit=0, # Disable retries for authentication errors
)
task = Task(
description="Test task",
expected_output="Test output",
agent=agent,
)
# Create executor with all required parameters
executor = CrewAgentExecutor(
agent=agent,
task=task,
llm=agent.llm,
crew=None,
prompt={
"system": "You are a test agent",
"user": "Execute the task: {input}"
},
max_iter=5,
tools=[],
tools_names="",
stop_words=[],
tools_description="",
tools_handler=ToolsHandler(),
)
# Mock the LLM call to raise LiteLLMAuthenticationError
with (
patch.object(LLM, "call") as mock_llm_call,
patch.object(Logger, "log") as mock_logger,
pytest.raises(LiteLLMAuthenticationError, match="Invalid API key"),
):
mock_llm_call.side_effect = LiteLLMAuthenticationError(
message="Invalid API key",
llm_provider="openai",
model="gpt-4"
)
executor.invoke({
"input": "test input",
"tool_names": "", # Required template variable
"tools": "", # Required template variable
})
# Verify error handling
mock_logger.assert_any_call(
level="error",
message="Authentication error with litellm occurred. Please check your API key and configuration.",
color="red",
)
mock_logger.assert_any_call(
level="error",
message="Error details: litellm.AuthenticationError: Invalid API key",
color="red",
)
# Verify the call was only made once (no retries)
mock_llm_call.assert_called_once()
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_handle_context_length_exceeds_limit(): def test_handle_context_length_exceeds_limit():
agent = Agent( agent = Agent(