WIP: need tests but user inputted summarization strategy implemented - handling context window exceeding errors

This commit is contained in:
Lorenze Jay
2024-07-31 16:29:24 -07:00
parent 62868c00db
commit 579689900a
2 changed files with 229 additions and 185 deletions

View File

@@ -0,0 +1,8 @@
from prompt_toolkit.validation import Validator, ValidationError
class YesNoValidator(Validator):
def validate(self, document):
text = document.text.lower()
if text not in ["y", "n", "yes", "no"]:
raise ValidationError(message="Please enter Y/N")

View File

@@ -1,6 +1,8 @@
import threading import threading
import time import time
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union
from prompt_toolkit import prompt
from langchain.agents import AgentExecutor from langchain.agents import AgentExecutor
from langchain.agents.agent import ExceptionTool from langchain.agents.agent import ExceptionTool
@@ -10,17 +12,21 @@ from langchain_core.exceptions import OutputParserException
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils.input import get_color_mapping from langchain_core.utils.input import get_color_mapping
from pydantic import InstanceOf from pydantic import InstanceOf
import tiktoken
from openai import BadRequestError
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.summarize import load_summarize_chain from langchain.chains.summarize import load_summarize_chain
from openai import BadRequestError
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.agent_builder.utilities.yes_no_cli_validator import YesNoValidator
from crewai.agents.tools_handler import ToolsHandler from crewai.agents.tools_handler import ToolsHandler
from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException
from crewai.utilities import I18N from crewai.utilities import I18N
from crewai.utilities.constants import TRAINING_DATA_FILE from crewai.utilities.constants import TRAINING_DATA_FILE
from crewai.utilities.training_handler import CrewTrainingHandler from crewai.utilities.training_handler import CrewTrainingHandler
from crewai.utilities.logger import Logger
class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin): class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
@@ -44,8 +50,8 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
system_template: Optional[str] = None system_template: Optional[str] = None
prompt_template: Optional[str] = None prompt_template: Optional[str] = None
response_template: Optional[str] = None response_template: Optional[str] = None
retry_summarize: bool = False _logger: Logger = Logger()
retry_summarize_count: int = 2 _fit_context_window_strategy: Optional[Literal["summarize"]] = "summarize"
def _call( def _call(
self, self,
@@ -126,198 +132,187 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
Override this to take control of how the agent makes and acts on choices. Override this to take control of how the agent makes and acts on choices.
""" """
for attempt in range(self.retry_summarize_count): try:
try: if self._should_force_answer():
if self._should_force_answer(): error = self._i18n.errors("force_final_answer")
error = self._i18n.errors("force_final_answer") output = AgentAction("_Exception", error, error)
output = AgentAction("_Exception", error, error) self.have_forced_answer = True
self.have_forced_answer = True yield AgentStep(action=output, observation=error)
yield AgentStep(action=output, observation=error)
return
intermediate_steps = self._prepare_intermediate_steps(
intermediate_steps
)
if self.retry_summarize:
encoding = tiktoken.encoding_for_model(self.llm.model_name)
original_token_count = len(
encoding.encode(intermediate_steps[0][1])
)
if original_token_count > 8000:
print(
"BEFORE AGENT PLAN TOKEN LENGTH",
original_token_count,
)
text = intermediate_steps[0][1]
text_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n", "\n"],
chunk_size=8000,
chunk_overlap=500,
)
docs = text_splitter.create_documents([text])
print("DOCS", docs)
print("DOCS length", len(docs))
breakpoint()
# TODO: store to vector db - using memgpt like strategy
summary_chain = load_summarize_chain(
self.llm, chain_type="map_reduce", verbose=True
)
summary = summary_chain.run(docs)
print("SUMMARY:", summary)
intermediate_steps[0] = (intermediate_steps[0][0], summary)
# Call the LLM to see what to do.
output = self.agent.plan( # type: ignore # Incompatible types in assignment (expression has type "AgentAction | AgentFinish | list[AgentAction]", variable has type "AgentAction")
intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
)
except OutputParserException as e:
if isinstance(self.handle_parsing_errors, bool):
raise_error = not self.handle_parsing_errors
else:
raise_error = False
if raise_error:
raise ValueError(
"An output parsing error occurred. "
"In order to pass this error back to the agent and have it try "
"again, pass `handle_parsing_errors=True` to the AgentExecutor. "
f"This is the error: {str(e)}"
)
str(e)
if isinstance(self.handle_parsing_errors, bool):
if e.send_to_llm:
observation = f"\n{str(e.observation)}"
str(e.llm_output)
else:
observation = ""
elif isinstance(self.handle_parsing_errors, str):
observation = f"\n{self.handle_parsing_errors}"
elif callable(self.handle_parsing_errors):
observation = f"\n{self.handle_parsing_errors(e)}"
else:
raise ValueError("Got unexpected type of `handle_parsing_errors`")
output = AgentAction("_Exception", observation, "")
if run_manager:
run_manager.on_agent_action(output, color="green")
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = ExceptionTool().run(
output.tool_input,
verbose=False,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
if self._should_force_answer():
error = self._i18n.errors("force_final_answer")
output = AgentAction("_Exception", error, error)
yield AgentStep(action=output, observation=error)
return
yield AgentStep(action=output, observation=observation)
return return
except BadRequestError as e: intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
print("bad request string str(e)", str(e))
if (
"context_length_exceeded" in str(e)
and attempt < self.retry_summarize_count - 1
):
print(
f"Context length exceeded. Retrying with summarization (attempt {attempt + 1})..."
)
self.retry_summarize = True
breakpoint()
continue
else:
print("Error now raising occurred in _iter_next_step:", e)
raise e
except Exception as e: # Call the LLM to see what to do.
print("Error occurred in _iter_next_step:", e) output = self.agent.plan( # type: ignore # Incompatible types in assignment (expression has type "AgentAction | AgentFinish | list[AgentAction]", variable has type "AgentAction")
intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
)
except OutputParserException as e:
if isinstance(self.handle_parsing_errors, bool):
raise_error = not self.handle_parsing_errors
else:
raise_error = False
if raise_error:
raise ValueError(
"An output parsing error occurred. "
"In order to pass this error back to the agent and have it try "
"again, pass `handle_parsing_errors=True` to the AgentExecutor. "
f"This is the error: {str(e)}"
)
str(e)
if isinstance(self.handle_parsing_errors, bool):
if e.send_to_llm:
observation = f"\n{str(e.observation)}"
str(e.llm_output)
else:
observation = ""
elif isinstance(self.handle_parsing_errors, str):
observation = f"\n{self.handle_parsing_errors}"
elif callable(self.handle_parsing_errors):
observation = f"\n{self.handle_parsing_errors(e)}"
else:
raise ValueError("Got unexpected type of `handle_parsing_errors`")
output = AgentAction("_Exception", observation, "")
if run_manager:
run_manager.on_agent_action(output, color="green")
tool_run_kwargs = self.agent.tool_run_logging_kwargs()
observation = ExceptionTool().run(
output.tool_input,
verbose=False,
color=None,
callbacks=run_manager.get_child() if run_manager else None,
**tool_run_kwargs,
)
if self._should_force_answer():
error = self._i18n.errors("force_final_answer")
output = AgentAction("_Exception", error, error)
yield AgentStep(action=output, observation=error)
return
yield AgentStep(action=output, observation=observation)
return
except BadRequestError as e:
print("Bad Request Error", e)
if "context_length_exceeded" in str(e):
self._logger.log(
"debug",
"Context length exceeded. Asking user if they want to use summarize prompt to fit, this will reduce context length.",
color="yellow",
)
user_choice = prompt(
"Context length exceeded. Do you want to summarize the text to fit models context window? (Y/N): ",
validator=YesNoValidator(),
).lower()
if user_choice in ["y", "yes"]:
self._logger.log(
"debug",
"Context length exceeded. Using summarize prompt to fit, this will reduce context length.",
color="bold_blue",
)
intermediate_steps = self._handle_context_length(intermediate_steps)
output = self.agent.plan( # type: ignore # Incompatible types in assignment (expression has type "AgentAction | AgentFinish | list[AgentAction]", variable has type "AgentAction")
intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
)
if isinstance(output, AgentFinish):
yield output
else:
yield AgentStep(action=output, observation=None)
return
else:
self._logger.log(
"info",
"User chose not to use summarization. Try again with smaller text or use our various RAG tools from crewai_tools",
color="red",
)
raise e
else:
raise e raise e
# If the tool chosen is the finishing tool, then we end and return. except Exception as e:
if isinstance(output, AgentFinish): yield AgentStep(
if self.should_ask_for_human_input: action=AgentAction("_Exception", str(e), str(e)),
human_feedback = self._ask_human_input( observation=str(e),
output.return_values["output"] )
) return
if self.crew and self.crew._train: # If the tool chosen is the finishing tool, then we end and return.
self._handle_crew_training_output(output, human_feedback) if isinstance(output, AgentFinish):
if self.should_ask_for_human_input:
human_feedback = self._ask_human_input(output.return_values["output"])
# Making sure we only ask for it once, so disabling for the next thought loop if self.crew and self.crew._train:
self.should_ask_for_human_input = False self._handle_crew_training_output(output, human_feedback)
action = AgentAction(
tool="Human Input", tool_input=human_feedback, log=output.log
)
yield AgentStep( # Making sure we only ask for it once, so disabling for the next thought loop
action=action, self.should_ask_for_human_input = False
observation=self._i18n.slice("human_feedback").format( action = AgentAction(
human_feedback=human_feedback tool="Human Input", tool_input=human_feedback, log=output.log
),
)
return
else:
if self.crew and self.crew._train:
self._handle_crew_training_output(output)
yield output
return
self._create_short_term_memory(output)
actions: List[AgentAction]
actions = [output] if isinstance(output, AgentAction) else output
yield from actions
for agent_action in actions:
if run_manager:
run_manager.on_agent_action(agent_action, color="green")
tool_usage = ToolUsage(
tools_handler=self.tools_handler, # type: ignore # Argument "tools_handler" to "ToolUsage" has incompatible type "ToolsHandler | None"; expected "ToolsHandler"
tools=self.tools, # type: ignore # Argument "tools" to "ToolUsage" has incompatible type "Sequence[BaseTool]"; expected "list[BaseTool]"
original_tools=self.original_tools,
tools_description=self.tools_description,
tools_names=self.tools_names,
function_calling_llm=self.function_calling_llm,
task=self.task,
agent=self.crew_agent,
action=agent_action,
) )
# print("tool_usage", tool_usage) yield AgentStep(
tool_calling = tool_usage.parse(agent_action.log) action=action,
# print("tool_calling", tool_calling) observation=self._i18n.slice("human_feedback").format(
human_feedback=human_feedback
),
)
return
if isinstance(tool_calling, ToolUsageErrorException): else:
observation = tool_calling.message if self.crew and self.crew._train:
self._handle_crew_training_output(output)
yield output
return
self._create_short_term_memory(output)
actions: List[AgentAction]
actions = [output] if isinstance(output, AgentAction) else output
yield from actions
for agent_action in actions:
if run_manager:
run_manager.on_agent_action(agent_action, color="green")
tool_usage = ToolUsage(
tools_handler=self.tools_handler, # type: ignore # Argument "tools_handler" to "ToolUsage" has incompatible type "ToolsHandler | None"; expected "ToolsHandler"
tools=self.tools, # type: ignore # Argument "tools" to "ToolUsage" has incompatible type "Sequence[BaseTool]"; expected "list[BaseTool]"
original_tools=self.original_tools,
tools_description=self.tools_description,
tools_names=self.tools_names,
function_calling_llm=self.function_calling_llm,
task=self.task,
agent=self.crew_agent,
action=agent_action,
)
tool_calling = tool_usage.parse(agent_action.log)
if isinstance(tool_calling, ToolUsageErrorException):
observation = tool_calling.message
else:
if tool_calling.tool_name.casefold().strip() in [
name.casefold().strip() for name in name_to_tool_map
] or tool_calling.tool_name.casefold().replace("_", " ") in [
name.casefold().strip() for name in name_to_tool_map
]:
observation = tool_usage.use(tool_calling, agent_action.log)
else: else:
if tool_calling.tool_name.casefold().strip() in [ observation = self._i18n.errors("wrong_tool_name").format(
name.casefold().strip() for name in name_to_tool_map tool=tool_calling.tool_name,
] or tool_calling.tool_name.casefold().replace("_", " ") in [ tools=", ".join([tool.name.casefold() for tool in self.tools]),
name.casefold().strip() for name in name_to_tool_map )
]: yield AgentStep(action=agent_action, observation=observation)
observation = tool_usage.use(tool_calling, agent_action.log)
else:
observation = self._i18n.errors("wrong_tool_name").format(
tool=tool_calling.tool_name,
tools=", ".join(
[tool.name.casefold() for tool in self.tools]
),
)
yield AgentStep(action=agent_action, observation=observation)
def _handle_crew_training_output( def _handle_crew_training_output(
self, output: AgentFinish, human_feedback: str | None = None self, output: AgentFinish, human_feedback: str | None = None
@@ -346,3 +341,44 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
CrewTrainingHandler(TRAINING_DATA_FILE).append( CrewTrainingHandler(TRAINING_DATA_FILE).append(
self.crew._train_iteration, agent_id, training_data self.crew._train_iteration, agent_id, training_data
) )
def _handle_context_length(
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> List[Tuple[AgentAction, str]]:
text = intermediate_steps[0][1]
original_action = intermediate_steps[0][0]
text_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n", "\n"],
chunk_size=8000,
chunk_overlap=500,
)
if self._fit_context_window_strategy == "summarize":
docs = text_splitter.create_documents([text])
self._logger.log(
"debug",
"Summarizing Content, it is recommended to use a RAG tool",
color="bold_blue",
)
summarize_chain = load_summarize_chain(
self.llm, chain_type="map_reduce", verbose=True
)
summarized_docs = []
for doc in docs:
summary = summarize_chain.run([doc])
summarized_docs.append(summary)
formatted_results = "\n\n".join(summarized_docs)
summary_step = AgentStep(
action=AgentAction(
tool=original_action.tool,
tool_input=original_action.tool_input,
log=original_action.log,
),
observation=formatted_results,
)
summary_tuple = (summary_step.action, summary_step.observation)
return [summary_tuple]
return intermediate_steps