diff --git a/docs/installation.mdx b/docs/installation.mdx index d629c4c80..8abba152a 100644 --- a/docs/installation.mdx +++ b/docs/installation.mdx @@ -139,7 +139,6 @@ Now let's get you set up! 🚀 │ └── __init__.py └── config/ ├── agents.yaml - ├── config.yaml └── tasks.yaml ``` diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 5823ef7f9..3a4d083d4 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -1,15 +1,12 @@ -import os import shutil import subprocess from typing import Any, Dict, List, Literal, Optional, Union -from litellm import AuthenticationError as LiteLLMAuthenticationError from pydantic import Field, InstanceOf, PrivateAttr, model_validator from crewai.agents import CacheHandler from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.crew_agent_executor import CrewAgentExecutor -from crewai.cli.constants import ENV_VARS, LITELLM_PARAMS from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context @@ -262,8 +259,8 @@ class Agent(BaseAgent): } )["output"] except Exception as e: - if isinstance(e, LiteLLMAuthenticationError): - # Do not retry on authentication errors + if e.__class__.__module__.startswith("litellm"): + # Do not retry on litellm errors raise e self._times_executed += 1 if self._times_executed > self.max_retry_limit: diff --git a/src/crewai/agents/crew_agent_executor.py b/src/crewai/agents/crew_agent_executor.py index ee5b9c582..b9797193c 100644 --- a/src/crewai/agents/crew_agent_executor.py +++ b/src/crewai/agents/crew_agent_executor.py @@ -3,8 +3,6 @@ import re from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union -from litellm.exceptions import AuthenticationError as LiteLLMAuthenticationError - from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin from crewai.agents.parser import ( @@ -103,7 +101,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): try: formatted_answer = self._invoke_loop() except Exception as e: - raise e + if e.__class__.__module__.startswith("litellm"): + # Do not retry on litellm errors + raise e + else: + self._handle_unknown_error(e) + raise e if self.ask_for_human_input: formatted_answer = self._handle_human_feedback(formatted_answer) @@ -146,36 +149,25 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): formatted_answer = self._handle_output_parser_exception(e) except Exception as e: + if e.__class__.__module__.startswith("litellm"): + # Do not retry on litellm errors + raise e if self._is_context_length_exceeded(e): self._handle_context_length() continue - elif self._is_litellm_authentication_error(e): - self._handle_litellm_auth_error(e) - raise e else: - self._printer.print( - content=f"Unhandled exception: {e}", - color="red", - ) + self._handle_unknown_error(e) + raise e finally: self.iterations += 1 self._show_logs(formatted_answer) return formatted_answer - def _is_litellm_authentication_error(self, exception: Exception) -> bool: - """Check if the exception is a litellm authentication error.""" - if LiteLLMAuthenticationError and isinstance( - exception, LiteLLMAuthenticationError - ): - return True - - return False - - def _handle_litellm_auth_error(self, exception: Exception) -> None: - """Handle litellm authentication error by informing the user and exiting.""" + def _handle_unknown_error(self, exception: Exception) -> None: + """Handle unknown errors by informing the user.""" self._printer.print( - content="Authentication error with litellm occurred. Please check your API key and configuration.", + content="An unknown error occurred. Please check the details below.", color="red", ) self._printer.print( diff --git a/tests/agent_test.py b/tests/agent_test.py index 46b20004e..fda47daaf 100644 --- a/tests/agent_test.py +++ b/tests/agent_test.py @@ -1623,7 +1623,7 @@ def test_litellm_auth_error_handling(): agent=agent, ) - # Mock the LLM call to raise LiteLLMAuthenticationError + # Mock the LLM call to raise AuthenticationError with ( patch.object(LLM, "call") as mock_llm_call, pytest.raises(LiteLLMAuthenticationError, match="Invalid API key"), @@ -1638,13 +1638,13 @@ def test_litellm_auth_error_handling(): def test_crew_agent_executor_litellm_auth_error(): - """Test that CrewAgentExecutor properly identifies and handles LiteLLM authentication errors.""" - from litellm import AuthenticationError as LiteLLMAuthenticationError + """Test that CrewAgentExecutor handles LiteLLM authentication errors by raising them.""" + from litellm.exceptions import AuthenticationError from crewai.agents.tools_handler import ToolsHandler from crewai.utilities import Printer - # Create an agent and executor with max_retry_limit=0 + # Create an agent and executor agent = Agent( role="test role", goal="test goal", @@ -1672,13 +1672,13 @@ def test_crew_agent_executor_litellm_auth_error(): tools_handler=ToolsHandler(), ) - # Mock the LLM call to raise LiteLLMAuthenticationError + # Mock the LLM call to raise AuthenticationError with ( patch.object(LLM, "call") as mock_llm_call, patch.object(Printer, "print") as mock_printer, - pytest.raises(LiteLLMAuthenticationError, match="Invalid API key"), + pytest.raises(AuthenticationError) as exc_info, ): - mock_llm_call.side_effect = LiteLLMAuthenticationError( + mock_llm_call.side_effect = AuthenticationError( message="Invalid API key", llm_provider="openai", model="gpt-4" ) executor.invoke( @@ -1689,14 +1689,53 @@ def test_crew_agent_executor_litellm_auth_error(): } ) - # Verify error handling + # Verify error handling messages + error_message = f"Error during LLM call: {str(mock_llm_call.side_effect)}" mock_printer.assert_any_call( - content="Authentication error with litellm occurred. Please check your API key and configuration.", - color="red", - ) - mock_printer.assert_any_call( - content="Error details: litellm.AuthenticationError: Invalid API key", + content=error_message, color="red", ) + # Verify the call was only made once (no retries) mock_llm_call.assert_called_once() + + # Assert that the exception was raised and has the expected attributes + assert exc_info.type is AuthenticationError + assert "Invalid API key".lower() in exc_info.value.message.lower() + assert exc_info.value.llm_provider == "openai" + assert exc_info.value.model == "gpt-4" + + +def test_litellm_anthropic_error_handling(): + """Test that AnthropicError from LiteLLM is handled correctly and not retried.""" + from litellm.llms.anthropic.common_utils import AnthropicError + + # Create an agent with a mocked LLM that uses an Anthropic model + agent = Agent( + role="test role", + goal="test goal", + backstory="test backstory", + llm=LLM(model="claude-3.5-sonnet-20240620"), + max_retry_limit=0, + ) + + # Create a task + task = Task( + description="Test task", + expected_output="Test output", + agent=agent, + ) + + # Mock the LLM call to raise AnthropicError + with ( + patch.object(LLM, "call") as mock_llm_call, + pytest.raises(AnthropicError, match="Test Anthropic error"), + ): + mock_llm_call.side_effect = AnthropicError( + status_code=500, + message="Test Anthropic error", + ) + agent.execute_task(task) + + # Verify the LLM call was only made once (no retries) + mock_llm_call.assert_called_once()