WIP: need tests but user inputted summarization strategy implemented - handling context window exceeding errors

This commit is contained in:
Lorenze Jay
2024-07-31 16:29:24 -07:00
parent 62868c00db
commit 579689900a
2 changed files with 229 additions and 185 deletions

View File

@@ -0,0 +1,8 @@
from prompt_toolkit.validation import Validator, ValidationError
class YesNoValidator(Validator):
def validate(self, document):
text = document.text.lower()
if text not in ["y", "n", "yes", "no"]:
raise ValidationError(message="Please enter Y/N")

View File

@@ -1,6 +1,8 @@
import threading
import time
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union
from prompt_toolkit import prompt
from langchain.agents import AgentExecutor
from langchain.agents.agent import ExceptionTool
@@ -10,17 +12,21 @@ from langchain_core.exceptions import OutputParserException
from langchain_core.tools import BaseTool
from langchain_core.utils.input import get_color_mapping
from pydantic import InstanceOf
import tiktoken
from openai import BadRequestError
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.summarize import load_summarize_chain
from openai import BadRequestError
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.agent_builder.utilities.yes_no_cli_validator import YesNoValidator
from crewai.agents.tools_handler import ToolsHandler
from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException
from crewai.utilities import I18N
from crewai.utilities.constants import TRAINING_DATA_FILE
from crewai.utilities.training_handler import CrewTrainingHandler
from crewai.utilities.logger import Logger
class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
@@ -44,8 +50,8 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
system_template: Optional[str] = None
prompt_template: Optional[str] = None
response_template: Optional[str] = None
retry_summarize: bool = False
retry_summarize_count: int = 2
_logger: Logger = Logger()
_fit_context_window_strategy: Optional[Literal["summarize"]] = "summarize"
def _call(
self,
@@ -126,198 +132,187 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
Override this to take control of how the agent makes and acts on choices.
"""
for attempt in range(self.retry_summarize_count):
try:
if self._should_force_answer():
error = self._i18n.errors("force_final_answer")
output = AgentAction("_Exception", error, error)
self.have_forced_answer = True
yield AgentStep(action=output, observation=error)
return
intermediate_steps = self._prepare_intermediate_steps(
intermediate_steps
)
if self.retry_summarize:
encoding = tiktoken.encoding_for_model(self.llm.model_name)
original_token_count = len(
encoding.encode(intermediate_steps[0][1])
)
if original_token_count > 8000:
print(
"BEFORE AGENT PLAN TOKEN LENGTH",
original_token_count,
)
text = intermediate_steps[0][1]
text_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n", "\n"],
chunk_size=8000,
chunk_overlap=500,
)
docs = text_splitter.create_documents([text])
print("DOCS", docs)
print("DOCS length", len(docs))
breakpoint()
# TODO: store to vector db - using memgpt like strategy
summary_chain = load_summarize_chain(
self.llm, chain_type="map_reduce", verbose=True
)
summary = summary_chain.run(docs)
print("SUMMARY:", summary)
intermediate_steps[0] = (intermediate_steps[0][0], summary)
# Call the LLM to see what to do.
output = self.agent.plan( # type: ignore # Incompatible types in assignment (expression has type "AgentAction | AgentFinish | list[AgentAction]", variable has type "AgentAction")
intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
)
except OutputParserException as e:
if isinstance(self.handle_parsing_errors, bool):
raise_error = not self.handle_parsing_errors
else:
raise_error = False
if raise_error:
raise ValueError(
"An output parsing error occurred. "
"In order to pass this error back to the agent and have it try "
"again, pass `handle_parsing_errors=True` to the AgentExecutor. "
f"This is the error: {str(e)}"
)
str(e)
if isinstance(self.handle_parsing_errors, bool):
if e.send_to_llm:
observation = f"\n{str(e.observation)}"
str(e.llm_output)
else:
observation = ""
elif isinstance(self.handle_parsing_errors, str):
observation = f"\n{self.handle_parsing_errors}"
elif callable(self.handle_parsing_errors):
observation = f"\n{self.handle_parsing_errors(e)}"
else:
raise ValueError("Got unexpected type of `handle_parsing_errors`")
output = AgentAction("_Exception", observation, "")
if run_manager:
run_manager.on_agent_action(output, color="green")
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = ExceptionTool().run(
output.tool_input,
verbose=False,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
if self._should_force_answer():
error = self._i18n.errors("force_final_answer")
output = AgentAction("_Exception", error, error)
yield AgentStep(action=output, observation=error)
return
yield AgentStep(action=output, observation=observation)
try:
if self._should_force_answer():
error = self._i18n.errors("force_final_answer")
output = AgentAction("_Exception", error, error)
self.have_forced_answer = True
yield AgentStep(action=output, observation=error)
return
except BadRequestError as e:
print("bad request string str(e)", str(e))
if (
"context_length_exceeded" in str(e)
and attempt < self.retry_summarize_count - 1
):
print(
f"Context length exceeded. Retrying with summarization (attempt {attempt + 1})..."
)
self.retry_summarize = True
breakpoint()
continue
else:
print("Error now raising occurred in _iter_next_step:", e)
raise e
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
except Exception as e:
print("Error occurred in _iter_next_step:", e)
# Call the LLM to see what to do.
output = self.agent.plan( # type: ignore # Incompatible types in assignment (expression has type "AgentAction | AgentFinish | list[AgentAction]", variable has type "AgentAction")
intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
)
except OutputParserException as e:
if isinstance(self.handle_parsing_errors, bool):
raise_error = not self.handle_parsing_errors
else:
raise_error = False
if raise_error:
raise ValueError(
"An output parsing error occurred. "
"In order to pass this error back to the agent and have it try "
"again, pass `handle_parsing_errors=True` to the AgentExecutor. "
f"This is the error: {str(e)}"
)
str(e)
if isinstance(self.handle_parsing_errors, bool):
if e.send_to_llm:
observation = f"\n{str(e.observation)}"
str(e.llm_output)
else:
observation = ""
elif isinstance(self.handle_parsing_errors, str):
observation = f"\n{self.handle_parsing_errors}"
elif callable(self.handle_parsing_errors):
observation = f"\n{self.handle_parsing_errors(e)}"
else:
raise ValueError("Got unexpected type of `handle_parsing_errors`")
output = AgentAction("_Exception", observation, "")
if run_manager:
run_manager.on_agent_action(output, color="green")
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = ExceptionTool().run(
output.tool_input,
verbose=False,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
if self._should_force_answer():
error = self._i18n.errors("force_final_answer")
output = AgentAction("_Exception", error, error)
yield AgentStep(action=output, observation=error)
return
yield AgentStep(action=output, observation=observation)
return
except BadRequestError as e:
print("Bad Request Error", e)
if "context_length_exceeded" in str(e):
self._logger.log(
"debug",
"Context length exceeded. Asking user if they want to use summarize prompt to fit, this will reduce context length.",
color="yellow",
)
user_choice = prompt(
"Context length exceeded. Do you want to summarize the text to fit models context window? (Y/N): ",
validator=YesNoValidator(),
).lower()
if user_choice in ["y", "yes"]:
self._logger.log(
"debug",
"Context length exceeded. Using summarize prompt to fit, this will reduce context length.",
color="bold_blue",
)
intermediate_steps = self._handle_context_length(intermediate_steps)
output = self.agent.plan( # type: ignore # Incompatible types in assignment (expression has type "AgentAction | AgentFinish | list[AgentAction]", variable has type "AgentAction")
intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
)
if isinstance(output, AgentFinish):
yield output
else:
yield AgentStep(action=output, observation=None)
return
else:
self._logger.log(
"info",
"User chose not to use summarization. Try again with smaller text or use our various RAG tools from crewai_tools",
color="red",
)
raise e
else:
raise e
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
if self.should_ask_for_human_input:
human_feedback = self._ask_human_input(
output.return_values["output"]
)
except Exception as e:
yield AgentStep(
action=AgentAction("_Exception", str(e), str(e)),
observation=str(e),
)
return
if self.crew and self.crew._train:
self._handle_crew_training_output(output, human_feedback)
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
if self.should_ask_for_human_input:
human_feedback = self._ask_human_input(output.return_values["output"])
# Making sure we only ask for it once, so disabling for the next thought loop
self.should_ask_for_human_input = False
action = AgentAction(
tool="Human Input", tool_input=human_feedback, log=output.log
)
if self.crew and self.crew._train:
self._handle_crew_training_output(output, human_feedback)
yield AgentStep(
action=action,
observation=self._i18n.slice("human_feedback").format(
human_feedback=human_feedback
),
)
return
else:
if self.crew and self.crew._train:
self._handle_crew_training_output(output)
yield output
return
self._create_short_term_memory(output)
actions: List[AgentAction]
actions = [output] if isinstance(output, AgentAction) else output
yield from actions
for agent_action in actions:
if run_manager:
run_manager.on_agent_action(agent_action, color="green")
tool_usage = ToolUsage(
tools_handler=self.tools_handler, # type: ignore # Argument "tools_handler" to "ToolUsage" has incompatible type "ToolsHandler | None"; expected "ToolsHandler"
tools=self.tools, # type: ignore # Argument "tools" to "ToolUsage" has incompatible type "Sequence[BaseTool]"; expected "list[BaseTool]"
original_tools=self.original_tools,
tools_description=self.tools_description,
tools_names=self.tools_names,
function_calling_llm=self.function_calling_llm,
task=self.task,
agent=self.crew_agent,
action=agent_action,
# Making sure we only ask for it once, so disabling for the next thought loop
self.should_ask_for_human_input = False
action = AgentAction(
tool="Human Input", tool_input=human_feedback, log=output.log
)
# print("tool_usage", tool_usage)
tool_calling = tool_usage.parse(agent_action.log)
# print("tool_calling", tool_calling)
yield AgentStep(
action=action,
observation=self._i18n.slice("human_feedback").format(
human_feedback=human_feedback
),
)
return
if isinstance(tool_calling, ToolUsageErrorException):
observation = tool_calling.message
else:
if self.crew and self.crew._train:
self._handle_crew_training_output(output)
yield output
return
self._create_short_term_memory(output)
actions: List[AgentAction]
actions = [output] if isinstance(output, AgentAction) else output
yield from actions
for agent_action in actions:
if run_manager:
run_manager.on_agent_action(agent_action, color="green")
tool_usage = ToolUsage(
tools_handler=self.tools_handler, # type: ignore # Argument "tools_handler" to "ToolUsage" has incompatible type "ToolsHandler | None"; expected "ToolsHandler"
tools=self.tools, # type: ignore # Argument "tools" to "ToolUsage" has incompatible type "Sequence[BaseTool]"; expected "list[BaseTool]"
original_tools=self.original_tools,
tools_description=self.tools_description,
tools_names=self.tools_names,
function_calling_llm=self.function_calling_llm,
task=self.task,
agent=self.crew_agent,
action=agent_action,
)
tool_calling = tool_usage.parse(agent_action.log)
if isinstance(tool_calling, ToolUsageErrorException):
observation = tool_calling.message
else:
if tool_calling.tool_name.casefold().strip() in [
name.casefold().strip() for name in name_to_tool_map
] or tool_calling.tool_name.casefold().replace("_", " ") in [
name.casefold().strip() for name in name_to_tool_map
]:
observation = tool_usage.use(tool_calling, agent_action.log)
else:
if tool_calling.tool_name.casefold().strip() in [
name.casefold().strip() for name in name_to_tool_map
] or tool_calling.tool_name.casefold().replace("_", " ") in [
name.casefold().strip() for name in name_to_tool_map
]:
observation = tool_usage.use(tool_calling, agent_action.log)
else:
observation = self._i18n.errors("wrong_tool_name").format(
tool=tool_calling.tool_name,
tools=", ".join(
[tool.name.casefold() for tool in self.tools]
),
)
yield AgentStep(action=agent_action, observation=observation)
observation = self._i18n.errors("wrong_tool_name").format(
tool=tool_calling.tool_name,
tools=", ".join([tool.name.casefold() for tool in self.tools]),
)
yield AgentStep(action=agent_action, observation=observation)
def _handle_crew_training_output(
self, output: AgentFinish, human_feedback: str | None = None
@@ -346,3 +341,44 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
CrewTrainingHandler(TRAINING_DATA_FILE).append(
self.crew._train_iteration, agent_id, training_data
)
def _handle_context_length(
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> List[Tuple[AgentAction, str]]:
text = intermediate_steps[0][1]
original_action = intermediate_steps[0][0]
text_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n", "\n"],
chunk_size=8000,
chunk_overlap=500,
)
if self._fit_context_window_strategy == "summarize":
docs = text_splitter.create_documents([text])
self._logger.log(
"debug",
"Summarizing Content, it is recommended to use a RAG tool",
color="bold_blue",
)
summarize_chain = load_summarize_chain(
self.llm, chain_type="map_reduce", verbose=True
)
summarized_docs = []
for doc in docs:
summary = summarize_chain.run([doc])
summarized_docs.append(summary)
formatted_results = "\n\n".join(summarized_docs)
summary_step = AgentStep(
action=AgentAction(
tool=original_action.tool,
tool_input=original_action.tool_input,
log=original_action.log,
),
observation=formatted_results,
)
summary_tuple = (summary_step.action, summary_step.observation)
return [summary_tuple]
return intermediate_steps