mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-15 02:58:30 +00:00
Add addtiional validation and tests. Also, if memory was set to True on a Crew, you were required to have a mem0 key. Fixed this direct dependency.
This commit is contained in:
119
tests/memory/entity_memory_test.py
Normal file
119
tests/memory/entity_memory_test.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# tests/memory/test_entity_memory.py
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from crewai.memory.entity.entity_memory import EntityMemory
|
||||
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
from crewai.memory.storage.rag_storage import RAGStorage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rag_storage():
|
||||
"""Fixture to create a mock RAGStorage instance"""
|
||||
return MagicMock(spec=RAGStorage)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mem0_storage():
|
||||
"""Fixture to create a mock Mem0Storage instance"""
|
||||
return MagicMock(spec=Mem0Storage)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def entity_memory_rag(mock_rag_storage):
|
||||
"""Fixture to create an EntityMemory instance with RAGStorage"""
|
||||
with patch(
|
||||
"crewai.memory.entity.entity_memory.RAGStorage", return_value=mock_rag_storage
|
||||
):
|
||||
return EntityMemory()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def entity_memory_mem0(mock_mem0_storage):
|
||||
"""Fixture to create an EntityMemory instance with Mem0Storage"""
|
||||
with patch(
|
||||
"crewai.memory.entity.entity_memory.Mem0Storage", return_value=mock_mem0_storage
|
||||
):
|
||||
return EntityMemory(memory_provider="mem0")
|
||||
|
||||
|
||||
def test_save_rag_storage(entity_memory_rag, mock_rag_storage):
|
||||
item = EntityMemoryItem(
|
||||
name="John Doe",
|
||||
type="Person",
|
||||
description="A software engineer",
|
||||
relationships="Works at TechCorp",
|
||||
)
|
||||
entity_memory_rag.save(item)
|
||||
|
||||
expected_data = "John Doe(Person): A software engineer"
|
||||
mock_rag_storage.save.assert_called_once_with(expected_data, item.metadata)
|
||||
|
||||
|
||||
def test_save_mem0_storage(entity_memory_mem0, mock_mem0_storage):
|
||||
item = EntityMemoryItem(
|
||||
name="John Doe",
|
||||
type="Person",
|
||||
description="A software engineer",
|
||||
relationships="Works at TechCorp",
|
||||
)
|
||||
entity_memory_mem0.save(item)
|
||||
|
||||
expected_data = """
|
||||
Remember details about the following entity:
|
||||
Name: John Doe
|
||||
Type: Person
|
||||
Entity Description: A software engineer
|
||||
"""
|
||||
mock_mem0_storage.save.assert_called_once_with(expected_data, item.metadata)
|
||||
|
||||
|
||||
def test_search(entity_memory_rag, mock_rag_storage):
|
||||
query = "software engineer"
|
||||
limit = 5
|
||||
filters = {"type": "Person"}
|
||||
score_threshold = 0.7
|
||||
|
||||
entity_memory_rag.search(query, limit, filters, score_threshold)
|
||||
|
||||
mock_rag_storage.search.assert_called_once_with(
|
||||
query=query, limit=limit, filters=filters, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
|
||||
def test_reset(entity_memory_rag, mock_rag_storage):
|
||||
entity_memory_rag.reset()
|
||||
mock_rag_storage.reset.assert_called_once()
|
||||
|
||||
|
||||
def test_reset_error(entity_memory_rag, mock_rag_storage):
|
||||
mock_rag_storage.reset.side_effect = Exception("Reset error")
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
entity_memory_rag.reset()
|
||||
|
||||
assert (
|
||||
str(exc_info.value)
|
||||
== "An error occurred while resetting the entity memory: Reset error"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("memory_provider", [None, "other"])
|
||||
def test_init_with_rag_storage(memory_provider):
|
||||
with patch("crewai.memory.entity.entity_memory.RAGStorage") as mock_rag_storage:
|
||||
EntityMemory(memory_provider=memory_provider)
|
||||
mock_rag_storage.assert_called_once()
|
||||
|
||||
|
||||
def test_init_with_mem0_storage():
|
||||
with patch("crewai.memory.entity.entity_memory.Mem0Storage") as mock_mem0_storage:
|
||||
EntityMemory(memory_provider="mem0")
|
||||
mock_mem0_storage.assert_called_once()
|
||||
|
||||
|
||||
def test_init_with_custom_storage():
|
||||
custom_storage = MagicMock()
|
||||
entity_memory = EntityMemory(storage=custom_storage)
|
||||
assert entity_memory.storage == custom_storage
|
||||
Reference in New Issue
Block a user