Compare commits

...

2 Commits

Author SHA1 Message Date
Devin AI
bfe3931dcd fix: resolve CI failures for fallback LLM implementation
- Fix type checker error by adding None check before raising last_exception
- Fix ContextWindowExceededError constructor with correct signature (message, model, llm_provider)
- Update auth error test assertion to match new print message format

Co-Authored-By: João <joao@crewai.com>
2025-06-19 06:13:44 +00:00
Devin AI
06e2683fd8 feat: implement fallback LLMs for agent execution
- Add fallback_llms field to Agent class to support multiple LLM fallbacks
- Modify get_llm_response in agent_utils.py to try fallback LLMs when primary fails
- Update CrewAgentExecutor and LiteAgent to pass fallback LLMs to get_llm_response
- Add smart error handling that skips fallbacks for auth errors but tries them for other failures
- Add comprehensive tests covering all fallback scenarios
- Maintain full backward compatibility for agents without fallback LLMs

Addresses GitHub Issue #3032: Support Fallback LLMs for Agent Execution

Co-Authored-By: João <joao@crewai.com>
2025-06-19 06:02:39 +00:00
6 changed files with 373 additions and 21 deletions

View File

@@ -91,6 +91,9 @@ class Agent(BaseAgent):
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
description="Language model that will run the agent.", default=None
)
fallback_llms: Optional[List[Union[str, InstanceOf[BaseLLM], Any]]] = Field(
default=None, description="List of fallback language models to try if the primary LLM fails."
)
system_template: Optional[str] = Field(
default=None, description="System format for the agent."
)
@@ -174,6 +177,8 @@ class Agent(BaseAgent):
self.agent_ops_agent_name = self.role
self.llm = create_llm(self.llm)
if self.fallback_llms:
self.fallback_llms = [create_llm(fallback_llm) for fallback_llm in self.fallback_llms]
if self.function_calling_llm and not isinstance(
self.function_calling_llm, BaseLLM
):

View File

@@ -159,6 +159,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
messages=self.messages,
callbacks=self.callbacks,
printer=self._printer,
fallback_llms=getattr(self.agent, 'fallback_llms', None),
)
formatted_answer = process_llm_response(answer, self.use_stop_words)

View File

@@ -526,6 +526,7 @@ class LiteAgent(FlowTrackable, BaseModel):
messages=self._messages,
callbacks=self._callbacks,
printer=self._printer,
fallback_llms=getattr(self, 'fallback_llms', None),
)
# Emit LLM call completed event

View File

