mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
Fixing summarization logic
This commit is contained in:
@@ -67,8 +67,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self.ask_for_human_input = False
|
||||
self.messages: List[Dict[str, str]] = []
|
||||
self.iterations = 0
|
||||
self.log_error_after = 3
|
||||
self.have_forced_answer = False
|
||||
self.name_to_tool_map = {tool.name: tool for tool in self.tools}
|
||||
self.llm.stop = self.stop
|
||||
|
||||
def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
if "system" in self.prompt:
|
||||
@@ -151,6 +153,11 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
|
||||
except OutputParserException as e:
|
||||
self.messages.append({"role": "user", "content": e.error})
|
||||
if self.iterations > self.log_error_after:
|
||||
self._printer.print(
|
||||
content=f"Error parsing LLM output, agent will retry: {e.error}",
|
||||
color="red",
|
||||
)
|
||||
return self._invoke_loop(formatted_answer)
|
||||
|
||||
except Exception as e:
|
||||
@@ -245,21 +252,21 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
|
||||
def _summarize_messages(self) -> None:
|
||||
messages_groups = []
|
||||
|
||||
for message in self.messages:
|
||||
content = message["content"]
|
||||
for i in range(0, len(content), 5000):
|
||||
messages_groups.append(content[i : i + 5000])
|
||||
cut_size = self.llm.get_context_window_size()
|
||||
for i in range(0, len(content), cut_size):
|
||||
messages_groups.append(content[i : i + cut_size])
|
||||
|
||||
summarized_contents = []
|
||||
for group in messages_groups:
|
||||
summary = self.llm.call(
|
||||
[
|
||||
self._format_msg(
|
||||
self._i18n.slices("summarizer_system_message"), role="system"
|
||||
self._i18n.slice("summarizer_system_message"), role="system"
|
||||
),
|
||||
self._format_msg(
|
||||
self._i18n.errors("sumamrize_instruction").format(group=group),
|
||||
self._i18n.slice("sumamrize_instruction").format(group=group),
|
||||
),
|
||||
],
|
||||
callbacks=self.callbacks,
|
||||
@@ -270,7 +277,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
|
||||
self.messages = [
|
||||
self._format_msg(
|
||||
self._i18n.errors("summary").format(merged_summary=merged_summary)
|
||||
self._i18n.slice("summary").format(merged_summary=merged_summary)
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -1,8 +1,75 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
import logging
|
||||
import warnings
|
||||
import litellm
|
||||
from litellm import get_supported_openai_params
|
||||
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
)
|
||||
|
||||
import sys
|
||||
import io
|
||||
|
||||
|
||||
class FilteredStream(io.StringIO):
|
||||
def write(self, s):
|
||||
if (
|
||||
"Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new"
|
||||
in s
|
||||
or "LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True`"
|
||||
in s
|
||||
):
|
||||
return
|
||||
super().write(s)
|
||||
|
||||
|
||||
LLM_CONTEXT_WINDOW_SIZES = {
|
||||
# openai
|
||||
"gpt-4": 8192,
|
||||
"gpt-4o": 128000,
|
||||
"gpt-4o-mini": 128000,
|
||||
"gpt-4-turbo": 128000,
|
||||
"o1-preview": 128000,
|
||||
"o1-mini": 128000,
|
||||
# deepseek
|
||||
"deepseek-chat": 128000,
|
||||
# groq
|
||||
"gemma2-9b-it": 8192,
|
||||
"gemma-7b-it": 8192,
|
||||
"llama3-groq-70b-8192-tool-use-preview": 8192,
|
||||
"llama3-groq-8b-8192-tool-use-preview": 8192,
|
||||
"llama-3.1-70b-versatile": 131072,
|
||||
"llama-3.1-8b-instant": 131072,
|
||||
"llama-3.2-1b-preview": 8192,
|
||||
"llama-3.2-3b-preview": 8192,
|
||||
"llama-3.2-11b-text-preview": 8192,
|
||||
"llama-3.2-90b-text-preview": 8192,
|
||||
"llama3-70b-8192": 8192,
|
||||
"llama3-8b-8192": 8192,
|
||||
"mixtral-8x7b-32768": 32768,
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def suppress_warnings():
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
# Redirect stdout and stderr
|
||||
old_stdout = sys.stdout
|
||||
old_stderr = sys.stderr
|
||||
sys.stdout = FilteredStream()
|
||||
sys.stderr = FilteredStream()
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Restore stdout and stderr
|
||||
sys.stdout = old_stdout
|
||||
sys.stderr = old_stderr
|
||||
|
||||
|
||||
class LLM:
|
||||
def __init__(
|
||||
@@ -50,43 +117,50 @@ class LLM:
|
||||
self.kwargs = kwargs
|
||||
|
||||
litellm.drop_params = True
|
||||
litellm.set_verbose = False
|
||||
litellm.callbacks = callbacks
|
||||
|
||||
def call(self, messages: List[Dict[str, str]], callbacks: List[Any] = []) -> str:
|
||||
if callbacks and len(callbacks) > 0:
|
||||
litellm.callbacks = callbacks
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
litellm.callbacks = callbacks
|
||||
|
||||
try:
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"timeout": self.timeout,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"logit_bias": self.logit_bias,
|
||||
"response_format": self.response_format,
|
||||
"seed": self.seed,
|
||||
"logprobs": self.logprobs,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
"api_base": self.base_url,
|
||||
"api_version": self.api_version,
|
||||
"api_key": self.api_key,
|
||||
**self.kwargs,
|
||||
}
|
||||
try:
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"timeout": self.timeout,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"logit_bias": self.logit_bias,
|
||||
"response_format": self.response_format,
|
||||
"seed": self.seed,
|
||||
"logprobs": self.logprobs,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
"api_base": self.base_url,
|
||||
"api_version": self.api_version,
|
||||
"api_key": self.api_key,
|
||||
"stream": False,
|
||||
**self.kwargs,
|
||||
}
|
||||
|
||||
# Remove None values to avoid passing unnecessary parameters
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
# Remove None values to avoid passing unnecessary parameters
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
response = litellm.completion(**params)
|
||||
return response["choices"][0]["message"]["content"]
|
||||
except Exception as e:
|
||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||
raise # Re-raise the exception after logging
|
||||
response = litellm.completion(**params)
|
||||
return response["choices"][0]["message"]["content"]
|
||||
except Exception as e:
|
||||
if not LLMContextLengthExceededException(
|
||||
str(e)
|
||||
)._is_context_limit_error(str(e)):
|
||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||
|
||||
raise # Re-raise the exception after logging
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
try:
|
||||
@@ -103,3 +177,7 @@ class LLM:
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to get supported params: {str(e)}")
|
||||
return False
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
# Only using 75% of the context window size to avoid cutting the message in the middle
|
||||
return int(LLM_CONTEXT_WINDOW_SIZES.get(self.model, 8192) * 0.75)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
class LLMContextLengthExceededException(Exception):
|
||||
CONTEXT_LIMIT_ERRORS = [
|
||||
"expected a string with maximum length",
|
||||
"maximum context length",
|
||||
"context length exceeded",
|
||||
"context_length_exceeded",
|
||||
|
||||
Reference in New Issue
Block a user