WIP: generated summary from documents split, could also create memgpt approach

This commit is contained in:
Lorenze Jay
2024-07-30 22:50:05 -07:00
parent 149cb1ffa1
commit 62868c00db
3 changed files with 197 additions and 130 deletions

View File

@@ -1,6 +1,7 @@
import time import time
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from crewai.memory.entity.entity_memory_item import EntityMemoryItem from crewai.memory.entity.entity_memory_item import EntityMemoryItem
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem from crewai.memory.short_term.short_term_memory_item import ShortTermMemoryItem

View File

@@ -10,6 +10,10 @@ from langchain_core.exceptions import OutputParserException
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from langchain_core.utils.input import get_color_mapping from langchain_core.utils.input import get_color_mapping
from pydantic import InstanceOf from pydantic import InstanceOf
import tiktoken
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.summarize import load_summarize_chain
from openai import BadRequestError
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.tools_handler import ToolsHandler from crewai.agents.tools_handler import ToolsHandler
@@ -40,6 +44,8 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
system_template: Optional[str] = None system_template: Optional[str] = None
prompt_template: Optional[str] = None prompt_template: Optional[str] = None
response_template: Optional[str] = None response_template: Optional[str] = None
retry_summarize: bool = False
retry_summarize_count: int = 2
def _call( def _call(
self, self,
@@ -120,6 +126,7 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
Override this to take control of how the agent makes and acts on choices. Override this to take control of how the agent makes and acts on choices.
""" """
for attempt in range(self.retry_summarize_count):
try: try:
if self._should_force_answer(): if self._should_force_answer():
error = self._i18n.errors("force_final_answer") error = self._i18n.errors("force_final_answer")
@@ -128,7 +135,39 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
yield AgentStep(action=output, observation=error) yield AgentStep(action=output, observation=error)
return return
intermediate_steps = self._prepare_intermediate_steps(intermediate_steps) intermediate_steps = self._prepare_intermediate_steps(
intermediate_steps
)
if self.retry_summarize:
encoding = tiktoken.encoding_for_model(self.llm.model_name)
original_token_count = len(
encoding.encode(intermediate_steps[0][1])
)
if original_token_count > 8000:
print(
"BEFORE AGENT PLAN TOKEN LENGTH",
original_token_count,
)
text = intermediate_steps[0][1]
text_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n", "\n"],
chunk_size=8000,
chunk_overlap=500,
)
docs = text_splitter.create_documents([text])
print("DOCS", docs)
print("DOCS length", len(docs))
breakpoint()
# TODO: store to vector db - using memgpt like strategy
summary_chain = load_summarize_chain(
self.llm, chain_type="map_reduce", verbose=True
)
summary = summary_chain.run(docs)
print("SUMMARY:", summary)
intermediate_steps[0] = (intermediate_steps[0][0], summary)
# Call the LLM to see what to do. # Call the LLM to see what to do.
output = self.agent.plan( # type: ignore # Incompatible types in assignment (expression has type "AgentAction | AgentFinish | list[AgentAction]", variable has type "AgentAction") output = self.agent.plan( # type: ignore # Incompatible types in assignment (expression has type "AgentAction | AgentFinish | list[AgentAction]", variable has type "AgentAction")
@@ -185,10 +224,32 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
yield AgentStep(action=output, observation=observation) yield AgentStep(action=output, observation=observation)
return return
except BadRequestError as e:
print("bad request string str(e)", str(e))
if (
"context_length_exceeded" in str(e)
and attempt < self.retry_summarize_count - 1
):
print(
f"Context length exceeded. Retrying with summarization (attempt {attempt + 1})..."
)
self.retry_summarize = True
breakpoint()
continue
else:
print("Error now raising occurred in _iter_next_step:", e)
raise e
except Exception as e:
print("Error occurred in _iter_next_step:", e)
raise e
# If the tool chosen is the finishing tool, then we end and return. # If the tool chosen is the finishing tool, then we end and return.
if isinstance(output, AgentFinish): if isinstance(output, AgentFinish):
if self.should_ask_for_human_input: if self.should_ask_for_human_input:
human_feedback = self._ask_human_input(output.return_values["output"]) human_feedback = self._ask_human_input(
output.return_values["output"]
)
if self.crew and self.crew._train: if self.crew and self.crew._train:
self._handle_crew_training_output(output, human_feedback) self._handle_crew_training_output(output, human_feedback)
@@ -235,7 +296,10 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
agent=self.crew_agent, agent=self.crew_agent,
action=agent_action, action=agent_action,
) )
# print("tool_usage", tool_usage)
tool_calling = tool_usage.parse(agent_action.log) tool_calling = tool_usage.parse(agent_action.log)
# print("tool_calling", tool_calling)
if isinstance(tool_calling, ToolUsageErrorException): if isinstance(tool_calling, ToolUsageErrorException):
observation = tool_calling.message observation = tool_calling.message
@@ -249,7 +313,9 @@ class CrewAgentExecutor(AgentExecutor, CrewAgentExecutorMixin):
else: else:
observation = self._i18n.errors("wrong_tool_name").format( observation = self._i18n.errors("wrong_tool_name").format(
tool=tool_calling.tool_name, tool=tool_calling.tool_name,
tools=", ".join([tool.name.casefold() for tool in self.tools]), tools=", ".join(
[tool.name.casefold() for tool in self.tools]
),
) )
yield AgentStep(action=agent_action, observation=observation) yield AgentStep(action=agent_action, observation=observation)

View File

@@ -16,7 +16,7 @@ try:
except ImportError: except ImportError:
agentops = None agentops = None
OPENAI_BIGGER_MODELS = ["gpt-4"] OPENAI_BIGGER_MODELS = ["gpt-4o"]
class ToolUsageErrorException(Exception): class ToolUsageErrorException(Exception):