mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
Wip to address https://github.com/crewAIInc/crewAI/issues/1934
This commit is contained in:
@@ -3,6 +3,8 @@ import re
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from litellm.exceptions import AuthenticationError as LiteLLMAuthenticationError
|
||||||
|
|
||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
||||||
from crewai.agents.parser import (
|
from crewai.agents.parser import (
|
||||||
@@ -114,8 +116,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
"""
|
"""
|
||||||
formatted_answer = None
|
formatted_answer = None
|
||||||
while not isinstance(formatted_answer, AgentFinish):
|
while not isinstance(formatted_answer, AgentFinish):
|
||||||
|
self.iterations += 1
|
||||||
try:
|
try:
|
||||||
if self._has_reached_max_iterations():
|
if self._has_reached_max_iterations():
|
||||||
|
print("Max iterations reached")
|
||||||
formatted_answer = self._handle_max_iterations_exceeded(
|
formatted_answer = self._handle_max_iterations_exceeded(
|
||||||
formatted_answer
|
formatted_answer
|
||||||
)
|
)
|
||||||
@@ -123,8 +127,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
|
|
||||||
self._enforce_rpm_limit()
|
self._enforce_rpm_limit()
|
||||||
|
|
||||||
|
print("Getting LLM response")
|
||||||
answer = self._get_llm_response()
|
answer = self._get_llm_response()
|
||||||
|
print(f"LLM response: {answer}")
|
||||||
formatted_answer = self._process_llm_response(answer)
|
formatted_answer = self._process_llm_response(answer)
|
||||||
|
|
||||||
if isinstance(formatted_answer, AgentAction):
|
if isinstance(formatted_answer, AgentAction):
|
||||||
@@ -142,15 +147,54 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
formatted_answer = self._handle_output_parser_exception(e)
|
formatted_answer = self._handle_output_parser_exception(e)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print(f"Exception: {e}")
|
||||||
if self._is_context_length_exceeded(e):
|
if self._is_context_length_exceeded(e):
|
||||||
self._handle_context_length()
|
self._handle_context_length()
|
||||||
continue
|
continue
|
||||||
|
elif self._is_litellm_authentication_error(e):
|
||||||
|
self._handle_litellm_auth_error(e)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
self._printer.print(
|
||||||
|
content=f"Unhandled exception: {e}",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
|
||||||
self._show_logs(formatted_answer)
|
self._show_logs(formatted_answer)
|
||||||
return formatted_answer
|
return formatted_answer
|
||||||
|
|
||||||
|
def _is_litellm_authentication_error(self, exception: Exception) -> bool:
|
||||||
|
"""Check if the exception is a litellm authentication error."""
|
||||||
|
# Check if the exception is an instance of litellm.AuthenticationError
|
||||||
|
if LiteLLMAuthenticationError and isinstance(
|
||||||
|
exception, LiteLLMAuthenticationError
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Alternatively, inspect the exception message
|
||||||
|
error_message = str(exception)
|
||||||
|
|
||||||
|
# Check if 'litellm.AuthenticationError' and 'invalid_api_key' are in the message
|
||||||
|
return "litellm.AuthenticationError" in error_message and (
|
||||||
|
"invalid_api_key" in error_message
|
||||||
|
or "Incorrect API key provided" in error_message
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_litellm_auth_error(self, exception: Exception) -> None:
|
||||||
|
"""Handle litellm authentication error by informing the user and exiting."""
|
||||||
|
self._printer.print(
|
||||||
|
content="Authentication error with litellm occurred. Please check your API key and configuration.",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
self._printer.print(
|
||||||
|
content=f"Error details: {exception}",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
|
||||||
def _has_reached_max_iterations(self) -> bool:
|
def _has_reached_max_iterations(self) -> bool:
|
||||||
"""Check if the maximum number of iterations has been reached."""
|
"""Check if the maximum number of iterations has been reached."""
|
||||||
|
print(f"Max iterations: {self.max_iter}")
|
||||||
|
print(f"Iterations: {self.iterations}")
|
||||||
return self.iterations >= self.max_iter
|
return self.iterations >= self.max_iter
|
||||||
|
|
||||||
def _enforce_rpm_limit(self) -> None:
|
def _enforce_rpm_limit(self) -> None:
|
||||||
@@ -160,10 +204,18 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
|
|
||||||
def _get_llm_response(self) -> str:
|
def _get_llm_response(self) -> str:
|
||||||
"""Call the LLM and return the response, handling any invalid responses."""
|
"""Call the LLM and return the response, handling any invalid responses."""
|
||||||
answer = self.llm.call(
|
try:
|
||||||
self.messages,
|
answer = self.llm.call(
|
||||||
callbacks=self.callbacks,
|
self.messages,
|
||||||
)
|
callbacks=self.callbacks,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self._printer.print(
|
||||||
|
content=f"Error during LLM call: {e}",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
# Re-raise the exception to let _invoke_loop handle it
|
||||||
|
raise e
|
||||||
|
|
||||||
if not answer:
|
if not answer:
|
||||||
self._printer.print(
|
self._printer.print(
|
||||||
@@ -184,7 +236,6 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
if FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE in e.error:
|
if FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE in e.error:
|
||||||
answer = answer.split("Observation:")[0].strip()
|
answer = answer.split("Observation:")[0].strip()
|
||||||
|
|
||||||
self.iterations += 1
|
|
||||||
return self._format_answer(answer)
|
return self._format_answer(answer)
|
||||||
|
|
||||||
def _handle_agent_action(
|
def _handle_agent_action(
|
||||||
|
|||||||
@@ -24,12 +24,10 @@ def create_llm(
|
|||||||
|
|
||||||
# 1) If llm_value is already an LLM object, return it directly
|
# 1) If llm_value is already an LLM object, return it directly
|
||||||
if isinstance(llm_value, LLM):
|
if isinstance(llm_value, LLM):
|
||||||
print("LLM value is already an LLM object")
|
|
||||||
return llm_value
|
return llm_value
|
||||||
|
|
||||||
# 2) If llm_value is a string (model name)
|
# 2) If llm_value is a string (model name)
|
||||||
if isinstance(llm_value, str):
|
if isinstance(llm_value, str):
|
||||||
print("LLM value is a string")
|
|
||||||
try:
|
try:
|
||||||
created_llm = LLM(model=llm_value)
|
created_llm = LLM(model=llm_value)
|
||||||
return created_llm
|
return created_llm
|
||||||
@@ -39,12 +37,10 @@ def create_llm(
|
|||||||
|
|
||||||
# 3) If llm_value is None, parse environment variables or use default
|
# 3) If llm_value is None, parse environment variables or use default
|
||||||
if llm_value is None:
|
if llm_value is None:
|
||||||
print("LLM value is None")
|
|
||||||
return _llm_via_environment_or_fallback()
|
return _llm_via_environment_or_fallback()
|
||||||
|
|
||||||
# 4) Otherwise, attempt to extract relevant attributes from an unknown object
|
# 4) Otherwise, attempt to extract relevant attributes from an unknown object
|
||||||
try:
|
try:
|
||||||
print("LLM value is an unknown object")
|
|
||||||
# Extract attributes with explicit types
|
# Extract attributes with explicit types
|
||||||
model = (
|
model = (
|
||||||
getattr(llm_value, "model_name", None)
|
getattr(llm_value, "model_name", None)
|
||||||
|
|||||||
Reference in New Issue
Block a user