@@ -145,27 +145,52 @@ def get_llm_response(
messages: List[Dict[str, str]],
callbacks: List[Any],
printer: Printer,
fallback_llms: Optional[List[Union[LLM, BaseLLM]]] = None,
) -> str:
"""Call the LLM and return the response, handling any invalid responses."""
try:
answer = llm.call(
messages,
callbacks=callbacks,
)
except Exception as e:
printer.print(
content=f"Error during LLM call: {e}",
color="red",
)
raise e
if not answer:
printer.print(
content="Received None or empty response from LLM call.",
color="red",
)
raise ValueError("Invalid response from LLM call - None or empty.")
return answer
"""Call the LLM and return the response, handling any invalid responses and trying fallbacks if available."""
llms_to_try = [llm]
if fallback_llms:
llms_to_try.extend(fallback_llms)
last_exception = None
for i, current_llm in enumerate(llms_to_try):
try:
answer = current_llm.call(
messages,
callbacks=callbacks,
)
if not answer:
error_msg = "Received None or empty response from LLM call."
printer.print(content=error_msg, color="red")
if i < len(llms_to_try) - 1:
printer.print(content=f"Trying fallback LLM {i+1}...", color="yellow")
continue
else:
raise ValueError("Invalid response from LLM call - None or empty.")
return answer
except Exception as e:
last_exception = e
if i == 0:
printer.print(content=f"Primary LLM failed: {e}", color="red")
else:
printer.print(content=f"Fallback LLM {i} failed: {e}", color="red")
if e.__class__.__module__.startswith("litellm"):
error_str = str(e).lower()
if any(term in error_str for term in ["authentication", "api key", "unauthorized", "forbidden"]):
printer.print(content="Authentication error detected, skipping remaining fallbacks", color="red")
raise e
if i < len(llms_to_try) - 1:
printer.print(content=f"Trying fallback LLM {i+1}...", color="yellow")
continue
printer.print(content="All LLMs failed, raising last exception", color="red")
if last_exception is not None:
raise last_exception
else:
raise RuntimeError("All LLMs failed but no exception was captured")
def process_llm_response(

View File

@@ -1984,7 +1984,7 @@ def test_crew_agent_executor_litellm_auth_error():
)
# Verify error handling messages
error_message = f"Error during LLM call: {str(mock_llm_call.side_effect)}"
error_message = f"Primary LLM failed: {str(mock_llm_call.side_effect)}"
mock_printer.assert_any_call(
content=error_message,
color="red",

View File

@@ -0,0 +1,320 @@
"""Tests for Agent fallback LLM functionality."""
import pytest
from unittest.mock import patch, MagicMock
from crewai import Agent, Task
from crewai.llm import LLM
from crewai.utilities.agent_utils import get_llm_response
from crewai.utilities import Printer
from litellm.exceptions import AuthenticationError, ContextWindowExceededError
def test_agent_with_fallback_llms_basic():
"""Test agent with fallback LLMs when primary fails."""
primary_llm = LLM("gpt-4")
fallback_llm = LLM("gpt-3.5-turbo")
agent = Agent(
role="Test Agent",
goal="Test fallback functionality",
backstory="I test fallback LLMs",
llm=primary_llm,
fallback_llms=[fallback_llm]
)
task = Task(
description="Simple test task",
expected_output="Test output",
agent=agent
)
with patch.object(primary_llm, 'call') as mock_primary, \
patch.object(fallback_llm, 'call') as mock_fallback:
mock_primary.side_effect = Exception("Primary LLM failed")
mock_fallback.return_value = "Fallback response"
result = agent.execute_task(task)
assert result == "Fallback response"
mock_primary.assert_called_once()
mock_fallback.assert_called_once()
def test_agent_fallback_llms_multiple():
"""Test agent with multiple fallback LLMs."""
primary_llm = LLM("gpt-4")
fallback1 = LLM("gpt-3.5-turbo")
fallback2 = LLM("claude-3-sonnet-20240229")
agent = Agent(
role="Test Agent",
goal="Test multiple fallbacks",
backstory="I test multiple fallback LLMs",
llm=primary_llm,
fallback_llms=[fallback1, fallback2]
)
task = Task(
description="Test task",
expected_output="Test output",
agent=agent
)
with patch.object(primary_llm, 'call') as mock_primary, \
patch.object(fallback1, 'call') as mock_fallback1, \
patch.object(fallback2, 'call') as mock_fallback2:
mock_primary.side_effect = Exception("Primary failed")
mock_fallback1.side_effect = Exception("Fallback 1 failed")
mock_fallback2.return_value = "Fallback 2 response"
result = agent.execute_task(task)
assert result == "Fallback 2 response"
mock_primary.assert_called_once()
mock_fallback1.assert_called_once()
mock_fallback2.assert_called_once()
def test_agent_fallback_auth_error_skips_fallbacks():
"""Test that authentication errors skip fallback attempts."""
primary_llm = LLM("gpt-4")
fallback_llm = LLM("gpt-3.5-turbo")
agent = Agent(
role="Test Agent",
goal="Test auth error handling",
backstory="I test auth error handling",
llm=primary_llm,
fallback_llms=[fallback_llm]
)
task = Task(
description="Test task",
expected_output="Test output",
agent=agent
)
with patch.object(primary_llm, 'call') as mock_primary, \
patch.object(fallback_llm, 'call') as mock_fallback, \
pytest.raises(AuthenticationError):
mock_primary.side_effect = AuthenticationError(
message="Invalid API key", llm_provider="openai", model="gpt-4"
)
agent.execute_task(task)
mock_primary.assert_called_once()
mock_fallback.assert_not_called()
def test_agent_fallback_context_window_error():
"""Test that context window errors try fallbacks."""
primary_llm = LLM("gpt-4")
fallback_llm = LLM("gpt-3.5-turbo")
agent = Agent(
role="Test Agent",
goal="Test context window error handling",
backstory="I test context window error handling",
llm=primary_llm,
fallback_llms=[fallback_llm]
)
task = Task(
description="Test task",
expected_output="Test output",
agent=agent
)
with patch.object(primary_llm, 'call') as mock_primary, \
patch.object(fallback_llm, 'call') as mock_fallback:
mock_primary.side_effect = ContextWindowExceededError(
message="Context window exceeded", model="gpt-4", llm_provider="openai"
)
mock_fallback.return_value = "Fallback response"
result = agent.execute_task(task)
assert result == "Fallback response"
mock_primary.assert_called_once()
mock_fallback.assert_called_once()
def test_agent_all_llms_fail():
"""Test behavior when all LLMs fail."""
primary_llm = LLM("gpt-4")
fallback_llm = LLM("gpt-3.5-turbo")
agent = Agent(
role="Test Agent",
goal="Test all LLMs failing",
backstory="I test all LLMs failing",
llm=primary_llm,
fallback_llms=[fallback_llm]
)
task = Task(
description="Test task",
expected_output="Test output",
agent=agent
)
with patch.object(primary_llm, 'call') as mock_primary, \
patch.object(fallback_llm, 'call') as mock_fallback, \
pytest.raises(Exception, match="Fallback failed"):
mock_primary.side_effect = Exception("Primary failed")
mock_fallback.side_effect = Exception("Fallback failed")
agent.execute_task(task)
mock_primary.assert_called_once()
mock_fallback.assert_called_once()
def test_agent_backward_compatibility():
"""Test that agents without fallback LLMs work as before."""
agent = Agent(
role="Test Agent",
goal="Test backward compatibility",
backstory="I test backward compatibility",
llm=LLM("gpt-4")
)
task = Task(
description="Test task",
expected_output="Test output",
agent=agent
)
with patch.object(agent.llm, 'call') as mock_llm:
mock_llm.return_value = "Primary response"
result = agent.execute_task(task)
assert result == "Primary response"
mock_llm.assert_called_once()
def test_get_llm_response_with_fallbacks():
"""Test get_llm_response function directly with fallbacks."""
primary_llm = MagicMock()
fallback_llm = MagicMock()
printer = Printer()
primary_llm.call.side_effect = Exception("Primary failed")
fallback_llm.call.return_value = "Fallback success"
result = get_llm_response(
llm=primary_llm,
messages=[{"role": "user", "content": "test"}],
callbacks=[],
printer=printer,
fallback_llms=[fallback_llm]
)
assert result == "Fallback success"
primary_llm.call.assert_called_once()
fallback_llm.call.assert_called_once()
def test_get_llm_response_no_fallbacks():
"""Test get_llm_response function without fallbacks (backward compatibility)."""
primary_llm = MagicMock()
printer = Printer()
primary_llm.call.return_value = "Primary success"
result = get_llm_response(
llm=primary_llm,
messages=[{"role": "user", "content": "test"}],
callbacks=[],
printer=printer
)
assert result == "Primary success"
primary_llm.call.assert_called_once()
def test_agent_fallback_llms_string_initialization():
"""Test that fallback LLMs can be initialized with string model names."""
agent = Agent(
role="Test Agent",
goal="Test string initialization",
backstory="I test string initialization",
llm="gpt-4",
fallback_llms=["gpt-3.5-turbo", "claude-3-sonnet-20240229"]
)
assert agent.fallback_llms is not None
assert len(agent.fallback_llms) == 2
assert hasattr(agent.fallback_llms[0], 'call')
assert hasattr(agent.fallback_llms[1], 'call')
def test_agent_primary_success_no_fallback():
"""Test that fallback LLMs are not called when primary succeeds."""
primary_llm = LLM("gpt-4")
fallback_llm = LLM("gpt-3.5-turbo")
agent = Agent(
role="Test Agent",
goal="Test primary success",
backstory="I test primary success",
llm=primary_llm,
fallback_llms=[fallback_llm]
)
task = Task(
description="Test task",
expected_output="Test output",
agent=agent
)
with patch.object(primary_llm, 'call') as mock_primary, \
patch.object(fallback_llm, 'call') as mock_fallback:
mock_primary.return_value = "Primary success"
result = agent.execute_task(task)
assert result == "Primary success"
mock_primary.assert_called_once()
mock_fallback.assert_not_called()
def test_agent_empty_response_triggers_fallback():
"""Test that empty responses from primary LLM trigger fallback."""
primary_llm = LLM("gpt-4")
fallback_llm = LLM("gpt-3.5-turbo")
agent = Agent(
role="Test Agent",
goal="Test empty response handling",
backstory="I test empty response handling",
llm=primary_llm,
fallback_llms=[fallback_llm]
)
task = Task(
description="Test task",
expected_output="Test output",
agent=agent
)
with patch.object(primary_llm, 'call') as mock_primary, \
patch.object(fallback_llm, 'call') as mock_fallback:
mock_primary.return_value = ""
mock_fallback.return_value = "Fallback response"
result = agent.execute_task(task)
assert result == "Fallback response"
mock_primary.assert_called_once()
mock_fallback.assert_called_once()