mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
Feat/sliding context window (#1042)
* patching for non-gpt model * removal of json_object tool name assignment * fixed issue for smaller models due to instructions prompt * fixing for ollama llama3 models * WIP: generated summary from documents split, could also create memgpt approach * WIP: need tests but user inputted summarization strategy implemented - handling context window exceeding errors * rm extra line * removed type ignores * added tests * handling n to summarize prompt * code cleanup, using click for cli asker * rm not used class * better refactor * reverted poetry lock * reverted poetry.locl * improved context window exceeding exception class
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union
|
||||
import click
|
||||
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.agents.agent import ExceptionTool
|
||||
@@ -11,12 +13,21 @@ from langchain_core.tools import BaseTool
|
||||
from langchain_core.utils.input import get_color_mapping
|
||||
from pydantic import InstanceOf
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.chains.summarize import load_summarize_chain
|
||||
|
||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
|
||||
|
||||
from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException
|
||||
from crewai.utilities import I18N
|
||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
)
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
from crewai.utilities.logger import Logger
|
||||
|
||||
|
||||
class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
|
||||
@@ -40,6 +51,8 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
|
||||
system_template: Optional[str] = None
|
||||
prompt_template: Optional[str] = None
|
||||
response_template: Optional[str] = None
|
||||
_logger: Logger = Logger(verbose_level=2)
|
||||
_fit_context_window_strategy: Optional[Literal["summarize"]] = "summarize"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
@@ -131,7 +144,7 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
|
||||
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
|
||||
|
||||
# Call the LLM to see what to do.
|
||||
output = self.agent.plan( # type: ignore # Incompatible types in assignment (expression has type "AgentAction | AgentFinish | list[AgentAction]", variable has type "AgentAction")
|
||||
output = self.agent.plan(
|
||||
intermediate_steps,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**inputs,
|
||||
@@ -185,6 +198,27 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
|
||||
yield AgentStep(action=output, observation=observation)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
if LLMContextLengthExceededException(str(e))._is_context_limit_error(
|
||||
str(e)
|
||||
):
|
||||
output = self._handle_context_length_error(
|
||||
intermediate_steps, run_manager, inputs
|
||||
)
|
||||
|
||||
if isinstance(output, AgentFinish):
|
||||
yield output
|
||||
elif isinstance(output, list):
|
||||
for step in output:
|
||||
yield step
|
||||
return
|
||||
|
||||
yield AgentStep(
|
||||
action=AgentAction("_Exception", str(e), str(e)),
|
||||
observation=str(e),
|
||||
)
|
||||
return
|
||||
|
||||
# If the tool chosen is the finishing tool, then we end and return.
|
||||
if isinstance(output, AgentFinish):
|
||||
if self.should_ask_for_human_input:
|
||||
@@ -235,6 +269,7 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
|
||||
agent=self.crew_agent,
|
||||
action=agent_action,
|
||||
)
|
||||
|
||||
tool_calling = tool_usage.parse(agent_action.log)
|
||||
|
||||
if isinstance(tool_calling, ToolUsageErrorException):
|
||||
@@ -280,3 +315,91 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
|
||||
CrewTrainingHandler(TRAINING_DATA_FILE).append(
|
||||
self.crew._train_iteration, agent_id, training_data
|
||||
)
|
||||
|
||||
def _handle_context_length(
|
||||
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||
) -> List[Tuple[AgentAction, str]]:
|
||||
text = intermediate_steps[0][1]
|
||||
original_action = intermediate_steps[0][0]
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
separators=["\n\n", "\n"],
|
||||
chunk_size=8000,
|
||||
chunk_overlap=500,
|
||||
)
|
||||
|
||||
if self._fit_context_window_strategy == "summarize":
|
||||
docs = text_splitter.create_documents([text])
|
||||
self._logger.log(
|
||||
"debug",
|
||||
"Summarizing Content, it is recommended to use a RAG tool",
|
||||
color="bold_blue",
|
||||
)
|
||||
summarize_chain = load_summarize_chain(
|
||||
self.llm, chain_type="map_reduce", verbose=True
|
||||
)
|
||||
summarized_docs = []
|
||||
for doc in docs:
|
||||
summary = summarize_chain.invoke(
|
||||
{"input_documents": [doc]}, return_only_outputs=True
|
||||
)
|
||||
|
||||
summarized_docs.append(summary["output_text"])
|
||||
|
||||
formatted_results = "\n\n".join(summarized_docs)
|
||||
summary_step = AgentStep(
|
||||
action=AgentAction(
|
||||
tool=original_action.tool,
|
||||
tool_input=original_action.tool_input,
|
||||
log=original_action.log,
|
||||
),
|
||||
observation=formatted_results,
|
||||
)
|
||||
summary_tuple = (summary_step.action, summary_step.observation)
|
||||
return [summary_tuple]
|
||||
|
||||
return intermediate_steps
|
||||
|
||||
def _handle_context_length_error(
|
||||
self,
|
||||
intermediate_steps: List[Tuple[AgentAction, str]],
|
||||
run_manager: Optional[CallbackManagerForChainRun],
|
||||
inputs: Dict[str, str],
|
||||
) -> Union[AgentFinish, List[AgentStep]]:
|
||||
self._logger.log(
|
||||
"debug",
|
||||
"Context length exceeded. Asking user if they want to use summarize prompt to fit, this will reduce context length.",
|
||||
color="yellow",
|
||||
)
|
||||
user_choice = click.confirm(
|
||||
"Context length exceeded. Do you want to summarize the text to fit models context window?"
|
||||
)
|
||||
if user_choice:
|
||||
self._logger.log(
|
||||
"debug",
|
||||
"Context length exceeded. Using summarize prompt to fit, this will reduce context length.",
|
||||
color="bold_blue",
|
||||
)
|
||||
intermediate_steps = self._handle_context_length(intermediate_steps)
|
||||
|
||||
output = self.agent.plan(
|
||||
intermediate_steps,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
**inputs,
|
||||
)
|
||||
|
||||
if isinstance(output, AgentFinish):
|
||||
return output
|
||||
elif isinstance(output, AgentAction):
|
||||
return [AgentStep(action=output, observation=None)]
|
||||
else:
|
||||
return [AgentStep(action=action, observation=None) for action in output]
|
||||
else:
|
||||
self._logger.log(
|
||||
"debug",
|
||||
"Context length exceeded. Consider using smaller text or RAG tools from crewai_tools.",
|
||||
color="red",
|
||||
)
|
||||
raise SystemExit(
|
||||
"Context length exceeded and user opted not to summarize. Consider using smaller text or RAG tools from crewai_tools."
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@ try:
|
||||
except ImportError:
|
||||
agentops = None
|
||||
|
||||
OPENAI_BIGGER_MODELS = ["gpt-4"]
|
||||
OPENAI_BIGGER_MODELS = ["gpt-4o"]
|
||||
|
||||
|
||||
class ToolUsageErrorException(Exception):
|
||||
|
||||
@@ -7,6 +7,9 @@ from .parser import YamlParser
|
||||
from .printer import Printer
|
||||
from .prompts import Prompts
|
||||
from .rpm_controller import RPMController
|
||||
from .exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Converter",
|
||||
@@ -19,4 +22,5 @@ __all__ = [
|
||||
"Prompts",
|
||||
"RPMController",
|
||||
"YamlParser",
|
||||
"LLMContextLengthExceededException",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
class LLMContextLengthExceededException(Exception):
|
||||
CONTEXT_LIMIT_ERRORS = [
|
||||
"maximum context length",
|
||||
"context length exceeded",
|
||||
"context_length_exceeded",
|
||||
"context window full",
|
||||
"too many tokens",
|
||||
"input is too long",
|
||||
"exceeds token limit",
|
||||
]
|
||||
|
||||
def __init__(self, error_message: str):
|
||||
self.original_error_message = error_message
|
||||
super().__init__(self._get_error_message(error_message))
|
||||
|
||||
def _is_context_limit_error(self, error_message: str) -> bool:
|
||||
return any(
|
||||
phrase.lower() in error_message.lower()
|
||||
for phrase in self.CONTEXT_LIMIT_ERRORS
|
||||
)
|
||||
|
||||
def _get_error_message(self, error_message: str):
|
||||
return (
|
||||
f"LLM context length exceeded. Original error: {error_message}\n"
|
||||
"Consider using a smaller input or implementing a text splitting strategy."
|
||||
)
|
||||
Reference in New Issue
Block a user