Fixing summarization logic

This commit is contained in:
João Moura
2024-09-26 21:40:11 -03:00
parent 6823f76ff4
commit ac331504e9
3 changed files with 123 additions and 37 deletions

View File

@@ -67,8 +67,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.ask_for_human_input = False
self.messages: List[Dict[str, str]] = []
self.iterations = 0
self.log_error_after = 3
self.have_forced_answer = False
self.name_to_tool_map = {tool.name: tool for tool in self.tools}
self.llm.stop = self.stop
def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
if "system" in self.prompt:
@@ -151,6 +153,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
except OutputParserException as e:
self.messages.append({"role": "user", "content": e.error})
if self.iterations > self.log_error_after:
self._printer.print(
content=f"Error parsing LLM output, agent will retry: {e.error}",
color="red",
)
return self._invoke_loop(formatted_answer)
except Exception as e:
@@ -245,21 +252,21 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
def _summarize_messages(self) -> None:
messages_groups = []
for message in self.messages:
content = message["content"]
for i in range(0, len(content), 5000):
messages_groups.append(content[i : i + 5000])
cut_size = self.llm.get_context_window_size()
for i in range(0, len(content), cut_size):
messages_groups.append(content[i : i + cut_size])
summarized_contents = []
for group in messages_groups:
summary = self.llm.call(
[
self._format_msg(
self._i18n.slices("summarizer_system_message"), role="system"
self._i18n.slice("summarizer_system_message"), role="system"
),
self._format_msg(
self._i18n.errors("sumamrize_instruction").format(group=group),
self._i18n.slice("sumamrize_instruction").format(group=group),
),
],
callbacks=self.callbacks,
@@ -270,7 +277,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.messages = [
self._format_msg(
self._i18n.errors("summary").format(merged_summary=merged_summary)
self._i18n.slice("summary").format(merged_summary=merged_summary)
)
]

View File

@@ -1,8 +1,75 @@
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Union
import logging
import warnings
import litellm
from litellm import get_supported_openai_params
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
)
import sys
import io
class FilteredStream(io.StringIO):
def write(self, s):
if (
"Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new"
in s
or "LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True`"
in s
):
return
super().write(s)
LLM_CONTEXT_WINDOW_SIZES = {
# openai
"gpt-4": 8192,
"gpt-4o": 128000,
"gpt-4o-mini": 128000,
"gpt-4-turbo": 128000,
"o1-preview": 128000,
"o1-mini": 128000,
# deepseek
"deepseek-chat": 128000,
# groq
"gemma2-9b-it": 8192,
"gemma-7b-it": 8192,
"llama3-groq-70b-8192-tool-use-preview": 8192,
"llama3-groq-8b-8192-tool-use-preview": 8192,
"llama-3.1-70b-versatile": 131072,
"llama-3.1-8b-instant": 131072,
"llama-3.2-1b-preview": 8192,
"llama-3.2-3b-preview": 8192,
"llama-3.2-11b-text-preview": 8192,
"llama-3.2-90b-text-preview": 8192,
"llama3-70b-8192": 8192,
"llama3-8b-8192": 8192,
"mixtral-8x7b-32768": 32768,
}
@contextmanager
def suppress_warnings():
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
# Redirect stdout and stderr
old_stdout = sys.stdout
old_stderr = sys.stderr
sys.stdout = FilteredStream()
sys.stderr = FilteredStream()
try:
yield
finally:
# Restore stdout and stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
class LLM:
def __init__(
@@ -50,43 +117,50 @@ class LLM:
self.kwargs = kwargs
litellm.drop_params = True
litellm.set_verbose = False
litellm.callbacks = callbacks
def call(self, messages: List[Dict[str, str]], callbacks: List[Any] = []) -> str:
if callbacks and len(callbacks) > 0:
litellm.callbacks = callbacks
with suppress_warnings():
if callbacks and len(callbacks) > 0:
litellm.callbacks = callbacks
try:
params = {
"model": self.model,
"messages": messages,
"timeout": self.timeout,
"temperature": self.temperature,
"top_p": self.top_p,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens or self.max_completion_tokens,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"logit_bias": self.logit_bias,
"response_format": self.response_format,
"seed": self.seed,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs,
"api_base": self.base_url,
"api_version": self.api_version,
"api_key": self.api_key,
**self.kwargs,
}
try:
params = {
"model": self.model,
"messages": messages,
"timeout": self.timeout,
"temperature": self.temperature,
"top_p": self.top_p,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens or self.max_completion_tokens,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"logit_bias": self.logit_bias,
"response_format": self.response_format,
"seed": self.seed,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs,
"api_base": self.base_url,
"api_version": self.api_version,
"api_key": self.api_key,
"stream": False,
**self.kwargs,
}
# Remove None values to avoid passing unnecessary parameters
params = {k: v for k, v in params.items() if v is not None}
# Remove None values to avoid passing unnecessary parameters
params = {k: v for k, v in params.items() if v is not None}
response = litellm.completion(**params)
return response["choices"][0]["message"]["content"]
except Exception as e:
logging.error(f"LiteLLM call failed: {str(e)}")
raise # Re-raise the exception after logging
response = litellm.completion(**params)
return response["choices"][0]["message"]["content"]
except Exception as e:
if not LLMContextLengthExceededException(
str(e)
)._is_context_limit_error(str(e)):
logging.error(f"LiteLLM call failed: {str(e)}")
raise # Re-raise the exception after logging
def supports_function_calling(self) -> bool:
try:
@@ -103,3 +177,7 @@ class LLM:
except Exception as e:
logging.error(f"Failed to get supported params: {str(e)}")
return False
def get_context_window_size(self) -> int:
# Only using 75% of the context window size to avoid cutting the message in the middle
return int(LLM_CONTEXT_WINDOW_SIZES.get(self.model, 8192) * 0.75)

View File

@@ -1,5 +1,6 @@
class LLMContextLengthExceededException(Exception):
CONTEXT_LIMIT_ERRORS = [
"expected a string with maximum length",
"maximum context length",
"context length exceeded",
"context_length_exceeded",