code cleanup, using click for cli asker

This commit is contained in:
Lorenze Jay
2024-08-01 11:55:47 -07:00
parent 9393373ff0
commit 98d139f66c
8 changed files with 92 additions and 79 deletions

27
poetry.lock generated
View File

@@ -3933,20 +3933,6 @@ nodeenv = ">=0.11.1"
pyyaml = ">=5.1" pyyaml = ">=5.1"
virtualenv = ">=20.10.0" virtualenv = ">=20.10.0"
[[package]]
name = "prompt-toolkit"
version = "3.0.47"
description = "Library for building powerful interactive command lines in Python"
optional = false
python-versions = ">=3.7.0"
files = [
{file = "prompt_toolkit-3.0.47-py3-none-any.whl", hash = "sha256:0d7bfa67001d5e39d02c224b663abc33687405033a8c422d0d675a5a13361d10"},
{file = "prompt_toolkit-3.0.47.tar.gz", hash = "sha256:1e1b29cb58080b1e69f207c893a1a7bf16d127a5c30c9d17a25a5d77792e5360"},
]
[package.dependencies]
wcwidth = "*"
[[package]] [[package]]
name = "proto-plus" name = "proto-plus"
version = "1.24.0" version = "1.24.0"
@@ -5819,17 +5805,6 @@ files = [
[package.dependencies] [package.dependencies]
anyio = ">=3.0.0" anyio = ">=3.0.0"
[[package]]
name = "wcwidth"
version = "0.2.13"
description = "Measures the displayed width of unicode strings in a terminal"
optional = false
python-versions = "*"
files = [
{file = "wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859"},
{file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"},
]
[[package]] [[package]]
name = "webencodings" name = "webencodings"
version = "0.5.1" version = "0.5.1"
@@ -6156,4 +6131,4 @@ tools = ["crewai-tools"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.10,<=3.13" python-versions = ">=3.10,<=3.13"
content-hash = "a4c662dbad415f52207bd3151541f8be028440fe831a8ea795f6ddc84a03ba42" content-hash = "f5ad9babb3c57c405e39232020e8cbfaaeb5c315c2e7c5bb8fdf66792f260343"

View File

@@ -29,7 +29,6 @@ jsonref = "^1.1.0"
agentops = { version = "^0.3.0", optional = true } agentops = { version = "^0.3.0", optional = true }
embedchain = "^0.1.114" embedchain = "^0.1.114"
json-repair = "^0.25.2" json-repair = "^0.25.2"
prompt-toolkit = "^3.0.47"
[tool.poetry.extras] [tool.poetry.extras]
tools = ["crewai-tools"] tools = ["crewai-tools"]

View File

@@ -1,8 +0,0 @@
from prompt_toolkit.validation import Validator, ValidationError
class YesNoValidator(Validator):
def validate(self, document):
text = document.text.lower()
if text not in ["y", "n", "yes", "no"]:
raise ValidationError(message="Please enter Y/N")

View File

@@ -1,7 +1,7 @@
import threading import threading
import time import time
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union
from prompt_toolkit import prompt import click
from langchain.agents import AgentExecutor from langchain.agents import AgentExecutor
@@ -13,18 +13,19 @@ from langchain_core.tools import BaseTool
from langchain_core.utils.input import get_color_mapping from langchain_core.utils.input import get_color_mapping
from pydantic import InstanceOf from pydantic import InstanceOf
from openai import BadRequestError
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.summarize import load_summarize_chain from langchain.chains.summarize import load_summarize_chain
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.agent_builder.utilities.yes_no_cli_validator import YesNoValidator
from crewai.agents.tools_handler import ToolsHandler from crewai.agents.tools_handler import ToolsHandler
from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException
from crewai.utilities import I18N from crewai.utilities import I18N
from crewai.utilities.constants import TRAINING_DATA_FILE from crewai.utilities.constants import TRAINING_DATA_FILE
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
)
from crewai.utilities.training_handler import CrewTrainingHandler from crewai.utilities.training_handler import CrewTrainingHandler
from crewai.utilities.logger import Logger from crewai.utilities.logger import Logger
@@ -50,7 +51,7 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
system_template: Optional[str] = None system_template: Optional[str] = None
prompt_template: Optional[str] = None prompt_template: Optional[str] = None
response_template: Optional[str] = None response_template: Optional[str] = None
_logger: Logger = Logger() _logger: Logger = Logger(verbose_level=2)
_fit_context_window_strategy: Optional[Literal["summarize"]] = "summarize" _fit_context_window_strategy: Optional[Literal["summarize"]] = "summarize"
def _call( def _call(
@@ -197,46 +198,49 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
yield AgentStep(action=output, observation=observation) yield AgentStep(action=output, observation=observation)
return return
except BadRequestError as e: except Exception as e:
print("Bad Request Error", e) if LLMContextLengthExceededException(str(e))._is_context_limit_error(
if "context_length_exceeded" in str(e): str(e)
self._logger.log( ):
"debug", if "context_length_exceeded" in str(e):
"Context length exceeded. Asking user if they want to use summarize prompt to fit, this will reduce context length.",
color="yellow",
)
user_choice = prompt(
"Context length exceeded. Do you want to summarize the text to fit models context window? (Y/N): ",
validator=YesNoValidator(),
).lower()
if user_choice in ["y", "yes"]:
self._logger.log( self._logger.log(
"debug", "debug",
"Context length exceeded. Using summarize prompt to fit, this will reduce context length.", "Context length exceeded. Asking user if they want to use summarize prompt to fit, this will reduce context length.",
color="bold_blue", color="yellow",
) )
intermediate_steps = self._handle_context_length(intermediate_steps) user_choice = click.confirm(
"Context length exceeded. Do you want to summarize the text to fit models context window?"
output = self.agent.plan(
intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
) )
if user_choice:
self._logger.log(
"debug",
"Context length exceeded. Using summarize prompt to fit, this will reduce context length.",
color="bold_blue",
)
intermediate_steps = self._handle_context_length(
intermediate_steps
)
if isinstance(output, AgentFinish): output = self.agent.plan(
yield output intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
)
if isinstance(output, AgentFinish):
yield output
else:
yield AgentStep(action=output, observation=None)
return
else: else:
yield AgentStep(action=output, observation=None) self._logger.log(
return "debug",
else: "Context length exceeded. Consider using smaller text or RAG tools from crewai_tools.",
raise SystemExit( color="red",
"Context length exceeded and user opted not to summarize. Consider using smaller text or RAG tools from crewai_tools." )
) raise SystemExit(
"Context length exceeded and user opted not to summarize. Consider using smaller text or RAG tools from crewai_tools."
else: )
raise e
except Exception as e:
yield AgentStep( yield AgentStep(
action=AgentAction("_Exception", str(e), str(e)), action=AgentAction("_Exception", str(e), str(e)),
observation=str(e), observation=str(e),
@@ -364,8 +368,11 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
) )
summarized_docs = [] summarized_docs = []
for doc in docs: for doc in docs:
summary = summarize_chain.run([doc]) summary = summarize_chain.invoke(
summarized_docs.append(summary) {"input_documents": [doc]}, return_only_outputs=True
)
summarized_docs.append(summary["output_text"])
formatted_results = "\n\n".join(summarized_docs) formatted_results = "\n\n".join(summarized_docs)
summary_step = AgentStep( summary_step = AgentStep(

View File

@@ -7,6 +7,9 @@ from .parser import YamlParser
from .printer import Printer from .printer import Printer
from .prompts import Prompts from .prompts import Prompts
from .rpm_controller import RPMController from .rpm_controller import RPMController
from .exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
)
__all__ = [ __all__ = [
"Converter", "Converter",
@@ -19,4 +22,5 @@ __all__ = [
"Prompts", "Prompts",
"RPMController", "RPMController",
"YamlParser", "YamlParser",
"LLMContextLengthExceededException",
] ]

View File

@@ -0,0 +1,30 @@
class LLMContextLengthExceededException(Exception):
CONTEXT_LIMIT_ERRORS = [
"maximum context length",
"context length exceeded",
"context window full",
"too many tokens",
"input is too long",
"exceeds token limit",
]
def __init__(self, error_message: str):
self.original_error_message = error_message
if self._is_context_limit_error(error_message):
super().__init__(self._get_error_message())
else:
raise ValueError(
"The provided error message is not related to context length limits."
)
def _is_context_limit_error(self, error_message: str) -> bool:
return any(
phrase.lower() in error_message.lower()
for phrase in self.CONTEXT_LIMIT_ERRORS
)
def _get_error_message(self):
return (
f"LLM context length exceeded. Original error: {self.original_error_message}\n"
"Consider using a smaller input or implementing a text splitting strategy."
)

View File

@@ -0,0 +1,8 @@
class ContextLengthExceeded(Exception):
def __init__(self, exceptions):
self.exceptions = exceptions
super().__init__(self.__str__())
def __str__(self):
error_messages = [str(e) for e in self.exceptions]
return f"Multiple BadRequestExceptions occurred: {', '.join(error_messages)}"

View File

@@ -1024,8 +1024,6 @@ def test_handle_context_length_exceeds_limit():
goal="test goal", goal="test goal",
backstory="test backstory", backstory="test backstory",
) )
task = Task(description="test task", agent=agent, expected_output="test output")
# crew = Crew(agents=[agent], tasks=[task])
original_action = AgentAction( original_action = AgentAction(
tool="test_tool", tool_input="test_input", log="test_log" tool="test_tool", tool_input="test_input", log="test_log"
) )
@@ -1041,7 +1039,7 @@ def test_handle_context_length_exceeds_limit():
task=task, task=task,
) )
private_mock.assert_called_once() private_mock.assert_called_once()
with patch("crewai.agents.executor.prompt") as mock_prompt: with patch("crewai.agents.executor.click") as mock_prompt:
mock_prompt.return_value = "y" mock_prompt.return_value = "y"
with patch.object( with patch.object(
CrewAgentExecutor, "_handle_context_length" CrewAgentExecutor, "_handle_context_length"
@@ -1082,7 +1080,7 @@ def test_handle_context_length_exceeds_limit_cli_no():
task=task, task=task,
) )
private_mock.assert_called_once() private_mock.assert_called_once()
with patch("crewai.agents.executor.prompt") as mock_prompt: with patch("crewai.agents.executor.click") as mock_prompt:
mock_prompt.return_value = "n" mock_prompt.return_value = "n"
pytest.raises(SystemExit) pytest.raises(SystemExit)
with patch.object( with patch.object(