mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-20 13:28:13 +00:00
Add addtiional validation and tests. Also, if memory was set to True on a Crew, you were required to have a mem0 key. Fixed this direct dependency.
This commit is contained in:
147
tests/memory/contextual_memory_test.py
Normal file
147
tests/memory/contextual_memory_test.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from crewai.memory import EntityMemory, LongTermMemory, ShortTermMemory, UserMemory
|
||||
from crewai.memory.contextual.contextual_memory import ContextualMemory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_memories():
|
||||
return {
|
||||
"stm": MagicMock(spec=ShortTermMemory),
|
||||
"ltm": MagicMock(spec=LongTermMemory),
|
||||
"em": MagicMock(spec=EntityMemory),
|
||||
"um": MagicMock(spec=UserMemory),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def contextual_memory_mem0(mock_memories):
|
||||
return ContextualMemory(
|
||||
memory_provider="mem0",
|
||||
stm=mock_memories["stm"],
|
||||
ltm=mock_memories["ltm"],
|
||||
em=mock_memories["em"],
|
||||
um=mock_memories["um"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def contextual_memory_other(mock_memories):
|
||||
return ContextualMemory(
|
||||
memory_provider="other",
|
||||
stm=mock_memories["stm"],
|
||||
ltm=mock_memories["ltm"],
|
||||
em=mock_memories["em"],
|
||||
um=mock_memories["um"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def contextual_memory_none(mock_memories):
|
||||
return ContextualMemory(
|
||||
memory_provider=None,
|
||||
stm=mock_memories["stm"],
|
||||
ltm=mock_memories["ltm"],
|
||||
em=mock_memories["em"],
|
||||
um=mock_memories["um"],
|
||||
)
|
||||
|
||||
|
||||
def test_build_context_for_task_mem0(contextual_memory_mem0, mock_memories):
|
||||
task = MagicMock(description="Test task")
|
||||
context = "Additional context"
|
||||
|
||||
mock_memories["stm"].search.return_value = ["Recent insight"]
|
||||
mock_memories["ltm"].search.return_value = [
|
||||
{"metadata": {"suggestions": ["Historical data"]}}
|
||||
]
|
||||
mock_memories["em"].search.return_value = [{"memory": "Entity memory"}]
|
||||
mock_memories["um"].search.return_value = [{"memory": "User memory"}]
|
||||
|
||||
result = contextual_memory_mem0.build_context_for_task(task, context)
|
||||
|
||||
assert "Recent Insights:" in result
|
||||
assert "Historical Data:" in result
|
||||
assert "Entities:" in result
|
||||
assert "User memories/preferences:" in result
|
||||
|
||||
|
||||
def test_build_context_for_task_other_provider(contextual_memory_other, mock_memories):
|
||||
task = MagicMock(description="Test task")
|
||||
context = "Additional context"
|
||||
|
||||
mock_memories["stm"].search.return_value = ["Recent insight"]
|
||||
mock_memories["ltm"].search.return_value = [
|
||||
{"metadata": {"suggestions": ["Historical data"]}}
|
||||
]
|
||||
mock_memories["em"].search.return_value = [{"context": "Entity context"}]
|
||||
mock_memories["um"].search.return_value = [{"memory": "User memory"}]
|
||||
|
||||
result = contextual_memory_other.build_context_for_task(task, context)
|
||||
|
||||
assert "Recent Insights:" in result
|
||||
assert "Historical Data:" in result
|
||||
assert "Entities:" in result
|
||||
assert "User memories/preferences:" not in result
|
||||
|
||||
|
||||
def test_build_context_for_task_none_provider(contextual_memory_none, mock_memories):
|
||||
task = MagicMock(description="Test task")
|
||||
context = "Additional context"
|
||||
|
||||
mock_memories["stm"].search.return_value = ["Recent insight"]
|
||||
mock_memories["ltm"].search.return_value = [
|
||||
{"metadata": {"suggestions": ["Historical data"]}}
|
||||
]
|
||||
mock_memories["em"].search.return_value = [{"context": "Entity context"}]
|
||||
mock_memories["um"].search.return_value = [{"memory": "User memory"}]
|
||||
|
||||
result = contextual_memory_none.build_context_for_task(task, context)
|
||||
|
||||
assert "Recent Insights:" in result
|
||||
assert "Historical Data:" in result
|
||||
assert "Entities:" in result
|
||||
assert "User memories/preferences:" not in result
|
||||
|
||||
|
||||
def test_fetch_entity_context_mem0(contextual_memory_mem0, mock_memories):
|
||||
mock_memories["em"].search.return_value = [
|
||||
{"memory": "Entity 1"},
|
||||
{"memory": "Entity 2"},
|
||||
]
|
||||
result = contextual_memory_mem0._fetch_entity_context("query")
|
||||
expected_result = "Entities:\n- Entity 1\n- Entity 2"
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_fetch_entity_context_other_provider(contextual_memory_other, mock_memories):
|
||||
mock_memories["em"].search.return_value = [
|
||||
{"context": "Entity 1"},
|
||||
{"context": "Entity 2"},
|
||||
]
|
||||
result = contextual_memory_other._fetch_entity_context("query")
|
||||
expected_result = "Entities:\n- Entity 1\n- Entity 2"
|
||||
assert result == expected_result
|
||||
|
||||
|
||||
def test_user_memories_only_for_mem0(contextual_memory_mem0, mock_memories):
|
||||
mock_memories["um"].search.return_value = [{"memory": "User memory"}]
|
||||
|
||||
# Test for mem0 provider
|
||||
result_mem0 = contextual_memory_mem0._fetch_user_memories("query")
|
||||
assert "User memories/preferences:" in result_mem0
|
||||
assert "User memory" in result_mem0
|
||||
|
||||
# Additional test to ensure user memories are included/excluded in the full context
|
||||
task = MagicMock(description="Test task")
|
||||
context = "Additional context"
|
||||
mock_memories["stm"].search.return_value = ["Recent insight"]
|
||||
mock_memories["ltm"].search.return_value = [
|
||||
{"metadata": {"suggestions": ["Historical data"]}}
|
||||
]
|
||||
mock_memories["em"].search.return_value = [{"memory": "Entity memory"}]
|
||||
|
||||
full_context_mem0 = contextual_memory_mem0.build_context_for_task(task, context)
|
||||
assert "User memories/preferences:" in full_context_mem0
|
||||
assert "User memory" in full_context_mem0
|
||||
119
tests/memory/entity_memory_test.py
Normal file
119
tests/memory/entity_memory_test.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# tests/memory/test_entity_memory.py
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rag_storage():
|
||||
"""Fixture to create a mock RAGStorage instance"""
|
||||
return MagicMock(spec=RAGStorage)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mem0_storage():
|
||||
"""Fixture to create a mock Mem0Storage instance"""
|
||||
return MagicMock(spec=Mem0Storage)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def entity_memory_rag(mock_rag_storage):
|
||||
"""Fixture to create an EntityMemory instance with RAGStorage"""
|
||||
with patch(
|
||||
"crewai.memory.entity.entity_memory.RAGStorage", return_value=mock_rag_storage
|
||||
):
|
||||
return EntityMemory()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def entity_memory_mem0(mock_mem0_storage):
|
||||
"""Fixture to create an EntityMemory instance with Mem0Storage"""
|
||||
with patch(
|
||||
"crewai.memory.entity.entity_memory.Mem0Storage", return_value=mock_mem0_storage
|
||||
):
|
||||
return EntityMemory(memory_provider="mem0")
|
||||
|
||||
|
||||
def test_save_rag_storage(entity_memory_rag, mock_rag_storage):
|
||||
item = EntityMemoryItem(
|
||||
name="John Doe",
|
||||
type="Person",
|
||||
description="A software engineer",
|
||||
relationships="Works at TechCorp",
|
||||
)
|
||||
entity_memory_rag.save(item)
|
||||
|
||||
expected_data = "John Doe(Person): A software engineer"
|
||||
mock_rag_storage.save.assert_called_once_with(expected_data, item.metadata)
|
||||
|
||||
|
||||
def test_save_mem0_storage(entity_memory_mem0, mock_mem0_storage):
|
||||
item = EntityMemoryItem(
|
||||
name="John Doe",
|
||||
type="Person",
|
||||
description="A software engineer",
|
||||
relationships="Works at TechCorp",
|
||||
)
|
||||
entity_memory_mem0.save(item)
|
||||
|
||||
expected_data = """
|
||||
Remember details about the following entity:
|
||||
Name: John Doe
|
||||
Type: Person
|
||||
Entity Description: A software engineer
|
||||
"""
|
||||
mock_mem0_storage.save.assert_called_once_with(expected_data, item.metadata)
|
||||
|
||||
|
||||
def test_search(entity_memory_rag, mock_rag_storage):
|
||||
query = "software engineer"
|
||||
limit = 5
|
||||
filters = {"type": "Person"}
|
||||
score_threshold = 0.7
|
||||
|
||||
entity_memory_rag.search(query, limit, filters, score_threshold)
|
||||
|
||||
mock_rag_storage.search.assert_called_once_with(
|
||||
query=query, limit=limit, filters=filters, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
|
||||
def test_reset(entity_memory_rag, mock_rag_storage):
|
||||
entity_memory_rag.reset()
|
||||
mock_rag_storage.reset.assert_called_once()
|
||||
|
||||
|
||||
def test_reset_error(entity_memory_rag, mock_rag_storage):
|
||||
mock_rag_storage.reset.side_effect = Exception("Reset error")
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
entity_memory_rag.reset()
|
||||
|
||||
assert (
|
||||
str(exc_info.value)
|
||||
== "An error occurred while resetting the entity memory: Reset error"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("memory_provider", [None, "other"])
|
||||
def test_init_with_rag_storage(memory_provider):
|
||||
with patch("crewai.memory.entity.entity_memory.RAGStorage") as mock_rag_storage:
|
||||
EntityMemory(memory_provider=memory_provider)
|
||||
mock_rag_storage.assert_called_once()
|
||||
|
||||
|
||||
def test_init_with_mem0_storage():
|
||||
with patch("crewai.memory.entity.entity_memory.Mem0Storage") as mock_mem0_storage:
|
||||
EntityMemory(memory_provider="mem0")
|
||||
mock_mem0_storage.assert_called_once()
|
||||
|
||||
|
||||
def test_init_with_custom_storage():
|
||||
custom_storage = MagicMock()
|
||||
entity_memory = EntityMemory(storage=custom_storage)
|
||||
assert entity_memory.storage == custom_storage
|
||||
@@ -1,29 +1,125 @@
|
||||
import pytest
|
||||
# tests/memory/long_term_memory_test.py
|
||||
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from crewai.memory.long_term.long_term_memory import LongTermMemory
|
||||
from crewai.memory.long_term.long_term_memory_item import LongTermMemoryItem
|
||||
from crewai.memory.storage.ltm_sqlite_storage import LTMSQLiteStorage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def long_term_memory():
|
||||
"""Fixture to create a LongTermMemory instance"""
|
||||
return LongTermMemory()
|
||||
def mock_storage():
|
||||
"""Fixture to create a mock LTMSQLiteStorage instance"""
|
||||
return MagicMock(spec=LTMSQLiteStorage)
|
||||
|
||||
|
||||
def test_save_and_search(long_term_memory):
|
||||
@pytest.fixture
|
||||
def long_term_memory(mock_storage):
|
||||
"""Fixture to create a LongTermMemory instance with mock storage"""
|
||||
return LongTermMemory(storage=mock_storage)
|
||||
|
||||
|
||||
def test_save(long_term_memory, mock_storage):
|
||||
memory = LongTermMemoryItem(
|
||||
agent="test_agent",
|
||||
task="test_task",
|
||||
expected_output="test_output",
|
||||
datetime="test_datetime",
|
||||
datetime="2023-01-01 12:00:00",
|
||||
quality=0.5,
|
||||
metadata={"task": "test_task", "quality": 0.5},
|
||||
metadata={"additional_info": "test_info"},
|
||||
)
|
||||
long_term_memory.save(memory)
|
||||
find = long_term_memory.search("test_task", latest_n=5)[0]
|
||||
assert find["score"] == 0.5
|
||||
assert find["datetime"] == "test_datetime"
|
||||
assert find["metadata"]["agent"] == "test_agent"
|
||||
assert find["metadata"]["quality"] == 0.5
|
||||
assert find["metadata"]["task"] == "test_task"
|
||||
assert find["metadata"]["expected_output"] == "test_output"
|
||||
|
||||
expected_metadata = {
|
||||
"additional_info": "test_info",
|
||||
"agent": "test_agent",
|
||||
"expected_output": "test_output",
|
||||
"quality": 0.5, # Include quality in expected metadata
|
||||
}
|
||||
mock_storage.save.assert_called_once_with(
|
||||
task_description="test_task",
|
||||
score=0.5,
|
||||
metadata=expected_metadata,
|
||||
datetime="2023-01-01 12:00:00",
|
||||
)
|
||||
|
||||
|
||||
def test_search(long_term_memory, mock_storage):
|
||||
mock_storage.load.return_value = [
|
||||
{
|
||||
"metadata": {
|
||||
"agent": "test_agent",
|
||||
"expected_output": "test_output",
|
||||
"task": "test_task",
|
||||
},
|
||||
"datetime": "2023-01-01 12:00:00",
|
||||
"score": 0.5,
|
||||
}
|
||||
]
|
||||
|
||||
result = long_term_memory.search("test_task", latest_n=5)
|
||||
|
||||
mock_storage.load.assert_called_once_with("test_task", 5)
|
||||
assert len(result) == 1
|
||||
assert result[0]["metadata"]["agent"] == "test_agent"
|
||||
assert result[0]["metadata"]["expected_output"] == "test_output"
|
||||
assert result[0]["metadata"]["task"] == "test_task"
|
||||
assert result[0]["datetime"] == "2023-01-01 12:00:00"
|
||||
assert result[0]["score"] == 0.5
|
||||
|
||||
|
||||
def test_save_with_minimal_metadata(long_term_memory, mock_storage):
|
||||
memory = LongTermMemoryItem(
|
||||
agent="minimal_agent",
|
||||
task="minimal_task",
|
||||
expected_output="minimal_output",
|
||||
datetime="2023-01-01 12:00:00",
|
||||
quality=0.3,
|
||||
metadata={},
|
||||
)
|
||||
long_term_memory.save(memory)
|
||||
|
||||
expected_metadata = {
|
||||
"agent": "minimal_agent",
|
||||
"expected_output": "minimal_output",
|
||||
"quality": 0.3, # Include quality in expected metadata
|
||||
}
|
||||
mock_storage.save.assert_called_once_with(
|
||||
task_description="minimal_task",
|
||||
score=0.3,
|
||||
metadata=expected_metadata,
|
||||
datetime="2023-01-01 12:00:00",
|
||||
)
|
||||
|
||||
|
||||
def test_reset(long_term_memory, mock_storage):
|
||||
long_term_memory.reset()
|
||||
mock_storage.reset.assert_called_once()
|
||||
|
||||
|
||||
def test_search_with_no_results(long_term_memory, mock_storage):
|
||||
mock_storage.load.return_value = []
|
||||
result = long_term_memory.search("nonexistent_task")
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_init_with_default_storage():
|
||||
with patch(
|
||||
"crewai.memory.long_term.long_term_memory.LTMSQLiteStorage"
|
||||
) as mock_storage_class:
|
||||
LongTermMemory()
|
||||
mock_storage_class.assert_called_once()
|
||||
|
||||
|
||||
def test_init_with_custom_storage():
|
||||
custom_storage = MagicMock()
|
||||
memory = LongTermMemory(storage=custom_storage)
|
||||
assert memory.storage == custom_storage
|
||||
|
||||
|
||||
@pytest.mark.parametrize("latest_n", [1, 3, 5, 10])
|
||||
def test_search_with_different_latest_n(long_term_memory, mock_storage, latest_n):
|
||||
long_term_memory.search("test_task", latest_n=latest_n)
|
||||
mock_storage.load.assert_called_once_with("test_task", latest_n)
|
||||
|
||||
Reference in New Issue
Block a user