mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
- Add fallback_llms field to Agent class to support multiple LLM fallbacks - Modify get_llm_response in agent_utils.py to try fallback LLMs when primary fails - Update CrewAgentExecutor and LiteAgent to pass fallback LLMs to get_llm_response - Add smart error handling that skips fallbacks for auth errors but tries them for other failures - Add comprehensive tests covering all fallback scenarios - Maintain full backward compatibility for agents without fallback LLMs Addresses GitHub Issue #3032: Support Fallback LLMs for Agent Execution Co-Authored-By: João <joao@crewai.com>
319 lines
9.4 KiB
Python
319 lines
9.4 KiB
Python
"""Tests for Agent fallback LLM functionality."""
|
|
|
|
import pytest
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
from crewai import Agent, Task
|
|
from crewai.llm import LLM
|
|
from crewai.utilities.agent_utils import get_llm_response
|
|
from crewai.utilities import Printer
|
|
from litellm.exceptions import AuthenticationError, ContextWindowExceededError
|
|
|
|
|
|
def test_agent_with_fallback_llms_basic():
|
|
"""Test agent with fallback LLMs when primary fails."""
|
|
primary_llm = LLM("gpt-4")
|
|
fallback_llm = LLM("gpt-3.5-turbo")
|
|
|
|
agent = Agent(
|
|
role="Test Agent",
|
|
goal="Test fallback functionality",
|
|
backstory="I test fallback LLMs",
|
|
llm=primary_llm,
|
|
fallback_llms=[fallback_llm]
|
|
)
|
|
|
|
task = Task(
|
|
description="Simple test task",
|
|
expected_output="Test output",
|
|
agent=agent
|
|
)
|
|
|
|
with patch.object(primary_llm, 'call') as mock_primary, \
|
|
patch.object(fallback_llm, 'call') as mock_fallback:
|
|
|
|
mock_primary.side_effect = Exception("Primary LLM failed")
|
|
mock_fallback.return_value = "Fallback response"
|
|
|
|
result = agent.execute_task(task)
|
|
|
|
assert result == "Fallback response"
|
|
mock_primary.assert_called_once()
|
|
mock_fallback.assert_called_once()
|
|
|
|
|
|
def test_agent_fallback_llms_multiple():
|
|
"""Test agent with multiple fallback LLMs."""
|
|
primary_llm = LLM("gpt-4")
|
|
fallback1 = LLM("gpt-3.5-turbo")
|
|
fallback2 = LLM("claude-3-sonnet-20240229")
|
|
|
|
agent = Agent(
|
|
role="Test Agent",
|
|
goal="Test multiple fallbacks",
|
|
backstory="I test multiple fallback LLMs",
|
|
llm=primary_llm,
|
|
fallback_llms=[fallback1, fallback2]
|
|
)
|
|
|
|
task = Task(
|
|
description="Test task",
|
|
expected_output="Test output",
|
|
agent=agent
|
|
)
|
|
|
|
with patch.object(primary_llm, 'call') as mock_primary, \
|
|
patch.object(fallback1, 'call') as mock_fallback1, \
|
|
patch.object(fallback2, 'call') as mock_fallback2:
|
|
|
|
mock_primary.side_effect = Exception("Primary failed")
|
|
mock_fallback1.side_effect = Exception("Fallback 1 failed")
|
|
mock_fallback2.return_value = "Fallback 2 response"
|
|
|
|
result = agent.execute_task(task)
|
|
|
|
assert result == "Fallback 2 response"
|
|
mock_primary.assert_called_once()
|
|
mock_fallback1.assert_called_once()
|
|
mock_fallback2.assert_called_once()
|
|
|
|
|
|
def test_agent_fallback_auth_error_skips_fallbacks():
|
|
"""Test that authentication errors skip fallback attempts."""
|
|
primary_llm = LLM("gpt-4")
|
|
fallback_llm = LLM("gpt-3.5-turbo")
|
|
|
|
agent = Agent(
|
|
role="Test Agent",
|
|
goal="Test auth error handling",
|
|
backstory="I test auth error handling",
|
|
llm=primary_llm,
|
|
fallback_llms=[fallback_llm]
|
|
)
|
|
|
|
task = Task(
|
|
description="Test task",
|
|
expected_output="Test output",
|
|
agent=agent
|
|
)
|
|
|
|
with patch.object(primary_llm, 'call') as mock_primary, \
|
|
patch.object(fallback_llm, 'call') as mock_fallback, \
|
|
pytest.raises(AuthenticationError):
|
|
|
|
mock_primary.side_effect = AuthenticationError(
|
|
message="Invalid API key", llm_provider="openai", model="gpt-4"
|
|
)
|
|
|
|
agent.execute_task(task)
|
|
|
|
mock_primary.assert_called_once()
|
|
mock_fallback.assert_not_called()
|
|
|
|
|
|
def test_agent_fallback_context_window_error():
|
|
"""Test that context window errors try fallbacks."""
|
|
primary_llm = LLM("gpt-4")
|
|
fallback_llm = LLM("gpt-3.5-turbo")
|
|
|
|
agent = Agent(
|
|
role="Test Agent",
|
|
goal="Test context window error handling",
|
|
backstory="I test context window error handling",
|
|
llm=primary_llm,
|
|
fallback_llms=[fallback_llm]
|
|
)
|
|
|
|
task = Task(
|
|
description="Test task",
|
|
expected_output="Test output",
|
|
agent=agent
|
|
)
|
|
|
|
with patch.object(primary_llm, 'call') as mock_primary, \
|
|
patch.object(fallback_llm, 'call') as mock_fallback:
|
|
|
|
mock_primary.side_effect = ContextWindowExceededError("Context window exceeded")
|
|
mock_fallback.return_value = "Fallback response"
|
|
|
|
result = agent.execute_task(task)
|
|
|
|
assert result == "Fallback response"
|
|
mock_primary.assert_called_once()
|
|
mock_fallback.assert_called_once()
|
|
|
|
|
|
def test_agent_all_llms_fail():
|
|
"""Test behavior when all LLMs fail."""
|
|
primary_llm = LLM("gpt-4")
|
|
fallback_llm = LLM("gpt-3.5-turbo")
|
|
|
|
agent = Agent(
|
|
role="Test Agent",
|
|
goal="Test all LLMs failing",
|
|
backstory="I test all LLMs failing",
|
|
llm=primary_llm,
|
|
fallback_llms=[fallback_llm]
|
|
)
|
|
|
|
task = Task(
|
|
description="Test task",
|
|
expected_output="Test output",
|
|
agent=agent
|
|
)
|
|
|
|
with patch.object(primary_llm, 'call') as mock_primary, \
|
|
patch.object(fallback_llm, 'call') as mock_fallback, \
|
|
pytest.raises(Exception, match="Fallback failed"):
|
|
|
|
mock_primary.side_effect = Exception("Primary failed")
|
|
mock_fallback.side_effect = Exception("Fallback failed")
|
|
|
|
agent.execute_task(task)
|
|
|
|
mock_primary.assert_called_once()
|
|
mock_fallback.assert_called_once()
|
|
|
|
|
|
def test_agent_backward_compatibility():
|
|
"""Test that agents without fallback LLMs work as before."""
|
|
agent = Agent(
|
|
role="Test Agent",
|
|
goal="Test backward compatibility",
|
|
backstory="I test backward compatibility",
|
|
llm=LLM("gpt-4")
|
|
)
|
|
|
|
task = Task(
|
|
description="Test task",
|
|
expected_output="Test output",
|
|
agent=agent
|
|
)
|
|
|
|
with patch.object(agent.llm, 'call') as mock_llm:
|
|
mock_llm.return_value = "Primary response"
|
|
|
|
result = agent.execute_task(task)
|
|
|
|
assert result == "Primary response"
|
|
mock_llm.assert_called_once()
|
|
|
|
|
|
def test_get_llm_response_with_fallbacks():
|
|
"""Test get_llm_response function directly with fallbacks."""
|
|
primary_llm = MagicMock()
|
|
fallback_llm = MagicMock()
|
|
printer = Printer()
|
|
|
|
primary_llm.call.side_effect = Exception("Primary failed")
|
|
fallback_llm.call.return_value = "Fallback success"
|
|
|
|
result = get_llm_response(
|
|
llm=primary_llm,
|
|
messages=[{"role": "user", "content": "test"}],
|
|
callbacks=[],
|
|
printer=printer,
|
|
fallback_llms=[fallback_llm]
|
|
)
|
|
|
|
assert result == "Fallback success"
|
|
primary_llm.call.assert_called_once()
|
|
fallback_llm.call.assert_called_once()
|
|
|
|
|
|
def test_get_llm_response_no_fallbacks():
|
|
"""Test get_llm_response function without fallbacks (backward compatibility)."""
|
|
primary_llm = MagicMock()
|
|
printer = Printer()
|
|
|
|
primary_llm.call.return_value = "Primary success"
|
|
|
|
result = get_llm_response(
|
|
llm=primary_llm,
|
|
messages=[{"role": "user", "content": "test"}],
|
|
callbacks=[],
|
|
printer=printer
|
|
)
|
|
|
|
assert result == "Primary success"
|
|
primary_llm.call.assert_called_once()
|
|
|
|
|
|
def test_agent_fallback_llms_string_initialization():
|
|
"""Test that fallback LLMs can be initialized with string model names."""
|
|
agent = Agent(
|
|
role="Test Agent",
|
|
goal="Test string initialization",
|
|
backstory="I test string initialization",
|
|
llm="gpt-4",
|
|
fallback_llms=["gpt-3.5-turbo", "claude-3-sonnet-20240229"]
|
|
)
|
|
|
|
assert agent.fallback_llms is not None
|
|
assert len(agent.fallback_llms) == 2
|
|
assert hasattr(agent.fallback_llms[0], 'call')
|
|
assert hasattr(agent.fallback_llms[1], 'call')
|
|
|
|
|
|
def test_agent_primary_success_no_fallback():
|
|
"""Test that fallback LLMs are not called when primary succeeds."""
|
|
primary_llm = LLM("gpt-4")
|
|
fallback_llm = LLM("gpt-3.5-turbo")
|
|
|
|
agent = Agent(
|
|
role="Test Agent",
|
|
goal="Test primary success",
|
|
backstory="I test primary success",
|
|
llm=primary_llm,
|
|
fallback_llms=[fallback_llm]
|
|
)
|
|
|
|
task = Task(
|
|
description="Test task",
|
|
expected_output="Test output",
|
|
agent=agent
|
|
)
|
|
|
|
with patch.object(primary_llm, 'call') as mock_primary, \
|
|
patch.object(fallback_llm, 'call') as mock_fallback:
|
|
|
|
mock_primary.return_value = "Primary success"
|
|
|
|
result = agent.execute_task(task)
|
|
|
|
assert result == "Primary success"
|
|
mock_primary.assert_called_once()
|
|
mock_fallback.assert_not_called()
|
|
|
|
|
|
def test_agent_empty_response_triggers_fallback():
|
|
"""Test that empty responses from primary LLM trigger fallback."""
|
|
primary_llm = LLM("gpt-4")
|
|
fallback_llm = LLM("gpt-3.5-turbo")
|
|
|
|
agent = Agent(
|
|
role="Test Agent",
|
|
goal="Test empty response handling",
|
|
backstory="I test empty response handling",
|
|
llm=primary_llm,
|
|
fallback_llms=[fallback_llm]
|
|
)
|
|
|
|
task = Task(
|
|
description="Test task",
|
|
expected_output="Test output",
|
|
agent=agent
|
|
)
|
|
|
|
with patch.object(primary_llm, 'call') as mock_primary, \
|
|
patch.object(fallback_llm, 'call') as mock_fallback:
|
|
|
|
mock_primary.return_value = ""
|
|
mock_fallback.return_value = "Fallback response"
|
|
|
|
result = agent.execute_task(task)
|
|
|
|
assert result == "Fallback response"
|
|
mock_primary.assert_called_once()
|
|
mock_fallback.assert_called_once()
|