Fixing summarization logic

This commit is contained in:
João Moura
2024-09-26 21:40:11 -03:00
parent 6823f76ff4
commit ac331504e9
3 changed files with 123 additions and 37 deletions

View File

@@ -67,8 +67,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.ask_for_human_input = False self.ask_for_human_input = False
self.messages: List[Dict[str, str]] = [] self.messages: List[Dict[str, str]] = []
self.iterations = 0 self.iterations = 0
self.log_error_after = 3
self.have_forced_answer = False self.have_forced_answer = False
self.name_to_tool_map = {tool.name: tool for tool in self.tools} self.name_to_tool_map = {tool.name: tool for tool in self.tools}
self.llm.stop = self.stop
def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]: def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
if "system" in self.prompt: if "system" in self.prompt:
@@ -151,6 +153,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
except OutputParserException as e: except OutputParserException as e:
self.messages.append({"role": "user", "content": e.error}) self.messages.append({"role": "user", "content": e.error})
if self.iterations > self.log_error_after:
self._printer.print(
content=f"Error parsing LLM output, agent will retry: {e.error}",
color="red",
)
return self._invoke_loop(formatted_answer) return self._invoke_loop(formatted_answer)
except Exception as e: except Exception as e:
@@ -245,21 +252,21 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
def _summarize_messages(self) -> None: def _summarize_messages(self) -> None:
messages_groups = [] messages_groups = []
for message in self.messages: for message in self.messages:
content = message["content"] content = message["content"]
for i in range(0, len(content), 5000): cut_size = self.llm.get_context_window_size()
messages_groups.append(content[i : i + 5000]) for i in range(0, len(content), cut_size):
messages_groups.append(content[i : i + cut_size])
summarized_contents = [] summarized_contents = []
for group in messages_groups: for group in messages_groups:
summary = self.llm.call( summary = self.llm.call(
[ [
self._format_msg( self._format_msg(
self._i18n.slices("summarizer_system_message"), role="system" self._i18n.slice("summarizer_system_message"), role="system"
), ),
self._format_msg( self._format_msg(
self._i18n.errors("sumamrize_instruction").format(group=group), self._i18n.slice("sumamrize_instruction").format(group=group),
), ),
], ],
callbacks=self.callbacks, callbacks=self.callbacks,
@@ -270,7 +277,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.messages = [ self.messages = [
self._format_msg( self._format_msg(
self._i18n.errors("summary").format(merged_summary=merged_summary) self._i18n.slice("summary").format(merged_summary=merged_summary)
) )
] ]

View File

@@ -1,8 +1,75 @@
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
import logging import logging
import warnings
import litellm import litellm
from litellm import get_supported_openai_params from litellm import get_supported_openai_params
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
)
import sys
import io
class FilteredStream(io.StringIO):
def write(self, s):
if (
"Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new"
in s
or "LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True`"
in s
):
return
super().write(s)
LLM_CONTEXT_WINDOW_SIZES = {
# openai
"gpt-4": 8192,
"gpt-4o": 128000,
"gpt-4o-mini": 128000,
"gpt-4-turbo": 128000,
"o1-preview": 128000,
"o1-mini": 128000,
# deepseek
"deepseek-chat": 128000,
# groq
"gemma2-9b-it": 8192,
"gemma-7b-it": 8192,
"llama3-groq-70b-8192-tool-use-preview": 8192,
"llama3-groq-8b-8192-tool-use-preview": 8192,
"llama-3.1-70b-versatile": 131072,
"llama-3.1-8b-instant": 131072,
"llama-3.2-1b-preview": 8192,
"llama-3.2-3b-preview": 8192,
"llama-3.2-11b-text-preview": 8192,
"llama-3.2-90b-text-preview": 8192,
"llama3-70b-8192": 8192,
"llama3-8b-8192": 8192,
"mixtral-8x7b-32768": 32768,
}
@contextmanager
def suppress_warnings():
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
# Redirect stdout and stderr
old_stdout = sys.stdout
old_stderr = sys.stderr
sys.stdout = FilteredStream()
sys.stderr = FilteredStream()
try:
yield
finally:
# Restore stdout and stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
class LLM: class LLM:
def __init__( def __init__(
@@ -50,43 +117,50 @@ class LLM:
self.kwargs = kwargs self.kwargs = kwargs
litellm.drop_params = True litellm.drop_params = True
litellm.set_verbose = False
litellm.callbacks = callbacks litellm.callbacks = callbacks
def call(self, messages: List[Dict[str, str]], callbacks: List[Any] = []) -> str: def call(self, messages: List[Dict[str, str]], callbacks: List[Any] = []) -> str:
if callbacks and len(callbacks) > 0: with suppress_warnings():
litellm.callbacks = callbacks if callbacks and len(callbacks) > 0:
litellm.callbacks = callbacks
try: try:
params = { params = {
"model": self.model, "model": self.model,
"messages": messages, "messages": messages,
"timeout": self.timeout, "timeout": self.timeout,
"temperature": self.temperature, "temperature": self.temperature,
"top_p": self.top_p, "top_p": self.top_p,
"n": self.n, "n": self.n,
"stop": self.stop, "stop": self.stop,
"max_tokens": self.max_tokens or self.max_completion_tokens, "max_tokens": self.max_tokens or self.max_completion_tokens,
"presence_penalty": self.presence_penalty, "presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty, "frequency_penalty": self.frequency_penalty,
"logit_bias": self.logit_bias, "logit_bias": self.logit_bias,
"response_format": self.response_format, "response_format": self.response_format,
"seed": self.seed, "seed": self.seed,
"logprobs": self.logprobs, "logprobs": self.logprobs,
"top_logprobs": self.top_logprobs, "top_logprobs": self.top_logprobs,
"api_base": self.base_url, "api_base": self.base_url,
"api_version": self.api_version, "api_version": self.api_version,
"api_key": self.api_key, "api_key": self.api_key,
**self.kwargs, "stream": False,
} **self.kwargs,
}
# Remove None values to avoid passing unnecessary parameters # Remove None values to avoid passing unnecessary parameters
params = {k: v for k, v in params.items() if v is not None} params = {k: v for k, v in params.items() if v is not None}
response = litellm.completion(**params) response = litellm.completion(**params)
return response["choices"][0]["message"]["content"] return response["choices"][0]["message"]["content"]
except Exception as e: except Exception as e:
logging.error(f"LiteLLM call failed: {str(e)}") if not LLMContextLengthExceededException(
raise # Re-raise the exception after logging str(e)
)._is_context_limit_error(str(e)):
logging.error(f"LiteLLM call failed: {str(e)}")
raise # Re-raise the exception after logging
def supports_function_calling(self) -> bool: def supports_function_calling(self) -> bool:
try: try:
@@ -103,3 +177,7 @@ class LLM:
except Exception as e: except Exception as e:
logging.error(f"Failed to get supported params: {str(e)}") logging.error(f"Failed to get supported params: {str(e)}")
return False return False
def get_context_window_size(self) -> int:
# Only using 75% of the context window size to avoid cutting the message in the middle
return int(LLM_CONTEXT_WINDOW_SIZES.get(self.model, 8192) * 0.75)

View File

@@ -1,5 +1,6 @@
class LLMContextLengthExceededException(Exception): class LLMContextLengthExceededException(Exception):
CONTEXT_LIMIT_ERRORS = [ CONTEXT_LIMIT_ERRORS = [
"expected a string with maximum length",
"maximum context length", "maximum context length",
"context length exceeded", "context length exceeded",
"context_length_exceeded", "context_length_exceeded",