This commit is contained in:
Brandon Hancock
2025-01-21 13:51:34 -05:00
parent a21e310d78
commit 002568f2b2
2 changed files with 57 additions and 10 deletions

View File

@@ -3,6 +3,8 @@ import re
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union
from litellm.exceptions import AuthenticationError as LiteLLMAuthenticationError
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.parser import (
@@ -114,8 +116,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
"""
formatted_answer = None
while not isinstance(formatted_answer, AgentFinish):
self.iterations += 1
try:
if self._has_reached_max_iterations():
print("Max iterations reached")
formatted_answer = self._handle_max_iterations_exceeded(
formatted_answer
)
@@ -123,8 +127,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self._enforce_rpm_limit()
print("Getting LLM response")
answer = self._get_llm_response()
print(f"LLM response: {answer}")
formatted_answer = self._process_llm_response(answer)
if isinstance(formatted_answer, AgentAction):
@@ -142,15 +147,54 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
formatted_answer = self._handle_output_parser_exception(e)
except Exception as e:
print(f"Exception: {e}")
if self._is_context_length_exceeded(e):
self._handle_context_length()
continue
elif self._is_litellm_authentication_error(e):
self._handle_litellm_auth_error(e)
break
else:
self._printer.print(
content=f"Unhandled exception: {e}",
color="red",
)
self._show_logs(formatted_answer)
return formatted_answer
def _is_litellm_authentication_error(self, exception: Exception) -> bool:
"""Check if the exception is a litellm authentication error."""
# Check if the exception is an instance of litellm.AuthenticationError
if LiteLLMAuthenticationError and isinstance(
exception, LiteLLMAuthenticationError
):
return True
# Alternatively, inspect the exception message
error_message = str(exception)
# Check if 'litellm.AuthenticationError' and 'invalid_api_key' are in the message
return "litellm.AuthenticationError" in error_message and (
"invalid_api_key" in error_message
or "Incorrect API key provided" in error_message
)
def _handle_litellm_auth_error(self, exception: Exception) -> None:
"""Handle litellm authentication error by informing the user and exiting."""
self._printer.print(
content="Authentication error with litellm occurred. Please check your API key and configuration.",
color="red",
)
self._printer.print(
content=f"Error details: {exception}",
color="red",
)
def _has_reached_max_iterations(self) -> bool:
"""Check if the maximum number of iterations has been reached."""
print(f"Max iterations: {self.max_iter}")
print(f"Iterations: {self.iterations}")
return self.iterations >= self.max_iter
def _enforce_rpm_limit(self) -> None:
@@ -160,10 +204,18 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
def _get_llm_response(self) -> str:
"""Call the LLM and return the response, handling any invalid responses."""
answer = self.llm.call(
self.messages,
callbacks=self.callbacks,
)
try:
answer = self.llm.call(
self.messages,
callbacks=self.callbacks,
)
except Exception as e:
self._printer.print(
content=f"Error during LLM call: {e}",
color="red",
)
# Re-raise the exception to let _invoke_loop handle it
raise e
if not answer:
self._printer.print(
@@ -184,7 +236,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE in e.error:
answer = answer.split("Observation:")[0].strip()
self.iterations += 1
return self._format_answer(answer)
def _handle_agent_action(

View File

@@ -24,12 +24,10 @@ def create_llm(
# 1) If llm_value is already an LLM object, return it directly
if isinstance(llm_value, LLM):
print("LLM value is already an LLM object")
return llm_value
# 2) If llm_value is a string (model name)
if isinstance(llm_value, str):
print("LLM value is a string")
try:
created_llm = LLM(model=llm_value)
return created_llm
@@ -39,12 +37,10 @@ def create_llm(
# 3) If llm_value is None, parse environment variables or use default
if llm_value is None:
print("LLM value is None")
return _llm_via_environment_or_fallback()
# 4) Otherwise, attempt to extract relevant attributes from an unknown object
try:
print("LLM value is an unknown object")
# Extract attributes with explicit types
model = (
getattr(llm_value, "model_name", None)