This commit is contained in:
Brandon Hancock
2025-01-21 13:51:34 -05:00
parent a21e310d78
commit 002568f2b2
2 changed files with 57 additions and 10 deletions

View File

@@ -3,6 +3,8 @@ import re
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
from litellm.exceptions import AuthenticationError as LiteLLMAuthenticationError
from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.parser import ( from crewai.agents.parser import (
@@ -114,8 +116,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
""" """
formatted_answer = None formatted_answer = None
while not isinstance(formatted_answer, AgentFinish): while not isinstance(formatted_answer, AgentFinish):
self.iterations += 1
try: try:
if self._has_reached_max_iterations(): if self._has_reached_max_iterations():
print("Max iterations reached")
formatted_answer = self._handle_max_iterations_exceeded( formatted_answer = self._handle_max_iterations_exceeded(
formatted_answer formatted_answer
) )
@@ -123,8 +127,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self._enforce_rpm_limit() self._enforce_rpm_limit()
print("Getting LLM response")
answer = self._get_llm_response() answer = self._get_llm_response()
print(f"LLM response: {answer}")
formatted_answer = self._process_llm_response(answer) formatted_answer = self._process_llm_response(answer)
if isinstance(formatted_answer, AgentAction): if isinstance(formatted_answer, AgentAction):
@@ -142,15 +147,54 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
formatted_answer = self._handle_output_parser_exception(e) formatted_answer = self._handle_output_parser_exception(e)
except Exception as e: except Exception as e:
print(f"Exception: {e}")
if self._is_context_length_exceeded(e): if self._is_context_length_exceeded(e):
self._handle_context_length() self._handle_context_length()
continue continue
elif self._is_litellm_authentication_error(e):
self._handle_litellm_auth_error(e)
break
else:
self._printer.print(
content=f"Unhandled exception: {e}",
color="red",
)
self._show_logs(formatted_answer) self._show_logs(formatted_answer)
return formatted_answer return formatted_answer
def _is_litellm_authentication_error(self, exception: Exception) -> bool:
"""Check if the exception is a litellm authentication error."""
# Check if the exception is an instance of litellm.AuthenticationError
if LiteLLMAuthenticationError and isinstance(
exception, LiteLLMAuthenticationError
):
return True
# Alternatively, inspect the exception message
error_message = str(exception)
# Check if 'litellm.AuthenticationError' and 'invalid_api_key' are in the message
return "litellm.AuthenticationError" in error_message and (
"invalid_api_key" in error_message
or "Incorrect API key provided" in error_message
)
def _handle_litellm_auth_error(self, exception: Exception) -> None:
"""Handle litellm authentication error by informing the user and exiting."""
self._printer.print(
content="Authentication error with litellm occurred. Please check your API key and configuration.",
color="red",
)
self._printer.print(
content=f"Error details: {exception}",
color="red",
)
def _has_reached_max_iterations(self) -> bool: def _has_reached_max_iterations(self) -> bool:
"""Check if the maximum number of iterations has been reached.""" """Check if the maximum number of iterations has been reached."""
print(f"Max iterations: {self.max_iter}")
print(f"Iterations: {self.iterations}")
return self.iterations >= self.max_iter return self.iterations >= self.max_iter
def _enforce_rpm_limit(self) -> None: def _enforce_rpm_limit(self) -> None:
@@ -160,10 +204,18 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
def _get_llm_response(self) -> str: def _get_llm_response(self) -> str:
"""Call the LLM and return the response, handling any invalid responses.""" """Call the LLM and return the response, handling any invalid responses."""
answer = self.llm.call( try:
self.messages, answer = self.llm.call(
callbacks=self.callbacks, self.messages,
) callbacks=self.callbacks,
)
except Exception as e:
self._printer.print(
content=f"Error during LLM call: {e}",
color="red",
)
# Re-raise the exception to let _invoke_loop handle it
raise e
if not answer: if not answer:
self._printer.print( self._printer.print(
@@ -184,7 +236,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE in e.error: if FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE in e.error:
answer = answer.split("Observation:")[0].strip() answer = answer.split("Observation:")[0].strip()
self.iterations += 1
return self._format_answer(answer) return self._format_answer(answer)
def _handle_agent_action( def _handle_agent_action(

View File

@@ -24,12 +24,10 @@ def create_llm(
# 1) If llm_value is already an LLM object, return it directly # 1) If llm_value is already an LLM object, return it directly
if isinstance(llm_value, LLM): if isinstance(llm_value, LLM):
print("LLM value is already an LLM object")
return llm_value return llm_value
# 2) If llm_value is a string (model name) # 2) If llm_value is a string (model name)
if isinstance(llm_value, str): if isinstance(llm_value, str):
print("LLM value is a string")
try: try:
created_llm = LLM(model=llm_value) created_llm = LLM(model=llm_value)
return created_llm return created_llm
@@ -39,12 +37,10 @@ def create_llm(
# 3) If llm_value is None, parse environment variables or use default # 3) If llm_value is None, parse environment variables or use default
if llm_value is None: if llm_value is None:
print("LLM value is None")
return _llm_via_environment_or_fallback() return _llm_via_environment_or_fallback()
# 4) Otherwise, attempt to extract relevant attributes from an unknown object # 4) Otherwise, attempt to extract relevant attributes from an unknown object
try: try:
print("LLM value is an unknown object")
# Extract attributes with explicit types # Extract attributes with explicit types
model = ( model = (
getattr(llm_value, "model_name", None) getattr(llm_value, "model_name", None)