mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 15:18:29 +00:00
- Fix TypeAlias annotation in elasticsearch/types.py using TYPE_CHECKING - Add 'elasticsearch' to _MissingProvider Literal type in base.py - Remove unused variable in test_client.py - Add usedforsecurity=False to MD5 hash in config.py for security check Co-Authored-By: João <joao@crewai.com>
398 lines
19 KiB
Python
398 lines
19 KiB
Python
"""Tests for ElasticsearchClient implementation."""
|
|
|
|
from unittest.mock import AsyncMock, Mock, patch
|
|
|
|
import pytest
|
|
|
|
from crewai.rag.elasticsearch.client import ElasticsearchClient
|
|
from crewai.rag.types import BaseRecord
|
|
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_elasticsearch_client():
|
|
"""Create a mock Elasticsearch client."""
|
|
mock_client = Mock()
|
|
mock_client.indices = Mock()
|
|
mock_client.indices.exists.return_value = False
|
|
mock_client.indices.create.return_value = {"acknowledged": True}
|
|
mock_client.indices.get.return_value = {"test_index": {"mappings": {}}}
|
|
mock_client.indices.delete.return_value = {"acknowledged": True}
|
|
mock_client.index.return_value = {"_id": "test_id", "result": "created"}
|
|
mock_client.search.return_value = {
|
|
"hits": {
|
|
"hits": [
|
|
{
|
|
"_id": "doc1",
|
|
"_score": 0.9,
|
|
"_source": {
|
|
"content": "test content",
|
|
"metadata": {"key": "value"}
|
|
}
|
|
}
|
|
]
|
|
}
|
|
}
|
|
return mock_client
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_async_elasticsearch_client():
|
|
"""Create a mock async Elasticsearch client."""
|
|
mock_client = Mock()
|
|
mock_client.indices = Mock()
|
|
mock_client.indices.exists = AsyncMock(return_value=False)
|
|
mock_client.indices.create = AsyncMock(return_value={"acknowledged": True})
|
|
mock_client.indices.get = AsyncMock(return_value={"test_index": {"mappings": {}}})
|
|
mock_client.indices.delete = AsyncMock(return_value={"acknowledged": True})
|
|
mock_client.index = AsyncMock(return_value={"_id": "test_id", "result": "created"})
|
|
mock_client.search = AsyncMock(return_value={
|
|
"hits": {
|
|
"hits": [
|
|
{
|
|
"_id": "doc1",
|
|
"_score": 0.9,
|
|
"_source": {
|
|
"content": "test content",
|
|
"metadata": {"key": "value"}
|
|
}
|
|
}
|
|
]
|
|
}
|
|
})
|
|
return mock_client
|
|
|
|
|
|
@pytest.fixture
|
|
def client(mock_elasticsearch_client) -> ElasticsearchClient:
|
|
"""Create an ElasticsearchClient instance for testing."""
|
|
mock_embedding = Mock()
|
|
mock_embedding.return_value = [0.1, 0.2, 0.3]
|
|
|
|
client = ElasticsearchClient(
|
|
client=mock_elasticsearch_client,
|
|
embedding_function=mock_embedding,
|
|
vector_dimension=3,
|
|
similarity="cosine"
|
|
)
|
|
return client
|
|
|
|
|
|
@pytest.fixture
|
|
def async_client(mock_async_elasticsearch_client) -> ElasticsearchClient:
|
|
"""Create an ElasticsearchClient instance with async client for testing."""
|
|
mock_embedding = Mock()
|
|
mock_embedding.return_value = [0.1, 0.2, 0.3]
|
|
|
|
client = ElasticsearchClient(
|
|
client=mock_async_elasticsearch_client,
|
|
embedding_function=mock_embedding,
|
|
vector_dimension=3,
|
|
similarity="cosine"
|
|
)
|
|
return client
|
|
|
|
|
|
class TestElasticsearchClient:
|
|
"""Test suite for ElasticsearchClient."""
|
|
|
|
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
|
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
|
def test_create_collection(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
|
"""Test that create_collection creates a new index."""
|
|
mock_elasticsearch_client.indices.exists.return_value = False
|
|
|
|
client.create_collection(collection_name="test_index")
|
|
|
|
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
|
mock_elasticsearch_client.indices.create.assert_called_once()
|
|
call_args = mock_elasticsearch_client.indices.create.call_args
|
|
assert call_args.kwargs["index"] == "test_index"
|
|
assert "mappings" in call_args.kwargs["body"]
|
|
|
|
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
|
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
|
def test_create_collection_already_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
|
"""Test that create_collection raises error if index exists."""
|
|
mock_elasticsearch_client.indices.exists.return_value = True
|
|
|
|
with pytest.raises(
|
|
ValueError, match="Index 'test_index' already exists"
|
|
):
|
|
client.create_collection(collection_name="test_index")
|
|
|
|
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
|
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
|
def test_create_collection_wrong_client_type(self, mock_is_async, mock_is_sync, mock_async_elasticsearch_client):
|
|
"""Test that create_collection raises error with async client."""
|
|
mock_embedding = Mock()
|
|
client = ElasticsearchClient(
|
|
client=mock_async_elasticsearch_client,
|
|
embedding_function=mock_embedding
|
|
)
|
|
|
|
with pytest.raises(ClientMethodMismatchError):
|
|
client.create_collection(collection_name="test_index")
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
|
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
|
async def test_acreate_collection(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
|
"""Test that acreate_collection creates a new index asynchronously."""
|
|
mock_async_elasticsearch_client.indices.exists.return_value = False
|
|
|
|
await async_client.acreate_collection(collection_name="test_index")
|
|
|
|
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
|
mock_async_elasticsearch_client.indices.create.assert_called_once()
|
|
call_args = mock_async_elasticsearch_client.indices.create.call_args
|
|
assert call_args.kwargs["index"] == "test_index"
|
|
assert "mappings" in call_args.kwargs["body"]
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
|
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
|
async def test_acreate_collection_already_exists(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
|
"""Test that acreate_collection raises error if index exists."""
|
|
mock_async_elasticsearch_client.indices.exists.return_value = True
|
|
|
|
with pytest.raises(
|
|
ValueError, match="Index 'test_index' already exists"
|
|
):
|
|
await async_client.acreate_collection(collection_name="test_index")
|
|
|
|
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
|
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
|
def test_get_or_create_collection(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
|
"""Test that get_or_create_collection returns existing index."""
|
|
mock_elasticsearch_client.indices.exists.return_value = True
|
|
|
|
result = client.get_or_create_collection(collection_name="test_index")
|
|
|
|
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
|
mock_elasticsearch_client.indices.get.assert_called_once_with(index="test_index")
|
|
assert result == {"test_index": {"mappings": {}}}
|
|
|
|
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
|
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
|
def test_get_or_create_collection_creates_new(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
|
"""Test that get_or_create_collection creates new index if not exists."""
|
|
mock_elasticsearch_client.indices.exists.return_value = False
|
|
|
|
client.get_or_create_collection(collection_name="test_index")
|
|
|
|
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
|
mock_elasticsearch_client.indices.create.assert_called_once()
|
|
mock_elasticsearch_client.indices.get.assert_called_once_with(index="test_index")
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
|
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
|
async def test_aget_or_create_collection(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
|
"""Test that aget_or_create_collection returns existing index asynchronously."""
|
|
mock_async_elasticsearch_client.indices.exists.return_value = True
|
|
|
|
result = await async_client.aget_or_create_collection(collection_name="test_index")
|
|
|
|
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
|
mock_async_elasticsearch_client.indices.get.assert_called_once_with(index="test_index")
|
|
assert result == {"test_index": {"mappings": {}}}
|
|
|
|
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
|
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
|
def test_add_documents(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
|
"""Test that add_documents indexes documents correctly."""
|
|
mock_elasticsearch_client.indices.exists.return_value = True
|
|
|
|
documents: list[BaseRecord] = [
|
|
{
|
|
"content": "test content",
|
|
"metadata": {"key": "value"}
|
|
}
|
|
]
|
|
|
|
client.add_documents(collection_name="test_index", documents=documents)
|
|
|
|
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
|
mock_elasticsearch_client.index.assert_called_once()
|
|
call_args = mock_elasticsearch_client.index.call_args
|
|
assert call_args.kwargs["index"] == "test_index"
|
|
assert "body" in call_args.kwargs
|
|
assert call_args.kwargs["body"]["content"] == "test content"
|
|
|
|
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
|
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
|
def test_add_documents_empty_list_raises_error(self, mock_is_async, mock_is_sync, client):
|
|
"""Test that add_documents raises error with empty documents list."""
|
|
with pytest.raises(ValueError, match="Documents list cannot be empty"):
|
|
client.add_documents(collection_name="test_index", documents=[])
|
|
|
|
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
|
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
|
def test_add_documents_index_not_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
|
"""Test that add_documents raises error if index doesn't exist."""
|
|
mock_elasticsearch_client.indices.exists.return_value = False
|
|
|
|
documents: list[BaseRecord] = [{"content": "test content"}]
|
|
|
|
with pytest.raises(ValueError, match="Index 'test_index' does not exist"):
|
|
client.add_documents(collection_name="test_index", documents=documents)
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
|
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
|
async def test_aadd_documents(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
|
"""Test that aadd_documents indexes documents correctly asynchronously."""
|
|
mock_async_elasticsearch_client.indices.exists.return_value = True
|
|
|
|
documents: list[BaseRecord] = [
|
|
{
|
|
"content": "test content",
|
|
"metadata": {"key": "value"}
|
|
}
|
|
]
|
|
|
|
await async_client.aadd_documents(collection_name="test_index", documents=documents)
|
|
|
|
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
|
mock_async_elasticsearch_client.index.assert_called_once()
|
|
call_args = mock_async_elasticsearch_client.index.call_args
|
|
assert call_args.kwargs["index"] == "test_index"
|
|
assert "body" in call_args.kwargs
|
|
|
|
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
|
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
|
def test_search(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
|
"""Test that search performs vector similarity search."""
|
|
mock_elasticsearch_client.indices.exists.return_value = True
|
|
|
|
results = client.search(
|
|
collection_name="test_index",
|
|
query="test query",
|
|
limit=5
|
|
)
|
|
|
|
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
|
mock_elasticsearch_client.search.assert_called_once()
|
|
call_args = mock_elasticsearch_client.search.call_args
|
|
assert call_args.kwargs["index"] == "test_index"
|
|
assert "body" in call_args.kwargs
|
|
|
|
assert len(results) == 1
|
|
assert results[0]["id"] == "doc1"
|
|
assert results[0]["content"] == "test content"
|
|
assert results[0]["score"] == 0.9
|
|
|
|
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
|
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
|
def test_search_with_metadata_filter(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
|
"""Test that search applies metadata filter correctly."""
|
|
mock_elasticsearch_client.indices.exists.return_value = True
|
|
|
|
client.search(
|
|
collection_name="test_index",
|
|
query="test query",
|
|
metadata_filter={"key": "value"}
|
|
)
|
|
|
|
mock_elasticsearch_client.search.assert_called_once()
|
|
call_args = mock_elasticsearch_client.search.call_args
|
|
query_body = call_args.kwargs["body"]
|
|
assert "bool" in query_body["query"]
|
|
|
|
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
|
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
|
def test_search_index_not_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
|
"""Test that search raises error if index doesn't exist."""
|
|
mock_elasticsearch_client.indices.exists.return_value = False
|
|
|
|
with pytest.raises(ValueError, match="Index 'test_index' does not exist"):
|
|
client.search(collection_name="test_index", query="test query")
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
|
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
|
async def test_asearch(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
|
"""Test that asearch performs vector similarity search asynchronously."""
|
|
mock_async_elasticsearch_client.indices.exists.return_value = True
|
|
|
|
results = await async_client.asearch(
|
|
collection_name="test_index",
|
|
query="test query",
|
|
limit=5
|
|
)
|
|
|
|
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
|
mock_async_elasticsearch_client.search.assert_called_once()
|
|
|
|
assert len(results) == 1
|
|
assert results[0]["id"] == "doc1"
|
|
assert results[0]["content"] == "test content"
|
|
|
|
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
|
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
|
def test_delete_collection(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
|
"""Test that delete_collection deletes the index."""
|
|
mock_elasticsearch_client.indices.exists.return_value = True
|
|
|
|
client.delete_collection(collection_name="test_index")
|
|
|
|
mock_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
|
mock_elasticsearch_client.indices.delete.assert_called_once_with(index="test_index")
|
|
|
|
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
|
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
|
def test_delete_collection_not_exists(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
|
"""Test that delete_collection raises error if index doesn't exist."""
|
|
mock_elasticsearch_client.indices.exists.return_value = False
|
|
|
|
with pytest.raises(ValueError, match="Index 'test_index' does not exist"):
|
|
client.delete_collection(collection_name="test_index")
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
|
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
|
async def test_adelete_collection(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
|
"""Test that adelete_collection deletes the index asynchronously."""
|
|
mock_async_elasticsearch_client.indices.exists.return_value = True
|
|
|
|
await async_client.adelete_collection(collection_name="test_index")
|
|
|
|
mock_async_elasticsearch_client.indices.exists.assert_called_once_with(index="test_index")
|
|
mock_async_elasticsearch_client.indices.delete.assert_called_once_with(index="test_index")
|
|
|
|
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=True)
|
|
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=False)
|
|
def test_reset(self, mock_is_async, mock_is_sync, client, mock_elasticsearch_client):
|
|
"""Test that reset deletes all non-system indices."""
|
|
mock_elasticsearch_client.indices.get.return_value = {
|
|
"test_index": {},
|
|
".system_index": {},
|
|
"another_index": {}
|
|
}
|
|
|
|
client.reset()
|
|
|
|
mock_elasticsearch_client.indices.get.assert_called_once_with(index="*")
|
|
assert mock_elasticsearch_client.indices.delete.call_count == 2
|
|
delete_calls = [call.kwargs["index"] for call in mock_elasticsearch_client.indices.delete.call_args_list]
|
|
assert "test_index" in delete_calls
|
|
assert "another_index" in delete_calls
|
|
assert ".system_index" not in delete_calls
|
|
|
|
@pytest.mark.asyncio
|
|
@patch("crewai.rag.elasticsearch.client._is_sync_client", return_value=False)
|
|
@patch("crewai.rag.elasticsearch.client._is_async_client", return_value=True)
|
|
async def test_areset(self, mock_is_async, mock_is_sync, async_client, mock_async_elasticsearch_client):
|
|
"""Test that areset deletes all non-system indices asynchronously."""
|
|
mock_async_elasticsearch_client.indices.get.return_value = {
|
|
"test_index": {},
|
|
".system_index": {},
|
|
"another_index": {}
|
|
}
|
|
|
|
await async_client.areset()
|
|
|
|
mock_async_elasticsearch_client.indices.get.assert_called_once_with(index="*")
|
|
assert mock_async_elasticsearch_client.indices.delete.call_count == 2
|