From 98d139f66c1e64b340859f1ae9949aba30f28fd1 Mon Sep 17 00:00:00 2001 From: Lorenze Jay Date: Thu, 1 Aug 2024 11:55:47 -0700 Subject: [PATCH] code cleanup, using click for cli asker --- poetry.lock | 27 +----- pyproject.toml | 1 - .../utilities/yes_no_cli_validator.py | 8 -- src/crewai/agents/executor.py | 87 ++++++++++--------- src/crewai/utilities/__init__.py | 4 + .../context_window_exceeding_exception.py | 30 +++++++ .../exceptions/exception_aggregator.py | 8 ++ tests/agent_test.py | 6 +- 8 files changed, 92 insertions(+), 79 deletions(-) delete mode 100644 src/crewai/agents/agent_builder/utilities/yes_no_cli_validator.py create mode 100644 src/crewai/utilities/exceptions/context_window_exceeding_exception.py create mode 100644 src/crewai/utilities/exceptions/exception_aggregator.py diff --git a/poetry.lock b/poetry.lock index f357a8bdd..356cd8756 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3933,20 +3933,6 @@ nodeenv = ">=0.11.1" pyyaml = ">=5.1" virtualenv = ">=20.10.0" -[[package]] -name = "prompt-toolkit" -version = "3.0.47" -description = "Library for building powerful interactive command lines in Python" -optional = false -python-versions = ">=3.7.0" -files = [ - {file = "prompt_toolkit-3.0.47-py3-none-any.whl", hash = "sha256:0d7bfa67001d5e39d02c224b663abc33687405033a8c422d0d675a5a13361d10"}, - {file = "prompt_toolkit-3.0.47.tar.gz", hash = "sha256:1e1b29cb58080b1e69f207c893a1a7bf16d127a5c30c9d17a25a5d77792e5360"}, -] - -[package.dependencies] -wcwidth = "*" - [[package]] name = "proto-plus" version = "1.24.0" @@ -5819,17 +5805,6 @@ files = [ [package.dependencies] anyio = ">=3.0.0" -[[package]] -name = "wcwidth" -version = "0.2.13" -description = "Measures the displayed width of unicode strings in a terminal" -optional = false -python-versions = "*" -files = [ - {file = "wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859"}, - {file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"}, -] - [[package]] name = "webencodings" version = "0.5.1" @@ -6156,4 +6131,4 @@ tools = ["crewai-tools"] [metadata] lock-version = "2.0" python-versions = ">=3.10,<=3.13" -content-hash = "a4c662dbad415f52207bd3151541f8be028440fe831a8ea795f6ddc84a03ba42" +content-hash = "f5ad9babb3c57c405e39232020e8cbfaaeb5c315c2e7c5bb8fdf66792f260343" diff --git a/pyproject.toml b/pyproject.toml index f10d07f88..a174fc669 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,6 @@ jsonref = "^1.1.0" agentops = { version = "^0.3.0", optional = true } embedchain = "^0.1.114" json-repair = "^0.25.2" -prompt-toolkit = "^3.0.47" [tool.poetry.extras] tools = ["crewai-tools"] diff --git a/src/crewai/agents/agent_builder/utilities/yes_no_cli_validator.py b/src/crewai/agents/agent_builder/utilities/yes_no_cli_validator.py deleted file mode 100644 index 4f44bee61..000000000 --- a/src/crewai/agents/agent_builder/utilities/yes_no_cli_validator.py +++ /dev/null @@ -1,8 +0,0 @@ -from prompt_toolkit.validation import Validator, ValidationError - - -class YesNoValidator(Validator): - def validate(self, document): - text = document.text.lower() - if text not in ["y", "n", "yes", "no"]: - raise ValidationError(message="Please enter Y/N") diff --git a/src/crewai/agents/executor.py b/src/crewai/agents/executor.py index d35272ed0..115a424ef 100644 --- a/src/crewai/agents/executor.py +++ b/src/crewai/agents/executor.py @@ -1,7 +1,7 @@ import threading import time from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union -from prompt_toolkit import prompt +import click from langchain.agents import AgentExecutor @@ -13,18 +13,19 @@ from langchain_core.tools import BaseTool from langchain_core.utils.input import get_color_mapping from pydantic import InstanceOf -from openai import BadRequestError from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.chains.summarize import load_summarize_chain from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin -from crewai.agents.agent_builder.utilities.yes_no_cli_validator import YesNoValidator from crewai.agents.tools_handler import ToolsHandler from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException from crewai.utilities import I18N from crewai.utilities.constants import TRAINING_DATA_FILE +from crewai.utilities.exceptions.context_window_exceeding_exception import ( + LLMContextLengthExceededException, +) from crewai.utilities.training_handler import CrewTrainingHandler from crewai.utilities.logger import Logger @@ -50,7 +51,7 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin): system_template: Optional[str] = None prompt_template: Optional[str] = None response_template: Optional[str] = None - _logger: Logger = Logger() + _logger: Logger = Logger(verbose_level=2) _fit_context_window_strategy: Optional[Literal["summarize"]] = "summarize" def _call( @@ -197,46 +198,49 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin): yield AgentStep(action=output, observation=observation) return - except BadRequestError as e: - print("Bad Request Error", e) - if "context_length_exceeded" in str(e): - self._logger.log( - "debug", - "Context length exceeded. Asking user if they want to use summarize prompt to fit, this will reduce context length.", - color="yellow", - ) - user_choice = prompt( - "Context length exceeded. Do you want to summarize the text to fit models context window? (Y/N): ", - validator=YesNoValidator(), - ).lower() - if user_choice in ["y", "yes"]: + except Exception as e: + if LLMContextLengthExceededException(str(e))._is_context_limit_error( + str(e) + ): + if "context_length_exceeded" in str(e): self._logger.log( "debug", - "Context length exceeded. Using summarize prompt to fit, this will reduce context length.", - color="bold_blue", + "Context length exceeded. Asking user if they want to use summarize prompt to fit, this will reduce context length.", + color="yellow", ) - intermediate_steps = self._handle_context_length(intermediate_steps) - - output = self.agent.plan( - intermediate_steps, - callbacks=run_manager.get_child() if run_manager else None, - **inputs, + user_choice = click.confirm( + "Context length exceeded. Do you want to summarize the text to fit models context window?" ) + if user_choice: + self._logger.log( + "debug", + "Context length exceeded. Using summarize prompt to fit, this will reduce context length.", + color="bold_blue", + ) + intermediate_steps = self._handle_context_length( + intermediate_steps + ) - if isinstance(output, AgentFinish): - yield output + output = self.agent.plan( + intermediate_steps, + callbacks=run_manager.get_child() if run_manager else None, + **inputs, + ) + + if isinstance(output, AgentFinish): + yield output + else: + yield AgentStep(action=output, observation=None) + return else: - yield AgentStep(action=output, observation=None) - return - else: - raise SystemExit( - "Context length exceeded and user opted not to summarize. Consider using smaller text or RAG tools from crewai_tools." - ) - - else: - raise e - - except Exception as e: + self._logger.log( + "debug", + "Context length exceeded. Consider using smaller text or RAG tools from crewai_tools.", + color="red", + ) + raise SystemExit( + "Context length exceeded and user opted not to summarize. Consider using smaller text or RAG tools from crewai_tools." + ) yield AgentStep( action=AgentAction("_Exception", str(e), str(e)), observation=str(e), @@ -364,8 +368,11 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin): ) summarized_docs = [] for doc in docs: - summary = summarize_chain.run([doc]) - summarized_docs.append(summary) + summary = summarize_chain.invoke( + {"input_documents": [doc]}, return_only_outputs=True + ) + + summarized_docs.append(summary["output_text"]) formatted_results = "\n\n".join(summarized_docs) summary_step = AgentStep( diff --git a/src/crewai/utilities/__init__.py b/src/crewai/utilities/__init__.py index efe620c2e..f13fce390 100644 --- a/src/crewai/utilities/__init__.py +++ b/src/crewai/utilities/__init__.py @@ -7,6 +7,9 @@ from .parser import YamlParser from .printer import Printer from .prompts import Prompts from .rpm_controller import RPMController +from .exceptions.context_window_exceeding_exception import ( + LLMContextLengthExceededException, +) __all__ = [ "Converter", @@ -19,4 +22,5 @@ __all__ = [ "Prompts", "RPMController", "YamlParser", + "LLMContextLengthExceededException", ] diff --git a/src/crewai/utilities/exceptions/context_window_exceeding_exception.py b/src/crewai/utilities/exceptions/context_window_exceeding_exception.py new file mode 100644 index 000000000..1149ddb17 --- /dev/null +++ b/src/crewai/utilities/exceptions/context_window_exceeding_exception.py @@ -0,0 +1,30 @@ +class LLMContextLengthExceededException(Exception): + CONTEXT_LIMIT_ERRORS = [ + "maximum context length", + "context length exceeded", + "context window full", + "too many tokens", + "input is too long", + "exceeds token limit", + ] + + def __init__(self, error_message: str): + self.original_error_message = error_message + if self._is_context_limit_error(error_message): + super().__init__(self._get_error_message()) + else: + raise ValueError( + "The provided error message is not related to context length limits." + ) + + def _is_context_limit_error(self, error_message: str) -> bool: + return any( + phrase.lower() in error_message.lower() + for phrase in self.CONTEXT_LIMIT_ERRORS + ) + + def _get_error_message(self): + return ( + f"LLM context length exceeded. Original error: {self.original_error_message}\n" + "Consider using a smaller input or implementing a text splitting strategy." + ) diff --git a/src/crewai/utilities/exceptions/exception_aggregator.py b/src/crewai/utilities/exceptions/exception_aggregator.py new file mode 100644 index 000000000..c501d1f23 --- /dev/null +++ b/src/crewai/utilities/exceptions/exception_aggregator.py @@ -0,0 +1,8 @@ +class ContextLengthExceeded(Exception): + def __init__(self, exceptions): + self.exceptions = exceptions + super().__init__(self.__str__()) + + def __str__(self): + error_messages = [str(e) for e in self.exceptions] + return f"Multiple BadRequestExceptions occurred: {', '.join(error_messages)}" diff --git a/tests/agent_test.py b/tests/agent_test.py index b02003e00..05ef15ba8 100644 --- a/tests/agent_test.py +++ b/tests/agent_test.py @@ -1024,8 +1024,6 @@ def test_handle_context_length_exceeds_limit(): goal="test goal", backstory="test backstory", ) - task = Task(description="test task", agent=agent, expected_output="test output") - # crew = Crew(agents=[agent], tasks=[task]) original_action = AgentAction( tool="test_tool", tool_input="test_input", log="test_log" ) @@ -1041,7 +1039,7 @@ def test_handle_context_length_exceeds_limit(): task=task, ) private_mock.assert_called_once() - with patch("crewai.agents.executor.prompt") as mock_prompt: + with patch("crewai.agents.executor.click") as mock_prompt: mock_prompt.return_value = "y" with patch.object( CrewAgentExecutor, "_handle_context_length" @@ -1082,7 +1080,7 @@ def test_handle_context_length_exceeds_limit_cli_no(): task=task, ) private_mock.assert_called_once() - with patch("crewai.agents.executor.prompt") as mock_prompt: + with patch("crewai.agents.executor.click") as mock_prompt: mock_prompt.return_value = "n" pytest.raises(SystemExit) with patch.object(