WIP: need tests but user inputted summarization strategy implemented - handling context window exceeding errors

This commit is contained in:
Lorenze Jay
2024-07-31 16:29:24 -07:00
parent 62868c00db
commit 579689900a
2 changed files with 229 additions and 185 deletions

View File

@@ -0,0 +1,8 @@
from prompt_toolkit.validation import Validator, ValidationError
class YesNoValidator(Validator):
def validate(self, document):
text = document.text.lower()
if text not in ["y", "n", "yes", "no"]:
raise ValidationError(message="Please enter Y/N")

View File

@@ -1,6 +1,8 @@
import threading import threading
import time import time
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union
from prompt_toolkit import prompt
from langchain.agents import AgentExecutor from langchain.agents import AgentExecutor
from langchain.agents.agent import ExceptionTool from langchain.agents.agent import ExceptionTool
@@ -10,17 +12,21 @@ from langchain_core.exceptions import OutputParserException
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils.input import get_color_mapping from langchain_core.utils.input import get_color_mapping
from pydantic import InstanceOf from pydantic import InstanceOf
import tiktoken
from openai import BadRequestError
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.summarize import load_summarize_chain from langchain.chains.summarize import load_summarize_chain
from openai import BadRequestError
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.agent_builder.utilities.yes_no_cli_validator import YesNoValidator
from crewai.agents.tools_handler import ToolsHandler from crewai.agents.tools_handler import ToolsHandler
from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException
from crewai.utilities import I18N from crewai.utilities import I18N
from crewai.utilities.constants import TRAINING_DATA_FILE from crewai.utilities.constants import TRAINING_DATA_FILE
from crewai.utilities.training_handler import CrewTrainingHandler from crewai.utilities.training_handler import CrewTrainingHandler
from crewai.utilities.logger import Logger
class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin): class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
@@ -44,8 +50,8 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
system_template: Optional[str] = None system_template: Optional[str] = None
prompt_template: Optional[str] = None prompt_template: Optional[str] = None
response_template: Optional[str] = None response_template: Optional[str] = None
retry_summarize: bool = False _logger: Logger = Logger()
retry_summarize_count: int = 2 _fit_context_window_strategy: Optional[Literal["summarize"]] = "summarize"
def _call( def _call(
self, self,
@@ -126,7 +132,6 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
Override this to take control of how the agent makes and acts on choices. Override this to take control of how the agent makes and acts on choices.
""" """
for attempt in range(self.retry_summarize_count):
try: try:
if self._should_force_answer(): if self._should_force_answer():
error = self._i18n.errors("force_final_answer") error = self._i18n.errors("force_final_answer")
@@ -135,39 +140,7 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
yield AgentStep(action=output, observation=error) yield AgentStep(action=output, observation=error)
return return
intermediate_steps = self._prepare_intermediate_steps( intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
intermediate_steps
)
if self.retry_summarize:
encoding = tiktoken.encoding_for_model(self.llm.model_name)
original_token_count = len(
encoding.encode(intermediate_steps[0][1])
)
if original_token_count > 8000:
print(
"BEFORE AGENT PLAN TOKEN LENGTH",
original_token_count,
)
text = intermediate_steps[0][1]
text_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n", "\n"],
chunk_size=8000,
chunk_overlap=500,
)
docs = text_splitter.create_documents([text])
print("DOCS", docs)
print("DOCS length", len(docs))
breakpoint()
# TODO: store to vector db - using memgpt like strategy
summary_chain = load_summarize_chain(
self.llm, chain_type="map_reduce", verbose=True
)
summary = summary_chain.run(docs)
print("SUMMARY:", summary)
intermediate_steps[0] = (intermediate_steps[0][0], summary)
# Call the LLM to see what to do. # Call the LLM to see what to do.
output = self.agent.plan( # type: ignore # Incompatible types in assignment (expression has type "AgentAction | AgentFinish | list[AgentAction]", variable has type "AgentAction") output = self.agent.plan( # type: ignore # Incompatible types in assignment (expression has type "AgentAction | AgentFinish | list[AgentAction]", variable has type "AgentAction")
@@ -225,31 +198,57 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
return return
except BadRequestError as e: except BadRequestError as e:
print("bad request string str(e)", str(e)) print("Bad Request Error", e)
if ( if "context_length_exceeded" in str(e):
"context_length_exceeded" in str(e) self._logger.log(
and attempt < self.retry_summarize_count - 1 "debug",
): "Context length exceeded. Asking user if they want to use summarize prompt to fit, this will reduce context length.",
print( color="yellow",
f"Context length exceeded. Retrying with summarization (attempt {attempt + 1})..."
) )
self.retry_summarize = True user_choice = prompt(
breakpoint() "Context length exceeded. Do you want to summarize the text to fit models context window? (Y/N): ",
continue validator=YesNoValidator(),
).lower()
if user_choice in ["y", "yes"]:
self._logger.log(
"debug",
"Context length exceeded. Using summarize prompt to fit, this will reduce context length.",
color="bold_blue",
)
intermediate_steps = self._handle_context_length(intermediate_steps)
output = self.agent.plan( # type: ignore # Incompatible types in assignment (expression has type "AgentAction | AgentFinish | list[AgentAction]", variable has type "AgentAction")
intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
)
if isinstance(output, AgentFinish):
yield output
else:
yield AgentStep(action=output, observation=None)
return
else:
self._logger.log(
"info",
"User chose not to use summarization. Try again with smaller text or use our various RAG tools from crewai_tools",
color="red",
)
raise e
else: else:
print("Error now raising occurred in _iter_next_step:", e)
raise e raise e
except Exception as e: except Exception as e:
print("Error occurred in _iter_next_step:", e) yield AgentStep(
raise e action=AgentAction("_Exception", str(e), str(e)),
observation=str(e),
)
return
# If the tool chosen is the finishing tool, then we end and return. # If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish): if isinstance(output, AgentFinish):
if self.should_ask_for_human_input: if self.should_ask_for_human_input:
human_feedback = self._ask_human_input( human_feedback = self._ask_human_input(output.return_values["output"])
output.return_values["output"]
)
if self.crew and self.crew._train: if self.crew and self.crew._train:
self._handle_crew_training_output(output, human_feedback) self._handle_crew_training_output(output, human_feedback)
@@ -297,9 +296,7 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
action=agent_action, action=agent_action,
) )
# print("tool_usage", tool_usage)
tool_calling = tool_usage.parse(agent_action.log) tool_calling = tool_usage.parse(agent_action.log)
# print("tool_calling", tool_calling)
if isinstance(tool_calling, ToolUsageErrorException): if isinstance(tool_calling, ToolUsageErrorException):
observation = tool_calling.message observation = tool_calling.message
@@ -313,9 +310,7 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
else: else:
observation = self._i18n.errors("wrong_tool_name").format( observation = self._i18n.errors("wrong_tool_name").format(
tool=tool_calling.tool_name, tool=tool_calling.tool_name,
tools=", ".join( tools=", ".join([tool.name.casefold() for tool in self.tools]),
[tool.name.casefold() for tool in self.tools]
),
) )
yield AgentStep(action=agent_action, observation=observation) yield AgentStep(action=agent_action, observation=observation)
@@ -346,3 +341,44 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
CrewTrainingHandler(TRAINING_DATA_FILE).append( CrewTrainingHandler(TRAINING_DATA_FILE).append(
self.crew._train_iteration, agent_id, training_data self.crew._train_iteration, agent_id, training_data
) )
def _handle_context_length(
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> List[Tuple[AgentAction, str]]:
text = intermediate_steps[0][1]
original_action = intermediate_steps[0][0]
text_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n", "\n"],
chunk_size=8000,
chunk_overlap=500,
)
if self._fit_context_window_strategy == "summarize":
docs = text_splitter.create_documents([text])
self._logger.log(
"debug",
"Summarizing Content, it is recommended to use a RAG tool",
color="bold_blue",
)
summarize_chain = load_summarize_chain(
self.llm, chain_type="map_reduce", verbose=True
)
summarized_docs = []
for doc in docs:
summary = summarize_chain.run([doc])
summarized_docs.append(summary)
formatted_results = "\n\n".join(summarized_docs)
summary_step = AgentStep(
action=AgentAction(
tool=original_action.tool,
tool_input=original_action.tool_input,
log=original_action.log,
),
observation=formatted_results,
)
summary_tuple = (summary_step.action, summary_step.observation)
return [summary_tuple]
return intermediate_steps