mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
156 lines
4.6 KiB
Python
156 lines
4.6 KiB
Python
import sqlite3
|
|
import threading
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from types import SimpleNamespace
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from crewai.llm import LLM
|
|
from crewai.memory.storage.llm_response_cache_storage import LLMResponseCacheStorage
|
|
from crewai.utilities.llm_response_cache_handler import LLMResponseCacheHandler
|
|
|
|
|
|
@pytest.fixture
|
|
def handler():
|
|
handler = LLMResponseCacheHandler()
|
|
handler.storage.add = MagicMock()
|
|
handler.storage.get = MagicMock()
|
|
return handler
|
|
|
|
|
|
def create_mock_response(content):
|
|
"""Create a properly structured mock response object for litellm.completion"""
|
|
message = SimpleNamespace(content=content)
|
|
choice = SimpleNamespace(message=message)
|
|
response = SimpleNamespace(choices=[choice])
|
|
return response
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_llm_recording(handler):
|
|
handler.start_recording()
|
|
|
|
llm = LLM(model="gpt-4o-mini")
|
|
llm.set_response_cache_handler(handler)
|
|
|
|
messages = [{"role": "user", "content": "Hello, world!"}]
|
|
|
|
with patch('litellm.completion') as mock_completion:
|
|
mock_completion.return_value = create_mock_response("Hello, human!")
|
|
|
|
response = llm.call(messages)
|
|
|
|
assert response == "Hello, human!"
|
|
|
|
handler.storage.add.assert_called_once_with(
|
|
"gpt-4o-mini", messages, "Hello, human!"
|
|
)
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_llm_replaying(handler):
|
|
handler.start_replaying()
|
|
handler.storage.get.return_value = "Cached response"
|
|
|
|
llm = LLM(model="gpt-4o-mini")
|
|
llm.set_response_cache_handler(handler)
|
|
|
|
messages = [{"role": "user", "content": "Hello, world!"}]
|
|
|
|
with patch('litellm.completion') as mock_completion:
|
|
response = llm.call(messages)
|
|
|
|
assert response == "Cached response"
|
|
|
|
mock_completion.assert_not_called()
|
|
|
|
handler.storage.get.assert_called_once_with("gpt-4o-mini", messages)
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_llm_replay_fallback(handler):
|
|
handler.start_replaying()
|
|
handler.storage.get.return_value = None
|
|
|
|
llm = LLM(model="gpt-4o-mini")
|
|
llm.set_response_cache_handler(handler)
|
|
|
|
messages = [{"role": "user", "content": "Hello, world!"}]
|
|
|
|
with patch('litellm.completion') as mock_completion:
|
|
mock_completion.return_value = create_mock_response("Hello, human!")
|
|
|
|
response = llm.call(messages)
|
|
|
|
assert response == "Hello, human!"
|
|
|
|
mock_completion.assert_called_once()
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_cache_error_handling():
|
|
"""Test that errors during cache operations are handled gracefully."""
|
|
handler = LLMResponseCacheHandler()
|
|
|
|
handler.storage.add = MagicMock(side_effect=sqlite3.Error("Mock DB error"))
|
|
handler.storage.get = MagicMock(side_effect=sqlite3.Error("Mock DB error"))
|
|
|
|
handler.start_recording()
|
|
|
|
handler.cache_response("model", [{"role": "user", "content": "test"}], "response")
|
|
|
|
handler.start_replaying()
|
|
|
|
assert handler.get_cached_response("model", [{"role": "user", "content": "test"}]) is None
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_cache_expiration():
|
|
"""Test that cache expiration works correctly."""
|
|
import sqlite3
|
|
|
|
conn = sqlite3.connect(":memory:")
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS llm_response_cache (
|
|
request_hash TEXT PRIMARY KEY,
|
|
model TEXT,
|
|
messages TEXT,
|
|
response TEXT,
|
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
"""
|
|
)
|
|
conn.commit()
|
|
|
|
storage = LLMResponseCacheStorage(":memory:")
|
|
|
|
original_get_connection = storage._get_connection
|
|
storage._get_connection = lambda: conn
|
|
|
|
try:
|
|
model = "test-model"
|
|
messages = [{"role": "user", "content": "test"}]
|
|
response = "test response"
|
|
storage.add(model, messages, response)
|
|
|
|
assert storage.get(model, messages) == response
|
|
|
|
storage.cleanup_expired_cache(max_age_days=0)
|
|
|
|
assert storage.get(model, messages) is None
|
|
finally:
|
|
storage._get_connection = original_get_connection
|
|
conn.close()
|
|
|
|
|
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
|
def test_concurrent_cache_access():
|
|
"""Test that concurrent cache access works correctly."""
|
|
pytest.skip("SQLite in-memory databases are not shared between threads")
|
|
|
|
|
|
# storage = LLMResponseCacheStorage(temp_db.name)
|