mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-29 18:18:13 +00:00
WIP: need tests but user inputted summarization strategy implemented - handling context window exceeding errors
This commit is contained in:
@@ -0,0 +1,8 @@
|
|||||||
|
from prompt_toolkit.validation import Validator, ValidationError
|
||||||
|
|
||||||
|
|
||||||
|
class YesNoValidator(Validator):
|
||||||
|
def validate(self, document):
|
||||||
|
text = document.text.lower()
|
||||||
|
if text not in ["y", "n", "yes", "no"]:
|
||||||
|
raise ValidationError(message="Please enter Y/N")
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union
|
||||||
|
from prompt_toolkit import prompt
|
||||||
|
|
||||||
|
|
||||||
from langchain.agents import AgentExecutor
|
from langchain.agents import AgentExecutor
|
||||||
from langchain.agents.agent import ExceptionTool
|
from langchain.agents.agent import ExceptionTool
|
||||||
@@ -10,17 +12,21 @@ from langchain_core.exceptions import OutputParserException
|
|||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
from langchain_core.utils.input import get_color_mapping
|
from langchain_core.utils.input import get_color_mapping
|
||||||
from pydantic import InstanceOf
|
from pydantic import InstanceOf
|
||||||
import tiktoken
|
|
||||||
|
from openai import BadRequestError
|
||||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||||
from langchain.chains.summarize import load_summarize_chain
|
from langchain.chains.summarize import load_summarize_chain
|
||||||
from openai import BadRequestError
|
|
||||||
|
|
||||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
||||||
|
from crewai.agents.agent_builder.utilities.yes_no_cli_validator import YesNoValidator
|
||||||
from crewai.agents.tools_handler import ToolsHandler
|
from crewai.agents.tools_handler import ToolsHandler
|
||||||
|
|
||||||
|
|
||||||
from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException
|
from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException
|
||||||
from crewai.utilities import I18N
|
from crewai.utilities import I18N
|
||||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||||
|
from crewai.utilities.logger import Logger
|
||||||
|
|
||||||
|
|
||||||
class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
|
class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
|
||||||
@@ -44,8 +50,8 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
|
|||||||
system_template: Optional[str] = None
|
system_template: Optional[str] = None
|
||||||
prompt_template: Optional[str] = None
|
prompt_template: Optional[str] = None
|
||||||
response_template: Optional[str] = None
|
response_template: Optional[str] = None
|
||||||
retry_summarize: bool = False
|
_logger: Logger = Logger()
|
||||||
retry_summarize_count: int = 2
|
_fit_context_window_strategy: Optional[Literal["summarize"]] = "summarize"
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
@@ -126,7 +132,6 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
|
|||||||
|
|
||||||
Override this to take control of how the agent makes and acts on choices.
|
Override this to take control of how the agent makes and acts on choices.
|
||||||
"""
|
"""
|
||||||
for attempt in range(self.retry_summarize_count):
|
|
||||||
try:
|
try:
|
||||||
if self._should_force_answer():
|
if self._should_force_answer():
|
||||||
error = self._i18n.errors("force_final_answer")
|
error = self._i18n.errors("force_final_answer")
|
||||||
@@ -135,39 +140,7 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
|
|||||||
yield AgentStep(action=output, observation=error)
|
yield AgentStep(action=output, observation=error)
|
||||||
return
|
return
|
||||||
|
|
||||||
intermediate_steps = self._prepare_intermediate_steps(
|
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps)
|
||||||
intermediate_steps
|
|
||||||
)
|
|
||||||
if self.retry_summarize:
|
|
||||||
encoding = tiktoken.encoding_for_model(self.llm.model_name)
|
|
||||||
original_token_count = len(
|
|
||||||
encoding.encode(intermediate_steps[0][1])
|
|
||||||
)
|
|
||||||
if original_token_count > 8000:
|
|
||||||
print(
|
|
||||||
"BEFORE AGENT PLAN TOKEN LENGTH",
|
|
||||||
original_token_count,
|
|
||||||
)
|
|
||||||
text = intermediate_steps[0][1]
|
|
||||||
|
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
|
||||||
separators=["\n\n", "\n"],
|
|
||||||
chunk_size=8000,
|
|
||||||
chunk_overlap=500,
|
|
||||||
)
|
|
||||||
docs = text_splitter.create_documents([text])
|
|
||||||
print("DOCS", docs)
|
|
||||||
print("DOCS length", len(docs))
|
|
||||||
breakpoint()
|
|
||||||
# TODO: store to vector db - using memgpt like strategy
|
|
||||||
summary_chain = load_summarize_chain(
|
|
||||||
self.llm, chain_type="map_reduce", verbose=True
|
|
||||||
)
|
|
||||||
summary = summary_chain.run(docs)
|
|
||||||
|
|
||||||
print("SUMMARY:", summary)
|
|
||||||
|
|
||||||
intermediate_steps[0] = (intermediate_steps[0][0], summary)
|
|
||||||
|
|
||||||
# Call the LLM to see what to do.
|
# Call the LLM to see what to do.
|
||||||
output = self.agent.plan( # type: ignore # Incompatible types in assignment (expression has type "AgentAction | AgentFinish | list[AgentAction]", variable has type "AgentAction")
|
output = self.agent.plan( # type: ignore # Incompatible types in assignment (expression has type "AgentAction | AgentFinish | list[AgentAction]", variable has type "AgentAction")
|
||||||
@@ -225,31 +198,57 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
|
|||||||
return
|
return
|
||||||
|
|
||||||
except BadRequestError as e:
|
except BadRequestError as e:
|
||||||
print("bad request string str(e)", str(e))
|
print("Bad Request Error", e)
|
||||||
if (
|
if "context_length_exceeded" in str(e):
|
||||||
"context_length_exceeded" in str(e)
|
self._logger.log(
|
||||||
and attempt < self.retry_summarize_count - 1
|
"debug",
|
||||||
):
|
"Context length exceeded. Asking user if they want to use summarize prompt to fit, this will reduce context length.",
|
||||||
print(
|
color="yellow",
|
||||||
f"Context length exceeded. Retrying with summarization (attempt {attempt + 1})..."
|
|
||||||
)
|
)
|
||||||
self.retry_summarize = True
|
user_choice = prompt(
|
||||||
breakpoint()
|
"Context length exceeded. Do you want to summarize the text to fit models context window? (Y/N): ",
|
||||||
continue
|
validator=YesNoValidator(),
|
||||||
|
).lower()
|
||||||
|
if user_choice in ["y", "yes"]:
|
||||||
|
self._logger.log(
|
||||||
|
"debug",
|
||||||
|
"Context length exceeded. Using summarize prompt to fit, this will reduce context length.",
|
||||||
|
color="bold_blue",
|
||||||
|
)
|
||||||
|
intermediate_steps = self._handle_context_length(intermediate_steps)
|
||||||
|
|
||||||
|
output = self.agent.plan( # type: ignore # Incompatible types in assignment (expression has type "AgentAction | AgentFinish | list[AgentAction]", variable has type "AgentAction")
|
||||||
|
intermediate_steps,
|
||||||
|
callbacks=run_manager.get_child() if run_manager else None,
|
||||||
|
**inputs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(output, AgentFinish):
|
||||||
|
yield output
|
||||||
|
else:
|
||||||
|
yield AgentStep(action=output, observation=None)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
self._logger.log(
|
||||||
|
"info",
|
||||||
|
"User chose not to use summarization. Try again with smaller text or use our various RAG tools from crewai_tools",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
raise e
|
||||||
else:
|
else:
|
||||||
print("Error now raising occurred in _iter_next_step:", e)
|
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error occurred in _iter_next_step:", e)
|
yield AgentStep(
|
||||||
raise e
|
action=AgentAction("_Exception", str(e), str(e)),
|
||||||
|
observation=str(e),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# If the tool chosen is the finishing tool, then we end and return.
|
# If the tool chosen is the finishing tool, then we end and return.
|
||||||
if isinstance(output, AgentFinish):
|
if isinstance(output, AgentFinish):
|
||||||
if self.should_ask_for_human_input:
|
if self.should_ask_for_human_input:
|
||||||
human_feedback = self._ask_human_input(
|
human_feedback = self._ask_human_input(output.return_values["output"])
|
||||||
output.return_values["output"]
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.crew and self.crew._train:
|
if self.crew and self.crew._train:
|
||||||
self._handle_crew_training_output(output, human_feedback)
|
self._handle_crew_training_output(output, human_feedback)
|
||||||
@@ -297,9 +296,7 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
|
|||||||
action=agent_action,
|
action=agent_action,
|
||||||
)
|
)
|
||||||
|
|
||||||
# print("tool_usage", tool_usage)
|
|
||||||
tool_calling = tool_usage.parse(agent_action.log)
|
tool_calling = tool_usage.parse(agent_action.log)
|
||||||
# print("tool_calling", tool_calling)
|
|
||||||
|
|
||||||
if isinstance(tool_calling, ToolUsageErrorException):
|
if isinstance(tool_calling, ToolUsageErrorException):
|
||||||
observation = tool_calling.message
|
observation = tool_calling.message
|
||||||
@@ -313,9 +310,7 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
|
|||||||
else:
|
else:
|
||||||
observation = self._i18n.errors("wrong_tool_name").format(
|
observation = self._i18n.errors("wrong_tool_name").format(
|
||||||
tool=tool_calling.tool_name,
|
tool=tool_calling.tool_name,
|
||||||
tools=", ".join(
|
tools=", ".join([tool.name.casefold() for tool in self.tools]),
|
||||||
[tool.name.casefold() for tool in self.tools]
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
yield AgentStep(action=agent_action, observation=observation)
|
yield AgentStep(action=agent_action, observation=observation)
|
||||||
|
|
||||||
@@ -346,3 +341,44 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
|
|||||||
CrewTrainingHandler(TRAINING_DATA_FILE).append(
|
CrewTrainingHandler(TRAINING_DATA_FILE).append(
|
||||||
self.crew._train_iteration, agent_id, training_data
|
self.crew._train_iteration, agent_id, training_data
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _handle_context_length(
|
||||||
|
self, intermediate_steps: List[Tuple[AgentAction, str]]
|
||||||
|
) -> List[Tuple[AgentAction, str]]:
|
||||||
|
text = intermediate_steps[0][1]
|
||||||
|
original_action = intermediate_steps[0][0]
|
||||||
|
|
||||||
|
text_splitter = RecursiveCharacterTextSplitter(
|
||||||
|
separators=["\n\n", "\n"],
|
||||||
|
chunk_size=8000,
|
||||||
|
chunk_overlap=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._fit_context_window_strategy == "summarize":
|
||||||
|
docs = text_splitter.create_documents([text])
|
||||||
|
self._logger.log(
|
||||||
|
"debug",
|
||||||
|
"Summarizing Content, it is recommended to use a RAG tool",
|
||||||
|
color="bold_blue",
|
||||||
|
)
|
||||||
|
summarize_chain = load_summarize_chain(
|
||||||
|
self.llm, chain_type="map_reduce", verbose=True
|
||||||
|
)
|
||||||
|
summarized_docs = []
|
||||||
|
for doc in docs:
|
||||||
|
summary = summarize_chain.run([doc])
|
||||||
|
summarized_docs.append(summary)
|
||||||
|
|
||||||
|
formatted_results = "\n\n".join(summarized_docs)
|
||||||
|
summary_step = AgentStep(
|
||||||
|
action=AgentAction(
|
||||||
|
tool=original_action.tool,
|
||||||
|
tool_input=original_action.tool_input,
|
||||||
|
log=original_action.log,
|
||||||
|
),
|
||||||
|
observation=formatted_results,
|
||||||
|
)
|
||||||
|
summary_tuple = (summary_step.action, summary_step.observation)
|
||||||
|
return [summary_tuple]
|
||||||
|
|
||||||
|
return intermediate_steps
|
||||||
|
|||||||
Reference in New Issue
Block a user