mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
feat: qdrant generic client (#3377)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
### Qdrant Client * Add core client with collection, search, and document APIs (sync + async) * Refactor utilities, types, and vector params (default 384-dim) * Improve error handling with `ClientMethodMismatchError` * Add score normalization, async embeddings, and optional `qdrant-client` dep * Expand tests and type safety throughout
This commit is contained in:
793
tests/rag/qdrant/test_client.py
Normal file
793
tests/rag/qdrant/test_client.py
Normal file
@@ -0,0 +1,793 @@
|
||||
"""Tests for QdrantClient implementation."""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
from qdrant_client import AsyncQdrantClient, QdrantClient as SyncQdrantClient
|
||||
|
||||
from crewai.rag.core.exceptions import ClientMethodMismatchError
|
||||
from crewai.rag.qdrant.client import QdrantClient
|
||||
from crewai.rag.types import BaseRecord
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_qdrant_client():
|
||||
"""Create a mock Qdrant client."""
|
||||
return Mock(spec=SyncQdrantClient)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_async_qdrant_client():
|
||||
"""Create a mock async Qdrant client."""
|
||||
return Mock(spec=AsyncQdrantClient)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(mock_qdrant_client) -> QdrantClient:
|
||||
"""Create a QdrantClient instance for testing."""
|
||||
mock_embedding = Mock()
|
||||
mock_embedding.return_value = [0.1, 0.2, 0.3]
|
||||
client = QdrantClient(client=mock_qdrant_client, embedding_function=mock_embedding)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_client(mock_async_qdrant_client) -> QdrantClient:
|
||||
"""Create a QdrantClient instance with async client for testing."""
|
||||
mock_embedding = Mock()
|
||||
mock_embedding.return_value = [0.1, 0.2, 0.3]
|
||||
client = QdrantClient(
|
||||
client=mock_async_qdrant_client, embedding_function=mock_embedding
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
class TestQdrantClient:
|
||||
"""Test suite for QdrantClient."""
|
||||
|
||||
def test_create_collection(self, client, mock_qdrant_client):
|
||||
"""Test that create_collection creates a new collection."""
|
||||
mock_qdrant_client.collection_exists.return_value = False
|
||||
|
||||
client.create_collection(collection_name="test_collection")
|
||||
|
||||
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
||||
mock_qdrant_client.create_collection.assert_called_once()
|
||||
call_args = mock_qdrant_client.create_collection.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert call_args.kwargs["vectors_config"] is not None
|
||||
|
||||
def test_create_collection_already_exists(self, client, mock_qdrant_client):
|
||||
"""Test that create_collection raises error if collection exists."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' already exists"
|
||||
):
|
||||
client.create_collection(collection_name="test_collection")
|
||||
|
||||
def test_create_collection_wrong_client_type(self, mock_async_qdrant_client):
|
||||
"""Test that create_collection raises TypeError for async client."""
|
||||
client = QdrantClient(
|
||||
client=mock_async_qdrant_client, embedding_function=Mock()
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method create_collection\(\) requires"
|
||||
):
|
||||
client.create_collection(collection_name="test_collection")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acreate_collection(self, async_client, mock_async_qdrant_client):
|
||||
"""Test that acreate_collection creates a new collection asynchronously."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
|
||||
mock_async_qdrant_client.create_collection = AsyncMock()
|
||||
|
||||
await async_client.acreate_collection(collection_name="test_collection")
|
||||
|
||||
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
mock_async_qdrant_client.create_collection.assert_called_once()
|
||||
call_args = mock_async_qdrant_client.create_collection.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert call_args.kwargs["vectors_config"] is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acreate_collection_already_exists(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test that acreate_collection raises error if collection exists."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' already exists"
|
||||
):
|
||||
await async_client.acreate_collection(collection_name="test_collection")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acreate_collection_wrong_client_type(self, mock_qdrant_client):
|
||||
"""Test that acreate_collection raises TypeError for sync client."""
|
||||
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method acreate_collection\(\) requires"
|
||||
):
|
||||
await client.acreate_collection(collection_name="test_collection")
|
||||
|
||||
def test_get_or_create_collection_existing(self, client, mock_qdrant_client):
|
||||
"""Test get_or_create_collection returns existing collection."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
mock_collection_info = Mock()
|
||||
mock_qdrant_client.get_collection.return_value = mock_collection_info
|
||||
|
||||
result = client.get_or_create_collection(collection_name="test_collection")
|
||||
|
||||
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
||||
mock_qdrant_client.get_collection.assert_called_once_with("test_collection")
|
||||
mock_qdrant_client.create_collection.assert_not_called()
|
||||
assert result == mock_collection_info
|
||||
|
||||
def test_get_or_create_collection_new(self, client, mock_qdrant_client):
|
||||
"""Test get_or_create_collection creates new collection."""
|
||||
mock_qdrant_client.collection_exists.return_value = False
|
||||
mock_collection_info = Mock()
|
||||
mock_qdrant_client.get_collection.return_value = mock_collection_info
|
||||
|
||||
result = client.get_or_create_collection(collection_name="test_collection")
|
||||
|
||||
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
||||
mock_qdrant_client.create_collection.assert_called_once()
|
||||
mock_qdrant_client.get_collection.assert_called_once_with("test_collection")
|
||||
assert result == mock_collection_info
|
||||
|
||||
def test_get_or_create_collection_wrong_client_type(self, mock_async_qdrant_client):
|
||||
"""Test get_or_create_collection raises TypeError for async client."""
|
||||
client = QdrantClient(
|
||||
client=mock_async_qdrant_client, embedding_function=Mock()
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError,
|
||||
match=r"Method get_or_create_collection\(\) requires",
|
||||
):
|
||||
client.get_or_create_collection(collection_name="test_collection")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aget_or_create_collection_existing(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test aget_or_create_collection returns existing collection."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
||||
mock_collection_info = Mock()
|
||||
mock_async_qdrant_client.get_collection = AsyncMock(
|
||||
return_value=mock_collection_info
|
||||
)
|
||||
|
||||
result = await async_client.aget_or_create_collection(
|
||||
collection_name="test_collection"
|
||||
)
|
||||
|
||||
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
mock_async_qdrant_client.get_collection.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
mock_async_qdrant_client.create_collection.assert_not_called()
|
||||
assert result == mock_collection_info
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aget_or_create_collection_new(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test aget_or_create_collection creates new collection."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
|
||||
mock_async_qdrant_client.create_collection = AsyncMock()
|
||||
mock_collection_info = Mock()
|
||||
mock_async_qdrant_client.get_collection = AsyncMock(
|
||||
return_value=mock_collection_info
|
||||
)
|
||||
|
||||
result = await async_client.aget_or_create_collection(
|
||||
collection_name="test_collection"
|
||||
)
|
||||
|
||||
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
mock_async_qdrant_client.create_collection.assert_called_once()
|
||||
mock_async_qdrant_client.get_collection.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
assert result == mock_collection_info
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aget_or_create_collection_wrong_client_type(
|
||||
self, mock_qdrant_client
|
||||
):
|
||||
"""Test aget_or_create_collection raises TypeError for sync client."""
|
||||
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError,
|
||||
match=r"Method aget_or_create_collection\(\) requires",
|
||||
):
|
||||
await client.aget_or_create_collection(collection_name="test_collection")
|
||||
|
||||
def test_add_documents(self, client, mock_qdrant_client):
|
||||
"""Test that add_documents adds documents to collection."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
||||
client.embedding_function.assert_called_once_with("Test document")
|
||||
mock_qdrant_client.upsert.assert_called_once()
|
||||
|
||||
# Check upsert was called with correct parameters
|
||||
call_args = mock_qdrant_client.upsert.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert call_args.kwargs["wait"] is True
|
||||
assert len(call_args.kwargs["points"]) == 1
|
||||
point = call_args.kwargs["points"][0]
|
||||
assert point.vector == [0.1, 0.2, 0.3]
|
||||
assert point.payload["content"] == "Test document"
|
||||
assert point.payload["source"] == "test"
|
||||
|
||||
def test_add_documents_with_doc_id(self, client, mock_qdrant_client):
|
||||
"""Test that add_documents uses provided doc_id."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"doc_id": "custom-id-123",
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
call_args = mock_qdrant_client.upsert.call_args
|
||||
point = call_args.kwargs["points"][0]
|
||||
assert point.id == "custom-id-123"
|
||||
|
||||
def test_add_documents_empty_list(self, client, mock_qdrant_client):
|
||||
"""Test that add_documents raises error for empty documents list."""
|
||||
documents: list[BaseRecord] = []
|
||||
|
||||
with pytest.raises(ValueError, match="Documents list cannot be empty"):
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
def test_add_documents_collection_not_exists(self, client, mock_qdrant_client):
|
||||
"""Test that add_documents raises error if collection doesn't exist."""
|
||||
mock_qdrant_client.collection_exists.return_value = False
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' does not exist"
|
||||
):
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
def test_add_documents_wrong_client_type(self, mock_async_qdrant_client):
|
||||
"""Test that add_documents raises TypeError for async client."""
|
||||
client = QdrantClient(
|
||||
client=mock_async_qdrant_client, embedding_function=Mock()
|
||||
)
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method add_documents\(\) requires"
|
||||
):
|
||||
client.add_documents(collection_name="test_collection", documents=documents)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents(self, async_client, mock_async_qdrant_client):
|
||||
"""Test that aadd_documents adds documents to collection asynchronously."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
||||
mock_async_qdrant_client.upsert = AsyncMock()
|
||||
async_client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
async_client.embedding_function.assert_called_once_with("Test document")
|
||||
mock_async_qdrant_client.upsert.assert_called_once()
|
||||
|
||||
# Check upsert was called with correct parameters
|
||||
call_args = mock_async_qdrant_client.upsert.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert call_args.kwargs["wait"] is True
|
||||
assert len(call_args.kwargs["points"]) == 1
|
||||
point = call_args.kwargs["points"][0]
|
||||
assert point.vector == [0.1, 0.2, 0.3]
|
||||
assert point.payload["content"] == "Test document"
|
||||
assert point.payload["source"] == "test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_with_doc_id(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test that aadd_documents uses provided doc_id."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
||||
mock_async_qdrant_client.upsert = AsyncMock()
|
||||
async_client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"doc_id": "custom-id-123",
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
call_args = mock_async_qdrant_client.upsert.call_args
|
||||
point = call_args.kwargs["points"][0]
|
||||
assert point.id == "custom-id-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_empty_list(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test that aadd_documents raises error for empty documents list."""
|
||||
documents: list[BaseRecord] = []
|
||||
|
||||
with pytest.raises(ValueError, match="Documents list cannot be empty"):
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_collection_not_exists(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test that aadd_documents raises error if collection doesn't exist."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' does not exist"
|
||||
):
|
||||
await async_client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aadd_documents_wrong_client_type(self, mock_qdrant_client):
|
||||
"""Test that aadd_documents raises TypeError for sync client."""
|
||||
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
||||
|
||||
documents: list[BaseRecord] = [
|
||||
{
|
||||
"content": "Test document",
|
||||
"metadata": {"source": "test"},
|
||||
}
|
||||
]
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method aadd_documents\(\) requires"
|
||||
):
|
||||
await client.aadd_documents(
|
||||
collection_name="test_collection", documents=documents
|
||||
)
|
||||
|
||||
def test_search(self, client, mock_qdrant_client):
|
||||
"""Test that search returns matching documents."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_point = Mock()
|
||||
mock_point.id = "doc-123"
|
||||
mock_point.payload = {"content": "Test content", "source": "test"}
|
||||
mock_point.score = 0.95
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.points = [mock_point]
|
||||
mock_qdrant_client.query_points.return_value = mock_response
|
||||
|
||||
results = client.search(collection_name="test_collection", query="test query")
|
||||
|
||||
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
||||
client.embedding_function.assert_called_once_with("test query")
|
||||
mock_qdrant_client.query_points.assert_called_once()
|
||||
|
||||
call_args = mock_qdrant_client.query_points.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
|
||||
assert call_args.kwargs["limit"] == 10
|
||||
assert call_args.kwargs["with_payload"] is True
|
||||
assert call_args.kwargs["with_vectors"] is False
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["id"] == "doc-123"
|
||||
assert results[0]["content"] == "Test content"
|
||||
assert results[0]["metadata"] == {"source": "test"}
|
||||
assert results[0]["score"] == 0.975
|
||||
|
||||
def test_search_with_filters(self, client, mock_qdrant_client):
|
||||
"""Test that search applies metadata filters correctly."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.points = []
|
||||
mock_qdrant_client.query_points.return_value = mock_response
|
||||
|
||||
client.search(
|
||||
collection_name="test_collection",
|
||||
query="test query",
|
||||
metadata_filter={"category": "tech", "status": "published"},
|
||||
)
|
||||
|
||||
call_args = mock_qdrant_client.query_points.call_args
|
||||
query_filter = call_args.kwargs["query_filter"]
|
||||
assert len(query_filter.must) == 2
|
||||
assert any(
|
||||
cond.key == "category" and cond.match.value == "tech"
|
||||
for cond in query_filter.must
|
||||
)
|
||||
assert any(
|
||||
cond.key == "status" and cond.match.value == "published"
|
||||
for cond in query_filter.must
|
||||
)
|
||||
|
||||
def test_search_with_options(self, client, mock_qdrant_client):
|
||||
"""Test that search applies limit and score_threshold correctly."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.points = []
|
||||
mock_qdrant_client.query_points.return_value = mock_response
|
||||
|
||||
client.search(
|
||||
collection_name="test_collection",
|
||||
query="test query",
|
||||
limit=5,
|
||||
score_threshold=0.8,
|
||||
)
|
||||
|
||||
call_args = mock_qdrant_client.query_points.call_args
|
||||
assert call_args.kwargs["limit"] == 5
|
||||
assert call_args.kwargs["score_threshold"] == 0.8
|
||||
|
||||
def test_search_collection_not_exists(self, client, mock_qdrant_client):
|
||||
"""Test that search raises error if collection doesn't exist."""
|
||||
mock_qdrant_client.collection_exists.return_value = False
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' does not exist"
|
||||
):
|
||||
client.search(collection_name="test_collection", query="test query")
|
||||
|
||||
def test_search_wrong_client_type(self, mock_async_qdrant_client):
|
||||
"""Test that search raises TypeError for async client."""
|
||||
client = QdrantClient(
|
||||
client=mock_async_qdrant_client, embedding_function=Mock()
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method search\(\) requires"
|
||||
):
|
||||
client.search(collection_name="test_collection", query="test query")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch(self, async_client, mock_async_qdrant_client):
|
||||
"""Test that asearch returns matching documents asynchronously."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
||||
async_client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_point = Mock()
|
||||
mock_point.id = "doc-123"
|
||||
mock_point.payload = {"content": "Test content", "source": "test"}
|
||||
mock_point.score = 0.95
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.points = [mock_point]
|
||||
mock_async_qdrant_client.query_points = AsyncMock(return_value=mock_response)
|
||||
|
||||
results = await async_client.asearch(
|
||||
collection_name="test_collection", query="test query"
|
||||
)
|
||||
|
||||
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
async_client.embedding_function.assert_called_once_with("test query")
|
||||
mock_async_qdrant_client.query_points.assert_called_once()
|
||||
|
||||
call_args = mock_async_qdrant_client.query_points.call_args
|
||||
assert call_args.kwargs["collection_name"] == "test_collection"
|
||||
assert call_args.kwargs["query"] == [0.1, 0.2, 0.3]
|
||||
assert call_args.kwargs["limit"] == 10
|
||||
assert call_args.kwargs["with_payload"] is True
|
||||
assert call_args.kwargs["with_vectors"] is False
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]["id"] == "doc-123"
|
||||
assert results[0]["content"] == "Test content"
|
||||
assert results[0]["metadata"] == {"source": "test"}
|
||||
assert results[0]["score"] == 0.975
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_with_filters(self, async_client, mock_async_qdrant_client):
|
||||
"""Test that asearch applies metadata filters correctly."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
||||
async_client.embedding_function.return_value = [0.1, 0.2, 0.3]
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.points = []
|
||||
mock_async_qdrant_client.query_points = AsyncMock(return_value=mock_response)
|
||||
|
||||
await async_client.asearch(
|
||||
collection_name="test_collection",
|
||||
query="test query",
|
||||
metadata_filter={"category": "tech", "status": "published"},
|
||||
)
|
||||
|
||||
call_args = mock_async_qdrant_client.query_points.call_args
|
||||
query_filter = call_args.kwargs["query_filter"]
|
||||
assert len(query_filter.must) == 2
|
||||
assert any(
|
||||
cond.key == "category" and cond.match.value == "tech"
|
||||
for cond in query_filter.must
|
||||
)
|
||||
assert any(
|
||||
cond.key == "status" and cond.match.value == "published"
|
||||
for cond in query_filter.must
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_collection_not_exists(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test that asearch raises error if collection doesn't exist."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' does not exist"
|
||||
):
|
||||
await async_client.asearch(
|
||||
collection_name="test_collection", query="test query"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_wrong_client_type(self, mock_qdrant_client):
|
||||
"""Test that asearch raises TypeError for sync client."""
|
||||
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method asearch\(\) requires"
|
||||
):
|
||||
await client.asearch(collection_name="test_collection", query="test query")
|
||||
|
||||
def test_delete_collection(self, client, mock_qdrant_client):
|
||||
"""Test that delete_collection deletes the collection."""
|
||||
mock_qdrant_client.collection_exists.return_value = True
|
||||
|
||||
client.delete_collection(collection_name="test_collection")
|
||||
|
||||
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
||||
mock_qdrant_client.delete_collection.assert_called_once_with(
|
||||
collection_name="test_collection"
|
||||
)
|
||||
|
||||
def test_delete_collection_not_exists(self, client, mock_qdrant_client):
|
||||
"""Test that delete_collection raises error if collection doesn't exist."""
|
||||
mock_qdrant_client.collection_exists.return_value = False
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' does not exist"
|
||||
):
|
||||
client.delete_collection(collection_name="test_collection")
|
||||
|
||||
mock_qdrant_client.collection_exists.assert_called_once_with("test_collection")
|
||||
mock_qdrant_client.delete_collection.assert_not_called()
|
||||
|
||||
def test_delete_collection_wrong_client_type(self, mock_async_qdrant_client):
|
||||
"""Test that delete_collection raises TypeError for async client."""
|
||||
client = QdrantClient(
|
||||
client=mock_async_qdrant_client, embedding_function=Mock()
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method delete_collection\(\) requires"
|
||||
):
|
||||
client.delete_collection(collection_name="test_collection")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adelete_collection(self, async_client, mock_async_qdrant_client):
|
||||
"""Test that adelete_collection deletes the collection asynchronously."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=True)
|
||||
mock_async_qdrant_client.delete_collection = AsyncMock()
|
||||
|
||||
await async_client.adelete_collection(collection_name="test_collection")
|
||||
|
||||
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
mock_async_qdrant_client.delete_collection.assert_called_once_with(
|
||||
collection_name="test_collection"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adelete_collection_not_exists(
|
||||
self, async_client, mock_async_qdrant_client
|
||||
):
|
||||
"""Test that adelete_collection raises error if collection doesn't exist."""
|
||||
mock_async_qdrant_client.collection_exists = AsyncMock(return_value=False)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Collection 'test_collection' does not exist"
|
||||
):
|
||||
await async_client.adelete_collection(collection_name="test_collection")
|
||||
|
||||
mock_async_qdrant_client.collection_exists.assert_called_once_with(
|
||||
"test_collection"
|
||||
)
|
||||
mock_async_qdrant_client.delete_collection.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_adelete_collection_wrong_client_type(self, mock_qdrant_client):
|
||||
"""Test that adelete_collection raises TypeError for sync client."""
|
||||
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method adelete_collection\(\) requires"
|
||||
):
|
||||
await client.adelete_collection(collection_name="test_collection")
|
||||
|
||||
def test_reset(self, client, mock_qdrant_client):
|
||||
"""Test that reset deletes all collections."""
|
||||
mock_collection1 = Mock()
|
||||
mock_collection1.name = "collection1"
|
||||
mock_collection2 = Mock()
|
||||
mock_collection2.name = "collection2"
|
||||
mock_collection3 = Mock()
|
||||
mock_collection3.name = "collection3"
|
||||
|
||||
mock_collections_response = Mock()
|
||||
mock_collections_response.collections = [
|
||||
mock_collection1,
|
||||
mock_collection2,
|
||||
mock_collection3,
|
||||
]
|
||||
mock_qdrant_client.get_collections.return_value = mock_collections_response
|
||||
|
||||
client.reset()
|
||||
|
||||
mock_qdrant_client.get_collections.assert_called_once()
|
||||
assert mock_qdrant_client.delete_collection.call_count == 3
|
||||
mock_qdrant_client.delete_collection.assert_any_call(
|
||||
collection_name="collection1"
|
||||
)
|
||||
mock_qdrant_client.delete_collection.assert_any_call(
|
||||
collection_name="collection2"
|
||||
)
|
||||
mock_qdrant_client.delete_collection.assert_any_call(
|
||||
collection_name="collection3"
|
||||
)
|
||||
|
||||
def test_reset_no_collections(self, client, mock_qdrant_client):
|
||||
"""Test that reset handles no collections gracefully."""
|
||||
mock_collections_response = Mock()
|
||||
mock_collections_response.collections = []
|
||||
mock_qdrant_client.get_collections.return_value = mock_collections_response
|
||||
|
||||
client.reset()
|
||||
|
||||
mock_qdrant_client.get_collections.assert_called_once()
|
||||
mock_qdrant_client.delete_collection.assert_not_called()
|
||||
|
||||
def test_reset_wrong_client_type(self, mock_async_qdrant_client):
|
||||
"""Test that reset raises TypeError for async client."""
|
||||
client = QdrantClient(
|
||||
client=mock_async_qdrant_client, embedding_function=Mock()
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method reset\(\) requires"
|
||||
):
|
||||
client.reset()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_areset(self, async_client, mock_async_qdrant_client):
|
||||
"""Test that areset deletes all collections asynchronously."""
|
||||
mock_collection1 = Mock()
|
||||
mock_collection1.name = "collection1"
|
||||
mock_collection2 = Mock()
|
||||
mock_collection2.name = "collection2"
|
||||
mock_collection3 = Mock()
|
||||
mock_collection3.name = "collection3"
|
||||
|
||||
mock_collections_response = Mock()
|
||||
mock_collections_response.collections = [
|
||||
mock_collection1,
|
||||
mock_collection2,
|
||||
mock_collection3,
|
||||
]
|
||||
mock_async_qdrant_client.get_collections = AsyncMock(
|
||||
return_value=mock_collections_response
|
||||
)
|
||||
mock_async_qdrant_client.delete_collection = AsyncMock()
|
||||
|
||||
await async_client.areset()
|
||||
|
||||
mock_async_qdrant_client.get_collections.assert_called_once()
|
||||
assert mock_async_qdrant_client.delete_collection.call_count == 3
|
||||
mock_async_qdrant_client.delete_collection.assert_any_call(
|
||||
collection_name="collection1"
|
||||
)
|
||||
mock_async_qdrant_client.delete_collection.assert_any_call(
|
||||
collection_name="collection2"
|
||||
)
|
||||
mock_async_qdrant_client.delete_collection.assert_any_call(
|
||||
collection_name="collection3"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_areset_no_collections(self, async_client, mock_async_qdrant_client):
|
||||
"""Test that areset handles no collections gracefully."""
|
||||
mock_collections_response = Mock()
|
||||
mock_collections_response.collections = []
|
||||
mock_async_qdrant_client.get_collections = AsyncMock(
|
||||
return_value=mock_collections_response
|
||||
)
|
||||
|
||||
await async_client.areset()
|
||||
|
||||
mock_async_qdrant_client.get_collections.assert_called_once()
|
||||
mock_async_qdrant_client.delete_collection.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_areset_wrong_client_type(self, mock_qdrant_client):
|
||||
"""Test that areset raises TypeError for sync client."""
|
||||
client = QdrantClient(client=mock_qdrant_client, embedding_function=Mock())
|
||||
|
||||
with pytest.raises(
|
||||
ClientMethodMismatchError, match=r"Method areset\(\) requires"
|
||||
):
|
||||
await client.areset()
|
||||
Reference in New Issue
Block a user