diff --git a/src/crewai/agent.py b/src/crewai/agent.py index b53f30f5a..94b8ba065 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -1,15 +1,13 @@ -import os import shutil import subprocess from typing import Any, Dict, List, Literal, Optional, Union -from litellm import AuthenticationError as LiteLLMAuthenticationError +from litellm.llms.base_llm.chat.transformation import BaseLLMException from pydantic import Field, InstanceOf, PrivateAttr, model_validator from crewai.agents import CacheHandler from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.crew_agent_executor import CrewAgentExecutor -from crewai.cli.constants import ENV_VARS, LITELLM_PARAMS from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context @@ -262,7 +260,7 @@ class Agent(BaseAgent): } )["output"] except Exception as e: - if e.__class__.__module__.startswith("litellm.exceptions"): + if isinstance(e, BaseLLMException): # Do not retry on litellm errors raise e self._times_executed += 1 diff --git a/src/crewai/agents/crew_agent_executor.py b/src/crewai/agents/crew_agent_executor.py index d7bf97795..dd5252ea7 100644 --- a/src/crewai/agents/crew_agent_executor.py +++ b/src/crewai/agents/crew_agent_executor.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union from litellm.exceptions import AuthenticationError as LiteLLMAuthenticationError +from litellm.llms.base_llm.chat.transformation import BaseLLMException from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin @@ -142,10 +143,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): self._invoke_step_callback(formatted_answer) self._append_message(formatted_answer.text, role="assistant") - except OutputParserException as e: - formatted_answer = self._handle_output_parser_exception(e) - except Exception as e: + if isinstance(e, BaseLLMException): + # Stop execution on litellm errors + raise e if self._is_context_length_exceeded(e): self._handle_context_length() continue diff --git a/tests/agent_test.py b/tests/agent_test.py index 527ea3b88..db909ef2f 100644 --- a/tests/agent_test.py +++ b/tests/agent_test.py @@ -1623,7 +1623,7 @@ def test_litellm_auth_error_handling(): agent=agent, ) - # Mock the LLM call to raise LiteLLMAuthenticationError + # Mock the LLM call to raise AuthenticationError with ( patch.object(LLM, "call") as mock_llm_call, pytest.raises(LiteLLMAuthenticationError, match="Invalid API key"), @@ -1639,7 +1639,7 @@ def test_litellm_auth_error_handling(): def test_crew_agent_executor_litellm_auth_error(): """Test that CrewAgentExecutor handles LiteLLM authentication errors by raising them.""" - from litellm import AuthenticationError as LiteLLMAuthenticationError + from litellm.exceptions import AuthenticationError from crewai.agents.tools_handler import ToolsHandler from crewai.utilities import Printer @@ -1672,13 +1672,13 @@ def test_crew_agent_executor_litellm_auth_error(): tools_handler=ToolsHandler(), ) - # Mock the LLM call to raise LiteLLMAuthenticationError + # Mock the LLM call to raise AuthenticationError with ( patch.object(LLM, "call") as mock_llm_call, patch.object(Printer, "print") as mock_printer, - pytest.raises(LiteLLMAuthenticationError, match="Invalid API key"), + pytest.raises(AuthenticationError, match="Invalid API key"), ): - mock_llm_call.side_effect = LiteLLMAuthenticationError( + mock_llm_call.side_effect = AuthenticationError( message="Invalid API key", llm_provider="openai", model="gpt-4" ) executor.invoke( @@ -1711,7 +1711,7 @@ def test_litellm_anthropic_error_handling(): role="test role", goal="test goal", backstory="test backstory", - llm=LLM(model="claude-3.5-sonnet-20240620"), + llm=LLM(model="claude-3.5-sonnet-20240620"), max_retry_limit=0, )