mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
* Simplified Qdrant and ChromaDB client initialization * Refactored factory structure and updated tests accordingly
792 lines
31 KiB
Python
792 lines
31 KiB
Python
"""Tests for QdrantClient implementation."""
|
|
|
|
from unittest.mock import AsyncMock, Mock
|
|
|
|
import pytest
|
|
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
|
|
|
|
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
|
from crewai.rag.qdrant.client import QdrantClient
|
|
from crewai.rag.types import BaseRecord
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_qdrant_client():
|
|
"""Create a mock Qdrant client."""
|
|
return Mock(spec=SyncQdrantClient)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_async_qdrant_client():
|
|
"""Create a mock async Qdrant client."""
|
|
return Mock(spec=AsyncQdrantClient)
|
|
|
|
|
|
@pytest.fixture
|
|
def client(mock_qdrant_client) -> QdrantClient:
|
|
"""Create a QdrantClient instance for testing."""
|
|
mock_embedding = Mock()
|
|
mock_embedding.return_value = [0.1, 0.2, 0.3]
|
|
client = QdrantClient(client=mock_qdrant_client, embedding_function=mock_embedding)
|
|
return client
|
|
|
|
|
|
@pytest.fixture
|
|
def async_client(mock_async_qdrant_client) -> QdrantClient:
|
|
"""Create a QdrantClient instance with async client for testing."""
|
|
mock_embedding = Mock()
|
|
mock_embedding.return_value = [0.1, 0.2, 0.3]
|
|
client = QdrantClient(
|
|
client=mock_async_qdrant_client, embedding_function=mock_embedding
|
|
)
|
|
return client
|
|
|
|
|
|
class TestQdrantClient:
|
|
"""Test suite for QdrantClient."""
|
|
|
|
def test_create_collection(self, client, mock_qdrant_client):
|
|
"""Test that create_collection creates a new collection."""
|
|
mock_qdrant_client.collection_exists.return_value = False
|
|
|
|
client.create_collection(collection_name="test_collection")
|
|
|
|
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
|
mock_qdrant_client.create_collection.assert_called_once()
|
|
call_args = mock_qdrant_client.create_collection.call_args
|
|
assert call_args.kwargs["collection_name"] == "test_collection"
|
|
assert call_args.kwargs["vectors_config"] is not None
|
|
|
|
def test_create_collection_already_exists(self, client, mock_qdrant_client):
|
|
"""Test that create_collection raises error if collection exists."""
|
|
mock_qdrant_client.collection_exists.return_value = True
|
|
|
|
with pytest.raises(
|
|
ValueError, match="Collection 'test_collection' already exists"
|
|
):
|
|
client.create_collection(collection_name="test_collection")
|
|
|
|
def test_create_collection_wrong_client_type(self, mock_async_qdrant_client):
|
|
"""Test that create_collection raises TypeError for async client."""
|
|
client = QdrantClient(
|
|
client=mock_async_qdrant_client, embedding_function=Mock()
|
|
)
|
|
|
|
with pytest.raises(
|
|
ClientMethodMismatchError, match=r"Method create_collection\(\) requires"
|
|
):
|
|
client.create_collection(collection_name="test_collection")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_acreate_collection(self, async_client, mock_async_qdrant_client):
|
|
"""Test that acreate_collection creates a new collection asynchronously."""
|
|
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
|
|
mock_async_qdrant_client.create_collection = AsyncMock()
|
|
|
|
await async_client.acreate_collection(collection_name="test_collection")
|
|
|
|
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
|
"test_collection"
|
|
)
|
|
mock_async_qdrant_client.create_collection.assert_called_once()
|
|
call_args = mock_async_qdrant_client.create_collection.call_args
|
|
assert call_args.kwargs["collection_name"] == "test_collection"
|
|
assert call_args.kwargs["vectors_config"] is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_acreate_collection_already_exists(
|
|
self, async_client, mock_async_qdrant_client
|
|
):
|
|
"""Test that acreate_collection raises error if collection exists."""
|
|
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
|
|
|
with pytest.raises(
|
|
ValueError, match="Collection 'test_collection' already exists"
|
|
):
|
|
await async_client.acreate_collection(collection_name="test_collection")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_acreate_collection_wrong_client_type(self, mock_qdrant_client):
|
|
"""Test that acreate_collection raises TypeError for sync client."""
|
|
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
|
|
|
with pytest.raises(
|
|
ClientMethodMismatchError, match=r"Method acreate_collection\(\) requires"
|
|
):
|
|
await client.acreate_collection(collection_name="test_collection")
|
|
|
|
def test_get_or_create_collection_existing(self, client, mock_qdrant_client):
|
|
"""Test get_or_create_collection returns existing collection."""
|
|
mock_qdrant_client.collection_exists.return_value = True
|
|
mock_collection_info = Mock()
|
|
mock_qdrant_client.get_collection.return_value = mock_collection_info
|
|
|
|
result = client.get_or_create_collection(collection_name="test_collection")
|
|
|
|
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
|
mock_qdrant_client.get_collection.assert_called_once_with("test_collection")
|
|
mock_qdrant_client.create_collection.assert_not_called()
|
|
assert result == mock_collection_info
|
|
|
|
def test_get_or_create_collection_new(self, client, mock_qdrant_client):
|
|
"""Test get_or_create_collection creates new collection."""
|
|
mock_qdrant_client.collection_exists.return_value = False
|
|
mock_collection_info = Mock()
|
|
mock_qdrant_client.get_collection.return_value = mock_collection_info
|
|
|
|
result = client.get_or_create_collection(collection_name="test_collection")
|
|
|
|
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
|
mock_qdrant_client.create_collection.assert_called_once()
|
|
mock_qdrant_client.get_collection.assert_called_once_with("test_collection")
|
|
assert result == mock_collection_info
|
|
|
|
def test_get_or_create_collection_wrong_client_type(self, mock_async_qdrant_client):
|
|
"""Test get_or_create_collection raises TypeError for async client."""
|
|
client = QdrantClient(
|
|
client=mock_async_qdrant_client, embedding_function=Mock()
|
|
)
|
|
|
|
with pytest.raises(
|
|
ClientMethodMismatchError,
|
|
match=r"Method get_or_create_collection\(\) requires",
|
|
):
|
|
client.get_or_create_collection(collection_name="test_collection")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_aget_or_create_collection_existing(
|
|
self, async_client, mock_async_qdrant_client
|
|
):
|
|
"""Test aget_or_create_collection returns existing collection."""
|
|
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
|
mock_collection_info = Mock()
|
|
mock_async_qdrant_client.get_collection = AsyncMock(
|
|
return_value=mock_collection_info
|
|
)
|
|
|
|
result = await async_client.aget_or_create_collection(
|
|
collection_name="test_collection"
|
|
)
|
|
|
|
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
|
"test_collection"
|
|
)
|
|
mock_async_qdrant_client.get_collection.assert_called_once_with(
|
|
"test_collection"
|
|
)
|
|
mock_async_qdrant_client.create_collection.assert_not_called()
|
|
assert result == mock_collection_info
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_aget_or_create_collection_new(
|
|
self, async_client, mock_async_qdrant_client
|
|
):
|
|
"""Test aget_or_create_collection creates new collection."""
|
|
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
|
|
mock_async_qdrant_client.create_collection = AsyncMock()
|
|
mock_collection_info = Mock()
|
|
mock_async_qdrant_client.get_collection = AsyncMock(
|
|
return_value=mock_collection_info
|
|
)
|
|
|
|
result = await async_client.aget_or_create_collection(
|
|
collection_name="test_collection"
|
|
)
|
|
|
|
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
|
"test_collection"
|
|
)
|
|
mock_async_qdrant_client.create_collection.assert_called_once()
|
|
mock_async_qdrant_client.get_collection.assert_called_once_with(
|
|
"test_collection"
|
|
)
|
|
assert result == mock_collection_info
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_aget_or_create_collection_wrong_client_type(
|
|
self, mock_qdrant_client
|
|
):
|
|
"""Test aget_or_create_collection raises TypeError for sync client."""
|
|
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
|
|
|
with pytest.raises(
|
|
ClientMethodMismatchError,
|
|
match=r"Method aget_or_create_collection\(\) requires",
|
|
):
|
|
await client.aget_or_create_collection(collection_name="test_collection")
|
|
|
|
def test_add_documents(self, client, mock_qdrant_client):
|
|
"""Test that add_documents adds documents to collection."""
|
|
mock_qdrant_client.collection_exists.return_value = True
|
|
client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
|
|
|
documents: list[BaseRecord] = [
|
|
{
|
|
"content": "Test document",
|
|
"metadata": {"source": "test"},
|
|
}
|
|
]
|
|
|
|
client.add_documents(collection_name="test_collection", documents=documents)
|
|
|
|
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
|
client.embedding_function.assert_called_once_with("Test document")
|
|
mock_qdrant_client.upsert.assert_called_once()
|
|
|
|
# Check upsert was called with correct parameters
|
|
call_args = mock_qdrant_client.upsert.call_args
|
|
assert call_args.kwargs["collection_name"] == "test_collection"
|
|
assert len(call_args.kwargs["points"]) == 1
|
|
point = call_args.kwargs["points"][0]
|
|
assert point.vector == [0.1, 0.2, 0.3]
|
|
assert point.payload["content"] == "Test document"
|
|
assert point.payload["source"] == "test"
|
|
|
|
def test_add_documents_with_doc_id(self, client, mock_qdrant_client):
|
|
"""Test that add_documents uses provided doc_id."""
|
|
mock_qdrant_client.collection_exists.return_value = True
|
|
client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
|
|
|
documents: list[BaseRecord] = [
|
|
{
|
|
"doc_id": "custom-id-123",
|
|
"content": "Test document",
|
|
"metadata": {"source": "test"},
|
|
}
|
|
]
|
|
|
|
client.add_documents(collection_name="test_collection", documents=documents)
|
|
|
|
call_args = mock_qdrant_client.upsert.call_args
|
|
point = call_args.kwargs["points"][0]
|
|
assert point.id == "custom-id-123"
|
|
|
|
def test_add_documents_empty_list(self, client, mock_qdrant_client):
|
|
"""Test that add_documents raises error for empty documents list."""
|
|
documents: list[BaseRecord] = []
|
|
|
|
with pytest.raises(ValueError, match="Documents list cannot be empty"):
|
|
client.add_documents(collection_name="test_collection", documents=documents)
|
|
|
|
def test_add_documents_collection_not_exists(self, client, mock_qdrant_client):
|
|
"""Test that add_documents raises error if collection doesn't exist."""
|
|
mock_qdrant_client.collection_exists.return_value = False
|
|
|
|
documents: list[BaseRecord] = [
|
|
{
|
|
"content": "Test document",
|
|
"metadata": {"source": "test"},
|
|
}
|
|
]
|
|
|
|
with pytest.raises(
|
|
ValueError, match="Collection 'test_collection' does not exist"
|
|
):
|
|
client.add_documents(collection_name="test_collection", documents=documents)
|
|
|
|
def test_add_documents_wrong_client_type(self, mock_async_qdrant_client):
|
|
"""Test that add_documents raises TypeError for async client."""
|
|
client = QdrantClient(
|
|
client=mock_async_qdrant_client, embedding_function=Mock()
|
|
)
|
|
|
|
documents: list[BaseRecord] = [
|
|
{
|
|
"content": "Test document",
|
|
"metadata": {"source": "test"},
|
|
}
|
|
]
|
|
|
|
with pytest.raises(
|
|
ClientMethodMismatchError, match=r"Method add_documents\(\) requires"
|
|
):
|
|
client.add_documents(collection_name="test_collection", documents=documents)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_aadd_documents(self, async_client, mock_async_qdrant_client):
|
|
"""Test that aadd_documents adds documents to collection asynchronously."""
|
|
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
|
mock_async_qdrant_client.upsert = AsyncMock()
|
|
async_client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
|
|
|
documents: list[BaseRecord] = [
|
|
{
|
|
"content": "Test document",
|
|
"metadata": {"source": "test"},
|
|
}
|
|
]
|
|
|
|
await async_client.aadd_documents(
|
|
collection_name="test_collection", documents=documents
|
|
)
|
|
|
|
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
|
"test_collection"
|
|
)
|
|
async_client.embedding_function.assert_called_once_with("Test document")
|
|
mock_async_qdrant_client.upsert.assert_called_once()
|
|
|
|
# Check upsert was called with correct parameters
|
|
call_args = mock_async_qdrant_client.upsert.call_args
|
|
assert call_args.kwargs["collection_name"] == "test_collection"
|
|
assert len(call_args.kwargs["points"]) == 1
|
|
point = call_args.kwargs["points"][0]
|
|
assert point.vector == [0.1, 0.2, 0.3]
|
|
assert point.payload["content"] == "Test document"
|
|
assert point.payload["source"] == "test"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_aadd_documents_with_doc_id(
|
|
self, async_client, mock_async_qdrant_client
|
|
):
|
|
"""Test that aadd_documents uses provided doc_id."""
|
|
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
|
mock_async_qdrant_client.upsert = AsyncMock()
|
|
async_client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
|
|
|
documents: list[BaseRecord] = [
|
|
{
|
|
"doc_id": "custom-id-123",
|
|
"content": "Test document",
|
|
"metadata": {"source": "test"},
|
|
}
|
|
]
|
|
|
|
await async_client.aadd_documents(
|
|
collection_name="test_collection", documents=documents
|
|
)
|
|
|
|
call_args = mock_async_qdrant_client.upsert.call_args
|
|
point = call_args.kwargs["points"][0]
|
|
assert point.id == "custom-id-123"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_aadd_documents_empty_list(
|
|
self, async_client, mock_async_qdrant_client
|
|
):
|
|
"""Test that aadd_documents raises error for empty documents list."""
|
|
documents: list[BaseRecord] = []
|
|
|
|
with pytest.raises(ValueError, match="Documents list cannot be empty"):
|
|
await async_client.aadd_documents(
|
|
collection_name="test_collection", documents=documents
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_aadd_documents_collection_not_exists(
|
|
self, async_client, mock_async_qdrant_client
|
|
):
|
|
"""Test that aadd_documents raises error if collection doesn't exist."""
|
|
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
|
|
|
|
documents: list[BaseRecord] = [
|
|
{
|
|
"content": "Test document",
|
|
"metadata": {"source": "test"},
|
|
}
|
|
]
|
|
|
|
with pytest.raises(
|
|
ValueError, match="Collection 'test_collection' does not exist"
|
|
):
|
|
await async_client.aadd_documents(
|
|
collection_name="test_collection", documents=documents
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_aadd_documents_wrong_client_type(self, mock_qdrant_client):
|
|
"""Test that aadd_documents raises TypeError for sync client."""
|
|
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
|
|
|
documents: list[BaseRecord] = [
|
|
{
|
|
"content": "Test document",
|
|
"metadata": {"source": "test"},
|
|
}
|
|
]
|
|
|
|
with pytest.raises(
|
|
ClientMethodMismatchError, match=r"Method aadd_documents\(\) requires"
|
|
):
|
|
await client.aadd_documents(
|
|
collection_name="test_collection", documents=documents
|
|
)
|
|
|
|
def test_search(self, client, mock_qdrant_client):
|
|
"""Test that search returns matching documents."""
|
|
mock_qdrant_client.collection_exists.return_value = True
|
|
client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
|
|
|
mock_point = Mock()
|
|
mock_point.id = "doc-123"
|
|
mock_point.payload = {"content": "Test content", "source": "test"}
|
|
mock_point.score = 0.95
|
|
|
|
mock_response = Mock()
|
|
mock_response.points = [mock_point]
|
|
mock_qdrant_client.query_points.return_value = mock_response
|
|
|
|
results = client.search(collection_name="test_collection", query="test query")
|
|
|
|
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
|
client.embedding_function.assert_called_once_with("test query")
|
|
mock_qdrant_client.query_points.assert_called_once()
|
|
|
|
call_args = mock_qdrant_client.query_points.call_args
|
|
assert call_args.kwargs["collection_name"] == "test_collection"
|
|
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
|
|
assert call_args.kwargs["limit"] == 10
|
|
assert call_args.kwargs["with_payload"] is True
|
|
assert call_args.kwargs["with_vectors"] is False
|
|
|
|
assert len(results) == 1
|
|
assert results[0]["id"] == "doc-123"
|
|
assert results[0]["content"] == "Test content"
|
|
assert results[0]["metadata"] == {"source": "test"}
|
|
assert results[0]["score"] == 0.975
|
|
|
|
def test_search_with_filters(self, client, mock_qdrant_client):
|
|
"""Test that search applies metadata filters correctly."""
|
|
mock_qdrant_client.collection_exists.return_value = True
|
|
client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
|
|
|
mock_response = Mock()
|
|
mock_response.points = []
|
|
mock_qdrant_client.query_points.return_value = mock_response
|
|
|
|
client.search(
|
|
collection_name="test_collection",
|
|
query="test query",
|
|
metadata_filter={"category": "tech", "status": "published"},
|
|
)
|
|
|
|
call_args = mock_qdrant_client.query_points.call_args
|
|
query_filter = call_args.kwargs["query_filter"]
|
|
assert len(query_filter.must) == 2
|
|
assert any(
|
|
cond.key == "category" and cond.match.value == "tech"
|
|
for cond in query_filter.must
|
|
)
|
|
assert any(
|
|
cond.key == "status" and cond.match.value == "published"
|
|
for cond in query_filter.must
|
|
)
|
|
|
|
def test_search_with_options(self, client, mock_qdrant_client):
|
|
"""Test that search applies limit and score_threshold correctly."""
|
|
mock_qdrant_client.collection_exists.return_value = True
|
|
client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
|
|
|
mock_response = Mock()
|
|
mock_response.points = []
|
|
mock_qdrant_client.query_points.return_value = mock_response
|
|
|
|
client.search(
|
|
collection_name="test_collection",
|
|
query="test query",
|
|
limit=5,
|
|
score_threshold=0.8,
|
|
)
|
|
|
|
call_args = mock_qdrant_client.query_points.call_args
|
|
assert call_args.kwargs["limit"] == 5
|
|
assert call_args.kwargs["score_threshold"] == 0.8
|
|
|
|
def test_search_collection_not_exists(self, client, mock_qdrant_client):
|
|
"""Test that search raises error if collection doesn't exist."""
|
|
mock_qdrant_client.collection_exists.return_value = False
|
|
|
|
with pytest.raises(
|
|
ValueError, match="Collection 'test_collection' does not exist"
|
|
):
|
|
client.search(collection_name="test_collection", query="test query")
|
|
|
|
def test_search_wrong_client_type(self, mock_async_qdrant_client):
|
|
"""Test that search raises TypeError for async client."""
|
|
client = QdrantClient(
|
|
client=mock_async_qdrant_client, embedding_function=Mock()
|
|
)
|
|
|
|
with pytest.raises(
|
|
ClientMethodMismatchError, match=r"Method search\(\) requires"
|
|
):
|
|
client.search(collection_name="test_collection", query="test query")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_asearch(self, async_client, mock_async_qdrant_client):
|
|
"""Test that asearch returns matching documents asynchronously."""
|
|
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
|
async_client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
|
|
|
mock_point = Mock()
|
|
mock_point.id = "doc-123"
|
|
mock_point.payload = {"content": "Test content", "source": "test"}
|
|
mock_point.score = 0.95
|
|
|
|
mock_response = Mock()
|
|
mock_response.points = [mock_point]
|
|
mock_async_qdrant_client.query_points = AsyncMock(return_value=mock_response)
|
|
|
|
results = await async_client.asearch(
|
|
collection_name="test_collection", query="test query"
|
|
)
|
|
|
|
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
|
"test_collection"
|
|
)
|
|
async_client.embedding_function.assert_called_once_with("test query")
|
|
mock_async_qdrant_client.query_points.assert_called_once()
|
|
|
|
call_args = mock_async_qdrant_client.query_points.call_args
|
|
assert call_args.kwargs["collection_name"] == "test_collection"
|
|
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
|
|
assert call_args.kwargs["limit"] == 10
|
|
assert call_args.kwargs["with_payload"] is True
|
|
assert call_args.kwargs["with_vectors"] is False
|
|
|
|
assert len(results) == 1
|
|
assert results[0]["id"] == "doc-123"
|
|
assert results[0]["content"] == "Test content"
|
|
assert results[0]["metadata"] == {"source": "test"}
|
|
assert results[0]["score"] == 0.975
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_asearch_with_filters(self, async_client, mock_async_qdrant_client):
|
|
"""Test that asearch applies metadata filters correctly."""
|
|
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
|
async_client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
|
|
|
mock_response = Mock()
|
|
mock_response.points = []
|
|
mock_async_qdrant_client.query_points = AsyncMock(return_value=mock_response)
|
|
|
|
await async_client.asearch(
|
|
collection_name="test_collection",
|
|
query="test query",
|
|
metadata_filter={"category": "tech", "status": "published"},
|
|
)
|
|
|
|
call_args = mock_async_qdrant_client.query_points.call_args
|
|
query_filter = call_args.kwargs["query_filter"]
|
|
assert len(query_filter.must) == 2
|
|
assert any(
|
|
cond.key == "category" and cond.match.value == "tech"
|
|
for cond in query_filter.must
|
|
)
|
|
assert any(
|
|
cond.key == "status" and cond.match.value == "published"
|
|
for cond in query_filter.must
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_asearch_collection_not_exists(
|
|
self, async_client, mock_async_qdrant_client
|
|
):
|
|
"""Test that asearch raises error if collection doesn't exist."""
|
|
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
|
|
|
|
with pytest.raises(
|
|
ValueError, match="Collection 'test_collection' does not exist"
|
|
):
|
|
await async_client.asearch(
|
|
collection_name="test_collection", query="test query"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_asearch_wrong_client_type(self, mock_qdrant_client):
|
|
"""Test that asearch raises TypeError for sync client."""
|
|
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
|
|
|
with pytest.raises(
|
|
ClientMethodMismatchError, match=r"Method asearch\(\) requires"
|
|
):
|
|
await client.asearch(collection_name="test_collection", query="test query")
|
|
|
|
def test_delete_collection(self, client, mock_qdrant_client):
|
|
"""Test that delete_collection deletes the collection."""
|
|
mock_qdrant_client.collection_exists.return_value = True
|
|
|
|
client.delete_collection(collection_name="test_collection")
|
|
|
|
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
|
mock_qdrant_client.delete_collection.assert_called_once_with(
|
|
collection_name="test_collection"
|
|
)
|
|
|
|
def test_delete_collection_not_exists(self, client, mock_qdrant_client):
|
|
"""Test that delete_collection raises error if collection doesn't exist."""
|
|
mock_qdrant_client.collection_exists.return_value = False
|
|
|
|
with pytest.raises(
|
|
ValueError, match="Collection 'test_collection' does not exist"
|
|
):
|
|
client.delete_collection(collection_name="test_collection")
|
|
|
|
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
|
mock_qdrant_client.delete_collection.assert_not_called()
|
|
|
|
def test_delete_collection_wrong_client_type(self, mock_async_qdrant_client):
|
|
"""Test that delete_collection raises TypeError for async client."""
|
|
client = QdrantClient(
|
|
client=mock_async_qdrant_client, embedding_function=Mock()
|
|
)
|
|
|
|
with pytest.raises(
|
|
ClientMethodMismatchError, match=r"Method delete_collection\(\) requires"
|
|
):
|
|
client.delete_collection(collection_name="test_collection")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_adelete_collection(self, async_client, mock_async_qdrant_client):
|
|
"""Test that adelete_collection deletes the collection asynchronously."""
|
|
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
|
mock_async_qdrant_client.delete_collection = AsyncMock()
|
|
|
|
await async_client.adelete_collection(collection_name="test_collection")
|
|
|
|
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
|
"test_collection"
|
|
)
|
|
mock_async_qdrant_client.delete_collection.assert_called_once_with(
|
|
collection_name="test_collection"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_adelete_collection_not_exists(
|
|
self, async_client, mock_async_qdrant_client
|
|
):
|
|
"""Test that adelete_collection raises error if collection doesn't exist."""
|
|
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
|
|
|
|
with pytest.raises(
|
|
ValueError, match="Collection 'test_collection' does not exist"
|
|
):
|
|
await async_client.adelete_collection(collection_name="test_collection")
|
|
|
|
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
|
"test_collection"
|
|
)
|
|
mock_async_qdrant_client.delete_collection.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_adelete_collection_wrong_client_type(self, mock_qdrant_client):
|
|
"""Test that adelete_collection raises TypeError for sync client."""
|
|
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
|
|
|
with pytest.raises(
|
|
ClientMethodMismatchError, match=r"Method adelete_collection\(\) requires"
|
|
):
|
|
await client.adelete_collection(collection_name="test_collection")
|
|
|
|
def test_reset(self, client, mock_qdrant_client):
|
|
"""Test that reset deletes all collections."""
|
|
mock_collection1 = Mock()
|
|
mock_collection1.name = "collection1"
|
|
mock_collection2 = Mock()
|
|
mock_collection2.name = "collection2"
|
|
mock_collection3 = Mock()
|
|
mock_collection3.name = "collection3"
|
|
|
|
mock_collections_response = Mock()
|
|
mock_collections_response.collections = [
|
|
mock_collection1,
|
|
mock_collection2,
|
|
mock_collection3,
|
|
]
|
|
mock_qdrant_client.get_collections.return_value = mock_collections_response
|
|
|
|
client.reset()
|
|
|
|
mock_qdrant_client.get_collections.assert_called_once()
|
|
assert mock_qdrant_client.delete_collection.call_count == 3
|
|
mock_qdrant_client.delete_collection.assert_any_call(
|
|
collection_name="collection1"
|
|
)
|
|
mock_qdrant_client.delete_collection.assert_any_call(
|
|
collection_name="collection2"
|
|
)
|
|
mock_qdrant_client.delete_collection.assert_any_call(
|
|
collection_name="collection3"
|
|
)
|
|
|
|
def test_reset_no_collections(self, client, mock_qdrant_client):
|
|
"""Test that reset handles no collections gracefully."""
|
|
mock_collections_response = Mock()
|
|
mock_collections_response.collections = []
|
|
mock_qdrant_client.get_collections.return_value = mock_collections_response
|
|
|
|
client.reset()
|
|
|
|
mock_qdrant_client.get_collections.assert_called_once()
|
|
mock_qdrant_client.delete_collection.assert_not_called()
|
|
|
|
def test_reset_wrong_client_type(self, mock_async_qdrant_client):
|
|
"""Test that reset raises TypeError for async client."""
|
|
client = QdrantClient(
|
|
client=mock_async_qdrant_client, embedding_function=Mock()
|
|
)
|
|
|
|
with pytest.raises(
|
|
ClientMethodMismatchError, match=r"Method reset\(\) requires"
|
|
):
|
|
client.reset()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_areset(self, async_client, mock_async_qdrant_client):
|
|
"""Test that areset deletes all collections asynchronously."""
|
|
mock_collection1 = Mock()
|
|
mock_collection1.name = "collection1"
|
|
mock_collection2 = Mock()
|
|
mock_collection2.name = "collection2"
|
|
mock_collection3 = Mock()
|
|
mock_collection3.name = "collection3"
|
|
|
|
mock_collections_response = Mock()
|
|
mock_collections_response.collections = [
|
|
mock_collection1,
|
|
mock_collection2,
|
|
mock_collection3,
|
|
]
|
|
mock_async_qdrant_client.get_collections = AsyncMock(
|
|
return_value=mock_collections_response
|
|
)
|
|
mock_async_qdrant_client.delete_collection = AsyncMock()
|
|
|
|
await async_client.areset()
|
|
|
|
mock_async_qdrant_client.get_collections.assert_called_once()
|
|
assert mock_async_qdrant_client.delete_collection.call_count == 3
|
|
mock_async_qdrant_client.delete_collection.assert_any_call(
|
|
collection_name="collection1"
|
|
)
|
|
mock_async_qdrant_client.delete_collection.assert_any_call(
|
|
collection_name="collection2"
|
|
)
|
|
mock_async_qdrant_client.delete_collection.assert_any_call(
|
|
collection_name="collection3"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_areset_no_collections(self, async_client, mock_async_qdrant_client):
|
|
"""Test that areset handles no collections gracefully."""
|
|
mock_collections_response = Mock()
|
|
mock_collections_response.collections = []
|
|
mock_async_qdrant_client.get_collections = AsyncMock(
|
|
return_value=mock_collections_response
|
|
)
|
|
|
|
await async_client.areset()
|
|
|
|
mock_async_qdrant_client.get_collections.assert_called_once()
|
|
mock_async_qdrant_client.delete_collection.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_areset_wrong_client_type(self, mock_qdrant_client):
|
|
"""Test that areset raises TypeError for sync client."""
|
|
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
|
|
|
with pytest.raises(
|
|
ClientMethodMismatchError, match=r"Method areset\(\) requires"
|
|
):
|
|
await client.areset()
|