Files
crewAI/tests/storage/test_mem0_storage_v2.py
2025-05-07 21:53:07 +00:00

291 lines
10 KiB
Python

import os
from unittest.mock import MagicMock, patch
import pytest
from mem0.client.main import MemoryClient
from mem0.memory.main import Memory
from crewai.agent import Agent
from crewai.crew import Crew
from crewai.memory.storage.mem0_storage import Mem0Storage
from crewai.task import Task
class MockCrew:
def __init__(self, memory_config):
self.memory_config = memory_config
self.agents = [MagicMock(role="Test Agent")]
@pytest.fixture
def mock_mem0_memory_client():
"""Fixture to create a mock MemoryClient instance"""
mock_memory = MagicMock(spec=MemoryClient)
return mock_memory
@pytest.fixture
def mem0_storage_with_v2_api(mock_mem0_memory_client):
"""Fixture to create a Mem0Storage instance with v2 API configuration"""
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
crew = MockCrew(
memory_config={
"provider": "mem0",
"config": {
"user_id": "test_user",
"api_key": "ABCDEFGH",
"version": "v2", # Explicitly set to v2
},
}
)
mem0_storage = Mem0Storage(type="short_term", crew=crew)
return mem0_storage, mock_mem0_memory_client
@pytest.fixture
def mem0_storage_with_run_id(mock_mem0_memory_client):
"""Fixture to create a Mem0Storage instance with run_id configuration"""
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
crew = MockCrew(
memory_config={
"provider": "mem0",
"config": {
"user_id": "test_user",
"api_key": "ABCDEFGH",
"version": "v2",
"run_id": "test-session-123", # Set run_id
},
}
)
mem0_storage = Mem0Storage(type="short_term", crew=crew)
return mem0_storage, mock_mem0_memory_client
@pytest.fixture
def mem0_storage_with_v1_api(mock_mem0_memory_client):
"""Fixture to create a Mem0Storage instance with v1.1 API configuration"""
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
crew = MockCrew(
memory_config={
"provider": "mem0",
"config": {
"user_id": "test_user",
"api_key": "ABCDEFGH",
"version": "v1.1", # Explicitly set to v1.1
},
}
)
mem0_storage = Mem0Storage(type="short_term", crew=crew)
return mem0_storage, mock_mem0_memory_client
@pytest.mark.v2_api
def test_mem0_storage_v2_initialization(mem0_storage_with_v2_api):
"""Test that Mem0Storage initializes correctly with v2 API configuration"""
mem0_storage, _ = mem0_storage_with_v2_api
assert mem0_storage.version == "v2"
assert mem0_storage.run_id is None
@pytest.mark.v2_api
def test_mem0_storage_with_run_id_initialization(mem0_storage_with_run_id):
"""Test that Mem0Storage initializes correctly with run_id configuration"""
mem0_storage, _ = mem0_storage_with_run_id
assert mem0_storage.version == "v2"
assert mem0_storage.run_id == "test-session-123"
@pytest.mark.v1_api
def test_mem0_storage_v1_initialization(mem0_storage_with_v1_api):
"""Test that Mem0Storage initializes correctly with v1.1 API configuration"""
mem0_storage, _ = mem0_storage_with_v1_api
assert mem0_storage.version == "v1.1"
assert mem0_storage.run_id is None
@pytest.mark.v2_api
def test_save_method_with_v2_api(mem0_storage_with_v2_api):
"""Test save method with v2 API"""
mem0_storage, mock_memory_client = mem0_storage_with_v2_api
mock_memory_client.add = MagicMock()
test_value = "This is a test memory"
test_metadata = {"key": "value"}
mem0_storage.save(test_value, test_metadata)
mock_memory_client.add.assert_called_once()
call_args = mock_memory_client.add.call_args[1]
assert call_args["version"] == "v2"
assert "run_id" not in call_args
assert call_args["agent_id"] == "Test_Agent"
assert call_args["metadata"] == {"type": "short_term", "key": "value"}
@pytest.mark.v2_api
def test_save_method_with_run_id(mem0_storage_with_run_id):
"""Test save method with run_id"""
mem0_storage, mock_memory_client = mem0_storage_with_run_id
mock_memory_client.add = MagicMock()
test_value = "This is a test memory"
test_metadata = {"key": "value"}
mem0_storage.save(test_value, test_metadata)
mock_memory_client.add.assert_called_once()
call_args = mock_memory_client.add.call_args[1]
assert call_args["version"] == "v2"
assert call_args["run_id"] == "test-session-123"
assert call_args["agent_id"] == "Test_Agent"
assert call_args["metadata"] == {"type": "short_term", "key": "value"}
@pytest.mark.v2_api
def test_search_method_with_v2_api(mem0_storage_with_v2_api):
"""Test search method with v2 API"""
mem0_storage, mock_memory_client = mem0_storage_with_v2_api
mock_results = {"results": [{"score": 0.9, "content": "Result 1"}, {"score": 0.4, "content": "Result 2"}]}
mock_memory_client.search = MagicMock(return_value=mock_results)
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
mock_memory_client.search.assert_called_once()
call_args = mock_memory_client.search.call_args[1]
assert call_args["version"] == "v2"
assert "run_id" not in call_args
assert call_args["query"] == "test query"
assert call_args["limit"] == 5
assert len(results) == 1
assert results[0]["content"] == "Result 1"
@pytest.mark.v2_api
def test_search_method_with_run_id(mem0_storage_with_run_id):
"""Test search method with run_id"""
mem0_storage, mock_memory_client = mem0_storage_with_run_id
mock_results = {"results": [{"score": 0.9, "content": "Result 1"}, {"score": 0.4, "content": "Result 2"}]}
mock_memory_client.search = MagicMock(return_value=mock_results)
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
mock_memory_client.search.assert_called_once()
call_args = mock_memory_client.search.call_args[1]
assert call_args["version"] == "v2"
assert call_args["run_id"] == "test-session-123"
assert call_args["query"] == "test query"
assert call_args["limit"] == 5
assert len(results) == 1
assert results[0]["content"] == "Result 1"
@pytest.mark.v2_api
def test_search_method_with_different_result_formats(mem0_storage_with_v2_api):
"""Test search method with different result formats"""
mem0_storage, mock_memory_client = mem0_storage_with_v2_api
mock_results_dict = {"results": [{"score": 0.9, "content": "Result 1"}, {"score": 0.4, "content": "Result 2"}]}
mock_memory_client.search = MagicMock(return_value=mock_results_dict)
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
assert len(results) == 1
assert results[0]["content"] == "Result 1"
mock_results_list = [{"score": 0.9, "content": "Result 3"}, {"score": 0.4, "content": "Result 4"}]
mock_memory_client.search = MagicMock(return_value=mock_results_list)
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
assert len(results) == 1
assert results[0]["content"] == "Result 3"
mock_memory_client.search = MagicMock(return_value="unexpected format")
results = mem0_storage.search("test query", limit=5, score_threshold=0.5)
assert len(results) == 0
@pytest.mark.parametrize("run_id", [None, "", "test-123", "a" * 256])
@pytest.mark.v2_api
def test_run_id_edge_cases(mock_mem0_memory_client, run_id):
"""Test edge cases for run_id parameter"""
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
crew = MockCrew(
memory_config={
"provider": "mem0",
"config": {
"user_id": "test_user",
"api_key": "ABCDEFGH",
"version": "v2",
"run_id": run_id,
},
}
)
if run_id == "":
mem0_storage = Mem0Storage(type="short_term", crew=crew)
assert mem0_storage.run_id == ""
mock_mem0_memory_client.add = MagicMock()
mem0_storage.save("test", {})
assert "run_id" not in mock_mem0_memory_client.add.call_args[1]
else:
mem0_storage = Mem0Storage(type="short_term", crew=crew)
assert mem0_storage.run_id == run_id
if run_id is not None:
mock_mem0_memory_client.add = MagicMock()
mem0_storage.save("test", {})
assert mock_mem0_memory_client.add.call_args[1].get("run_id") == run_id
def test_invalid_version_handling(mock_mem0_memory_client):
"""Test handling of invalid version"""
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
crew = MockCrew(
memory_config={
"provider": "mem0",
"config": {
"user_id": "test_user",
"api_key": "ABCDEFGH",
"version": "invalid",
},
}
)
with pytest.raises(ValueError, match="Unsupported version"):
Mem0Storage(type="short_term", crew=crew)
def test_invalid_run_id_type(mock_mem0_memory_client):
"""Test handling of invalid run_id type"""
with patch.object(MemoryClient, "__new__", return_value=mock_mem0_memory_client):
crew = MockCrew(
memory_config={
"provider": "mem0",
"config": {
"user_id": "test_user",
"api_key": "ABCDEFGH",
"version": "v2",
"run_id": 123, # Not a string
},
}
)
with pytest.raises(ValueError, match="run_id must be a string"):
Mem0Storage(type="short_term", crew=crew)