Feat/sliding context window (#1042)

* patching for non-gpt model

* removal of json_object tool name assignment

* fixed issue for smaller models due to instructions prompt

* fixing for ollama llama3 models

* WIP: generated summary from documents split, could also create memgpt approach

* WIP: need tests but user inputted summarization strategy implemented - handling context window exceeding errors

* rm extra line

* removed type ignores

* added tests

* handling n to summarize prompt

* code cleanup, using click for cli asker

* rm not used class

* better refactor

* reverted poetry lock

* reverted poetry.locl

* improved context window exceeding exception class
This commit is contained in:
Lorenze Jay
2024-08-01 13:15:50 -07:00
committed by GitHub
parent c93b85ac53
commit 8118b7b7d6
7 changed files with 572 additions and 3 deletions

View File

@@ -1,6 +1,8 @@
import threading
import time
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union
import click
from langchain.agents import AgentExecutor
from langchain.agents.agent import ExceptionTool
@@ -11,12 +13,21 @@ from langchain_core.tools import BaseTool
from langchain_core.utils.input import get_color_mapping
from pydantic import InstanceOf
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.summarize import load_summarize_chain
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.tools_handler import ToolsHandler
from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException
from crewai.utilities import I18N
from crewai.utilities.constants import TRAINING_DATA_FILE
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
)
from crewai.utilities.training_handler import CrewTrainingHandler
from crewai.utilities.logger import Logger
class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
@@ -40,6 +51,8 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
system_template: Optional[str] = None
prompt_template: Optional[str] = None
response_template: Optional[str] = None
_logger: Logger = Logger(verbose_level=2)
_fit_context_window_strategy: Optional[Literal["summarize"]] = "summarize"
def _call(
self,
@@ -131,7 +144,7 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
# Call the LLM to see what to do.
output = self.agent.plan( # type: ignore # Incompatible types in assignment (expression has type "AgentAction | AgentFinish | list[AgentAction]", variable has type "AgentAction")
output = self.agent.plan(
intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
@@ -185,6 +198,27 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
yield AgentStep(action=output, observation=observation)
return
except Exception as e:
if LLMContextLengthExceededException(str(e))._is_context_limit_error(
str(e)
):
output = self._handle_context_length_error(
intermediate_steps, run_manager, inputs
)
if isinstance(output, AgentFinish):
yield output
elif isinstance(output, list):
for step in output:
yield step
return
yield AgentStep(
action=AgentAction("_Exception", str(e), str(e)),
observation=str(e),
)
return
# If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish):
if self.should_ask_for_human_input:
@@ -235,6 +269,7 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
agent=self.crew_agent,
action=agent_action,
)
tool_calling = tool_usage.parse(agent_action.log)
if isinstance(tool_calling, ToolUsageErrorException):
@@ -280,3 +315,91 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
CrewTrainingHandler(TRAINING_DATA_FILE).append(
self.crew._train_iteration, agent_id, training_data
)
def _handle_context_length(
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> List[Tuple[AgentAction, str]]:
text = intermediate_steps[0][1]
original_action = intermediate_steps[0][0]
text_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n", "\n"],
chunk_size=8000,
chunk_overlap=500,
)
if self._fit_context_window_strategy == "summarize":
docs = text_splitter.create_documents([text])
self._logger.log(
"debug",
"Summarizing Content, it is recommended to use a RAG tool",
color="bold_blue",
)
summarize_chain = load_summarize_chain(
self.llm, chain_type="map_reduce", verbose=True
)
summarized_docs = []
for doc in docs:
summary = summarize_chain.invoke(
{"input_documents": [doc]}, return_only_outputs=True
)
summarized_docs.append(summary["output_text"])
formatted_results = "\n\n".join(summarized_docs)
summary_step = AgentStep(
action=AgentAction(
tool=original_action.tool,
tool_input=original_action.tool_input,
log=original_action.log,
),
observation=formatted_results,
)
summary_tuple = (summary_step.action, summary_step.observation)
return [summary_tuple]
return intermediate_steps
def _handle_context_length_error(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
run_manager: Optional[CallbackManagerForChainRun],
inputs: Dict[str, str],
) -> Union[AgentFinish, List[AgentStep]]:
self._logger.log(
"debug",
"Context length exceeded. Asking user if they want to use summarize prompt to fit, this will reduce context length.",
color="yellow",
)
user_choice = click.confirm(
"Context length exceeded. Do you want to summarize the text to fit models context window?"
)
if user_choice:
self._logger.log(
"debug",
"Context length exceeded. Using summarize prompt to fit, this will reduce context length.",
color="bold_blue",
)
intermediate_steps = self._handle_context_length(intermediate_steps)
output = self.agent.plan(
intermediate_steps,
callbacks=run_manager.get_child() if run_manager else None,
**inputs,
)
if isinstance(output, AgentFinish):
return output
elif isinstance(output, AgentAction):
return [AgentStep(action=output, observation=None)]
else:
return [AgentStep(action=action, observation=None) for action in output]
else:
self._logger.log(
"debug",
"Context length exceeded. Consider using smaller text or RAG tools from crewai_tools.",
color="red",
)
raise SystemExit(
"Context length exceeded and user opted not to summarize. Consider using smaller text or RAG tools from crewai_tools."
)