diff --git a/src/crewai/agent.py b/src/crewai/agent.py index c8e34b2e6..060dc09e8 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -91,6 +91,9 @@ class Agent(BaseAgent): function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( description="Language model that will run the agent.", default=None ) + fallback_llms: Optional[List[Union[str, InstanceOf[BaseLLM], Any]]] = Field( + default=None, description="List of fallback language models to try if the primary LLM fails." + ) system_template: Optional[str] = Field( default=None, description="System format for the agent." ) @@ -174,6 +177,8 @@ class Agent(BaseAgent): self.agent_ops_agent_name = self.role self.llm = create_llm(self.llm) + if self.fallback_llms: + self.fallback_llms = [create_llm(fallback_llm) for fallback_llm in self.fallback_llms] if self.function_calling_llm and not isinstance( self.function_calling_llm, BaseLLM ): diff --git a/src/crewai/agents/crew_agent_executor.py b/src/crewai/agents/crew_agent_executor.py index 48099fad9..1f1e17b2a 100644 --- a/src/crewai/agents/crew_agent_executor.py +++ b/src/crewai/agents/crew_agent_executor.py @@ -159,6 +159,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): messages=self.messages, callbacks=self.callbacks, printer=self._printer, + fallback_llms=getattr(self.agent, 'fallback_llms', None), ) formatted_answer = process_llm_response(answer, self.use_stop_words) diff --git a/src/crewai/lite_agent.py b/src/crewai/lite_agent.py index 8dfbfaff8..114e2a760 100644 --- a/src/crewai/lite_agent.py +++ b/src/crewai/lite_agent.py @@ -526,6 +526,7 @@ class LiteAgent(FlowTrackable, BaseModel): messages=self._messages, callbacks=self._callbacks, printer=self._printer, + fallback_llms=getattr(self, 'fallback_llms', None), ) # Emit LLM call completed event diff --git a/src/crewai/utilities/agent_utils.py b/src/crewai/utilities/agent_utils.py index be0abd92c..41f281188 100644 --- a/src/crewai/utilities/agent_utils.py +++ b/src/crewai/utilities/agent_utils.py @@ -145,27 +145,49 @@ def get_llm_response( messages: List[Dict[str, str]], callbacks: List[Any], printer: Printer, + fallback_llms: Optional[List[Union[LLM, BaseLLM]]] = None, ) -> str: - """Call the LLM and return the response, handling any invalid responses.""" - try: - answer = llm.call( - messages, - callbacks=callbacks, - ) - except Exception as e: - printer.print( - content=f"Error during LLM call: {e}", - color="red", - ) - raise e - if not answer: - printer.print( - content="Received None or empty response from LLM call.", - color="red", - ) - raise ValueError("Invalid response from LLM call - None or empty.") - - return answer + """Call the LLM and return the response, handling any invalid responses and trying fallbacks if available.""" + llms_to_try = [llm] + if fallback_llms: + llms_to_try.extend(fallback_llms) + + last_exception = None + + for i, current_llm in enumerate(llms_to_try): + try: + answer = current_llm.call( + messages, + callbacks=callbacks, + ) + if not answer: + error_msg = "Received None or empty response from LLM call." + printer.print(content=error_msg, color="red") + if i < len(llms_to_try) - 1: + printer.print(content=f"Trying fallback LLM {i+1}...", color="yellow") + continue + else: + raise ValueError("Invalid response from LLM call - None or empty.") + return answer + except Exception as e: + last_exception = e + if i == 0: + printer.print(content=f"Primary LLM failed: {e}", color="red") + else: + printer.print(content=f"Fallback LLM {i} failed: {e}", color="red") + + if e.__class__.__module__.startswith("litellm"): + error_str = str(e).lower() + if any(term in error_str for term in ["authentication", "api key", "unauthorized", "forbidden"]): + printer.print(content="Authentication error detected, skipping remaining fallbacks", color="red") + raise e + + if i < len(llms_to_try) - 1: + printer.print(content=f"Trying fallback LLM {i+1}...", color="yellow") + continue + + printer.print(content="All LLMs failed, raising last exception", color="red") + raise last_exception def process_llm_response( diff --git a/tests/test_agent_fallback_llms.py b/tests/test_agent_fallback_llms.py new file mode 100644 index 000000000..6c0d7e733 --- /dev/null +++ b/tests/test_agent_fallback_llms.py @@ -0,0 +1,318 @@ +"""Tests for Agent fallback LLM functionality.""" + +import pytest +from unittest.mock import patch, MagicMock + +from crewai import Agent, Task +from crewai.llm import LLM +from crewai.utilities.agent_utils import get_llm_response +from crewai.utilities import Printer +from litellm.exceptions import AuthenticationError, ContextWindowExceededError + + +def test_agent_with_fallback_llms_basic(): + """Test agent with fallback LLMs when primary fails.""" + primary_llm = LLM("gpt-4") + fallback_llm = LLM("gpt-3.5-turbo") + + agent = Agent( + role="Test Agent", + goal="Test fallback functionality", + backstory="I test fallback LLMs", + llm=primary_llm, + fallback_llms=[fallback_llm] + ) + + task = Task( + description="Simple test task", + expected_output="Test output", + agent=agent + ) + + with patch.object(primary_llm, 'call') as mock_primary, \ + patch.object(fallback_llm, 'call') as mock_fallback: + + mock_primary.side_effect = Exception("Primary LLM failed") + mock_fallback.return_value = "Fallback response" + + result = agent.execute_task(task) + + assert result == "Fallback response" + mock_primary.assert_called_once() + mock_fallback.assert_called_once() + + +def test_agent_fallback_llms_multiple(): + """Test agent with multiple fallback LLMs.""" + primary_llm = LLM("gpt-4") + fallback1 = LLM("gpt-3.5-turbo") + fallback2 = LLM("claude-3-sonnet-20240229") + + agent = Agent( + role="Test Agent", + goal="Test multiple fallbacks", + backstory="I test multiple fallback LLMs", + llm=primary_llm, + fallback_llms=[fallback1, fallback2] + ) + + task = Task( + description="Test task", + expected_output="Test output", + agent=agent + ) + + with patch.object(primary_llm, 'call') as mock_primary, \ + patch.object(fallback1, 'call') as mock_fallback1, \ + patch.object(fallback2, 'call') as mock_fallback2: + + mock_primary.side_effect = Exception("Primary failed") + mock_fallback1.side_effect = Exception("Fallback 1 failed") + mock_fallback2.return_value = "Fallback 2 response" + + result = agent.execute_task(task) + + assert result == "Fallback 2 response" + mock_primary.assert_called_once() + mock_fallback1.assert_called_once() + mock_fallback2.assert_called_once() + + +def test_agent_fallback_auth_error_skips_fallbacks(): + """Test that authentication errors skip fallback attempts.""" + primary_llm = LLM("gpt-4") + fallback_llm = LLM("gpt-3.5-turbo") + + agent = Agent( + role="Test Agent", + goal="Test auth error handling", + backstory="I test auth error handling", + llm=primary_llm, + fallback_llms=[fallback_llm] + ) + + task = Task( + description="Test task", + expected_output="Test output", + agent=agent + ) + + with patch.object(primary_llm, 'call') as mock_primary, \ + patch.object(fallback_llm, 'call') as mock_fallback, \ + pytest.raises(AuthenticationError): + + mock_primary.side_effect = AuthenticationError( + message="Invalid API key", llm_provider="openai", model="gpt-4" + ) + + agent.execute_task(task) + + mock_primary.assert_called_once() + mock_fallback.assert_not_called() + + +def test_agent_fallback_context_window_error(): + """Test that context window errors try fallbacks.""" + primary_llm = LLM("gpt-4") + fallback_llm = LLM("gpt-3.5-turbo") + + agent = Agent( + role="Test Agent", + goal="Test context window error handling", + backstory="I test context window error handling", + llm=primary_llm, + fallback_llms=[fallback_llm] + ) + + task = Task( + description="Test task", + expected_output="Test output", + agent=agent + ) + + with patch.object(primary_llm, 'call') as mock_primary, \ + patch.object(fallback_llm, 'call') as mock_fallback: + + mock_primary.side_effect = ContextWindowExceededError("Context window exceeded") + mock_fallback.return_value = "Fallback response" + + result = agent.execute_task(task) + + assert result == "Fallback response" + mock_primary.assert_called_once() + mock_fallback.assert_called_once() + + +def test_agent_all_llms_fail(): + """Test behavior when all LLMs fail.""" + primary_llm = LLM("gpt-4") + fallback_llm = LLM("gpt-3.5-turbo") + + agent = Agent( + role="Test Agent", + goal="Test all LLMs failing", + backstory="I test all LLMs failing", + llm=primary_llm, + fallback_llms=[fallback_llm] + ) + + task = Task( + description="Test task", + expected_output="Test output", + agent=agent + ) + + with patch.object(primary_llm, 'call') as mock_primary, \ + patch.object(fallback_llm, 'call') as mock_fallback, \ + pytest.raises(Exception, match="Fallback failed"): + + mock_primary.side_effect = Exception("Primary failed") + mock_fallback.side_effect = Exception("Fallback failed") + + agent.execute_task(task) + + mock_primary.assert_called_once() + mock_fallback.assert_called_once() + + +def test_agent_backward_compatibility(): + """Test that agents without fallback LLMs work as before.""" + agent = Agent( + role="Test Agent", + goal="Test backward compatibility", + backstory="I test backward compatibility", + llm=LLM("gpt-4") + ) + + task = Task( + description="Test task", + expected_output="Test output", + agent=agent + ) + + with patch.object(agent.llm, 'call') as mock_llm: + mock_llm.return_value = "Primary response" + + result = agent.execute_task(task) + + assert result == "Primary response" + mock_llm.assert_called_once() + + +def test_get_llm_response_with_fallbacks(): + """Test get_llm_response function directly with fallbacks.""" + primary_llm = MagicMock() + fallback_llm = MagicMock() + printer = Printer() + + primary_llm.call.side_effect = Exception("Primary failed") + fallback_llm.call.return_value = "Fallback success" + + result = get_llm_response( + llm=primary_llm, + messages=[{"role": "user", "content": "test"}], + callbacks=[], + printer=printer, + fallback_llms=[fallback_llm] + ) + + assert result == "Fallback success" + primary_llm.call.assert_called_once() + fallback_llm.call.assert_called_once() + + +def test_get_llm_response_no_fallbacks(): + """Test get_llm_response function without fallbacks (backward compatibility).""" + primary_llm = MagicMock() + printer = Printer() + + primary_llm.call.return_value = "Primary success" + + result = get_llm_response( + llm=primary_llm, + messages=[{"role": "user", "content": "test"}], + callbacks=[], + printer=printer + ) + + assert result == "Primary success" + primary_llm.call.assert_called_once() + + +def test_agent_fallback_llms_string_initialization(): + """Test that fallback LLMs can be initialized with string model names.""" + agent = Agent( + role="Test Agent", + goal="Test string initialization", + backstory="I test string initialization", + llm="gpt-4", + fallback_llms=["gpt-3.5-turbo", "claude-3-sonnet-20240229"] + ) + + assert agent.fallback_llms is not None + assert len(agent.fallback_llms) == 2 + assert hasattr(agent.fallback_llms[0], 'call') + assert hasattr(agent.fallback_llms[1], 'call') + + +def test_agent_primary_success_no_fallback(): + """Test that fallback LLMs are not called when primary succeeds.""" + primary_llm = LLM("gpt-4") + fallback_llm = LLM("gpt-3.5-turbo") + + agent = Agent( + role="Test Agent", + goal="Test primary success", + backstory="I test primary success", + llm=primary_llm, + fallback_llms=[fallback_llm] + ) + + task = Task( + description="Test task", + expected_output="Test output", + agent=agent + ) + + with patch.object(primary_llm, 'call') as mock_primary, \ + patch.object(fallback_llm, 'call') as mock_fallback: + + mock_primary.return_value = "Primary success" + + result = agent.execute_task(task) + + assert result == "Primary success" + mock_primary.assert_called_once() + mock_fallback.assert_not_called() + + +def test_agent_empty_response_triggers_fallback(): + """Test that empty responses from primary LLM trigger fallback.""" + primary_llm = LLM("gpt-4") + fallback_llm = LLM("gpt-3.5-turbo") + + agent = Agent( + role="Test Agent", + goal="Test empty response handling", + backstory="I test empty response handling", + llm=primary_llm, + fallback_llms=[fallback_llm] + ) + + task = Task( + description="Test task", + expected_output="Test output", + agent=agent + ) + + with patch.object(primary_llm, 'call') as mock_primary, \ + patch.object(fallback_llm, 'call') as mock_fallback: + + mock_primary.return_value = "" + mock_fallback.return_value = "Fallback response" + + result = agent.execute_task(task) + + assert result == "Fallback response" + mock_primary.assert_called_once() + mock_fallback.assert_called_once()