From 579689900a45fc2ecd4c637eb325ba591661714e Mon Sep 17 00:00:00 2001 From: Lorenze Jay Date: Wed, 31 Jul 2024 16:29:24 -0700 Subject: [PATCH] WIP: need tests but user inputted summarization strategy implemented - handling context window exceeding errors --- .../utilities/yes_no_cli_validator.py | 8 + src/crewai/agents/executor.py | 406 ++++++++++-------- 2 files changed, 229 insertions(+), 185 deletions(-) create mode 100644 src/crewai/agents/agent_builder/utilities/yes_no_cli_validator.py diff --git a/src/crewai/agents/agent_builder/utilities/yes_no_cli_validator.py b/src/crewai/agents/agent_builder/utilities/yes_no_cli_validator.py new file mode 100644 index 000000000..4f44bee61 --- /dev/null +++ b/src/crewai/agents/agent_builder/utilities/yes_no_cli_validator.py @@ -0,0 +1,8 @@ +from prompt_toolkit.validation import Validator, ValidationError + + +class YesNoValidator(Validator): + def validate(self, document): + text = document.text.lower() + if text not in ["y", "n", "yes", "no"]: + raise ValidationError(message="Please enter Y/N") diff --git a/src/crewai/agents/executor.py b/src/crewai/agents/executor.py index 2de498440..517fc2081 100644 --- a/src/crewai/agents/executor.py +++ b/src/crewai/agents/executor.py @@ -1,6 +1,8 @@ import threading import time -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union +from prompt_toolkit import prompt + from langchain.agents import AgentExecutor from langchain.agents.agent import ExceptionTool @@ -10,17 +12,21 @@ from langchain_core.exceptions import OutputParserException from langchain_core.tools import BaseTool from langchain_core.utils.input import get_color_mapping from pydantic import InstanceOf -import tiktoken + +from openai import BadRequestError from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.chains.summarize import load_summarize_chain -from openai import BadRequestError from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin +from crewai.agents.agent_builder.utilities.yes_no_cli_validator import YesNoValidator from crewai.agents.tools_handler import ToolsHandler + + from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException from crewai.utilities import I18N from crewai.utilities.constants import TRAINING_DATA_FILE from crewai.utilities.training_handler import CrewTrainingHandler +from crewai.utilities.logger import Logger class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin): @@ -44,8 +50,8 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin): system_template: Optional[str] = None prompt_template: Optional[str] = None response_template: Optional[str] = None - retry_summarize: bool = False - retry_summarize_count: int = 2 + _logger: Logger = Logger() + _fit_context_window_strategy: Optional[Literal["summarize"]] = "summarize" def _call( self, @@ -126,198 +132,187 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin): Override this to take control of how the agent makes and acts on choices. """ - for attempt in range(self.retry_summarize_count): - try: - if self._should_force_answer(): - error = self._i18n.errors("force_final_answer") - output = AgentAction("_Exception", error, error) - self.have_forced_answer = True - yield AgentStep(action=output, observation=error) - return - - intermediate_steps = self._prepare_intermediate_steps( - intermediate_steps - ) - if self.retry_summarize: - encoding = tiktoken.encoding_for_model(self.llm.model_name) - original_token_count = len( - encoding.encode(intermediate_steps[0][1]) - ) - if original_token_count > 8000: - print( - "BEFORE AGENT PLAN TOKEN LENGTH", - original_token_count, - ) - text = intermediate_steps[0][1] - - text_splitter = RecursiveCharacterTextSplitter( - separators=["\n\n", "\n"], - chunk_size=8000, - chunk_overlap=500, - ) - docs = text_splitter.create_documents([text]) - print("DOCS", docs) - print("DOCS length", len(docs)) - breakpoint() - # TODO: store to vector db - using memgpt like strategy - summary_chain = load_summarize_chain( - self.llm, chain_type="map_reduce", verbose=True - ) - summary = summary_chain.run(docs) - - print("SUMMARY:", summary) - - intermediate_steps[0] = (intermediate_steps[0][0], summary) - - # Call the LLM to see what to do. - output = self.agent.plan( # type: ignore # Incompatible types in assignment (expression has type "AgentAction | AgentFinish | list[AgentAction]", variable has type "AgentAction") - intermediate_steps, - callbacks=run_manager.get_child() if run_manager else None, - **inputs, - ) - - except OutputParserException as e: - if isinstance(self.handle_parsing_errors, bool): - raise_error = not self.handle_parsing_errors - else: - raise_error = False - if raise_error: - raise ValueError( - "An output parsing error occurred. " - "In order to pass this error back to the agent and have it try " - "again, pass `handle_parsing_errors=True` to the AgentExecutor. " - f"This is the error: {str(e)}" - ) - str(e) - if isinstance(self.handle_parsing_errors, bool): - if e.send_to_llm: - observation = f"\n{str(e.observation)}" - str(e.llm_output) - else: - observation = "" - elif isinstance(self.handle_parsing_errors, str): - observation = f"\n{self.handle_parsing_errors}" - elif callable(self.handle_parsing_errors): - observation = f"\n{self.handle_parsing_errors(e)}" - else: - raise ValueError("Got unexpected type of `handle_parsing_errors`") - output = AgentAction("_Exception", observation, "") - - if run_manager: - run_manager.on_agent_action(output, color="green") - - tool_run_kwargs = self.agent.tool_run_logging_kwargs() - observation = ExceptionTool().run( - output.tool_input, - verbose=False, - color=None, - callbacks=run_manager.get_child() if run_manager else None, - **tool_run_kwargs, - ) - - if self._should_force_answer(): - error = self._i18n.errors("force_final_answer") - output = AgentAction("_Exception", error, error) - yield AgentStep(action=output, observation=error) - return - - yield AgentStep(action=output, observation=observation) + try: + if self._should_force_answer(): + error = self._i18n.errors("force_final_answer") + output = AgentAction("_Exception", error, error) + self.have_forced_answer = True + yield AgentStep(action=output, observation=error) return - except BadRequestError as e: - print("bad request string str(e)", str(e)) - if ( - "context_length_exceeded" in str(e) - and attempt < self.retry_summarize_count - 1 - ): - print( - f"Context length exceeded. Retrying with summarization (attempt {attempt + 1})..." - ) - self.retry_summarize = True - breakpoint() - continue - else: - print("Error now raising occurred in _iter_next_step:", e) - raise e + intermediate_steps = self._prepare_intermediate_steps(intermediate_steps) - except Exception as e: - print("Error occurred in _iter_next_step:", e) + # Call the LLM to see what to do. + output = self.agent.plan( # type: ignore # Incompatible types in assignment (expression has type "AgentAction | AgentFinish | list[AgentAction]", variable has type "AgentAction") + intermediate_steps, + callbacks=run_manager.get_child() if run_manager else None, + **inputs, + ) + + except OutputParserException as e: + if isinstance(self.handle_parsing_errors, bool): + raise_error = not self.handle_parsing_errors + else: + raise_error = False + if raise_error: + raise ValueError( + "An output parsing error occurred. " + "In order to pass this error back to the agent and have it try " + "again, pass `handle_parsing_errors=True` to the AgentExecutor. " + f"This is the error: {str(e)}" + ) + str(e) + if isinstance(self.handle_parsing_errors, bool): + if e.send_to_llm: + observation = f"\n{str(e.observation)}" + str(e.llm_output) + else: + observation = "" + elif isinstance(self.handle_parsing_errors, str): + observation = f"\n{self.handle_parsing_errors}" + elif callable(self.handle_parsing_errors): + observation = f"\n{self.handle_parsing_errors(e)}" + else: + raise ValueError("Got unexpected type of `handle_parsing_errors`") + output = AgentAction("_Exception", observation, "") + + if run_manager: + run_manager.on_agent_action(output, color="green") + + tool_run_kwargs = self.agent.tool_run_logging_kwargs() + observation = ExceptionTool().run( + output.tool_input, + verbose=False, + color=None, + callbacks=run_manager.get_child() if run_manager else None, + **tool_run_kwargs, + ) + + if self._should_force_answer(): + error = self._i18n.errors("force_final_answer") + output = AgentAction("_Exception", error, error) + yield AgentStep(action=output, observation=error) + return + + yield AgentStep(action=output, observation=observation) + return + + except BadRequestError as e: + print("Bad Request Error", e) + if "context_length_exceeded" in str(e): + self._logger.log( + "debug", + "Context length exceeded. Asking user if they want to use summarize prompt to fit, this will reduce context length.", + color="yellow", + ) + user_choice = prompt( + "Context length exceeded. Do you want to summarize the text to fit models context window? (Y/N): ", + validator=YesNoValidator(), + ).lower() + if user_choice in ["y", "yes"]: + self._logger.log( + "debug", + "Context length exceeded. Using summarize prompt to fit, this will reduce context length.", + color="bold_blue", + ) + intermediate_steps = self._handle_context_length(intermediate_steps) + + output = self.agent.plan( # type: ignore # Incompatible types in assignment (expression has type "AgentAction | AgentFinish | list[AgentAction]", variable has type "AgentAction") + intermediate_steps, + callbacks=run_manager.get_child() if run_manager else None, + **inputs, + ) + + if isinstance(output, AgentFinish): + yield output + else: + yield AgentStep(action=output, observation=None) + return + else: + self._logger.log( + "info", + "User chose not to use summarization. Try again with smaller text or use our various RAG tools from crewai_tools", + color="red", + ) + raise e + else: raise e - # If the tool chosen is the finishing tool, then we end and return. - if isinstance(output, AgentFinish): - if self.should_ask_for_human_input: - human_feedback = self._ask_human_input( - output.return_values["output"] - ) + except Exception as e: + yield AgentStep( + action=AgentAction("_Exception", str(e), str(e)), + observation=str(e), + ) + return - if self.crew and self.crew._train: - self._handle_crew_training_output(output, human_feedback) + # If the tool chosen is the finishing tool, then we end and return. + if isinstance(output, AgentFinish): + if self.should_ask_for_human_input: + human_feedback = self._ask_human_input(output.return_values["output"]) - # Making sure we only ask for it once, so disabling for the next thought loop - self.should_ask_for_human_input = False - action = AgentAction( - tool="Human Input", tool_input=human_feedback, log=output.log - ) + if self.crew and self.crew._train: + self._handle_crew_training_output(output, human_feedback) - yield AgentStep( - action=action, - observation=self._i18n.slice("human_feedback").format( - human_feedback=human_feedback - ), - ) - return - - else: - if self.crew and self.crew._train: - self._handle_crew_training_output(output) - - yield output - return - - self._create_short_term_memory(output) - - actions: List[AgentAction] - actions = [output] if isinstance(output, AgentAction) else output - yield from actions - - for agent_action in actions: - if run_manager: - run_manager.on_agent_action(agent_action, color="green") - - tool_usage = ToolUsage( - tools_handler=self.tools_handler, # type: ignore # Argument "tools_handler" to "ToolUsage" has incompatible type "ToolsHandler | None"; expected "ToolsHandler" - tools=self.tools, # type: ignore # Argument "tools" to "ToolUsage" has incompatible type "Sequence[BaseTool]"; expected "list[BaseTool]" - original_tools=self.original_tools, - tools_description=self.tools_description, - tools_names=self.tools_names, - function_calling_llm=self.function_calling_llm, - task=self.task, - agent=self.crew_agent, - action=agent_action, + # Making sure we only ask for it once, so disabling for the next thought loop + self.should_ask_for_human_input = False + action = AgentAction( + tool="Human Input", tool_input=human_feedback, log=output.log ) - # print("tool_usage", tool_usage) - tool_calling = tool_usage.parse(agent_action.log) - # print("tool_calling", tool_calling) + yield AgentStep( + action=action, + observation=self._i18n.slice("human_feedback").format( + human_feedback=human_feedback + ), + ) + return - if isinstance(tool_calling, ToolUsageErrorException): - observation = tool_calling.message + else: + if self.crew and self.crew._train: + self._handle_crew_training_output(output) + + yield output + return + + self._create_short_term_memory(output) + + actions: List[AgentAction] + actions = [output] if isinstance(output, AgentAction) else output + yield from actions + + for agent_action in actions: + if run_manager: + run_manager.on_agent_action(agent_action, color="green") + + tool_usage = ToolUsage( + tools_handler=self.tools_handler, # type: ignore # Argument "tools_handler" to "ToolUsage" has incompatible type "ToolsHandler | None"; expected "ToolsHandler" + tools=self.tools, # type: ignore # Argument "tools" to "ToolUsage" has incompatible type "Sequence[BaseTool]"; expected "list[BaseTool]" + original_tools=self.original_tools, + tools_description=self.tools_description, + tools_names=self.tools_names, + function_calling_llm=self.function_calling_llm, + task=self.task, + agent=self.crew_agent, + action=agent_action, + ) + + tool_calling = tool_usage.parse(agent_action.log) + + if isinstance(tool_calling, ToolUsageErrorException): + observation = tool_calling.message + else: + if tool_calling.tool_name.casefold().strip() in [ + name.casefold().strip() for name in name_to_tool_map + ] or tool_calling.tool_name.casefold().replace("_", " ") in [ + name.casefold().strip() for name in name_to_tool_map + ]: + observation = tool_usage.use(tool_calling, agent_action.log) else: - if tool_calling.tool_name.casefold().strip() in [ - name.casefold().strip() for name in name_to_tool_map - ] or tool_calling.tool_name.casefold().replace("_", " ") in [ - name.casefold().strip() for name in name_to_tool_map - ]: - observation = tool_usage.use(tool_calling, agent_action.log) - else: - observation = self._i18n.errors("wrong_tool_name").format( - tool=tool_calling.tool_name, - tools=", ".join( - [tool.name.casefold() for tool in self.tools] - ), - ) - yield AgentStep(action=agent_action, observation=observation) + observation = self._i18n.errors("wrong_tool_name").format( + tool=tool_calling.tool_name, + tools=", ".join([tool.name.casefold() for tool in self.tools]), + ) + yield AgentStep(action=agent_action, observation=observation) def _handle_crew_training_output( self, output: AgentFinish, human_feedback: str | None = None @@ -346,3 +341,44 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin): CrewTrainingHandler(TRAINING_DATA_FILE).append( self.crew._train_iteration, agent_id, training_data ) + + def _handle_context_length( + self, intermediate_steps: List[Tuple[AgentAction, str]] + ) -> List[Tuple[AgentAction, str]]: + text = intermediate_steps[0][1] + original_action = intermediate_steps[0][0] + + text_splitter = RecursiveCharacterTextSplitter( + separators=["\n\n", "\n"], + chunk_size=8000, + chunk_overlap=500, + ) + + if self._fit_context_window_strategy == "summarize": + docs = text_splitter.create_documents([text]) + self._logger.log( + "debug", + "Summarizing Content, it is recommended to use a RAG tool", + color="bold_blue", + ) + summarize_chain = load_summarize_chain( + self.llm, chain_type="map_reduce", verbose=True + ) + summarized_docs = [] + for doc in docs: + summary = summarize_chain.run([doc]) + summarized_docs.append(summary) + + formatted_results = "\n\n".join(summarized_docs) + summary_step = AgentStep( + action=AgentAction( + tool=original_action.tool, + tool_input=original_action.tool_input, + log=original_action.log, + ), + observation=formatted_results, + ) + summary_tuple = (summary_step.action, summary_step.observation) + return [summary_tuple] + + return intermediate_steps