mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-29 10:48:29 +00:00
Compare commits
2 Commits
1.2.0
...
devin/1750
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bfe3931dcd | ||
|
|
06e2683fd8 |
@@ -91,6 +91,9 @@ class Agent(BaseAgent):
|
|||||||
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
||||||
description="Language model that will run the agent.", default=None
|
description="Language model that will run the agent.", default=None
|
||||||
)
|
)
|
||||||
|
fallback_llms: Optional[List[Union[str, InstanceOf[BaseLLM], Any]]] = Field(
|
||||||
|
default=None, description="List of fallback language models to try if the primary LLM fails."
|
||||||
|
)
|
||||||
system_template: Optional[str] = Field(
|
system_template: Optional[str] = Field(
|
||||||
default=None, description="System format for the agent."
|
default=None, description="System format for the agent."
|
||||||
)
|
)
|
||||||
@@ -174,6 +177,8 @@ class Agent(BaseAgent):
|
|||||||
self.agent_ops_agent_name = self.role
|
self.agent_ops_agent_name = self.role
|
||||||
|
|
||||||
self.llm = create_llm(self.llm)
|
self.llm = create_llm(self.llm)
|
||||||
|
if self.fallback_llms:
|
||||||
|
self.fallback_llms = [create_llm(fallback_llm) for fallback_llm in self.fallback_llms]
|
||||||
if self.function_calling_llm and not isinstance(
|
if self.function_calling_llm and not isinstance(
|
||||||
self.function_calling_llm, BaseLLM
|
self.function_calling_llm, BaseLLM
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -159,6 +159,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
callbacks=self.callbacks,
|
callbacks=self.callbacks,
|
||||||
printer=self._printer,
|
printer=self._printer,
|
||||||
|
fallback_llms=getattr(self.agent, 'fallback_llms', None),
|
||||||
)
|
)
|
||||||
formatted_answer = process_llm_response(answer, self.use_stop_words)
|
formatted_answer = process_llm_response(answer, self.use_stop_words)
|
||||||
|
|
||||||
|
|||||||
@@ -526,6 +526,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
messages=self._messages,
|
messages=self._messages,
|
||||||
callbacks=self._callbacks,
|
callbacks=self._callbacks,
|
||||||
printer=self._printer,
|
printer=self._printer,
|
||||||
|
fallback_llms=getattr(self, 'fallback_llms', None),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Emit LLM call completed event
|
# Emit LLM call completed event
|
||||||
|
|||||||
@@ -145,27 +145,52 @@ def get_llm_response(
|
|||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
callbacks: List[Any],
|
callbacks: List[Any],
|
||||||
printer: Printer,
|
printer: Printer,
|
||||||
|
fallback_llms: Optional[List[Union[LLM, BaseLLM]]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Call the LLM and return the response, handling any invalid responses."""
|
"""Call the LLM and return the response, handling any invalid responses and trying fallbacks if available."""
|
||||||
try:
|
llms_to_try = [llm]
|
||||||
answer = llm.call(
|
if fallback_llms:
|
||||||
messages,
|
llms_to_try.extend(fallback_llms)
|
||||||
callbacks=callbacks,
|
|
||||||
)
|
last_exception = None
|
||||||
except Exception as e:
|
|
||||||
printer.print(
|
for i, current_llm in enumerate(llms_to_try):
|
||||||
content=f"Error during LLM call: {e}",
|
try:
|
||||||
color="red",
|
answer = current_llm.call(
|
||||||
)
|
messages,
|
||||||
raise e
|
callbacks=callbacks,
|
||||||
if not answer:
|
)
|
||||||
printer.print(
|
if not answer:
|
||||||
content="Received None or empty response from LLM call.",
|
error_msg = "Received None or empty response from LLM call."
|
||||||
color="red",
|
printer.print(content=error_msg, color="red")
|
||||||
)
|
if i < len(llms_to_try) - 1:
|
||||||
raise ValueError("Invalid response from LLM call - None or empty.")
|
printer.print(content=f"Trying fallback LLM {i+1}...", color="yellow")
|
||||||
|
continue
|
||||||
return answer
|
else:
|
||||||
|
raise ValueError("Invalid response from LLM call - None or empty.")
|
||||||
|
return answer
|
||||||
|
except Exception as e:
|
||||||
|
last_exception = e
|
||||||
|
if i == 0:
|
||||||
|
printer.print(content=f"Primary LLM failed: {e}", color="red")
|
||||||
|
else:
|
||||||
|
printer.print(content=f"Fallback LLM {i} failed: {e}", color="red")
|
||||||
|
|
||||||
|
if e.__class__.__module__.startswith("litellm"):
|
||||||
|
error_str = str(e).lower()
|
||||||
|
if any(term in error_str for term in ["authentication", "api key", "unauthorized", "forbidden"]):
|
||||||
|
printer.print(content="Authentication error detected, skipping remaining fallbacks", color="red")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
if i < len(llms_to_try) - 1:
|
||||||
|
printer.print(content=f"Trying fallback LLM {i+1}...", color="yellow")
|
||||||
|
continue
|
||||||
|
|
||||||
|
printer.print(content="All LLMs failed, raising last exception", color="red")
|
||||||
|
if last_exception is not None:
|
||||||
|
raise last_exception
|
||||||
|
else:
|
||||||
|
raise RuntimeError("All LLMs failed but no exception was captured")
|
||||||
|
|
||||||
|
|
||||||
def process_llm_response(
|
def process_llm_response(
|
||||||
|
|||||||
@@ -1984,7 +1984,7 @@ def test_crew_agent_executor_litellm_auth_error():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Verify error handling messages
|
# Verify error handling messages
|
||||||
error_message = f"Error during LLM call: {str(mock_llm_call.side_effect)}"
|
error_message = f"Primary LLM failed: {str(mock_llm_call.side_effect)}"
|
||||||
mock_printer.assert_any_call(
|
mock_printer.assert_any_call(
|
||||||
content=error_message,
|
content=error_message,
|
||||||
color="red",
|
color="red",
|
||||||
|
|||||||
320
tests/test_agent_fallback_llms.py
Normal file
320
tests/test_agent_fallback_llms.py
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
"""Tests for Agent fallback LLM functionality."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
from crewai import Agent, Task
|
||||||
|
from crewai.llm import LLM
|
||||||
|
from crewai.utilities.agent_utils import get_llm_response
|
||||||
|
from crewai.utilities import Printer
|
||||||
|
from litellm.exceptions import AuthenticationError, ContextWindowExceededError
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_with_fallback_llms_basic():
|
||||||
|
"""Test agent with fallback LLMs when primary fails."""
|
||||||
|
primary_llm = LLM("gpt-4")
|
||||||
|
fallback_llm = LLM("gpt-3.5-turbo")
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test fallback functionality",
|
||||||
|
backstory="I test fallback LLMs",
|
||||||
|
llm=primary_llm,
|
||||||
|
fallback_llms=[fallback_llm]
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Simple test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=agent
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(primary_llm, 'call') as mock_primary, \
|
||||||
|
patch.object(fallback_llm, 'call') as mock_fallback:
|
||||||
|
|
||||||
|
mock_primary.side_effect = Exception("Primary LLM failed")
|
||||||
|
mock_fallback.return_value = "Fallback response"
|
||||||
|
|
||||||
|
result = agent.execute_task(task)
|
||||||
|
|
||||||
|
assert result == "Fallback response"
|
||||||
|
mock_primary.assert_called_once()
|
||||||
|
mock_fallback.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_fallback_llms_multiple():
|
||||||
|
"""Test agent with multiple fallback LLMs."""
|
||||||
|
primary_llm = LLM("gpt-4")
|
||||||
|
fallback1 = LLM("gpt-3.5-turbo")
|
||||||
|
fallback2 = LLM("claude-3-sonnet-20240229")
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test multiple fallbacks",
|
||||||
|
backstory="I test multiple fallback LLMs",
|
||||||
|
llm=primary_llm,
|
||||||
|
fallback_llms=[fallback1, fallback2]
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=agent
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(primary_llm, 'call') as mock_primary, \
|
||||||
|
patch.object(fallback1, 'call') as mock_fallback1, \
|
||||||
|
patch.object(fallback2, 'call') as mock_fallback2:
|
||||||
|
|
||||||
|
mock_primary.side_effect = Exception("Primary failed")
|
||||||
|
mock_fallback1.side_effect = Exception("Fallback 1 failed")
|
||||||
|
mock_fallback2.return_value = "Fallback 2 response"
|
||||||
|
|
||||||
|
result = agent.execute_task(task)
|
||||||
|
|
||||||
|
assert result == "Fallback 2 response"
|
||||||
|
mock_primary.assert_called_once()
|
||||||
|
mock_fallback1.assert_called_once()
|
||||||
|
mock_fallback2.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_fallback_auth_error_skips_fallbacks():
|
||||||
|
"""Test that authentication errors skip fallback attempts."""
|
||||||
|
primary_llm = LLM("gpt-4")
|
||||||
|
fallback_llm = LLM("gpt-3.5-turbo")
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test auth error handling",
|
||||||
|
backstory="I test auth error handling",
|
||||||
|
llm=primary_llm,
|
||||||
|
fallback_llms=[fallback_llm]
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=agent
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(primary_llm, 'call') as mock_primary, \
|
||||||
|
patch.object(fallback_llm, 'call') as mock_fallback, \
|
||||||
|
pytest.raises(AuthenticationError):
|
||||||
|
|
||||||
|
mock_primary.side_effect = AuthenticationError(
|
||||||
|
message="Invalid API key", llm_provider="openai", model="gpt-4"
|
||||||
|
)
|
||||||
|
|
||||||
|
agent.execute_task(task)
|
||||||
|
|
||||||
|
mock_primary.assert_called_once()
|
||||||
|
mock_fallback.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_fallback_context_window_error():
|
||||||
|
"""Test that context window errors try fallbacks."""
|
||||||
|
primary_llm = LLM("gpt-4")
|
||||||
|
fallback_llm = LLM("gpt-3.5-turbo")
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test context window error handling",
|
||||||
|
backstory="I test context window error handling",
|
||||||
|
llm=primary_llm,
|
||||||
|
fallback_llms=[fallback_llm]
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=agent
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(primary_llm, 'call') as mock_primary, \
|
||||||
|
patch.object(fallback_llm, 'call') as mock_fallback:
|
||||||
|
|
||||||
|
mock_primary.side_effect = ContextWindowExceededError(
|
||||||
|
message="Context window exceeded", model="gpt-4", llm_provider="openai"
|
||||||
|
)
|
||||||
|
mock_fallback.return_value = "Fallback response"
|
||||||
|
|
||||||
|
result = agent.execute_task(task)
|
||||||
|
|
||||||
|
assert result == "Fallback response"
|
||||||
|
mock_primary.assert_called_once()
|
||||||
|
mock_fallback.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_all_llms_fail():
|
||||||
|
"""Test behavior when all LLMs fail."""
|
||||||
|
primary_llm = LLM("gpt-4")
|
||||||
|
fallback_llm = LLM("gpt-3.5-turbo")
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test all LLMs failing",
|
||||||
|
backstory="I test all LLMs failing",
|
||||||
|
llm=primary_llm,
|
||||||
|
fallback_llms=[fallback_llm]
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=agent
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(primary_llm, 'call') as mock_primary, \
|
||||||
|
patch.object(fallback_llm, 'call') as mock_fallback, \
|
||||||
|
pytest.raises(Exception, match="Fallback failed"):
|
||||||
|
|
||||||
|
mock_primary.side_effect = Exception("Primary failed")
|
||||||
|
mock_fallback.side_effect = Exception("Fallback failed")
|
||||||
|
|
||||||
|
agent.execute_task(task)
|
||||||
|
|
||||||
|
mock_primary.assert_called_once()
|
||||||
|
mock_fallback.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_backward_compatibility():
|
||||||
|
"""Test that agents without fallback LLMs work as before."""
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test backward compatibility",
|
||||||
|
backstory="I test backward compatibility",
|
||||||
|
llm=LLM("gpt-4")
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=agent
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(agent.llm, 'call') as mock_llm:
|
||||||
|
mock_llm.return_value = "Primary response"
|
||||||
|
|
||||||
|
result = agent.execute_task(task)
|
||||||
|
|
||||||
|
assert result == "Primary response"
|
||||||
|
mock_llm.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_llm_response_with_fallbacks():
|
||||||
|
"""Test get_llm_response function directly with fallbacks."""
|
||||||
|
primary_llm = MagicMock()
|
||||||
|
fallback_llm = MagicMock()
|
||||||
|
printer = Printer()
|
||||||
|
|
||||||
|
primary_llm.call.side_effect = Exception("Primary failed")
|
||||||
|
fallback_llm.call.return_value = "Fallback success"
|
||||||
|
|
||||||
|
result = get_llm_response(
|
||||||
|
llm=primary_llm,
|
||||||
|
messages=[{"role": "user", "content": "test"}],
|
||||||
|
callbacks=[],
|
||||||
|
printer=printer,
|
||||||
|
fallback_llms=[fallback_llm]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == "Fallback success"
|
||||||
|
primary_llm.call.assert_called_once()
|
||||||
|
fallback_llm.call.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_llm_response_no_fallbacks():
|
||||||
|
"""Test get_llm_response function without fallbacks (backward compatibility)."""
|
||||||
|
primary_llm = MagicMock()
|
||||||
|
printer = Printer()
|
||||||
|
|
||||||
|
primary_llm.call.return_value = "Primary success"
|
||||||
|
|
||||||
|
result = get_llm_response(
|
||||||
|
llm=primary_llm,
|
||||||
|
messages=[{"role": "user", "content": "test"}],
|
||||||
|
callbacks=[],
|
||||||
|
printer=printer
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result == "Primary success"
|
||||||
|
primary_llm.call.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_fallback_llms_string_initialization():
|
||||||
|
"""Test that fallback LLMs can be initialized with string model names."""
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test string initialization",
|
||||||
|
backstory="I test string initialization",
|
||||||
|
llm="gpt-4",
|
||||||
|
fallback_llms=["gpt-3.5-turbo", "claude-3-sonnet-20240229"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert agent.fallback_llms is not None
|
||||||
|
assert len(agent.fallback_llms) == 2
|
||||||
|
assert hasattr(agent.fallback_llms[0], 'call')
|
||||||
|
assert hasattr(agent.fallback_llms[1], 'call')
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_primary_success_no_fallback():
|
||||||
|
"""Test that fallback LLMs are not called when primary succeeds."""
|
||||||
|
primary_llm = LLM("gpt-4")
|
||||||
|
fallback_llm = LLM("gpt-3.5-turbo")
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test primary success",
|
||||||
|
backstory="I test primary success",
|
||||||
|
llm=primary_llm,
|
||||||
|
fallback_llms=[fallback_llm]
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=agent
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(primary_llm, 'call') as mock_primary, \
|
||||||
|
patch.object(fallback_llm, 'call') as mock_fallback:
|
||||||
|
|
||||||
|
mock_primary.return_value = "Primary success"
|
||||||
|
|
||||||
|
result = agent.execute_task(task)
|
||||||
|
|
||||||
|
assert result == "Primary success"
|
||||||
|
mock_primary.assert_called_once()
|
||||||
|
mock_fallback.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_empty_response_triggers_fallback():
|
||||||
|
"""Test that empty responses from primary LLM trigger fallback."""
|
||||||
|
primary_llm = LLM("gpt-4")
|
||||||
|
fallback_llm = LLM("gpt-3.5-turbo")
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="Test Agent",
|
||||||
|
goal="Test empty response handling",
|
||||||
|
backstory="I test empty response handling",
|
||||||
|
llm=primary_llm,
|
||||||
|
fallback_llms=[fallback_llm]
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Test task",
|
||||||
|
expected_output="Test output",
|
||||||
|
agent=agent
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(primary_llm, 'call') as mock_primary, \
|
||||||
|
patch.object(fallback_llm, 'call') as mock_fallback:
|
||||||
|
|
||||||
|
mock_primary.return_value = ""
|
||||||
|
mock_fallback.return_value = "Fallback response"
|
||||||
|
|
||||||
|
result = agent.execute_task(task)
|
||||||
|
|
||||||
|
assert result == "Fallback response"
|
||||||
|
mock_primary.assert_called_once()
|
||||||
|
mock_fallback.assert_called_once()
|
||||||
Reference in New Issue
Block a user