mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 00:02:36 +00:00
fix: preserve files during message summarization
This commit is contained in:
@@ -3,11 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.agent_utils import convert_tools_to_openai_schema
|
||||
from crewai.utilities.agent_utils import convert_tools_to_openai_schema, summarize_messages
|
||||
|
||||
|
||||
class CalculatorInput(BaseModel):
|
||||
@@ -211,3 +212,136 @@ class TestConvertToolsToOpenaiSchema:
|
||||
# Default value should be preserved
|
||||
assert "default" in max_results_prop
|
||||
assert max_results_prop["default"] == 10
|
||||
|
||||
|
||||
class TestSummarizeMessages:
|
||||
"""Tests for summarize_messages function."""
|
||||
|
||||
def test_preserves_files_from_user_messages(self) -> None:
|
||||
"""Test that files attached to user messages are preserved after summarization."""
|
||||
mock_files = {"image.png": MagicMock(), "doc.pdf": MagicMock()}
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Analyze this image", "files": mock_files},
|
||||
{"role": "assistant", "content": "I can see the image shows..."},
|
||||
{"role": "user", "content": "What about the colors?"},
|
||||
]
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.get_context_window_size.return_value = 1000
|
||||
mock_llm.call.return_value = "Summarized conversation about image analysis."
|
||||
|
||||
mock_i18n = MagicMock()
|
||||
mock_i18n.slice.side_effect = lambda key: {
|
||||
"summarizer_system_message": "Summarize the following.",
|
||||
"summarize_instruction": "Summarize: {group}",
|
||||
"summary": "Summary: {merged_summary}",
|
||||
}.get(key, "")
|
||||
|
||||
summarize_messages(
|
||||
messages=messages,
|
||||
llm=mock_llm,
|
||||
callbacks=[],
|
||||
i18n=mock_i18n,
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["role"] == "user"
|
||||
assert "files" in messages[0]
|
||||
assert messages[0]["files"] == mock_files
|
||||
|
||||
def test_merges_files_from_multiple_user_messages(self) -> None:
|
||||
"""Test that files from multiple user messages are merged."""
|
||||
file1 = MagicMock()
|
||||
file2 = MagicMock()
|
||||
file3 = MagicMock()
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "First image", "files": {"img1.png": file1}},
|
||||
{"role": "assistant", "content": "I see the first image."},
|
||||
{"role": "user", "content": "Second image", "files": {"img2.png": file2, "doc.pdf": file3}},
|
||||
{"role": "assistant", "content": "I see the second image and document."},
|
||||
]
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.get_context_window_size.return_value = 1000
|
||||
mock_llm.call.return_value = "Summarized conversation."
|
||||
|
||||
mock_i18n = MagicMock()
|
||||
mock_i18n.slice.side_effect = lambda key: {
|
||||
"summarizer_system_message": "Summarize the following.",
|
||||
"summarize_instruction": "Summarize: {group}",
|
||||
"summary": "Summary: {merged_summary}",
|
||||
}.get(key, "")
|
||||
|
||||
summarize_messages(
|
||||
messages=messages,
|
||||
llm=mock_llm,
|
||||
callbacks=[],
|
||||
i18n=mock_i18n,
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert "files" in messages[0]
|
||||
assert messages[0]["files"] == {
|
||||
"img1.png": file1,
|
||||
"img2.png": file2,
|
||||
"doc.pdf": file3,
|
||||
}
|
||||
|
||||
def test_works_without_files(self) -> None:
|
||||
"""Test that summarization works when no files are attached."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.get_context_window_size.return_value = 1000
|
||||
mock_llm.call.return_value = "A greeting exchange."
|
||||
|
||||
mock_i18n = MagicMock()
|
||||
mock_i18n.slice.side_effect = lambda key: {
|
||||
"summarizer_system_message": "Summarize the following.",
|
||||
"summarize_instruction": "Summarize: {group}",
|
||||
"summary": "Summary: {merged_summary}",
|
||||
}.get(key, "")
|
||||
|
||||
summarize_messages(
|
||||
messages=messages,
|
||||
llm=mock_llm,
|
||||
callbacks=[],
|
||||
i18n=mock_i18n,
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
assert "files" not in messages[0]
|
||||
|
||||
def test_modifies_original_messages_list(self) -> None:
|
||||
"""Test that the original messages list is modified in-place."""
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "First message"},
|
||||
{"role": "assistant", "content": "Response"},
|
||||
{"role": "user", "content": "Second message"},
|
||||
]
|
||||
original_list_id = id(messages)
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.get_context_window_size.return_value = 1000
|
||||
mock_llm.call.return_value = "Summary"
|
||||
|
||||
mock_i18n = MagicMock()
|
||||
mock_i18n.slice.side_effect = lambda key: {
|
||||
"summarizer_system_message": "Summarize.",
|
||||
"summarize_instruction": "Summarize: {group}",
|
||||
"summary": "Summary: {merged_summary}",
|
||||
}.get(key, "")
|
||||
|
||||
summarize_messages(
|
||||
messages=messages,
|
||||
llm=mock_llm,
|
||||
callbacks=[],
|
||||
i18n=mock_i18n,
|
||||
)
|
||||
|
||||
assert id(messages) == original_list_id
|
||||
assert len(messages) == 1
|
||||
|
||||
Reference in New Issue
Block a user