Files
crewAI/tests/rag/qdrant/test_client.py
Greyson LaLonde 4b4a119a9f refactor: simplify rag client initialization (#3401)
* Simplified Qdrant and ChromaDB client initialization
* Refactored factory structure and updated tests accordingly
2025-08-26 08:54:51 -04:00

792 lines
31 KiB
Python

"""Tests for QdrantClient implementation."""
from unittest.mock import AsyncMock, Mock
import pytest
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
from crewai.rag.core.exceptions import ClientMethodMismatchError
from crewai.rag.qdrant.client import QdrantClient
from crewai.rag.types import BaseRecord
@pytest.fixture
def mock_qdrant_client():
"""Create a mock Qdrant client."""
return Mock(spec=SyncQdrantClient)
@pytest.fixture
def mock_async_qdrant_client():
"""Create a mock async Qdrant client."""
return Mock(spec=AsyncQdrantClient)
@pytest.fixture
def client(mock_qdrant_client) -> QdrantClient:
"""Create a QdrantClient instance for testing."""
mock_embedding = Mock()
mock_embedding.return_value = [0.1, 0.2, 0.3]
client = QdrantClient(client=mock_qdrant_client, embedding_function=mock_embedding)
return client
@pytest.fixture
def async_client(mock_async_qdrant_client) -> QdrantClient:
"""Create a QdrantClient instance with async client for testing."""
mock_embedding = Mock()
mock_embedding.return_value = [0.1, 0.2, 0.3]
client = QdrantClient(
client=mock_async_qdrant_client, embedding_function=mock_embedding
)
return client
class TestQdrantClient:
"""Test suite for QdrantClient."""
def test_create_collection(self, client, mock_qdrant_client):
"""Test that create_collection creates a new collection."""
mock_qdrant_client.collection_exists.return_value = False
client.create_collection(collection_name="test_collection")
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
mock_qdrant_client.create_collection.assert_called_once()
call_args = mock_qdrant_client.create_collection.call_args
assert call_args.kwargs["collection_name"] == "test_collection"
assert call_args.kwargs["vectors_config"] is not None
def test_create_collection_already_exists(self, client, mock_qdrant_client):
"""Test that create_collection raises error if collection exists."""
mock_qdrant_client.collection_exists.return_value = True
with pytest.raises(
ValueError, match="Collection 'test_collection' already exists"
):
client.create_collection(collection_name="test_collection")
def test_create_collection_wrong_client_type(self, mock_async_qdrant_client):
"""Test that create_collection raises TypeError for async client."""
client = QdrantClient(
client=mock_async_qdrant_client, embedding_function=Mock()
)
with pytest.raises(
ClientMethodMismatchError, match=r"Method create_collection\(\) requires"
):
client.create_collection(collection_name="test_collection")
@pytest.mark.asyncio
async def test_acreate_collection(self, async_client, mock_async_qdrant_client):
"""Test that acreate_collection creates a new collection asynchronously."""
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
mock_async_qdrant_client.create_collection = AsyncMock()
await async_client.acreate_collection(collection_name="test_collection")
mock_async_qdrant_client.collection_exists.assert_called_once_with(
"test_collection"
)
mock_async_qdrant_client.create_collection.assert_called_once()
call_args = mock_async_qdrant_client.create_collection.call_args
assert call_args.kwargs["collection_name"] == "test_collection"
assert call_args.kwargs["vectors_config"] is not None
@pytest.mark.asyncio
async def test_acreate_collection_already_exists(
self, async_client, mock_async_qdrant_client
):
"""Test that acreate_collection raises error if collection exists."""
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
with pytest.raises(
ValueError, match="Collection 'test_collection' already exists"
):
await async_client.acreate_collection(collection_name="test_collection")
@pytest.mark.asyncio
async def test_acreate_collection_wrong_client_type(self, mock_qdrant_client):
"""Test that acreate_collection raises TypeError for sync client."""
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
with pytest.raises(
ClientMethodMismatchError, match=r"Method acreate_collection\(\) requires"
):
await client.acreate_collection(collection_name="test_collection")
def test_get_or_create_collection_existing(self, client, mock_qdrant_client):
"""Test get_or_create_collection returns existing collection."""
mock_qdrant_client.collection_exists.return_value = True
mock_collection_info = Mock()
mock_qdrant_client.get_collection.return_value = mock_collection_info
result = client.get_or_create_collection(collection_name="test_collection")
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
mock_qdrant_client.get_collection.assert_called_once_with("test_collection")
mock_qdrant_client.create_collection.assert_not_called()
assert result == mock_collection_info
def test_get_or_create_collection_new(self, client, mock_qdrant_client):
"""Test get_or_create_collection creates new collection."""
mock_qdrant_client.collection_exists.return_value = False
mock_collection_info = Mock()
mock_qdrant_client.get_collection.return_value = mock_collection_info
result = client.get_or_create_collection(collection_name="test_collection")
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
mock_qdrant_client.create_collection.assert_called_once()
mock_qdrant_client.get_collection.assert_called_once_with("test_collection")
assert result == mock_collection_info
def test_get_or_create_collection_wrong_client_type(self, mock_async_qdrant_client):
"""Test get_or_create_collection raises TypeError for async client."""
client = QdrantClient(
client=mock_async_qdrant_client, embedding_function=Mock()
)
with pytest.raises(
ClientMethodMismatchError,
match=r"Method get_or_create_collection\(\) requires",
):
client.get_or_create_collection(collection_name="test_collection")
@pytest.mark.asyncio
async def test_aget_or_create_collection_existing(
self, async_client, mock_async_qdrant_client
):
"""Test aget_or_create_collection returns existing collection."""
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
mock_collection_info = Mock()
mock_async_qdrant_client.get_collection = AsyncMock(
return_value=mock_collection_info
)
result = await async_client.aget_or_create_collection(
collection_name="test_collection"
)
mock_async_qdrant_client.collection_exists.assert_called_once_with(
"test_collection"
)
mock_async_qdrant_client.get_collection.assert_called_once_with(
"test_collection"
)
mock_async_qdrant_client.create_collection.assert_not_called()
assert result == mock_collection_info
@pytest.mark.asyncio
async def test_aget_or_create_collection_new(
self, async_client, mock_async_qdrant_client
):
"""Test aget_or_create_collection creates new collection."""
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
mock_async_qdrant_client.create_collection = AsyncMock()
mock_collection_info = Mock()
mock_async_qdrant_client.get_collection = AsyncMock(
return_value=mock_collection_info
)
result = await async_client.aget_or_create_collection(
collection_name="test_collection"
)
mock_async_qdrant_client.collection_exists.assert_called_once_with(
"test_collection"
)
mock_async_qdrant_client.create_collection.assert_called_once()
mock_async_qdrant_client.get_collection.assert_called_once_with(
"test_collection"
)
assert result == mock_collection_info
@pytest.mark.asyncio
async def test_aget_or_create_collection_wrong_client_type(
self, mock_qdrant_client
):
"""Test aget_or_create_collection raises TypeError for sync client."""
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
with pytest.raises(
ClientMethodMismatchError,
match=r"Method aget_or_create_collection\(\) requires",
):
await client.aget_or_create_collection(collection_name="test_collection")
def test_add_documents(self, client, mock_qdrant_client):
"""Test that add_documents adds documents to collection."""
mock_qdrant_client.collection_exists.return_value = True
client.embedding_function.return_value = [0.1, 0.2, 0.3]
documents: list[BaseRecord] = [
{
"content": "Test document",
"metadata": {"source": "test"},
}
]
client.add_documents(collection_name="test_collection", documents=documents)
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
client.embedding_function.assert_called_once_with("Test document")
mock_qdrant_client.upsert.assert_called_once()
# Check upsert was called with correct parameters
call_args = mock_qdrant_client.upsert.call_args
assert call_args.kwargs["collection_name"] == "test_collection"
assert len(call_args.kwargs["points"]) == 1
point = call_args.kwargs["points"][0]
assert point.vector == [0.1, 0.2, 0.3]
assert point.payload["content"] == "Test document"
assert point.payload["source"] == "test"
def test_add_documents_with_doc_id(self, client, mock_qdrant_client):
"""Test that add_documents uses provided doc_id."""
mock_qdrant_client.collection_exists.return_value = True
client.embedding_function.return_value = [0.1, 0.2, 0.3]
documents: list[BaseRecord] = [
{
"doc_id": "custom-id-123",
"content": "Test document",
"metadata": {"source": "test"},
}
]
client.add_documents(collection_name="test_collection", documents=documents)
call_args = mock_qdrant_client.upsert.call_args
point = call_args.kwargs["points"][0]
assert point.id == "custom-id-123"
def test_add_documents_empty_list(self, client, mock_qdrant_client):
"""Test that add_documents raises error for empty documents list."""
documents: list[BaseRecord] = []
with pytest.raises(ValueError, match="Documents list cannot be empty"):
client.add_documents(collection_name="test_collection", documents=documents)
def test_add_documents_collection_not_exists(self, client, mock_qdrant_client):
"""Test that add_documents raises error if collection doesn't exist."""
mock_qdrant_client.collection_exists.return_value = False
documents: list[BaseRecord] = [
{
"content": "Test document",
"metadata": {"source": "test"},
}
]
with pytest.raises(
ValueError, match="Collection 'test_collection' does not exist"
):
client.add_documents(collection_name="test_collection", documents=documents)
def test_add_documents_wrong_client_type(self, mock_async_qdrant_client):
"""Test that add_documents raises TypeError for async client."""
client = QdrantClient(
client=mock_async_qdrant_client, embedding_function=Mock()
)
documents: list[BaseRecord] = [
{
"content": "Test document",
"metadata": {"source": "test"},
}
]
with pytest.raises(
ClientMethodMismatchError, match=r"Method add_documents\(\) requires"
):
client.add_documents(collection_name="test_collection", documents=documents)
@pytest.mark.asyncio
async def test_aadd_documents(self, async_client, mock_async_qdrant_client):
"""Test that aadd_documents adds documents to collection asynchronously."""
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
mock_async_qdrant_client.upsert = AsyncMock()
async_client.embedding_function.return_value = [0.1, 0.2, 0.3]
documents: list[BaseRecord] = [
{
"content": "Test document",
"metadata": {"source": "test"},
}
]
await async_client.aadd_documents(
collection_name="test_collection", documents=documents
)
mock_async_qdrant_client.collection_exists.assert_called_once_with(
"test_collection"
)
async_client.embedding_function.assert_called_once_with("Test document")
mock_async_qdrant_client.upsert.assert_called_once()
# Check upsert was called with correct parameters
call_args = mock_async_qdrant_client.upsert.call_args
assert call_args.kwargs["collection_name"] == "test_collection"
assert len(call_args.kwargs["points"]) == 1
point = call_args.kwargs["points"][0]
assert point.vector == [0.1, 0.2, 0.3]
assert point.payload["content"] == "Test document"
assert point.payload["source"] == "test"
@pytest.mark.asyncio
async def test_aadd_documents_with_doc_id(
self, async_client, mock_async_qdrant_client
):
"""Test that aadd_documents uses provided doc_id."""
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
mock_async_qdrant_client.upsert = AsyncMock()
async_client.embedding_function.return_value = [0.1, 0.2, 0.3]
documents: list[BaseRecord] = [
{
"doc_id": "custom-id-123",
"content": "Test document",
"metadata": {"source": "test"},
}
]
await async_client.aadd_documents(
collection_name="test_collection", documents=documents
)
call_args = mock_async_qdrant_client.upsert.call_args
point = call_args.kwargs["points"][0]
assert point.id == "custom-id-123"
@pytest.mark.asyncio
async def test_aadd_documents_empty_list(
self, async_client, mock_async_qdrant_client
):
"""Test that aadd_documents raises error for empty documents list."""
documents: list[BaseRecord] = []
with pytest.raises(ValueError, match="Documents list cannot be empty"):
await async_client.aadd_documents(
collection_name="test_collection", documents=documents
)
@pytest.mark.asyncio
async def test_aadd_documents_collection_not_exists(
self, async_client, mock_async_qdrant_client
):
"""Test that aadd_documents raises error if collection doesn't exist."""
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
documents: list[BaseRecord] = [
{
"content": "Test document",
"metadata": {"source": "test"},
}
]
with pytest.raises(
ValueError, match="Collection 'test_collection' does not exist"
):
await async_client.aadd_documents(
collection_name="test_collection", documents=documents
)
@pytest.mark.asyncio
async def test_aadd_documents_wrong_client_type(self, mock_qdrant_client):
"""Test that aadd_documents raises TypeError for sync client."""
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
documents: list[BaseRecord] = [
{
"content": "Test document",
"metadata": {"source": "test"},
}
]
with pytest.raises(
ClientMethodMismatchError, match=r"Method aadd_documents\(\) requires"
):
await client.aadd_documents(
collection_name="test_collection", documents=documents
)
def test_search(self, client, mock_qdrant_client):
"""Test that search returns matching documents."""
mock_qdrant_client.collection_exists.return_value = True
client.embedding_function.return_value = [0.1, 0.2, 0.3]
mock_point = Mock()
mock_point.id = "doc-123"
mock_point.payload = {"content": "Test content", "source": "test"}
mock_point.score = 0.95
mock_response = Mock()
mock_response.points = [mock_point]
mock_qdrant_client.query_points.return_value = mock_response
results = client.search(collection_name="test_collection", query="test query")
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
client.embedding_function.assert_called_once_with("test query")
mock_qdrant_client.query_points.assert_called_once()
call_args = mock_qdrant_client.query_points.call_args
assert call_args.kwargs["collection_name"] == "test_collection"
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
assert call_args.kwargs["limit"] == 10
assert call_args.kwargs["with_payload"] is True
assert call_args.kwargs["with_vectors"] is False
assert len(results) == 1
assert results[0]["id"] == "doc-123"
assert results[0]["content"] == "Test content"
assert results[0]["metadata"] == {"source": "test"}
assert results[0]["score"] == 0.975
def test_search_with_filters(self, client, mock_qdrant_client):
"""Test that search applies metadata filters correctly."""
mock_qdrant_client.collection_exists.return_value = True
client.embedding_function.return_value = [0.1, 0.2, 0.3]
mock_response = Mock()
mock_response.points = []
mock_qdrant_client.query_points.return_value = mock_response
client.search(
collection_name="test_collection",
query="test query",
metadata_filter={"category": "tech", "status": "published"},
)
call_args = mock_qdrant_client.query_points.call_args
query_filter = call_args.kwargs["query_filter"]
assert len(query_filter.must) == 2
assert any(
cond.key == "category" and cond.match.value == "tech"
for cond in query_filter.must
)
assert any(
cond.key == "status" and cond.match.value == "published"
for cond in query_filter.must
)
def test_search_with_options(self, client, mock_qdrant_client):
"""Test that search applies limit and score_threshold correctly."""
mock_qdrant_client.collection_exists.return_value = True
client.embedding_function.return_value = [0.1, 0.2, 0.3]
mock_response = Mock()
mock_response.points = []
mock_qdrant_client.query_points.return_value = mock_response
client.search(
collection_name="test_collection",
query="test query",
limit=5,
score_threshold=0.8,
)
call_args = mock_qdrant_client.query_points.call_args
assert call_args.kwargs["limit"] == 5
assert call_args.kwargs["score_threshold"] == 0.8
def test_search_collection_not_exists(self, client, mock_qdrant_client):
"""Test that search raises error if collection doesn't exist."""
mock_qdrant_client.collection_exists.return_value = False
with pytest.raises(
ValueError, match="Collection 'test_collection' does not exist"
):
client.search(collection_name="test_collection", query="test query")
def test_search_wrong_client_type(self, mock_async_qdrant_client):
"""Test that search raises TypeError for async client."""
client = QdrantClient(
client=mock_async_qdrant_client, embedding_function=Mock()
)
with pytest.raises(
ClientMethodMismatchError, match=r"Method search\(\) requires"
):
client.search(collection_name="test_collection", query="test query")
@pytest.mark.asyncio
async def test_asearch(self, async_client, mock_async_qdrant_client):
"""Test that asearch returns matching documents asynchronously."""
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
async_client.embedding_function.return_value = [0.1, 0.2, 0.3]
mock_point = Mock()
mock_point.id = "doc-123"
mock_point.payload = {"content": "Test content", "source": "test"}
mock_point.score = 0.95
mock_response = Mock()
mock_response.points = [mock_point]
mock_async_qdrant_client.query_points = AsyncMock(return_value=mock_response)
results = await async_client.asearch(
collection_name="test_collection", query="test query"
)
mock_async_qdrant_client.collection_exists.assert_called_once_with(
"test_collection"
)
async_client.embedding_function.assert_called_once_with("test query")
mock_async_qdrant_client.query_points.assert_called_once()
call_args = mock_async_qdrant_client.query_points.call_args
assert call_args.kwargs["collection_name"] == "test_collection"
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
assert call_args.kwargs["limit"] == 10
assert call_args.kwargs["with_payload"] is True
assert call_args.kwargs["with_vectors"] is False
assert len(results) == 1
assert results[0]["id"] == "doc-123"
assert results[0]["content"] == "Test content"
assert results[0]["metadata"] == {"source": "test"}
assert results[0]["score"] == 0.975
@pytest.mark.asyncio
async def test_asearch_with_filters(self, async_client, mock_async_qdrant_client):
"""Test that asearch applies metadata filters correctly."""
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
async_client.embedding_function.return_value = [0.1, 0.2, 0.3]
mock_response = Mock()
mock_response.points = []
mock_async_qdrant_client.query_points = AsyncMock(return_value=mock_response)
await async_client.asearch(
collection_name="test_collection",
query="test query",
metadata_filter={"category": "tech", "status": "published"},
)
call_args = mock_async_qdrant_client.query_points.call_args
query_filter = call_args.kwargs["query_filter"]
assert len(query_filter.must) == 2
assert any(
cond.key == "category" and cond.match.value == "tech"
for cond in query_filter.must
)
assert any(
cond.key == "status" and cond.match.value == "published"
for cond in query_filter.must
)
@pytest.mark.asyncio
async def test_asearch_collection_not_exists(
self, async_client, mock_async_qdrant_client
):
"""Test that asearch raises error if collection doesn't exist."""
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
with pytest.raises(
ValueError, match="Collection 'test_collection' does not exist"
):
await async_client.asearch(
collection_name="test_collection", query="test query"
)
@pytest.mark.asyncio
async def test_asearch_wrong_client_type(self, mock_qdrant_client):
"""Test that asearch raises TypeError for sync client."""
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
with pytest.raises(
ClientMethodMismatchError, match=r"Method asearch\(\) requires"
):
await client.asearch(collection_name="test_collection", query="test query")
def test_delete_collection(self, client, mock_qdrant_client):
"""Test that delete_collection deletes the collection."""
mock_qdrant_client.collection_exists.return_value = True
client.delete_collection(collection_name="test_collection")
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
mock_qdrant_client.delete_collection.assert_called_once_with(
collection_name="test_collection"
)
def test_delete_collection_not_exists(self, client, mock_qdrant_client):
"""Test that delete_collection raises error if collection doesn't exist."""
mock_qdrant_client.collection_exists.return_value = False
with pytest.raises(
ValueError, match="Collection 'test_collection' does not exist"
):
client.delete_collection(collection_name="test_collection")
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
mock_qdrant_client.delete_collection.assert_not_called()
def test_delete_collection_wrong_client_type(self, mock_async_qdrant_client):
"""Test that delete_collection raises TypeError for async client."""
client = QdrantClient(
client=mock_async_qdrant_client, embedding_function=Mock()
)
with pytest.raises(
ClientMethodMismatchError, match=r"Method delete_collection\(\) requires"
):
client.delete_collection(collection_name="test_collection")
@pytest.mark.asyncio
async def test_adelete_collection(self, async_client, mock_async_qdrant_client):
"""Test that adelete_collection deletes the collection asynchronously."""
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
mock_async_qdrant_client.delete_collection = AsyncMock()
await async_client.adelete_collection(collection_name="test_collection")
mock_async_qdrant_client.collection_exists.assert_called_once_with(
"test_collection"
)
mock_async_qdrant_client.delete_collection.assert_called_once_with(
collection_name="test_collection"
)
@pytest.mark.asyncio
async def test_adelete_collection_not_exists(
self, async_client, mock_async_qdrant_client
):
"""Test that adelete_collection raises error if collection doesn't exist."""
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
with pytest.raises(
ValueError, match="Collection 'test_collection' does not exist"
):
await async_client.adelete_collection(collection_name="test_collection")
mock_async_qdrant_client.collection_exists.assert_called_once_with(
"test_collection"
)
mock_async_qdrant_client.delete_collection.assert_not_called()
@pytest.mark.asyncio
async def test_adelete_collection_wrong_client_type(self, mock_qdrant_client):
"""Test that adelete_collection raises TypeError for sync client."""
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
with pytest.raises(
ClientMethodMismatchError, match=r"Method adelete_collection\(\) requires"
):
await client.adelete_collection(collection_name="test_collection")
def test_reset(self, client, mock_qdrant_client):
"""Test that reset deletes all collections."""
mock_collection1 = Mock()
mock_collection1.name = "collection1"
mock_collection2 = Mock()
mock_collection2.name = "collection2"
mock_collection3 = Mock()
mock_collection3.name = "collection3"
mock_collections_response = Mock()
mock_collections_response.collections = [
mock_collection1,
mock_collection2,
mock_collection3,
]
mock_qdrant_client.get_collections.return_value = mock_collections_response
client.reset()
mock_qdrant_client.get_collections.assert_called_once()
assert mock_qdrant_client.delete_collection.call_count == 3
mock_qdrant_client.delete_collection.assert_any_call(
collection_name="collection1"
)
mock_qdrant_client.delete_collection.assert_any_call(
collection_name="collection2"
)
mock_qdrant_client.delete_collection.assert_any_call(
collection_name="collection3"
)
def test_reset_no_collections(self, client, mock_qdrant_client):
"""Test that reset handles no collections gracefully."""
mock_collections_response = Mock()
mock_collections_response.collections = []
mock_qdrant_client.get_collections.return_value = mock_collections_response
client.reset()
mock_qdrant_client.get_collections.assert_called_once()
mock_qdrant_client.delete_collection.assert_not_called()
def test_reset_wrong_client_type(self, mock_async_qdrant_client):
"""Test that reset raises TypeError for async client."""
client = QdrantClient(
client=mock_async_qdrant_client, embedding_function=Mock()
)
with pytest.raises(
ClientMethodMismatchError, match=r"Method reset\(\) requires"
):
client.reset()
@pytest.mark.asyncio
async def test_areset(self, async_client, mock_async_qdrant_client):
"""Test that areset deletes all collections asynchronously."""
mock_collection1 = Mock()
mock_collection1.name = "collection1"
mock_collection2 = Mock()
mock_collection2.name = "collection2"
mock_collection3 = Mock()
mock_collection3.name = "collection3"
mock_collections_response = Mock()
mock_collections_response.collections = [
mock_collection1,
mock_collection2,
mock_collection3,
]
mock_async_qdrant_client.get_collections = AsyncMock(
return_value=mock_collections_response
)
mock_async_qdrant_client.delete_collection = AsyncMock()
await async_client.areset()
mock_async_qdrant_client.get_collections.assert_called_once()
assert mock_async_qdrant_client.delete_collection.call_count == 3
mock_async_qdrant_client.delete_collection.assert_any_call(
collection_name="collection1"
)
mock_async_qdrant_client.delete_collection.assert_any_call(
collection_name="collection2"
)
mock_async_qdrant_client.delete_collection.assert_any_call(
collection_name="collection3"
)
@pytest.mark.asyncio
async def test_areset_no_collections(self, async_client, mock_async_qdrant_client):
"""Test that areset handles no collections gracefully."""
mock_collections_response = Mock()
mock_collections_response.collections = []
mock_async_qdrant_client.get_collections = AsyncMock(
return_value=mock_collections_response
)
await async_client.areset()
mock_async_qdrant_client.get_collections.assert_called_once()
mock_async_qdrant_client.delete_collection.assert_not_called()
@pytest.mark.asyncio
async def test_areset_wrong_client_type(self, mock_qdrant_client):
"""Test that areset raises TypeError for sync client."""
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
with pytest.raises(
ClientMethodMismatchError, match=r"Method areset\(\) requires"
):
await client.areset()