mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-02 07:42:40 +00:00
fix: preserve files during message summarization
This commit is contained in:
@@ -613,13 +613,23 @@ def summarize_messages(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Summarize messages to fit within context window.
|
"""Summarize messages to fit within context window.
|
||||||
|
|
||||||
|
Preserves any files attached to user messages and re-attaches them to
|
||||||
|
the summarized message. Files from all user messages are merged.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of messages to summarize
|
messages: List of messages to summarize (modified in-place)
|
||||||
llm: LLM instance for summarization
|
llm: LLM instance for summarization
|
||||||
callbacks: List of callbacks for LLM
|
callbacks: List of callbacks for LLM
|
||||||
i18n: I18N instance for messages
|
i18n: I18N instance for messages
|
||||||
"""
|
"""
|
||||||
messages_string = " ".join([message["content"] for message in messages]) # type: ignore[misc]
|
preserved_files: dict[str, Any] = {}
|
||||||
|
for msg in messages:
|
||||||
|
if msg.get("role") == "user" and msg.get("files"):
|
||||||
|
preserved_files.update(msg["files"])
|
||||||
|
|
||||||
|
messages_string = " ".join(
|
||||||
|
[str(message.get("content", "")) for message in messages]
|
||||||
|
)
|
||||||
cut_size = llm.get_context_window_size()
|
cut_size = llm.get_context_window_size()
|
||||||
|
|
||||||
messages_groups = [
|
messages_groups = [
|
||||||
@@ -636,7 +646,7 @@ def summarize_messages(
|
|||||||
color="yellow",
|
color="yellow",
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [
|
summarization_messages = [
|
||||||
format_message_for_llm(
|
format_message_for_llm(
|
||||||
i18n.slice("summarizer_system_message"), role="system"
|
i18n.slice("summarizer_system_message"), role="system"
|
||||||
),
|
),
|
||||||
@@ -645,7 +655,7 @@ def summarize_messages(
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
summary = llm.call(
|
summary = llm.call(
|
||||||
messages,
|
summarization_messages,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
)
|
)
|
||||||
summarized_contents.append({"content": str(summary)})
|
summarized_contents.append({"content": str(summary)})
|
||||||
@@ -653,11 +663,12 @@ def summarize_messages(
|
|||||||
merged_summary = " ".join(content["content"] for content in summarized_contents)
|
merged_summary = " ".join(content["content"] for content in summarized_contents)
|
||||||
|
|
||||||
messages.clear()
|
messages.clear()
|
||||||
messages.append(
|
summary_message = format_message_for_llm(
|
||||||
format_message_for_llm(
|
i18n.slice("summary").format(merged_summary=merged_summary)
|
||||||
i18n.slice("summary").format(merged_summary=merged_summary)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
if preserved_files:
|
||||||
|
summary_message["files"] = preserved_files
|
||||||
|
messages.append(summary_message)
|
||||||
|
|
||||||
|
|
||||||
def show_agent_logs(
|
def show_agent_logs(
|
||||||
@@ -859,7 +870,11 @@ def extract_tool_call_info(
|
|||||||
if hasattr(tool_call, "function"):
|
if hasattr(tool_call, "function"):
|
||||||
# OpenAI-style: has .function.name and .function.arguments
|
# OpenAI-style: has .function.name and .function.arguments
|
||||||
call_id = getattr(tool_call, "id", f"call_{id(tool_call)}")
|
call_id = getattr(tool_call, "id", f"call_{id(tool_call)}")
|
||||||
return call_id, sanitize_tool_name(tool_call.function.name), tool_call.function.arguments
|
return (
|
||||||
|
call_id,
|
||||||
|
sanitize_tool_name(tool_call.function.name),
|
||||||
|
tool_call.function.arguments,
|
||||||
|
)
|
||||||
if hasattr(tool_call, "function_call") and tool_call.function_call:
|
if hasattr(tool_call, "function_call") and tool_call.function_call:
|
||||||
# Gemini-style: has .function_call.name and .function_call.args
|
# Gemini-style: has .function_call.name and .function_call.args
|
||||||
call_id = f"call_{id(tool_call)}"
|
call_id = f"call_{id(tool_call)}"
|
||||||
|
|||||||
@@ -3,11 +3,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai.tools.base_tool import BaseTool
|
from crewai.tools.base_tool import BaseTool
|
||||||
from crewai.utilities.agent_utils import convert_tools_to_openai_schema
|
from crewai.utilities.agent_utils import convert_tools_to_openai_schema, summarize_messages
|
||||||
|
|
||||||
|
|
||||||
class CalculatorInput(BaseModel):
|
class CalculatorInput(BaseModel):
|
||||||
@@ -211,3 +212,136 @@ class TestConvertToolsToOpenaiSchema:
|
|||||||
# Default value should be preserved
|
# Default value should be preserved
|
||||||
assert "default" in max_results_prop
|
assert "default" in max_results_prop
|
||||||
assert max_results_prop["default"] == 10
|
assert max_results_prop["default"] == 10
|
||||||
|
|
||||||
|
|
||||||
|
class TestSummarizeMessages:
|
||||||
|
"""Tests for summarize_messages function."""
|
||||||
|
|
||||||
|
def test_preserves_files_from_user_messages(self) -> None:
|
||||||
|
"""Test that files attached to user messages are preserved after summarization."""
|
||||||
|
mock_files = {"image.png": MagicMock(), "doc.pdf": MagicMock()}
|
||||||
|
messages: list[dict[str, Any]] = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Analyze this image", "files": mock_files},
|
||||||
|
{"role": "assistant", "content": "I can see the image shows..."},
|
||||||
|
{"role": "user", "content": "What about the colors?"},
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.get_context_window_size.return_value = 1000
|
||||||
|
mock_llm.call.return_value = "Summarized conversation about image analysis."
|
||||||
|
|
||||||
|
mock_i18n = MagicMock()
|
||||||
|
mock_i18n.slice.side_effect = lambda key: {
|
||||||
|
"summarizer_system_message": "Summarize the following.",
|
||||||
|
"summarize_instruction": "Summarize: {group}",
|
||||||
|
"summary": "Summary: {merged_summary}",
|
||||||
|
}.get(key, "")
|
||||||
|
|
||||||
|
summarize_messages(
|
||||||
|
messages=messages,
|
||||||
|
llm=mock_llm,
|
||||||
|
callbacks=[],
|
||||||
|
i18n=mock_i18n,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(messages) == 1
|
||||||
|
assert messages[0]["role"] == "user"
|
||||||
|
assert "files" in messages[0]
|
||||||
|
assert messages[0]["files"] == mock_files
|
||||||
|
|
||||||
|
def test_merges_files_from_multiple_user_messages(self) -> None:
|
||||||
|
"""Test that files from multiple user messages are merged."""
|
||||||
|
file1 = MagicMock()
|
||||||
|
file2 = MagicMock()
|
||||||
|
file3 = MagicMock()
|
||||||
|
messages: list[dict[str, Any]] = [
|
||||||
|
{"role": "user", "content": "First image", "files": {"img1.png": file1}},
|
||||||
|
{"role": "assistant", "content": "I see the first image."},
|
||||||
|
{"role": "user", "content": "Second image", "files": {"img2.png": file2, "doc.pdf": file3}},
|
||||||
|
{"role": "assistant", "content": "I see the second image and document."},
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.get_context_window_size.return_value = 1000
|
||||||
|
mock_llm.call.return_value = "Summarized conversation."
|
||||||
|
|
||||||
|
mock_i18n = MagicMock()
|
||||||
|
mock_i18n.slice.side_effect = lambda key: {
|
||||||
|
"summarizer_system_message": "Summarize the following.",
|
||||||
|
"summarize_instruction": "Summarize: {group}",
|
||||||
|
"summary": "Summary: {merged_summary}",
|
||||||
|
}.get(key, "")
|
||||||
|
|
||||||
|
summarize_messages(
|
||||||
|
messages=messages,
|
||||||
|
llm=mock_llm,
|
||||||
|
callbacks=[],
|
||||||
|
i18n=mock_i18n,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(messages) == 1
|
||||||
|
assert "files" in messages[0]
|
||||||
|
assert messages[0]["files"] == {
|
||||||
|
"img1.png": file1,
|
||||||
|
"img2.png": file2,
|
||||||
|
"doc.pdf": file3,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_works_without_files(self) -> None:
|
||||||
|
"""Test that summarization works when no files are attached."""
|
||||||
|
messages: list[dict[str, Any]] = [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.get_context_window_size.return_value = 1000
|
||||||
|
mock_llm.call.return_value = "A greeting exchange."
|
||||||
|
|
||||||
|
mock_i18n = MagicMock()
|
||||||
|
mock_i18n.slice.side_effect = lambda key: {
|
||||||
|
"summarizer_system_message": "Summarize the following.",
|
||||||
|
"summarize_instruction": "Summarize: {group}",
|
||||||
|
"summary": "Summary: {merged_summary}",
|
||||||
|
}.get(key, "")
|
||||||
|
|
||||||
|
summarize_messages(
|
||||||
|
messages=messages,
|
||||||
|
llm=mock_llm,
|
||||||
|
callbacks=[],
|
||||||
|
i18n=mock_i18n,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(messages) == 1
|
||||||
|
assert "files" not in messages[0]
|
||||||
|
|
||||||
|
def test_modifies_original_messages_list(self) -> None:
|
||||||
|
"""Test that the original messages list is modified in-place."""
|
||||||
|
messages: list[dict[str, Any]] = [
|
||||||
|
{"role": "user", "content": "First message"},
|
||||||
|
{"role": "assistant", "content": "Response"},
|
||||||
|
{"role": "user", "content": "Second message"},
|
||||||
|
]
|
||||||
|
original_list_id = id(messages)
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.get_context_window_size.return_value = 1000
|
||||||
|
mock_llm.call.return_value = "Summary"
|
||||||
|
|
||||||
|
mock_i18n = MagicMock()
|
||||||
|
mock_i18n.slice.side_effect = lambda key: {
|
||||||
|
"summarizer_system_message": "Summarize.",
|
||||||
|
"summarize_instruction": "Summarize: {group}",
|
||||||
|
"summary": "Summary: {merged_summary}",
|
||||||
|
}.get(key, "")
|
||||||
|
|
||||||
|
summarize_messages(
|
||||||
|
messages=messages,
|
||||||
|
llm=mock_llm,
|
||||||
|
callbacks=[],
|
||||||
|
i18n=mock_i18n,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert id(messages) == original_list_id
|
||||||
|
assert len(messages) == 1
|
||||||
|
|||||||
Reference in New Issue
Block a user