mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-23 15:18:14 +00:00
fix: preserve files during message summarization
This commit is contained in:
@@ -613,13 +613,23 @@ def summarize_messages(
|
||||
) -> None:
|
||||
"""Summarize messages to fit within context window.
|
||||
|
||||
Preserves any files attached to user messages and re-attaches them to
|
||||
the summarized message. Files from all user messages are merged.
|
||||
|
||||
Args:
|
||||
messages: List of messages to summarize
|
||||
messages: List of messages to summarize (modified in-place)
|
||||
llm: LLM instance for summarization
|
||||
callbacks: List of callbacks for LLM
|
||||
i18n: I18N instance for messages
|
||||
"""
|
||||
messages_string = " ".join([message["content"] for message in messages]) # type: ignore[misc]
|
||||
preserved_files: dict[str, Any] = {}
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user" and msg.get("files"):
|
||||
preserved_files.update(msg["files"])
|
||||
|
||||
messages_string = " ".join(
|
||||
[str(message.get("content", "")) for message in messages]
|
||||
)
|
||||
cut_size = llm.get_context_window_size()
|
||||
|
||||
messages_groups = [
|
||||
@@ -636,7 +646,7 @@ def summarize_messages(
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
messages = [
|
||||
summarization_messages = [
|
||||
format_message_for_llm(
|
||||
i18n.slice("summarizer_system_message"), role="system"
|
||||
),
|
||||
@@ -645,7 +655,7 @@ def summarize_messages(
|
||||
),
|
||||
]
|
||||
summary = llm.call(
|
||||
messages,
|
||||
summarization_messages,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
summarized_contents.append({"content": str(summary)})
|
||||
@@ -653,11 +663,12 @@ def summarize_messages(
|
||||
merged_summary = " ".join(content["content"] for content in summarized_contents)
|
||||
|
||||
messages.clear()
|
||||
messages.append(
|
||||
format_message_for_llm(
|
||||
i18n.slice("summary").format(merged_summary=merged_summary)
|
||||
)
|
||||
summary_message = format_message_for_llm(
|
||||
i18n.slice("summary").format(merged_summary=merged_summary)
|
||||
)
|
||||
if preserved_files:
|
||||
summary_message["files"] = preserved_files
|
||||
messages.append(summary_message)
|
||||
|
||||
|
||||
def show_agent_logs(
|
||||
@@ -859,7 +870,11 @@ def extract_tool_call_info(
|
||||
if hasattr(tool_call, "function"):
|
||||
# OpenAI-style: has .function.name and .function.arguments
|
||||
call_id = getattr(tool_call, "id", f"call_{id(tool_call)}")
|
||||
return call_id, sanitize_tool_name(tool_call.function.name), tool_call.function.arguments
|
||||
return (
|
||||
call_id,
|
||||
sanitize_tool_name(tool_call.function.name),
|
||||
tool_call.function.arguments,
|
||||
)
|
||||
if hasattr(tool_call, "function_call") and tool_call.function_call:
|
||||
# Gemini-style: has .function_call.name and .function_call.args
|
||||
call_id = f"call_{id(tool_call)}"
|
||||
|
||||
@@ -3,11 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.agent_utils import convert_tools_to_openai_schema
|
||||
from crewai.utilities.agent_utils import convert_tools_to_openai_schema, summarize_messages
|
||||
|
||||
|
||||
class CalculatorInput(BaseModel):
|
||||
@@ -211,3 +212,136 @@ class TestConvertToolsToOpenaiSchema:
|
||||
# Default value should be preserved
|
||||
assert "default" in max_results_prop
|
||||
assert max_results_prop["default"] == 10
|
||||
|
||||
|
||||
class TestSummarizeMessages:
|
||||
"""Tests for summarize_messages function."""
|
||||
|
||||
def test_preserves_files_from_user_messages(self) -> None:
|
||||
"""Test that files attached to user messages are preserved after summarization."""
|
||||
mock_files = {"image.png": MagicMock(), "doc.pdf": MagicMock()}
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Analyze this image", "files": mock_files},
|
||||
{"role": "assistant", "content": "I can see the image shows..."},
|
||||
{"role": "user", "content": "What about the colors?"},
|
||||
]
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.get_context_window_size.return_value = 1000
|
||||
mock_llm.call.return_value = "Summarized conversation about image analysis."
|
||||
|
||||
mock_i18n = MagicMock()
|
||||
mock_i18n.slice.side_effect = lambda key: {
|
||||
"summarizer_system_message": "Summarize the following.",
|
||||
"summarize_instruction": "Summarize: {group}",
|
||||
"summary": "Summary: {merged_summary}",
|
||||
}.get(key, "")
|
||||
|
||||
summarize_messages(
|
||||
messages=messages,
|
||||
llm=mock_llm,
|
||||
callbacks=[],
|
||||
i18n=mock_i18n,
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["role"] == "user"
|
||||
assert "files" in messages[0]
|
||||
assert messages[0]["files"] == mock_files
|
||||
|
||||
def test_merges_files_from_multiple_user_messages(self) -> None:
|
||||
"""Test that files from multiple user messages are merged."""
|
||||
file1 = MagicMock()
|
||||
file2 = MagicMock()
|
||||
file3 = MagicMock()
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "First image", "files": {"img1.png": file1}},
|
||||
{"role": "assistant", "content": "I see the first image."},
|
||||
{"role": "user", "content": "Second image", "files": {"img2.png": file2, "doc.pdf": file3}},
|
||||
{"role": "assistant", "content": "I see the second image and document."},
|
||||
]
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.get_context_window_size.return_value = 1000
|
||||
mock_llm.call.return_value = "Summarized conversation."
|
||||
|
||||
mock_i18n = MagicMock()
|
||||
mock_i18n.slice.side_effect = lambda key: {
|
||||
"summarizer_system_message": "Summarize the following.",
|
||||
"summarize_instruction": "Summarize: {group}",
|
||||
"summary": "Summary: {merged_summary}",
|
||||
}.get(key, "")
|
||||
|
||||
summarize_messages(
|
||||
messages=messages,
|
||||
llm=mock_llm,
|
||||
callbacks=[],
|
||||
i18n=mock_i18n,
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert "files" in messages[0]
|
||||
assert messages[0]["files"] == {
|
||||
"img1.png": file1,
|
||||
"img2.png": file2,
|
||||
"doc.pdf": file3,
|
||||
}
|
||||
|
||||
def test_works_without_files(self) -> None:
|
||||
"""Test that summarization works when no files are attached."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.get_context_window_size.return_value = 1000
|
||||
mock_llm.call.return_value = "A greeting exchange."
|
||||
|
||||
mock_i18n = MagicMock()
|
||||
mock_i18n.slice.side_effect = lambda key: {
|
||||
"summarizer_system_message": "Summarize the following.",
|
||||
"summarize_instruction": "Summarize: {group}",
|
||||
"summary": "Summary: {merged_summary}",
|
||||
}.get(key, "")
|
||||
|
||||
summarize_messages(
|
||||
messages=messages,
|
||||
llm=mock_llm,
|
||||
callbacks=[],
|
||||
i18n=mock_i18n,
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert "files" not in messages[0]
|
||||
|
||||
def test_modifies_original_messages_list(self) -> None:
|
||||
"""Test that the original messages list is modified in-place."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "First message"},
|
||||
{"role": "assistant", "content": "Response"},
|
||||
{"role": "user", "content": "Second message"},
|
||||
]
|
||||
original_list_id = id(messages)
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.get_context_window_size.return_value = 1000
|
||||
mock_llm.call.return_value = "Summary"
|
||||
|
||||
mock_i18n = MagicMock()
|
||||
mock_i18n.slice.side_effect = lambda key: {
|
||||
"summarizer_system_message": "Summarize.",
|
||||
"summarize_instruction": "Summarize: {group}",
|
||||
"summary": "Summary: {merged_summary}",
|
||||
}.get(key, "")
|
||||
|
||||
summarize_messages(
|
||||
messages=messages,
|
||||
llm=mock_llm,
|
||||
callbacks=[],
|
||||
i18n=mock_i18n,
|
||||
)
|
||||
|
||||
assert id(messages) == original_list_id
|
||||
assert len(messages) == 1
|
||||
|
||||
Reference in New Issue
Block a user