mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
120 lines
3.6 KiB
Python
120 lines
3.6 KiB
Python
# tests/memory/test_entity_memory.py
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
from crewai.memory.entity.entity_memory import EntityMemory
|
|
from crewai.memory.entity.entity_memory_item import EntityMemoryItem
|
|
from crewai.memory.storage.mem0_storage import Mem0Storage
|
|
from crewai.memory.storage.rag_storage import RAGStorage
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_rag_storage():
|
|
"""Fixture to create a mock RAGStorage instance"""
|
|
return MagicMock(spec=RAGStorage)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_mem0_storage():
|
|
"""Fixture to create a mock Mem0Storage instance"""
|
|
return MagicMock(spec=Mem0Storage)
|
|
|
|
|
|
@pytest.fixture
|
|
def entity_memory_rag(mock_rag_storage):
|
|
"""Fixture to create an EntityMemory instance with RAGStorage"""
|
|
with patch(
|
|
"crewai.memory.entity.entity_memory.RAGStorage", return_value=mock_rag_storage
|
|
):
|
|
return EntityMemory()
|
|
|
|
|
|
@pytest.fixture
|
|
def entity_memory_mem0(mock_mem0_storage):
|
|
"""Fixture to create an EntityMemory instance with Mem0Storage"""
|
|
with patch(
|
|
"crewai.memory.entity.entity_memory.Mem0Storage", return_value=mock_mem0_storage
|
|
):
|
|
return EntityMemory(memory_provider="mem0")
|
|
|
|
|
|
def test_save_rag_storage(entity_memory_rag, mock_rag_storage):
|
|
item = EntityMemoryItem(
|
|
name="John Doe",
|
|
type="Person",
|
|
description="A software engineer",
|
|
relationships="Works at TechCorp",
|
|
)
|
|
entity_memory_rag.save(item)
|
|
|
|
expected_data = "John Doe(Person): A software engineer"
|
|
mock_rag_storage.save.assert_called_once_with(expected_data, item.metadata)
|
|
|
|
|
|
def test_save_mem0_storage(entity_memory_mem0, mock_mem0_storage):
|
|
item = EntityMemoryItem(
|
|
name="John Doe",
|
|
type="Person",
|
|
description="A software engineer",
|
|
relationships="Works at TechCorp",
|
|
)
|
|
entity_memory_mem0.save(item)
|
|
|
|
expected_data = """
|
|
Remember details about the following entity:
|
|
Name: John Doe
|
|
Type: Person
|
|
Entity Description: A software engineer
|
|
"""
|
|
mock_mem0_storage.save.assert_called_once_with(expected_data, item.metadata)
|
|
|
|
|
|
def test_search(entity_memory_rag, mock_rag_storage):
|
|
query = "software engineer"
|
|
limit = 5
|
|
filters = {"type": "Person"}
|
|
score_threshold = 0.7
|
|
|
|
entity_memory_rag.search(query, limit, filters, score_threshold)
|
|
|
|
mock_rag_storage.search.assert_called_once_with(
|
|
query=query, limit=limit, filters=filters, score_threshold=score_threshold
|
|
)
|
|
|
|
|
|
def test_reset(entity_memory_rag, mock_rag_storage):
|
|
entity_memory_rag.reset()
|
|
mock_rag_storage.reset.assert_called_once()
|
|
|
|
|
|
def test_reset_error(entity_memory_rag, mock_rag_storage):
|
|
mock_rag_storage.reset.side_effect = Exception("Reset error")
|
|
|
|
with pytest.raises(Exception) as exc_info:
|
|
entity_memory_rag.reset()
|
|
|
|
assert (
|
|
str(exc_info.value)
|
|
== "An error occurred while resetting the entity memory: Reset error"
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("memory_provider", [None, "other"])
|
|
def test_init_with_rag_storage(memory_provider):
|
|
with patch("crewai.memory.entity.entity_memory.RAGStorage") as mock_rag_storage:
|
|
EntityMemory(memory_provider=memory_provider)
|
|
mock_rag_storage.assert_called_once()
|
|
|
|
|
|
def test_init_with_mem0_storage():
|
|
with patch("crewai.memory.entity.entity_memory.Mem0Storage") as mock_mem0_storage:
|
|
EntityMemory(memory_provider="mem0")
|
|
mock_mem0_storage.assert_called_once()
|
|
|
|
|
|
def test_init_with_custom_storage():
|
|
custom_storage = MagicMock()
|
|
entity_memory = EntityMemory(storage=custom_storage)
|
|
assert entity_memory.storage == custom_storage
|