code cleanup, using click for cli asker

This commit is contained in:
Lorenze Jay
2024-08-01 11:55:47 -07:00
parent 9393373ff0
commit 98d139f66c
8 changed files with 92 additions and 79 deletions

27
poetry.lock generated
View File

@@ -3933,20 +3933,6 @@ nodeenv = ">=0.11.1"
pyyaml = ">=5.1"
virtualenv = ">=20.10.0"
[[package]]
name = "prompt-toolkit"
version = "3.0.47"
description = "Library for building powerful interactive command lines in Python"
optional = false
python-versions = ">=3.7.0"
files = [
{file = "prompt_toolkit-3.0.47-py3-none-any.whl", hash = "sha256:0d7bfa67001d5e39d02c224b663abc33687405033a8c422d0d675a5a13361d10"},
{file = "prompt_toolkit-3.0.47.tar.gz", hash = "sha256:1e1b29cb58080b1e69f207c893a1a7bf16d127a5c30c9d17a25a5d77792e5360"},
]
[package.dependencies]
wcwidth = "*"
[[package]]
name = "proto-plus"
version = "1.24.0"
@@ -5819,17 +5805,6 @@ files = [
[package.dependencies]
anyio = ">=3.0.0"
[[package]]
name = "wcwidth"
version = "0.2.13"
description = "Measures the displayed width of unicode strings in a terminal"
optional = false
python-versions = "*"
files = [
{file = "wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859"},
{file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"},
]
[[package]]
name = "webencodings"
version = "0.5.1"
@@ -6156,4 +6131,4 @@ tools = ["crewai-tools"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<=3.13"
content-hash = "a4c662dbad415f52207bd3151541f8be028440fe831a8ea795f6ddc84a03ba42"
content-hash = "f5ad9babb3c57c405e39232020e8cbfaaeb5c315c2e7c5bb8fdf66792f260343"

View File

@@ -29,7 +29,6 @@ jsonref = "^1.1.0"
agentops = { version = "^0.3.0", optional = true }
embedchain = "^0.1.114"
json-repair = "^0.25.2"
prompt-toolkit = "^3.0.47"
[tool.poetry.extras]
tools = ["crewai-tools"]

View File

@@ -1,8 +0,0 @@
from prompt_toolkit.validation import Validator, ValidationError
class YesNoValidator(Validator):
def validate(self, document):
text = document.text.lower()
if text not in ["y", "n", "yes", "no"]:
raise ValidationError(message="Please enter Y/N")

View File

@@ -1,7 +1,7 @@
import threading
import time
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union
from prompt_toolkit import prompt
import click
from langchain.agents import AgentExecutor
@@ -13,18 +13,19 @@ from langchain_core.tools import BaseTool
from langchain_core.utils.input import get_color_mapping
from pydantic import InstanceOf
from openai import BadRequestError
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.summarize import load_summarize_chain
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.agent_builder.utilities.yes_no_cli_validator import YesNoValidator
from crewai.agents.tools_handler import ToolsHandler
from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException
from crewai.utilities import I18N
from crewai.utilities.constants import TRAINING_DATA_FILE
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
)
from crewai.utilities.training_handler import CrewTrainingHandler
from crewai.utilities.logger import Logger
@@ -50,7 +51,7 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
system_template: Optional[str] = None
prompt_template: Optional[str] = None
response_template: Optional[str] = None
_logger: Logger = Logger()
_logger: Logger = Logger(verbose_level=2)
_fit_context_window_strategy: Optional[Literal["summarize"]] = "summarize"
def _call(
@@ -197,25 +198,28 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
yield AgentStep(action=output, observation=observation)
return
except BadRequestError as e:
print("Bad Request Error", e)
except Exception as e:
if LLMContextLengthExceededException(str(e))._is_context_limit_error(
str(e)
):
if "context_length_exceeded" in str(e):
self._logger.log(
"debug",
"Context length exceeded. Asking user if they want to use summarize prompt to fit, this will reduce context length.",
color="yellow",
)
user_choice = prompt(
"Context length exceeded. Do you want to summarize the text to fit models context window? (Y/N): ",
validator=YesNoValidator(),
).lower()
if user_choice in ["y", "yes"]:
user_choice = click.confirm(
"Context length exceeded. Do you want to summarize the text to fit models context window?"
)
if user_choice:
self._logger.log(
"debug",
"Context length exceeded. Using summarize prompt to fit, this will reduce context length.",
color="bold_blue",
)
intermediate_steps = self._handle_context_length(intermediate_steps)
intermediate_steps = self._handle_context_length(
intermediate_steps
)
output = self.agent.plan(
intermediate_steps,
@@ -229,14 +233,14 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
yield AgentStep(action=output, observation=None)
return
else:
self._logger.log(
"debug",
"Context length exceeded. Consider using smaller text or RAG tools from crewai_tools.",
color="red",
)
raise SystemExit(
"Context length exceeded and user opted not to summarize. Consider using smaller text or RAG tools from crewai_tools."
)
else:
raise e
except Exception as e:
yield AgentStep(
action=AgentAction("_Exception", str(e), str(e)),
observation=str(e),
@@ -364,8 +368,11 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
)
summarized_docs = []
for doc in docs:
summary = summarize_chain.run([doc])
summarized_docs.append(summary)
summary = summarize_chain.invoke(
{"input_documents": [doc]}, return_only_outputs=True
)
summarized_docs.append(summary["output_text"])
formatted_results = "\n\n".join(summarized_docs)
summary_step = AgentStep(

View File

@@ -7,6 +7,9 @@ from .parser import YamlParser
from .printer import Printer
from .prompts import Prompts
from .rpm_controller import RPMController
from .exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
)
__all__ = [
"Converter",
@@ -19,4 +22,5 @@ __all__ = [
"Prompts",
"RPMController",
"YamlParser",
"LLMContextLengthExceededException",
]

View File

@@ -0,0 +1,30 @@
class LLMContextLengthExceededException(Exception):
CONTEXT_LIMIT_ERRORS = [
"maximum context length",
"context length exceeded",
"context window full",
"too many tokens",
"input is too long",
"exceeds token limit",
]
def __init__(self, error_message: str):
self.original_error_message = error_message
if self._is_context_limit_error(error_message):
super().__init__(self._get_error_message())
else:
raise ValueError(
"The provided error message is not related to context length limits."
)
def _is_context_limit_error(self, error_message: str) -> bool:
return any(
phrase.lower() in error_message.lower()
for phrase in self.CONTEXT_LIMIT_ERRORS
)
def _get_error_message(self):
return (
f"LLM context length exceeded. Original error: {self.original_error_message}\n"
"Consider using a smaller input or implementing a text splitting strategy."
)

View File

@@ -0,0 +1,8 @@
class ContextLengthExceeded(Exception):
def __init__(self, exceptions):
self.exceptions = exceptions
super().__init__(self.__str__())
def __str__(self):
error_messages = [str(e) for e in self.exceptions]
return f"Multiple BadRequestExceptions occurred: {', '.join(error_messages)}"

View File

@@ -1024,8 +1024,6 @@ def test_handle_context_length_exceeds_limit():
goal="test goal",
backstory="test backstory",
)
task = Task(description="test task", agent=agent, expected_output="test output")
# crew = Crew(agents=[agent], tasks=[task])
original_action = AgentAction(
tool="test_tool", tool_input="test_input", log="test_log"
)
@@ -1041,7 +1039,7 @@ def test_handle_context_length_exceeds_limit():
task=task,
)
private_mock.assert_called_once()
with patch("crewai.agents.executor.prompt") as mock_prompt:
with patch("crewai.agents.executor.click") as mock_prompt:
mock_prompt.return_value = "y"
with patch.object(
CrewAgentExecutor, "_handle_context_length"
@@ -1082,7 +1080,7 @@ def test_handle_context_length_exceeds_limit_cli_no():
task=task,
)
private_mock.assert_called_once()
with patch("crewai.agents.executor.prompt") as mock_prompt:
with patch("crewai.agents.executor.click") as mock_prompt:
mock_prompt.return_value = "n"
pytest.raises(SystemExit)
with patch.object(